In [None]:
%cd ..

In [2]:
import os

os.environ['SEED'] = "42"

import dataclasses
from pathlib import Path
import warnings

import nlp
import torch
import numpy as np
import torch.nn.functional as F
from transformers import BertForSequenceClassification
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.model_selection import train_test_split

from pytorch_helper_bot import (
    BaseBot, MovingAverageStatsTrackerCallback,  CheckpointCallback,
    LearningRateSchedulerCallback, MultiStageScheduler, Top1Accuracy,
    LinearLR, Callback
)

from nobita.models import get_sequence_model


try:
    from apex import amp
    APEX_AVAILABLE = True
except ModuleNotFoundError:
    APEX_AVAILABLE = False


In [3]:
CACHE_DIR = Path("cache/")
CACHE_DIR.mkdir(exist_ok=True)

Reference:

    * https://github.com/huggingface/nlp/blob/master/notebooks/Overview.ipynb

In [4]:
class SST2Dataset(torch.utils.data.Dataset):
    def __init__(self, entries_dict, temperature=1):
        super().__init__()
        self.entries_dict = entries_dict
        self.temperature = temperature
    
    def __len__(self):
        return len(self.entries_dict["label"])
    
    def __getitem__(self, idx):
        return (
            self.entries_dict["input_ids"][idx],
            self.entries_dict["attention_mask"][idx].sum(), #input_lengths
            {
                "label": self.entries_dict["label"][idx], 
                "logits": self.entries_dict["logits"][idx] / self.temperature
            }
        )

In [5]:
train_dict, valid_dict, test_dict = torch.load(str(CACHE_DIR / "distill-dicts-augmented.jbl"))

In [6]:
# Instantiate a PyTorch Dataloader around our dataset
TEMPERATURE = 2.
train_loader = torch.utils.data.DataLoader(SST2Dataset(train_dict, temperature=TEMPERATURE), batch_size=64, shuffle=True)
valid_loader = torch.utils.data.DataLoader(SST2Dataset(valid_dict, temperature=TEMPERATURE), batch_size=64, drop_last=False)
test_loader = torch.utils.data.DataLoader(SST2Dataset(test_dict, temperature=1.), batch_size=64, drop_last=False)

In [7]:
ALPHA = 0
DISTILL_OBJECTIVE = torch.nn.MSELoss()

def cross_entropy(logits, targets):
    targets = F.softmax(targets, dim=-1)
    return -(targets * F.log_softmax(logits, dim=-1)).sum(dim=1).mean()

def distill_loss(logits, targets):
#     distill_part = F.binary_cross_entropy_with_logits(
#         logits[:, 1], targets["logits"][:, 1]
#     )
    distill_part = cross_entropy(
        logits, targets["logits"]
    )
    classification_part = F.cross_entropy(
        logits, targets["label"]
    )
    return ALPHA * classification_part + (1-ALPHA) * distill_part

In [8]:
bert_model = BertForSequenceClassification.from_pretrained(str(CACHE_DIR / "sst2_bert_uncased")).cpu()

In [9]:
bert_model.bert.embeddings.word_embeddings.weight.shape

torch.Size([30522, 768])

In [10]:
# Note: apex does not support weight dropping
model = get_sequence_model(
    voc_size=bert_model.bert.embeddings.word_embeddings.weight.shape[0],
    emb_size=bert_model.bert.embeddings.word_embeddings.weight.shape[1],
    pad_idx = 0,
    dropoute = 0,
    rnn_hid = 768,
    rnn_layers = 2,
    bidir = True,
    dropouth = 0.25,
    dropouti = 0.25,
    wdrop = 0,
    unit_type = "gru",
    fcn_layers = [512, 2],
    fcn_dropouts = [0.25, 0.25],
    use_attention = True
)

In [11]:
model

SequenceModel(
  (embeddings): BasicEmbeddings(
    (encoder): Embedding(30522, 768, padding_idx=0)
  )
  (encoder): RNNStack(
    (rnns): ModuleList(
      (0): GRU(768, 384, bidirectional=True)
      (1): GRU(768, 384, bidirectional=True)
    )
    (dropouti): LockedDropout()
    (dropouths): ModuleList(
      (0): LockedDropout()
      (1): LockedDropout()
    )
  )
  (fcn): AttentionFCN(
    (attention): Attention(768, return attention=False)
    (layers): ModuleList(
      (0): LinearBlock(
        (lin): Linear(in_features=768, out_features=512, bias=True)
        (drop): Dropout(p=0.25, inplace=False)
        (bn): BatchNorm1d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): LinearBlock(
        (lin): Linear(in_features=512, out_features=2, bias=True)
        (drop): Dropout(p=0.25, inplace=False)
        (bn): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
)

In [12]:
# Copy the embedding weights to the LSTM model
try:
    model.embeddings.encoder.emb.weight.data = bert_model.bert.embeddings.word_embeddings.weight.data
except:
    model.embeddings.encoder.weight.data = bert_model.bert.embeddings.word_embeddings.weight.data

In [13]:
# Freeze the embedding layer
for param in model.embeddings.encoder.parameters():
    param.requires_grad = False

In [14]:
model = model.cuda()

In [15]:
# Use only leaf tensors
parameters = [x for x in model.parameters() if x.is_leaf and x.requires_grad]

In [16]:
del bert_model

In [17]:
optimizer = torch.optim.Adam(parameters, lr=1e-3, betas=(0.8, 0.99))
# optimizer = torch.optim.RMSprop(parameters, lr=0.01)

In [18]:
if APEX_AVAILABLE:
    model, optimizer = amp.initialize(
        model, optimizer, opt_level="O1"
    )

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


In [19]:
class TransposeCallback(Callback):
    def on_batch_inputs(self, bot, input_tensors, targets):
        input_tensors = [input_tensors[0].transpose(1, 0), input_tensors[1]]
        return input_tensors, targets

In [20]:
class DistillTop1Accuracy(Top1Accuracy):
    def __call__(self, truth, pred):
        truth = truth["label"]
        return super().__call__(truth, pred)

In [21]:
total_steps = len(train_loader) * 5

checkpoints = CheckpointCallback(
    keep_n_checkpoints=1,
    checkpoint_dir=CACHE_DIR / "distill_model_cache/",
    monitor_metric="loss"
)
lr_durations = [
    int(total_steps*0.2),
    int(np.ceil(total_steps*0.8))
]
break_points = [0] + list(np.cumsum(lr_durations))[:-1]
callbacks = [
    MovingAverageStatsTrackerCallback(
        avg_window=len(train_loader) // 8,
        log_interval=len(train_loader) // 10
    ),
    LearningRateSchedulerCallback(
        MultiStageScheduler(
            [
                LinearLR(optimizer, 0.01, lr_durations[0]),
                CosineAnnealingLR(optimizer, lr_durations[1])
            ],
            start_at_epochs=break_points
        )
    ),
    checkpoints,
    TransposeCallback()
]
    
bot = BaseBot(
    log_dir = CACHE_DIR / "distill_logs",
    model=model, 
    train_loader=train_loader,
    valid_loader=valid_loader, 
    clip_grad=10.,
    optimizer=optimizer, echo=True,
    criterion=distill_loss,
    callbacks=callbacks,
    pbar=False, use_tensorboard=False,
    use_amp=APEX_AVAILABLE,
    metrics=(DistillTop1Accuracy(),)
)

[INFO][06/22/2020 17:38:45] SEED: 42
[INFO][06/22/2020 17:38:45] # of parameters: 29,156,610
[INFO][06/22/2020 17:38:45] # of trainable parameters: 5,715,714


In [22]:
print(total_steps)

bot.train(
    total_steps=total_steps,
    checkpoint_interval=len(train_loader) // 2
)
bot.load_model(checkpoints.best_performers[0][1])
checkpoints.remove_checkpoints(keep=0)

[INFO][06/22/2020 17:38:45] Optimizer Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.8, 0.99)
    eps: 1e-08
    initial_lr: 0.001
    lr: 0.001
    weight_decay: 0
)
[INFO][06/22/2020 17:38:45] Batches per epoch: 3157


15785


[INFO][06/22/2020 17:38:52] Step   315 | loss 0.90497637 | lr: 1.09e-04 | 0.021s per step
[INFO][06/22/2020 17:38:58] Step   630 | loss 0.70167393 | lr: 2.08e-04 | 0.019s per step
[INFO][06/22/2020 17:39:04] Step   945 | loss 0.60207349 | lr: 3.06e-04 | 0.019s per step
[INFO][06/22/2020 17:39:10] Step  1260 | loss 0.54594113 | lr: 4.05e-04 | 0.019s per step
[INFO][06/22/2020 17:39:16] Step  1575 | loss 0.50735221 | lr: 5.04e-04 | 0.019s per step
[INFO][06/22/2020 17:39:16] Metrics at step 1578:
[INFO][06/22/2020 17:39:16] loss: 0.47315171
[INFO][06/22/2020 17:39:16] accuracy: 80.96%
[INFO][06/22/2020 17:39:22] Step  1890 | loss 0.47063435 | lr: 6.03e-04 | 0.020s per step
[INFO][06/22/2020 17:39:28] Step  2205 | loss 0.44136320 | lr: 7.02e-04 | 0.019s per step
[INFO][06/22/2020 17:39:34] Step  2520 | loss 0.43186608 | lr: 8.00e-04 | 0.019s per step
[INFO][06/22/2020 17:39:40] Step  2835 | loss 0.42349426 | lr: 8.99e-04 | 0.019s per step
[INFO][06/22/2020 17:39:46] Step  3150 | loss 0.41

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 262144.0


[INFO][06/22/2020 17:40:58] Step  6930 | loss 0.34986204 | lr: 7.96e-04 | 0.019s per step
[INFO][06/22/2020 17:41:04] Step  7245 | loss 0.34427714 | lr: 7.64e-04 | 0.019s per step
[INFO][06/22/2020 17:41:10] Step  7560 | loss 0.34157666 | lr: 7.29e-04 | 0.019s per step
[INFO][06/22/2020 17:41:16] Step  7875 | loss 0.34239671 | lr: 6.94e-04 | 0.019s per step
[INFO][06/22/2020 17:41:17] Metrics at step 7890:
[INFO][06/22/2020 17:41:17] loss: 0.37721515
[INFO][06/22/2020 17:41:17] accuracy: 88.30%
[INFO][06/22/2020 17:41:23] Step  8190 | loss 0.34349061 | lr: 6.57e-04 | 0.020s per step
[INFO][06/22/2020 17:41:29] Step  8505 | loss 0.33809089 | lr: 6.20e-04 | 0.019s per step
[INFO][06/22/2020 17:41:35] Step  8820 | loss 0.33803396 | lr: 5.81e-04 | 0.019s per step
[INFO][06/22/2020 17:41:41] Step  9135 | loss 0.33576415 | lr: 5.42e-04 | 0.019s per step


Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 262144.0


[INFO][06/22/2020 17:41:46] Step  9450 | loss 0.32987089 | lr: 5.03e-04 | 0.019s per step
[INFO][06/22/2020 17:41:47] Metrics at step 9468:
[INFO][06/22/2020 17:41:47] loss: 0.36626294
[INFO][06/22/2020 17:41:47] accuracy: 88.76%
[INFO][06/22/2020 17:41:53] Step  9765 | loss 0.32613052 | lr: 4.64e-04 | 0.020s per step
[INFO][06/22/2020 17:41:59] Step 10080 | loss 0.32443664 | lr: 4.25e-04 | 0.019s per step
[INFO][06/22/2020 17:42:05] Step 10395 | loss 0.32083841 | lr: 3.86e-04 | 0.019s per step
[INFO][06/22/2020 17:42:11] Step 10710 | loss 0.32213914 | lr: 3.49e-04 | 0.019s per step
[INFO][06/22/2020 17:42:17] Step 11025 | loss 0.32125724 | lr: 3.12e-04 | 0.019s per step
[INFO][06/22/2020 17:42:17] Metrics at step 11046:
[INFO][06/22/2020 17:42:17] loss: 0.36650688
[INFO][06/22/2020 17:42:17] accuracy: 87.84%
[INFO][06/22/2020 17:42:23] Step 11340 | loss 0.32072177 | lr: 2.76e-04 | 0.019s per step
[INFO][06/22/2020 17:42:29] Step 11655 | loss 0.31864733 | lr: 2.42e-04 | 0.019s per step

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 524288.0


[INFO][06/22/2020 17:43:05] Step 13545 | loss 0.30846991 | lr: 7.58e-05 | 0.019s per step
[INFO][06/22/2020 17:43:11] Step 13860 | loss 0.30989205 | lr: 5.63e-05 | 0.019s per step
[INFO][06/22/2020 17:43:17] Step 14175 | loss 0.30981290 | lr: 3.96e-05 | 0.019s per step
[INFO][06/22/2020 17:43:17] Metrics at step 14202:
[INFO][06/22/2020 17:43:17] loss: 0.35285248
[INFO][06/22/2020 17:43:17] accuracy: 88.30%
[INFO][06/22/2020 17:43:23] Step 14490 | loss 0.31023143 | lr: 2.58e-05 | 0.019s per step
[INFO][06/22/2020 17:43:29] Step 14805 | loss 0.30981392 | lr: 1.48e-05 | 0.019s per step
[INFO][06/22/2020 17:43:35] Step 15120 | loss 0.31230602 | lr: 6.85e-06 | 0.019s per step


Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 524288.0


[INFO][06/22/2020 17:43:41] Step 15435 | loss 0.31075922 | lr: 1.91e-06 | 0.019s per step
[INFO][06/22/2020 17:43:47] Step 15750 | loss 0.30959355 | lr: 2.01e-08 | 0.019s per step
[INFO][06/22/2020 17:43:47] Metrics at step 15780:
[INFO][06/22/2020 17:43:47] loss: 0.35488558
[INFO][06/22/2020 17:43:47] accuracy: 88.76%
[INFO][06/22/2020 17:43:47] Training finished. Best step(s):
[INFO][06/22/2020 17:43:47] loss: 0.35020860 @ step 12624
[INFO][06/22/2020 17:43:47] accuracy: 88.76% @ step 9468


In [23]:
bot.eval(valid_loader)

{'loss': (0.35020703350732085, '0.35020703'),
 'accuracy': (-0.8830275229357798, '88.30%')}

In [24]:
bot.eval(test_loader)

{'loss': (0.28759598021113547, '0.28759598'),
 'accuracy': (-0.8692660550458715, '86.93%')}