In [None]:
import io
import math
import time
from tqdm import tqdm
from collections import Counter
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

import torch
import numpy as np
import torch.nn as nn
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torch.nn import (TransformerEncoder, TransformerDecoder,
                      TransformerEncoderLayer, TransformerDecoderLayer)

import torchtext
from torchtext.vocab import vocab
from torchtext.data.utils import get_tokenizer
from torchtext.utils import download_from_url, extract_archive

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE

'cpu'

In [None]:
url_base = 'https://raw.githubusercontent.com/multi30k/dataset/master/data/task1/raw/'
train_urls = ('train.de.gz', 'train.en.gz')
val_urls = ('val.de.gz', 'val.en.gz')
test_urls = ('test_2016_flickr.de.gz', 'test_2016_flickr.en.gz')

train_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in train_urls]
val_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in val_urls]
test_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in test_urls]

In [None]:
!python -m spacy info

2023-12-03 20:25:14.753962: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-03 20:25:14.754041: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-03 20:25:14.754093: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
[1m

spaCy version    3.6.1                         
Location         /usr/local/lib/python3.10/dist-packages/spacy
Platform         Linux-5.15.120+-x86_64-with-glibc2.35
Python version   3.10.12                       
Pipelines        de_core_news_sm (3.6.0), en_core_web_sm (3.6.0)



In [None]:
!python -m spacy download de_core_news_sm

2023-12-03 20:25:28.516320: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-03 20:25:28.516393: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-03 20:25:28.516439: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Collecting de-core-news-sm==3.6.0
  Downloading https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.6.0/de_core_news_sm-3.6.0-py3-none-any.whl (14.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.6/14.6 MB[0m [31m77.5 MB/s[0m eta [36m0:00:00[0m
[38;5;2m✔ Download and installation successful[0m
You can no

In [None]:
de_tokenizer = get_tokenizer('spacy', language='de_core_news_sm')
en_tokenizer = get_tokenizer('spacy', language='en_core_web_sm')

def build_vocab(filepath, tokenizer):
    counter = Counter()
    with io.open(filepath, encoding="utf8") as f:
        for string_ in f:
            counter.update(tokenizer(string_))
    return vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])
    # return vocab(counter)
de_vocab = build_vocab(train_filepaths[0], de_tokenizer)
en_vocab = build_vocab(train_filepaths[1], en_tokenizer)
de_vocab.set_default_index(de_vocab['<unk>'])
en_vocab.set_default_index(en_vocab['<unk>'])

In [None]:
len(de_vocab), len(en_vocab)

(19215, 10838)

In [None]:
def data_process(filepaths):
    raw_de_iter = iter(io.open(filepaths[0], encoding="utf8"))
    raw_en_iter = iter(io.open(filepaths[1], encoding="utf8"))
    data = []
    for (raw_de, raw_en) in zip(raw_de_iter, raw_en_iter):
        # print()
        de_tensor_ = torch.tensor(
            [de_vocab[token] for token in de_tokenizer(raw_de.rstrip("\n"))],
            dtype=torch.long)
        en_tensor_ = torch.tensor(
            [en_vocab[token] for token in en_tokenizer(raw_en.rstrip("\n"))],
            dtype=torch.long)
        data.append((de_tensor_, en_tensor_))
    return data

train_data = data_process(train_filepaths)
val_data = data_process(val_filepaths)
test_data = data_process(test_filepaths)

In [None]:
val_data[1]

(tensor([  24,   31,  633,   11,   32,  105,  718,   36,   32, 2900,   16]),
 tensor([ 26,  33, 546,  18,  22,  47, 698,  38,  22, 688,  14]))

In [None]:
en_vocab.get_itos()[3504]

'Here'

In [None]:
BATCH_SIZE = 128
PAD_IDX = de_vocab['<pad>']
BOS_IDX = de_vocab['<bos>']
EOS_IDX = de_vocab['<eos>']
print(PAD_IDX , BOS_IDX , EOS_IDX)

1 2 3


In [None]:
def generate_batch(data_batch):
    de_batch, en_batch = [], []
    for (de_item, en_item) in data_batch:
        de_batch.append(torch.cat([torch.tensor([BOS_IDX]), de_item, torch.tensor([EOS_IDX])], dim=0))
        en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))
    de_batch = pad_sequence(de_batch, padding_value=PAD_IDX)
    en_batch = pad_sequence(en_batch, padding_value=PAD_IDX)
    return de_batch, en_batch

train_iter = DataLoader(train_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch)
valid_iter = DataLoader(val_data, batch_size=1,
                        shuffle=True, collate_fn=generate_batch)
test_iter = DataLoader(test_data, batch_size=BATCH_SIZE,
                       shuffle=True, collate_fn=generate_batch)

In [None]:

class Seq2SeqTransformer(nn.Module):
    def __init__(self, num_encoder_layers: int, num_decoder_layers: int,
                 emb_size: int, src_vocab_size: int, tgt_vocab_size: int,
                 dim_feedforward:int = 512, dropout:float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        encoder_layer = TransformerEncoderLayer(d_model=emb_size, nhead=NHEAD,
                                                dim_feedforward=dim_feedforward)
        self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        decoder_layer = TransformerDecoderLayer(d_model=emb_size, nhead=NHEAD,
                                                dim_feedforward=dim_feedforward)
        self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)

    def forward(self, src: Tensor, trg: Tensor, src_mask: Tensor,
                tgt_mask: Tensor, src_padding_mask: Tensor,
                tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        memory = self.transformer_encoder(src_emb, src_mask, src_padding_mask)
        outs = self.transformer_decoder(tgt_emb, memory, tgt_mask, None,
                                        tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer_encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer_decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)


class PositionalEncoding(nn.Module):
    def __init__(self, emb_size: int, dropout, maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding +
                            self.pos_embedding[:token_embedding.size(0),:])

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size
    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

# Делаем, так чтобы в обучении не было заглядывания на дальнешие слова
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len), device=DEVICE).type(torch.bool)

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

In [None]:
torch.FloatTensor([[1],[2]]).shape

torch.Size([2, 1])

In [None]:
print(torch.cat([torch.FloatTensor([[1],[2]])]*8 , dim=1).shape)

torch.Size([2, 8])


In [None]:
torch.ones(1, 8).fill_(9).type(torch.long).shape

torch.Size([1, 8])

In [None]:
def greedy_decode(model, src, src_mask, max_len, start_symbol, num_samples=1):
    src = src.to(DEVICE)
    src = torch.cat([src]*num_samples, dim=1)
    src_mask = src_mask.to(DEVICE)

    memory = model.encode(src, src_mask)

    ys = torch.ones(1, num_samples).fill_(start_symbol).type(torch.long).to(DEVICE)
    for i in range(max_len-1):
        memory = memory.to(DEVICE)
        memory_mask = torch.zeros(ys.shape[0], memory.shape[0]).to(DEVICE).type(torch.bool)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                                    .type(torch.bool)).to(DEVICE)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim = 1)
        next_word = next_word.detach()

        ys = torch.cat([ys,
                        next_word.view(1, -1)], dim=0)
    return ys.transpose(0,1)

def sampling_decode(model, src, src_mask, max_len, start_symbol, num_samples=1):
    src = src.to(DEVICE)
    src = torch.cat([src]*num_samples, dim=1)
    src_mask = src_mask.to(DEVICE)

    memory = model.encode(src, src_mask)

    ys = torch.ones(1, num_samples).fill_(start_symbol).type(torch.long).to(DEVICE)
    for i in range(max_len-1):
        memory = memory.to(DEVICE)
        memory_mask = torch.zeros(ys.shape[0], memory.shape[0]).to(DEVICE).type(torch.bool)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                                    .type(torch.bool)).to(DEVICE)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        next_word = torch.multinomial(torch.nn.functional.softmax(prob, dim=-1), 1)
        next_word = next_word.detach()

        ys = torch.cat([ys,
                        next_word.view(1, -1)], dim=0)
    return ys.transpose(0,1)

def translate(model,
              srcs,
              src_vocab,
              tgt_vocab,
              src_tokenizer,
              decoder=greedy_decode,
              ret_tokens=False,
              ret_idx=False,
              max_len_add=10,
              input_idx=False,
              **argv):
    model.eval()
    itos = tgt_vocab.get_itos()
    global_answers = []
    for src in srcs:
        if not input_idx:
            tokens = [BOS_IDX] + [src_vocab.stoi[tok] for tok in src_tokenizer(src)]+ [EOS_IDX]
            src = torch.LongTensor(tokens)
        num_tokens = len(src)
        src = src.reshape(num_tokens, 1)

        src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
        tgt_tokens = decoder(model, src, src_mask, max_len=num_tokens + max_len_add, start_symbol=BOS_IDX, **argv)

        answers = []
        for tgt_token in tgt_tokens:
            if not ret_idx:
                reference = []
                for tok in tgt_token:
                    if tok.item() == tgt_vocab['<eos>']:
                        break
                    if tok.item() not in {tgt_vocab['<eos>'], tgt_vocab['<bos>'], tgt_vocab['<pad>']}:
                        reference.append(itos[tok])
                answers.append(" ".join(reference).strip())
                if ret_tokens:
                    answers[-1] = answers[-1].split(" ")
            else:
                reference = []
                for tok in tgt_token:
                    if tok.item() == tgt_vocab['<eos>']:
                        break
                    if tok.item() not in {tgt_vocab['<eos>'], tgt_vocab['<bos>'], tgt_vocab['<pad>']}:
                        reference.append(tok.item())

                answers.append(reference)
        global_answers.append(answers)
    return global_answers

In [None]:
for idx, (src, tgt) in (enumerate(valid_iter)):
        itos = en_vocab.get_itos()
        de_itos = de_vocab.get_itos()
        tgt_input = tgt[:-1, :]
        # print(src)
        print(' '.join([de_itos[j[0]]  for j in src]))
        print(' '.join([itos[j[0]]  for j in tgt_input]))

        # print(tgt_input.shape)
        # print(src.shape , tgt.shape)
        break

<bos> Vier Männer in T-Shirts und Shorts blicken aus einem Fenster auf die Straße hinunter . <eos>
<bos> Four men dressed in t - shirts and shorts are looking out a window <unk> the street below .


In [None]:
def evaluate(model, val_iter):
    model.eval()
    losses = 0
    for idx, (src, tgt) in (enumerate(valid_iter)):
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
        # print('before model')
        logits = model(src, tgt_input, src_mask, tgt_mask,
                              src_padding_mask, tgt_padding_mask, src_padding_mask)
        # print('after model')
        tgt_out = tgt[1:,:]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()
    return losses / len(val_iter)

def bleu_calculate(model, data_iter, decoder=greedy_decode):
    model.eval()
    bleus = []
    itos = en_vocab.get_itos()
    for idx, (src, tgt) in enumerate(data_iter):
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        src_input = src.transpose(0, 1)
        tgt_input = tgt.transpose(0, 1)
        tgt_output = translate(
            model,
            src_input,
            de_vocab, en_vocab, de_tokenizer,
            decoder=decoder, ret_tokens=True, ret_idx=False, input_idx=True, num_samples=1)
        for refs, candidates in zip(tgt_input, tgt_output):
            reference = []
            for tok in refs[1:]:
                if tok.item() == en_vocab['<eos>']:
                    break
                if tok.item() not in {en_vocab['<eos>'], en_vocab['<bos>'], en_vocab['<pad>']}:

                    reference.append(itos[tok])
            bleus.append(
                sentence_bleu(
                    [reference], candidates[0],
                    smoothing_function=SmoothingFunction().method1))

    return np.mean(bleus)

In [None]:
# for idx, (src, tgt) in enumerate(valid_iter):
#         src = src.to(DEVICE)
#         tgt = tgt.to(DEVICE)

#         src_input = src.transpose(0, 1)
#         tgt_input = tgt.transpose(0, 1)
#         for refs in tgt_input:
#             reference = []
#             for tok in refs[1:]:
#                 print(tok)
#                 break
#             break
#         break

In [None]:
#

In [None]:
# bleu = bleu_calculate(transformer, valid_iter)

In [None]:
SRC_VOCAB_SIZE = len(de_vocab)
TGT_VOCAB_SIZE = len(en_vocab)
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 128
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3
NUM_EPOCHS = 16


transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS,
                                 NUM_DECODER_LAYERS,
                                 EMB_SIZE, SRC_VOCAB_SIZE,
                                 TGT_VOCAB_SIZE,
                                 FFN_HID_DIM)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = transformer.to(DEVICE)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(
    transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9
)



In [None]:
def train_epoch(model, train_iter, optimizer):
    model.train()
    losses = 0
    # print('1')
    for idx, (src, tgt) in enumerate(train_iter):
        # print('1')
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)
        # print('1')
        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
        # print('1')
        # print(src.shape,tgt_input.shape,
        #                src_mask.shape,
        #                tgt_mask.shape,
        #                src_padding_mask.shape,
        #                tgt_padding_mask.shape,
        #                src_padding_mask.shape )
        logits = model(src,
                       tgt_input,
                       src_mask,
                       tgt_mask,
                       src_padding_mask,
                       tgt_padding_mask,
                       src_padding_mask)
        # print('1')
        optimizer.zero_grad()

        tgt_out = tgt[1:,:]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()
        # print('1')
        optimizer.step()
        losses += loss.item()
        # break # ---------------------------------------------------------------------------------ATTENTION!!!!!!!!!!!!!!!!!!!!!!!!!!!----------------------------------
    return losses / len(train_iter)

In [None]:
for epoch in range(1, NUM_EPOCHS+1):
    start_time = time.time()
    train_loss = train_epoch(transformer, train_iter, optimizer)
    end_time = time.time()
    val_loss = evaluate(transformer, valid_iter)
    bleu = bleu_calculate(transformer, valid_iter)
    all_time = time.time()
    print(f"Epoch: {epoch}, "
          f"Train loss: {train_loss:.3f}, "
          f"Val loss: {val_loss:.3f}, "
          f"Blue: {bleu:.3f}, "
          f"Epoch time = {(end_time - start_time):.3f}s, "
          f"All time = {(all_time - start_time):.3f}s")



KeyboardInterrupt: ignored

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class RNNEncoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(RNNEncoder, self).__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        # self.linear = nn.Linear(2 * hidden_size, hidden_size)
        # self.relu = nn.ReLU()

    def forward(self, x, h_prev):
        output, h_next = self.rnn(x, h_prev)
        return output, h_next

In [None]:
encoder = RNNEncoder(30,15)
input_sequence = torch.randn(1, 5, 30)
initial_hidden = torch.randn(1, 1, 15)

output_sequence, next_hidden = encoder(input_sequence, initial_hidden)

print("Output sequence shape:", output_sequence.shape)
print("Next hidden state shape:", next_hidden.shape)

Output sequence shape: torch.Size([1, 5, 15])
Next hidden state shape: torch.Size([1, 1, 15])


In [None]:
class RNNDecoder(nn.Module):
    def __init__(self, context_size, hidden_size, prev_size ):
        super(RNNDecoder, self).__init__()
        self.hidden_size = hidden_size
        self.context_size = context_size
        self.prev_size = prev_size

        self.get_new_hidden = nn.Sequential(
            nn.Linear(in_features = self.prev_size + self.context_size + self.hidden_size , out_features = self.hidden_size) ,
            nn.ReLU()
        )
        self.get_y = nn.Sequential(
            nn.Linear(in_features = self.prev_size + self.context_size + self.hidden_size , out_features = self.prev_size) ,
            nn.ReLU()
        )

    def forward(self, context , hidden , y ):
        concat = torch.cat((hidden , context , y ), dim=2)
        # print(concat.shape)
        new_hidden = self.get_new_hidden( concat )

        next_concat = torch.cat([new_hidden , context , y ], dim=2)
        new_y = self.get_y( next_concat )

        return new_hidden , new_y

In [None]:
decoder = RNNDecoder(context_size = 20, hidden_size = 10,prev_size = 15)
context = torch.randn(1, 1, 20)
hidden = torch.randn(1, 1, 10)
y = torch.randn(1, 1, 15)

new_hidden , new_y = decoder(context , hidden , y )
print(new_hidden.shape , new_y.shape)

torch.Size([1, 1, 10]) torch.Size([1, 1, 15])


In [None]:
class Attention(nn.Module):
    # ADDITIVE ATTENTION
    def __init__(self, query_size ,key_size ):
        super(Attention, self).__init__()

        self.query_size = query_size
        self.key_size = key_size
        # print(query_size , key_size)
        self.linear_query = nn.Linear(query_size, key_size)
        self.linear_key = nn.Linear(key_size, key_size)
        self.v = nn.Parameter(torch.rand(key_size))


    def forward(self, hidden , embs ):


        query = self.linear_query(embs)


        key = self.linear_key(hidden)


        energy = torch.tanh(query + key)
        attention_weight = F.softmax(torch.matmul(energy, self.v), dim=1)

        context = torch.matmul(attention_weight, embs)
        return context

In [None]:
attention = Attention( query_size = 20, key_size = 10 )
hidden = torch.randn(1, 1, 10)
embs = torch.randn(1, 5, 20)

context = attention(hidden=hidden , embs=embs)
print(context.shape)

torch.Size([1, 1, 20])


In [None]:
MAX_GENERATE = 20
class RNN_Att(nn.Module) :
    def __init__(self , input_size , encoder_size ,decoder_size , output_size ) :
        super(RNN_Att, self).__init__()
        self.input_size = input_size
        self.hidden_encoder_size = encoder_size
        self.hidden_decoder_size = decoder_size
        self.output_size = output_size

        self.encoder = RNNEncoder(input_size = self.input_size, hidden_size = self.hidden_encoder_size)
        # print(self.query_size , self.hidden_decoder_size )

        self.attention = Attention(query_size = self.hidden_encoder_size , key_size = self.hidden_decoder_size)
        self.decoder = RNNDecoder(self.hidden_encoder_size, self.hidden_decoder_size, self.output_size)

    def forward(self , input_sequence ) :
        initial_hidden = torch.zeros(1,1,self.hidden_encoder_size)
        # print(input_sequence.shape, initial_hidden.shape )
        encoder_hidden, _ = self.encoder(input_sequence, initial_hidden) # hidden layers of encoder
        # print(encoder_hidden.shape)
        #INITALIZATION
        initial_output = torch.zeros(1,1,self.output_size)
        decoder_hidden = torch.zeros(1,1,self.hidden_decoder_size)
        current_output = initial_output

        #GENERATE current_output
        output_sequence = []
        for j in range(MAX_GENERATE) :
            context = self.attention(hidden=decoder_hidden , embs=encoder_hidden)
            decoder_hidden , current_output = self.decoder(context , decoder_hidden , current_output )
            output_sequence.append(current_output)
            # Условие, что сгенерировался конечный токен
            # if current_output
            #     break

        output_sequence = torch.cat(output_sequence, dim=1)
        # print(output_sequence.shape)
        # output_sequence = output_sequence.permute(1,0,3,2)
        # output_sequence = output_sequence.squeeze(3)
        return output_sequence

In [None]:
input_size = 30 ; encoder_size = 15 ;decoder_size = 20 ; output_size = 30
rnn_att = RNN_Att( input_size = input_size, encoder_size = encoder_size ,decoder_size=decoder_size , output_size=output_size)
input_sequence = torch.randn(1, 5, input_size)

ouput_sequence = rnn_att(input_sequence)




In [None]:
a = torch.randn(5,4,1,15)
b = a.permute(1,0,3,2)
print(b.shape)
c = b.squeeze(3)
print(c.shape)

torch.Size([4, 5, 15, 1])
torch.Size([4, 5, 15])
