# NLP intro

In [None]:
!pip -q install torch==1.7 transformers==4.0.0 datasets catalyst==21.03

In [None]:
from datasets import load_dataset

imdb_dataset = load_dataset("imdb")

In [None]:
imdb_dataset

In [None]:
imdb_dataset["train"][0]

In [None]:
test = imdb_dataset["train"][0]["text"]

In [None]:
# Q: what about text preprocessing?

In [None]:
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained("google/bert_uncased_L-6_H-256_A-4")

In [None]:
print(tokenizer.tokenize(test))

In [None]:
print(tokenizer.encode(test))

Tokenizer has additional functions to create attention masks, get offsets mapping or token types to train transformer models.

In [None]:
print(tokenizer.encode_plus(test))

In [None]:
print(tokenizer.encode_plus(test, max_length=64, truncation=True, padding="max_length", return_tensors="pt"))

In [None]:
from typing import Dict, Any
import torch
from catalyst.utils import get_loader


def text_data_transforms(row: Dict[str, Any]) -> Dict[str, torch.Tensor]:
    tokens = tokenizer.encode_plus(
        test, 
        max_length=64, 
        truncation=True, 
        padding="max_length", 
        return_tensors="pt"
    )
    tokens = {k: v[0] for k, v in tokens.items()}
    tokens.update({"targets": row["label"]})
    return tokens
    

train_dataloader = get_loader(
    imdb_dataset["train"],
    open_fn=lambda x: x,
    dict_transform=text_data_transforms,
    batch_size=256,
    num_workers=4,
    shuffle=True,
    drop_last=True,
)

valid_dataloader = get_loader(
    imdb_dataset["test"],
    open_fn=lambda x: x,
    dict_transform=text_data_transforms,
    batch_size=256,
    num_workers=4,
    shuffle=True,
    drop_last=True,
)

In [None]:
next(iter(train_dataloader))

In [None]:
loaders = {
    "train": train_dataloader,
    "valid": valid_dataloader
}

In [None]:
# Q: what about BERT?

Load BERT model for SequenceClassification. We need models smaller, than bert-uncased-base. List of the all model: model names.

In [None]:
from torch import nn, optim
from torch.nn import functional as F
from transformers import BertForSequenceClassification


model = BertForSequenceClassification.from_pretrained("google/bert_uncased_L-6_H-256_A-4")
model.classifier = nn.Linear(256, 2)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=3e-4)

In [None]:
from catalyst import dl, metrics, utils

In [None]:
class BertRunner(dl.SupervisedRunner):
    def handle_batch(self, batch: Dict[str, torch.Tensor]):
        input = {k: batch[k] for k in self._input_key}
        output = self.model(**input, return_dict=True) # logits
        self.batch = {**self.batch, **output}

runner = BertRunner(input_key=["input_ids", "attention_mask"])

In [None]:
from datetime import datetime
from pathlib import Path

logdir=Path("logs") / datetime.now().strftime("%Y%m%d-%H%M%S")
runner.train(
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    loaders=loaders,
    logdir=logdir,
    num_epochs=3,
    verbose=True,
    callbacks=[
        dl.AccuracyCallback(input_key="logits", target_key="targets", num_classes=2),
        # dl.CheckpointCallback(logdir=logdir, loader_key="valid", metric_key="loss", minimize=True, save_n_best=3),
        # dl.CheckRunCallback(),
    ],
)

In [None]:
logdir

In [None]:
# model stochastic weight averaging
# model.load_state_dict(utils.get_averaged_weights_by_path_mask(path_mask="logs/20210316-110940/train.*.pth"))

----

In [None]:
# Q: how could we have it smaller?

In [None]:
from collections import OrderedDict

NUM_TOKENS = tokenizer.vocab_size

class StudentModel(nn.Module):
    def __init__(self):
        super().__init__()

        embedding_size = 128
        self.embedding = nn.Embedding(NUM_TOKENS, 128)

        transformer_blocks = []
        for i in range(2):
            transformer_block = nn.TransformerEncoderLayer(
                d_model=embedding_size,
                nhead=2,
                dim_feedforward=128,
                dropout=0.2,
            )
            transformer_blocks.append((f"transformer_block_{i}", transformer_block))

        self.transformer_encoder = nn.Sequential(OrderedDict(transformer_blocks))

        self.linear = nn.Linear(in_features=embedding_size, out_features=128)
        self.scorer = nn.Linear(in_features=128, out_features=2)

    def forward(self, input_ids):

        embeddings = self.embedding(input_ids)

        transformer_output = self.transformer_encoder(embeddings)
        pooling = torch.mean(transformer_output, dim=1)
        linear = torch.relu(self.linear(pooling))
        logits = self.scorer(linear)

        return logits

In [None]:
class DistilRunner(dl.SupervisedRunner):
    def handle_batch(self, batch: Dict[str, torch.Tensor]):
        teacher, student = self.model["teacher"], self.model["student"]

        teacher.eval()  # manually set teacher model to eval mode
        attention_mask = batch["input_ids"] != 0
        with torch.no_grad():
            input = {k: batch[k] for k in self._input_key}
            t_logits = teacher(**input, return_dict=True)["logits"]

        s_logits = student(batch["input_ids"])
        self.batch = {**self.batch, **{"t_logits": t_logits, "s_logits": s_logits}}

In [None]:
from catalyst.core.callback import CallbackOrder

class KLDivLossCallback(dl.Callback):
    def __init__(self):
        super().__init__(order=CallbackOrder.metric)
        self.temperature = 1.0
        self._criterion = nn.KLDivLoss(reduction="batchmean")
        self._metric = metrics.AdditiveValueMetric()
  
    def on_loader_start(self, runner):
        self._metric.reset()
    
    def on_batch_end(self, runner):
        s_logits, t_logits = runner.batch["s_logits"], runner.batch["t_logits"]
        # As with NLLLoss, the input given is expected to contain log-probabilities and is not restricted to a 2D Tensor. 
        # The targets are interpreted as probabilities by default, but could be considered as log-probabilities with log_target set to True.
        loss_kl = (
            self._criterion(
                F.log_softmax(s_logits / self.temperature, dim=-1),
                F.softmax(t_logits / self.temperature, dim=-1),
            )
            * self.temperature ** 2
        )

        runner.batch_metrics["kl_div_loss"] = loss_kl
        self._metric.update(loss_kl.item(), len(s_logits))
    
    def on_loader_end(self, runner):
        mean, _ = self._metric.compute()
        mean = torch.tensor(mean, device=runner.device)
        runner.loader_metrics["kl_div_loss"] = mean

In [None]:
runner = DistilRunner(input_key=["input_ids", "attention_mask"])

from datetime import datetime
from pathlib import Path

logdir=Path("logs") / datetime.now().strftime("%Y%m%d-%H%M%S")
runner.train(
    model={"teacher": model, "student": StudentModel()},
    optimizer=optimizer,
    criterion=criterion,
    loaders=loaders,
    logdir=logdir,
    num_epochs=3,
    verbose=True,
    callbacks=[
        dl.AccuracyCallback(input_key="t_logits", target_key="targets", num_classes=2, prefix="teacher_"),
        dl.AccuracyCallback(input_key="s_logits", target_key="targets", num_classes=2, prefix="student_"),
        dl.CriterionCallback(input_key="s_logits", target_key="targets", metric_key="cls_loss"),
        KLDivLossCallback(),
        dl.MetricAggregationCallback(prefix="loss", metrics=["kl_div_loss", "cls_loss"], mode="mean"),
        dl.OptimizerCallback(metric_key="loss", model_key="student"),
        dl.CheckpointCallback(logdir=logdir, loader_key="valid", metric_key="loss", minimize=True, save_n_best=3),
        dl.CheckRunCallback(),
    ],
)

In [None]:
# Q: what about model post processing?

In [None]:
features_batch = next(iter(loaders["valid"]))["input_ids"]
model = runner.model["student"]

In [None]:
features_batch

In [None]:
# model tracing
utils.trace_model(model=model.cpu(), batch=features_batch)

In [None]:
# model quantization
utils.quantize_model(model=model)

In [None]:
# model pruning
utils.prune_model(model=model, pruning_fn="l1_unstructured", amount=0.8)

In [None]:
# onnx export
# utils.onnx_export(model=model, batch=features_batch, file="./logs/mnist.onnx", verbose=True)