In [1]:
import os

PROJECT_NAME = "compression-text-models"

curdir = os.path.abspath(os.path.curdir).split("/")
project_index = curdir.index(PROJECT_NAME)
os.chdir("/" + os.path.join(*curdir[:project_index + 1]))

In [2]:
import torch
import transformers
import torchinfo

DEVICE = "cuda"

transformers.logging.set_verbosity_error()

In [3]:
tensor_path = "data/bin/tokenized-tensor.pt"
t = torch.load(tensor_path)

# Loading Teacher Model

In [4]:
teacher_name = "neuralmind/bert-base-portuguese-cased"

tokenizer = transformers.BertTokenizer.from_pretrained(
    teacher_name,
)

teacher_model = transformers.BertForMaskedLM.from_pretrained(
    teacher_name, 
    output_hidden_states=True,
)
teacher_model = teacher_model.to(DEVICE)
teacher_model = teacher_model.eval()
torchinfo.summary(teacher_model)

Layer (type:depth-idx)                                  Param #
BertForMaskedLM                                         --
├─BertModel: 1-1                                        --
│    └─BertEmbeddings: 2-1                              --
│    │    └─Embedding: 3-1                              22,881,792
│    │    └─Embedding: 3-2                              393,216
│    │    └─Embedding: 3-3                              1,536
│    │    └─LayerNorm: 3-4                              1,536
│    │    └─Dropout: 3-5                                --
│    └─BertEncoder: 2-2                                 --
│    │    └─ModuleList: 3-6                             85,054,464
├─BertOnlyMLMHead: 1-2                                  --
│    └─BertLMPredictionHead: 2-3                        --
│    │    └─BertPredictionHeadTransform: 3-7            592,128
│    │    └─Linear: 3-8                                 22,911,586
Total params: 108,954,466
Trainable params: 108,954,466
Non-trainable 

# Loading Student Model

In [5]:
student_configuration_path = "configs/distillation/distilbert-base-cased.json"
extracted_base_model = "models/artifacts/model-extraction/default-model.pth"

student_config = transformers.DistilBertConfig.from_pretrained(
    student_configuration_path
)
student_config.output_hidden_states = True
student_model = transformers.DistilBertForMaskedLM.from_pretrained(
    extracted_base_model,
    config=student_config,
)
student_model = student_model.to(DEVICE)
student_model = student_model.train()
torchinfo.summary(student_model)

Layer (type:depth-idx)                                  Param #
DistilBertForMaskedLM                                   --
├─GELUActivation: 1-1                                   --
├─DistilBertModel: 1-2                                  --
│    └─Embeddings: 2-1                                  --
│    │    └─Embedding: 3-1                              22,881,792
│    │    └─Embedding: 3-2                              (393,216)
│    │    └─LayerNorm: 3-3                              1,536
│    │    └─Dropout: 3-4                                --
│    └─Transformer: 2-2                                 --
│    │    └─ModuleList: 3-5                             42,527,232
├─Linear: 1-3                                           590,592
├─LayerNorm: 1-4                                        1,536
├─Linear: 1-5                                           22,911,586
├─CrossEntropyLoss: 1-6                                 --
Total params: 66,425,698
Trainable params: 66,032,482
Non-trainable 

# MLM Smoothing

In [6]:
import pickle

mlm_smoothing = 0.7
token_counts_path = "data/processed/tokenized/separated-token-counts.pickle"

with open(token_counts_path, 'rb') as f:
    token_counts = pickle.load(f)
    counts_tensor = torch.LongTensor(token_counts)
token_probs = torch.maximum(counts_tensor, torch.ones(counts_tensor.shape))
token_probs = torch.pow(token_probs, -mlm_smoothing)
token_probs[tokenizer.all_special_ids] = 0
token_probs, token_probs.size(0) == tokenizer.vocab_size

(tensor([0., 1., 1.,  ..., 1., 1., 1.]), True)

# Generating Batch Example

In [7]:
example = t[:16]
example = example.to(DEVICE)
batch = example, (example != 0).float()
with torch.no_grad():
    results = teacher_model(batch[0], attention_mask=batch[1])
    t_logits, t_hidden_states = results['logits'], results['hidden_states']

# prepare_batch_mlm

In [8]:
import math

mlm_mask_prop = 0.15

token_ids, attn_mask = batch
token_ids = token_ids.to(DEVICE)
attn_mask = attn_mask.to(DEVICE)

lengths = attn_mask.sum(axis=-1)
token_probs[token_ids.flatten()]

mlm_labels = token_ids.clone().detach()

bs, max_seq_len = token_ids.size()
bs, max_seq_len

x_prob = token_probs[token_ids.flatten()]
n_tgt = math.ceil(mlm_mask_prop * lengths.sum().item())
tgt_ids = torch.multinomial(x_prob / x_prob.sum(), n_tgt, replacement=False)
pred_mask = torch.zeros(
    bs * max_seq_len, dtype=torch.bool, device=token_ids.device
) 
pred_mask[tgt_ids] = 1
pred_mask = pred_mask.view(bs, max_seq_len)
pred_mask[token_ids == tokenizer.pad_token_id] = 0
word_mask, word_keep, word_rand = 0.8, 0.1, 0.1
pred_probs = torch.tensor([word_mask, word_keep, word_rand], device=DEVICE)

_token_ids_real = token_ids[pred_mask]
_token_ids_rand = _token_ids_real.clone().random_(tokenizer.vocab_size)
_token_ids_mask = _token_ids_real.clone().fill_(tokenizer.mask_token_id)
probs = torch.multinomial(pred_probs, len(_token_ids_real), replacement=True)
_token_ids = (
    _token_ids_mask * (probs == 0).long()
    + _token_ids_real * (probs == 1).long()
    + _token_ids_rand * (probs == 2).long()
)
token_ids = token_ids.masked_scatter(pred_mask, _token_ids)

mlm_labels[~pred_mask] = -100  # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
mlm_labels

tensor([[ -100, 11433,  -100,  ...,  -100,  -100,  -100],
        [ -100,  -100, 13859,  ...,  -100,  -100,  -100],
        [ -100,  2627, 19747,  ...,  -100,  -100,  -100],
        ...,
        [ -100,  -100,  -100,  ...,  -100,  -100,  -100],
        [ -100, 18116,  -100,  ...,  -100,  -100,  -100],
        [ -100,  -100,  2812,  ...,  -100,  -100,  -100]], device='cuda:0')

# Model Step

In [9]:
student_outputs = student_model(
    input_ids=token_ids,
    attention_mask=attn_mask,
)

teacher_outputs = teacher_model(
    input_ids=token_ids,
    attention_mask=attn_mask,
)

s_logits, s_hidden_states = student_outputs['logits'], student_outputs['hidden_states']
t_logits, t_hidden_state = teacher_outputs['logits'], teacher_outputs['hidden_states']
assert s_logits.size() == t_logits.size()

mask = (mlm_labels > -1).unsqueeze(-1).expand_as(s_logits)

# Loss Definition

In [10]:
temperature = 2.0
alpha_ce = 0.33
alpha_mlm = 0.33
alpha_cos = 0.33

In [11]:
lm_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
ce_loss_fct = torch.nn.KLDivLoss(reduction="batchmean")
cosine_loss_fct = torch.nn.CosineEmbeddingLoss(reduction="mean")

In [12]:
s_logits_slct = torch.masked_select(s_logits, mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1))  # (bs * seq_length, voc_size) modulo the 1s in mask
t_logits_slct = torch.masked_select(t_logits, mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1))  # (bs * seq_length, voc_size) modulo the 1s in mask
assert t_logits_slct.size() == s_logits_slct.size()

In [21]:
loss_ce = (
    ce_loss_fct(
        torch.nn.functional.log_softmax(s_logits_slct / temperature, dim=-1),
        torch.nn.functional.softmax(t_logits_slct / temperature, dim=-1),
    ) * (temperature) ** 2
)

In [22]:
loss_mlm = lm_loss_fct(s_logits.view(-1, s_logits.size(-1)), mlm_labels.view(-1))

In [23]:
s_hidden_states_ = s_hidden_states[-1]  # (bs, seq_length, dim)
t_hidden_states_ = t_hidden_states[-1]  # (bs, seq_length, dim)
mask = attn_mask.unsqueeze(-1).expand_as(s_hidden_states_)  # (bs, seq_length, dim)
assert s_hidden_states_.size() == t_hidden_states_.size()
dim = s_hidden_states_.size(-1)

s_hidden_states_slct = torch.masked_select(s_hidden_states_, mask.bool())  # (bs * seq_length * dim)
s_hidden_states_slct = s_hidden_states_slct.view(-1, dim)  # (bs * seq_length, dim)
t_hidden_states_slct = torch.masked_select(t_hidden_states_, mask.bool())  # (bs * seq_length * dim)
t_hidden_states_slct = t_hidden_states_slct.view(-1, dim)  # (bs * seq_length, dim)

target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1)  # (bs * seq_length,)
loss_cos = cosine_loss_fct(s_hidden_states_slct, t_hidden_states_slct, target)

In [24]:
loss = loss_ce * alpha_ce + \
       loss_mlm * alpha_mlm + \
       loss_cos * alpha_cos
loss.backward()

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

# Defining Optimizer

In [25]:
weight_decay = 0.0
learning_rate = 5e-4
adam_epsilon = 1e-6

warmup_prop = 0.05
gradient_accumulation_steps = 50
n_epoch = 1
num_steps_epoch = 32

In [18]:
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [
            p for n, p in student_model.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad
        ],
        "weight_decay": weight_decay,
    },
    {
        "params": [
            p for n, p in student_model.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad
        ],
        "weight_decay": 0.0,
    },
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=adam_epsilon, betas=(0.9, 0.98))

In [20]:
from transformers import get_linear_schedule_with_warmup

num_train_optimization_steps = (int(num_steps_epoch / gradient_accumulation_steps * n_epoch) + 1)
warmup_steps = math.ceil(num_train_optimization_steps * warmup_prop)
scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=warmup_steps,
    num_training_steps=num_train_optimization_steps,
)
optimizer.step()
optimizer.zero_grad()