---   
# HW3 - Transfer learning

#### Due October 30, 2019

In this assignment you will learn about transfer learning. This technique is perhaps one of the most important techniques for industry. When a problem you want to solve does not have enough data, we use a different (larger) dataset to learn representations which can help us solve our task using the smaller task.

The general steps to transfer learning are as follows:

1. Find a huge dataset with similar characteristics to the problem you are interested in.
2. Choose a model powerful enough to extract meaningful representations from the huge dataset.
3. Train this model on the huge dataset.
4. Use this model to train on the smaller dataset.


### This homework has the following sections:
1. Question 1: MNIST fine-tuning (Parts A, B, C, D).
2. Question 2: Pretrain on Wikitext2 (Part A, B, C, D)
3. Question 3: Finetune on MNLI (Part A, B, C, D)
4. Question 4: Finetune using pretrained BERT (Part A, B, C)

---   
## Question 1 (MNIST transfer learning)
To grasp the high-level approach to transfer learning, let's first do a simple example using computer vision. 

The torchvision library has pretrained models (resnets, vggnets, etc) on the Imagenet dataset. Imagenet is a dataset
with 1.3 million images covering over 1000 classes of objects. When you use one of these models, the weights of the model initialize
with the weights saved from training on imagenet.

In this task we will:
1. Choose a pretrained model.
2. Freeze the model so that the weights don't change.
3. Fine-tune on a few labels of MNIST.   

#### Choose a model
Here we pick any of the models from torchvision

In [1]:
import torch
import torchvision.models as models

class Identity(torch.nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x

# init the pretrained feature extractor
pretrained_resnet18 = models.resnet18(pretrained=True)

# we don't want the built in last layer, we're going to modify it ourselves
pretrained_resnet18.fc = Identity()

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


HBox(children=(FloatProgress(value=0.0, max=46827520.0), HTML(value='')))




#### Freeze the model
Here we freeze the weights of the model. Freezing means the gradients will not backpropagate
into these weights.

By doing this you can think about the model as a feature extractor. This feature extractor outputs
a **representation** of an input. This representation is a matrix that encodes information about the input.

In [2]:
def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False
        
def unfreeze_model(model):
    for param in model.parameters():
        param.requires_grad = True
        
freeze_model(pretrained_resnet18)

#### Init target dataset
Here we define the dataset we are actually interested in.

In [None]:
import os
from torchvision import transforms
from torchvision.datasets import  MNIST
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F

#  train/val  split
mnist_dataset = MNIST(os.getcwd(), train=True, download=True, transform = transforms.Compose([transforms.Grayscale(3), transforms.ToTensor()]))
mnist_train, mnist_val = random_split(mnist_dataset, [55000, 5000])

mnist_train = DataLoader(mnist_train, batch_size=32)
mnist_val = DataLoader(mnist_val, batch_size=32)

# test split
mnist_test = MNIST(os.getcwd(), train=False, download=True, transform = transforms.Compose([transforms.Grayscale(3), transforms.ToTensor()]))
mnist_test = DataLoader(mnist_test, batch_size=32)

### Part A (init fine-tune model)
decide what model to use for fine-tuning

In [None]:
def init_fine_tune_model():
    # YOUR CODE HERE
    fine_tune_model = feature_extractor
    return fine_tune_model

### Part B (Fine-tune (Frozen))

The actual problem we care about solving likely has a different number of classes or is a different task altogether. Fine-tuning is the process of using the extracted representations (features) to solve this downstream task  (the task you're interested in).

To illustrate this, we'll use our pretrained model (on Imagenet), to solve the MNIST classification task.

There are two types of finetuning. 

#### 1. Frozen feature_extractor
In the first type we pretrain with the FROZEN feature_extractor and NEVER unfreeze it during finetuning.


#### 2. Unfrozen feature_extractor
In the second, we finetune with a FROZEN feature_extractor for a few epochs, then unfreeze the feature extractor and finish training.


In this part we will use the first version

In [None]:
import torch.optim as optim
from torch import nn
from copy import deepcopy
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def FROZEN_fine_tune_mnist(feature_extractor, fine_tune_model, mnist_train, mnist_val):
    """
    model is a feature extractor (resnet).
    Create a new model which uses those features to finetune on MNIST
    
    return the fine_tune model
    """     
    
    # INSERT YOUR CODE: (train the fine_tune model using features extracted by feature_extractor)
    for param in feature_extractor.parameters():
        param.requires_grad = False
    feature_extractor.fc = nn.Linear(512, 10)
    feature_extractor.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(feature_extractor.parameters(), lr = 0.0001)
    #do train
    highest_acc = 0
    for epoch in range(30):
        feature_extractor.train()
        correct_train = 0
        total_train = 0
        correct_val = 0
        total_val = 0
        for i, (imgs, labels) in enumerate(mnist_train):
            optimizer.zero_grad()
            imgs = imgs.to(device)
            labels = labels.to(device)
            outputs = feature_extractor(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            predicted = outputs.max(1, keepdim=True)[1]
            total_train += labels.size(0)
            correct_train += predicted.eq(labels.view_as(predicted)).sum().item()
        print('Training accuracy after {} epoch = {:.{prec}f}'.format(epoch, 100 * correct_train / total_train, prec=4))
            
        # do eval
        feature_extractor.eval()
        with torch.no_grad():
            for i, (imgs, labels) in enumerate(mnist_val):
                imgs = imgs.to(device)
                labels = labels.to(device)
                outputs = feature_extractor(imgs)
                predicted = outputs.max(1, keepdim=True)[1]
                total_val += labels.size(0)
                correct_val += predicted.eq(labels.view_as(predicted)).sum().item()
                loss = criterion(outputs, labels)
        print('Validation accuracy after {} epoch = {:.{prec}f}'.format(epoch, 100 * correct_val / total_val, prec=4))
        if 100 * correct_val / total_val > highest_acc:
                highest_acc = 100 * correct_val / total_val
                fine_tune_model = deepcopy(feature_extractor)
 

### Part C (compute test accuracy)
Compute the test accuracy of fine-tuned model on MNIST

In [None]:
def calculate_mnist_test_accuracy(feature_extractor, fine_tune_model, mnist_test):
    fine_tune_model.eval()
    total_test, correct_test = 0, 0
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for i, (imgs, labels) in enumerate(mnist_test):
            imgs = imgs.to(device)
            labels = labels.to(device)
            outputs = fine_tune_model(imgs)
            predicted = outputs.max(1, keepdim=True)[1]
            total_test += labels.size(0)
            correct_test += predicted.eq(labels.view_as(predicted)).sum().item()
            loss = criterion(outputs, labels)
    test_accuracy = 100 * correct_test / total_test
    return test_accuracy

### Grade!
Let's see how you did

In [None]:
def grade_mnist_frozen():
    
    # init a ft model
    fine_tune_model = init_fine_tune_model()
    
    # run the transfer learning routine
    FROZEN_fine_tune_mnist(pretrained_resnet18, fine_tune_model, mnist_train, mnist_val)
    
    # calculate test accuracy
    test_accuracy = calculate_mnist_test_accuracy(pretrained_resnet18, fine_tune_model, mnist_test)
    
    # the real threshold will be released by Oct 11 
    assert test_accuracy > 0.0, 'your accuracy is too low...'
    
    return test_accuracy
    
frozen_test_accuracy = grade_mnist_frozen()

Training accuracy after 0 epoch = 52.2527
Validation accuracy after 0 epoch = 67.8000
Training accuracy after 1 epoch = 67.3145
Validation accuracy after 1 epoch = 72.1400
Training accuracy after 2 epoch = 70.1982
Validation accuracy after 2 epoch = 74.1800
Training accuracy after 3 epoch = 71.5491
Validation accuracy after 3 epoch = 75.1800
Training accuracy after 4 epoch = 72.4091
Validation accuracy after 4 epoch = 75.7600
Training accuracy after 5 epoch = 72.9673
Validation accuracy after 5 epoch = 76.2800
Training accuracy after 6 epoch = 73.4473
Validation accuracy after 6 epoch = 76.5400
Training accuracy after 7 epoch = 73.7582
Validation accuracy after 7 epoch = 76.7800
Training accuracy after 8 epoch = 74.0382
Validation accuracy after 8 epoch = 77.0400
Training accuracy after 9 epoch = 74.2145
Validation accuracy after 9 epoch = 77.1000
Training accuracy after 10 epoch = 74.3782
Validation accuracy after 10 epoch = 77.2800
Training accuracy after 11 epoch = 74.5600
Validatio

In [None]:
frozen_test_accuracy

78.45

### Part D (Fine-tune Unfrozen)
Now we'll learn how to train using the "unfrozen" approach.

In this approach we'll:
1. keep the feature_extract frozen for a few epochs (10)
2. Unfreeze it.
3. Finish training

In [None]:
def UNFROZEN_fine_tune_mnist(feature_extractor, fine_tune_model, mnist_train, mnist_val):
    """
    model is a feature extractor (resnet).
    Create a new model which uses those features to finetune on MNIST
    
    return the fine_tune model
    """     
    
    # INSERT YOUR CODE:
    # keep frozen for 10 epochs
    # ... train
    # unfreeze
    # train for rest of the time
    for param in feature_extractor.parameters():
        param.requires_grad = False
    feature_extractor.fc = nn.Linear(512, 10)
    feature_extractor.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(feature_extractor.parameters(), lr = 0.00001)
    #do train
    highest_acc = 0
    for epoch in range(10):
        correct_train = 0
        total_train = 0
        correct_val = 0
        total_val = 0
        feature_extractor.train()
        for i, (imgs, labels) in enumerate(mnist_train):
            optimizer.zero_grad()
            imgs = imgs.to(device)
            labels = labels.to(device)
            outputs = feature_extractor(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            predicted = outputs.max(1, keepdim=True)[1]
            total_train += labels.size(0)
            correct_train += predicted.eq(labels.view_as(predicted)).sum().item()
        print('Training accuracy after {} epoch = {:.{prec}f}'.format(epoch, 100 * correct_train / total_train, prec=4))
        # do eval
        feature_extractor.eval()
        with torch.no_grad():
            for i, (imgs, labels) in enumerate(mnist_val):
                imgs = imgs.to(device)
                labels = labels.to(device)
                outputs = feature_extractor(imgs)
                predicted = outputs.max(1, keepdim=True)[1]
                total_val += labels.size(0)
                correct_val += predicted.eq(labels.view_as(predicted)).sum().item()
                loss = criterion(outputs, labels)
                if 100 * correct_val / total_val > highest_acc:
                    highest_acc = 100 * correct_val / total_val
                    fine_tune_model = deepcopy(feature_extractor)
        print('Validation accuracy after {} epoch = {:.{prec}f}'.format(epoch, 100 * correct_val / total_val, prec=4))
    for param in feature_extractor.parameters():
        param.requires_grad = True
    for epoch in range(20):
        correct_train = 0
        total_train = 0
        correct_val = 0
        total_val = 0
        feature_extractor.train()
        for i, (imgs, labels) in enumerate(mnist_train):
            optimizer.zero_grad()
            imgs = imgs.to(device)
            labels = labels.to(device)
            outputs = feature_extractor(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            predicted = outputs.max(1, keepdim=True)[1]
            total_train += labels.size(0)
            correct_train += predicted.eq(labels.view_as(predicted)).sum().item()
        print('Training accuracy after {} epoch = {:.{prec}f}'.format(epoch + 10, 100 * correct_train / total_train, prec=4))
        # do eval
        feature_extractor.eval()
        with torch.no_grad():
            for i, (imgs, labels) in enumerate(mnist_val):
                imgs = imgs.to(device)
                labels = labels.to(device)
                outputs = feature_extractor(imgs)
                predicted = outputs.max(1, keepdim=True)[1]
                total_val += labels.size(0)
                correct_val += predicted.eq(labels.view_as(predicted)).sum().item()
                loss = criterion(outputs, labels)
                if 100 * correct_val / total_val > highest_acc:
                    highest_acc = 100 * correct_val / total_val
                    fine_tune_model = deepcopy(feature_extractor)
        print('Validation accuracy after {} epoch = {:.{prec}f}'.format(epoch + 10, 100 * correct_val / total_val, prec=4))

### Grade UNFROZEN
Let's see if there's a difference in accuracy!

In [None]:
def grade_mnist_unfrozen():
    
    # init a ft model
    fine_tune_model = init_fine_tune_model()
    
    # run the transfer learning routine
    UNFROZEN_fine_tune_mnist(pretrained_resnet18, fine_tune_model, mnist_train, mnist_val)
    
    # calculate test accuracy
    test_accuracy = calculate_mnist_test_accuracy(pretrained_resnet18, fine_tune_model, mnist_test)
    
    # the real threshold will be released by Oct 11 
    assert test_accuracy > 0.0, 'your accuracy is too low...'
    
    return test_accuracy
    
unfrozen_test_accuracy = grade_mnist_unfrozen()

Training accuracy after 0 epoch = 19.3218
Validation accuracy after 0 epoch = 30.9200
Training accuracy after 1 epoch = 37.1855
Validation accuracy after 1 epoch = 45.9400
Training accuracy after 2 epoch = 47.7109
Validation accuracy after 2 epoch = 54.3200
Training accuracy after 3 epoch = 53.6727
Validation accuracy after 3 epoch = 58.9800
Training accuracy after 4 epoch = 57.2782
Validation accuracy after 4 epoch = 62.3800
Training accuracy after 5 epoch = 59.7345
Validation accuracy after 5 epoch = 64.2800
Training accuracy after 6 epoch = 61.5836
Validation accuracy after 6 epoch = 65.9000
Training accuracy after 7 epoch = 62.9618
Validation accuracy after 7 epoch = 66.9800
Training accuracy after 8 epoch = 64.0964
Validation accuracy after 8 epoch = 68.0800
Training accuracy after 9 epoch = 65.0364
Validation accuracy after 9 epoch = 68.8200
Training accuracy after 0 epoch = 89.6564
Validation accuracy after 0 epoch = 97.1800
Training accuracy after 1 epoch = 97.0436
Validation a

In [None]:
unfrozen_test_accuracy

98.79

In [None]:
assert unfrozen_test_accuracy > frozen_test_accuracy, 'the unfrozen model should be better'

--- 
# Question 2 (train a model on Wikitext-2)

Here we'll apply what we just learned to NLP. In this section we'll make our own feature extractor and pretrain it on Wikitext-2.

The WikiText language modeling dataset is a collection of over 100 million tokens extracted from the set of verified Good and Featured articles on Wikipedia. The dataset is available under the Creative Commons Attribution-ShareAlike License.

#### Part A
In this section you need to generate the training, validation and test split. Feel free to use code from your previous lectures.

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
!pip install jsonlines

Collecting jsonlines
  Downloading https://files.pythonhosted.org/packages/d4/58/06f430ff7607a2929f80f07bfd820acbc508a4e977542fefcc522cde9dff/jsonlines-2.0.0-py3-none-any.whl
Installing collected packages: jsonlines
Successfully installed jsonlines-2.0.0


In [5]:
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
import pickle
import torch
import torchvision.models as models
import os
from torchvision import transforms
from torchvision.datasets import  MNIST
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F
import os
import json
import jsonlines
import numpy as np
from collections import defaultdict
from torch import nn

In [6]:
import io
def load_vectors(fname):
    fin = io.open(fname, 'r', encoding='utf-8', newline='\n', errors='ignore')
    n, d = map(int, fin.readline().split())
    embedding_size = 300
    max_vocab_size = 35000
    embedding_dict = np.random.randn(max_vocab_size+2, embedding_size)
    all_train_tokens = []
    i = 0
    
    for line in fin:
        tokens = line.rstrip().split(' ')
        all_train_tokens.append(tokens[0])
        embedding_dict[i+2] = list(map(float, tokens[1:]))
        i += 1
        if i == max_vocab_size:
            break
            
    return embedding_dict, all_train_tokens
  
# download the vectors yourself
fasttext_embedding_dict, all_fasttext_tokens = load_vectors('/content/drive/My Drive/wiki-news-300d-1M.vec')
  
class LMDataset(Dataset):
    def __init__(self, list_of_token_lists):
        self.input_tensors = []
        self.target_tensors = []

        for sample in list_of_token_lists:
            self.input_tensors.append(torch.tensor([sample[:-1]], dtype=torch.long))
            self.target_tensors.append(torch.tensor([sample[1:]], dtype=torch.long))

    def __len__(self):
        return len(self.input_tensors)

    def __getitem__(self, idx):
        return (self.input_tensors[idx], self.target_tensors[idx])


def tokenize_dataset(datasets, dictionary):
    tokenized_datasets = {}
    for split, dataset in datasets.items():
        _current_dictified = []
        for l in tqdm(dataset):
            l = ['<bos>'] + l + ['<eos>']
            encoded_l = dictionary.encode_token_seq(l)
            _current_dictified.append(encoded_l)
        tokenized_datasets[split] = _current_dictified
    return tokenized_datasets

def tokenize_mnli_dataset(datasets, dictionary):
    tokenized_datasets = {}
    for split, dataset in datasets.items():
        _current_dictified = []
        for s1, s2 in tqdm(dataset):
            s1 = ['<bos>'] + s1 + ['<eos>']
            s2 = ['<bos>'] + s2 + ['<eos>']
            encoded_s1 = dictionary.encode_token_seq(s1)            
            encoded_s2 = dictionary.encode_token_seq(s2)
            _current_dictified.append([encoded_s1, encoded_s2])
        tokenized_datasets[split] = _current_dictified
    return tokenized_datasets

def pad_list_of_tensors(list_of_tensors, pad_token):
    max_length = max([t.size(-1) for t in list_of_tensors])
    padded_list = []
    for t in list_of_tensors:
        padded_tensor = torch.cat(
            [t, torch.tensor([[pad_token] * (max_length - t.size(-1))], dtype=torch.long)], dim=-1)
        padded_list.append(padded_tensor)

    padded_tensor = torch.cat(padded_list, dim=0)
    return padded_tensor


def pad_collate_fn(pad_idx, batch):
    input_list = [s[0] for s in batch]
    target_list = [s[1] for s in batch]
    input_tensor = pad_list_of_tensors(input_list, pad_idx)
    target_tensor = pad_list_of_tensors(target_list, pad_idx)
    return input_tensor, target_tensor


def load_wikitext(data_dir):
    import subprocess
    filename = os.path.join(data_dir, 'wikitext2-sentencized.json')
    if not os.path.exists(filename):
        os.makedirs(data_dir, exist_ok=True)
        url = "https://nyu.box.com/shared/static/9kb7l7ci30hb6uahhbssjlq0kctr5ii4.json"
        args = ['wget', '-O', filename, url]
        subprocess.call(args)
    raw_datasets = json.load(open(filename, 'r'))
    for name in raw_datasets:
        raw_datasets[name] = [x.split() for x in raw_datasets[name]]

    if os.path.exists(os.path.join(data_dir, 'vocab.pkl')):
        vocab = pickle.load(open(os.path.join(data_dir, 'vocab.pkl'), 'rb'))
    else:
        vocab = Dictionary(raw_datasets, include_valid=False)
        pickle.dump(vocab, open(os.path.join(data_dir, 'vocab.pkl'), 'wb'))

    tokenized_datasets = tokenize_dataset(raw_datasets, vocab)
    datasets = {name: LMDataset(ds) for name, ds in tokenized_datasets.items()}
    print("Vocab size: %d" % (len(vocab)))
    return raw_datasets, datasets, vocab


class Dictionary(object):
    def __init__(self, datasets, include_valid=False):
        self.tokens = []
        self.ids = {}
        self.counts = {}
        
        # add special tokens
        self.add_token('<bos>')
        self.add_token('<eos>')
        self.add_token('<pad>')
        self.add_token('<unk>')
        
        for line in tqdm(datasets['train']):
            for w in line:
                self.add_token(w)
                    
        if include_valid is True:
            for line in tqdm(datasets['valid']):
                for w in line:
                    self.add_token(w)
                            
    def add_token(self, w):
        if w not in self.tokens:
            self.tokens.append(w)
            _w_id = len(self.tokens) - 1
            self.ids[w] = _w_id
            self.counts[w] = 1
        else:
            self.counts[w] += 1

    def get_id(self, w):
        return self.ids[w]
    
    def get_token(self, idx):
        return self.tokens[idx]
    
    def decode_idx_seq(self, l):
        return [self.tokens[i] for i in l]
    
    def encode_token_seq(self, l):
        return [self.ids[i] if i in self.ids else self.ids['<unk>'] for i in l]
    
    def __len__(self):
        return len(self.tokens)
 
raw_datasets, datasets, vocab = load_wikitext(os.getcwd())

data_loaders = {name: DataLoader(datasets[name], batch_size=32, shuffle=True,
                                     collate_fn=lambda x: pad_collate_fn(vocab.get_id('<pad>'), x))
                    for name in datasets}
wk2_train_dataloader = data_loaders['train']
wk2_val_dataloader = data_loaders['valid']
wk2_test_dataloader = data_loaders['test']

100%|██████████| 78274/78274 [02:00<00:00, 651.65it/s]
100%|██████████| 78274/78274 [00:00<00:00, 105856.78it/s]
100%|██████████| 8464/8464 [00:00<00:00, 124147.03it/s]
100%|██████████| 9708/9708 [00:00<00:00, 124416.17it/s]


Vocab size: 33178


In [7]:
from torchtext.datasets import WikiText2
def init_wikitext_dataset():
    """
    Fill in the details
    """
    wikitext_val = wk2_val_dataloader
    wikitext_train = wk2_train_dataloader
    wikitext_test = wk2_test_dataloader
    
    return wikitext_train, wikitext_val, wikitext_test

#### Part B   
Here we design our own feature extractor. In MNIST that was a resnet because we were dealing with images. Now we need to pick a model that can model sequences better. Design an RNN-based model here.

In [None]:
import torch.nn as nn

class GRU_LM(nn.Module):
    """
    This model combines embedding, gru and projection layer into a single model
    """
    def __init__(self, options):
        super().__init__()
        
        # create each LM part here 
        self.lookup = nn.Embedding(num_embeddings=options['num_embeddings'], embedding_dim=options['embedding_dim'], padding_idx=options['padding_idx'])
        self.gru = nn.GRU(options['input_size'], options['hidden_size'], options['num_layers'], dropout=options['gru_dropout'], batch_first=True)
        self.projection = nn.Linear(options['hidden_size'], options['num_embeddings'])
        
    def forward(self, encoded_input_sequence):
        """
        Forward method process the input from token ids to logits
        """
        embeddings = self.lookup(encoded_input_sequence)
        #self.gru.flatten_parameters()
        gru_outputs = self.gru(embeddings)
        logits = self.projection(gru_outputs[0])
        
        return logits, gru_outputs[0]

In [None]:
def init_feature_extractor():
    embedding_size = 300
    hidden_size = 1024
    num_layers = 2
    gru_dropout = 0.2

    options = {
        'num_embeddings': len(vocab),
        'embedding_dim': embedding_size,
        'padding_idx': vocab.get_id('<pad>'),
        'input_size': embedding_size,
        'hidden_size': hidden_size,
        'num_layers': num_layers,
        'gru_dropout': gru_dropout,
    }

    feature_extractor = GRU_LM((options)).to(device)
    return feature_extractor

#### Part C
Pretrain the feature extractor

In [None]:
num_gpus = torch.cuda.device_count()
if num_gpus > 0:
    device = 'cuda'
else:
    device = 'cpu'

In [None]:
import numpy as np
import torch.optim as optim
from copy import deepcopy
def fit_feature_extractor(feature_extractor, wikitext_train, wikitext_val):
    best_feature_extractor = None
    lowest_loss = float('inf')
    criterion = nn.CrossEntropyLoss(ignore_index=vocab.get_id('<pad>'), reduction='sum')
    feature_extractor_parameters = [p for p in feature_extractor.parameters() if p.requires_grad]
    optimizer = optim.Adam(feature_extractor_parameters, lr=0.0001)
    for epoch_number in range(20):
      avg_loss=0
      feature_extractor.train()
      train_loss_cache = 0
      train_non_pad_tokens_cache = 0
      for i, (inp, target) in enumerate(wikitext_train):
          optimizer.zero_grad()
          inp = inp.to(device)
          target = target.to(device)
          logits = feature_extractor(inp)[0]
          loss = criterion(logits.view(-1, logits.size(-1)), target.view(-1))
          train_loss_cache += loss.item()
          non_pad_tokens = target.view(-1).ne(vocab.get_id('<pad>')).sum().item()
          train_non_pad_tokens_cache += non_pad_tokens
          loss /= non_pad_tokens
          loss.backward()
          optimizer.step()
          if i % 100 == 0:
              avg_loss = train_loss_cache / train_non_pad_tokens_cache
              print('Step {} avg train loss = {:.{prec}f}'.format(i, avg_loss, prec=4))
              train_log_cache = []
            
      #do valid
      valid_loss_cache = 0
      valid_non_pad_tokens_cache = 0
      feature_extractor.eval()
      with torch.no_grad():
          for i, (inp, target) in enumerate(wikitext_val):
              inp = inp.to(device)
              target = target.to(device)
              logits = feature_extractor(inp)[0]
              loss = criterion(logits.view(-1, logits.size(-1)), target.view(-1))
              valid_loss_cache += loss.item()
              non_pad_tokens = target.view(-1).ne(vocab.get_id('<pad>')).sum().item()
              valid_non_pad_tokens_cache += non_pad_tokens
          avg_val_loss = valid_loss_cache / valid_non_pad_tokens_cache
          if avg_val_loss < lowest_loss:
              lowest_loss = avg_val_loss
              best_feature_extractor = deepcopy(feature_extractor)
          print('Validation loss after {} epoch = {:.{prec}f}'.format(epoch_number, avg_val_loss, prec=4))
    return best_feature_extractor
          


In [None]:
wikitext_train, wikitext_val, wikitext_test = init_wikitext_dataset()
feature_extractor = fit_feature_extractor(init_feature_extractor(), wikitext_train, wikitext_val)

Step 0 avg train loss = 10.4097
Step 100 avg train loss = 7.9352
Step 200 avg train loss = 7.4803
Step 300 avg train loss = 7.2839
Step 400 avg train loss = 7.1609
Step 500 avg train loss = 7.0709
Step 600 avg train loss = 6.9968
Step 700 avg train loss = 6.9388
Step 800 avg train loss = 6.8864
Step 900 avg train loss = 6.8423
Step 1000 avg train loss = 6.8006
Step 1100 avg train loss = 6.7628
Step 1200 avg train loss = 6.7271
Step 1300 avg train loss = 6.6957
Step 1400 avg train loss = 6.6670
Step 1500 avg train loss = 6.6370
Step 1600 avg train loss = 6.6100
Step 1700 avg train loss = 6.5830
Step 1800 avg train loss = 6.5592
Step 1900 avg train loss = 6.5364
Step 2000 avg train loss = 6.5136
Step 2100 avg train loss = 6.4926
Step 2200 avg train loss = 6.4721
Step 2300 avg train loss = 6.4526
Step 2400 avg train loss = 6.4346
Validation loss after 0 epoch = 5.8344
Step 0 avg train loss = 5.8996
Step 100 avg train loss = 5.9147
Step 200 avg train loss = 5.9128
Step 300 avg train loss =

In [None]:
torch.save(feature_extractor.state_dict(), "/content/drive/My Drive/feature_extractor")

In [None]:
feature_extractor = init_feature_extractor()
feature_extractor.load_state_dict(torch.load("/content/drive/My Drive/feature_extractor"))

<All keys matched successfully>

#### Part D
Calculate the test perplexity on wikitext2. Feel free to recycle code from previous assignments from this class. 

In [None]:
def calculate_wiki2_test_perplexity(feature_extractor, wikitext_test):
    test_loss_cache = 0
    test_non_pad_tokens_cache = 0
    criterion = nn.CrossEntropyLoss(ignore_index=vocab.get_id('<pad>'), reduction='sum')
    feature_extractor.eval()
    with torch.no_grad():
          for i, (inp, target) in enumerate(wikitext_test):
              inp = inp.to(device)
              target = target.to(device)
              logits = feature_extractor(inp)[0]
              loss = criterion(logits.view(-1, logits.size(-1)), target.view(-1))
              test_loss_cache += loss.item()
              non_pad_tokens = target.view(-1).ne(vocab.get_id('<pad>')).sum().item()
              test_non_pad_tokens_cache += non_pad_tokens
          avg_test_loss = test_loss_cache / test_non_pad_tokens_cache
          test_ppl = 2**(avg_test_loss/np.log(2))
    return test_ppl

#### Let's grade your results!
(don't touch this part)

In [None]:
def grade_wikitext2():
    # load data
    wikitext_train, wikitext_val, wikitext_test = init_wikitext_dataset()

    # load feature extractor
    #feature_extractor = init_feature_extractor()

    # pretrain using the feature extractor
    #feature_extractor = fit_feature_extractor(feature_extractor, wikitext_train, wikitext_val)

    # check test accuracy
    test_ppl = calculate_wiki2_test_perplexity(feature_extractor, wikitext_test)

    # the real threshold will be released by Oct 11 
    assert test_ppl < 10000, 'ummm... your perplexity is too high...'
    print(test_ppl)
grade_wikitext2()

173.39674515304023


---   
## Question 3 (fine-tune on MNLI)
In this question you will use your feature_extractor from question 2
to fine-tune on MNLI.

(From the website):
The Multi-Genre Natural Language Inference (MultiNLI) corpus is a crowd-sourced collection of 433k sentence pairs annotated with textual entailment information. The corpus is modeled on the SNLI corpus, but differs in that covers a range of genres of spoken and written text, and supports a distinctive cross-genre generalization evaluation. The corpus served as the basis for the shared task of the RepEval 2017 Workshop at EMNLP in Copenhagen.

MNLI has 3 genres (3 classes).
The goal of this question is to maximize the test accuracy in MNLI. 

### Part A
In this section you need to generate the training, validation and test split. Feel free to use code from your previous lectures.

In [None]:
y_label_map = {'contradiction':0,'neutral':1,'entailment':2}
import pandas as pd
import io
from collections import Counter

def get_string_tokenized_data(data):
    
    tokenized_data_x = [];
    y_labels = []
    all_tokens = [];
    
    for i,x in enumerate(data):
        label = x[2]
        if label == 'nan':
          continue
        
        label = y_label_map[label]
        y_labels.append(label)
        
        dp = [x[0].split(), x[1].split()]
        tokenized_data_x.append(dp)
        all_tokens += (dp[0] + dp[1])
        

    return all_tokens, tokenized_data_x, y_labels
        


# convert token to id in the dataset
def token2index_dataset(tokens_data, token2id):
    indices_data = []
    for tokens1, tokens2 in tokens_data:
        index_list1 = [token2id[token] if token in token2id else UNK_IDX for token in tokens1]
        index_list2 = [token2id[token] if token in token2id else UNK_IDX for token in tokens2]
        indices_data.append([index_list1, index_list2])
    return indices_data
  
# convert token to id in the dataset
def token2index_using_wikitext2_dict(tokens_data, vocab):
    indices_data = []
    for tokens1, tokens2 in tokens_data:
        index_list1 = vocab.encode_token_seq(tokens1)
        index_list2 = vocab.encode_token_seq(tokens2)
        indices_data.append([index_list1, index_list2])
    return indices_data


def build_vocab(all_tokens):
    # Returns:
    # id2token: list of tokens, where id2token[i] returns token that corresponds to token i
    # token2id: dictionary where keys represent tokens and corresponding values represent indices
    token_counter = Counter(all_tokens)
    vocab, count = zip(*token_counter.most_common(MAX_VOCAB_SIZE))
    id2token = list(vocab)
    token2id = dict(zip(vocab, range(2,2+len(vocab)))) 
    id2token = ['<pad>', '<unk>'] + id2token
    token2id['<pad>'] = PAD_IDX 
    token2id['<unk>'] = UNK_IDX
    return token2id, id2token

In [None]:
# LOAD VAL
val_df = pd.read_csv('/content/drive/My Drive/mnli_val.tsv', sep="\t")
print(val_df.head(2))

val_df  = np.array(val_df)
val_df = val_df.astype(str)
val_genre_list = val_df[:, 3]

                                           sentence1  ...      genre
0  'Not entirely , ' I snapped , harsher than int...  ...    fiction
1  cook and then the next time it would be my tur...  ...  telephone

[2 rows x 4 columns]


In [None]:
_, val_data_x, val_data_y = get_string_tokenized_data(val_df)
del val_df

In [None]:
train_df = pd.read_csv('/content/drive/My Drive/mnli_train.tsv', sep="\t", chunksize=5000)
train_genre_list = []
train_data_x = []
train_data_y = []

for chunk in train_df:

  chunk  = np.array(chunk)
  chunk = chunk.astype(str)
  train_genre_list_chunk = chunk[:, 3]
  train_genre_list.append(train_genre_list_chunk)
  
  _, train_data_x_chunk, train_data_y_chunk = get_string_tokenized_data(chunk)
  train_data_x.append(train_data_x_chunk)
  train_data_y.append(train_data_y_chunk)

train_genre_list = np.concatenate(train_genre_list)
train_data_x = np.concatenate(train_data_x)
train_data_y = np.concatenate(train_data_y)

In [None]:
mnli_raw_datasets = {'train': train_data_x, 'val': val_data_x}
mnli_tokenized_datasets = tokenize_mnli_dataset(mnli_raw_datasets, vocab)

train_data_indices = mnli_tokenized_datasets['train']
val_data_indices = mnli_tokenized_datasets['val']


# double checking
print('\n')
print ("Train dataset size is {}".format(len(train_data_indices)))
print ("Val dataset size is {}".format(len(val_data_indices)))

# del train_data_x
# del train_data_y
# del val_data_x
# del val_data_y
del mnli_tokenized_datasets

100%|██████████| 20000/20000 [00:00<00:00, 62570.08it/s]
100%|██████████| 5000/5000 [00:00<00:00, 67275.49it/s]



Train dataset size is 20000
Val dataset size is 5000





In [None]:
unique_genre = list(set(val_genre_list));
nb_classes = len(y_label_map)

In [None]:
MAX_SENTENCE_LENGTH = 200
class MNLIDataset(Dataset):
    """
    Class that represents a train/validation/test dataset that's readable for PyTorch
    Note that this class inherits torch.utils.data.Dataset
    """

    def __init__(self, data_x, target_list):
        """
        @param data_list: list of newsgroup tokens
        @param target_list: list of newsgroup targets

        """
        self.data_x = data_x;
        self.target_list = target_list
        
        assert(len(data_x) == len(target_list))

    def __len__(self):
        return len(self.target_list)

    def __getitem__(self, key):
        """
        Triggered when you call dataset[i]
        """
        prem_token_idx = self.data_x[key][0][:MAX_SENTENCE_LENGTH]
        hyp_token_idx = self.data_x[key][1][:MAX_SENTENCE_LENGTH]
        label = self.target_list[key]
        return [prem_token_idx, hyp_token_idx, label]


def encode_collate_func(batch):
    """
    Customized function for DataLoader that dynamically pads the batch so that all
    data have the same length
    """
    prem_data_list = []
    hyp_data_list = []
    label_list = []
    length_list = []
    # print("collate batch: ", batch[0][0])
    # batch[0][0] = batch[0][0][:MAX_SENTENCE_LENGTH]
    for datum in batch:
        label_list.append(datum[2])
    # padding
    for datum in batch:
        prem_padded_vec = np.pad(np.array(datum[0]),
                                 pad_width=((0, MAX_SENTENCE_LENGTH - len(datum[0]))),
                                 mode="constant", constant_values=0)
        hyp_padded_vec = np.pad(np.array(datum[1]),
                                pad_width=((0, MAX_SENTENCE_LENGTH - len(datum[1]))),
                                mode="constant", constant_values=0)
        prem_data_list.append(prem_padded_vec)
        hyp_data_list.append(hyp_padded_vec)
    return [torch.from_numpy((np.array(prem_data_list))), torch.from_numpy(np.array(hyp_data_list)),
            torch.LongTensor(label_list)]

In [None]:
BATCH_SIZE = 32
nb_train_samples = int(0.95 * len(train_data_indices))
nb_val_samples = len(train_data_indices) - nb_train_samples

# train/val split
train_val_dataset = MNLIDataset(train_data_indices, train_data_y)
train_dataset, val_dataset = random_split(train_val_dataset, [nb_train_samples, nb_val_samples])

# train loader
train_mnli_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=BATCH_SIZE,
                                           collate_fn=encode_collate_func,
                                           shuffle=True)

# val loader
val_mnli_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                           batch_size=BATCH_SIZE,
                                           collate_fn=encode_collate_func,
                                           shuffle=True)

# test loader
test_dataset = MNLIDataset(val_data_indices, val_data_y)
test_mnli_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                           batch_size=BATCH_SIZE,
                                           collate_fn=encode_collate_func,
                                           shuffle=True)

NameError: ignored

In [None]:
print('MNLI dataset statistics')
print(f'training samples:{len(train_mnli_loader)}')
print(f'val samples:{len(val_mnli_loader)}')
print(f'test samples:{len(test_mnli_loader)}')

MNLI dataset statistics
training samples:594
val samples:32
test samples:157


In [None]:
from torchtext.datasets import MultiNLI

def init_mnli_dataset():
    """
    Fill in the details
    """
    mnli_val = val_mnli_loader
    mnli_train = train_mnli_loader
    mnli_test = test_mnli_loader
    
    return mnli_train, mnli_val, mnli_test

### Part B
Here we again design a model for finetuning. Use the output of your feature-extractor as the input to this model. This should be a powerful classifier (up to you).

In [None]:
import torch.nn as nn

class MultiNLI_Classifier(nn.Module):
    def __init__(self, options):
        super().__init__()
        # create each LM part here 
        self.fc1 = nn.Linear(768, 256)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(256, options['num_classes'])
        
    def forward(self, feature_extractor_output):
        feature_extractor_output = feature_extractor_output.view(feature_extractor_output.size(0), -1)
        output = self.fc1(feature_extractor_output)
        output = self.relu(output)
        output = self.fc2(output)
        return output

In [None]:
def init_finetune_model():
    options = {
        #'hidden_size': 1024,
        'num_classes': 3,
    }
    fine_tune_model = MultiNLI_Classifier(options).to(device)
    return fine_tune_model

### Part C
Use the feature_extractor and your fine_tune_model to fine_tune MNLI

In [None]:
def fine_tune_mnli(feature_extractor, fine_tune_model, mnli_train, mnli_val):
    # YOUR CODE HERE
    highest_acc = 0
    best_feature_extractor = feature_extractor
    best_classifier = None
    for param in feature_extractor.parameters():
        param.requires_grad = True
    criterion = nn.CrossEntropyLoss(ignore_index=vocab.get_id('<pad>'), reduction='sum')
    fine_tune_model_parameters = [p for p in fine_tune_model.parameters() if p.requires_grad]
    feature_extractor_parameters = [p for p in feature_extractor.parameters() if p.requires_grad]
    optimizer = torch.optim.Adam(list(fine_tune_model_parameters + feature_extractor_parameters), lr = 0.0001)
    #do train
    for epoch in range(10):
        fine_tune_model.train()
        correct_train = 0
        total_train = 0
        correct_val = 0
        total_val = 0
        for i, (inp1, inp2, target) in enumerate(mnli_train):
            optimizer.zero_grad()
            inp1 = inp1.to(device)
            inp2 = inp2.to(device)
            target = target.to(device)
            feature_extractor_output1, feature_extractor_output2 = feature_extractor(inp1)[1], feature_extractor(inp2)[1]
            feature_extractor_output = torch.cat([feature_extractor_output1, feature_extractor_output2], dim=1)
            outputs = fine_tune_model(feature_extractor_output)
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()
            predicted = outputs.max(1, keepdim=True)[1]
            total_train += target.size(0)
            correct_train += predicted.eq(target.view_as(predicted)).sum().item()
        print('Training accuracy after {} epoch = {:.{prec}f}'.format(epoch, 100 * correct_train / total_train, prec=4))
        # do eval
        feature_extractor.eval()
        fine_tune_model.eval()
        with torch.no_grad():
            for i, (inp1, inp2, target) in enumerate(mnli_val):
                optimizer.zero_grad()
                inp1 = inp1.to(device)
                inp2 = inp2.to(device)
                target = target.to(device)
                feature_extractor_output1, feature_extractor_output2 = feature_extractor(inp1)[1], feature_extractor(inp2)[1]
                feature_extractor_output = torch.cat([feature_extractor_output1, feature_extractor_output2], dim=1)
                outputs = fine_tune_model(feature_extractor_output)
                loss = criterion(outputs, target)
                predicted = outputs.max(1, keepdim=True)[1]
                total_val += target.size(0)
                correct_val += predicted.eq(target.view_as(predicted)).sum().item()
            if correct_val/total_val > highest_acc:
                highest_acc = correct_val/total_val
                best_feature_extractor = deepcopy(feature_extractor)
                best_classifier = deepcopy(fine_tune_model)
        print('Validation accuracy after {} epoch = {:.{prec}f}'.format(epoch, 100 * correct_val / total_val, prec=4))
    return best_feature_extractor, best_classifier

### Part D
Evaluate the test accuracy

In [None]:
def calculate_mnli_test_accuracy(feature_extractor, fine_tune_model, mnli_test):
    correct_test = 0
    total_test = 0
    criterion = nn.CrossEntropyLoss(ignore_index=vocab.get_id('<pad>'), reduction='sum')
    feature_extractor.eval()
    fine_tune_model.eval()
    with torch.no_grad():
        for i, (inp1, inp2, target) in enumerate(mnli_test):
            optimizer.zero_grad()
            inp1 = inp1.to(device)
            inp2 = inp2.to(device)
            target = target.to(device)
            feature_extractor_output1, feature_extractor_output2 = feature_extractor(inp1)[1], feature_extractor(inp2)[1]
            feature_extractor_output = torch.cat([feature_extractor_output1, feature_extractor_output2], dim=1)
            outputs = fine_tune_model(feature_extractor_output)
            loss = criterion(outputs, target)
            predicted = outputs.max(1, keepdim=True)[1]
            total_test += target.size(0)
            correct_test += predicted.eq(target.view_as(predicted)).sum().item()
    return correct_test/total_test

### Let's grade your results

In [None]:
def grade_mnli():
    # load data
    mnli_train, mnli_val, mnli_test = init_mnli_dataset()

    # no need to load feature extractor because it is fine-tuned
    feature_extractor = init_feature_extractor()

    # init the fine_tune model
    fine_tune_model = init_finetune_model()
    
    # finetune
    fine_tune_mnli(feature_extractor, fine_tune_model, mnli_train, mnli_val)

    # check test accuracy
    test_accuracy = calculate_mnli_test_accuracy(feature_extractor, wikitext_test)

    # the real threshold will be released by Oct 11 
    assert test_ppl > 0.00, 'ummm... your accuracy is too low...'
    
grade_mnli()

RuntimeError: ignored

---  
### Question 4 (BERT)

A major direction in research came from a model called BERT, released last year.  

In this question you'll use BERT as your feature_extractor instead of the model you
designed yourself.

To get BERT, head on over to (https://github.com/huggingface/transformers) and load your BERT model here

In [None]:
!pip install transformers



### Part A (init BERT)
In this section you need to create an instance of BERT and return if from the function

In [None]:
from transformers.data.processors.glue import MnliProcessor
import torch
import pandas as pd
import os
import sys
import shutil
import argparse
import tempfile
import urllib.request
import zipfile
from transformers import glue_convert_examples_to_features as convert_examples_to_features
from transformers import BertTokenizer
from torch.utils.data import TensorDataset, RandomSampler, DataLoader


from transformers import (
    BertModel,
    BertTokenizer
)

tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
bert = BertModel.from_pretrained('bert-base-cased', output_attentions=True)

In [None]:
TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"]
TASK2PATH = {
    "CoLA": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4",  # noqa
    "SST": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8",  # noqa
    "MRPC": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc",  # noqa
    "QQP": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP-clean.zip?alt=media&token=11a647cb-ecd3-49c9-9d31-79f8ca8fe277",  # noqa
    "STS": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5",  # noqa
    "MNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce",  # noqa
    "SNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df",  # noqa
    "QNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601",  # noqa
    "RTE": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb",  # noqa
    "WNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf",  # noqa
    "diagnostic": [
        "https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D",  # noqa
        "https://www.dropbox.com/s/ju7d95ifb072q9f/diagnostic-full.tsv?dl=1",
    ],
}

MRPC_TRAIN = "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt"
MRPC_TEST = "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt"


def download_and_extract(task, data_dir):
    print("Downloading and extracting %s..." % task)
    data_file = "%s.zip" % task
    urllib.request.urlretrieve(TASK2PATH[task], data_file)
    with zipfile.ZipFile(data_file) as zip_ref:
        zip_ref.extractall(data_dir)
    os.remove(data_file)
    print("\tCompleted!")

In [None]:
download_and_extract('MNLI', '.')

Downloading and extracting MNLI...
	Completed!


In [None]:
processor = MnliProcessor()

In [None]:

def generate_mnli_bert_dataloaders():
  # ----------------------
  # TRAIN/VAL DATALOADERS
  # ----------------------
  train = processor.get_train_examples('MNLI')
  features = convert_examples_to_features(train,
                                          tokenizer,
                                          label_list=['contradiction','neutral','entailment'],
                                          max_length=128,
                                          output_mode='classification',
                                          pad_on_left=False,
                                          pad_token=tokenizer.pad_token_id,
                                          pad_token_segment_id=0)
  train_dataset = TensorDataset(torch.tensor([f.input_ids for f in features], dtype=torch.long), 
                                torch.tensor([f.attention_mask for f in features], dtype=torch.long), 
                                torch.tensor([f.token_type_ids for f in features], dtype=torch.long), 
                                torch.tensor([f.label for f in features], dtype=torch.long))

  nb_train_samples = int(0.95 * len(train_dataset))
  nb_val_samples = len(train_dataset) - nb_train_samples

  bert_mnli_train_dataset, bert_mnli_val_dataset = random_split(train_dataset, [nb_train_samples, nb_val_samples])

  # train loader
  train_sampler = RandomSampler(bert_mnli_train_dataset)
  bert_mnli_train_dataloader = DataLoader(bert_mnli_train_dataset, sampler=train_sampler, batch_size=32)

  # val loader
  val_sampler = RandomSampler(bert_mnli_val_dataset)
  bert_mnli_val_dataloader = DataLoader(bert_mnli_val_dataset, sampler=val_sampler, batch_size=32)


  # ----------------------
  # TEST DATALOADERS
  # ----------------------
  dev = processor.get_dev_examples('MNLI')
  features = convert_examples_to_features(dev,
                                          tokenizer,
                                          label_list=['contradiction','neutral','entailment'],
                                          max_length=128,
                                          output_mode='classification',
                                          pad_on_left=False,
                                          pad_token=tokenizer.pad_token_id,
                                          pad_token_segment_id=0)

  bert_mnli_test_dataset = TensorDataset(torch.tensor([f.input_ids for f in features], dtype=torch.long), 
                                torch.tensor([f.attention_mask for f in features], dtype=torch.long), 
                                torch.tensor([f.token_type_ids for f in features], dtype=torch.long), 
                                torch.tensor([f.label for f in features], dtype=torch.long))

  # test dataset
  test_sampler = RandomSampler(bert_mnli_test_dataset)
  bert_mnli_test_dataloader = DataLoader(bert_mnli_test_dataset, sampler=test_sampler, batch_size=32)
  
  return bert_mnli_train_dataloader, bert_mnli_val_dataloader, bert_mnli_test_dataloader

In [None]:
bert_mnli_train_dataloader, bert_mnli_val_dataloader, bert_mnli_test_dataloader = generate_mnli_bert_dataloaders()

In [None]:
class BERTSequence(nn.Module):
    def __init__(self, bert):
        super().__init__()
        self.bert = bert
        
    def forward(self, input_ids, attention_mask, token_type_ids):
        h, _, attn = self.bert(input_ids=input_ids, 
                               attention_mask=attention_mask, 
                               token_type_ids=token_type_ids)
        h_cls = h[:, 0]
        return h_cls, attn

In [None]:
from transformers import BertTokenizer, BertModel, BertForMaskedLM

def init_bert():
    tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
    bert = BertModel.from_pretrained('bert-base-cased', output_attentions=True)   
    BERT = BERTSequence(bert)
    return BERT

In [None]:
device = "cuda"

In [None]:
BERT_feature_extractor = init_bert()

In [None]:
BERT_feature_extractor

BERTSequence(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)

In [None]:
BERT_feature_extractor = BERT_feature_extractor.to(device)

## Part B (fine-tune with BERT)

Use BERT as your feature extractor to finetune MNLI. Use a new finetune model (reset weights).

In [None]:
from copy import deepcopy
def fine_tune_mnli_BERT(BERT_feature_extractor, fine_tune_model, bert_mnli_train_dataloader, bert_mnli_val_dataloader):   
    highest_acc = 0
    best_feature_extractor = BERT_feature_extractor
    best_classifier = None
    BERT_feature_extractor = BERT_feature_extractor.to(device)
    fine_tune_model = fine_tune_model.to(device)
    for param in BERT_feature_extractor.parameters():
        param.requires_grad = False
    criterion = nn.CrossEntropyLoss()
    fine_tune_model_parameters = [p for p in fine_tune_model.parameters() if p.requires_grad]
    optimizer = torch.optim.Adam(fine_tune_model_parameters, lr = 0.0001)
    #do train
    for epoch in range(5):
        BERT_feature_extractor.train()
        fine_tune_model.train()
        correct_train = 0
        total_train = 0
        correct_val = 0
        total_val = 0
        for i, (inp, att_mask, token_type_ids, target) in enumerate(bert_mnli_train_dataloader):
            optimizer.zero_grad()
            inp, att_mask, token_type_ids, target = inp.to(device), att_mask.to(device), token_type_ids.to(device), target.to(device)
            feature_extractor_output = BERT_feature_extractor(inp, att_mask, token_type_ids)[0]
            outputs = fine_tune_model(feature_extractor_output)
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()
            predicted = outputs.max(1, keepdim=True)[1]
            total_train += target.size(0)
            correct_train += predicted.eq(target.view_as(predicted)).sum().item()
        print('Training accuracy after {} epoch = {:.{prec}f}'.format(epoch, 100 * correct_train / total_train, prec=4))
        # do eval
        BERT_feature_extractor.eval()
        fine_tune_model.eval()
        with torch.no_grad():
            for i, (inp, att_mask, token_type_ids, target) in enumerate(bert_mnli_val_dataloader):
                optimizer.zero_grad()
                inp, att_mask, token_type_ids, target = inp.to(device), att_mask.to(device), token_type_ids.to(device), target.to(device)
                feature_extractor_output = BERT_feature_extractor(inp, att_mask, token_type_ids)[0]
                outputs = fine_tune_model(feature_extractor_output)
                loss = criterion(outputs, target)
                predicted = outputs.max(1, keepdim=True)[1]
                total_val += target.size(0)
                correct_val += predicted.eq(target.view_as(predicted)).sum().item()
            if correct_val/total_val > highest_acc:
                highest_acc = correct_val/total_val
                #best_feature_extractor = deepcopy(BERT_feature_extractor)
                best_classifier = deepcopy(fine_tune_model)
        print('Validation accuracy after {} epoch = {:.{prec}f}'.format(epoch, 100 * correct_val / total_val, prec=4))
    return best_feature_extractor, best_classifier

## Part C
Evaluate how well we did

In [None]:
def calculate_mnli_test_accuracy_BERT(feature_extractor, fine_tune_model, mnli_test):
    
    # YOUR CODE HERE...
    correct_test = 0
    total_test = 0
    criterion = nn.CrossEntropyLoss()
    feature_extractor.eval()
    fine_tune_model.eval()
    with torch.no_grad():
        for i, (inp, att_mask, token_type_ids, target) in enumerate(bert_mnli_test_dataloader):
            optimizer.zero_grad()
            inp, att_mask, token_type_ids, target = inp.to(device), att_mask.to(device), token_type_ids.to(device), target.to(device)
            feature_extractor_output = BERT_feature_extractor(inp, att_mask, token_type_ids)[0]
            outputs = fine_tune_model(feature_extractor_output)
            loss = criterion(outputs, target)
            predicted = outputs.max(1, keepdim=True)[1]
            total_test += target.size(0)
            correct_test += predicted.eq(target.view_as(predicted)).sum().item()
    return correct_test/total_test

## Let's grade your BERT results!

In [None]:
def grade_mnli_BERT():
    BERT_feature_extractor = init_bert()
    
    # load data
    mnli_train, mnli_val, mnli_test = bert_mnli_train_dataloader, bert_mnli_val_dataloader, bert_mnli_test_dataloader

    # init the fine_tune model
    fine_tune_model = init_finetune_model()
    
    # finetune
    fine_tune_mnli_BERT(BERT_feature_extractor, fine_tune_model, mnli_train, mnli_val)

    # check test accuracy
    test_accuracy = calculate_mnli_test_accuracy(BERT_feature_extractor, mnli_test)
    
    # the real threshold will be released by Oct 11 
    assert test_ppl > 0.0, 'ummm... your accuracy is too low...'
    
grade_mnli_BERT()

Training accuracy after 0 epoch = 51.5584
Validation accuracy after 0 epoch = 55.0163
Training accuracy after 1 epoch = 53.6993
Validation accuracy after 1 epoch = 54.7056
Training accuracy after 2 epoch = 54.3212
Validation accuracy after 2 epoch = 56.9719
