#### https://github.com/bentrevett/pytorch-seq2seq/blob/master/6%20-%20Attention%20is%20All%20You%20Need.ipynb

In [None]:
%load_ext autoreload
%autoreload 2

### Set random seed for reproducibility

In [None]:
SEED = 1234

In [None]:
import random
random.seed(SEED)

In [None]:
import numpy as np
np.random.seed(SEED)

In [None]:
import torch
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
tokenize=lambda x: x.replace('.', ' .').replace('?', ' ?').replace("'", " ' ").replace('-', ' - ').split()

### Load all training data, split into train/test/val and construct iterator

In [None]:
from torchtext.data import Field

SRC = Field(sequential=True,
            use_vocab=True,
            tokenize = tokenize, 
            lower = False,
            init_token = '<sos>', 
            eos_token = '<eos>')

TRG = Field(tokenize = tokenize, 
            use_vocab = True,
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = False)

In [None]:
from torchtext.data import TabularDataset

all_data = TabularDataset(
    path='/home/catskills/Desktop/openasr/torch_tutorial/data/eng-fra.tsv',
    format='tsv',
    fields=[('trg', TRG), ('src', SRC)])

In [None]:
(train_data, valid_data, test_data)=all_data.split([.6,.2,.2])

In [None]:
MAX_LENGTH=max([max(len(example.src), len(example.trg)) for example in all_data.examples])+10
MAX_LENGTH

In [None]:
MIN_FREQ=12
SRC.build_vocab(all_data, min_freq = MIN_FREQ)
TRG.build_vocab(all_data, min_freq = MIN_FREQ)
INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
INPUT_DIM, OUTPUT_DIM

In [None]:
SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]
SRC_PAD_IDX, TRG_PAD_IDX

In [None]:
from torchtext.data import BucketIterator

BATCH_SIZE = 1

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
     batch_size = BATCH_SIZE,
     device = device)

### Build model

In [None]:
HID_DIM = 256
ENC_LAYERS = 3
DEC_LAYERS = 3
ENC_HEADS = 8
DEC_HEADS = 8
ENC_PF_DIM = 512
DEC_PF_DIM = 512
ENC_DROPOUT = 0.1
DEC_DROPOUT = 0.1

In [None]:
from Encoder import Encoder
enc = Encoder(INPUT_DIM, 
              HID_DIM, 
              ENC_LAYERS, 
              ENC_HEADS, 
              ENC_PF_DIM, 
              ENC_DROPOUT, 
              device,
              MAX_LENGTH)

In [None]:
from Decoder import Decoder
dec = Decoder(OUTPUT_DIM, 
              HID_DIM, 
              DEC_LAYERS, 
              DEC_HEADS, 
              DEC_PF_DIM, 
              DEC_DROPOUT, 
              device,
              MAX_LENGTH)

In [None]:
from Seq2Seq import Seq2Seq
model = Seq2Seq(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device)

In [None]:
from count_parameters import count_parameters
print(f'The model has {count_parameters(model):,} trainable parameters')

In [None]:
from initialize_weights import initialize_weights
model.apply(initialize_weights);

In [None]:
LEARNING_RATE = 0.0005
optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)

In [None]:
import torch.nn as nn
criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)

In [None]:
from transformer_train import transformer_train
from transformer_evaluate import transformer_evaluate

In [None]:
import time
N_EPOCHS = 10
CLIP = 1

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    
    train_loss = transformer_train(model, train_iterator, optimizer, criterion, CLIP)
    valid_loss = transformer_evaluate(model, valid_iterator, criterion)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'tut6-model.pt')
    
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

In [None]:
test_loss = transformer_evaluate(model, test_iterator, criterion)

print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')

In [None]:
example_idx = 8

src = vars(train_data.examples[example_idx])['src']
trg = vars(train_data.examples[example_idx])['trg']

print(f'src = {src}')
print(f'trg = {trg}')

In [None]:
translation, attention = translate_sentence(src, SRC, TRG, model, device)

print(f'predicted trg = {translation}')

In [None]:
display_attention(src, translation, attention)

In [None]:
example_idx = 6

src = vars(valid_data.examples[example_idx])['src']
trg = vars(valid_data.examples[example_idx])['trg']

print(f'src = {src}')
print(f'trg = {trg}')

In [None]:
translation, attention = translate_sentence(src, SRC, TRG, model, device)

print(f'predicted trg = {translation}')

In [None]:
display_attention(src, translation, attention)

In [None]:
example_idx = 10

src = vars(test_data.examples[example_idx])['src']
trg = vars(test_data.examples[example_idx])['trg']

print(f'src = {src}')
print(f'trg = {trg}')

In [None]:
translation, attention = translate_sentence(src, SRC, TRG, model, device)

print(f'predicted trg = {translation}')

In [None]:
display_attention(src, translation, attention)

In [None]:
bleu_score = calculate_bleu(test_data, SRC, TRG, model, device)

print(f'BLEU score = {bleu_score*100:.2f}')