In [90]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [108]:

import math
import time
from tqdm import tqdm



import torch
import numpy as np
import torch.nn as nn
from torch import Tensor

from torch.utils.data import DataLoader
from torch.nn import (TransformerEncoder, TransformerDecoder,
                      TransformerEncoderLayer, TransformerDecoderLayer)

from src.prepare_data import download_data, build_train_vocab, get_train_test_val, check_tokens,tokens_to_sentence , generate_batch , visualize_iter_data , get_embed
from src.LSTM import RNNdecoder, RNNencoder,Seq2SeqRNN
from src.train import create_mask,generate_square_subsequent_mask, train_epoch, bleu_calculate , evaluate
from src.transformer import Seq2SeqTransformer,PositionalEncoding,TokenEmbedding


In [3]:
train_filepaths , val_filepaths , test_filepaths = download_data()
de_vocab, en_vocab, de_tokenizer, en_tokenizer = build_train_vocab(train_filepaths)
print( 'De vocab En vocab: ',len(de_vocab), len(en_vocab))
train_data , val_data , test_data = get_train_test_val(train_filepaths, test_filepaths, val_filepaths , de_vocab , en_vocab ,de_tokenizer,en_tokenizer )
print('Train Test Val: ',len(train_data),len(test_data) , len(val_data))

De vocab En vocab:  19215 10838
Train Test Val:  29000 1014 1000


In [4]:
BATCH_SIZE = 128
PAD_IDX = de_vocab['<pad>']
BOS_IDX = de_vocab['<bos>']
EOS_IDX = de_vocab['<eos>']
print(PAD_IDX , BOS_IDX , EOS_IDX)

1 2 3


In [5]:
train_iter = DataLoader(train_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn= lambda x : generate_batch(x , BOS_IDX=BOS_IDX,PAD_IDX=PAD_IDX,EOS_IDX=EOS_IDX))
valid_iter = DataLoader(val_data, batch_size=1,
                        shuffle=True, collate_fn= lambda x : generate_batch(x , BOS_IDX=BOS_IDX,PAD_IDX=PAD_IDX,EOS_IDX=EOS_IDX))
test_iter = DataLoader(test_data, batch_size=BATCH_SIZE,
                       shuffle=True, collate_fn= lambda x : generate_batch(x , BOS_IDX=BOS_IDX,PAD_IDX=PAD_IDX,EOS_IDX=EOS_IDX))

In [6]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE

'cpu'

# TEST TRANSFORMER

In [13]:
SRC_VOCAB_SIZE = len(de_vocab)
TGT_VOCAB_SIZE = len(en_vocab)
EMB_SIZE = 256
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 32
NUM_ENCODER_LAYERS = 2
NUM_DECODER_LAYERS = 2

transformer = Seq2SeqTransformer(num_encoder_layers=NUM_ENCODER_LAYERS,
                                num_decoder_layers= NUM_DECODER_LAYERS,
                                emb_size= EMB_SIZE, src_vocab_size= SRC_VOCAB_SIZE,
                                 tgt_vocab_size= TGT_VOCAB_SIZE,
                                 dim_feedforward= FFN_HID_DIM , NHEAD=NHEAD)
for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

In [None]:
for idx, (src, tgt) in enumerate(train_iter):
        #FORWARD
        tgt_input = tgt[:-1, :]
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input, pad_idx = PAD_IDX, device=DEVICE)
        print('Батч предложений deutch: ',src.shape)
        print('Батч переводов english: ',tgt_input.shape)
        print('Маска предложений входных: ', src_mask.shape)
        print('Маска переводов: ', tgt_mask.shape )
        print('Падинг маска предложений входных: ',src_padding_mask.shape)
        print('Падинг маска переводов: ',tgt_padding_mask.shape)
        logits = transformer.forward(src,
                       tgt_input,
                       src_mask,
                       tgt_mask,
                       src_padding_mask,
                       tgt_padding_mask,
                       src_padding_mask)
        print('Предсказания вероятностей слов на каждой позиции: ',logits.shape)

        #DECODER

        break

# TEST RNN ENCODER DECODER

In [8]:
SRC_VOCAB_SIZE = len(de_vocab)
TGT_VOCAB_SIZE = len(en_vocab)
EMB_SIZE = 256
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 32
NUM_ENCODER_LAYERS = 2
NUM_DECODER_LAYERS = 2


In [9]:

data_object = get_embed(train_iter = train_iter , EMB_SIZE = EMB_SIZE , DEVICE = DEVICE , SRC_VOCAB_SIZE = SRC_VOCAB_SIZE , PAD_IDX = PAD_IDX )
src = data_object['src']
tgt = data_object['tgt']
src_mask = data_object['src_mask']
tgt_mask = data_object['tgt_mask']
src_padding_mask = data_object['src_padding_mask']
tgt_padding_mask = data_object['tgt_padding_mask']
src_emb = data_object['src_emb']
tgt_emb = data_object['tgt_emb']



In [30]:
# src.permute(1,0)[0]
# PositionalEncoding(EMB_SIZE, dropout=0.1)(TokenEmbedding(SRC_VOCAB_SIZE, EMB_SIZE)(src)).permute(1,0,2)[0][:,5]
# src_padding_mask[0]
# src.permute(1,0)[src_padding_mask].shape
# src.permute(1,0)[0][~src_padding_mask[0]]
# src_emb.shape , src_mask.shape , src_padding_mask.shape , tgt_emb.shape , tgt_mask.shape , tgt_padding_mask.shape
# torch.sum(src_mask , dim=1) ,torch.sum(src_padding_mask , dim=1)  # src_mask из FALSE вся, но src_padding_mask указывает на паддинговые элементы
# memory[:,-1:,:] == hidden[0].permute(1,0,2)
# tgt_emb.shape , memory.shape , hidden[0].shape , hidden[1].shape

## encoder

In [92]:
encoder = RNNencoder(EMB_SIZE,EMB_SIZE )
memory, hidden = encoder(src_emb , src_padding_mask)
print(memory.shape , hidden[0].shape , hidden[1].shape)

torch.Size([128, 27, 256]) torch.Size([1, 128, 256]) torch.Size([1, 128, 256])


## decoder

In [87]:
lengths = (~tgt_padding_mask).sum(1).int()
lengths

tensor([16,  9, 11, 15, 16, 13, 19, 14, 12, 11, 13, 18, 19, 15, 12, 11, 13, 11,
        14, 12, 20, 13, 15, 10, 17, 17, 11, 11, 15, 14, 13, 14, 13, 18, 23, 15,
        23, 10, 19,  9, 13, 16, 15, 15, 13, 12, 17, 13, 16, 16, 16, 15, 16, 16,
        13, 19, 10, 26, 14, 10, 11, 18, 15, 21, 11, 14, 22, 13, 24, 15, 17, 10,
        17, 15, 13, 13, 17, 13, 21, 17, 13, 14, 12, 12, 19, 17, 18, 12, 18, 14,
        13, 14, 10, 10, 15, 15, 12, 11, 20, 25, 16, 14, 12, 24, 10, 16, 18, 18,
        14, 14, 26, 23, 12, 16, 15, 12, 19, 13,  8, 13, 23, 16, 20, 19, 12, 14,
        13, 11], dtype=torch.int32)

In [88]:
tgt_emb.shape

torch.Size([26, 128, 256])

In [96]:

input = tgt_emb 
result = []
lstm = torch.nn.LSTM(EMB_SIZE, EMB_SIZE, batch_first=True)

result,_ = lstm(input.permute(1,0,2))

# for i in range(input.shape[0]) :
#     print(act.shape)
#     act = input[i:i+1,:,:]
#     act = act.permute(1,0,2)
#     act, hidden = lstm(act,  (hidden[0] , hidden[1]))
#     result.append(act)

# result = torch.stack(result).squeeze(2)
print(result.shape)

torch.Size([128, 26, 256])


In [99]:
decoder = RNNdecoder(EMB_SIZE , EMB_SIZE )
decode_output  = decoder(tgt_emb , memory , hidden ,tgt_padding_mask )
print(decode_output.shape)


torch.Size([128, 26, 256])


#seq2sew

In [107]:
SRC_VOCAB_SIZE = len(de_vocab)
TGT_VOCAB_SIZE = len(en_vocab)
EMB_SIZE = 256
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 32
NUM_EPOCHS = 10

rnn = Seq2SeqRNN(emb_size= EMB_SIZE, 
                src_vocab_size= SRC_VOCAB_SIZE,
                tgt_vocab_size= TGT_VOCAB_SIZE,
                hidden_size= FFN_HID_DIM)
for p in rnn.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)
rnn = rnn.to(DEVICE)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(
    rnn.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9
)

In [106]:
out_forward = rnn.forward(src,tgt, src_padding_mask, tgt_padding_mask)
print(out_forward.shape)

torch.Size([128, 26, 10838])


In [None]:
for epoch in range(1, NUM_EPOCHS+1):
    start_time = time.time()
    print('train')
    train_loss = train_epoch(rnn, train_iter, optimizer , DEVICE =DEVICE , loss_fn =loss_fn , pad_idx = PAD_IDX)
    end_time = time.time()
    print('eval')
    val_loss = evaluate(rnn, valid_iter , DEVICE =DEVICE , loss_fn =loss_fn,  pad_idx = PAD_IDX)
    print('bleu')
    bleu = bleu_calculate(rnn, valid_iter, en_vocab = en_vocab ,de_vocab = de_vocab ,de_tokenizer = de_tokenizer ,DEVICE = DEVICE , EOS_IDX = EOS_IDX ,BOS_IDX = BOS_IDX)
    all_time = time.time()
    print(f"Epoch: {epoch}, "
          f"Train loss: {train_loss:.3f}, "
          f"Val loss: {val_loss:.3f}, "
          f"Blue: {bleu:.3f}, "
          f"Epoch time = {(end_time - start_time):.3f}s, "
          f"All time = {(all_time - start_time):.3f}s")