---   
# HW3 - Transfer learning

#### Due October 30, 2019

In this assignment you will learn about transfer learning. This technique is perhaps one of the most important techniques for industry. When a problem you want to solve does not have enough data, we use a different (larger) dataset to learn representations which can help us solve our task using the smaller task.

The general steps to transfer learning are as follows:

1. Find a huge dataset with similar characteristics to the problem you are interested in.
2. Choose a model powerful enough to extract meaningful representations from the huge dataset.
3. Train this model on the huge dataset.
4. Use this model to train on the smaller dataset.


### This homework has the following sections:
1. Question 1: MNIST fine-tuning (Parts A, B, C, D).
2. Question 2: Pretrain on Wikitext2 (Part A, B, C, D)
3. Question 3: Finetune on MNLI (Part A, B, C, D)
4. Question 4: Finetune using pretrained BERT (Part A, B, C)

---   
## Question 1 (MNIST transfer learning)
To grasp the high-level approach to transfer learning, let's first do a simple example using computer vision. 

The torchvision library has pretrained models (resnets, vggnets, etc) on the Imagenet dataset. Imagenet is a dataset
with 1.3 million images covering over 1000 classes of objects. When you use one of these models, the weights of the model initialize
with the weights saved from training on imagenet.

In this task we will:
1. Choose a pretrained model.
2. Freeze the model so that the weights don't change.
3. Fine-tune on a few labels of MNIST.   

#### Choose a model
Here we pick any of the models from torchvision

In [1]:
import torch
import torchvision.models as models

class Identity(torch.nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x

# init the pretrained feature extractor
pretrained_resnet18 = models.resnet18(pretrained=True)

# we don't want the built in last layer, we're going to modify it ourselves
pretrained_resnet18.fc = Identity()

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/checkpoints/resnet18-5c106cde.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 107MB/s]


#### Freeze the model
Here we freeze the weights of the model. Freezing means the gradients will not backpropagate
into these weights.

By doing this you can think about the model as a feature extractor. This feature extractor outputs
a **representation** of an input. This representation is a matrix that encodes information about the input.

In [0]:
def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False
        
def unfreeze_model(model):
    for param in model.parameters():
        param.requires_grad = True
        
freeze_model(pretrained_resnet18)

#### Init target dataset
Here we define the dataset we are actually interested in.

In [0]:
import os
from torchvision import transforms
from torchvision.datasets import  MNIST
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F

#  train/val  split
mnist_dataset = MNIST(os.getcwd(), train=True, download=True, transform = transforms.Compose([transforms.Grayscale(3),transforms.ToTensor()]))
mnist_train, mnist_val = random_split(mnist_dataset, [55000, 5000])

mnist_train = DataLoader(mnist_train, batch_size=32)
mnist_val = DataLoader(mnist_val, batch_size=32)

# test split
mnist_test = DataLoader(MNIST(os.getcwd(), train=False, download=True, transform = transforms.Compose([transforms.Grayscale(3),transforms.ToTensor()])), batch_size=32)

### Part A (init fine-tune model)
decide what model to use for fine-tuning

In [0]:
def init_fine_tune_model():
  # we will use the structure from resnet18 as our fine_tune_model
  feature_extractor = models.resnet18(pretrained=True)
  num_ftrs = feature_extractor.fc.in_features
  # freeze the model
  freeze_model(feature_extractor)
  fine_tune_model = feature_extractor 
  # add the last layer 
  fine_tune_model.fc = nn.Linear(num_ftrs,10)
  fine_tune_model= fine_tune_model.to(current_device)
  
  return fine_tune_model

### Part B (Fine-tune (Frozen))

The actual problem we care about solving likely has a different number of classes or is a different task altogether. Fine-tuning is the process of using the extracted representations (features) to solve this downstream task  (the task you're interested in).

To illustrate this, we'll use our pretrained model (on Imagenet), to solve the MNIST classification task.

There are two types of finetuning. 

#### 1. Frozen feature_extractor
In the first type we pretrain with the FROZEN feature_extractor and NEVER unfreeze it during finetuning.


#### 2. Unfrozen feature_extractor
In the second, we finetune with a FROZEN feature_extractor for a few epochs, then unfreeze the feature extractor and finish training.


In this part we will use the first version

In [0]:
# defining what device to use
num_gpus = torch.cuda.device_count()
if num_gpus > 0:
    current_device = 'cuda'
else:
    current_device = 'cpu'

In [0]:
import torch.optim as optim
import torch.nn as nn

def FROZEN_fine_tune_mnist(feature_extractor, fine_tune_model, mnist_train, mnist_val,model_name,hyperparameters,save):

  model =  fine_tune_model 

  # criterion
  criterion = nn.CrossEntropyLoss()
  # only update the parameters that are not frozen 
  model_params = [p for p in model.parameters() if p.requires_grad]
  optimizer = optim.Adam(model_params, lr=hyperparameters['lr'], weight_decay=hyperparameters['weight_decay'])
  num_epochs = hyperparameters['num_epochs']
  PATH = model_name + '.pth'

  for epoch in range(num_epochs):
    avg_loss = 0
    train_log_cache = []

    # training phase
    model.train()
    for i,(inp,target) in enumerate(mnist_train):
      optimizer.zero_grad()
      #inp = inp.repeat(1,3,1,1)
      inp = inp.to(current_device)
      target = target.to(current_device)
      logits = model(inp)
      # compute loss 
      loss = criterion(logits.view(-1, logits.size(-1)), target.view(-1))
      # back-propogation 
      loss.backward()
      # gradient clipping 
      nn.utils.clip_grad_norm_(model.parameters(), max_norm=5)
      optimizer.step()
      train_log_cache.append(loss.item())

      if i % 1000 == 0:
        avg_train_loss = sum(train_log_cache)/len(train_log_cache)
        print('Step {} avg train loss = {:.{prec}f}'.format(i, avg_train_loss, prec=4))
        train_log_cache = []

    # validation phase 
    valid_losses = []
    model.eval()
    with torch.no_grad():
      for i,(inp,target) in enumerate(mnist_val):
        #inp = inp.repeat(1,3,1,1)
        inp = inp.to(current_device)
        target = target.to(current_device)
        logits = model(inp)
        # compute loss 
        loss = criterion(logits.view(-1, logits.size(-1)), target.view(-1))
        valid_losses.append(loss.item())
    avg_val_loss = sum(valid_losses) / len(valid_losses)
    print('Validation loss after {} epoch = {:.{prec}f}'.format(epoch+1, avg_val_loss, prec=4))

  print('Finish training!')

  if save:
    torch.save(model,PATH)


### Part C (compute test accuracy)
Compute the test accuracy of fine-tuned model on MNIST

In [0]:
  def calculate_mnist_test_accuracy(feature_extractor, fine_tune_model, mnist_test):
    model = fine_tune_model
    model.eval()
    test_accuracy = []
    batch_num = 0
    with torch.no_grad():
        for i, (inp, target) in enumerate(mnist_test):
          batch_num += 1
          #inp = inp.repeat(1,3,1,1)
          inp = inp.to(current_device)
          target = target.to(current_device)
          logits = model(inp)
          softmax_output = F.softmax(logits,dim = 1)
          predictions = torch.argmax(softmax_output, dim=1)
          for i in range(len(predictions)):
            if predictions[i] == target[i]:
              test_accuracy.append(1)
            else:
              test_accuracy.append(0)
    
    return(sum(test_accuracy) / len(test_accuracy))

### Grade!
Let's see how you did

In [18]:
def grade_mnist_frozen():

    # init a ft model
    fine_tune_model = init_fine_tune_model()
    
    # run the transfer learning routine
    FROZEN_fine_tune_mnist(feature_extractor = pretrained_resnet18,
                          fine_tune_model = fine_tune_model,
                          mnist_train =  mnist_train,
                          mnist_val = mnist_val,
                          model_name = 'frozen_resnet18',
                          hyperparameters = hyperparameters,
                          save = True)
    
    # calculate test accuracy
    PATH = 'frozen_resnet18.pth'
    fine_tune_model = torch.load(PATH, map_location=current_device)
    test_accuracy = calculate_mnist_test_accuracy(pretrained_resnet18, fine_tune_model, mnist_test)
    
    # the real threshold will be released by Oct 11 
    assert test_accuracy > 0.0, 'your accuracy is too low...'
    
    return test_accuracy
    
hyperparameters= {
    'lr': 0.001,
    'num_epochs': 15,
    'weight_decay': 0         
 }    
frozen_test_accuracy = grade_mnist_frozen()

Step 0 avg train loss = 2.7019
Step 1000 avg train loss = 1.0716
Validation loss after 1 epoch = 0.7967
Step 0 avg train loss = 1.1307
Step 1000 avg train loss = 0.8440
Validation loss after 2 epoch = 0.7699
Step 0 avg train loss = 1.1272
Step 1000 avg train loss = 0.8186
Validation loss after 3 epoch = 0.7608
Step 0 avg train loss = 1.1222
Step 1000 avg train loss = 0.8084
Validation loss after 4 epoch = 0.7568
Step 0 avg train loss = 1.1179
Step 1000 avg train loss = 0.8033
Validation loss after 5 epoch = 0.7550
Step 0 avg train loss = 1.1145
Step 1000 avg train loss = 0.8005
Validation loss after 6 epoch = 0.7543
Step 0 avg train loss = 1.1118
Step 1000 avg train loss = 0.7988
Validation loss after 7 epoch = 0.7542
Step 0 avg train loss = 1.1096
Step 1000 avg train loss = 0.7978
Validation loss after 8 epoch = 0.7543
Step 0 avg train loss = 1.1078
Step 1000 avg train loss = 0.7971
Validation loss after 9 epoch = 0.7546
Step 0 avg train loss = 1.1062
Step 1000 avg train loss = 0.7967

In [19]:
print("The test accuracy for frozen model is {}".format(frozen_test_accuracy))

The test accuracy for frozen model is 0.7691


### Part D (Fine-tune Unfrozen)
Now we'll learn how to train using the "unfrozen" approach.

In this approach we'll:
1. keep the feature_extract frozen for a few epochs (10)
2. Unfreeze it.
3. Finish training

In [0]:
def UNFROZEN_fine_tune_mnist(feature_extractor,fine_tune_model,model_name,mnist_train, mnist_val,hyperparameters,save):
  model = fine_tune_model
  criterion = nn.CrossEntropyLoss()
  # only update the parameters that are not frozen 
  model_params = [p for p in model.parameters() if p.requires_grad]
  optimizer = optim.Adam(model_params, lr=hyperparameters['lr'], weight_decay=hyperparameters['weight_decay'])
  num_epochs = hyperparameters['num_epochs']
  PATH = model_name + '.pth'
  
  for epoch in range(num_epochs):
    if epoch >10:
      # unfreeze the model 
      unfreeze_model(model)
    avg_loss = 0
    train_log_cache = []

    # training phase
    model.train()
    for i,(inp,target) in enumerate(mnist_train):
      optimizer.zero_grad()
      #inp = inp.repeat(1,3,1,1)
      inp = inp.to(current_device)
      target = target.to(current_device)
      logits = model(inp)
      # compute loss 
      loss = criterion(logits.view(-1, logits.size(-1)), target.view(-1))
      # back-propogation 
      loss.backward()
      # gradient clipping 
      nn.utils.clip_grad_norm_(model.parameters(), max_norm=5)
      optimizer.step()
      train_log_cache.append(loss.item())

      if i % 1000 == 0:
        avg_train_loss = sum(train_log_cache)/len(train_log_cache)
        print('Step {} avg train loss = {:.{prec}f}'.format(i, avg_train_loss, prec=4))
        train_log_cache = []

    # validation phase 
    valid_losses = []
    model.eval()
    with torch.no_grad():
      for i,(inp,target) in enumerate(mnist_val):
        #inp = inp.repeat(1,3,1,1)
        inp = inp.to(current_device)
        target = target.to(current_device)
        logits = model(inp)
        # compute loss 
        loss = criterion(logits.view(-1, logits.size(-1)), target.view(-1))
        valid_losses.append(loss.item())
    avg_val_loss = sum(valid_losses) / len(valid_losses)
    print('Validation loss after {} epoch = {:.{prec}f}'.format(epoch+1, avg_val_loss, prec=4))

  print('Finish training!')

  if save:
    torch.save(model,PATH)

### Grade UNFROZEN
Let's see if there's a difference in accuracy!

In [21]:
def grade_mnist_unfrozen():
  # init a ft model
    fine_tune_model = init_fine_tune_model()
    
    # run the transfer learning routine
    UNFROZEN_fine_tune_mnist(feature_extractor = pretrained_resnet18,
                              fine_tune_model = fine_tune_model,
                              mnist_train =  mnist_train,
                              mnist_val = mnist_val,
                              model_name = 'unfrozen_resnet18',
                              hyperparameters = hyperparameters,
                              save = True)
    
    # calculate test accuracy
    PATH = 'unfrozen_resnet18.pth'
    fine_tune_model = torch.load(PATH, map_location=current_device)
    test_accuracy = calculate_mnist_test_accuracy(pretrained_resnet18, fine_tune_model, mnist_test)
    
    # the real threshold will be released by Oct 11 
    assert test_accuracy > 0.0, 'your accuracy is too low...'
    
    return test_accuracy

hyperparameters= {
    'lr': 0.001,
    'num_epochs': 15,
    'weight_decay': 0         
 }
    
unfrozen_test_accuracy = grade_mnist_unfrozen()

Step 0 avg train loss = 2.4754
Step 1000 avg train loss = 1.0709
Validation loss after 1 epoch = 0.7985
Step 0 avg train loss = 1.1337
Step 1000 avg train loss = 0.8445
Validation loss after 2 epoch = 0.7709
Step 0 avg train loss = 1.1287
Step 1000 avg train loss = 0.8191
Validation loss after 3 epoch = 0.7616
Step 0 avg train loss = 1.1243
Step 1000 avg train loss = 0.8088
Validation loss after 4 epoch = 0.7575
Step 0 avg train loss = 1.1201
Step 1000 avg train loss = 0.8036
Validation loss after 5 epoch = 0.7557
Step 0 avg train loss = 1.1167
Step 1000 avg train loss = 0.8007
Validation loss after 6 epoch = 0.7550
Step 0 avg train loss = 1.1139
Step 1000 avg train loss = 0.7989
Validation loss after 7 epoch = 0.7548
Step 0 avg train loss = 1.1116
Step 1000 avg train loss = 0.7979
Validation loss after 8 epoch = 0.7548
Step 0 avg train loss = 1.1096
Step 1000 avg train loss = 0.7972
Validation loss after 9 epoch = 0.7550
Step 0 avg train loss = 1.1079
Step 1000 avg train loss = 0.7967

In [22]:
print("The test accuracy for unfrozen model is {}".format(unfrozen_test_accuracy))

The test accuracy for unfrozen model is 0.7681


In [23]:
assert unfrozen_test_accuracy > frozen_test_accuracy, 'the unfrozen model should be better'

AssertionError: ignored

--- 
# Question 2 (train a model on Wikitext-2)

Here we'll apply what we just learned to NLP. In this section we'll make our own feature extractor and pretrain it on Wikitext-2.

The WikiText language modeling dataset is a collection of over 100 million tokens extracted from the set of verified Good and Featured articles on Wikipedia. The dataset is available under the Creative Commons Attribution-ShareAlike License.

#### Part A
In this section you need to generate the training, validation and test split. Feel free to use code from your previous lectures.

In [0]:
from torchtext.datasets import WikiText2


def init_wikitext_dataset():
    """
    Fill in the details
    """
    wikitext_val = None # YOUR CODE HERE
    wikitext_train = None # YOUR CODE HERE
    wikitext_test = None # YOUR CODE HERE
    
    return wikitext_train, wikitext_val, wikitext_test

In [9]:
type(WikiText2)

type

#### Part B   
Here we design our own feature extractor. In MNIST that was a resnet because we were dealing with images. Now we need to pick a model that can model sequences better. Design an RNN-based model here.

In [0]:
def init_feature_extractor():
    
    feature_extractor = None #  YOUR CODE
    
    return feature_extractor

#### Part C
Pretrain the feature extractor

In [0]:
def fit_feature_extractor(feature_extractor, wikitext_train, wikitext_val):
    # FILL IN THE DETAILS
    pass

#### Part D
Calculate the test perplexity on wikitext2. Feel free to recycle code from previous assignments from this class. 

In [0]:
def calculate_wiki2_test_perplexity(feature_extractor, wikitext_test):
    
    # FILL IN DETAILS
    
    return test_ppl

#### Let's grade your results!
(don't touch this part)

In [0]:
def grade_wikitext2():
    # load data
    wikitext_train, wikitext_val, wikitext_test = init_wikitext_dataset()

    # load feature extractor
    feature_extractor = init_feature_extractor()

    # pretrain using the feature extractor
    fit_feature_extractor(feature_extractor, wikitext_train, wikitext_val)

    # check test accuracy
    test_ppl = calculate_wiki2_test_perplexity(feature_extractor, wikitext_test)

    # the real threshold will be released by Oct 11 
    assert test_ppl < 10000, 'ummm... your perplexity is too high...'
    
grade_wikitext2()

---   
## Question 3 (fine-tune on MNLI)
In this question you will use your feature_extractor from question 2
to fine-tune on MNLI.

(From the website):
The Multi-Genre Natural Language Inference (MultiNLI) corpus is a crowd-sourced collection of 433k sentence pairs annotated with textual entailment information. The corpus is modeled on the SNLI corpus, but differs in that covers a range of genres of spoken and written text, and supports a distinctive cross-genre generalization evaluation. The corpus served as the basis for the shared task of the RepEval 2017 Workshop at EMNLP in Copenhagen.

MNLI has 3 genres (3 classes).
The goal of this question is to maximize the test accuracy in MNLI. 

### Part A
In this section you need to generate the training, validation and test split. Feel free to use code from your previous lectures.

In [0]:
from torchtext.datasets import MultiNLI

def init_mnli_dataset():
    """
    Fill in the details
    """
    mnli_val = None # TODO
    mnli_train = None # TODO
    mnli_test = None # TODO
    
    return mnli_train, mnli_val, mnli_test

### Part B
Here we again design a model for finetuning. Use the output of your feature-extractor as the input to this model. This should be a powerful classifier (up to you).

In [0]:
def init_finetune_model():
    
    # TODO FILL IN THE DETAILS
    fine_tune_model = ...
    
    return fine_tune_model

### Part C
Use the feature_extractor and your fine_tune_model to fine_tune MNLI

In [0]:
def fine_tune_mnli(feature_extractor, fine_tune_model, mnli_train, mnli_val):
    # YOUR CODE HERE
    pass

### Part D
Evaluate the test accuracy

In [0]:
def calculate_mnli_test_accuracy(feature_extractor, fine_tune_model, mnli_test):
    
    # YOUR CODE HERE...
    
    return test_ppl

### Let's grade your results

In [0]:
def grade_mnli():
    # load data
    mnli_train, mnli_val, mnli_test = init_mnli_dataset()

    # no need to load feature extractor because it is fine-tuned
    feature_extractor = feature_extractor

    # init the fine_tune model
    fine_tune_model = init_finetune_model()
    
    # finetune
    fine_tune_mnli(feature_extractor, fine_tune_model, mnli_train, mnli_val)

    # check test accuracy
    test_accuracy = calculate_mnli_test_accuracy(feature_extractor, wikitext_test)

    # the real threshold will be released by Oct 11 
    assert test_ppl > 0.00, 'ummm... your accuracy is too low...'
    
grade_mnli()

---  
### Question 4 (BERT)

A major direction in research came from a model called BERT, released last year.  

In this question you'll use BERT as your feature_extractor instead of the model you
designed yourself.

To get BERT, head on over to (https://github.com/huggingface/transformers) and load your BERT model here

In [0]:
!pip install transformers

### Part A (init BERT)
In this section you need to create an instance of BERT and return if from the function

In [0]:
from transformers import BertTokenizer, BertModel, BertForMaskedLM

def init_bert():
    
    BERT = None # ... YOUR CODE HERE
    
    return BERT

## Part B (fine-tune with BERT)

Use BERT as your feature extractor to finetune MNLI. Use a new finetune model (reset weights).

In [0]:
def fine_tune_mnli_BERT(BERT_feature_extractor, fine_tune_model, mnli_train, mnli_val):
    # YOUR CODE HERE
    pass

## Part C
Evaluate how well we did

In [0]:
def calculate_mnli_test_accuracy_BERT(feature_extractor, fine_tune_model, mnli_test):
    
    # YOUR CODE HERE...
    
    return test_ppl

## Let's grade your BERT results!

In [0]:
def grade_mnli_BERT():
    BERT_feature_extractor = init_bert()
    
    # load data
    mnli_train, mnli_val, mnli_test = init_mnli_dataset()

    # init the fine_tune model
    fine_tune_model = init_finetune_model()
    
    # finetune
    fine_tune_mnli(BERT_feature_extractor, fine_tune_model, mnli_train, mnli_val)

    # check test accuracy
    test_accuracy = calculate_mnli_test_accuracy(feature_extractor, wikitext_test)
    
    # the real threshold will be released by Oct 11 
    assert test_ppl > 0.0, 'ummm... your accuracy is too low...'
    
grade_mnli_BERT()