#  Машинный перевод с использованием рекуррентных нейронных сетей и механизма внимания

__Автор задач: Блохин Н.В. (NVBlokhin@fa.ru)__

Материалы:
* https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html#the-seq2seq-model
* https://www.adeveloperdiary.com/data-science/deep-learning/nlp/machine-translation-using-attention-with-pytorch/
* http://ethen8181.github.io/machine-learning/deep_learning/seq2seq/2_torch_seq2seq_attention.html
* https://tomekkorbak.com/2020/06/26/implementing-attention-in-pytorch/

## Задачи для совместного разбора

1\. Рассмотрите пример работы `torch.nn.utils.rnn.pad_sequence` для генерации батча.

In [240]:
from torch.utils.data import Dataset, DataLoader
import torch as th


class FakeDset(Dataset):
    def __init__(self):
        self.x = [
            th.LongTensor([2, 27, 705, 2327, 5744, 3]),
            th.LongTensor([2, 7, 29, 240, 5669, 2747, 1479, 3]),
            th.LongTensor([2, 27, 705, 2327, 3]),
            th.LongTensor([2, 7, 29, 240, 5669, 2747, 1479, 7, 29, 240, 5669, 3]),
            th.LongTensor([2, 7, 29, 240, 5669, 2747, 1479, 7, 29, 240, 5669, 3]),
            th.LongTensor([2, 7, 29, 240, 5669, 2747, 1479, 7, 29, 240, 5669, 3]),
            th.LongTensor([2, 7, 29, 240, 5669, 2747, 1479, 7, 29, 240, 5669, 3]),
            th.LongTensor([2, 7, 29, 240, 32, 2747, 1479, 7, 29, 240, 5669, 3]),
            th.LongTensor([2, 7, 29, 240, 5669, 2747, 1479, 7, 29, 240, 5669, 3]),
            th.LongTensor([2, 7, 42, 240, 5669, 2747, 1479, 7, 29, 240, 78, 3]),
            th.LongTensor([2, 7, 29, 899, 5669, 2747, 1479, 7, 29, 240, 5669, 3]),
            th.LongTensor([2, 7, 29, 240, 42, 2747, 1479, 7, 29, 240, 53, 3]),
            th.LongTensor([2, 7, 29, 240, 5669, 42, 1479, 7, 29, 240, 5669, 3]),
            th.LongTensor([2, 7, 29, 240, 5669, 2747, 1479, 7, 42, 240, 5669, 3]),
            th.LongTensor([2, 7, 29, 653, 5669, 2747, 1479, 7, 29, 240, 5669, 3]),
        ]

    def __getitem__(self, idx):
        return (self.x[idx],)

    def __len__(self):
        return len(self.x)


In [241]:
dset = FakeDset()
dset[:]

([tensor([   2,   27,  705, 2327, 5744,    3]),
  tensor([   2,    7,   29,  240, 5669, 2747, 1479,    3]),
  tensor([   2,   27,  705, 2327,    3]),
  tensor([   2,    7,   29,  240, 5669, 2747, 1479,    7,   29,  240, 5669,    3]),
  tensor([   2,    7,   29,  240, 5669, 2747, 1479,    7,   29,  240, 5669,    3]),
  tensor([   2,    7,   29,  240, 5669, 2747, 1479,    7,   29,  240, 5669,    3]),
  tensor([   2,    7,   29,  240, 5669, 2747, 1479,    7,   29,  240, 5669,    3]),
  tensor([   2,    7,   29,  240,   32, 2747, 1479,    7,   29,  240, 5669,    3]),
  tensor([   2,    7,   29,  240, 5669, 2747, 1479,    7,   29,  240, 5669,    3]),
  tensor([   2,    7,   42,  240, 5669, 2747, 1479,    7,   29,  240,   78,    3]),
  tensor([   2,    7,   29,  899, 5669, 2747, 1479,    7,   29,  240, 5669,    3]),
  tensor([   2,    7,   29,  240,   42, 2747, 1479,    7,   29,  240,   53,    3]),
  tensor([   2,    7,   29,  240, 5669,   42, 1479,    7,   29,  240, 5669,    3]),
  tensor([

In [242]:
[x[0] for x in dset]

[tensor([   2,   27,  705, 2327, 5744,    3]),
 tensor([   2,    7,   29,  240, 5669, 2747, 1479,    3]),
 tensor([   2,   27,  705, 2327,    3]),
 tensor([   2,    7,   29,  240, 5669, 2747, 1479,    7,   29,  240, 5669,    3]),
 tensor([   2,    7,   29,  240, 5669, 2747, 1479,    7,   29,  240, 5669,    3]),
 tensor([   2,    7,   29,  240, 5669, 2747, 1479,    7,   29,  240, 5669,    3]),
 tensor([   2,    7,   29,  240, 5669, 2747, 1479,    7,   29,  240, 5669,    3]),
 tensor([   2,    7,   29,  240,   32, 2747, 1479,    7,   29,  240, 5669,    3]),
 tensor([   2,    7,   29,  240, 5669, 2747, 1479,    7,   29,  240, 5669,    3]),
 tensor([   2,    7,   42,  240, 5669, 2747, 1479,    7,   29,  240,   78,    3]),
 tensor([   2,    7,   29,  899, 5669, 2747, 1479,    7,   29,  240, 5669,    3]),
 tensor([   2,    7,   29,  240,   42, 2747, 1479,    7,   29,  240,   53,    3]),
 tensor([   2,    7,   29,  240, 5669,   42, 1479,    7,   29,  240, 5669,    3]),
 tensor([   2,    7,   

In [243]:
from torch.nn.utils.rnn import pad_sequence

In [244]:
class Collator:
  def __init__(self, pad_idx, batch_first=True):
    self.pad_idx = pad_idx
    self.batch_first = batch_first

  def __call__(self, batch):
    x = [x[0] for x in batch]
    # print(batch)
    # print(x)
    x = pad_sequence(
        x,
        batch_first=self.batch_first,
        padding_value=self.pad_idx
    )
    return x

In [245]:
dset

<__main__.FakeDset at 0x7a980bdbf2e0>

In [246]:
loader = DataLoader(
    dset,
    batch_size=4,
    collate_fn=Collator(batch_first=True, pad_idx=0)
)

In [247]:
next(iter(loader))

tensor([[   2,   27,  705, 2327, 5744,    3,    0,    0,    0,    0,    0,    0],
        [   2,    7,   29,  240, 5669, 2747, 1479,    3,    0,    0,    0,    0],
        [   2,   27,  705, 2327,    3,    0,    0,    0,    0,    0,    0,    0],
        [   2,    7,   29,  240, 5669, 2747, 1479,    7,   29,  240, 5669,    3]])

In [248]:
for x in loader:
  print(x.shape)

torch.Size([4, 12])
torch.Size([4, 12])
torch.Size([4, 12])
torch.Size([3, 12])


2\. Рассмотрите основные шаги для реализации механизма аддитивного внимания с использованием RNN.

![attention](https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcS2F23YEOESycSImB4FBnFS09C_R18ffuD9luGO4X4cpF4Unqon5-l6fuWwsvottjU_Aj8&usqp=CAU)

$$c_i = \sum_{j=1}^{T_x}\alpha_{ij}h_j$$

$$\alpha_{ij} = \frac{exp(e_{ij})}{\sum_{j=1}^{T_x}exp(e_{ik})}$$

$$ e_{ij} = a(s_{i-1},h_j) = v_a^Ttanh(W_a s_{i-1} + U_a h_j)$$

In [249]:
import torch as th

In [250]:
batch_size, input_seq_len = 16, 5
encoder_hidden_dim = 100

# X: batch_size x input_seq_len
# encoder(X)
encoder_outputs = th.rand(batch_size, input_seq_len, encoder_hidden_dim)
encoder_final_hidden = encoder_outputs[:, -1, :].unsqueeze(0)
# encoder_outputs: batch_size, input_seq_len x encoder_hidden_dim
# encoder_final_hidden: 1 x batch_size x encoder_hidden_dim

In [251]:
encoder_outputs.shape, encoder_final_hidden.shape   # выходной слой после encoder (out, h)

(torch.Size([16, 5, 100]), torch.Size([1, 16, 100]))

In [252]:
import torch.nn as nn

decoder_hidden_dim = 100

decoder_input = th.rand(batch_size, encoder_hidden_dim)  # 16 x 100
decoder_hidden = encoder_final_hidden[0]   # h, выход encoder-- входной слой для decoder
# print(decoder_hidden)
rnn = nn.GRUCell(encoder_hidden_dim, decoder_hidden_dim)
rnn(decoder_input, decoder_hidden).shape

torch.Size([16, 100])

In [253]:
decoder_hidden = encoder_final_hidden
print(decoder_hidden.shape)
decoder_hidden = (
    decoder_hidden
    .repeat(input_seq_len, 1, 1)
    .permute(1, 0, 2)
) # batch x input_seq_len x encoder_hidden
decoder_hidden.shape

torch.Size([1, 16, 100])


torch.Size([16, 5, 100])

In [254]:
encoder_with_hidden = th.cat(
    (decoder_hidden, encoder_outputs),
    dim=2
)   # сложение размерностей 2
encoder_with_hidden.shape

torch.Size([16, 5, 200])

In [255]:
fc1 = nn.Linear(200, decoder_hidden_dim)

attn_hidden = fc1(encoder_with_hidden).tanh()
# attn_hidden : batch x input_seq_len x decoder_hidden_dim

fc2 = nn.Linear(decoder_hidden_dim, 1)
energy = fc2(attn_hidden).squeeze(2)
print(energy.shape)
# energy: batch x input_seq x 1
attn = energy.softmax(dim=1)
# attn : batch x input_seq

torch.Size([16, 5])


In [256]:
# encoder_outputs: batch_size x input_seq_len x encoder_hidden_dim
attn = attn.unsqueeze(1)
# attn : batch_size x 1 x input_seq_len
weighted_encoder_outputs = th.bmm(attn, encoder_outputs).squeeze(1)
weighted_encoder_outputs.shape

torch.Size([16, 100])

In [257]:
# возвращаемся сюда
decoder_input = th.rand(batch_size, encoder_hidden_dim)

rnn_input = th.cat(
    (decoder_input, weighted_encoder_outputs),
    dim=1
)
rnn_input.shape

torch.Size([16, 200])

In [258]:
decoder_hidden = encoder_final_hidden[0]
rnn = nn.GRUCell(200, decoder_hidden_dim)
decoder_hidden = rnn(rnn_input, decoder_hidden)

In [259]:
n_en_words = 1337
fc_out = nn.Linear(100, n_en_words)

logits = fc_out(decoder_hidden)
logits.shape

torch.Size([16, 1337])

In [260]:
logits.argmax(dim=1).shape
# идем наверх

torch.Size([16])

## Задачи для самостоятельного решения

<p class="task" id="1"></p>

1\. Создайте наборы данных для решения задачи машинного перевода на основе файлов `RuBQ_2.0_train.json` (обучающее множество) и `RuBQ_2.0_test.json`. При подготовке набора данных не приводите весь набор данных к одинаковой фиксированной длине. Реализуйте класс `Collator`, который приводит все примеры в батче к одной фиксированной длине, используя `torch.nn.utils.rnn.pad_sequence`. Создайте `DataLoader` с использованием `collate_fn`, получите батч и выведите на экран размер тензоров.

- [ ] Проверено на семинаре

In [261]:
import pandas as pd

test = pd.read_json('/content/drive/MyDrive/пм21_финашка/3 курс/NLP/04_rnn/RuBQ_2.0_test.json')
train = pd.read_json('/content/drive/MyDrive/пм21_финашка/3 курс/NLP/04_rnn/RuBQ_2.0_train.json')
train.head()

Unnamed: 0,uid,question_text,query,answer_text,question_uris,question_props,answers,paragraphs_uids,tags,RuBQ_version,question_eng
0,0,Что может вызвать цунами?,SELECT ?answer \nWHERE {\n wd:Q8070 wdt:P828 ...,Землетрясение,[http://www.wikidata.org/entity/Q8070],[wdt:P828],"[{'type': 'uri', 'label': 'землетрясение', 'va...","{'with_answer': [35622], 'all_related': [35622...",[1-hop],1,What can cause a tsunami?
1,1,Кто написал роман «Хижина дяди Тома»?,SELECT ?answer \nWHERE {\n wd:Q2222 wdt:P50 ?...,Г. Бичер-Стоу,[http://www.wikidata.org/entity/Q2222],[wdt:P50],"[{'type': 'uri', 'label': 'Гарриет Бичер-Стоу'...","{'with_answer': [35652], 'all_related': [35652...",[1-hop],1,"Who wrote the novel ""uncle Tom's Cabin""?"
2,2,Кто автор пьесы «Ромео и Джульетта»?,SELECT ?answer \nWHERE {\n wd:Q83186 wdt:P50 ...,Шекспир,[http://www.wikidata.org/entity/Q83186],[wdt:P50],"[{'type': 'uri', 'label': 'Уильям Шекспир', 'v...","{'with_answer': [35676, 35677], 'all_related':...",[1-hop],1,"Who is the author of the play ""Romeo and Juliet""?"
3,3,Как называется столица Румынии?,SELECT ?answer \nWHERE {\n wd:Q218 wdt:P36 ?a...,Бухарест,[http://www.wikidata.org/entity/Q218],[wdt:P36],"[{'type': 'uri', 'label': 'Бухарест', 'value':...","{'with_answer': [35702, 35703], 'all_related':...",[1-hop],1,What is the name of the capital of Romania?
4,5,На каком инструменте играл Джимми Хендрикс?,SELECT ?answer \nWHERE {\n wd:Q5928 wdt:P1303...,Гитара,[http://www.wikidata.org/entity/Q5928],[wdt:P1303],"[{'type': 'uri', 'label': 'гитара', 'value': '...","{'with_answer': [35728, 35727], 'all_related':...",[1-hop],1,What instrument did Jimi Hendrix play?


In [262]:
train_ru = train.question_text.values
train_en = train.question_eng.values

test_ru = test.question_text.values
test_en = test.question_eng.values
train_ru

array(['Что может вызвать цунами?',
       'Кто написал роман «Хижина дяди Тома»?',
       'Кто автор пьесы «Ромео и Джульетта»?', ...,
       'В каком году сняли с производства Jaguar E-type?',
       'С какого года закончили выпускать Rolls-Royce Silver Ghost?',
       'В каком году сняли с производства Porsche 550?'], dtype=object)

In [263]:
from torchtext.vocab import Vocab
from torchtext.vocab import build_vocab_from_iterator

import nltk
from nltk.tokenize import word_tokenize
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [264]:
import re
pattern_ru = re.compile(r'[^А-Яа-яёЁ]+')
pattern_en = re.compile(r'[^A-Za-z]+')

def del_punctuation(df, pattern):
  res = []
  for i in df:
    stroka = re.sub(pattern,' ', i)
    res.append(word_tokenize(stroka))

  return res

train_ru = del_punctuation(train_ru, pattern_ru)
train_en = del_punctuation(train_en, pattern_en)

test_ru = del_punctuation(test_ru, pattern_ru)
test_en = del_punctuation(test_en, pattern_en)
test_en[:10]

[['Which',
  'country',
  'does',
  'the',
  'famous',
  'Easter',
  'island',
  'belong',
  'to'],
 ['Which',
  'music',
  'group',
  'is',
  'Mick',
  'Jagger',
  's',
  'name',
  'inextricably',
  'linked',
  'to'],
 ['Where', 'is', 'the', 'Summer', 'garden'],
 ['Which', 'city', 'is', 'the', 'capital', 'of', 'Turkmenistan'],
 ['In',
  'which',
  'city',
  'was',
  'the',
  'first',
  'Russian',
  'revolutionary',
  'newspaper',
  'Kolokol',
  'published',
  'since',
  'by',
  'A',
  'Herzen',
  'and',
  'N',
  'Ogarev'],
 ['Which',
  'country',
  'has',
  'the',
  'second',
  'highest',
  'active',
  'volcano',
  'with',
  'the',
  'funny',
  'name',
  'Popocatepetl',
  'm'],
 ['In', 'what', 'sport', 'was', 'Kournikova', 'famous'],
 ['What', 'city', 'was', 'Wolfgang', 'Amadeus', 'Mozart', 'born', 'in'],
 ['What', 'city', 'was', 'John', 'Lennon', 'killed', 'in'],
 ['What',
  'country',
  'was',
  'the',
  'inventor',
  'of',
  'the',
  'Morse',
  'code',
  'Telegraph',
  'a',
  'citi

In [265]:
token = ['<PAD>', '<SOS>', '<EOS>']

vocab_ru = build_vocab_from_iterator(
    train_ru,
    specials = token
)

vocab_en = build_vocab_from_iterator(
    train_en,
    specials = token
)

len(vocab_ru), len(vocab_en)

(5959, 4348)

In [266]:
import torchtext.transforms as T

In [267]:
class Dataset:

  def __init__(self, x, y, vocab_ru, vocab_en):

    self.x = x
    self.y = y
    self.vocab_ru = vocab_ru
    self.vocab_en = vocab_en
    self.transform_ru = self.transforms(self.vocab_ru)
    self.transform_en = self.transforms(self.vocab_en)

  def transforms(self, vocab):
    transf = T.Sequential(
        T.VocabTransform(vocab),
        T.AddToken(begin=True, token= vocab.get_stoi()['<SOS>']),
        T.AddToken(begin=False, token= vocab.get_stoi()['<EOS>']),
        T.ToTensor()
    )
    return transf

  def __getitem__(self, idx):

    if type(idx) == int:
      idx = slice(idx, idx+1)

    ru = self.transform_ru(self.x[idx]).squeeze(0)
    en = self.transform_en(self.y[idx]).squeeze(0)

    return ru, en

  def __len__(self):
    return len(self.x)



In [268]:
ds_train = Dataset(train_ru, train_en, vocab_ru, vocab_en)
ds_train

<__main__.Dataset at 0x7a980bc35d80>

In [269]:
class Collator:
  def __init__(self, pad_idx, batch_first=True):
    self.pad_idx = pad_idx
    self.batch_first = batch_first

  def __call__(self, batch):
    x = [x[0] for x in batch]
    y = [y[1] for y in batch]

    x = pad_sequence(
        x,
        batch_first=self.batch_first,
        padding_value=self.pad_idx
    )

    y = pad_sequence(
        y,
        batch_first=self.batch_first,
        padding_value=self.pad_idx
    )
    return x, y

In [270]:
loader_train = DataLoader(ds_train, batch_size=128, collate_fn=Collator(pad_idx=0, batch_first=True), drop_last= True)
loader_train

<torch.utils.data.dataloader.DataLoader at 0x7a97edf48700>

In [271]:
for x, y in loader_train:
    print(x.shape, y.shape)
    # break

torch.Size([128, 22]) torch.Size([128, 31])
torch.Size([128, 19]) torch.Size([128, 25])
torch.Size([128, 19]) torch.Size([128, 23])
torch.Size([128, 21]) torch.Size([128, 25])
torch.Size([128, 21]) torch.Size([128, 32])
torch.Size([128, 19]) torch.Size([128, 23])
torch.Size([128, 22]) torch.Size([128, 26])
torch.Size([128, 19]) torch.Size([128, 33])
torch.Size([128, 27]) torch.Size([128, 29])
torch.Size([128, 19]) torch.Size([128, 23])
torch.Size([128, 14]) torch.Size([128, 21])
torch.Size([128, 13]) torch.Size([128, 18])
torch.Size([128, 14]) torch.Size([128, 18])
torch.Size([128, 12]) torch.Size([128, 18])
torch.Size([128, 17]) torch.Size([128, 18])
torch.Size([128, 12]) torch.Size([128, 17])
torch.Size([128, 13]) torch.Size([128, 16])
torch.Size([128, 15]) torch.Size([128, 18])


In [272]:
token = ['<PAD>', '<SOS>', '<EOS>']

vocab_ru_test = build_vocab_from_iterator(
    test_ru,
    specials = token
)

vocab_en_test = build_vocab_from_iterator(
    test_en,
    specials = token
)

len(vocab_ru_test), len(vocab_en_test)

(1973, 1598)

In [273]:
ds_test = Dataset(test_ru, test_en, vocab_ru_test, vocab_en_test)

loader_test = DataLoader(
    ds_test,
    batch_size=216,
    collate_fn=Collator(batch_first=True, pad_idx=0)
)

for x,y in loader_test:
  print(x.shape, y.shape)

torch.Size([216, 19]) torch.Size([216, 27])
torch.Size([216, 19]) torch.Size([216, 33])
torch.Size([148, 15]) torch.Size([148, 18])


<p class="task" id="2"></p>

2\. Создайте и обучите модель машинного перевода, используя архитектуру Encoder-Decoder на основе RNN с использованием механизма аддитивного внимания. Во время обучения выводите на экран значения функции потерь для эпохи (на обучающем множестве), значение accuracy по токенам (на обучающем множестве) и пример перевода, сгенерированного моделью. После завершения обучения посчитайте BLEU для тестового множества.

- [ ] Проверено на семинаре

In [274]:
import torch.nn as nn

In [275]:
class Encoder(nn.Module):
  def __init__(self,  embedding_dim, encoder_hidden_dim):
      super().__init__()
      self.emb = nn.Embedding(
          num_embeddings=len(vocab_ru),
          embedding_dim=embedding_dim,
          padding_idx=0
      )
      self.rnn = nn.GRU(embedding_dim, encoder_hidden_dim, batch_first=True)
      self.fc = nn.Linear(encoder_hidden_dim, decoder_hidden_dim)

  def forward(self, X):
    out = self.emb(X)
    out, h = self.rnn(out)
    hidden = th.tanh(self.fc(h))

    return out, hidden

In [290]:
embedding_dim = 300
encoder_hidden_dim = 42
decoder_hidden_dim = 42

In [301]:
model_enc = Encoder(embedding_dim, encoder_hidden_dim)
encoder_output, decoder_hidden = model_enc(next(iter(loader_train))[0])
encoder_output.shape, decoder_hidden.shape
# batch_size x seq_len x encoder_hidden_dim, 1 x batch_size x decoder_hidden_dim

(torch.Size([128, 22, 42]), torch.Size([1, 128, 42]))

In [302]:
class Attention(nn.Module):
  def __init__(self, encoder_hidden_dim, decoder_hidden_dim):
    super().__init__()
    self.fc1 = nn.Linear(encoder_hidden_dim * 2, decoder_hidden_dim)
    self.fc2 = nn.Linear(decoder_hidden_dim, 1)

  def forward(self, encoder_output):
    encoder_final_hidden = encoder_output[:, -1, :].unsqueeze(0) # 1 x batch_size x encoder_hidden_dim
    decoder_hidden = encoder_final_hidden # 1 x batch_size x encoder_hidden_dim

    decoder_hidden = (
        decoder_hidden
        .repeat(encoder_output.shape[1], 1, 1)
        .permute(1, 0, 2)
    )

    encoder_with_hidden = th.cat(
                    (decoder_hidden, encoder_output),
                    dim=2)

    attn_hidden = self.fc1(encoder_with_hidden).tanh()
    energy = self.fc2(attn_hidden).squeeze(2) # batch x input_seq
    attn = energy.softmax(dim=1) # batch x input_seq_len

    return attn.unsqueeze(1)   # batch x 1 x input_seq_len

In [303]:
model_att = Attention(encoder_hidden_dim,  decoder_hidden_dim)
attn = model_att(encoder_output)
attn.shape
# batch_size x 1 x seq_len

torch.Size([128, 1, 22])

In [304]:
class Decoder(nn.Module):
    def __init__(self, embedding_dim, encoder_hidden_dim, decoder_hidden_dim, attention):
        super().__init__()
        self.emb = nn.Embedding(
            num_embeddings=len(vocab_en),
            embedding_dim=embedding_dim,
            padding_idx=0
        )

        self.rnn = nn.GRUCell(embedding_dim + encoder_hidden_dim, decoder_hidden_dim)
        self.fc = nn.Linear(decoder_hidden_dim, len(vocab_en))
        self.attention = attention

    def forward(self, encoder_output, decoder_hidden, labels):
        # labels: batch x seq_len - считаем, что в 0 столбце SOS
        # encoder_output: 1 x batch x encoder_hidden_size
        seq_len = labels.size(1)
        input_tokens = labels[:, 0]
        attn = self.attention(encoder_output)
        weighted_encoder_outputs = th.bmm(attn, encoder_output).squeeze(1)
        outputs = []

        for _ in range(seq_len):
            out = self.emb(input_tokens)
            rnn_input = th.cat(
                              (out, weighted_encoder_outputs),
                              dim=1
            )
            decoder_hidden = self.rnn(rnn_input, decoder_hidden.squeeze(0))
            out = self.fc(decoder_hidden)

            input_tokens = out.argmax(dim=-1)  # входной токен для следующей итерации
            outputs.append(out)

        # вернуть прогнозы для каждого эл-та последовательности
        # batch x seq x n_en_token
        outputs = th.stack(outputs, dim=1)
        return outputs

In [305]:
model_dec = Decoder(embedding_dim, encoder_hidden_dim, decoder_hidden_dim, attention)
decoder_output = model_dec(encoder_output, decoder_hidden, next(iter(loader_train))[1])
decoder_output.shape
 # batch x seq x n_en_token

torch.Size([128, 31, 4348])

In [306]:
class EncoderDecoder(nn.Module):
  def __init__(self, encoder, decoder):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder

  def forward(self, x, y):
    out, h = self.encoder(x)
    out = self.decoder(out, h, y)
    return out.reshape(out.shape[0]*out.shape[1], out.shape[2])

In [307]:
model = EncoderDecoder(encoder = model_enc, decoder = model_dec)
out = model(next(iter(loader_train))[0], next(iter(loader_train))[1])
out.shape


torch.Size([3968, 4348])

In [318]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam(model.parameters(), lr = 0.01)
n_epochs = 10
losses = []

for epoch in range(n_epochs):
  for x_train, y_train in loader_train:

    out = model(x_train, y_train)
    loss = loss = criterion(out, y_train.view(-1))

    loss.backward()
    # optimizer.step()
    # optimizer.zero_grad()


<p class="task" id="3"></p>

3\. Сгенерируйте перевод при помощи обученной модели и визуализируйте матрицу внимания, в которой отображено, на какие слова из исходного предложения модель обращала внимание при генерации очередного слова в переводе.

- [ ] Проверено на семинаре

## Обратная связь
- [ ] Хочу получить обратную связь по решению