# Attention

Привет! Это семинарский ноутбук для курса DL Basic для Тинькофф. В этом ноутбуке мы реализуем внимание, которое мы обсуждали на лекции. Внимание - это один из самых популярных подходов в NLP. В этом ноутбуке мы реализуем его для задачи классификации текстов.

Для начала установим все нужные библиотеки. Если вы используете Google Colab, то просто запустите следующую ячейку. Если вы используете свой компьютер, то установите все библиотеки, перечисленные в следующей ячейке.

In [85]:
#!g1.1
import subprocess
import sys

In [86]:
#!g1.1
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    subprocess.run("pip install datasets nltk gensim transformers einops evaluate", shell=True)
    subprocess.run("python -m nltk.downloader punkt wordnet", shell=True)

In [87]:
#!g1.1
import torch
import nltk
import einops
import evaluate
import math
import copy
import time

from datasets import load_dataset
from tqdm.notebook import tqdm, tnrange

In [88]:
#!g1.1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## SST-2

Вспомним задачу с предыдущего занятия - SST-2. Это задача классификации текстов на два класса: положительный и отрицательный. Для этого мы будем использовать датасет SST-2. Для начала загрузим его, сделаем токенизацию и превращение в эмбеддинги.

In [89]:
#!g1.1
sst2_dataset = load_dataset("sst2")
sst2_dataset

Downloading builder script:   0%|          | 0.00/3.77k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.85k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/5.10k [00:00<?, ?B/s]

Downloading and preparing dataset sst2/default to /tmp/xdg_cache/huggingface/datasets/sst2/default/2.0.0/9896208a8d85db057ac50c72282bcb8fe755accc671a57dd8059d4e130961ed5...


Downloading data:   0%|          | 0.00/7.44M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]

  arrays.append(pa.array(typed_sequence))


Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]

Dataset sst2 downloaded and prepared to /tmp/xdg_cache/huggingface/datasets/sst2/default/2.0.0/9896208a8d85db057ac50c72282bcb8fe755accc671a57dd8059d4e130961ed5. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['idx', 'sentence', 'label'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['idx', 'sentence', 'label'],
        num_rows: 872
    })
    test: Dataset({
        features: ['idx', 'sentence', 'label'],
        num_rows: 1821
    })
})

In [90]:
#!g1.1
tokenizer = nltk.WordPunctTokenizer()
lemmatizer = nltk.WordNetLemmatizer()

In [91]:
#!g1.1
nltk.download('wordnet')

[nltk_data] Downloading package wordnet to /home/jupyter/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


True

In [92]:
#!g1.1
def tokenize_pipeline(sentence):
    tokens = tokenizer.tokenize(sentence)
    return [lemmatizer.lemmatize(token) for token in tokens if token.isalpha()]

Для работы с эмбеддингами, надо сделать словарь. В нем буду все слова, которые мы знаем и для которых есть эмбеддинг.

In [93]:
#!g1.1
tokenized = (
    [tokenize_pipeline(sentence["sentence"]) for sentence in sst2_dataset["train"]] +
    [tokenize_pipeline(sentence["sentence"]) for sentence in sst2_dataset["validation"]] +
    [tokenize_pipeline(sentence["sentence"]) for sentence in sst2_dataset["test"]]
)

In [94]:
#!g1.1
tokenized[0]

['hide', 'new', 'secretion', 'from', 'the', 'parental', 'unit']

In [95]:
#!g1.1
all_tokenized_words = set(word for words in tokenized for word in words)

In [96]:
#!g1.1
words_to_ids = {word: idx + 16 for idx, word in enumerate(all_tokenized_words)}

In [97]:
#!g1.1
len(all_tokenized_words)

14512

Добавим в наш датасет преврашение в ids:

In [98]:
#!g1.1
class SST2Dataset(torch.utils.data.Dataset):
    def __init__(self, tokenizer, words_to_ids, dataset, max_len=64):
        self.tokenizer = tokenizer
        self.words_to_ids = words_to_ids
        
        def tokenizer_sentece(example):
            return {"tokens": self.tokenizer(example["sentence"])}
        
        def convert_words_to_ids(example):
            return {"ids": [self.words_to_ids[token] for token in example["tokens"]]}

        dataset = dataset.map(tokenizer_sentece)
        
        self.dataset = dataset.map(convert_words_to_ids)
        self.max_len = 64

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        example = self.dataset[index]
        tokens_ids = example["ids"][:self.max_len]
        if len(tokens_ids) < self.max_len:
            tokens_ids += [0 for _ in range(self.max_len - len(tokens_ids))]
        return tokens_ids, example["label"]

In [99]:
#!g1.1
train_dataset = SST2Dataset(tokenize_pipeline, words_to_ids, sst2_dataset["train"])
valid_dataset = SST2Dataset(tokenize_pipeline, words_to_ids, sst2_dataset["validation"])



Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

In [127]:
#!g1.1
def collate_fn(items):
    x = torch.tensor([i[0] for i in items])
    y = torch.tensor([i[1] for i in items])
    return x, y

In [175]:
#!g1.1
BATCH_SIZE = 256

In [130]:
#!g1.1
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    collate_fn=collate_fn,
    num_workers=8
)
valid_dataloader = torch.utils.data.DataLoader(
    valid_dataset, 
    batch_size=BATCH_SIZE, 
    collate_fn=collate_fn,
    num_workers=8
)

In [131]:
#!g1.1
next(iter(train_dataloader))

(tensor([[ 8066,  8746, 10335,  ...,     0,     0,     0],
         [ 1675,  7278,    73,  ...,     0,     0,     0],
         [ 8490, 13824,  2499,  ...,     0,     0,     0],
         ...,
         [ 9512,  1748, 11172,  ...,     0,     0,     0],
         [  928,  7748,  7044,  ...,     0,     0,     0],
         [12694,  9864,  9028,  ...,     0,     0,     0]]),
 tensor([0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1,
         0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0,
         1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0,
         1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1,
         0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0,
         1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0,
         0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1,
         1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 

## Attention

В этой части мы разберем, как работает механизм внимания и попробуем его написать своими руками. Для этого нам пригодится специальный модуль -- [einops](https://github.com/arogozhnikov/einops)!

Шпаргалка, как работает Attention:

$$
\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}(\frac{\mathbf{Q}\mathbf{K}^\top}{\sqrt{n}})\mathbf{V}
$$

<div>
<img src="https://lilianweng.github.io/posts/2018-06-24-attention/multi-head-attention.png" width="40%">  
</div>

Попробуем вычислить attention на случайном примере. У нас будут 4 головы

In [132]:
#!g1.1
hidden_example = torch.rand((1, 10, 16)) # batch_size, seq_len, hidden_size
hidden_example_headed = hidden_example.reshape(1, 10, 4, 4) # b, t, c, l
q, k, v = hidden_example_headed, hidden_example_headed, hidden_example_headed

Посмотрим, как с помощью `einops` сделать матричное произведение:

In [133]:
#!g1.1
q.shape # b, s, c, l

torch.Size([1, 10, 4, 4])

In [134]:
#!g1.1
k.shape # b, t, c, l

torch.Size([1, 10, 4, 4])

In [135]:
#!g1.1
torch.einsum('bscl,btcl->bstl', [k, q]).shape # 1, 10, 10, 4 (b, s, t, l)

torch.Size([1, 10, 10, 4])

In [136]:
#!g1.1
torch.softmax(torch.einsum('bscl,btcl->bstl', [k, q]), dim=1).shape

torch.Size([1, 10, 10, 4])

Посчитаем внимание и вычислим выход:

In [137]:
#!g1.1
attention = torch.softmax(torch.einsum('bscl,btcl->bstl', [k, q]), dim=1)
result_headed = torch.einsum('bstl,bscl->btcl', [attention, v])
result = result_headed.reshape(1, 10, 16)

In [138]:
#!g1.1
result

tensor([[[0.5128, 0.3561, 0.5493, 0.4876, 0.5267, 0.5736, 0.5587, 0.6351,
          0.7110, 0.5628, 0.6069, 0.4776, 0.5839, 0.5166, 0.6039, 0.5183],
         [0.5502, 0.3253, 0.6234, 0.6324, 0.5322, 0.5694, 0.6101, 0.6259,
          0.7006, 0.5339, 0.5762, 0.5507, 0.5401, 0.4517, 0.5953, 0.6558],
         [0.5696, 0.3107, 0.5928, 0.4874, 0.5834, 0.6037, 0.5990, 0.6125,
          0.7238, 0.5625, 0.5775, 0.4849, 0.5673, 0.4342, 0.5946, 0.5134],
         [0.5044, 0.3339, 0.5901, 0.5852, 0.5628, 0.6086, 0.6307, 0.6587,
          0.7281, 0.5133, 0.6273, 0.4916, 0.6161, 0.4905, 0.5896, 0.6211],
         [0.5478, 0.3389, 0.5761, 0.5700, 0.5145, 0.5832, 0.5947, 0.6117,
          0.7157, 0.5849, 0.5553, 0.5119, 0.5330, 0.4922, 0.6110, 0.5942],
         [0.5001, 0.3420, 0.5876, 0.5376, 0.5130, 0.5963, 0.6045, 0.6703,
          0.6981, 0.5082, 0.6202, 0.4556, 0.5859, 0.4556, 0.6011, 0.5684],
         [0.5129, 0.3206, 0.5531, 0.6008, 0.4935, 0.5991, 0.6103, 0.6180,
          0.6866, 0.5593, 0.5802

Соединим наши операции в одну функцию:

In [139]:
#!g1.1
def attention(K, V, Q, num_head):
    batch_size, seq_len, hidden_dim = Q.size()
    K = K.reshape(batch_size, seq_len, -1, num_head)
    Q = Q.reshape(batch_size, seq_len, -1, num_head)
    V = V.reshape(batch_size, seq_len, -1, num_head)
    attention = torch.softmax(torch.einsum('bscl,btcl->bstl', [K, Q]) / math.sqrt(hidden_dim // num_head), dim=1)
    result_headed = torch.einsum('bstl,btcl->btcl', [attention, V])
    return result_headed.reshape(batch_size, seq_len, hidden_dim)

In [140]:
#!g1.1
hidden_example = torch.rand((1, 10, 16))

attention(hidden_example, hidden_example, hidden_example, 4)

tensor([[[0.3532, 0.3087, 0.8533, 0.3811, 0.3869, 0.9242, 0.4588, 0.0943,
          0.1812, 0.2577, 0.7815, 0.4578, 0.7215, 0.7940, 0.0060, 0.2070],
         [0.1072, 0.0355, 0.0968, 0.8201, 0.3911, 0.2392, 0.4216, 0.7789,
          0.6255, 0.6707, 0.9069, 0.0929, 0.4780, 0.5028, 0.9313, 0.3110],
         [0.1175, 0.8704, 0.8328, 0.7375, 0.2083, 0.7545, 0.3205, 0.6627,
          0.5592, 0.2013, 0.5803, 0.4320, 0.9566, 0.9780, 0.3005, 0.6894],
         [0.0234, 0.6272, 0.7860, 0.0703, 0.7788, 0.5237, 0.7940, 0.9806,
          0.8161, 0.8385, 0.6523, 0.1674, 0.9757, 0.2881, 0.4313, 0.9480],
         [0.2551, 0.9772, 0.2181, 0.4976, 0.5112, 0.2398, 0.0742, 0.4700,
          0.6835, 0.7303, 0.2098, 0.1885, 0.1930, 0.2555, 0.6862, 0.9558],
         [0.5956, 0.7076, 0.3705, 0.7396, 0.2858, 0.9369, 0.6454, 0.6604,
          0.5156, 0.6973, 0.7708, 0.9178, 0.2674, 0.6862, 0.7491, 0.0775],
         [0.6750, 0.7217, 0.9530, 0.4819, 0.9958, 0.2670, 0.0181, 0.7064,
          0.4767, 0.4841, 0.5969

Воспользуемся нашей функцией для создания модуля Attention. Также сделаем Feed-Forward(или MLP) слой и построим всю нашу модель.

In [141]:
#!g1.1
class AttentionModule(torch.nn.Module):
    def __init__(self, hidden_dim: int, num_heads: int):
        super().__init__()

        self.q_linear = torch.nn.Linear(hidden_dim, hidden_dim)
        self.k_linear = torch.nn.Linear(hidden_dim, hidden_dim)
        self.v_linear = torch.nn.Linear(hidden_dim, hidden_dim)

        self.out_linear = torch.nn.Linear(hidden_dim, hidden_dim)
        self.num_heads = num_heads
    
    def forward(self, hidden_state):
        Q = self.q_linear(hidden_state)
        K = self.k_linear(hidden_state)
        V = self.v_linear(hidden_state)
        attention_output = attention(K, V, Q, self.num_heads)
        return self.out_linear(attention_output) + hidden_state

In [142]:
#!g1.1
class MLP(torch.nn.Module):
    def __init__(self, hidden_dim: int):
        super().__init__()
        
        self.linear_0 = torch.nn.Linear(hidden_dim, 4 * hidden_dim)
        self.linear_1 = torch.nn.Linear(4 * hidden_dim, hidden_dim)
    
    def forward(self, hidden_state):
        return self.linear_1(torch.relu(self.linear_0(hidden_state))) + hidden_state

In [143]:
#!g1.1
class TransformerLayer(torch.nn.Module):
    def __init__(self, hidden_dim: int, num_heads: int):
        super().__init__()
        
        self.attention_layer = AttentionModule(hidden_dim, num_heads)
        self.mlp_layer = MLP(hidden_dim)
    
    def forward(self, hidden_state):
        attn_output = self.attention_layer(hidden_state)
        mlp_output = self.mlp_layer(attn_output)
        return mlp_output

In [156]:
#!g1.1
class TransformerModel(torch.nn.Module):
    def __init__(self, embedding_size, hidden_dim, num_heads, output_dim, max_seq_len=64):
        super().__init__()
        
        self.word_embedding = torch.nn.Embedding(
            embedding_size, hidden_dim)
        self.pos_embedding = torch.nn.Embedding(
            max_seq_len, hidden_dim)
        
        self.attention_layer_0 = TransformerLayer(hidden_dim, num_heads)
        self.attention_layer_1 = TransformerLayer(hidden_dim, num_heads)
        self.attention_layer_2 = TransformerLayer(hidden_dim, num_heads)
        
        self.cls_head = torch.nn.Linear(hidden_dim, output_dim)
    
    def forward(self, input_ids): # (bs, seq_len)
        input_ids = input_ids
        arange_tensor = torch.arange(input_ids.size(-1)).to(input_ids.device)
      
        word_embs = self.word_embedding(input_ids)
        
        pos_embs = self.pos_embedding(arange_tensor)
        embs = word_embs + pos_embs
        hidden_state = self.attention_layer_0(embs)
        hidden_state = self.attention_layer_1(hidden_state)
        hidden_state = self.attention_layer_2(hidden_state)

        return self.cls_head(hidden_state[:, 0]) # (bs, num_class)

In [157]:
#!g1.1
example = next(iter(train_dataloader))

In [158]:
#!g1.1
example

(tensor([[ 8066,  8746, 10335,  ...,     0,     0,     0],
         [ 1675,  7278,    73,  ...,     0,     0,     0],
         [ 8490, 13824,  2499,  ...,     0,     0,     0],
         ...,
         [ 9512,  1748, 11172,  ...,     0,     0,     0],
         [  928,  7748,  7044,  ...,     0,     0,     0],
         [12694,  9864,  9028,  ...,     0,     0,     0]]),
 tensor([0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1,
         0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0,
         1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0,
         1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1,
         0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0,
         1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0,
         0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1,
         1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 

In [159]:
#!g1.1
model = TransformerModel(len(all_tokenized_words) + 16, 16, 4, 2)

In [160]:
#!g1.1
model(example[0])

tensor([[-6.3395e-02, -9.7158e-01],
        [-1.6544e-01, -4.7155e-02],
        [-6.5275e-01, -2.2590e-01],
        [ 1.1693e-01, -7.8511e-01],
        [-3.4221e-01, -1.5774e+00],
        [-6.5275e-01, -2.2590e-01],
        [-9.4975e-01, -1.0469e+00],
        [-5.5574e-01,  2.4828e-01],
        [-9.2121e-01,  1.0274e+00],
        [-3.2676e-01, -4.2459e-01],
        [-3.6674e-01,  5.3935e-01],
        [-7.9264e-01, -6.4037e-01],
        [ 2.2017e-02, -9.2330e-01],
        [ 2.7613e-01, -1.7920e+00],
        [-5.6299e-01,  6.9523e-01],
        [ 2.2017e-02, -9.2330e-01],
        [-3.3031e-01, -3.0952e-01],
        [-1.3188e+00, -2.6220e+00],
        [-2.8396e-01, -1.0294e+00],
        [-9.9966e-01, -1.7287e+00],
        [-8.4951e-01,  1.4271e+00],
        [-8.2480e-01, -1.2963e+00],
        [-9.2121e-01,  1.0274e+00],
        [-1.0125e+00, -2.0060e+00],
        [-6.1525e-01,  8.2447e-02],
        [-6.5554e-01, -7.0612e-01],
        [ 4.2182e-01, -1.2997e+00],
        [-6.5053e-01,  5.836

## Training Model!

In [161]:
#!g1.1
device

device(type='cuda')

In [178]:
#!g1.1
model = TransformerModel(len(all_tokenized_words) + 16, 16, 4, 2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss()

In [179]:
#!g1.1
def train_model(model, criterion, optimizer, num_epochs=10):
    len_train_dataset = len(train_dataset)
    for epoch in tnrange(num_epochs): 
        model.train()

        for batch in tqdm(
            train_dataloader, 
            leave=False, 
            total=len_train_dataset / BATCH_SIZE
        ):
            inputs, labels = batch
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            num_classes = outputs.size(-1)
            loss = criterion(outputs.view(-1, num_classes), labels.view(-1))
            optimizer.step()

    return model

In [180]:
#!g1.1
trained_model = train_model(
    model=model, 
    criterion=criterion,
    optimizer=optimizer, 
    num_epochs=20
)

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/263.08203125 [00:00<?, ?it/s]

  0%|          | 0/263.08203125 [00:00<?, ?it/s]

  0%|          | 0/263.08203125 [00:00<?, ?it/s]

  0%|          | 0/263.08203125 [00:00<?, ?it/s]

  0%|          | 0/263.08203125 [00:00<?, ?it/s]

  0%|          | 0/263.08203125 [00:00<?, ?it/s]

  0%|          | 0/263.08203125 [00:00<?, ?it/s]

  0%|          | 0/263.08203125 [00:00<?, ?it/s]

  0%|          | 0/263.08203125 [00:00<?, ?it/s]

  0%|          | 0/263.08203125 [00:00<?, ?it/s]

  0%|          | 0/263.08203125 [00:00<?, ?it/s]

  0%|          | 0/263.08203125 [00:00<?, ?it/s]

  0%|          | 0/263.08203125 [00:00<?, ?it/s]

  0%|          | 0/263.08203125 [00:00<?, ?it/s]

  0%|          | 0/263.08203125 [00:00<?, ?it/s]

  0%|          | 0/263.08203125 [00:00<?, ?it/s]

  0%|          | 0/263.08203125 [00:00<?, ?it/s]

  0%|          | 0/263.08203125 [00:00<?, ?it/s]

  0%|          | 0/263.08203125 [00:00<?, ?it/s]

  0%|          | 0/263.08203125 [00:00<?, ?it/s]

In [181]:
#!g1.1
def validate_model(model, criterion):
    current_loss = 0.0 
    current_corrects = 0.0
    best_acc = 0.0
    len_valid_dataloader = len(valid_dataloader)
    since = time.time() 

    with torch.inference_mode():
        model.eval() # Set model to evaluate mode 
        for batch in valid_dataloader:
            inputs, labels = batch
            
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1) 
          
            num_classes = outputs.size(-1)
            current_loss += criterion(outputs.view(-1, num_classes), labels.view(-1))
            current_corrects += torch.sum(preds == labels.data) 

    print('Loss: {:.4f}'.format(current_loss / len_valid_dataloader))
    print('Acc: {:.4f}'.format(current_corrects / len_valid_dataloader / BATCH_SIZE))
    return current_loss

In [182]:
#!g1.1
validate_los = validate_model(model, criterion)

Loss: 0.7881
Acc: 0.4375


## Autoregressive Attention

Попробуем другой способ решения задачи -- заставим модель генерировать токен, который будет означать тот или иной класс. Добавим эти токены и несколько других -- BOS, EOS.

In [183]:
#!g1.1
words_to_ids["[BOS]"] = 1
words_to_ids["[EOS]"] = 2

In [184]:
#!g1.1
class SST2DatasetForSeqCls(torch.utils.data.Dataset):
    def __init__(self, tokenizer, words_to_ids, dataset, max_len=64):
        self.tokenizer = tokenizer
        self.words_to_ids = words_to_ids
        
        def tokenizer_sentece(example):
            return {"tokens": self.tokenizer(example["sentence"])}
        
        def convert_words_to_ids(example):
            # don't forget to add new tokes for classes
            return {"ids": [self.words_to_ids[token] for token in example["tokens"]]}

        dataset = dataset.map(tokenizer_sentece)
        
        self.dataset = dataset.map(convert_words_to_ids)
        self.max_len = max_len

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        tokens = [1] + self.dataset[index]["ids"][:self.max_len - 2] + [2]
        if len(tokens) < self.max_len:
            tokens += [2 for _ in range(self.max_len - len(tokens))]
        return tokens

In [185]:
#!g1.1
def collate_fn(items):
    return torch.tensor(items)

In [211]:
#!g1.1
sst2_dataset["train"]

Dataset({
    features: ['idx', 'sentence', 'label'],
    num_rows: 67349
})

In [186]:
#!g1.1
train_dataset = SST2DatasetForSeqCls(tokenize_pipeline, words_to_ids, sst2_dataset["train"])
valid_dataset = SST2DatasetForSeqCls(tokenize_pipeline, words_to_ids, sst2_dataset["validation"])

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)

Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

In [189]:
#!g1.1
train_dataset[0]

[1,
 8066,
 8746,
 10335,
 12718,
 12905,
 10140,
 4109,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2]

In [190]:
#!g1.1
next(iter(train_dataloader))

tensor([[    1,  8066,  8746,  ...,     2,     2,     2],
        [    1,  1675,  7278,  ...,     2,     2,     2],
        [    1,  8490, 13824,  ...,     2,     2,     2],
        ...,
        [    1,  3245,  3805,  ...,     2,     2,     2],
        [    1, 10263,  2029,  ...,     2,     2,     2],
        [    1, 12905,  1965,  ...,     2,     2,     2]])

Также нам надо поменять способ для вычисления внимания -- для генерации текста мы не можем смотреть в будущее, то чего ещё нет. Добавим в вычисление маски!

In [191]:
#!g1.1
# mask t < s
# K bscl
# Q btcl 

In [192]:
#!g1.1
torch.triu(torch.ones((4, 4)))

tensor([[1., 1., 1., 1.],
        [0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.]])

In [193]:
#!g1.1
torch.triu(torch.ones((4, 4))).unsqueeze(0).unsqueeze(-1).shape # (bstl)

torch.Size([1, 4, 4, 1])

In [194]:
#!g1.1
def masked_attention(K, V, Q, num_head):
    batch_size, seq_len, hidden_dim = Q.size()
    K = K.reshape(batch_size, seq_len, -1, num_head)
    Q = Q.reshape(batch_size, seq_len, -1, num_head)
    V = V.reshape(batch_size, seq_len, -1, num_head)
    attention = torch.softmax(torch.einsum('bscl,btcl->bstl', [K, Q]) / math.sqrt(hidden_dim // num_head), dim=1)
    mask = torch.triu(torch.ones((seq_len, seq_len))).unsqueeze(0).unsqueeze(-1)
    masked_attetion = mask * attention
    result_headed = torch.einsum('bstl,bscl->btcl', [masked_attetion, V])
    return result_headed.reshape(batch_size, seq_len, hidden_dim)

# add all other classes

In [195]:
#!g1.1
hidden_example = torch.rand((1, 10, 16))

masked_attention(hidden_example, hidden_example, hidden_example, 4).shape

torch.Size([1, 10, 16])

In [196]:
#!g1.1
hidden_example = torch.eye(4).unsqueeze(0) + 0.01
hidden_example[0, 1]

tensor([0.0100, 1.0100, 0.0100, 0.0100])

In [197]:
#!g1.1
hidden_example[0, 2]

tensor([0.0100, 0.0100, 1.0100, 0.0100])

In [198]:
#!g1.1
output = masked_attention(hidden_example, hidden_example, hidden_example, 4)
output

tensor([[[0.4826, 0.0025, 0.0025, 0.0025],
         [0.2569, 0.4844, 0.0050, 0.0050],
         [0.2594, 0.2594, 0.4861, 0.0075],
         [0.2619, 0.2619, 0.2619, 0.4879]]])

In [199]:
#!g1.1
output[:, 0], output[:, 1]

(tensor([[0.4826, 0.0025, 0.0025, 0.0025]]),
 tensor([[0.2569, 0.4844, 0.0050, 0.0050]]))

In [200]:
#!g1.1
class AutoregressiveAttentionModule(torch.nn.Module):
    def __init__(self, hidden_dim: int, num_heads: int):
        super().__init__()

        self.q_linear = torch.nn.Linear(hidden_dim, hidden_dim)
        self.k_linear = torch.nn.Linear(hidden_dim, hidden_dim)
        self.v_linear = torch.nn.Linear(hidden_dim, hidden_dim)

        self.out_linear = torch.nn.Linear(hidden_dim, hidden_dim)
        self.num_heads = num_heads
    
    def forward(self, hidden_state):
        Q = self.q_linear(hidden_state)
        K = self.k_linear(hidden_state)
        V = self.v_linear(hidden_state)
        attention_output = masked_attention(K, V, Q, self.num_heads)
        return self.out_linear(attention_output) + hidden_state

In [201]:
#!g1.1
class AutoregressiveTransformerLayer(torch.nn.Module):
    def __init__(self, hidden_dim: int, num_heads: int):
        super().__init__()
        
        self.attention_layer = AutoregressiveAttentionModule(hidden_dim, num_heads)
        self.mlp_layer = MLP(hidden_dim)
    
    def forward(self, hidden_state):
        attn_output = self.attention_layer(hidden_state)
        mlp_output = self.mlp_layer(attn_output)
        return mlp_output

Заставим нашу модель решать задачу авторегрессивной генерации -- предсказывать токен, которые будет следующим!

In [202]:
#!g1.1
class AutoregressiveTransformerModel(torch.nn.Module):
    def __init__(self, embedding_size, hidden_dim, num_heads, output_dim, max_seq_len=64):
        super().__init__()
        
        self.word_embedding = torch.nn.Embedding(embedding_size, hidden_dim)
        self.pos_embedding = torch.nn.Embedding(max_seq_len, hidden_dim)
        
        self.attention_layer_0 = TransformerLayer(hidden_dim, num_heads)
        self.attention_layer_1 = TransformerLayer(hidden_dim, num_heads)
        self.attention_layer_2 = TransformerLayer(hidden_dim,num_heads)
        
        self.cls_head = torch.nn.Linear(hidden_dim, embedding_size)
        
    def forward(self, input_ids):
        arange_tensor = torch.arange(input_ids.size(-1))

        word_embs = self.word_embedding(input_ids)
        pos_embs = self.pos_embedding(arange_tensor)
        embs = word_embs + pos_embs
        hidden_state = self.attention_layer_0(embs)
        hidden_state = self.attention_layer_1(hidden_state)
        hidden_state = self.attention_layer_2(hidden_state)

        return self.cls_head(hidden_state)

In [203]:
#!g1.1
model = AutoregressiveTransformerModel(len(all_tokenized_words) + 16, 16, 4, 2)

In [204]:
#!g1.1
example = next(iter(train_dataloader))

In [205]:
#!g1.1
model(example).shape

torch.Size([16, 64, 14528])

In [206]:
#!g1.1
model = TransformerModel(len(all_tokenized_words) + 16, 16, 4, 2)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss()

In [208]:
#!g1.1
model.train()

for batch in train_dataloader:
    print(batch)
    break

    inputs, labels = batch

    optimizer.zero_grad()
    outputs = model(inputs)

    num_classes = outputs.size(-1)
    loss = criterion(outputs.view(-1, num_classes), labels.view(-1))
    optimizer.step()

tensor([[    1,  8066,  8746,  ...,     2,     2,     2],
        [    1,  1675,  7278,  ...,     2,     2,     2],
        [    1,  8490, 13824,  ...,     2,     2,     2],
        ...,
        [    1,  3245,  3805,  ...,     2,     2,     2],
        [    1, 10263,  2029,  ...,     2,     2,     2],
        [    1, 12905,  1965,  ...,     2,     2,     2]])


In [209]:
#!g1.1
validate_los = validate_model(model, criterion)

ValueError: too many values to unpack (expected 2)

## `transformers` lib

Попробуем решить нашу задачу с помощью предобученной модели [DistilBERT](https://huggingface.co/transformers/model_doc/distilbert.html). Для этого мы будем использовать библиотеку `transformers`.

In [66]:
import transformers

In [67]:
tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = transformers.AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased")

In [68]:
tokenizer("Hi!", return_tensors="pt", padding='max_length', max_length=128)

Подготовим датасет и процесс обучения. Для обучения будем использовать [Trainer](https://huggingface.co/docs/transformers/main_classes/trainer).

In [69]:
class SST2Dataset4DistillBert(torch.utils.data.Dataset):
    def __init__(self, tokenizer, dataset):        
        self.dataset = dataset
        self.tokenizer = tokenizer
        
        # what must be saved?

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        example = self.dataset[index]
        return ... # what must be done?

In [70]:
train_dataset = SST2Dataset4DistillBert(tokenize_pipeline, sst2_dataset["train"])
valid_dataset = SST2Dataset4DistillBert(tokenize_pipeline, sst2_dataset["validation"])

Подготовим метрики для проверки того, что модель обучается.

In [71]:
metric = evaluate.load("accuracy")

In [72]:
import numpy as np

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

И поставим обучение!

In [73]:
!pip install transformers==4.28.0

In [74]:
training_args = transformers.TrainingArguments(output_dir="trainer", evaluation_strategy="epoch")

In [75]:
trainer = transformers.Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    compute_metrics=compute_metrics,
)

In [76]:
trainer.train()

## BLEU

В этом юните мы рассмотрим подсчет BLEU метрики. Это потребуется для вашего домашнего задания.

In [77]:
bleu = evaluate.load("bleu")

In [78]:
references = [
    ["I'm a cat.", "I'm the cat."],
    ["I'm a dog.", "I'm the dog."],
]

In [79]:
predictions = ["I'm a cat.", "I'm the puppy."]

In [80]:
bleu.compute(predictions=predictions, references=references)

In [81]:
# Byte Pair Encoding

In [82]:
# For better NMT transformers:
# - Better Implement Transformers (dropout, Layer Norm, Encoder+Decoder Connection)
# - Teacher Forcing
# - Sequence Sorting (Curriculum Learning)
# - Float16
# - Find papers based on Attention is all Need