In [None]:
from typing import NamedTuple, Tuple, Dict, List, Union
import sys
import math
import torch
import random
import pandas as pd
from tqdm import tqdm


DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

<center><img src="media/rnn-transformer-meme.png" style="width: 500px;"/></center>

## Что же такое этот Transformer?
Transformer - модель полностью основанная только на Attention механизме.
<table>
    <td> <img src="media/transformer-architecture.png" style="width: 500px;"/> </td>
    <td> <img src="media/dot-product-multi-head.png" style="width: 600px;"/> </td>
</table>
$B$ - batch size, $T$ - sequence length, $E$ - embedding size
$$
\begin{eqnarray}
    &Input &&\sim &&B \times T \times E \\
    &Q = InputW^{Q} &&\sim &&B \times T \times QK_{dim} \\
    &K = InputW^{K} &&\sim &&B \times T \times QK_{dim} \\
    &V = InputW^{V} &&\sim &&B \times T \times V_{dim} \\
    &QK^{T} &&\sim &&B \times T \times T
\end{eqnarray}
$$
$$
\begin{equation}
    Self \hbox{-} Attention(Q,\,K,\,V) = Softmax\biggl(\frac{QK^{T}}{\sqrt{QK_{dim}}}\biggr)V
\end{equation}
$$
Где $W^{Q} \in R^{E \times QK_{dim}},\, W^{K} \in R^{E \times QK_{dim}},\, W^{V} \in R^{E \times V_{dim}}$

В качестве Norm используют [LayerNorm by Hinton](https://arxiv.org/abs/1607.06450), так как он значительно лучше работает для NLP задач и не зависит от размера Batch, благодаря чему можно делать онлайн-обучение. Также для LayerNorm нет необходимости делать скользящее средней по expectation и variance.
Теперь порисуем где-нибудь модель и напишем псевдокод для каждой части.

## Преимущества Transformer
<center><img src="media/lstm-schmidhuber-meme.jpg" style="width: 500px;"/></center>

1. Быстрее.
2. Проще распаралелить.
3. Лучше работает с длинными последовательностями (так как расстояние от любого другого токена равно `O(1)`, в то время как у RNN `O(n)`).
4. Интерпретируемость.

## Напишем свой MultiHead Self-Attention

In [None]:
class MultiHeadAttentionOutput(NamedTuple):
    """Result of MultiHeadAttention Module call."""

    values: torch.Tensor
    attention: torch.Tensor


class MultiHeadAttention(torch.nn.Module):
    """
    Compute Multi-Head Attention like in "Attention Is All You Need" paper.
    For simplicity assume that value dimension equals query and key dimension.

    Parameters
    ----------
    num_heads : `int`, required
        Number of heads for Self-Attention.
    hidden_size : `int`, required
        Hidden size for projection in Self-Attention.
    bias : `bool`, optional (default = `True`)
        Whether to include bias for projection or not.
    dropout : `float`, optional (default = `0.1`)
        Dropout probability for Self-Attention after softmax.
    output_size : `int`, optional (default = `None`)
        Size for output projection. If None hidden size is used.
    attention_fill_value : `float`, optional (default = `1e-32`)
        Fill value for attention before softmax if mask is passed if forward.
    """

    def __init__(
        self,
        num_heads: int,
        hidden_size: int,
        bias: bool = True,
        dropout: float = 0.1,
        output_size: int = None,
        attention_fill_value: float = 1e-32
    ) -> None:
        super().__init__()
        assert hidden_size % num_heads == 0
        self.attn_size = hidden_size // num_heads
        self.num_heads = num_heads
        self.projections = torch.nn.ModuleDict({
            "query": torch.nn.Linear(hidden_size, hidden_size, bias=bias),
            "key": torch.nn.Linear(hidden_size, hidden_size, bias=bias),
            "value": torch.nn.Linear(hidden_size, hidden_size, bias=bias)
        })
        self.output = torch.nn.Linear(hidden_size, output_size or hidden_size)
        self.attention_fill_value = attention_fill_value
        if dropout:
            self.dropout = torch.nn.Dropout(p=dropout)
        else:
            self.dropout = lambda x: x

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: torch.Tensor = None
    ) -> MultiHeadAttentionOutput:
        # query ~ (batch size, seq length, hidden size)
        # key ~ (batch size, seq length, hidden size)
        # value ~ (batch size, seq length, hidden size)
        # mask ~ (batch size, seq length)
        # Because we decided to keep the model simple
        # then query.size = key.size = value.size
        original_size = value.size()
        # 1) Linear projections in batch from
        # (batch size, seq length, hidden size) => (batch size, num heads, seq length, attn size).
        # TODO: YOUR CODE HERE
        # 2) Apply self-attention.
        # output ~ (batch size, num heads, seq length, attn size)
        # attn ~ (batch size, num heads, seq length, seq length)
        # TODO: YOUR CODE HERE
        # 3) Rearrange back to normal.
        # output ~ (batch size, seq length, hidden size)
        # TODO: YOUR CODE HERE
        return MultiHeadAttentionOutput(torch.Tensor(), torch.Tensor())

Зафиксировать seed и посмотрим работает ли))

In [None]:
tensor = torch.randn(10, 15, 256)
module = MultiHeadAttention(num_heads=4, hidden_size=256, dropout=0.0, bias=False)
module(tensor, tensor, tensor).values.size()

Ну а теперь проверим с PyTorch MultiheadAttention

In [None]:
from torch.nn.modules.activation import MultiheadAttention as MultiheadAttentionTorch
# Instantiate Module
torch_attention = MultiheadAttentionTorch(embed_dim=256, num_heads=4, dropout=0.0, bias=False)
# Change parameters
torch_attention._qkv_same_embed_dim = False
# Projections
torch_attention.q_proj_weight = module.projections["query"].weight
torch_attention.k_proj_weight = module.projections["key"].weight
torch_attention.v_proj_weight = module.projections["value"].weight
torch_attention.out_proj = module.output
# Permute tensor because PyTorch MultiheadAttention
# accepts (sequence length, batch size, embedding size)
torch_attn_tensor = tensor.permute(1, 0, 2).contiguous()
assert torch.allclose(
    torch_attention(torch_attn_tensor, torch_attn_tensor, torch_attn_tensor)[0].permute(1, 0, 2),
    module(tensor, tensor, tensor).values
)

Сделать модель на основе Transformer в PyTorch.

Загрузим данные для NMT French-English

In [None]:
# Create data dir if not exists
!mkdir -p data
# Load dataset
!wget https://www.manythings.org/anki/fra-eng.zip -O data/fra-eng.zip
# Unarchive dataset
!tar -xvf data/fra-eng.zip --directory data/
# Rename for build_dataset to work
!mv data/fra.txt data/eng-fra.txt

Ниже код для предобработки данных. Не обращайте на него внимание, он взят из [туториала PyTorch](https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html).

In [None]:
from utils import build_dataset


build_dataset()

### Подготовим LabelEncoder и Dataset в стилье PyTorch

In [None]:
from collections import defaultdict
from torch.utils.data import Dataset, DataLoader


SOS = "<start>"
EOS = "<end>"
PAD = "<pad>"


class Vocabulary:
    """
    Vocabulry construct label encoding (token -> index)
    and labeld decoding (index -> token) based on datasets.
    """

    def __init__(self, **datasets) -> None:
        self._word2index = defaultdict(dict)
        self._index2word = defaultdict(dict)
        self._setup_indexers()
        for name, dataset in datasets.items():
            self._iterate_dataset(name, dataset)

    @property
    def word2index(self):
        return self._word2index

    @property
    def index2word(self):
        return self._index2word

    def get_size(self, field: str) -> int:
        return len(self._word2index[field])

    def _iterate_dataset(self, name: str, dataset: Dataset) -> None:
        for sample in tqdm(dataset, desc=f"Building vocab from {name}", total=len(dataset)):
            self._iterate_sample(sample)

    def _iterate_sample(self, sample: Dict[str, List[str]]) -> None:
        for field, sentence in sample.items():
            for word in filter(
                lambda x: x not in self._word2index[field], sentence
            ):
                index = len(self._word2index[field])
                self._word2index[field][word] = index
                self._index2word[field][index] = word

    def _setup_indexers(self) -> None:
        for token_type in [PAD, SOS, EOS]:
            index = len(self._word2index["source"])
            self._word2index["source"][token_type] = index
            self._word2index["target"][token_type] = index
            self._index2word["source"][index] = token_type
            self._index2word["target"][index] = token_type


class NMTDataset(Dataset):
    """
    PyTorch Dataset for NMT task.

    Parameters
    ----------
    data_path : `str`, required
        Path to dataset.
    """

    def __init__(self, data_path: str) -> None:
        self._dataset = []
        self._indexer = None
        with open(data_path, "r", encoding="utf-8") as file:
            for sample in tqdm(file, desc=f"Reading dataset at {data_path} in memory"):
                source, target = sample.rstrip().split('\t')
                self._dataset.append({
                    "source": source.split() + [EOS],
                    "target": [SOS] + target.split() + [EOS],
                })

    def __getitem__(self, idx: int) -> Dict[str, List[str]]:
        return {
            k: [self._indexer[k][x] if self._indexer else x for x in v]
            for k, v in self._dataset[idx].items()
        }

    def __len__(self) -> int:
        return len(self._dataset)

    def index_with(self, vocab: Vocabulary) -> None:
        self._indexer = vocab.word2index

In [None]:
from torch.nn.utils.rnn import pad_sequence


# collate_fn в DataLoader вызывается чуть ли не в самом конце сборки одного батча
# поэтому это отличный способ сделать дополнительную предобработку данных
# в нашем случае нам необходимо сделать padding и сделать маску
class CollateBatch:
    """
    Collate Function for DataLoader as Class to perform postprocessing.
    We better use classes because it is better for code structure
    and much more convenient to build custom functions for batch of samples such pin_memory.

    Parameters
    ----------
    instances : `List[Dict[str, List[int]]]`
        List of samples as dicts from PyTorch Dataset.
    """

    def __init__(self, instances: List[Dict[str, List[int]]]) -> None:
        batch = self._form_batch(instances)
        # Pad sentences as each sentence is of different size
        self.source_tokens = pad_sequence([torch.Tensor(x) for x in batch['source']], batch_first=True).long()
        self.target_tokens = pad_sequence([torch.Tensor(x) for x in batch['target']], batch_first=True).long()
        # Construct mask to identify padding
        self.source_mask = self.source_tokens.ne(0)
        self.target_mask = self.target_tokens.ne(0)

    def pin_memory(self):
        """Pin memory for fast data transfer on CUDA."""
        self.__dict__ = {
            prop: tensor.pin_memory()
            for prop, tensor in self.__dict__.items()
        }
        return self

    def to_device(
        self,
        device: Union[str, torch.device],
        **extra_params
    ) -> Dict[str, torch.Tensor]:
        """Helper function to send batch to device and convert it to dict."""
        return {
            prop: tensor.to(device=device, **extra_params)
            for prop, tensor in self.__dict__.items()
        }

    @staticmethod
    def _form_batch(instances: List[Dict[str, List[int]]]) -> Dict[str, List[List]]:
        """Consturct normal batched data as dict from list of dicts."""
        tensor_dict = defaultdict(list)
        for instance in instances:
            for field, tensor in instance.items():
                tensor_dict[field].append(tensor)
        return tensor_dict

In [None]:
# Building datasets
train_dataset = NMTDataset('./data/train.txt')
validation_dataset = NMTDataset('./data/test.txt')
# Building Vocabulry to convert words to indecies sequence
vocab = Vocabulary(
    train_dataset=train_dataset,
    validation_dataset=validation_dataset
)
# Index datasets with Vocabulary
train_dataset.index_with(vocab)
validation_dataset.index_with(vocab)
# Building DataLoaders
train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=32,
    num_workers=1,
    pin_memory=True,
    collate_fn=CollateBatch,
    shuffle=True,
)
validation_dataloader = DataLoader(
    dataset=validation_dataset,
    batch_size=32,
    num_workers=1,
    pin_memory=True,
    collate_fn=CollateBatch,
)

### Напишем модель на основе LSTM для NMT задачи

In [None]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence


class LSTMEncoderOutput(NamedTuple):
    final_state: torch.Tensor
    cell_state: torch.Tensor


class LSTMEncoder(torch.nn.Module):
    """
    Decode text data with LSTM Module.

    Parameters
    ----------
    input_size : `int`, required
        Hidden size of input tokens.
    hidden_size : `int`, required
        Hidden size for LSTM.
    num_layers : `int`, optional (default = `1`)
        Number of stacked LSTM modules.
    dropout : `int`, optional (default = `0.2`)
        Dropout for LSTM module.
    """
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        num_layers: int = 1,
        dropout: float = 0.2
    ) -> None:
        super().__init__()
        self._input_size = input_size
        self._hidden_size = hidden_size
        self._encoder = torch.nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout,
            batch_first=True,
            bidirectional=True,
        )

    def forward(self, source_tokens: torch.Tensor, mask: torch.Tensor) -> LSTMEncoderOutput:
        # source_tokens ~ (batch size, sequence length, hidden size)
        # mask ~ (batch size, sequence length)
        # 1) Pack sequence to skip processing of padding in tokens tensor
        # (LSTM returns 0 on padding)
        # Get lengths of each sequence
        lenghts = mask.long().sum(dim=-1)
        # Pack sequence to skip padding in RNN
        packed_sequence_input = pack_padded_sequence(
            input=source_tokens,
            lengths=lenghts.data.tolist(),
            batch_first=True,
            # Let pack_padded do the sorting for ourselves
            enforce_sorted=False,
        )
        # 2) Encode PackedSequence
        padded_encoded_sequence, (final_hidden_state, cell_state) = self._encoder(packed_sequence_input)
        # 3) Unpack encoded sequence
        # encoded_sequence ~ (batch size, sequence length, LSTM hidden size * 2)
        encoded_sequence = pad_packed_sequence(padded_encoded_sequence, batch_first=True)
        # Take average over frist dimension as it returns both forward and backward paths of LSTM
        return LSTMEncoderOutput(
            final_hidden_state.mean(0).unsqueeze(0),
            cell_state.mean(0).unsqueeze(0)
        )

    def get_input_size(self):
        return self._input_size

    def get_output_size(self):
        # Return bidirectional LSTM so mul by 2
        return self._hidden_size * 2


class LSTMDecoder(torch.nn.Module):
    """
    Decode sequence for NMT task with LSTM Module.

    Parameters
    ----------
    input_size : `int`, required
        Hidden size of input tokens.
    hidden_size : `int`, required
        Hidden size for LSTM.
    num_layers : `int`, optional (default = `1`)
        Number of stacked LSTM modules.
    dropout : `int`, optional (default = `0.2`)
        Dropout for LSTM module.
    """
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        num_layers: int = 1,
        dropout: float = 0.2
    ) -> None:
        super().__init__()
        self._input_size = input_size
        self._hidden_size = hidden_size
        self._decoder = torch.nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout,
            batch_first=True,
        )

    def forward(
        self,
        target_tokens: torch.Tensor,
        encoder_output: LSTMEncoderOutput,
        mask: torch.Tensor,
    ) -> torch.Tensor:
        # target_tokens ~ (batch size, sequence length, hidden size)
        # mask ~ (batch size, sequence length)
        # 1) Pack sequence to skip processing of padding in tokens tensor
        # (LSTM returns 0 on padding)
        lenghts = mask.long().sum(dim=-1)
        # Pack sequence to skip padding in RNN
        packed_sequence_input = pack_padded_sequence(
            input=target_tokens,
            lengths=lenghts.data.tolist(),
            batch_first=True,
            # Let pack_padded do the sorting for ourselves
            enforce_sorted=False,
        )
        # 2) Encode PackedSequence
        padded_decoded_sequence = self._decoder(packed_sequence_input, encoder_output)[0]
        # 3) Unpack encoded sequence
        # decoded_sequence ~ (batch size, sequence lenght, LSTM hidden size)
        decoded_sequence, _ = pad_packed_sequence(padded_decoded_sequence, batch_first=True)
        return decoded_sequence

    def get_input_size(self):
        return self._input_size

    def get_output_size(self):
        return self._hidden_size

### Напишем модель на основе Transformer для NMT задачи
Формулы для Positional Encoding, где $i \in [0, \frac{d}{2}]$ и $d$ - это размерность входного токена (размерность эмбеддинга)
$$
\begin{gather}
    PE_{t,\,2i} = sin(t/10000^{\frac{2i}{d}}),\\
    PE_{t,\,2i + 1} = cos(t/10000^{\frac{2i}{d}})
\end{gather}
$$
Или по-другому, где $c_{i}$ - это некоторая константа перед $t$'тым токеном (для Transformer $c_{i} = 1 / 10000^{\frac{2i}{d}}$)
$$
\begin{equation}
    PE_{t} =
    \begin{bmatrix}
        sin(c_{0}t)\\
        cos(c_{0}t)\\
        .\\
        .\\
        .\\
        sin(c_{\frac{d}{2} - 1}t)\\
        cos(c_{\frac{d}{2} - 1}t)
    \end{bmatrix}
\end{equation}
$$

In [None]:
class TransformerEncoderOutput(NamedTuple):
    memory: torch.Tensor
    mask: torch.Tensor


class PositionalEncoding(torch.nn.Module):
    """
    Positional encoding for Transformer

    Parameters
    ----------
    hidden_size : `int`, required
        Hidden size of positional encoding.
        Must match hidden size of input tokens.
    dropout : `float`, required
        Dropout probability after positional encoding addition.
        If None dropout is not considered.
    max_len : `int`, optional (default = `5000`)
        Maximum sequence length to construct Positional Encoding. 
    """

    def __init__(self, hidden_size: int, dropout: float, max_len: int = 5000):
        super().__init__()
        # Compute the positional encodings once in log space.
        # TODO: YOUR CODE HERE
        pass

    def forward(self, tokens: torch.Tensor) -> torch.Tensor:
        # TODO: YOUR CODE HERE
        return torch.Tensor()


# PyTorch Transformer is somewhat weirdly implemented
# so you better write one yourself but for now let's use PyTorch Transformer
class TransformerEncoder(torch.nn.Module):
    """
    Encode input tokens with Transformer.

    Parameters
    ----------
    input_size : `int`, required
        Hidden size of input tokens.
    num_layers : `int`, required
        Number of stacked Transformers.
    feedforward_hidden_dim : `int`, optional (default = `2048`)
        Hidden size for feedforward layer in Transformer.
    num_attention_heads : `int`, optional (default = `8`)
        Number of attention heads in Transformer.
    dropout_prob : `float`, optional (default = `0.1`)
        Dropout probability for Transformer.
    activation : `str`, optional (default = `"relu"`)
        The activation function of intermediate layer, relu or gelu.
    """

    def __init__(
        self,
        input_size: int,
        num_layers: int,
        feedforward_hidden_dim: int = 2048,
        num_attention_heads: int = 8,
        dropout_prob: float = 0.1,
        activation: str = "relu",
    ) -> None:
        super().__init__()
        layer = torch.nn.TransformerEncoderLayer(
            d_model=input_size,
            nhead=num_attention_heads,
            dim_feedforward=feedforward_hidden_dim,
            dropout=dropout_prob,
            activation=activation,
        )
        self._transformer = torch.nn.TransformerEncoder(layer, num_layers)
        self._positional_encoding = PositionalEncoding(
            hidden_size=input_size,
            dropout=None,
        )
        self._input_size = input_size

    def forward(self, source_tokens: torch.Tensor, mask: torch.Tensor) -> TransformerEncoderOutput:
        # source_tokens ~ (batch size, sequence length, hidden size)
        # mask ~ (batch size, sequence length)
        tokens = self._positional_encoding(source_tokens)
        # For some reason the torch transformer expects the shape (sequence, batch, features), not the more
        # familiar (batch, sequence, features), so we have to fix it.
        # tokens ~ (sequence length, batch size, hidden size)
        tokens = tokens.permute(1, 0, 2)
        # For some other reason, the torch transformer takes the mask backwards.
        tokens_mask = ~mask
        # output ~ (sequence length, batch size, hidden size)
        output = self._transformer(tokens, src_key_padding_mask=tokens_mask)
        # output ~ (batch size, sequence length, hidden size)
        output = output.permute(1, 0, 2)
        return TransformerEncoderOutput(output, mask)

    def get_input_size(self) -> int:
        return self._input_size

    def get_output_size(self) -> int:
        return self._input_size


class TransformerDecoder(torch.nn.Module):
    """
    Decode sequence for NMT task with Transformer.

    Parameters
    ----------
    input_size : `int`, required
        Hidden size of input tokens.
    num_layers : `int`, required
        Number of stacked Transformers.
    feedforward_hidden_dim : `int`, optional (default = `2048`)
        Hidden size for feedforward layer in Transformer.
    num_attention_heads : `int`, optional (default = `8`)
        Number of attention heads in Transformer.
    dropout_prob : `float`, optional (default = `0.1`)
        Dropout probability for Transformer.
    activation : `str`, optional (default = `"relu"`)
        The activation function of intermediate layer, relu or gelu.
    """

    def __init__(
        self,
        input_size: int,
        num_layers: int,
        feedforward_hidden_dim: int = 2048,
        num_attention_heads: int = 8,
        dropout_prob: float = 0.1,
        activation: str = "relu",
    ) -> None:
        super().__init__()
        layer = torch.nn.TransformerDecoderLayer(
            d_model=input_size,
            nhead=num_attention_heads,
            dim_feedforward=feedforward_hidden_dim,
            dropout=dropout_prob,
            activation=activation
        )
        self._transformer = torch.nn.TransformerDecoder(layer, num_layers=num_layers)
        self._input_size = input_size

    def forward(
        self,
        target_tokens: torch.Tensor,
        encoder_output: TransformerEncoderOutput,
        mask: torch.Tensor
    ) -> torch.Tensor:
        # target_tokens ~ (batch size, sequence length, hidden size)
        # mask ~ (batch size, sequence length)
        # For some reason the torch transformer expects the shape (sequence, batch, features), not the more
        # familiar (batch, sequence, features), so we have to fix it.
        # target_tokens ~ (sequence length, batch size, hidden size)
        # memory ~ (sequence length, batch size, hidden size)
        target_tokens = target_tokens.permute(1, 0, 2)
        memory = encoder_output.memory.permute(1, 0, 2)
        # For some other reason, the torch transformer takes the mask backwards.
        target_mask = ~mask
        memory_mask = ~encoder_output.mask
        # output ~ (sequence length, batch size, hidden size)
        output = self._transformer(
            target_tokens, memory,
            tgt_key_padding_mask=target_mask,
            memory_key_padding_mask=memory_mask,
        )
        # output ~ (batch size, sequence length, hidden size)
        output = output.permute(1, 0, 2)
        return output

    def get_input_size(self) -> int:
        return self._input_size

    def get_output_size(self) -> int:
        return self._input_size

### Напишем Loss под нашу задачу

До этого мы использовали дефолтные лоссы в PyTorch, но теперь нам нужно написать свой лосс, так как мы хотим посчитать `CrossEntropyLoss` по тензору из 3 размерностей, в то время как `PyTorch CrossEntropyLoss` принимает тензоры размерности 2.

Возможно в примерах в интернете вы виделе как пишут свой лосс через `torch.autograd.Function`. Это хорошая практики и отличная тренировка писать ручками свой Backward, однако такой подход лучше всего подходит, когда у вас сложный лосс и там возникают for. Для нашего лосс это не нужно, поэтому сдалем всё через знакомый нам `torch.nn.Module`.

Очевидное решение, это просто сделать reshape и соединить `batch size` и `sequence length`, однако в таком случае мы будет считать лосс ещё и по `padding`, что нам совсем не нужно.

Поэтому напишем свой `CrossEntropyLoss`, который не будет учитывать `padding`. Его идея будет заключаться в том, чтобы взять логиты под индексами, которые не являются `padding`. Для этого очень хорошо подходит операция [torch.gather](https://pytorch.org/docs/stable/generated/torch.gather.html#torch-gather), которая буквально переводится как "собрать".

In [None]:
class SequenceNLLLoss(torch.nn.Module):
    """NLL Loss for NMT task that ignores computing loss on paddding."""
    def __init__(self, size_average: bool = True) -> None:
        super().__init__()
        self._size_average = size_average

    def forward(
        self,
        logits: torch.Tensor,
        target: torch.Tensor,
        weights: torch.FloatTensor,
    ) -> torch.Tensor:
        # logits ~ (batch size, sequence length, num_classes)
        # target ~ (batch size, sequence length)
        # weights ~ (batch size, sequence length)
        non_batch_dims = tuple(range(1, weights.dim()))
        # weights_batch_sum  (batch size,)
        weights_batch_sum = weights.sum(dim=non_batch_dims)
        # logits_flat ~ (batch size * sequence length, num_classes)
        logits_flat = logits.view(-1, logits.size(-1))
        log_probs_flat = torch.log_softmax(logits_flat, dim=-1)
        # targets_flat ~ (batch size * sequence length, 1)
        targets_flat = target.view(-1, 1).long()
        # nll_loss ~ (batch size * sequence length,)
        # Gather numbers from log_probs_flat in targets_flat indices
        nll_loss = -torch.gather(log_probs_flat, dim=1, index=targets_flat)
        # nll_loss ~ (batch size, sequence length)
        nll_loss = nll_loss.view(*target.size()) * weights
        # per_batch_nll_loss ~ (batch size,)
        per_batch_nll_loss = nll_loss.sum(non_batch_dims) / torch.clamp(weights_batch_sum, min=1e-13)
        if self._size_average:
            num_non_empty_sequences = torch.clamp(weights_batch_sum.gt(0).float().sum(), min=1e-13)
            return per_batch_nll_loss.sum() / num_non_empty_sequences
        else:
            return per_batch_nll_loss

### Напишем в общем виде модель для задачи NMT, которая будет поддерживать общий API моделей LSTM и Transformer описанных выше

In [None]:
class NMTModel(torch.nn.Module):
    """
    Model for NMT task.

    Parameters
    ----------
    vocab : `Vocabulary`, required
        Vocabulary with label encoding and label decoding constructed on datasets.
    embeddings : `torch.nn.ModuleDict`, required
        Torch ModuleDict of embeddings. It should have source and target keys.
    encoder : `torch.nn.Module`, required
        Encoder module for input tokens.
    decoder : `torch.nn.Module`, required
        Decoder module for encoder output.
    """

    def __init__(
        self,
        vocab: Vocabulary,
        embeddings: torch.nn.ModuleDict,
        # Better use seperate Interface for encoder and decoder
        # but for now let's keep it simple stupid and use torch.nn.Module
        encoder: torch.nn.Module,
        decoder: torch.nn.Module,
        device: str,
    ) -> None:
        super().__init__()
        self.vocab = vocab
        self.device = device
        self._embeddings = embeddings
        self._encoder = encoder
        self._decoder = decoder
        self._projection = torch.nn.Linear(
            in_features=self._decoder.get_output_size(),
            out_features=vocab.get_size("target")
        )
        self._loss = SequenceNLLLoss()

    def forward(
        self,
        source_tokens: torch.Tensor,
        target_tokens: torch.Tensor,
        source_mask: torch.Tensor,
        target_mask: torch.Tensor,
    ) -> Dict[str, torch.Tensor]:
        # source_tokens ~ (batch size, sequence length)
        # target_tokens ~ (batch size, sequence length)
        # source_mask ~ (batch size, sequence length)
        # target_mask ~ (batch size, sequence length)
        # 1) Embed source and run encoder
        encoder_output = self._encoder(
            self._embeddings["source"](source_tokens),
            mask=source_mask
        )
        # 2) Embed target and run encoder
        decoder_output = self._decoder(
            self._embeddings["target"](target_tokens),
            encoder_output=encoder_output,
            mask=target_mask
        )
        # 3) Project output to vocabulary
        output = self._projection(decoder_output)
        # 4) Construct output dict
        output_dict = {"logits": output, "probs": torch.softmax(output, dim=-1)}
        output_dict["prediction"] = torch.argmax(output_dict["probs"], dim=-1)
        output_dict["loss"] = self._loss(
            logits=output_dict["logits"],
            target=target_tokens,
            weights=target_mask.float(),
        )
        return output_dict

    def decode(self, indices: List[int], field: str) -> List[str]:
        """Decode sequence of `indices` based on `filed` in Vocabulary."""
        return [self.vocab.index2word[field][int(idx)] for idx in indices]

### Ниже уже знакомый вам train-evaluation loop

In [None]:
def train_epoch(
    model: torch.nn.Module,
    data_loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    return_losses: bool = False,
) -> Union[Dict[str, float], Tuple[Dict[str, float], List[float]]]:
    model.train()
    total_loss = 0
    num_batches = 0
    all_losses = []
    with tqdm(total=len(data_loader), file=sys.stdout) as prbar:
        for batch in data_loader:
            # Move Batch to GPU
            batch = batch.to_device(model.device, non_blocking=True)
            # Get model results
            output_dict = model(**batch)
            loss = output_dict['loss']
            # Update weights
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            # Update descirption for tqdm
            prbar.set_description(f"Loss: {loss.item():.4f}")
            prbar.update(1)
            total_loss += loss.item()
            num_batches += 1
            all_losses.append(loss.detach().item())
    metrics = {"loss": total_loss / num_batches}
    if return_losses:
        return metrics, all_losses
    else:
        return metrics


def validate(
    model: torch.nn.Module,
    data_loader: DataLoader,
) -> Dict[str, float]:
    model.eval()
    total_loss = 0
    num_batches = 0
    with tqdm(total=len(data_loader), file=sys.stdout) as prbar:
        for batch in data_loader:
            # Move Batch to GPU
            batch = batch.to_device(model.device, non_blocking=True)
            # Get model results
            output_dict = model(**batch)
            loss = output_dict['loss']
            # Update descirption for tqdm
            prbar.set_description(f"Loss: {loss.item():.4f}")
            prbar.update(1)
            total_loss += loss.item()
            num_batches += 1
    metrics = {"loss": total_loss / num_batches}
    return metrics


def random_sample_results(
    model: torch.nn.Module,
    batch: CollateBatch,
) -> None:
    """Randomly sample from batch to log model results."""
    batch = batch.to_device(model.device, non_blocking=True)
    result = model(**batch)
    # Get one sample to log
    idx = random.randint(0, len(batch))
    # Print results
    print(
        f"#############\n"
        f"RANDOM SAMPLE:\n"
        f"Source sample: {' '.join(model.decode(batch['source_tokens'][idx], field='source'))}\n"
        f"Target sample: {' '.join(model.decode(batch['target_tokens'][idx], field='target'))}\n"
        f"Predict sample: {' '.join(model.decode(result['prediction'][idx], field='target'))}\n"
        f"#############"
    )

In [None]:
class LossInfo(NamedTuple):
    full_train_losses: List[float]
    train_epoch_losses: List[float]
    eval_epoch_losses: List[float]


EPOCHS = 10
LR = 0.001

In [None]:
def fit(
    model: torch.nn.Module,
    epochs: int,
    train_data_loader: DataLoader,
    validation_data_loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    show_random: bool = True
) -> LossInfo:
    all_train_losses = []
    epoch_train_losses = []
    epoch_eval_losses = []
    for epoch in range(epochs):
        # Train step
        print(f"Train Epoch: {epoch}")
        train_metrics, one_epoch_train_losses = train_epoch(
            model=model,
            data_loader=train_data_loader,
            optimizer=optimizer,
            return_losses=True,
        )
        print(f"Train step loss: {train_metrics['loss']:.4f}")
        # Save Train losses
        all_train_losses.extend(one_epoch_train_losses)
        epoch_train_losses.append(train_metrics['loss'])
        # Eval step
        print(f"Validation Epoch: {epoch}")
        with torch.no_grad():
            validation_metrics = validate(
                model=model,
                data_loader=validation_data_loader,
            )
            print(f"Validation step loss: {validation_metrics['loss']:.4f}")
            if show_random:
                # Get random batch
                batch = next(iter(validation_data_loader))
                random_sample_results(model, batch)
        # Save eval losses
        epoch_eval_losses.append(validation_metrics['loss'])
    return LossInfo(all_train_losses, epoch_train_losses, epoch_eval_losses)

### Ну что же... Перейдём к обучению моделей

#### Сначала LSTM

In [None]:
# Сначала encoder
input_size = 128
encoder_hidden_size = 256
encoder = LSTMEncoder(
    input_size=input_size,
    hidden_size=encoder_hidden_size,
    # Dropout in LSTM Module is activated only if pass num_layers > 1
    dropout=0.0,
)
# Теперь decoder
decoder = LSTMDecoder(
    input_size=encoder_hidden_size,
    hidden_size=encoder_hidden_size,
    # Dropout in LSTM Module is activated only if pass num_layers > 1
    dropout=0.0,
)
# Ну и объединим это всё в одну модель
lstm_model = NMTModel(
    vocab=vocab,
    embeddings=torch.nn.ModuleDict({
        "source": torch.nn.Embedding(vocab.get_size("source"), input_size, padding_idx=0),
        "target": torch.nn.Embedding(vocab.get_size("target"), encoder_hidden_size, padding_idx=0),
    }),
    encoder=encoder,
    decoder=decoder,
    device=DEVICE,
).to(device=DEVICE)
print(f"Количество параметров: {sum(p.numel() for p in lstm_model.parameters())}")

In [None]:
loss_info = fit(
    model=lstm_model,
    epochs=EPOCHS,
    train_data_loader=train_dataloader,
    validation_data_loader=validation_dataloader,
    optimizer=torch.optim.Adam(lstm_model.parameters(), lr=LR)
)

#### Теперь Transformer

In [None]:
# Сначала encoder
encoder = TransformerEncoder(
    input_size=128,
    num_layers=1,
    feedforward_hidden_dim=512,
    num_attention_heads=2,
)
# Теперь decoder
decoder = TransformerDecoder(
    input_size=encoder.get_output_size(),
    num_layers=1,
    feedforward_hidden_dim=512,
    num_attention_heads=2,
)
# Ну и объединим это всё в одну модель
transformer_model = NMTModel(
    vocab=vocab,
    embeddings=torch.nn.ModuleDict({
        "source": torch.nn.Embedding(vocab.get_size("source"), 128),
        "target": torch.nn.Embedding(vocab.get_size("target"), 128),
    }),
    encoder=encoder,
    decoder=decoder,
    device=DEVICE,
).to(device=DEVICE)
print(f"Количество параметров: {sum(p.numel() for p in transformer_model.parameters())}")

In [None]:
loss_info = fit(
    model=transformer_model,
    epochs=EPOCHS,
    train_data_loader=train_dataloader,
    validation_data_loader=validation_dataloader,
    optimizer=torch.optim.Adam(transformer_model.parameters(), lr=LR)
)

## BERT
BERT - это просто несколько stacked Transformer Encoder слоёв из общей модели Transformer + задача Masked Language Modelling.

**Зачем нужен вообще Masked Language Modelling?**

Оригинально BERT претрейнили на задачу Language Modelling, ну и как можно догадаться язык у нас строится последовательно слово за слово, но обычный Transformer так не умеет, так как это bidirectional модель. Такая архитектура не подходит для Language Modelling в общей постановке, так как каждое слово будет косвенно видеть само себя в контекста. Тогда создатели статьи решили применить задачу Cloze или же Masked Language Modelling.

Особенность Masked Language Modelling заключается в том, что мы случайном образом ставим вместо слова **[MASK]** и пытаемся по контексту предсказать тот токен, который там был изначально.

## BPE
Кроме того, в BERT активно используется такая вещь как Byte Pair Encoding. С помощью него решают проблему очень большого вокабуляра для языка, BPE позволяет регулировать его размер. Грубо говоря BPE сегментация оставляет самые частотные токены в датасете, а редкие разделяет на несколько токенов. Это, как можно догадаться, решает проблему дисбаланаса классов, когда наша модель почти не будет предсказывать редкие слова. Притом BPE потенциально позволяет модели научиться распознавать морфологию языка, композицию слов и даже транслитерацию.
* Вот ссылочка на отличную статью на тему BPE: https://arxiv.org/abs/1910.13267
* Для BPE я люблю использовать SentencePiece от Google.

Тут ещё стоит добавить, что некоторые из вас, читая статью от Google про BERT, могли заметить такую вещь как WordPiece, которую использовали исследователи для BPE токенизации. WordPiece отличается от SentencePiece тем, что строит вокабуляр не частотно как последний, а через максимизацию правдоподбия. Теоретически это должно быть эффективно, однако ребята, написавшие RoBERTa, проверили это и пришли к выводу, что частотного подхода достаточно.

## GPT
![SegmentLocal](media/gpt-3-example.gif "GPT-3")