---   
# HW3 - Transfer learning

#### Due October 30, 2019

In this assignment you will learn about transfer learning. This technique is perhaps one of the most important techniques for industry. When a problem you want to solve does not have enough data, we use a different (larger) dataset to learn representations which can help us solve our task using the smaller task.

The general steps to transfer learning are as follows:

1. Find a huge dataset with similar characteristics to the problem you are interested in.
2. Choose a model powerful enough to extract meaningful representations from the huge dataset.
3. Train this model on the huge dataset.
4. Use this model to train on the smaller dataset.


### This homework has the following sections:
1. Question 1: MNIST fine-tuning (Parts A, B, C, D).
2. Question 2: Pretrain on Wikitext2 (Part A, B, C, D)
3. Question 3: Finetune on MNLI (Part A, B, C, D)
4. Question 4: Finetune using pretrained BERT (Part A, B, C)

--- 
# Question 2 (train a model on Wikitext-2)

Here we'll apply what we just learned to NLP. In this section we'll make our own feature extractor and pretrain it on Wikitext-2.

The WikiText language modeling dataset is a collection of over 100 million tokens extracted from the set of verified Good and Featured articles on Wikipedia. The dataset is available under the Creative Commons Attribution-ShareAlike License.

#### Part A
In this section you need to generate the training, validation and test split. Feel free to use code from your previous lectures.

In [0]:
!pip install jsonlines
import torchtext
from torchtext import data
import spacy
import os
import json
import jsonlines
import numpy as np
from tqdm import tqdm
from collections import defaultdict
import torch
import torch.nn as nn



In [0]:
class Dictionary(object):
    def __init__(self, datasets, include_valid=False):
        self.tokens = []
        self.ids = {}
        self.counts = {}
        
        # add special tokens
        self.add_token('<bos>')
        self.add_token('<eos>')
        self.add_token('<pad>')
        self.add_token('<unk>')
        
        for line in tqdm(datasets['train']):
            for w in line:
                self.add_token(w)
                    
        if include_valid is True:
            for line in tqdm(datasets['valid']):
                for w in line:
                    self.add_token(w)
        # include test
        for line in tqdm(datasets['test']):
            for w in line:
                self.add_token(w)
        
        
    def add_token(self, w):
        if w not in self.tokens:
            self.tokens.append(w)
            _w_id = len(self.tokens) - 1
            self.ids[w] = _w_id
            self.counts[w] = 1
        else:
            self.counts[w] += 1

    def get_id(self, w):
        return self.ids[w]
    
    def get_token(self, idx):
        return self.tokens[idx]
    
    def decode_idx_seq(self, l):
        return [self.tokens[i] for i in l]
    
    def encode_token_seq(self, l):
        return [self.ids[i] if i in self.ids else self.ids['<unk>'] for i in l]
    
    def __len__(self):
        return len(self.tokens)

In [0]:
def tokenize_dataset_wikitext(datasets, dictionary, ngram_order=2):
    tokenized_datasets = {}
    for split, dataset in datasets.items():
        _current_dictified = []
        for l in tqdm(dataset):
            l = ['<bos>']*(ngram_order-1) + l + ['<eos>']
            encoded_l = dictionary.encode_token_seq(l)
            _current_dictified.append(encoded_l)
        tokenized_datasets[split] = _current_dictified
        
    return tokenized_datasets

In [0]:
from torchtext.datasets import WikiText2
def load_wikitext(filename='wikitext2-sentencized.json'):
      if not os.path.exists(filename):
        !wget "https://nyu.box.com/shared/static/9kb7l7ci30hb6uahhbssjlq0kctr5ii4.json" -O $filename
    
      datasets = json.load(open(filename, 'r'))
      for name in datasets:
          datasets[name] = [x.split() for x in datasets[name]]
      vocab = list(set([t for ts in datasets['train'] for t in ts]))      
      print("Vocab size: %d" % (len(vocab)))
      return datasets, vocab
    
datasets,vocab = load_wikitext()
wikitext_dict = Dictionary(datasets, include_valid=True)


--2019-10-15 19:37:04--  https://nyu.box.com/shared/static/9kb7l7ci30hb6uahhbssjlq0kctr5ii4.json
Resolving nyu.box.com (nyu.box.com)... 107.152.26.197
Connecting to nyu.box.com (nyu.box.com)|107.152.26.197|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: /public/static/9kb7l7ci30hb6uahhbssjlq0kctr5ii4.json [following]
--2019-10-15 19:37:04--  https://nyu.box.com/public/static/9kb7l7ci30hb6uahhbssjlq0kctr5ii4.json
Reusing existing connection to nyu.box.com:443.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://nyu.app.box.com/public/static/9kb7l7ci30hb6uahhbssjlq0kctr5ii4.json [following]
--2019-10-15 19:37:04--  https://nyu.app.box.com/public/static/9kb7l7ci30hb6uahhbssjlq0kctr5ii4.json
Resolving nyu.app.box.com (nyu.app.box.com)... 107.152.26.199
Connecting to nyu.app.box.com (nyu.app.box.com)|107.152.26.199|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://public.boxcloud.com/

  1%|          | 457/78274 [00:00<00:17, 4568.70it/s]

Vocab size: 33175


100%|██████████| 78274/78274 [02:25<00:00, 538.47it/s]
100%|██████████| 8464/8464 [00:12<00:00, 695.00it/s]
100%|██████████| 9708/9708 [00:12<00:00, 794.32it/s]


In [0]:
def init_wikitext_dataset(datasets):
    """
    Fill in the details
    """
    raw_train = datasets["train"]
    raw_val = datasets["valid"]
    raw_test = datasets["test"]
    
    return raw_train,raw_val,raw_test


In [0]:
wikitext_train,wikitext_val,wikitext_test = init_wikitext_dataset(datasets)

In [0]:
# checking some example
print(' '.join(wikitext_train[3010]))

encoded = wikitext_dict.encode_token_seq(wikitext_train[3010])
print(f'\n encoded - {encoded}')
decoded = wikitext_dict.decode_idx_seq(encoded)
print(f'\n decoded - {decoded}')

The Nataraja and Ardhanarishvara sculptures are also attributed to the Rashtrakutas .

 encoded - [75, 8816, 30, 8817, 8732, 70, 91, 2960, 13, 6, 8806, 39]

 decoded - ['The', 'Nataraja', 'and', 'Ardhanarishvara', 'sculptures', 'are', 'also', 'attributed', 'to', 'the', 'Rashtrakutas', '.']


In [0]:
len(vocab)

33175

#### Part B   
Here we design our own feature extractor. In MNIST that was a resnet because we were dealing with images. Now we need to pick a model that can model sequences better. Design an RNN-based model here.

In [0]:
LOAD_PRETRAINED = True
IGNORE_PROJECTION = False

In [0]:
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x

In [0]:
class LSTMLanguageModel(nn.Module):
    """
    This model combines embedding, rnn and projection layer into a single model
    """
    def __init__(self, options):
        super().__init__()
        
        # create each LM part here 
        self.lookup = nn.Embedding(num_embeddings=options['num_embeddings'], embedding_dim=options['embedding_dim'], padding_idx=options['padding_idx'])
        self.lstm = nn.LSTM(options['input_size'], options['hidden_size'], options['num_layers'], dropout=options['lstm_dropout'], batch_first=True)
        self.projection = nn.Linear(options['hidden_size'], options['num_embeddings'])
        
    def forward(self, encoded_input_sequence):
        """
        Forward method process the input from token ids to logits
        """
        # |V| -> emb_dim
        embeddings = self.lookup(encoded_input_sequence)
        # emb_dim, hidden -> output, (h_n,c_n)     (hidden,hidden*num_layers)
        lstm_outputs = self.lstm(embeddings)
        logits = self.projection(lstm_outputs[0])
        
        return logits

In [0]:
num_gpus = torch.cuda.device_count()
if num_gpus > 0:
    current_device = 'cuda'
else:
    current_device = 'cpu'

if LOAD_PRETRAINED:
  model_dict = torch.load("LSTM_model3.pt")
  model_weights = torch.load('LSTM_checkpoint3.pt')
  options = model_dict['options']
  model_LSTM = LSTMLanguageModel(options).to(current_device)
  model_LSTM.load_state_dict(model_weights)

else:
  embedding_size = 256
  hidden_size = 256
  num_layers = 3
  lstm_dropout = 0.3
  options = {
          'num_embeddings': len(wikitext_dict),
          'embedding_dim': embedding_size,
          'padding_idx': wikitext_dict.get_id('<pad>'),
          'input_size': embedding_size,
          'hidden_size': hidden_size,
          'num_layers': num_layers,
          'lstm_dropout': lstm_dropout,
      }
  model_LSTM = LSTMLanguageModel(options).to(current_device)



In [0]:
def init_feature_extractor(model):
    feature_extractor = model
    
    return feature_extractor

In [0]:
feature_extractor = init_feature_extractor(model_LSTM)
feature_extractor.named_children

<bound method Module.named_children of LSTMLanguageModel(
  (lookup): Embedding(33186, 256, padding_idx=2)
  (lstm): LSTM(256, 256, num_layers=3, batch_first=True, dropout=0.3)
  (projection): Linear(in_features=256, out_features=33186, bias=True)
)>

#### Part C
Pretrain the feature extractor

In [0]:
import torch
from torch.utils.data import Dataset, RandomSampler, SequentialSampler, DataLoader

class TensoredDataset(Dataset):
    def __init__(self, list_of_lists_of_tokens):
        self.input_tensors = []
        self.target_tensors = []
        
        for sample in list_of_lists_of_tokens:
            self.input_tensors.append(torch.tensor([sample[:-1]], dtype=torch.long))
            self.target_tensors.append(torch.tensor([sample[1:]], dtype=torch.long))
    
    def __len__(self):
        return len(self.input_tensors)
    
    def __getitem__(self, idx):
        # return a (input, target) tuple
        return (self.input_tensors[idx], self.target_tensors[idx])

In [0]:
def pad_list_of_tensors(list_of_tensors, pad_token):
    max_length = max([t.size(-1) for t in list_of_tensors])
    padded_list = []
    
    for t in list_of_tensors:
        padded_tensor = torch.cat([t, torch.tensor([[pad_token]*(max_length - t.size(-1))], dtype=torch.long)], dim = -1)
        padded_list.append(padded_tensor)
        
    padded_tensor = torch.cat(padded_list, dim=0)
    
    return padded_tensor

def pad_collate_fn(batch):
    # batch is a list of sample tuples
    input_list = [s[0] for s in batch]
    target_list = [s[1] for s in batch]
    
    pad_token = wikitext_dict.get_id('<pad>')
    #pad_token = 2
    
    input_tensor = pad_list_of_tensors(input_list, pad_token)
    target_tensor = pad_list_of_tensors(target_list, pad_token)
    
    return input_tensor, target_tensor

In [0]:
wikitext_tokenized_datasets = tokenize_dataset_wikitext(datasets, wikitext_dict)
wikitext_tensor_dataset = {}

for split, listoflists in wikitext_tokenized_datasets.items():
    wikitext_tensor_dataset[split] = TensoredDataset(listoflists)
    
# check the first example
wikitext_tensor_dataset['train'][0]

100%|██████████| 78274/78274 [00:00<00:00, 94013.99it/s]
100%|██████████| 8464/8464 [00:00<00:00, 120218.32it/s]
100%|██████████| 9708/9708 [00:00<00:00, 36790.38it/s]


(tensor([[ 0,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14,  4, 15, 16, 17, 18, 10,
          19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]]),
 tensor([[ 4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14,  4, 15, 16, 17, 18, 10, 19,
          20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,  1]]))

In [0]:
wikitext_loaders = {}
batch_size = 256 #64

for split, wikitext_dataset in wikitext_tensor_dataset.items():
    wikitext_loaders[split] = DataLoader(wikitext_dataset, batch_size=batch_size, shuffle=True, collate_fn=pad_collate_fn)

In [0]:
import numpy as np

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0,name = 'LSTM_'):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.name = name

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score - self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.name+'checkpoint.pt')
        self.val_loss_min = val_loss

In [0]:
patience = 5
early_stopping = EarlyStopping(patience=patience, verbose=True,name="LSTM_")
early_stopping

<__main__.EarlyStopping at 0x7f733c9a4128>

In [0]:
import torch.optim as optim
criterion = nn.CrossEntropyLoss(ignore_index=wikitext_dict.get_id('<pad>'))

model_parameters = [p for p in feature_extractor.parameters() if p.requires_grad]
optimizer = optim.SGD(model_parameters, lr=0.001, momentum=0.999)

In [0]:
def fit_feature_extractor(feature_extractor, wikitext_train, wikitext_val):
    # FILL IN THE DETAILS
    plot_cache = []

    for epoch_number in range(100):
        avg_loss=0
        if not LOAD_PRETRAINED:
            # do train
            feature_extractor.train()
            train_log_cache = []
            for i, (inp, target) in enumerate(wikitext_loaders['train']):
                optimizer.zero_grad()
                inp = inp.to(current_device)
                target = target.to(current_device)
                logits = feature_extractor(inp)

                loss = criterion(logits.view(-1, logits.size(-1)), target.view(-1))

                loss.backward()
                optimizer.step()

                train_log_cache.append(loss.item())

                if i % 100 == 0:
                    avg_loss = sum(train_log_cache)/len(train_log_cache)
                    print('Step {} avg train loss = {:.{prec}f}'.format(i, avg_loss, prec=4))
                    train_log_cache = []

        #do valid
        valid_losses = []
        feature_extractor.eval()
        with torch.no_grad():
            for i, (inp, target) in enumerate(wikitext_loaders['valid']):
                inp = inp.to(current_device)
                target = target.to(current_device)
                logits = feature_extractor(inp)

                loss = criterion(logits.view(-1, logits.size(-1)), target.view(-1))
                valid_losses.append(loss.item())
            avg_val_loss = sum(valid_losses) / len(valid_losses)
            print('Validation loss after {} epoch = {:.{prec}f}'.format(epoch_number, avg_val_loss, prec=4))

        plot_cache.append((avg_loss, avg_val_loss))
        
        if LOAD_PRETRAINED:
            print("Validation PPL:",2**(avg_val_loss/np.log(2)))
            break

      

In [0]:
fit_feature_extractor(feature_extractor, wikitext_train, wikitext_val)

Validation loss after 0 epoch = 5.0082
Validation PPL: 149.63181996722716


#### Part D
Calculate the test perplexity on wikitext2. Feel free to recycle code from previous assignments from this class. 

In [0]:
def calculate_wiki2_test_perplexity(feature_extractor, wikitext_test):
    
    # FILL IN DETAILS
    plot_cache = []
    
    #do valid
    test_losses = []
    feature_extractor.eval()
    with torch.no_grad():
        for i, (inp, target) in enumerate(wikitext_loaders['test']):
            inp = inp.to(current_device)
            target = target.to(current_device)
            logits = feature_extractor(inp)

            loss = criterion(logits.view(-1, logits.size(-1)), target.view(-1))
            test_losses.append(loss.item())
        avg_test_loss = sum(test_losses) / len(test_losses)
        print('Test loss = {:.{prec}f}'.format(avg_test_loss, prec=4))

    plot_cache.append(avg_test_loss)
    test_ppl = 2**(avg_test_loss/np.log(2))   
    print('Test PPL:', test_ppl)
    return test_ppl

#### Let's grade your results!
(don't touch this part)

In [0]:
def grade_wikitext2():
    # load data
    wikitext_train, wikitext_val, wikitext_test = init_wikitext_dataset(datasets)

    # load feature extractor
    feature_extractor = init_feature_extractor(model_LSTM)

    # pretrain using the feature extractor
    fit_feature_extractor(feature_extractor, wikitext_train, wikitext_val)

    # check test accuracy
    test_ppl = calculate_wiki2_test_perplexity(feature_extractor, wikitext_test)

    # the real threshold will be released by Oct 11 
    assert test_ppl < 10000, 'ummm... your perplexity is too high...'
    
grade_wikitext2()

Validation loss after 0 epoch = 5.0188
Validation PPL: 151.2367787907184
Test loss = 4.9733
Test PPL: 144.50948081333289
