# NLP Model Distillation
The aim of this task was to distill a small sentiment classifier from a pretrained RoBERTa teacher model (delivered via the `bert-fast` library). For the student model, GloVe embeddings were fed into an LSTM based model. All training commands (given below) are excecuted via the command line - this notebook is to present results rather than train the models but if you have checkpoints available you can run evaluation below.

## Student Architecture
The sudent architecture design-choices were as follows:
* Avoid attention-based residual blocks/models as this would significantly increase student inference time which defeats the point of the distillation.
* Use LSTM model for sequence processing - hidden sizes chosen such that there are < `1M` parameters.
* For final state take concatenation of max-pool over input sequence and the hidden state of the final token in the sequence. Note that the latter of these requires that we keep track of token sequence lengths.  


## Student from scratch
To train the student from scratch run the following command:

```bash
   python distill/train/train_student.py \
        --expt_name from-scratch \
        --input_csv input-csv-containing-labelled-text.csv 
        --model_type lstm
 
```

This will train and evaluate the model on a 60:20 subset of `--input_csv` (the final 20% is set aside as a test-set).  

## Distillation
To distill a provided `bert-fast` model directory into a randomly initialized student model run the following command:

```bash
   python distill/train/distill_from_teacher.py \
        --expt_name teacher \
        --input_csv input-csv-containing-unlabelled-text.csv 
        --model_type lstm
```
This will train the student model on the teacher's softmax outputs.  

## Aside: Potential Optimization
The distillation is slow - unecessarily so as I am re-generating the teacher outputs on each epoch. A simple optimization would be to preprocess these once at the start of training.  

# Evaluation
I will load local versions of these models to generate the results below - to run the cells below it will be necessary to train and manually select the best epoch on the validation set after viewing the model printouts. I will evaluate both models on my held-out test set of labelled headlines.

In [1]:
cd ..

In [2]:
ls logs/lstm9

LSTMClassifier_100.pt  LSTMClassifier_40.pt  LSTMClassifier_80.pt
LSTMClassifier_10.pt   LSTMClassifier_50.pt  LSTMClassifier_90.pt
LSTMClassifier_20.pt   LSTMClassifier_60.pt
LSTMClassifier_30.pt   LSTMClassifier_70.pt


In [3]:
student_from_scratch_fp = './logs/lstm6/LSTMClassifier_90.pt'
student_distilled_fp = './logs/distill/lstm1/LSTMClassifier_9.pt'
teacher_dir = './model'

In [4]:
%load_ext autoreload
%autoreload 2

In [5]:
import argparse 
import copy 
import time 

import torch 
from torch.utils.data import DataLoader

from distill.evaluate import evaluate, print_eval_res
from distill.train.train_student import add_train_args, train_init
from distill.data import CSVTextDataset
from distill.labels import probs_to_labels, all_labels
from distill.teacher import TeacherNLPClassifier
from distill.train.train_teacher import unpack_batch_send_to_device as unpack_batch_teacher

In [6]:
def init_and_load_student_model(ckpt_path, model_type='lstm'):
    parser = argparse.ArgumentParser()
    parser = add_train_args(parser)
    args = parser.parse_args()
    args.model_type = model_type
    student_dict = train_init(args)
    student = student_dict['model']
    student.load_state_dict(torch.load(ckpt_path)['model'])
    return student_dict 

In [7]:
scratch_dict = init_and_load_student_model(student_from_scratch_fp)
distilled_dict = init_and_load_student_model(student_distilled_fp)
teacher = TeacherNLPClassifier(teacher_dir)

test_loader = scratch_dict['test_loader']

### Evaluate student trained from scratch

In [8]:
results = evaluate(
    **scratch_dict, 
    loader=test_loader, 
    subset='test',
    probs_to_labels=probs_to_labels, 
    all_labels=all_labels
)
print_eval_res(results)

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Accuracy:       	av=77.6%  
F1 Scores:      	negative=0.625  neutral=0.845  positive=0.675  av=0.715  av_weight=0.772  micro=0.776  
Confusion [negative,neutral,positive]
[[ 70  36   8]
 [ 20 516  51]
 [ 20  82 167]]


### Evaluate distilled student
Distilled student will be evaluated on x2 validation subsets: 
1. The one used above to enable comparisons with the student trained from scratch
2. The validation subset used when the teacher was trained to facilitate comparisons with the teacher.

In [9]:
# 1: initial test subset
results = evaluate(
    **distilled_dict, 
    loader=test_loader, 
    subset='test',
    probs_to_labels=probs_to_labels, 
    all_labels=all_labels
)
print_eval_res(results)

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Accuracy:       	av=83.5%  
F1 Scores:      	negative=0.742  neutral=0.884  positive=0.759  av=0.795  av_weight=0.832  micro=0.835  
Confusion [negative,neutral,positive]
[[ 82  26   6]
 [ 12 536  39]
 [ 13  64 192]]


In [10]:
# init teacher's validation loader
dataset = CSVTextDataset(csv_file='./data/val.csv', headers=['text', 'negative', 'neutral', 'positive'])
teacher_dataloader = DataLoader(
        dataset,
        batch_size=16,
        shuffle=False,
        collate_fn=dataset.collate_batch,
)

In [13]:
# 2: teacher validation's set for distilled student 
t1 = time.time()
results = evaluate(
    **distilled_dict, 
    loader=teacher_dataloader, 
    subset='test',
    probs_to_labels=probs_to_labels, 
    all_labels=all_labels
)
t2 = time.time()
print_eval_res(results)
print(f'Time taken for evaluation = {(t2 - t1):.3f}s')

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Accuracy:       	av=80.9%  
F1 Scores:      	negative=0.766  neutral=0.859  positive=0.715  av=0.780  av_weight=0.807  micro=0.809  
Confusion [negative,neutral,positive]
[[ 72  21   3]
 [  8 378  45]
 [ 12  50 138]]
Time taken for evaluation = 1.810s


### Evaluate teacher

In [12]:
t1 = time.time()
results = evaluate(
    model=teacher,
    unpack_batch_fn=unpack_batch_teacher,
    loader=teacher_dataloader, 
    subset='test',
    probs_to_labels=probs_to_labels, 
    all_labels=all_labels
)
t2 = time.time()
print_eval_res(results)
print(f'Time taken for evaluation = {(t2 - t1):.3f}s')

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Accuracy:       	av=87.9%  
F1 Scores:      	negative=0.863  neutral=0.902  positive=0.838  av=0.868  av_weight=0.879  micro=0.879  
Confusion [negative,neutral,positive]
[[ 82  12   2]
 [  9 384  38]
 [  3  24 173]]
Time taken for evaluation = 17.568s


## Summary
- Distillation improved the topline accuracy significantly vs the student trained from scratch. 
- It also reduced inference time by 10x (on large GPU - improvements on CPU likely greater).
