In [None]:
%matplotlib inline
import math
import nltk
import numpy as np
import os
import random
import torch
import torch.nn as nn
import unittest

from collections import Counter
from datetime import datetime
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
from transformers import AdamW, get_linear_schedule_with_warmup

def set_seed(seed, device='cpu'):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  if device == 'cuda':
    torch.cuda.manual_seed_all(seed)

In [30]:
def read_sst2_file(path):
  examples = []
  with open(path) as f:
    f.readline() 
    for line in f:
      sent, label = line.split('\t')
      examples.append([sent, int(label)])
  return examples

class SST2Dataset(Dataset):
  def __init__(self, filename):
    path = os.path.join(os.getcwd() + '/datasets/SST-2', filename)
    self.examples = read_sst2_file(path)

  def __len__(self):
    return len(self.examples)

  def __getitem__(self, index):
    return self.examples[index]

dataset_train = SST2Dataset('train.tsv')
dataset_val = SST2Dataset('dev.tsv')
print('{} train sents, {} val sents'.format(len(dataset_train), len(dataset_val)))

67349 train sents, 872 val sents


In [31]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
print('The pretrained tokenizer in bert-base-uncased has vocab size {:d}\n'.format(tokenizer.vocab_size))

# We need to provide a custom collate function to DataLoader because we're handling sentences of varying lengths. 
def collate_fn(batch):
  # https://huggingface.co/transformers/preprocessing.html
  sents, labels = zip(*batch)
  labels = torch.FloatTensor(labels)
  encoded = tokenizer(sents, padding=True, return_tensors='pt')
  return encoded['input_ids'], encoded['attention_mask'], labels

set_seed(42)
dataloader_val = DataLoader(dataset_val, batch_size=2, shuffle=False, num_workers=2, collate_fn=collate_fn)

The pretrained tokenizer in bert-base-uncased has vocab size 30522



In [32]:
def get_init_transformer(transformer):
  """
  Initialization scheme used for transformers:
  https://huggingface.co/transformers/_modules/transformers/modeling_bert.html
  """
  def init_transformer(module):
    if isinstance(module, (nn.Linear, nn.Embedding)):
        module.weight.data.normal_(mean=0.0, std=transformer.config.initializer_range)
    elif isinstance(module, nn.LayerNorm):
        module.bias.data.zero_()
        module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
        module.bias.data.zero_()

  return init_transformer


class BertClassifier(nn.Module):

  def __init__(self, drop=0.1):
    super().__init__()
    self.encoder = BertModel.from_pretrained('bert-base-uncased')
    self.score = nn.Sequential(nn.Dropout(drop), 
                               nn.Linear(self.encoder.config.hidden_size, 1))
    self.score.apply(get_init_transformer(self.encoder))  # Important to initialize any additional weights the same way as pretrained encoder. 
    self.loss = nn.BCEWithLogitsLoss(reduction='sum') 

  def forward(self, input_ids, attention_mask, labels):
    hiddens_last = self.encoder(input_ids, attention_mask=attention_mask)[0]  # (batch_size, length, dim), these are last layer embeddings
    embs = hiddens_last[:,0,:]  # [CLS] token embeddings
    logits = self.score(embs).squeeze(1)  # batch_size
    loss_total = self.loss(logits, labels)
    return logits, loss_total

In [33]:
def count_params(model):
  return sum(p.numel() for p in model.parameters())

model = BertClassifier()
print('Model has {} parameters\n'.format(count_params(model)))

Model has 109483009 parameters



In [34]:
def configure_optimization(model, num_train_steps, num_warmup_steps, lr, weight_decay=0.01):  
  # Copied from: https://huggingface.co/transformers/training.html
  no_decay = ['bias', 'LayerNorm.weight']
  optimizer_grouped_parameters = [{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 
                                   'weight_decay': weight_decay},
                                  {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
                                   'weight_decay': 0.}]
  optimizer = AdamW(optimizer_grouped_parameters, lr=lr)  
  scheduler = get_linear_schedule_with_warmup(optimizer, num_training_steps=num_train_steps, num_warmup_steps=num_warmup_steps) 
  return optimizer, scheduler

In [35]:
def get_acc_val(model, device):
  num_correct_val = 0
  model.eval()  
  with torch.no_grad(): 
    for input_ids, attention_mask, labels in dataloader_val:
      input_ids = input_ids.to(device)
      attention_mask = attention_mask.to(device)
      labels = labels.to(device)
      logits, _ = model(input_ids, attention_mask, labels) 
      preds = torch.where(logits > 0., 1, 0)  # 1 if p(1|x) > 0.5, 0 else
      num_correct_val += (preds == labels).sum()
  acc_val = num_correct_val / len(dataloader_val.dataset) * 100.
  return acc_val

In [37]:
def train(model, batch_size=32, num_warmup_steps=10, lr=0.00005, num_epochs=3, clip=1., verbose=True, device='cpu'):
  model = model.to(device)  # Move the model to device.  
  dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=collate_fn) 
  num_train_steps = len(dataset_train) // batch_size * num_epochs
  optimizer, scheduler = configure_optimization(model, num_train_steps, num_warmup_steps, lr)

  loss_avg = float('inf')
  acc_train = 0.
  best_acc_val = 0.
  for epoch in range(num_epochs):
    model.train()  # This turns on the training mode (e.g., enable dropout).
    loss_total = 0.
    num_correct_train = 0
    for batch_ind, (input_ids, attention_mask, labels) in enumerate(dataloader_train):
      if (batch_ind + 1) % 200 == 0: 
        print(batch_ind + 1, '/', len(dataloader_train), 'batches done')
      input_ids = input_ids.to(device).long()
      attention_mask = attention_mask.to(device)
      labels = labels.to(device)      
      logits, loss_batch_total = model(input_ids, attention_mask, labels) 
      preds = torch.where(logits > 0., 1, 0)  # 1 if p(1|x) > 0.5, 0 else
      num_correct_train += (preds == labels).sum()
      loss_total += loss_batch_total.item()            
      
      loss_batch_avg = loss_batch_total / input_ids.size(0)  
      loss_batch_avg.backward()  

      if clip > 0.:  # Optional gradient clipping
        nn.utils.clip_grad_norm_(model.parameters(), clip)

      optimizer.step()  # optimizer updates model weights based on stored gradients
      scheduler.step()  # Update lr. 
      optimizer.zero_grad()  # Reset gradient slots to zero

    # Useful training information
    loss_avg = loss_total / len(dataloader_train.dataset)
    acc_train = num_correct_train / len(dataloader_train.dataset) * 100.

    # Check validation performance at the end of every epoch. 
    acc_val = get_acc_val(model, device)

    if verbose:
      print('Epoch {:3d} | avg loss {:8.4f} | train acc {:2.2f} | val acc {:2.2f}'.format(epoch + 1, loss_avg, acc_train, acc_val))

    if acc_val > best_acc_val: 
      best_acc_val = acc_val
  
  if verbose: 
    print('Final avg loss {:8.4f} | final train acc {:2.2f} | best val acc {:2.2f}'.format(loss_avg, acc_train, best_acc_val))

  return best_acc_val

## Below Runs BERT on SST-2 

In [38]:
if True: # Set True to run. 
  set_seed(42)
  model = BertClassifier()
  best_acc_val = train(model, batch_size=32)

200 / 2105 batches done
400 / 2105 batches done


KeyboardInterrupt: 