In [3]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from transformers import BertModel
from datasets import load_dataset

from torch.optim import AdamW
from transformers import get_scheduler
from tqdm.auto import tqdm
import os

In [6]:
dataset_path = '/home/ubuntu/nlp-brain-biased-robustness/data/mnli'
data_path = dataset_path+'/multinli_1.0'
if not os.path.exists(data_path):
    os.system('mkdir '+dataset_path)
    os.system('wget https://cims.nyu.edu/~sbowman/multinli/multinli_1.0.zip -P '+dataset_path)
    os.system(f'unzip /home/ubuntu/nlp-brain-biased-robustness/data/mnli/multinli_1.0.zip -d /home/ubuntu/nlp-brain-biased-robustness/data/mnli/')

In [None]:
#mnli = load_dataset('glue','mnli')

In [None]:
#mnli['train']

In [7]:
import csv
data_path = '/home/ubuntu/nlp-brain-biased-robustness/data/mnli/multinli_1.0/'

import sys
maxInt = sys.maxsize

while True:
    # decrease the maxInt value by factor 10 
    # as long as the OverflowError occurs.

    try:
        csv.field_size_limit(maxInt)
        break
    except OverflowError:
        maxInt = int(maxInt/10)

def load_data(data_file):
    dataset = []
    with open(data_path+data_file) as file:
        tsv_file = csv.reader(file, delimiter="\t")
        for line in tsv_file:
            dataset.append(line)
    return dataset


train_set = load_data('multinli_1.0_train.txt')
dev_matched = load_data('multinli_1.0_dev_matched.txt')
dev_mismatched = load_data('multinli_1.0_dev_mismatched.txt')

In [8]:
def split_data():
    telephone = []
    letters = []
    facetoface = []

    def extract(dataset):
        for ex in dataset:
            if ex[9] == 'telephone':
                telephone.append(ex)
            if ex[9] == 'letters':
                letters.append(ex)
            if ex[9] == 'facetoface':
                facetoface.append(ex)
                
    extract(train_set)
    extract(dev_matched)
    extract(dev_mismatched)
    return telephone, letters, facetoface

telephone, letters, facetoface = split_data()

In [9]:
def simplify_data(dataset):
    simplified_dataset = []
    for item in dataset:
        i = 0
        example = {}
        example['sentence_1'] = item[5]
        example['sentence_2'] = item[6]
        if item[0] == 'entailment':
            example['labels'] = [0,0,1]
            i = 1
        if item[0] == 'neutral':
            example['labels'] = [0,1,0]
            i = 1
        if item[0] == 'contradiction':
            example['labels'] = [1,0,0]
            i =1
        if i == 1:
            simplified_dataset.append(example)
    return simplified_dataset
        
train_set = simplify_data(train_set)[1:]
dev_matched = simplify_data(dev_matched)[1:]
dev_mismatched = simplify_data(dev_mismatched)[1:]

telephone = simplify_data(telephone)
letters = simplify_data(letters)
facetoface = simplify_data(facetoface)

In [10]:
telephone_dataset = []
for data_point in telephone:
    new_data_point = {}
    new_sentence = data_point['sentence_1']+'. '+data_point['sentence_2']
    new_data_point['sentence'] = new_sentence
    new_data_point['labels'] = data_point['labels']
    telephone_dataset.append(new_data_point)
    
letters_dataset = []
for data_point in letters:
    new_data_point = {}
    new_sentence = data_point['sentence_1']+'. '+data_point['sentence_2']
    new_data_point['sentence'] = new_sentence
    new_data_point['labels'] = data_point['labels']
    letters_dataset.append(new_data_point)
    
facetoface_dataset = []
for data_point in facetoface:
    new_data_point = {}
    new_sentence = data_point['sentence_1']+'. '+data_point['sentence_2']
    new_data_point['sentence'] = new_sentence
    new_data_point['labels'] = data_point['labels']
    facetoface_dataset.append(new_data_point)

In [12]:
len(facetoface_dataset)

1974

In [11]:
#train_set_dataloader = DataLoader(train_set, shuffle=True, batch_size=8)
#dev_matched_dataloader = DataLoader(dev_matched, shuffle=True, batch_size=8)
#dev_mismatched_dataloader = DataLoader(dev_mismatched, shuffle=True, batch_size=8)

telephone_dataloader = DataLoader(telephone_dataset[:15000], shuffle=True, batch_size=8)
letters_dataloader = DataLoader(letters_dataset, shuffle=True, batch_size=8)
facetoface_dataloader = DataLoader(facetoface_dataset, shuffle=True, batch_size=8)

In [12]:
import wandb

def change_all_keys(pre_odict):
    def change_key(odict, old, new):
        for _ in range(len(odict)):
            k, v = odict.popitem(False)
            odict[new if old == k else k] = v
            return odict
    for key in pre_odict.keys():
        if key[:5] == 'bert.':
            post_odict = change_key(pre_odict, key, key[5:])
            return change_all_keys(post_odict)
        if key[:7] == 'linear.':
            del pre_odict[key]
            return change_all_keys(pre_odict)
    return pre_odict

class PlaceHolderBERT(nn.Module):
    def __init__(self, brain=True):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
        self.bert = BertModel.from_pretrained('bert-base-cased')
        if brain == True:
            state_path = '/home/ubuntu/nlp-brain-biased-robustness/state_dicts/NSD_model_prime_prime_epoch_10'
            pre_odict = torch.load(state_path)
            filtered_odict = change_all_keys(pre_odict)
            self.bert.load_state_dict(filtered_odict, strict=True)
        self.linear = nn.Linear(768, 3)
        self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    def forward(self, x):
        x_embeddings = self.tokenizer(x, return_tensors='pt', padding=True, truncation=True)
        x_embeddings.to(self.device)
        x_representations = self.bert(**x_embeddings).last_hidden_state
        x_cls_representation = x_representations[:,0,:]
        pred = self.linear(x_cls_representation)
        return pred
    
    
def train(model, dataloader, num_epochs=10): #can scrap keyword
    wandb.init(project="preliminary results just in case", entity="nlp-brain-biased-robustness")
    wandb.run.name = 'mnli bb bert fo real e 10'
    wandb.config = {
      "learning_rate": 5e-5,
      "epochs": 10,
      "batch_size": 8
    }
    #optimizer as usual
    optimizer = AdamW(model.parameters(), lr=5e-5)
    loss_function = torch.nn.MSELoss()
    #learning rate scheduler
    num_training_steps = num_epochs * len(dataloader)
    lr_scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)

    #auto logging; progress bar
    progress_bar = tqdm(range(num_training_steps))

    #training loop
    model.train()
    for epoch in range(num_epochs):
        for batch in dataloader: #tryin unpacking text from 'labels' as in model development
            #batch = {k: v.to(device) for k, v in batch.items()}
            #features = {k: v for k, v in batch.items() if k != 'labels'}
            pred = model(batch['sentence'])
            targets = torch.stack(tuple(batch['labels'])).to(device)
            targets = torch.transpose(targets, 0, 1)
            loss = loss_function(pred, targets.float())
            loss.backward()

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(1)
        telephone_score = evaluate(model, telephone_dataloader)
        print(telephone_score)
        wandb.log({"telephone": telephone_score})
        letters_score = evaluate(model, letters_dataloader)
        print(letters_score)
        wandb.log({"letters": letters_score})
        facetoface_score = evaluate(model, facetoface_dataloader)
        print(facetoface_score)
        wandb.log({"facetoface": facetoface_score})
            

def evaluate(model, dataloader):
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)
    model.eval()
    num_correct = 0
    num_samples = 0
    for batch in dataloader:
        with torch.no_grad():
            pred = model(batch['sentence'])
            pred = torch.argmax(pred, axis=1)
            targets = torch.stack(tuple(batch['labels'])).to(device)
            targets = torch.transpose(targets, 0, 1)
            labels = torch.argmax(targets, axis=1)
            num_correct += (pred==labels).sum()
            num_samples += pred.size(0)
    return float(num_correct)/float(num_samples)*100 


In [None]:
model = PlaceHolderBERT()
train(model, telephone_dataloader)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
[34m[1mwandb[0m: Currently logged in as: [33mjgc239[0m ([33mnlp-brain-biased-robustness[0m). Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/18750 [00:00<?, ?it/s]

69.5
67.42539200809307
61.043566362715296
88.53333333333333
66.86899342438038
64.58966565349544
95.48666666666666
65.14921598381386
62.00607902735562
97.67333333333333
63.63176530096105
63.0192502532928
98.48666666666666
66.51492159838138
63.37386018237082
99.03999999999999
67.77946383409206
62.76595744680851
99.45333333333333
66.81841173495194
64.08308004052685
99.60666666666667
66.61608497723824
63.47517730496454


In [15]:
evaluate(model, facetoface_dataloader)

42.806484295846