In [1]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/drive


In [0]:
import os
os.chdir("/content/drive/My Drive/hw3")

---   
# HW3 - Transfer learning

#### Due October 30, 2019

In this assignment you will learn about transfer learning. This technique is perhaps one of the most important techniques for industry. When a problem you want to solve does not have enough data, we use a different (larger) dataset to learn representations which can help us solve our task using the smaller task.

The general steps to transfer learning are as follows:

1. Find a huge dataset with similar characteristics to the problem you are interested in.
2. Choose a model powerful enough to extract meaningful representations from the huge dataset.
3. Train this model on the huge dataset.
4. Use this model to train on the smaller dataset.


### This homework has the following sections:
1. Question 1: MNIST fine-tuning (Parts A, B, C, D).
2. Question 2: Pretrain on Wikitext2 (Part A, B, C, D)
3. Question 3: Finetune on MNLI (Part A, B, C, D)
4. Question 4: Finetune using pretrained BERT (Part A, B, C)

---   
## Question 1 (MNIST transfer learning)
To grasp the high-level approach to transfer learning, let's first do a simple example using computer vision. 

The torchvision library has pretrained models (resnets, vggnets, etc) on the Imagenet dataset. Imagenet is a dataset
with 1.3 million images covering over 1000 classes of objects. When you use one of these models, the weights of the model initialize
with the weights saved from training on imagenet.

In this task we will:
1. Choose a pretrained model.
2. Freeze the model so that the weights don't change.
3. Fine-tune on a few labels of MNIST.   

#### Choose a model
Here we pick any of the models from torchvision

In [0]:
import torch
num_gpus = torch.cuda.device_count()
if num_gpus > 0:
    current_device = 'cuda'
else:
    current_device = 'cpu'

In [4]:
import torchvision.models as models

class Identity(torch.nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x

# init the pretrained feature extractor
pretrained_resnet18 = models.resnet18(pretrained=True).to(current_device)

# we don't want the built in last layer, we're going to modify it ourselves
pretrained_resnet18.fc = Identity()

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/checkpoints/resnet18-5c106cde.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 56.8MB/s]


In [5]:
pretrained_resnet18.fc

Identity()

#### Freeze the model
Here we freeze the weights of the model. Freezing means the gradients will not backpropagate
into these weights.

By doing this you can think about the model as a feature extractor. This feature extractor outputs
a **representation** of an input. This representation is a matrix that encodes information about the input.

In [0]:
def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False
        
def unfreeze_model(model):
    for param in model.parameters():
        param.requires_grad = True
        
freeze_model(pretrained_resnet18)

#### Init target dataset
Here we define the dataset we are actually interested in.

In [0]:
import os
from torchvision import transforms
from torchvision.datasets import  MNIST
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F

#  train/val  split
transform = transforms.Compose([transforms.Grayscale(3),
    transforms.ToTensor()
])
mnist_dataset = MNIST(os.getcwd(), train=True, download=True, transform=transform)
mnist_train, mnist_val = random_split(mnist_dataset, [55000, 5000])

mnist_train = DataLoader(mnist_train, batch_size=32)
mnist_val = DataLoader(mnist_val, batch_size=32)

# test split
mnist_test = DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=transform), batch_size=32)

In [8]:
for images, labels in mnist_train:  
    print('Image batch dimensions:', images.shape)
    print('Image label dimensions:', labels.shape)
    break

Image batch dimensions: torch.Size([32, 3, 28, 28])
Image label dimensions: torch.Size([32])


### Part A (init fine-tune model)
decide what model to use for fine-tuning

In [0]:
class MLP(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(MLP, self).__init__()
        self.fc1 = torch.nn.Linear(D_in, H)
        self.relu = torch.nn.ReLU()
        self.projection = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        out = self.fc1(x)
        out = self.relu(out)
        out = self.projection(out)
        return out

In [0]:
def init_fine_tune_model():
    
    # YOUR CODE HERE
    fine_tune_model = MLP(512,256,10).to(current_device)
    return fine_tune_model

### Part B (Fine-tune (Frozen))

The actual problem we care about solving likely has a different number of classes or is a different task altogether. Fine-tuning is the process of using the extracted representations (features) to solve this downstream task  (the task you're interested in).

To illustrate this, we'll use our pretrained model (on Imagenet), to solve the MNIST classification task.

There are two types of finetuning. 

#### 1. Frozen feature_extractor
In the first type we pretrain with the FROZEN feature_extractor and NEVER unfreeze it during finetuning.


#### 2. Unfrozen feature_extractor
In the second, we finetune with a FROZEN feature_extractor for a few epochs, then unfreeze the feature extractor and finish training.


In this part we will use the first version

In [0]:
import numpy as np

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, unfrozen,patience=5, verbose=False, delta=0,name = 'RNN_'):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.name = name
        self.unfrozen = unfrozen

    def __call__(self, val_loss, model, extractor):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, extractor)
        elif score < self.best_score - self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model, extractor)
            self.counter = 0

    def save_checkpoint(self, val_loss, model, extractor):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.name+'checkpoint.pt')
        if self.unfrozen:
          torch.save(extractor.state_dict(), self.name+'extractor_checkpoint.pt')
        self.val_loss_min = val_loss
        
patience = 5
early_stopping = EarlyStopping(unfrozen=True,patience=patience, verbose=True,name="MLP_frozen_")


In [0]:
import torch.optim as optim
plot_cache = []
def FROZEN_fine_tune_mnist(feature_extractor, fine_tune_model, mnist_train, mnist_val,num_epochs=10):
    """
    model is a feature extractor (resnet).
    Create a new model which uses those features to finetune on MNIST
    
    return the fine_tune model
    """     

    for epoch_number in range(num_epochs):
        avg_loss=0
        # do train
        feature_extractor.train()
        fine_tune_model.train()
        train_log_cache = []
        for i, (inp, target) in enumerate(mnist_train):
            optimizer.zero_grad()
            inp = inp.to(current_device)
            target = target.to(current_device)
            out = feature_extractor(inp)
            logits = fine_tune_model(out)
            
            loss = criterion(logits.view(-1, logits.size(-1)), target.view(-1))
            
            loss.backward()
            optimizer.step()
            
            train_log_cache.append(loss.item())
            
            if i % 100 == 0:
                avg_loss = sum(train_log_cache)/len(train_log_cache)
                print('Step {} avg train loss = {:.{prec}f}'.format(i, avg_loss, prec=4))
                train_log_cache = []
            
        #do valid
        valid_losses = []
        feature_extractor.eval()
        fine_tune_model.eval()
        with torch.no_grad():
            for i, (inp, target) in enumerate(mnist_val):
                inp = inp.to(current_device)
                target = target.to(current_device)
                out = feature_extractor(inp)
                logits = fine_tune_model(out)

                loss = criterion(logits.view(-1, logits.size(-1)), target.view(-1))
                valid_losses.append(loss.item())
            avg_val_loss = sum(valid_losses) / len(valid_losses)
            print('Validation loss after {} epoch = {:.{prec}f}'.format(epoch_number, avg_val_loss, prec=4))
        
        plot_cache.append((avg_loss, avg_val_loss))
    
        early_stopping(avg_val_loss, fine_tune_model,feature_extractor)
        torch.save({
            'loss_cache': plot_cache,
        }, './fine_tune_model_frozen.pt')

        if early_stopping.early_stop:
            print("Early stopping")
            break

In [0]:
fine_tune_model = init_fine_tune_model()
fine_tune_model.to(current_device)
pretrained_resnet18 = models.resnet18(pretrained=True).to(current_device)
pretrained_resnet18.fc = Identity()
freeze_model(pretrained_resnet18)
feature_extractor = pretrained_resnet18.to(current_device)
criterion = torch.nn.CrossEntropyLoss()
fine_tune_model_parameters = list(feature_extractor.parameters()) + list(fine_tune_model.parameters())
optimizer = optim.Adam(filter(lambda p: p.requires_grad, fine_tune_model_parameters))

### Part C (compute test accuracy)
Compute the test accuracy of fine-tuned model on MNIST

In [0]:
def calculate_mnist_test_accuracy(feature_extractor, fine_tune_model, mnist_test):
    correct = 0
    total = 0
    feature_extractor.eval()
    fine_tune_model.eval()
    for data, labels in mnist_test:
        data = data.to(current_device)
        labels = labels.to(current_device)
        out = feature_extractor(data)
        out = fine_tune_model(out)
        outputs = F.softmax(out, dim=1)
        predicted = outputs.max(1, keepdim=True)[1]
        
        total += labels.size(0)
        correct += predicted.eq(labels.view_as(predicted).to(current_device)).sum().item()
    return 100 * correct / total

In [15]:
 calculate_mnist_test_accuracy(feature_extractor, fine_tune_model, mnist_test)

8.68

### Grade!
Let's see how you did

In [0]:
load_pretrained = True

In [0]:
def grade_mnist_frozen():
    
    # # init a ft model
    # fine_tune_model = init_fine_tune_model()
    
    # run the transfer learning routine
    if not load_pretrained:
        FROZEN_fine_tune_mnist(feature_extractor, fine_tune_model, mnist_train, mnist_val,num_epochs=10)
    
    feature_extractor_load = models.resnet18(pretrained=False)
    feature_extractor_load.fc = Identity()
    feature_extractor_load.load_state_dict(torch.load("MLP_frozen_extractor_checkpoint.pt"))
    feature_extractor_load.to(current_device)

    fine_tune_model_load = init_fine_tune_model()
    fine_tune_model_load.load_state_dict(torch.load("MLP_frozen_checkpoint.pt"))
    fine_tune_model_load.to(current_device)
    # calculate test accuracy
    test_accuracy = calculate_mnist_test_accuracy(feature_extractor_load, fine_tune_model_load, mnist_test)
    
    # the real threshold will be released by Oct 11 
    assert test_accuracy > 0.0, 'your accuracy is too low...'
    
    return test_accuracy
    
frozen_test_accuracy = grade_mnist_frozen()

In [18]:
frozen_test_accuracy

78.99

In [19]:
!md5sum "MLP_frozen_extractor_checkpoint.pt"
!md5sum "MLP_frozen_checkpoint.pt"

f8e8d97515bc8b3458ff69158b9e2c15  MLP_frozen_extractor_checkpoint.pt
2f6dafdf2a7fedc33287032f638f5c20  MLP_frozen_checkpoint.pt


### Part D (Fine-tune Unfrozen)
Now we'll learn how to train using the "unfrozen" approach.

In this approach we'll:
1. keep the feature_extract frozen for a few epochs (10)
2. Unfreeze it.
3. Finish training

In [0]:
pretrained_resnet18_unfrozen = models.resnet18(pretrained=True).to(current_device)
pretrained_resnet18_unfrozen.fc = Identity()
freeze_model(pretrained_resnet18_unfrozen)


fine_tune_model_2 = init_fine_tune_model()
fine_tune_model_2.to(current_device)
feature_extractor_2 = pretrained_resnet18_unfrozen.to(current_device)
patience = 5
early_stopping_unfrozen = EarlyStopping(unfrozen=True, patience=patience, verbose=True,name="MLP_unfrozen_")

In [0]:
def UNFROZEN_fine_tune_mnist(feature_extractor, fine_tune_model, mnist_train, mnist_val,num_epochs=50):
    """
    model is a feature extractor (resnet).
    Create a new model which uses those features to finetune on MNIST
    
    return the fine_tune model
    """     
    
    # INSERT YOUR CODE:
    # keep frozen for 10 epochs
    # ... train
    # unfreeze
    # train for rest of the time
    criterion_unfrozen = torch.nn.CrossEntropyLoss()
    fine_tune_model_parameters_unfrozen = list(feature_extractor.parameters()) + list(fine_tune_model.parameters())
    optimizer_unfrozen = optim.Adam(filter(lambda p: p.requires_grad, fine_tune_model_parameters_unfrozen))
    for epoch_number in range(10):
        avg_loss=0
        fine_tune_model.train()
        feature_extractor.train()
        train_log_cache = []
        for i, (inp, target) in enumerate(mnist_train):
            optimizer_unfrozen.zero_grad()
            inp = inp.to(current_device)
            target = target.to(current_device)
            inp = feature_extractor(inp)
            logits = fine_tune_model(inp)
            
            loss = criterion_unfrozen(logits.view(-1, logits.size(-1)), target.view(-1))
            
            loss.backward()
            optimizer_unfrozen.step()
            
            train_log_cache.append(loss.item())
            
            if i % 100 == 0:
                avg_loss = sum(train_log_cache)/len(train_log_cache)
                print('Step {} avg train loss = {:.{prec}f}'.format(i, avg_loss, prec=4))
                train_log_cache = []
            
        #do valid
        valid_losses = []
        fine_tune_model.eval()
        feature_extractor.eval()
        with torch.no_grad():
            for i, (inp, target) in enumerate(mnist_val):
                inp = inp.to(current_device)
                target = target.to(current_device)
                inp = feature_extractor(inp)
                logits = fine_tune_model(inp)

                loss = criterion_unfrozen(logits.view(-1, logits.size(-1)), target.view(-1))
                valid_losses.append(loss.item())
            avg_val_loss = sum(valid_losses) / len(valid_losses)
            print('Validation loss after {} epoch = {:.{prec}f}'.format(epoch_number, avg_val_loss, prec=4))
        
        plot_cache.append((avg_loss, avg_val_loss))
    
    unfreeze_model(feature_extractor)
    fine_tune_model_parameters_unfrozen2 = list(feature_extractor.parameters()) + list(fine_tune_model.parameters())
    optimizer_unfrozen2 = optim.Adam(filter(lambda p: p.requires_grad, fine_tune_model_parameters_unfrozen2))
    
    
    for epoch_number in range(num_epochs):
        avg_loss=0
        fine_tune_model.train()
        feature_extractor.train()
        train_log_cache = []
        for i, (inp, target) in enumerate(mnist_train):
            optimizer_unfrozen2.zero_grad()
            inp = inp.to(current_device)
            target = target.to(current_device)
            inp = feature_extractor(inp)
            logits = fine_tune_model(inp)
            
            loss = criterion_unfrozen(logits.view(-1, logits.size(-1)), target.view(-1))
            
            loss.backward()
            optimizer_unfrozen2.step()
            
            train_log_cache.append(loss.item())
            
            if i % 100 == 0:
                avg_loss = sum(train_log_cache)/len(train_log_cache)
                print('Step {} avg train loss = {:.{prec}f}'.format(i, avg_loss, prec=4))
                train_log_cache = []
        #do valid
        valid_losses = []
        fine_tune_model.eval()
        feature_extractor.eval()
        with torch.no_grad():
            for i, (inp, target) in enumerate(mnist_val):
                inp = inp.to(current_device)
                target = target.to(current_device)
                inp = feature_extractor(inp)
                logits = fine_tune_model(inp)

                loss = criterion_unfrozen(logits.view(-1, logits.size(-1)), target.view(-1))
                valid_losses.append(loss.item())
            avg_val_loss = sum(valid_losses) / len(valid_losses)
            print('Validation loss after {} epoch = {:.{prec}f}'.format(epoch_number, avg_val_loss, prec=4))
        
        plot_cache.append((avg_loss, avg_val_loss))
                
        early_stopping_unfrozen(avg_val_loss, fine_tune_model,feature_extractor)
        torch.save({
            'loss_cache': plot_cache
        }, './fine_tune_model_unfrozen.pt')
    
    
        if early_stopping_unfrozen.early_stop:
            print("Early stopping")
            break
    

### Grade UNFROZEN
Let's see if there's a difference in accuracy!

In [0]:
load_pretrained=True

In [0]:
def grade_mnist_unfrozen():
    
    # init a ft model
    # fine_tune_model = init_fine_tune_model()
    
    # run the transfer learning routine
    if not load_pretrained:
        UNFROZEN_fine_tune_mnist(feature_extractor_2, fine_tune_model_2, mnist_train, mnist_val)
    
    feature_extractor_2_load = models.resnet18(pretrained=False)
    feature_extractor_2_load.fc = Identity()
    feature_extractor_2_load.load_state_dict(torch.load("MLP_unfrozen_extractor_checkpoint.pt"))
    feature_extractor_2_load.to(current_device)
    
    fine_tune_model_2_load = init_fine_tune_model()
    fine_tune_model_2_load.load_state_dict(torch.load("MLP_unfrozen_checkpoint.pt"))
    fine_tune_model_2_load.to(current_device)
    # calculate test accuracy
    test_accuracy = calculate_mnist_test_accuracy(feature_extractor_2_load, fine_tune_model_2_load, mnist_test)
    
    # the real threshold will be released by Oct 11 
    assert test_accuracy > 0.0, 'your accuracy is too low...'
    
    return test_accuracy
    
unfrozen_test_accuracy = grade_mnist_unfrozen()

In [0]:
assert unfrozen_test_accuracy > frozen_test_accuracy, 'the unfrozen model should be better'

In [25]:
print("unfrozen test accuracy is {0}".format(unfrozen_test_accuracy))

unfrozen test accuracy is 99.16


In [26]:
!md5sum "MLP_unfrozen_extractor_checkpoint.pt"
!md5sum "MLP_unfrozen_checkpoint.pt"

f3d6386b92407b070d43f1746d7f0879  MLP_unfrozen_extractor_checkpoint.pt
d52bad4deb6fc31f323ce48780cb99e6  MLP_unfrozen_checkpoint.pt


--- 
# Question 2 (train a model on Wikitext-2)

Here we'll apply what we just learned to NLP. In this section we'll make our own feature extractor and pretrain it on Wikitext-2.

The WikiText language modeling dataset is a collection of over 100 million tokens extracted from the set of verified Good and Featured articles on Wikipedia. The dataset is available under the Creative Commons Attribution-ShareAlike License.

#### Part A
In this section you need to generate the training, validation and test split. Feel free to use code from your previous lectures.

In [27]:
!pip install jsonlines
import pickle
import os
import io
import json
import jsonlines
from tqdm import tqdm
from collections import defaultdict
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
import torch
import torch.optim

Collecting jsonlines
  Downloading https://files.pythonhosted.org/packages/4f/9a/ab96291470e305504aa4b7a2e0ec132e930da89eb3ca7a82fbe03167c131/jsonlines-1.2.0-py2.py3-none-any.whl
Installing collected packages: jsonlines
Successfully installed jsonlines-1.2.0


In [0]:
class Dictionary(object):
    def __init__(self, datasets, include_valid=False):
        self.tokens = []
        self.ids = {}
        self.counts = {}
        
        # add special tokens
        self.add_token('<bos>')
        self.add_token('<eos>')
        self.add_token('<pad>')
        self.add_token('<unk>')
        
        for line in tqdm(datasets['train']):
            for w in line:
                self.add_token(w)
                    
        if include_valid is True:
            for line in tqdm(datasets['valid']):
                for w in line:
                    self.add_token(w)
        # include test
        for line in tqdm(datasets['test']):
            for w in line:
                self.add_token(w)
        
        
    def add_token(self, w):
        if w not in self.tokens:
            self.tokens.append(w)
            _w_id = len(self.tokens) - 1
            self.ids[w] = _w_id
            self.counts[w] = 1
        else:
            self.counts[w] += 1

    def get_id(self, w):
        return self.ids[w]
    
    def get_token(self, idx):
        return self.tokens[idx]
    
    def decode_idx_seq(self, l):
        return [self.tokens[i] for i in l]
    
    def encode_token_seq(self, l):
        return [self.ids[i] if i in self.ids else self.ids['<unk>'] for i in l]
    
    def __len__(self):
        return len(self.tokens)

In [0]:
def tokenize_dataset_wikitext(datasets, dictionary, ngram_order=2):
    tokenized_datasets = {}
    for split, dataset in datasets.items():
        _current_dictified = []
        for l in tqdm(dataset):
            l = ['<bos>']*(ngram_order-1) + l + ['<eos>']
            encoded_l = dictionary.encode_token_seq(l)
            _current_dictified.append(encoded_l)
        tokenized_datasets[split] = _current_dictified
        
    return tokenized_datasets

In [30]:
from torchtext.datasets import WikiText2
def load_wikitext(filename='wikitext2-sentencized.json'):
      if not os.path.exists(filename):
        !wget "https://nyu.box.com/shared/static/9kb7l7ci30hb6uahhbssjlq0kctr5ii4.json" -O $filename
    
      datasets = json.load(open(filename, 'r'))
      for name in datasets:
          datasets[name] = [x.split() for x in datasets[name]]
      vocab = list(set([t for ts in datasets['train'] for t in ts]))      
      print("Vocab size: %d" % (len(vocab)))
      return datasets, vocab
    
datasets,vocab = load_wikitext()
wikitext_dict = Dictionary(datasets, include_valid=True)


  1%|          | 511/78274 [00:00<00:15, 5102.63it/s]

Vocab size: 33175


100%|██████████| 78274/78274 [02:00<00:00, 647.92it/s]
100%|██████████| 8464/8464 [00:09<00:00, 880.13it/s]
100%|██████████| 9708/9708 [00:10<00:00, 919.04it/s]


In [0]:
def init_wikitext_dataset(datasets):
    """
    Fill in the details
    """
    raw_train = datasets["train"]
    raw_val = datasets["valid"]
    raw_test = datasets["test"]
    
    return raw_train,raw_val,raw_test


In [0]:
wikitext_train,wikitext_val,wikitext_test = init_wikitext_dataset(datasets)

#### Part B   
Here we design our own feature extractor. In MNIST that was a resnet because we were dealing with images. Now we need to pick a model that can model sequences better. Design an RNN-based model here.

In [0]:
class LSTMLanguageModel(nn.Module):
    """
    This model combines embedding, rnn and projection layer into a single model
    """
    def __init__(self, options):
        super().__init__()
        
        # create each LM part here 
        self.lookup = nn.Embedding(num_embeddings=options['num_embeddings'], embedding_dim=options['embedding_dim'], padding_idx=options['padding_idx'])
        self.lstm = nn.LSTM(options['input_size'], options['hidden_size'], options['num_layers'], dropout=options['lstm_dropout'], batch_first=True)
        self.projection = nn.Linear(options['hidden_size'], options['num_embeddings'])
        
    def forward(self, encoded_input_sequence):
        """
        Forward method process the input from token ids to logits
        """
        # |V| -> emb_dim
        embeddings = self.lookup(encoded_input_sequence)
        # emb_dim, hidden -> output, (h_n,c_n)     (hidden,hidden*num_layers)
        lstm_outputs = self.lstm(embeddings)
        logits = self.projection(lstm_outputs[0])
        
        return logits

In [0]:
load_pretrained = True

In [0]:
num_gpus = torch.cuda.device_count()
if num_gpus > 0:
    current_device = 'cuda'
else:
    current_device = 'cpu'

if load_pretrained:
  model_dict = torch.load("LSTM_model3.pt")
  model_weights = torch.load('LSTM_checkpoint3.pt')
  options = model_dict['options']
  model_LSTM = LSTMLanguageModel(options).to(current_device)
  model_LSTM.load_state_dict(model_weights)

else:
  embedding_size = 256
  hidden_size = 256
  num_layers = 3
  lstm_dropout = 0.3
  options = {
          'num_embeddings': len(wikitext_dict),
          'embedding_dim': embedding_size,
          'padding_idx': wikitext_dict.get_id('<pad>'),
          'input_size': embedding_size,
          'hidden_size': hidden_size,
          'num_layers': num_layers,
          'lstm_dropout': lstm_dropout,
      }
  model_LSTM = LSTMLanguageModel(options).to(current_device)



In [61]:
!md5sum "LSTM_model3.pt"
!md5sum "LSTM_checkpoint3.pt"

66992fe0a285d584e9e0f3e15370ac2f  LSTM_model3.pt
cb722931c9bc81aba31cc4870951edf6  LSTM_checkpoint3.pt


In [0]:
def init_feature_extractor(model):
    feature_extractor = model
    
    return feature_extractor

In [37]:
feature_extractor = init_feature_extractor(model_LSTM)
feature_extractor.named_children

<bound method Module.named_children of LSTMLanguageModel(
  (lookup): Embedding(33186, 256, padding_idx=2)
  (lstm): LSTM(256, 256, num_layers=3, batch_first=True, dropout=0.3)
  (projection): Linear(in_features=256, out_features=33186, bias=True)
)>

#### Part C
Pretrain the feature extractor

In [0]:
import torch
from torch.utils.data import Dataset, RandomSampler, SequentialSampler, DataLoader

class TensoredDataset(Dataset):
    def __init__(self, list_of_lists_of_tokens):
        self.input_tensors = []
        self.target_tensors = []
        
        for sample in list_of_lists_of_tokens:
            self.input_tensors.append(torch.tensor([sample[:-1]], dtype=torch.long))
            self.target_tensors.append(torch.tensor([sample[1:]], dtype=torch.long))
    
    def __len__(self):
        return len(self.input_tensors)
    
    def __getitem__(self, idx):
        # return a (input, target) tuple
        return (self.input_tensors[idx], self.target_tensors[idx])

In [0]:
def pad_list_of_tensors(list_of_tensors, pad_token):
    max_length = max([t.size(-1) for t in list_of_tensors])
    padded_list = []
    
    for t in list_of_tensors:
        padded_tensor = torch.cat([t, torch.tensor([[pad_token]*(max_length - t.size(-1))], dtype=torch.long)], dim = -1)
        padded_list.append(padded_tensor)
        
    padded_tensor = torch.cat(padded_list, dim=0)
    
    return padded_tensor

def pad_collate_fn(batch):
    # batch is a list of sample tuples
    input_list = [s[0] for s in batch]
    target_list = [s[1] for s in batch]
    
    pad_token = wikitext_dict.get_id('<pad>')
    #pad_token = 2
    
    input_tensor = pad_list_of_tensors(input_list, pad_token)
    target_tensor = pad_list_of_tensors(target_list, pad_token)
    
    return input_tensor, target_tensor

In [40]:
wikitext_tokenized_datasets = tokenize_dataset_wikitext(datasets, wikitext_dict)
wikitext_tensor_dataset = {}

for split, listoflists in wikitext_tokenized_datasets.items():
    wikitext_tensor_dataset[split] = TensoredDataset(listoflists)
    
# check the first example
wikitext_tensor_dataset['train'][0]

100%|██████████| 78274/78274 [00:00<00:00, 91285.18it/s]
100%|██████████| 8464/8464 [00:00<00:00, 118123.86it/s]
100%|██████████| 9708/9708 [00:00<00:00, 123000.79it/s]


(tensor([[ 0,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14,  4, 15, 16, 17, 18, 10,
          19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]]),
 tensor([[ 4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14,  4, 15, 16, 17, 18, 10, 19,
          20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,  1]]))

In [0]:
wikitext_loaders = {}
batch_size = 256 #64

for split, wikitext_dataset in wikitext_tensor_dataset.items():
    wikitext_loaders[split] = DataLoader(wikitext_dataset, batch_size=batch_size, shuffle=True, collate_fn=pad_collate_fn)

In [0]:
import numpy as np

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0,name = 'LSTM_'):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.name = name

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score - self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.name+'checkpoint.pt')
        self.val_loss_min = val_loss

In [43]:
patience = 5
early_stopping = EarlyStopping(patience=patience, verbose=True,name="LSTM_")
early_stopping

<__main__.EarlyStopping at 0x7f739d676f98>

In [0]:
plot_cache = []
criterion = torch.nn.CrossEntropyLoss(ignore_index=wikitext_dict.get_id('<pad>'))
def fit_feature_extractor(feature_extractor, wikitext_train, wikitext_val):
        model_parameters = [p for p in feature_extractor.parameters() if p.requires_grad]
        optimizer = torch.optim.SGD(model_parameters, lr=0.001, momentum=0.999)

        for epoch_number in range(100):
            avg_loss=0
            if not load_pretrained:
                # do train
                feature_extractor.train()
                train_log_cache = []
                for i, (inp, target) in enumerate(wikitext_loaders["train"]):
                    optimizer.zero_grad()
                    inp = inp.to(current_device)
                    target = target.to(current_device)
                    logits = feature_extractor(inp)
                    
                    loss = criterion(logits.view(-1, logits.size(-1)), target.view(-1))
                    
                    loss.backward()
                    optimizer.step()
                    
                    train_log_cache.append(loss.item())
                    
                    if i % 100 == 0:
                        avg_loss = sum(train_log_cache)/len(train_log_cache)
                        print('Step {} avg train loss = {:.{prec}f}'.format(i, avg_loss, prec=4))
                        train_log_cache = []
                
            #do valid
            valid_losses = []
            feature_extractor.eval()
            with torch.no_grad():
                for i, (inp, target) in enumerate(wikitext_loaders["valid"]):
                    inp = inp.to(current_device)
                    target = target.to(current_device)
                    logits = feature_extractor(inp)

                    loss = criterion(logits.view(-1, logits.size(-1)), target.view(-1))
                    valid_losses.append(loss.item())
                avg_val_loss = sum(valid_losses) / len(valid_losses)
                print('Validation loss after {} epoch = {:.{prec}f}'.format(epoch_number, avg_val_loss, prec=4))
            
            plot_cache.append((avg_loss, avg_val_loss))
        
            if not load_pretrained:
                early_stopping(avg_val_loss,feature_extractor)
                torch.save({
                    'loss_cache': plot_cache,
                }, './lstm.pt')

                if early_stopping.early_stop:
                    print("Early stopping")
                    break
                    
            if load_pretrained:
                print("Validation PPL:",2**(avg_val_loss/np.log(2)))
                break

#### Part D
Calculate the test perplexity on wikitext2. Feel free to recycle code from previous assignments from this class. 

In [0]:
def calculate_wiki2_test_perplexity(feature_extractor, wikitext_test):
    
    # FILL IN DETAILS
    plot_cache = []
    
    #do valid
    test_losses = []
    feature_extractor.eval()
    with torch.no_grad():
        for i, (inp, target) in enumerate(wikitext_loaders['test']):
            inp = inp.to(current_device)
            target = target.to(current_device)
            logits = feature_extractor(inp)

            loss = criterion(logits.view(-1, logits.size(-1)), target.view(-1))
            test_losses.append(loss.item())
        avg_test_loss = sum(test_losses) / len(test_losses)
        print('Test loss = {:.{prec}f}'.format(avg_test_loss, prec=4))

    plot_cache.append(avg_test_loss)
    test_ppl = 2**(avg_test_loss/np.log(2))   
    print('Test PPL:', test_ppl)
    return test_ppl

#### Let's grade your results!
(don't touch this part)

In [46]:
def grade_wikitext2():
    # load data
    wikitext_train, wikitext_val, wikitext_test = init_wikitext_dataset(datasets)

    # load feature extractor
    feature_extractor = init_feature_extractor(model_LSTM)

    # pretrain using the feature extractor
    fit_feature_extractor(feature_extractor, wikitext_train, wikitext_val)

    # check test accuracy
    test_ppl = calculate_wiki2_test_perplexity(feature_extractor, wikitext_test)

    # the real threshold will be released by Oct 11 
    assert test_ppl < 10000, 'ummm... your perplexity is too high...'
    
grade_wikitext2()

Validation loss after 0 epoch = 5.0225
Validation PPL: 151.78610281596377
Test loss = 4.9736
Test PPL: 144.55170551222892


---   
## Question 3 (fine-tune on MNLI)
In this question you will use your feature_extractor from question 2
to fine-tune on MNLI.

(From the website):
The Multi-Genre Natural Language Inference (MultiNLI) corpus is a crowd-sourced collection of 433k sentence pairs annotated with textual entailment information. The corpus is modeled on the SNLI corpus, but differs in that covers a range of genres of spoken and written text, and supports a distinctive cross-genre generalization evaluation. The corpus served as the basis for the shared task of the RepEval 2017 Workshop at EMNLP in Copenhagen.

MNLI has 3 genres (3 classes).
The goal of this question is to maximize the test accuracy in MNLI. 

### Part A
In this section you need to generate the training, validation and test split. Feel free to use code from your previous lectures.

In [0]:
from torchtext.datasets import MultiNLI

def init_mnli_dataset():
    """
    Fill in the details
    """
    mnli_val = None # TODO
    mnli_train = None # TODO
    mnli_test = None # TODO
    
    return mnli_train, mnli_val, mnli_test

### Part B
Here we again design a model for finetuning. Use the output of your feature-extractor as the input to this model. This should be a powerful classifier (up to you).

In [0]:
def init_finetune_model():
    
    # TODO FILL IN THE DETAILS
    fine_tune_model = ...
    
    return fine_tune_model

### Part C
Use the feature_extractor and your fine_tune_model to fine_tune MNLI

In [0]:
def fine_tune_mnli(feature_extractor, fine_tune_model, mnli_train, mnli_val):
    # YOUR CODE HERE
    pass

### Part D
Evaluate the test accuracy

In [0]:
def calculate_mnli_test_accuracy(feature_extractor, fine_tune_model, mnli_test):
    
    # YOUR CODE HERE...
    
    return test_ppl

### Let's grade your results

In [0]:
def grade_mnli():
    # load data
    mnli_train, mnli_val, mnli_test = init_mnli_dataset()

    # no need to load feature extractor because it is fine-tuned
    feature_extractor = feature_extractor

    # init the fine_tune model
    fine_tune_model = init_finetune_model()
    
    # finetune
    fine_tune_mnli(feature_extractor, fine_tune_model, mnli_train, mnli_val)

    # check test accuracy
    test_accuracy = calculate_mnli_test_accuracy(feature_extractor, wikitext_test)

    # the real threshold will be released by Oct 11 
    assert test_ppl > 0.00, 'ummm... your accuracy is too low...'
    
grade_mnli()

---  
### Question 4 (BERT)

A major direction in research came from a model called BERT, released last year.  

In this question you'll use BERT as your feature_extractor instead of the model you
designed yourself.

To get BERT, head on over to (https://github.com/huggingface/transformers) and load your BERT model here

In [47]:
!pip install jsonlines
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
import pickle
import torchvision.models as models
import os
from torchvision import transforms
from torchvision.datasets import  MNIST
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F
import os
import json
import jsonlines
import numpy as np
from collections import defaultdict
from torch import nn
import numpy



In [48]:
!pip install transformers
from transformers.data.processors.glue import MnliProcessor
import pandas as pd
import os
import sys
import shutil
import argparse
import tempfile
import urllib.request
import zipfile
from transformers import glue_convert_examples_to_features as convert_examples_to_features
from transformers import BertTokenizer
from torch.utils.data import TensorDataset, RandomSampler, DataLoader
import torch.optim as optim
from tqdm import trange
from tqdm import tqdm_notebook as tqdm

from transformers import (
    BertModel,
    BertTokenizer
)

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/fd/f9/51824e40f0a23a49eab4fcaa45c1c797cbf9761adedd0b558dab7c958b34/transformers-2.1.1-py3-none-any.whl (311kB)
[K     |█                               | 10kB 18.0MB/s eta 0:00:01[K     |██                              | 20kB 2.2MB/s eta 0:00:01[K     |███▏                            | 30kB 3.2MB/s eta 0:00:01[K     |████▏                           | 40kB 2.1MB/s eta 0:00:01[K     |█████▎                          | 51kB 2.6MB/s eta 0:00:01[K     |██████▎                         | 61kB 3.1MB/s eta 0:00:01[K     |███████▍                        | 71kB 3.6MB/s eta 0:00:01[K     |████████▍                       | 81kB 4.1MB/s eta 0:00:01[K     |█████████▌                      | 92kB 4.5MB/s eta 0:00:01[K     |██████████▌                     | 102kB 3.5MB/s eta 0:00:01[K     |███████████▋                    | 112kB 3.5MB/s eta 0:00:01[K     |████████████▋                   | 122kB 3.5M

100%|██████████| 231508/231508 [00:00<00:00, 2632875.27B/s]


In [49]:
TASK2PATH = {
    "MNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce"
}
def download_and_extract(task, data_dir):
    print("Downloading and extracting %s..." % task)
    data_file = "%s.zip" % task
    urllib.request.urlretrieve(TASK2PATH[task], data_file)
    with zipfile.ZipFile(data_file) as zip_ref:
        zip_ref.extractall(data_dir)
    os.remove(data_file)
    print("\tCompleted!")
download_and_extract('MNLI', '.')

Downloading and extracting MNLI...
	Completed!


### Part A (init BERT)
In this section you need to create an instance of BERT and return if from the function

In [0]:
def init_mnli_dataset():
  # ----------------------
  # TRAIN/VAL DATALOADERS
  # ----------------------
  train = processor.get_train_examples('MNLI')
  features = convert_examples_to_features(train,
                                          tokenizer,
                                          label_list=['contradiction','neutral','entailment'],
                                          max_length=64,
                                          output_mode='classification',
                                          pad_on_left=False,
                                          pad_token=tokenizer.pad_token_id,
                                          pad_token_segment_id=0)
  train_dataset = TensorDataset(torch.tensor([f.input_ids for f in features], dtype=torch.long), 
                                torch.tensor([f.attention_mask for f in features], dtype=torch.long), 
                                torch.tensor([f.token_type_ids for f in features], dtype=torch.long), 
                                torch.tensor([f.label for f in features], dtype=torch.long))

  nb_train_samples = int(0.75 * len(train_dataset))
  nb_val_samples = len(train_dataset) - nb_train_samples

  bert_mnli_train_dataset, bert_mnli_val_dataset = random_split(train_dataset, [nb_train_samples, nb_val_samples])

  # train loader
  train_sampler = RandomSampler(bert_mnli_train_dataset)
  bert_mnli_train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=32)

  # val loader
  val_sampler = RandomSampler(bert_mnli_val_dataset)
  bert_mnli_val_dataloader = DataLoader(bert_mnli_val_dataset, sampler=val_sampler, batch_size=32)


  # ----------------------
  # TEST DATALOADERS
  # ----------------------
  dev = processor.get_dev_examples('MNLI')
  features = convert_examples_to_features(dev,
                                          tokenizer,
                                          label_list=['contradiction','neutral','entailment'],
                                          max_length=64,
                                          output_mode='classification',
                                          pad_on_left=False,
                                          pad_token=tokenizer.pad_token_id,
                                          pad_token_segment_id=0)

  bert_mnli_test_dataset = TensorDataset(torch.tensor([f.input_ids for f in features], dtype=torch.long), 
                                torch.tensor([f.attention_mask for f in features], dtype=torch.long), 
                                torch.tensor([f.token_type_ids for f in features], dtype=torch.long), 
                                torch.tensor([f.label for f in features], dtype=torch.long))

  # test dataset
  test_sampler = RandomSampler(bert_mnli_test_dataset)
  bert_mnli_test_dataloader = DataLoader(bert_mnli_test_dataset, sampler=test_sampler, batch_size=32)
  
  return bert_mnli_train_dataloader, bert_mnli_val_dataloader, bert_mnli_test_dataloader

In [0]:
processor = MnliProcessor()

In [0]:
from transformers import BertTokenizer, BertModel, BertForMaskedLM

def init_bert():
    pretrained_weights = "bert-base-uncased"
    bert = BertModel.from_pretrained(pretrained_weights, output_attentions=True)
    return bert

## Part B (fine-tune with BERT)

Use BERT as your feature extractor to finetune MNLI. Use a new finetune model (reset weights).

In [0]:
class BERTSequenceClassifier(nn.Module):
    def __init__(self, bert, num_classes,hidden_size):
        super().__init__()
        self.bert = bert
        self.W = nn.Linear(bert.config.hidden_size, hidden_size)
        self.num_classes = num_classes
        self.relu = nn.ReLU()
        self.linear = nn.Linear(hidden_size, num_classes)
        
    def forward(self, input_ids, attention_mask, token_type_ids):
        h, _, attn = self.bert(input_ids=input_ids, 
                               attention_mask=attention_mask, 
                               token_type_ids=token_type_ids)
        h_cls = h[:, 0]
        logits = self.W(h_cls)
        relu = self.relu(logits)
        output = self.linear(relu)
        return output, attn

In [0]:
def init_finetune_model(bert, num, hidden):
  model = BERTSequenceClassifier(bert, num, hidden)
  return model

In [0]:
plot_cache = []
num_gpus = torch.cuda.device_count()
if num_gpus > 0:
    current_device = 'cuda'
else:
    current_device = 'cpu'

In [0]:
def fine_tune_mnli(model, train_loader, val_loader):
    criterion = nn.CrossEntropyLoss(ignore_index=-1).to(current_device)
    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], lr=2e-5, eps=1e-08)
    plot_cache = []
    for epoch_number in range(1):
        avg_loss=0
        model.train()
        train_log_cache = []
        epoch_iter = tqdm(train_loader, desc='training')
        for i, (inp, attention_masks , token_type_ids ,target) in enumerate(epoch_iter):
            optimizer.zero_grad()
            inp = inp.to(current_device)
            target = target.to(current_device)
            token_type_ids = token_type_ids.to(current_device)
            attention_masks = attention_masks.to(current_device)
            logits, _ = model(inp,attention_masks,token_type_ids)

            loss = criterion(logits.view(-1, logits.size(-1)), target.view(-1))

            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            train_log_cache.append(loss.item())

            if i % 1000 == 0:
                avg_loss = sum(train_log_cache)/len(train_log_cache)
                print('Step {} avg train loss = {:.{prec}f}'.format(i, avg_loss, prec=4))
                train_log_cache = []

        #do valid
        valid_losses = []
        #model.eval()
        with torch.no_grad():
            model.eval()
            total = 0
            correct = 0
            val_epoch_iter = tqdm(val_loader, desc='validating')
            for i,(inp, attention_masks , token_type_ids ,target) in enumerate(val_epoch_iter):
              inp = inp.to(current_device)
              target = target.to(current_device)
              token_type_ids = token_type_ids.to(current_device)
              attention_masks = attention_masks.to(current_device)
              logits, _ = model(inp,attention_masks,token_type_ids)

              outputs = F.softmax(logits, dim=1)
              predicted = outputs.max(1, keepdim=True)[1]
              temp = predicted
              total += target.size(0)
              correct += predicted.eq(target.view_as(predicted).to(current_device)).sum().item()
            print("Validation acc: ",100 * correct / total) 
    print('saving') 
    torch.save(model,"Bert_model.pt")
    torch.save(model.state_dict(),"Bert_model_state_dict.pt")

## Part C
Evaluate how well we did

In [0]:
import torch.nn.functional as F
def calculate_mnli_test_accuracy(model, test_loader):
    correct = 0
    total = 0

    for batch in tqdm(test_loader, desc="Testing"):
        model.eval()
        batch = tuple(t.to(current_device) for t in batch)
        labels = batch[-1]
        
        with torch.no_grad():
            logits = model(*batch[:-1])
            outputs = F.softmax(logits[0], dim=1)
            predicted = outputs.max(1, keepdim=True)[1]
            total += labels.size(0)
            correct += predicted.eq(labels.view_as(predicted)).sum().item()     
    
    accuracy = (100 * correct / total)
    print()
    print("Test acc: " + str(accuracy))
    return accuracy

## Let's grade your BERT results!

In [0]:
LOAD_PRETRAINED = True

In [59]:
def grade_mnli_BERT():
    BERT_feature_extractor = init_bert()
    # load data
    mnli_train, mnli_val, mnli_test = init_mnli_dataset()

    if not LOAD_PRETRAINED:
        # init the fine_tune model
        fine_tune_model = init_finetune_model(BERT_feature_extractor, 3, 256).to(current_device)
        
        # finetune
        fine_tune_mnli(fine_tune_model, mnli_train, mnli_val)
    else:
        fine_tune_model = torch.load("Bert_model.pt")

    # check test accuracy
    test_accuracy = calculate_mnli_test_accuracy(fine_tune_model, mnli_test)
    
    # the real threshold will be released by Oct 11 
    assert test_accuracy > 0.0, 'ummm... your accuracy is too low...'
    return test_accuracy
    
acc = grade_mnli_BERT()

100%|██████████| 313/313 [00:00<00:00, 99463.38B/s]
100%|██████████| 440473133/440473133 [00:09<00:00, 47644889.20B/s]


HBox(children=(IntProgress(value=0, description='Testing', max=307, style=ProgressStyle(description_width='ini…



Test acc: 82.31278655119715


In [60]:
!md5sum "Bert_model.pt"

1ac88f0b1cf52a658d4f1997375b5972  Bert_model.pt
