## Setup

### Installs
Setup environment once by running these cells when AWS instance is turned on

In [None]:
# !pip install transformers
# # #!pip install --upgrade torch --- do I really need this? 
# # # # makes deep learning container have diff environment --- or switch to pytorch_p37 latest?
# !pip install papermill
# # # !pip install lab-ml
# !pip install allennlp
# !pip install tensorboard

In [None]:
# %cd /home/ec2-user/SageMaker
# !git clone https://github.com/bobub/TextBrewer.git
# %cd TextBrewer
# !python setup.py install
# import os, sys
# sys.path.insert(0, "/home/ec2-user/SageMaker/TextBrewer/build/lib")

In [None]:
# %cd /home/ec2-user/SageMaker
# !git clone https://github.com/bobub/nn.git
# %cd /home/ec2-user/SageMaker/nn
# !python setup.py install
# import sys, os
# sys.path.insert(0, "/home/ec2-user/SageMaker/nn/build/lib")

### Imports

In [1]:
import textbrewer
from textbrewer import GeneralDistiller
from textbrewer import TrainingConfig, DistillationConfig
from transformers import AutoConfig, AutoTokenizer, AutoModel, AdamW, AutoModelForMaskedLM
from transformers import get_linear_schedule_with_warmup, get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm import tqdm
import seaborn as sns
import pandas as pd
import itertools
import matplotlib.pyplot as plt
import pickle
from sagemaker import get_execution_role
import boto3
import time
import io
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import mean_squared_error
from functools import partial
from labml_nn.transformers.switch import SwitchTransformer, SwitchTransformerLayer, SwitchFeedForward
from labml_nn.transformers import MultiHeadAttention
from labml_nn.transformers.feed_forward import FeedForward
import torch.nn as nn
import datetime
import math
import json
from utils import *
from distil_funcs import *
#import papermill
from sklearn.decomposition import PCA
from transformers import BertModel, BertConfig

### Papermill Arguments
Papermill is used for hyperparameter tuning - see hyperparam_search.ipynb

In [2]:
# shows the tunable parameters during hyperparam search
# papermill.inspect_notebook('distillation.ipynb')

### General Arguments

In [3]:
# GENERAL SETTINGS - for setting individual runs
TRAIN_DATA = 10000000
EVAL_DATA = 5000
BATCH_SIZE = 128
SHUFFLE = True
NUM_WORKERS = 4

# model
STUDENT = 'switch'
#'bert-base-multilingual-cased'
#'nreimers/TinyBERT_L-6_H-768_v2' , 
#"nreimers/TinyBERT_L-4_H-312_v2", 
# "distilbert-base-multilingual-cased"

# training config
NUM_STEPS = 80000 # steps of optimisation
NUM_EPOCHS = 5 # recommended 30-50
CKPT_FREQ = 1 # save X times per epoch
CKPT_EPOCH_FREQ = 1 # save every X epochs
CKPT_STEPS = 1 # save every X steps

#LOG_DIR = 's3://eu1-sagemaker-bucket/borisbubla/log/run-' + STUDENT + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # '/home/ec2-user/SageMaker/log'
OUTPUT_DIR = 'models/'+STUDENT+'/time-'+datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
FP16 = False # use Apex mixed precision training
FP16_OPT_LEVEL = 'O1'
DATA_PARALLEL = False # use DataParallel training
LOCAL_RANK = -1 # Distributed DataParallel, check documentation
LEARNING_RATE = 1e-4
SCHEDULER_CLASS = get_linear_schedule_with_warmup
OPTIMIZER = AdamW
EVAL_METRIC = 'cosine_similarity'
DEVICE = 'cuda'

# distillation config
TEMPERATURE = 9
TEMPERATURE_SCHEDULER = 'constant'
KD_LOSS_TYPE = 'cos'
KD_LOSS_WEIGHT = 10 # (kd_weight,int_match_weight)
INT_LOSS_WEIGHT = 1
KD_WEIGHT_SCHEDULER = 'none'
IS_CACHING_LOGITS = False

# switch transformer settings
D_MODEL = 768
HEADS = 12
DROPOUT = 0.1
DROPOUT_FFN = 0.4
D_FF = 12
N_LAYERS = 2
N_EXPERTS = 320

# intermediate matches - TextBrewer framework
int_matches = (11,N_LAYERS-1)


In [4]:
 # need to be executed post Papermill injection parameters cell

GRADIENT_ACCUMULATION_STEPS = 1 #TRAIN_DATA/100000 #reduces GPU memory usage by only calling optimizer.step() every X backward 
#steps.
INTERMEDIATE_MATCHES = [    
        {'layer_T':int_matches[0], 'layer_S':int_matches[1], 'feature':'attention', 'loss': 'attention_ce', 'weight' : INT_LOSS_WEIGHT},#]
        {'layer_T':int_matches[0], 'layer_S':int_matches[1], 'feature':'value_relation', 'loss': 'value_relation_ce', 'weight' : INT_LOSS_WEIGHT}
    ]

LOG_DIR = 's3://eu1-sagemaker-bucket/borisbubla/experiments/'+str(TRAIN_DATA/1000)+'k/'+STUDENT+'/LR'+str(LEARNING_RATE)+'LAY'+str(N_LAYERS)+'EXP'+str(N_EXPERTS)+'D_FF'+str(D_FF)+'TEMP'+str(TEMPERATURE)+'TIME-'+datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

In [5]:
# Switch Transformer Configuration
student_config = {
    'd_model': D_MODEL, # hidden dim of model
    'heads': HEADS, # attention heads
    'dropout':DROPOUT, # dropout in network except ffn
    'dropout_ffn':DROPOUT_FFN, # dropout in ffn 
    'd_ff': D_FF, # num features in FFN hidden layer
    'n_layers': N_LAYERS, # num of transformer layers
    'n_experts': N_EXPERTS, # number of FFN experts
    'load_balancing_loss_ceof': 0.01, # load balancing co-eff, encourages expert diversity
    'is_scale_prob': True, # whether to scale the selected expert outputs by routing probability
    'drop_tokens': False, # whether to drop tokens
    'capacity_factor':1.25, # capacity factor - seemed to work best in Switch Transformer
}

### Setup DataLoader

In [6]:
# load OpenSubtitles data files
train_file = '/home/ec2-user/SageMaker/data/train.pkl'
valid_file = '/home/ec2-user/SageMaker/data/dev.pkl'
encoded_train_set = load_pickle(train_file)
encoded_valid_set = load_pickle(valid_file)

In [8]:
# Create Datasets
open_sub_train = CustomDataset(encoded_train_set['input_ids'][:TRAIN_DATA], 
                                      encoded_train_set['token_type_ids'][:TRAIN_DATA],
                                      encoded_train_set['attention_mask'][:TRAIN_DATA])
eval_dataset = CustomDataset(encoded_valid_set['input_ids'][:EVAL_DATA], 
                                      encoded_valid_set['token_type_ids'][:EVAL_DATA],
                                      encoded_valid_set['attention_mask'][:EVAL_DATA])

# Parameters
data_params = {'batch_size': BATCH_SIZE,
          'shuffle': SHUFFLE,
          'num_workers': NUM_WORKERS}

# Create DataLoader
dataloader = DataLoader(open_sub_train, **data_params)
# validation_generator = DataLoader(open_sub_valid, **params)

In [9]:
# save space
del encoded_train_set
del encoded_valid_set

### Load Models

In [10]:
# set device - cuda or cpu
device = torch.device(DEVICE)
# load teacher
teacher_model = load_teacher(device=device)

Some weights of the model checkpoint at sentence-transformers/LaBSE were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [11]:
# loads student - performs word embedding compression and weight initialsations too
student_model = load_student(name=STUDENT, student_config=student_config, device=device, teacher_model=teacher_model, int_matches=int_matches, N_LAYERS=N_LAYERS)

## Distillation

### Distillation Adaptor and Callback

Manually run resume training cell or starting training cell

In [15]:
# IF RESUMING TRAINING
optimizer = OPTIMIZER(student_model.parameters(), lr=LEARNING_RATE)

# path of model state dict to resume training from
path = '/home/ec2-user/SageMaker/models/switch/time-20210620-201324/model_3136.pkl' 
checkpoint = torch.load(path, map_location = torch.device('cpu'))

# load optimiser and model
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
student_model.load_state_dict(checkpoint['model_state_dict'])

epoch = checkpoint['step']
losses_dict = checkpoint['losses_dict']
total_loss = checkpoint['total_loss']

student_model.train()
num_training_steps = NUM_STEPS

# arguments dict except optimiser 
scheduler_args = {'num_warmup_steps':int(0.1*num_training_steps), 'num_training_steps':num_training_steps}
del checkpoint

In [11]:
# IF STARTING TRAINING
# Epochs and training steps
num_training_steps = NUM_STEPS

# Optimizer and learning rate scheduler
optimizer = OPTIMIZER(student_model.parameters(), lr=LEARNING_RATE)

# arguments dict except 'optimizer'
scheduler_args = {'num_warmup_steps':int(0.1*num_training_steps), 'num_training_steps':num_training_steps}

In [16]:
callback_fun = partial(predict, teacher_model=teacher_model, eval_dataset=eval_dataset, device=device, STUDENT=STUDENT, BATCH_SIZE=BATCH_SIZE, eval_metric=EVAL_METRIC, feedback=True) # fill other arguments

### Distillation Configurations and Model Overview

In [17]:
# display model parameters statistics
print("\nteacher_model's parameters:")
result, _ = textbrewer.utils.display_parameters(teacher_model,max_level=4)
print (result)

print("student_model's parameters:")
result, _ = textbrewer.utils.display_parameters(student_model,max_level=4)
print (result)


teacher_model's parameters:

LAYER NAME                  	        #PARAMS	     RATIO	 MEM(MB)
--model:                    	    470,927,360	   100.00%	 1796.45
  --embeddings:             	    385,282,304	    81.81%	 1469.74
    --position_ids:         	            512	     0.00%	    0.00
    --word_embeddings       
      --weight:             	    384,885,504	    81.73%	 1468.22
    --position_embeddings   
      --weight:             	        393,216	     0.08%	    1.50
    --token_type_embeddings 
      --weight:             	          1,536	     0.00%	    0.01
    --LayerNorm:            	          1,536	     0.00%	    0.01
      --weight:             	            768	     0.00%	    0.00
      --bias:               	            768	     0.00%	    0.00
  --encoder                 
    --layer:                	     85,054,464	    18.06%	  324.46
      --0:                  	      7,087,872	     1.51%	   27.04
        --attention:        	      2,363,904	     0.50%	    9.02
        -

In [18]:
# Initialize configurations and distiller
train_config = TrainingConfig(
    gradient_accumulation_steps = GRADIENT_ACCUMULATION_STEPS,
    ckpt_frequency = CKPT_FREQ,
    ckpt_epoch_frequency = CKPT_EPOCH_FREQ,
    ckpt_steps = NUM_STEPS,
    output_dir = OUTPUT_DIR,
    fp16 = FP16,
    fp16_opt_level = FP16_OPT_LEVEL,
    data_parallel = DATA_PARALLEL,
    local_rank = LOCAL_RANK,
    device=device,
    log_dir=LOG_DIR

)
distill_config = DistillationConfig(
    temperature=TEMPERATURE,
    temperature_scheduler = TEMPERATURE_SCHEDULER,
    hard_label_weight=0,
    kd_loss_type=KD_LOSS_TYPE,
    kd_loss_weight = KD_LOSS_WEIGHT,
    kd_loss_weight_scheduler = KD_WEIGHT_SCHEDULER,
    is_caching_logits=IS_CACHING_LOGITS,
    intermediate_matches=INTERMEDIATE_MATCHES
)

In [19]:
print("TRAIN_CONFIG:")
print(train_config)

print("DISTILL_CONFIG:")
print(distill_config)

TRAIN_CONFIG:
gradient_accumulation_steps : 1
ckpt_frequency : 1
ckpt_epoch_frequency : 1
ckpt_steps : 28224
log_dir : s3://eu1-sagemaker-bucket/borisbubla/experiments/10000.0k/switch/LR0.0001LAY2EXP320D_FF12TEMP9TIME-20210621-142400
output_dir : models/switch/time-20210621-142359
device : cuda
fp16 : False
fp16_opt_level : O1
data_parallel : False
local_rank : -1

DISTILL_CONFIG:
temperature : 9
temperature_scheduler : <function constant_temperature_scheduler at 0x7facb2ae5ea0>
hard_label_weight : 0
hard_label_weight_scheduler : None
kd_loss_type : cos
kd_loss_weight : 10
kd_loss_weight_scheduler : None
probability_shift : False
intermediate_matches : [
IntermediateMatch: layer_T : 11, layer_S : 1, feature : attention, weight : 1, loss : attention_ce, proj : None, 
IntermediateMatch: layer_T : 11, layer_S : 1, feature : value_relation, weight : 1, loss : value_relation_ce, proj : None]
is_caching_logits : False



### Distillation Training

In [None]:
if STUDENT=='switch':
    distiller = GeneralDistiller(
    train_config=train_config, distill_config = distill_config,
    model_T = teacher_model, model_S = student_model, 
    adaptor_T = teacher_adaptor, adaptor_S = switch_student_adaptor,
    n_experts = student_config['n_experts'],
    load_balancing_loss_ceof = student_config['load_balancing_loss_ceof'])
    
if STUDENT!='switch':
    distiller = GeneralDistiller(
    train_config=train_config, distill_config = distill_config,
    model_T = teacher_model, model_S = student_model, 
    adaptor_T = simple_adaptor, adaptor_S = simple_adaptor,
    n_experts = 'none',
    load_balancing_loss_ceof = student_config['load_balancing_loss_ceof'])

start = time.time()
print('Start Time: ',start)
# Start distilling
with distiller:
    distiller.train(optimizer,dataloader, num_steps=NUM_STEPS, 
    scheduler_class=SCHEDULER_CLASS, scheduler_args = scheduler_args, callback=callback_fun)
end = time.time()

training_time = end-start
print('DONE! Time taken: ',training_time)

Start Time:  1624285505.461808
Global Step:  1  of  28224
Global Step:  2  of  28224
Global Step:  3  of  28224
Global Step:  4  of  28224
Global Step:  5  of  28224
Global Step:  6  of  28224
Global Step:  7  of  28224
Global Step:  8  of  28224
Global Step:  9  of  28224
Global Step:  10  of  28224
Global Step:  11  of  28224
Global Step:  12  of  28224
Global Step:  13  of  28224
Global Step:  14  of  28224
Global Step:  15  of  28224
Global Step:  16  of  28224
Global Step:  17  of  28224
Global Step:  18  of  28224
Global Step:  19  of  28224
Global Step:  20  of  28224
Global Step:  21  of  28224
Global Step:  22  of  28224
Global Step:  23  of  28224
Global Step:  24  of  28224
Global Step:  25  of  28224
Global Step:  26  of  28224
Global Step:  27  of  28224
Global Step:  28  of  28224
Global Step:  29  of  28224
Global Step:  30  of  28224
Global Step:  31  of  28224
Global Step:  32  of  28224
Global Step:  33  of  28224
Global Step:  34  of  28224
Global Step:  35  of  2822

In [None]:
# store important parameters 
important_hyperparams = {
    'training_data':TRAIN_DATA,
    'batch_size':BATCH_SIZE,
    'student':STUDENT,
    'learning_rate':LEARNING_RATE,
    'learning_rate_scheduler':SCHEDULER_CLASS,
    'temperature': TEMPERATURE,
    'temp_scheduler':TEMPERATURE_SCHEDULER,
    'loss_weights: kd, int_match': [KD_LOSS_WEIGHT, INT_LOSS_WEIGHT],
    'kd_weight_scheduler': KD_WEIGHT_SCHEDULER,
    'int_matches': INTERMEDIATE_MATCHES,
    'optimizer':OPTIMIZER,
    'student_config': student_config,
    'training_time':training_time
}

In [None]:
# display
important_hyperparams

In [None]:
# save the important hyperparams 
path = LOG_DIR + '/PARAMS.json'
save_json(important_hyperparams, path)

In [None]:
# save the model
path = LOG_DIR + '/trained_model_{}'.format(time.time())
write_torch(student_model.state_dict(), path)