In [1]:
import pandas as pd
import numpy as np

# Data Processing

In [2]:
df = pd.read_csv('snli_1.0_train.txt', sep = '\t')

In [3]:
#excluding examples with no majority label, according to dataset description

df = df[df['gold_label'] != '-']

In [4]:
df['gold_label'].value_counts()

gold_label
entailment       183416
contradiction    183187
neutral          182764
Name: count, dtype: int64

In [7]:
df['gold_label'].replace({"entailment": 0, "neutral": 1, "contradiction": 2}, inplace = True)

In [23]:
#training on small subset of the data for now
batch_1 = df.sample(n = 5000)

In [6]:
from transformers import BertTokenizer, BertForSequenceClassification, AdamW, get_linear_schedule_with_warmup
import torch
from torch.utils.data import DataLoader, TensorDataset, RandomSampler, SequentialSampler

In [24]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")


def tokenize_premise_hypothesis(premise, hypothesis, tokenizer, max_length):
    
    combined_text = f"{premise} $ {hypothesis}" #similar to Radford paper
    
    out = tokenizer(combined_text, add_special_tokens=True, truncation=True, 
                    max_length=max_length, padding="max_length", return_tensors="pt")
    return out

max_length = 128 


inputs = [tokenize_premise_hypothesis(p, h, tokenizer, max_length) 
          for p, h in zip(batch_1['sentence1'], batch_1['sentence2'])]




labels = batch_1['gold_label'].to_numpy()


input_ids = torch.cat([entry['input_ids'] for entry in inputs], dim=0)
attention_mask = torch.cat([entry['attention_mask'] for entry in inputs], dim=0)
labels = torch.tensor(labels)


dataset = TensorDataset(input_ids, attention_mask, labels)

In [12]:
if torch.cuda.is_available():
    print ('cuda')
    device = 'cuda'
else:
    device = 'cpu'

cuda


In [25]:
# Load pre-trained BERT model
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=3).to(device)



batch_size = 32
num_epochs = 3

train_dataloader = DataLoader(dataset, sampler=RandomSampler(dataset), batch_size=batch_size)
#val_dataloader = DataLoader(val_dataset, sampler=SequentialSampler(val_dataset), batch_size=batch_size)


optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, 
                                            num_training_steps=len(train_dataloader) * num_epochs)

#fine-tuning BERT
for epoch in range(num_epochs):
    model.train()
    for batch in train_dataloader:
        input_ids, attention_mask, labels = batch
        outputs = model(input_ids.to(device), attention_mask=attention_mask.to(device), 
                        labels=labels.to(device))
        
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [31]:
batch_2 = df.sample(n = 10000)

In [32]:
val_inputs = [tokenize_premise_hypothesis(p, h, tokenizer, max_length) 
          for p, h in zip(batch_2['sentence1'], batch_2['sentence2'])]




val_labels = batch_2['gold_label'].to_numpy()

# Create tensors from the inputs and labels
val_input_ids = torch.cat([entry['input_ids'] for entry in val_inputs], dim=0)
val_attention_mask = torch.cat([entry['attention_mask'] for entry in val_inputs], dim=0)
val_labels = torch.tensor(val_labels)


val_dataset = TensorDataset(val_input_ids, val_attention_mask, val_labels)

In [33]:
valid_dataloader = DataLoader(val_dataset, batch_size=batch_size)


In [34]:
model.eval()


predictions = [] 

with torch.no_grad():
    for batch in valid_dataloader:
        input_ids, attention_mask, labels = batch
        outputs = model(input_ids.to(device), attention_mask=attention_mask.to(device))
        logits = outputs.logits
        predicted_labels = logits.argmax(dim=1)
        predictions.extend(predicted_labels.tolist())

In [35]:
(np.array(predictions) == val_labels.numpy()).mean()

0.7601