In [None]:
%cd ..

In [2]:
import os

os.environ['SEED'] = "42"

import dataclasses
from pathlib import Path
import warnings

import nlp
import torch
import numpy as np
import torch.nn.functional as F
from transformers import (
    BertForSequenceClassification, 
    DistilBertConfig,
    DistilBertForSequenceClassification
)
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.model_selection import train_test_split

from pytorch_helper_bot import (
    BaseBot, MovingAverageStatsTrackerCallback,  CheckpointCallback,
    LearningRateSchedulerCallback, MultiStageScheduler, Top1Accuracy,
    LinearLR, Callback
)

try:
    from apex import amp
    APEX_AVAILABLE = True
except ModuleNotFoundError:
    APEX_AVAILABLE = False


In [3]:
CACHE_DIR = Path("cache/")
CACHE_DIR.mkdir(exist_ok=True)

In [4]:
class SST2Dataset(torch.utils.data.Dataset):
    def __init__(self, entries_dict, temperature=1):
        super().__init__()
        self.entries_dict = entries_dict
        self.temperature = temperature
    
    def __len__(self):
        return len(self.entries_dict["label"])
    
    def __getitem__(self, idx):
        return (
            self.entries_dict["input_ids"][idx],
            self.entries_dict["attention_mask"][idx],
            {
                "label": self.entries_dict["label"][idx], 
                "logits": self.entries_dict["logits"][idx] / self.temperature
            }
        )

In [5]:
train_dict, valid_dict, test_dict = torch.load(str(CACHE_DIR / "distill-dicts-augmented.jbl"))

In [6]:
# Instantiate a PyTorch Dataloader around our dataset
TEMPERATURE = 2.
train_loader = torch.utils.data.DataLoader(SST2Dataset(train_dict, temperature=TEMPERATURE), batch_size=64, shuffle=True)
valid_loader = torch.utils.data.DataLoader(SST2Dataset(valid_dict, temperature=TEMPERATURE), batch_size=64, drop_last=False)
test_loader = torch.utils.data.DataLoader(SST2Dataset(test_dict, temperature=1.), batch_size=64, drop_last=False)

In [7]:
ALPHA = 0
DISTILL_OBJECTIVE = torch.nn.MSELoss()

def cross_entropy(logits, targets):
    targets = F.softmax(targets, dim=-1)
    return -(targets * F.log_softmax(logits, dim=-1)).sum(dim=1).mean()

def distill_loss(logits, targets):
#     distill_part = F.binary_cross_entropy_with_logits(
#         logits[:, 1], targets["logits"][:, 1]
#     )
    distill_part = cross_entropy(
        logits, targets["logits"]
    )
    classification_part = F.cross_entropy(
        logits, targets["label"]
    )
    return ALPHA * classification_part + (1-ALPHA) * distill_part

In [8]:
bert_model = BertForSequenceClassification.from_pretrained(str(CACHE_DIR / "sst2_bert_uncased")).cpu()

In [9]:
bert_model.bert.embeddings.word_embeddings.weight.shape

torch.Size([30522, 768])

In [10]:
config = DistilBertConfig(
    vocab_size=30522, 
    max_position_embeddings=128, 
    sinusoidal_pos_embds=False, 
    n_layers=2, n_heads=6, dim=768, 
    hidden_dim=1536, dropout=0.1, 
    attention_dropout=0.1, activation='gelu', 
    initializer_range=0.02, qa_dropout=0.1, 
    seq_classif_dropout=0.5
)
distill_bert_model = DistilBertForSequenceClassification(
    config
)

In [11]:
bert_model.bert.embeddings.position_embeddings.weight.data.shape

torch.Size([512, 768])

In [12]:
distill_bert_model.distilbert.embeddings.word_embeddings.weight.data = bert_model.bert.embeddings.word_embeddings.weight.data
distill_bert_model.distilbert.embeddings.position_embeddings.weight.data = bert_model.bert.embeddings.position_embeddings.weight.data[:128]

In [13]:
# Freeze the embedding layer
for param in distill_bert_model.distilbert.embeddings.parameters():
    param.requires_grad = False

In [14]:
distill_bert_model =distill_bert_model.cuda()

In [15]:
del bert_model

In [16]:
optimizer = torch.optim.Adam(distill_bert_model.parameters(), lr=1e-4, betas=(0.9, 0.99))

In [17]:
if APEX_AVAILABLE:
    distill_bert_model, optimizer = amp.initialize(
        distill_bert_model, optimizer, opt_level="O1"
    )

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


In [18]:
class DistillTop1Accuracy(Top1Accuracy):
    def __call__(self, truth, pred):
        truth = truth["label"]
        return super().__call__(truth, pred)

In [19]:
@dataclasses.dataclass
class SST2Bot(BaseBot):
    log_dir = CACHE_DIR / "logs"
    
    def __post_init__(self):
        super().__post_init__()
        self.loss_format = "%.6f"

    @staticmethod
    def extract_prediction(output):
        return output[0]

In [20]:
total_steps = len(train_loader) * 5

checkpoints = CheckpointCallback(
    keep_n_checkpoints=1,
    checkpoint_dir=CACHE_DIR / "distill_model_cache/",
    monitor_metric="loss"
)
lr_durations = [
    int(total_steps*0.2),
    int(np.ceil(total_steps*0.8))
]
break_points = [0] + list(np.cumsum(lr_durations))[:-1]
callbacks = [
    MovingAverageStatsTrackerCallback(
        avg_window=len(train_loader) // 8,
        log_interval=len(train_loader) // 10
    ),
    LearningRateSchedulerCallback(
        MultiStageScheduler(
            [
                LinearLR(optimizer, 0.01, lr_durations[0]),
                CosineAnnealingLR(optimizer, lr_durations[1])
            ],
            start_at_epochs=break_points
        )
    ),
    checkpoints
]
    
bot = SST2Bot(
    log_dir = CACHE_DIR / "distill_logs",
    model=distill_bert_model, 
    train_loader=train_loader,
    valid_loader=valid_loader, 
    clip_grad=10.,
    optimizer=optimizer, echo=True,
    criterion=distill_loss,
    callbacks=callbacks,
    pbar=False, use_tensorboard=False,
    use_amp=APEX_AVAILABLE,
    metrics=(DistillTop1Accuracy(),)
)

[INFO][07/01/2020 20:13:59] SEED: 42
[INFO][07/01/2020 20:13:59] # of parameters: 33,586,946
[INFO][07/01/2020 20:13:59] # of trainable parameters: 10,046,210


In [21]:
print(total_steps)

bot.train(
    total_steps=total_steps,
    checkpoint_interval=len(train_loader) // 2
)
bot.load_model(checkpoints.best_performers[0][1])
checkpoints.remove_checkpoints(keep=0)

[INFO][07/01/2020 20:13:59] Optimizer Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.99)
    eps: 1e-08
    initial_lr: 0.0001
    lr: 0.0001
    weight_decay: 0
)
[INFO][07/01/2020 20:13:59] Batches per epoch: 3157


15785


[INFO][07/01/2020 20:14:07] Step   315 | loss 0.676484 | lr: 1.09e-05 | 0.024s per step
[INFO][07/01/2020 20:14:15] Step   630 | loss 0.514955 | lr: 2.08e-05 | 0.024s per step
[INFO][07/01/2020 20:14:22] Step   945 | loss 0.436636 | lr: 3.06e-05 | 0.023s per step
[INFO][07/01/2020 20:14:30] Step  1260 | loss 0.425980 | lr: 4.05e-05 | 0.024s per step
[INFO][07/01/2020 20:14:37] Step  1575 | loss 0.414432 | lr: 5.04e-05 | 0.024s per step
[INFO][07/01/2020 20:14:37] Metrics at step 1578:
[INFO][07/01/2020 20:14:37] loss: 0.466530
[INFO][07/01/2020 20:14:37] accuracy: 83.03%
[INFO][07/01/2020 20:14:45] Step  1890 | loss 0.404962 | lr: 6.03e-05 | 0.026s per step
[INFO][07/01/2020 20:14:53] Step  2205 | loss 0.394814 | lr: 7.02e-05 | 0.025s per step
[INFO][07/01/2020 20:15:01] Step  2520 | loss 0.390542 | lr: 8.00e-05 | 0.024s per step
[INFO][07/01/2020 20:15:08] Step  2835 | loss 0.378836 | lr: 8.99e-05 | 0.024s per step
[INFO][07/01/2020 20:15:16] Step  3150 | loss 0.371571 | lr: 9.98e-05 

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 131072.0


[INFO][07/01/2020 20:15:54] Step  4725 | loss 0.345210 | lr: 9.63e-05 | 0.023s per step
[INFO][07/01/2020 20:15:54] Metrics at step 4734:
[INFO][07/01/2020 20:15:54] loss: 0.442633
[INFO][07/01/2020 20:15:54] accuracy: 84.17%
[INFO][07/01/2020 20:16:01] Step  5040 | loss 0.341085 | lr: 9.47e-05 | 0.024s per step
[INFO][07/01/2020 20:16:09] Step  5355 | loss 0.338211 | lr: 9.28e-05 | 0.024s per step
[INFO][07/01/2020 20:16:16] Step  5670 | loss 0.335036 | lr: 9.06e-05 | 0.023s per step
[INFO][07/01/2020 20:16:24] Step  5985 | loss 0.332460 | lr: 8.82e-05 | 0.024s per step
[INFO][07/01/2020 20:16:31] Step  6300 | loss 0.329597 | lr: 8.55e-05 | 0.024s per step
[INFO][07/01/2020 20:16:32] Metrics at step 6312:
[INFO][07/01/2020 20:16:32] loss: 0.457789
[INFO][07/01/2020 20:16:32] accuracy: 83.49%
[INFO][07/01/2020 20:16:39] Step  6615 | loss 0.317121 | lr: 8.27e-05 | 0.024s per step
[INFO][07/01/2020 20:16:46] Step  6930 | loss 0.316913 | lr: 7.96e-05 | 0.024s per step
[INFO][07/01/2020 20

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 131072.0


[INFO][07/01/2020 20:17:02] Step  7560 | loss 0.313495 | lr: 7.29e-05 | 0.025s per step
[INFO][07/01/2020 20:17:10] Step  7875 | loss 0.312311 | lr: 6.94e-05 | 0.025s per step
[INFO][07/01/2020 20:17:10] Metrics at step 7890:
[INFO][07/01/2020 20:17:10] loss: 0.436516
[INFO][07/01/2020 20:17:10] accuracy: 82.57%
[INFO][07/01/2020 20:17:18] Step  8190 | loss 0.313470 | lr: 6.57e-05 | 0.025s per step
[INFO][07/01/2020 20:17:25] Step  8505 | loss 0.308677 | lr: 6.20e-05 | 0.024s per step
[INFO][07/01/2020 20:17:33] Step  8820 | loss 0.307265 | lr: 5.81e-05 | 0.024s per step
[INFO][07/01/2020 20:17:40] Step  9135 | loss 0.303628 | lr: 5.42e-05 | 0.024s per step
[INFO][07/01/2020 20:17:48] Step  9450 | loss 0.304623 | lr: 5.03e-05 | 0.024s per step
[INFO][07/01/2020 20:17:48] Metrics at step 9468:
[INFO][07/01/2020 20:17:48] loss: 0.429833
[INFO][07/01/2020 20:17:48] accuracy: 84.86%
[INFO][07/01/2020 20:17:56] Step  9765 | loss 0.300152 | lr: 4.64e-05 | 0.025s per step
[INFO][07/01/2020 20

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 524288.0


[INFO][07/01/2020 20:19:28] Step 13545 | loss 0.289868 | lr: 7.58e-06 | 0.024s per step
[INFO][07/01/2020 20:19:35] Step 13860 | loss 0.290081 | lr: 5.63e-06 | 0.024s per step


Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 262144.0


[INFO][07/01/2020 20:19:43] Step 14175 | loss 0.289177 | lr: 3.96e-06 | 0.025s per step
[INFO][07/01/2020 20:19:44] Metrics at step 14202:
[INFO][07/01/2020 20:19:44] loss: 0.441592
[INFO][07/01/2020 20:19:44] accuracy: 85.09%
[INFO][07/01/2020 20:19:51] Step 14490 | loss 0.288174 | lr: 2.58e-06 | 0.025s per step
[INFO][07/01/2020 20:19:59] Step 14805 | loss 0.286645 | lr: 1.48e-06 | 0.024s per step
[INFO][07/01/2020 20:20:06] Step 15120 | loss 0.288812 | lr: 6.85e-07 | 0.025s per step
[INFO][07/01/2020 20:20:14] Step 15435 | loss 0.290974 | lr: 1.91e-07 | 0.025s per step
[INFO][07/01/2020 20:20:22] Step 15750 | loss 0.289079 | lr: 2.01e-09 | 0.025s per step
[INFO][07/01/2020 20:20:23] Metrics at step 15780:
[INFO][07/01/2020 20:20:23] loss: 0.441398
[INFO][07/01/2020 20:20:23] accuracy: 85.09%
[INFO][07/01/2020 20:20:23] Training finished. Best step(s):
[INFO][07/01/2020 20:20:23] loss: 0.429833 @ step 9468
[INFO][07/01/2020 20:20:23] accuracy: 85.55% @ step 11046


In [22]:
bot.eval(valid_loader)

{'loss': (0.4298326703933401, '0.429833'),
 'accuracy': (-0.8486238532110092, '84.86%')}

In [23]:
bot.eval(test_loader)

{'loss': (0.3775219143530644, '0.377522'),
 'accuracy': (-0.8325688073394495, '83.26%')}