In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

PROJECT_NAME = "compression-text-models"

curdir = os.path.abspath(os.path.curdir).split("/")
project_index = curdir.index(PROJECT_NAME)
os.chdir("/" + os.path.join(*curdir[:project_index + 1]))

In [3]:
import torch
import transformers
import torchinfo

DEVICE = "cuda"

#torch.cuda.set_per_process_memory_fraction(0.5)
transformers.logging.set_verbosity_error()

In [4]:
tensor_path = "data/bin/tokenized-tensor.pt"
t = torch.load(tensor_path)

# Loading Teacher Model

In [5]:
teacher_name = "neuralmind/bert-base-portuguese-cased"

tokenizer = transformers.BertTokenizer.from_pretrained(
    teacher_name,
)

teacher_model = transformers.BertForMaskedLM.from_pretrained(
    teacher_name, 
    output_hidden_states=True,
)
teacher_model = teacher_model.to(DEVICE)
teacher_model = teacher_model.eval()
torchinfo.summary(teacher_model)

Layer (type:depth-idx)                                  Param #
BertForMaskedLM                                         --
├─BertModel: 1-1                                        --
│    └─BertEmbeddings: 2-1                              --
│    │    └─Embedding: 3-1                              22,881,792
│    │    └─Embedding: 3-2                              393,216
│    │    └─Embedding: 3-3                              1,536
│    │    └─LayerNorm: 3-4                              1,536
│    │    └─Dropout: 3-5                                --
│    └─BertEncoder: 2-2                                 --
│    │    └─ModuleList: 3-6                             85,054,464
├─BertOnlyMLMHead: 1-2                                  --
│    └─BertLMPredictionHead: 2-3                        --
│    │    └─BertPredictionHeadTransform: 3-7            592,128
│    │    └─Linear: 3-8                                 22,911,586
Total params: 108,954,466
Trainable params: 108,954,466
Non-trainable 

# Loading Student Model

In [6]:
student_configuration_path = "configs/distillation/distilbert-base-cased.json"
extracted_base_model = "models/artifacts/model-extraction/default-model.pth"

student_config = transformers.DistilBertConfig.from_pretrained(
    student_configuration_path
)
student_config.output_hidden_states = True
student_model = transformers.DistilBertForMaskedLM.from_pretrained(
    extracted_base_model,
    config=student_config,
)
student_model = student_model.to(DEVICE)
student_model = student_model.train()
torchinfo.summary(student_model)

Layer (type:depth-idx)                                  Param #
DistilBertForMaskedLM                                   --
├─GELUActivation: 1-1                                   --
├─DistilBertModel: 1-2                                  --
│    └─Embeddings: 2-1                                  --
│    │    └─Embedding: 3-1                              22,881,792
│    │    └─Embedding: 3-2                              (393,216)
│    │    └─LayerNorm: 3-3                              1,536
│    │    └─Dropout: 3-4                                --
│    └─Transformer: 2-2                                 --
│    │    └─ModuleList: 3-5                             42,527,232
├─Linear: 1-3                                           590,592
├─LayerNorm: 1-4                                        1,536
├─Linear: 1-5                                           22,911,586
├─CrossEntropyLoss: 1-6                                 --
Total params: 66,425,698
Trainable params: 66,032,482
Non-trainable 

# MLM Smoothing

In [7]:
import pickle

mlm_smoothing = 0.7
token_counts_path = "data/processed/tokenized/separated-token-counts.pickle"

with open(token_counts_path, 'rb') as f:
    token_counts = pickle.load(f)
    counts_tensor = torch.LongTensor(token_counts)
token_probs = torch.maximum(counts_tensor, torch.ones(counts_tensor.shape))
token_probs = torch.pow(token_probs, -mlm_smoothing)
token_probs[tokenizer.all_special_ids] = 0
token_probs, token_probs.size(0) == tokenizer.vocab_size

(tensor([0., 1., 1.,  ..., 1., 1., 1.]), True)

# Generating Batch Example

In [8]:
example = t[:16]
example = example.to(DEVICE)
batch = example, (example != 0).float()
#with torch.no_grad():
#    results = teacher_model(batch[0], attention_mask=batch[1])
#    t_logits, t_hidden_states = results['logits'], results['hidden_states']

In [9]:
batch[1].sum(axis=-1)

tensor([196., 178.,  75., 257., 238., 118., 183., 169., 255., 208., 148., 111.,
         89., 124., 336., 234.], device='cuda:0')

# prepare_batch_mlm

In [10]:
import math

mlm_mask_prop = 0.15

token_ids, attn_mask = batch
token_ids = token_ids.to(DEVICE)
attn_mask = attn_mask.to(DEVICE)

lengths = attn_mask.sum(axis=-1)
token_probs[token_ids.flatten()]

mlm_labels = token_ids.clone().detach()

bs, max_seq_len = token_ids.size()
bs, max_seq_len

x_prob = token_probs[token_ids.flatten()]
n_tgt = math.ceil(mlm_mask_prop * lengths.sum().item())
tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False)
pred_mask = torch.zeros(
    bs * max_seq_len, dtype=torch.bool, device=token_ids.device
) 
pred_mask[tgt_ids] = 1
pred_mask = pred_mask.view(bs, max_seq_len)
pred_mask[token_ids == tokenizer.pad_token_id] = 0
word_mask, word_keep, word_rand = 0.8, 0.1, 0.1
pred_probs = torch.tensor([word_mask, word_keep, word_rand], device=DEVICE)

_token_ids_real = token_ids[pred_mask]
_token_ids_rand = _token_ids_real.clone().random_(tokenizer.vocab_size)
_token_ids_mask = _token_ids_real.clone().fill_(tokenizer.mask_token_id)
probs = torch.multinomial(pred_probs, len(_token_ids_real), replacement=True)
_token_ids = (
    _token_ids_mask * (probs == 0).long()
    + _token_ids_real * (probs == 1).long()
    + _token_ids_rand * (probs == 2).long()
)
token_ids = token_ids.masked_scatter(pred_mask, _token_ids)

mlm_labels[~pred_mask] = -100  # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
mlm_labels

tensor([[ -100,  -100,  -100,  ...,  -100,  -100,  -100],
        [ -100,  -100,  -100,  ...,  -100,  -100,  -100],
        [ -100,  -100, 19747,  ...,  -100,  -100,  -100],
        ...,
        [ -100,  -100,  -100,  ...,  -100,  -100,  -100],
        [ -100, 18116,  -100,  ...,  -100,  -100,  -100],
        [ -100,  -100,  -100,  ...,  -100,  -100,  -100]], device='cuda:0')

In [11]:
batch[0].shape, batch[1].shape

(torch.Size([16, 512]), torch.Size([16, 512]))

# Model Step

In [12]:
student_outputs = student_model(
    input_ids=token_ids,
    attention_mask=attn_mask,
)

teacher_outputs = teacher_model(
    input_ids=token_ids,
    attention_mask=attn_mask,
)

s_logits, s_hidden_states = student_outputs['logits'], student_outputs['hidden_states']
t_logits, t_hidden_states = teacher_outputs['logits'], teacher_outputs['hidden_states']
assert s_logits.size() == t_logits.size()

mask = (mlm_labels > -1).unsqueeze(-1).expand_as(s_logits)

# Loss Definition

In [13]:
temperature = 2.0
alpha_ce = 0.33
alpha_mlm = 0.33
alpha_cos = 0.33

In [14]:
lm_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
ce_loss_fct = torch.nn.KLDivLoss(reduction="batchmean")
cosine_loss_fct = torch.nn.CosineEmbeddingLoss(reduction="mean")

In [15]:
s_logits_slct = torch.masked_select(s_logits, mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1))  # (bs * seq_length, voc_size) modulo the 1s in mask
t_logits_slct = torch.masked_select(t_logits, mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1))  # (bs * seq_length, voc_size) modulo the 1s in mask
assert t_logits_slct.size() == s_logits_slct.size()

In [16]:
loss_ce = (
    ce_loss_fct(
        torch.nn.functional.log_softmax(s_logits_slct / temperature, dim=-1),
        torch.nn.functional.softmax(t_logits_slct / temperature, dim=-1),
    ) * (temperature) ** 2
)

In [17]:
loss_mlm = lm_loss_fct(s_logits.view(-1, s_logits.size(-1)), mlm_labels.view(-1))

In [18]:
s_hidden_states_ = s_hidden_states[-1]  # (bs, seq_length, dim)
t_hidden_states_ = t_hidden_states[-1]  # (bs, seq_length, dim)
mask = attn_mask.unsqueeze(-1).expand_as(s_hidden_states_)  # (bs, seq_length, dim)
assert s_hidden_states_.size() == t_hidden_states_.size()
dim = s_hidden_states_.size(-1)

s_hidden_states_slct = torch.masked_select(s_hidden_states_, mask.bool())  # (bs * seq_length * dim)
s_hidden_states_slct = s_hidden_states_slct.view(-1, dim)  # (bs * seq_length, dim)
t_hidden_states_slct = torch.masked_select(t_hidden_states_, mask.bool())  # (bs * seq_length * dim)
t_hidden_states_slct = t_hidden_states_slct.view(-1, dim)  # (bs * seq_length, dim)

target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1)  # (bs * seq_length,)
loss_cos = cosine_loss_fct(s_hidden_states_slct, t_hidden_states_slct, target)

In [19]:
loss = loss_ce * alpha_ce + \
       loss_mlm * alpha_mlm + \
       loss_cos * alpha_cos
loss.backward()

# Defining Optimizer

In [20]:
weight_decay = 0.0
learning_rate = 5e-4
adam_epsilon = 1e-6

warmup_prop = 0.05
gradient_accumulation_steps = 50
n_epoch = 1
num_steps_epoch = 32

In [21]:
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [
            p for n, p in student_model.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad
        ],
        "weight_decay": weight_decay,
    },
    {
        "params": [
            p for n, p in student_model.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad
        ],
        "weight_decay": 0.0,
    },
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=adam_epsilon, betas=(0.9, 0.98))

In [22]:
from transformers import get_linear_schedule_with_warmup

num_train_optimization_steps = (int(num_steps_epoch / gradient_accumulation_steps * n_epoch) + 1)
warmup_steps = math.ceil(num_train_optimization_steps * warmup_prop)
scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=warmup_steps,
    num_training_steps=num_train_optimization_steps,
)
optimizer.step()
optimizer.zero_grad()

# Comparing to new code

In [23]:
from src.fastdistillation import distiller
from src.fastdistillation import lm_seqs_dataset

In [24]:
token_ids = [
    tokenizer.unk_token_id,
    tokenizer.sep_token_id,
    tokenizer.pad_token_id,
    tokenizer.cls_token_id,
    tokenizer.mask_token_id,
]
token_name_id_map = dict(zip(tokenizer.special_tokens_map.keys(), token_ids))

In [25]:
# Distiller Params
from types import SimpleNamespace

params = SimpleNamespace(**{
    "dump_path": "test.pth",
    "batch_size": 5,
    "temperature": 2.0,
    "alpha_ce": 0.33,
    "alpha_mlm": 0.33,
    "alpha_cos": 0.33,

    "mlm_mask_prop": 0.15,
    "word_mask": 0.8,
    "word_keep": 0.1,
    "word_rand": 0.1,

    "gradient_accumulation_steps": 50,
    "n_epoch": 1,
    "max_grad_norm": 5.0,

    "weight_decay": 0.0,
    "learning_rate": 5e-4,
    "adam_epsilon": 1e-6,

    "warmup_prop": 0.05,
    "special_tok_ids": token_name_id_map,

    "device": "cuda",
    "log_interval": 500,
    "checkpoint_interval": 4000,
})


# Dataset Params
dataset_data = batch[0]
dataset_params = SimpleNamespace(**{
    "special_tok_ids": tokenizer.special_tokens_map
})

dataset = lm_seqs_dataset.LmSeqsDataset(dataset_data, attn_mask)
token_probs = token_probs
student = student_model
teacher = teacher_model

dist = distiller.Distiller(
    params,
    dataset,
    token_probs,
    student,
    teacher,
)

09/07/2022 18:22:32 - INFO - src.fastdistillation.distiller - PID: 565822 -  Initializing Distiller
09/07/2022 18:22:32 - INFO - src.fastdistillation.distiller - PID: 565822 -  Using MLM loss for LM step.
09/07/2022 18:22:32 - INFO - src.fastdistillation.distiller - PID: 565822 -  --- Initializing model optimizer
09/07/2022 18:22:32 - INFO - src.fastdistillation.distiller - PID: 565822 -  ------ Number of trainable parameters (student): 66032482
09/07/2022 18:22:32 - INFO - src.fastdistillation.distiller - PID: 565822 -  ------ Number of parameters (student): 66425698
09/07/2022 18:22:32 - INFO - src.fastdistillation.distiller - PID: 565822 -  --- Initializing Tensorboard


In [26]:
dist.train()

09/07/2022 18:22:33 - INFO - src.fastdistillation.distiller - PID: 565822 -  Starting training
09/07/2022 18:22:33 - INFO - src.fastdistillation.distiller - PID: 565822 -  --- Starting epoch 0/0
-Iter: 100%|██████████| 4/4 [00:00<00:00,  6.29it/s, Last_loss=4.79, Avg_cum_loss=4.65]
09/07/2022 18:22:33 - INFO - src.fastdistillation.distiller - PID: 565822 -  --- Ending epoch 0/0
09/07/2022 18:22:33 - INFO - src.fastdistillation.distiller - PID: 565822 -  16 sequences have been trained during this epoch.
09/07/2022 18:22:34 - INFO - src.fastdistillation.distiller - PID: 565822 -  Save very last checkpoint as `pytorch_model.bin`.
09/07/2022 18:22:34 - INFO - src.fastdistillation.distiller - PID: 565822 -  Training is finished


# Extracting inputs and outputs from distiller

In [27]:
token_ids, attn_mask, mlm_labels = dist.prepare_batch_mlm(batch)

In [38]:
mask = (mlm_labels > -1).unsqueeze(-1).expand_as(s_logits)

In [49]:
s_logits[0][1]

tensor([-6.0699, -4.7514, -5.4603,  ..., -6.0048, -4.1862, -5.3527],
       device='cuda:0', grad_fn=<SelectBackward0>)

In [44]:
torch.masked_select(s_logits, mask)

tensor([-6.0699, -4.7514, -5.4603,  ..., -5.0317, -6.4731, -5.6466],
       device='cuda:0', grad_fn=<MaskedSelectBackward0>)

In [28]:
token_ids[0][:10], attn_mask[0][:10], mlm_labels[0][:10]

(tensor([  101, 11433, 22332,   243,  6240,  1445, 22341,  4790, 18471, 22322],
        device='cuda:0'),
 tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0'),
 tensor([ -100,  -100,  -100,  -100,  -100, 20697,  -100,  -100,  -100,  -100],
        device='cuda:0'))

# Passing data through teacher model and extracting teacher hidden_state

In [29]:
from torch.utils.data import DataLoader

dl = DataLoader(t, batch_size=128, num_workers=12, pin_memory=True)
teacher_movel = teacher_model.eval()

for idx, token_ids in enumerate(dl):
    token_ids, attn_mask = token_ids, (token_ids != 0).float()
    token_ids = token_ids.to(DEVICE)
    attn_mask = attn_mask.to(DEVICE)
    batch = token_ids, attn_mask
    token_ids, attn_mask, mlm_labels = dist.prepare_batch_mlm(batch)
    with torch.no_grad():
        results = teacher_model(token_ids, attn_mask)
    if idx % 100 == 0:
        print(f"Index {idx} of {len(dl)}")
        break

Index 0 of 75600


In [31]:
results

MaskedLMOutput(loss=None, logits=tensor([[[ -7.2039,  -8.3236,  -7.0194,  ...,  -8.7136,  -6.3941,  -7.2889],
         [ -9.7674,  -5.1069,  -9.3473,  ...,  -9.1819,  -5.3219,  -8.2620],
         [-11.9612,  -8.1032, -10.0777,  ..., -10.6829,  -9.6714, -11.1105],
         ...,
         [ -7.1679,  -7.3371,  -7.6079,  ...,  -6.5680,  -7.7695,  -7.7533],
         [ -6.8212,  -8.3886,  -8.9212,  ...,  -8.1491,  -6.4727,  -8.6142],
         [ -7.2082,  -8.3268,  -7.0258,  ...,  -8.7107,  -6.3904,  -7.2802]],

        [[-11.1519,  -9.8615,  -7.9989,  ..., -10.5612,  -9.0526,  -9.4292],
         [ -5.9810,  -4.4497,  -4.4526,  ...,  -5.0530,  -5.1152,  -6.9403],
         [ -3.6801,  -3.5615,  -4.8330,  ...,  -4.4447,  -4.1416,  -3.1035],
         ...,
         [ -8.0703,  -8.5968,  -6.9216,  ...,  -7.1585,  -6.3272,  -5.3865],
         [ -8.1839,  -8.5917,  -6.9312,  ...,  -7.0328,  -6.3574,  -5.6031],
         [-11.1485,  -9.8577,  -8.0014,  ..., -10.5544,  -9.0517,  -9.4273]],

        [[ 