In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from transformers import BertModel
from datasets import load_dataset

from torch.optim import AdamW
from transformers import get_scheduler
from tqdm.auto import tqdm

In [None]:
#mnli = load_dataset('glue','mnli')

In [None]:
#mnli['train']

In [2]:
import csv
data_path = '/home/ubuntu/NLP-brain-biased-robustness/data/mnli/multinli_1.0/'

import sys
maxInt = sys.maxsize

while True:
    # decrease the maxInt value by factor 10 
    # as long as the OverflowError occurs.

    try:
        csv.field_size_limit(maxInt)
        break
    except OverflowError:
        maxInt = int(maxInt/10)

def load_data(data_file):
    dataset = []
    with open(data_path+data_file) as file:
        tsv_file = csv.reader(file, delimiter="\t")
        for line in tsv_file:
            dataset.append(line)
    return dataset


train_set = load_data('multinli_1.0_train.txt')
dev_matched = load_data('multinli_1.0_dev_matched.txt')
dev_mismatched = load_data('multinli_1.0_dev_mismatched.txt')

In [3]:
def split_data():
    telephone = []
    letters = []
    facetoface = []

    def extract(dataset):
        for ex in dataset:
            if ex[9] == 'telephone':
                telephone.append(ex)
            if ex[9] == 'letters':
                letters.append(ex)
            if ex[9] == 'facetoface':
                facetoface.append(ex)
                
    extract(train_set)
    extract(dev_matched)
    extract(dev_mismatched)
    return telephone, letters, facetoface

telephone, letters, facetoface = split_data()

In [5]:
def simplify_data(dataset):
    simplified_dataset = []
    for item in dataset:
        i = 0
        example = {}
        example['sentence_1'] = item[5]
        example['sentence_2'] = item[6]
        if item[0] == 'entailment':
            example['labels'] = [0,0,1]
            i = 1
        if item[0] == 'neutral':
            example['labels'] = [0,1,0]
            i = 1
        if item[0] == 'contradiction':
            example['labels'] = [1,0,0]
            i =1
        if i == 1:
            simplified_dataset.append(example)
    return simplified_dataset
        
train_set = simplify_data(train_set)[1:]
dev_matched = simplify_data(dev_matched)[1:]
dev_mismatched = simplify_data(dev_mismatched)[1:]

telephone = simplify_data(telephone)
letters = simplify_data(letters)
facetoface = simplify_data(facetoface)

In [None]:
telephone[0]

In [6]:
train_set_dataloader = DataLoader(train_set, shuffle=True, batch_size=8)
dev_matched_dataloader = DataLoader(dev_matched, shuffle=True, batch_size=8)
dev_mismatched_dataloader = DataLoader(dev_mismatched, shuffle=True, batch_size=8)

telephone_dataloader = DataLoader(telephone, shuffle=True, batch_size=8)
letters_dataloader = DataLoader(letters, shuffle=True, batch_size=8)
facetoface_dataloader = DataLoader(facetoface, shuffle=True, batch_size=8)

In [14]:
class PlaceHolderBERT(nn.Module):
    def __init__(self):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
        self.bert = BertModel.from_pretrained('bert-base-cased')
        self.linear = nn.Linear(768*2, 3)
        self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    def forward(self, x, y):
        x_embeddings = self.tokenizer(x, return_tensors='pt', padding=True, truncation=True)
        y_embeddings = self.tokenizer(y, return_tensors='pt', padding=True, truncation=True)
        x_embeddings.to(self.device)
        y_embeddings.to(self.device)
        x_representations = self.bert(**x_embeddings).last_hidden_state
        x_cls_representation = x_representations[:,0,:]
        y_representations = self.bert(**y_embeddings).last_hidden_state
        y_cls_representation = y_representations[:,0,:]
        input_vec = torch.cat((x_cls_representation, y_cls_representation), axis=1)
        pred = self.linear(input_vec)
        return pred
    
    
def train(model, dataloader, num_epochs=1): #can scrap keyword
    #optimizer as usual
    optimizer = AdamW(model.parameters(), lr=5e-5)
    loss_function = torch.nn.MSELoss()
    #learning rate scheduler
    num_training_steps = num_epochs * len(dataloader)
    lr_scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)

    #auto logging; progress bar
    progress_bar = tqdm(range(num_training_steps))

    #training loop
    model.train()
    for epoch in range(num_epochs):
        for batch in dataloader: #tryin unpacking text from 'labels' as in model development
            #batch = {k: v.to(device) for k, v in batch.items()}
            #features = {k: v for k, v in batch.items() if k != 'labels'}
            pred = model(batch['sentence_1'], batch['sentence_2'])
            targets = torch.stack(tuple(batch['labels'])).to(device)
            targets = torch.transpose(targets, 0, 1)
            loss = loss_function(pred, targets.float())
            loss.backward()

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(1)
            

def evaluate(model, dataloader):
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)
    model.eval()
    num_correct = 0
    num_samples = 0
    for batch in dataloader:
        with torch.no_grad():
            pred = model(batch['sentence_1'], batch['sentence_2'])
            pred = torch.argmax(pred, axis=1)
            targets = torch.stack(tuple(batch['labels'])).to(device)
            targets = torch.transpose(targets, 0, 1)
            labels = torch.argmax(targets, axis=1)
            num_correct += (pred==labels).sum()
            num_samples += pred.size(0)
    return float(num_correct)/float(num_samples)*100 


In [17]:
model = PlaceHolderBERT()
train(model, letters_dataloader)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


  0%|          | 0/248 [00:00<?, ?it/s]

In [15]:
evaluate(model, facetoface_dataloader)

42.806484295846