In [1]:
%cd ..

/mnt/SSD_Data/active_projects/transformer_to_lstm


In [2]:
import os

os.environ['SEED'] = "42"

import dataclasses
from pathlib import Path
import warnings

import nlp
import torch
import numpy as np
import torch.nn.functional as F
from transformers import (
    BertForSequenceClassification,
    DistilBertForSequenceClassification
)
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.model_selection import train_test_split

from pytorch_helper_bot import (
    BaseBot, MovingAverageStatsTrackerCallback,  CheckpointCallback,
    LearningRateSchedulerCallback, MultiStageScheduler, Top1Accuracy,
    LinearLR, Callback
)

try:
    from apex import amp
    APEX_AVAILABLE = True
except ModuleNotFoundError:
    APEX_AVAILABLE = False


In [3]:
CACHE_DIR = Path("cache/")
CACHE_DIR.mkdir(exist_ok=True)

In [4]:
class SST2Dataset(torch.utils.data.Dataset):
    def __init__(self, entries_dict, temperature=1):
        super().__init__()
        self.entries_dict = entries_dict
        self.temperature = temperature
    
    def __len__(self):
        return len(self.entries_dict["label"])
    
    def __getitem__(self, idx):
        return (
            self.entries_dict["input_ids"][idx],
            self.entries_dict["attention_mask"][idx],
            {
                "label": self.entries_dict["label"][idx], 
                "logits": self.entries_dict["logits"][idx] / self.temperature
            }
        )

In [5]:
train_dict, valid_dict, test_dict = torch.load(str(CACHE_DIR / "distill-dicts-augmented.jbl"))

In [6]:
# Instantiate a PyTorch Dataloader around our dataset
TEMPERATURE = 2.
train_loader = torch.utils.data.DataLoader(SST2Dataset(train_dict, temperature=TEMPERATURE), batch_size=64, shuffle=True)
valid_loader = torch.utils.data.DataLoader(SST2Dataset(valid_dict, temperature=TEMPERATURE), batch_size=64, drop_last=False)
test_loader = torch.utils.data.DataLoader(SST2Dataset(test_dict, temperature=1.), batch_size=64, drop_last=False)

In [7]:
ALPHA = 0
DISTILL_OBJECTIVE = torch.nn.MSELoss()

def cross_entropy(logits, targets):
    targets = F.softmax(targets, dim=-1)
    return -(targets * F.log_softmax(logits, dim=-1)).sum(dim=1).mean()

def distill_loss(logits, targets):
#     distill_part = F.binary_cross_entropy_with_logits(
#         logits[:, 1], targets["logits"][:, 1]
#     )
    distill_part = cross_entropy(
        logits, targets["logits"]
    )
    classification_part = F.cross_entropy(
        logits, targets["label"]
    )
    return ALPHA * classification_part + (1-ALPHA) * distill_part

In [8]:
bert_model = BertForSequenceClassification.from_pretrained(str(CACHE_DIR / "sst2_bert_uncased")).cpu()

In [9]:
bert_model.bert.embeddings.word_embeddings.weight.shape

torch.Size([30522, 768])

In [10]:
distill_bert_model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")

In [11]:
bert_model.bert.embeddings.position_embeddings.weight.data.shape

torch.Size([512, 768])

In [12]:
# distill_bert_model.distilbert.embeddings.word_embeddings.weight.data = bert_model.bert.embeddings.word_embeddings.weight.data
# distill_bert_model.distilbert.embeddings.position_embeddings.weight.data = bert_model.bert.embeddings.position_embeddings.weight.data[:128]

In [13]:
# Freeze the embedding layer
# for param in distill_bert_model.distilbert.embeddings.parameters():
#     param.requires_grad = False

In [14]:
distill_bert_model =distill_bert_model.cuda()

In [15]:
del bert_model

In [16]:
optimizer = torch.optim.Adam(distill_bert_model.parameters(), lr=2e-5, betas=(0.9, 0.99))

In [17]:
if APEX_AVAILABLE:
    distill_bert_model, optimizer = amp.initialize(
        distill_bert_model, optimizer, opt_level="O1"
    )

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


In [18]:
class DistillTop1Accuracy(Top1Accuracy):
    def __call__(self, truth, pred):
        truth = truth["label"]
        return super().__call__(truth, pred)

In [19]:
@dataclasses.dataclass
class SST2Bot(BaseBot):
    log_dir = CACHE_DIR / "logs"
    
    def __post_init__(self):
        super().__post_init__()
        self.loss_format = "%.6f"

    @staticmethod
    def extract_prediction(output):
        return output[0]

In [20]:
total_steps = len(train_loader) * 5

checkpoints = CheckpointCallback(
    keep_n_checkpoints=1,
    checkpoint_dir=CACHE_DIR / "distill_model_cache/",
    monitor_metric="loss"
)
lr_durations = [
    int(total_steps*0.2),
    int(np.ceil(total_steps*0.8))
]
break_points = [0] + list(np.cumsum(lr_durations))[:-1]
callbacks = [
    MovingAverageStatsTrackerCallback(
        avg_window=len(train_loader) // 8,
        log_interval=len(train_loader) // 10
    ),
    LearningRateSchedulerCallback(
        MultiStageScheduler(
            [
                LinearLR(optimizer, 0.01, lr_durations[0]),
                CosineAnnealingLR(optimizer, lr_durations[1])
            ],
            start_at_epochs=break_points
        )
    ),
    checkpoints
]
    
bot = SST2Bot(
    log_dir = CACHE_DIR / "distill_logs",
    model=distill_bert_model, 
    train_loader=train_loader,
    valid_loader=valid_loader, 
    clip_grad=10.,
    optimizer=optimizer, echo=True,
    criterion=distill_loss,
    callbacks=callbacks,
    pbar=False, use_tensorboard=False,
    use_amp=APEX_AVAILABLE,
    metrics=(DistillTop1Accuracy(),)
)

[INFO][07/01/2020 21:13:08] SEED: 42
[INFO][07/01/2020 21:13:08] # of parameters: 66,955,010
[INFO][07/01/2020 21:13:08] # of trainable parameters: 66,955,010


In [21]:
print(total_steps)

bot.train(
    total_steps=total_steps,
    checkpoint_interval=len(train_loader) // 2
)
bot.load_model(checkpoints.best_performers[0][1])
checkpoints.remove_checkpoints(keep=0)

[INFO][07/01/2020 21:13:08] Optimizer Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.99)
    eps: 1e-08
    initial_lr: 2e-05
    lr: 2e-05
    weight_decay: 0
)
[INFO][07/01/2020 21:13:08] Batches per epoch: 3157


15785


[INFO][07/01/2020 21:13:38] Step   315 | loss 0.632218 | lr: 2.18e-06 | 0.095s per step


Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0


[INFO][07/01/2020 21:14:07] Step   630 | loss 0.437308 | lr: 4.15e-06 | 0.093s per step
[INFO][07/01/2020 21:14:36] Step   945 | loss 0.383704 | lr: 6.13e-06 | 0.092s per step
[INFO][07/01/2020 21:15:05] Step  1260 | loss 0.364518 | lr: 8.10e-06 | 0.092s per step
[INFO][07/01/2020 21:15:34] Step  1575 | loss 0.350856 | lr: 1.01e-05 | 0.092s per step
[INFO][07/01/2020 21:15:35] Metrics at step 1578:
[INFO][07/01/2020 21:15:35] loss: 0.306696
[INFO][07/01/2020 21:15:35] accuracy: 88.76%
[INFO][07/01/2020 21:16:04] Step  1890 | loss 0.333983 | lr: 1.21e-05 | 0.094s per step
[INFO][07/01/2020 21:16:33] Step  2205 | loss 0.322254 | lr: 1.40e-05 | 0.092s per step
[INFO][07/01/2020 21:17:01] Step  2520 | loss 0.312666 | lr: 1.60e-05 | 0.092s per step
[INFO][07/01/2020 21:17:30] Step  2835 | loss 0.308547 | lr: 1.80e-05 | 0.092s per step


Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0


[INFO][07/01/2020 21:17:59] Step  3150 | loss 0.303859 | lr: 2.00e-05 | 0.092s per step
[INFO][07/01/2020 21:18:00] Metrics at step 3156:
[INFO][07/01/2020 21:18:00] loss: 0.292737
[INFO][07/01/2020 21:18:00] accuracy: 90.14%
[INFO][07/01/2020 21:18:29] Step  3465 | loss 0.294506 | lr: 2.00e-05 | 0.094s per step
[INFO][07/01/2020 21:18:58] Step  3780 | loss 0.292304 | lr: 1.99e-05 | 0.092s per step
[INFO][07/01/2020 21:19:27] Step  4095 | loss 0.291308 | lr: 1.97e-05 | 0.092s per step
[INFO][07/01/2020 21:19:56] Step  4410 | loss 0.289131 | lr: 1.95e-05 | 0.092s per step
[INFO][07/01/2020 21:20:24] Step  4725 | loss 0.288863 | lr: 1.93e-05 | 0.092s per step
[INFO][07/01/2020 21:20:25] Metrics at step 4734:
[INFO][07/01/2020 21:20:25] loss: 0.284858
[INFO][07/01/2020 21:20:25] accuracy: 90.60%
[INFO][07/01/2020 21:20:54] Step  5040 | loss 0.286733 | lr: 1.89e-05 | 0.094s per step
[INFO][07/01/2020 21:21:23] Step  5355 | loss 0.283754 | lr: 1.86e-05 | 0.092s per step
[INFO][07/01/2020 21

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 131072.0


[INFO][07/01/2020 21:30:09] Step 11025 | loss 0.272045 | lr: 6.24e-06 | 0.092s per step
[INFO][07/01/2020 21:30:11] Metrics at step 11046:
[INFO][07/01/2020 21:30:11] loss: 0.272128
[INFO][07/01/2020 21:30:11] accuracy: 92.20%
[INFO][07/01/2020 21:30:39] Step 11340 | loss 0.272244 | lr: 5.52e-06 | 0.095s per step
[INFO][07/01/2020 21:31:08] Step 11655 | loss 0.271524 | lr: 4.84e-06 | 0.092s per step
[INFO][07/01/2020 21:31:37] Step 11970 | loss 0.270636 | lr: 4.18e-06 | 0.092s per step
[INFO][07/01/2020 21:32:06] Step 12285 | loss 0.271408 | lr: 3.56e-06 | 0.092s per step
[INFO][07/01/2020 21:32:35] Step 12600 | loss 0.271894 | lr: 2.98e-06 | 0.092s per step
[INFO][07/01/2020 21:32:37] Metrics at step 12624:
[INFO][07/01/2020 21:32:37] loss: 0.270088
[INFO][07/01/2020 21:32:37] accuracy: 91.97%
[INFO][07/01/2020 21:33:04] Step 12915 | loss 0.270533 | lr: 2.45e-06 | 0.095s per step
[INFO][07/01/2020 21:33:33] Step 13230 | loss 0.270737 | lr: 1.96e-06 | 0.092s per step
[INFO][07/01/2020 

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 262144.0


[INFO][07/01/2020 21:36:27] Step 15120 | loss 0.270291 | lr: 1.37e-07 | 0.092s per step


Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 131072.0


[INFO][07/01/2020 21:36:56] Step 15435 | loss 0.272322 | lr: 3.81e-08 | 0.092s per step
[INFO][07/01/2020 21:37:25] Step 15750 | loss 0.269410 | lr: 4.01e-10 | 0.092s per step
[INFO][07/01/2020 21:37:28] Metrics at step 15780:
[INFO][07/01/2020 21:37:28] loss: 0.271371
[INFO][07/01/2020 21:37:28] accuracy: 91.74%
[INFO][07/01/2020 21:37:29] Training finished. Best step(s):
[INFO][07/01/2020 21:37:29] loss: 0.270088 @ step 12624
[INFO][07/01/2020 21:37:29] accuracy: 92.20% @ step 11046


In [22]:
bot.eval(valid_loader)

{'loss': (0.27008805581189077, '0.270088'),
 'accuracy': (-0.9197247706422018, '91.97%')}

In [23]:
bot.eval(test_loader)

{'loss': (0.1434098145830522, '0.143410'),
 'accuracy': (-0.9105504587155964, '91.06%')}