In [1]:

## UNCOMMENT THIS CELL IF USING GOOGLE COLAB ###

from google.colab import drive

drive.mount('/content/drive', force_remount=True)

# Enter the path in your Drive..
FOLDERNAME = "Transformer"

assert FOLDERNAME is not None, "[!] Enter the foldername."

# Now that we've mounted your Drive, this ensures that
# the Python interpreter of the Colab VM can load
# python files from within it.
import sys
sys.path.append(f'/content/drive/My Drive/{FOLDERNAME}')

path = f'/content/drive/My Drive/{FOLDERNAME}'

Mounted at /content/drive


In [2]:
%load_ext autoreload
%autoreload 2


import os
import torch
import numpy as np
from torch import nn
from torch.utils.data import DataLoader
from preprocessor import WMTPreProcessor
from utils import *
from dataset import *
from model import *
from Scheduler import *


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = "cpu"
print(device)

cuda:0


In [3]:
data_folder = "WMT14 EN-GE/"

vocab_folder = os.path.join(data_folder,"vocab")
train_folder = os.path.join(data_folder,"train")
test_folder = os.path.join(data_folder,"test")
mappings_path = os.path.join(data_folder,"dict.en-de")

folders_map = {"train":train_folder,"test":test_folder}

In [4]:
!pip install wget
import wget

for folder in [data_folder, vocab_folder, train_folder, test_folder]:
  if not(os.path.exists(folder)):
    os.mkdir(folder)

urls_to_load = {"https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/vocab.50K.en":vocab_folder,
                "https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/vocab.50K.de":vocab_folder,
                "https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/train.en":train_folder,
                "https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2012.en":test_folder,
                "https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/train.de":train_folder,
                "https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2012.de":test_folder}  # a list to store the urls

files_namings = {"https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/vocab.50K.en":"vocab.en",
                "https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/vocab.50K.de":"vocab.de",
                "https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/train.en":"train.en",
                "https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2012.en":"test.en",
                "https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/train.de":"train.de",
                "https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2012.de":"test.de"}

# starting to download
print("Starting downloading")
for url,path in urls_to_load.items():
    file_name = path + '/' + files_namings[url] # get the full path to the file
    if not(os.path.exists(file_name)):
     # if doesn't exists, download it.
        file_name = wget.download(url, out=file_name)
        print(file_name)

print("ok")

Starting downloading
ok


In [5]:
wmt_preprocessor = WMTPreProcessor(vocabPath = vocab_folder, mappings_path = mappings_path)
wmt_preprocessor.initialize()

Building Vocab ....:   0%|          | 0/50001 [00:00<?, ?it/s]

Building Vocab ....:   0%|          | 0/50001 [00:00<?, ?it/s]

In [6]:
data_text = {"en":{"train":[],"test":[],"val":[]},
             "de":{"train":[],"test":[],"val":[]}}

data_text = extractTextFromFolders(folders_map,data_text,val_split = 0.75, limit = 5e4)


extracting text from folders ...:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
data_tokens = extractTokens(data_text,wmt_preprocessor)

extracting tokens ...:   0%|          | 0/2 [00:00<?, ?it/s]

tokenizing ...:   0%|          | 0/37500 [00:00<?, ?it/s]

setting max sent len ....:   0%|          | 0/37500 [00:00<?, ?it/s]

tokenizing ...:   0%|          | 0/3004 [00:00<?, ?it/s]

setting max sent len ....:   0%|          | 0/3004 [00:00<?, ?it/s]

tokenizing ...:   0%|          | 0/12500 [00:00<?, ?it/s]

setting max sent len ....:   0%|          | 0/12500 [00:00<?, ?it/s]

tokenizing ...:   0%|          | 0/37500 [00:00<?, ?it/s]

setting max sent len ....:   0%|          | 0/37500 [00:00<?, ?it/s]

tokenizing ...:   0%|          | 0/3004 [00:00<?, ?it/s]

setting max sent len ....:   0%|          | 0/3004 [00:00<?, ?it/s]

tokenizing ...:   0%|          | 0/12500 [00:00<?, ?it/s]

setting max sent len ....:   0%|          | 0/12500 [00:00<?, ?it/s]

In [8]:
data_encodings = extractEncodings(data_tokens,wmt_preprocessor)

extracting encodings ...:   0%|          | 0/2 [00:00<?, ?it/s]

encoding ....:   0%|          | 0/37500 [00:00<?, ?it/s]

padding ...:   0%|          | 0/37500 [00:00<?, ?it/s]

encoding ....:   0%|          | 0/3004 [00:00<?, ?it/s]

padding ...:   0%|          | 0/3004 [00:00<?, ?it/s]

encoding ....:   0%|          | 0/12500 [00:00<?, ?it/s]

padding ...:   0%|          | 0/12500 [00:00<?, ?it/s]

encoding ....:   0%|          | 0/37500 [00:00<?, ?it/s]

padding ...:   0%|          | 0/37500 [00:00<?, ?it/s]

encoding ....:   0%|          | 0/3004 [00:00<?, ?it/s]

padding ...:   0%|          | 0/3004 [00:00<?, ?it/s]

encoding ....:   0%|          | 0/12500 [00:00<?, ?it/s]

padding ...:   0%|          | 0/12500 [00:00<?, ?it/s]

In [9]:
wmt_train = WMT(inpt_encodings=data_encodings["en"]["train"],tgt_encodings=data_encodings["de"]["train"], pad_index = wmt_preprocessor.PAD)
wmt_val = WMT(inpt_encodings=data_encodings["en"]["val"],tgt_encodings=data_encodings["de"]["val"], pad_index = wmt_preprocessor.PAD)
wmt_test = WMT(inpt_encodings=data_encodings["en"]["test"],tgt_encodings=data_encodings["de"]["test"], pad_index = wmt_preprocessor.PAD)

In [10]:
batch_size = 64
trainLoaders = {}

train_dataloader = DataLoader(wmt_train, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(wmt_val, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(wmt_test, batch_size=batch_size, shuffle=True)


trainLoaders["train"] = train_dataloader
trainLoaders["val"] = val_dataloader


## Generating an example ##
for i,batch in enumerate(iter(val_dataloader)):

    inputs,targets = batch["input"],batch["target"]
    idx = np.random.randint(0,len(inputs))
    if i < 1:
        print(f"inputs batch shape: {inputs['encodings'].size()}")
        print(f"targets batch shape: {targets['decoder_input_encodings'].size()}")
        sample = (inputs["encodings"][idx],targets["decoder_input_encodings"][idx])
        print(f"encodings : {sample[0]}")
        print(f"masks : {inputs['masks'][idx]}")
        print(f"masks : {inputs['masks'].shape}")
        print(f"masks : {targets['masks'][idx]}")
        print(f"masks : {targets['masks'].shape}")
        print(wmt_preprocessor.decode([sample[0].tolist()], unpad = True, idx2word = wmt_preprocessor.idx2word_en),"\n",
             wmt_preprocessor.decode([sample[1].tolist()], unpad = True, idx2word = wmt_preprocessor.idx2word_de))
        break

inputs batch shape: torch.Size([64, 138])
targets batch shape: torch.Size([64, 146])
encodings : tensor([   12,     6, 13572, 21152,    15,    26,   106,    17,  5619,     7,
           16,   241,     7,    30,    40,   274,    13,  1211,    12,     3,
           27,  2068,   125,     3,    34,    47,    36,   316,    25,     8,
            1,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,   

decoding ...:   0%|          | 0/1 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/1 [00:00<?, ?it/s]

[['in', 'the', 'wiki', 'markup', 'that', 'we', 'use', 'on', 'wikitravel', ',', 'for', 'example', ',', 'you', 'can', 'put', 'a', 'word', 'in', '<unk>', 'by', 'putting', 'two', '<unk>', '(', '&quot;', ')', 'around', 'it', '.', '<EOS>']] 
 [['<START>', 'hier', 'ein', 'kleines', 'beispiel', 'zur', 'veranschaulichung', ':', 'um', 'mit', 'dem', 'wiki', 'markup', ',', 'das', 'wir', 'bei', 'wikitravel', 'verwenden', ',', 'ein', 'wort', '<unk>', 'zu', 'schreiben', ',', 'stellt', 'man', 'es', 'einfach', 'zwischen', 'je', 'zwei', '<unk>', '(', '&quot;', ')', '.']]


In [11]:
model = Transformer(src_vocabSize = len(wmt_preprocessor.vocab_en),
                    tgt_vocabSize = len(wmt_preprocessor.vocab_de),
                    d_embed = 512,
                    d_model = 512,
                    d_ff = 2048,
                    dropout = 0.1,
                    noEncoder = 1,
                    noDecoder = 1,
                    pad_index = wmt_preprocessor.PAD,
                   device = device).to(device)

In [12]:
optimizer = torch.optim.Adam(model.parameters(),
                             betas = (0.9, 0.98),
                              eps = 1.0e-9)

In [13]:
scheduler = Scheduler(optimizer =  optimizer,
                 dim_embed= 512,
                 warmup_steps=4000)

In [14]:
print(f"Numeber of the model's trainable paramaters : {count_parameters(model)}")


Numeber of the model's trainable paramaters : 74871090


In [35]:
## An example of the untrained model prediction ##

out = model(wmt_train[:1]["input"], wmt_train[:1]["target"])

print(out.shape)
print(f"True label : {wmt_preprocessor.decode(wmt_train[:1]['target']['target_encodings'].tolist(), unpad = True, idx2word = wmt_preprocessor.idx2word_de)}")
print(f"Prediction : {wmt_preprocessor.decode(torch.argmax(out,2).tolist(), unpad = True, idx2word = wmt_preprocessor.idx2word_de)}")

## The prediction loss ##
loss = translationLoss(output=out, target=wmt_train[:1]["target"]["target_encodings"].to(device), pad_index = wmt_preprocessor.PAD, label_smoothing = 0.2)
print(loss)

torch.Size([1, 146, 45362])


decoding ...:   0%|          | 0/1 [00:00<?, ?it/s]

True label : [['iron', '<unk>', 'ist', 'eine', 'gebrauchs', '##at##-##at##', 'fertige', 'paste', ',', 'die', 'mit', 'einem', '<unk>', 'oder', 'den', 'fingern', 'als', '<unk>', 'in', 'die', '<unk>', '(', 'winkel', ')', 'der', '<unk>', '<unk>', 'aufgetragen', 'wird', '.', '<EOS>']]


decoding ...:   0%|          | 0/1 [00:00<?, ?it/s]

Prediction : [['die', '<unk>', '<unk>', 'eine', '<unk>', '##at##-##at##', '<unk>', '<unk>', ',', 'die', '<unk>', 'dem', '<unk>', '<unk>', '<unk>', '<unk>', '.', '<unk>', '<unk>', 'der', '<unk>', '<unk>', '<unk>', '<unk>', '.', '<unk>', '<unk>', '<unk>', ',', '.', '<EOS>']]
tensor(5.6684, device='cuda:0', grad_fn=<AddBackward0>)


In [16]:
if not (os.path.exists("trained")):
        os.mkdir("trained")

filename = "transformer-weights"

In [20]:
trainedModel = train_model(model, trainLoaders, translationLoss, optimizer,
                           pad_index = wmt_preprocessor.PAD, label_smoothing = 0.2,
                           scheduler = scheduler, num_epochs=10, device = device, isSave=True,
                           filename = filename, verbose = False)

Epoch 0/9
----------


Predicting ...:   0%|          | 0/586 [00:00<?, ?it/s]

train prev epoch Loss: 0
train current epoch Loss: 11.25350168377873, lr = 0.0005504098336278543


Predicting ...:   0%|          | 0/196 [00:00<?, ?it/s]

val prev epoch Loss: 0
val current epoch Loss: 13.398886875230438, lr = 0.0005504098336278543
Epoch 1/9
----------


Predicting ...:   0%|          | 0/586 [00:00<?, ?it/s]

train prev epoch Loss: 11.25350168377873
train current epoch Loss: 11.074508714187674, lr = 0.0005269807009541811


Predicting ...:   0%|          | 0/196 [00:00<?, ?it/s]

val prev epoch Loss: 13.398886875230438
val current epoch Loss: 13.409672406254982, lr = 0.0005269807009541811
Epoch 2/9
----------


Predicting ...:   0%|          | 0/586 [00:00<?, ?it/s]

train prev epoch Loss: 11.074508714187674
train current epoch Loss: 10.911967723442833, lr = 0.0005063094492826154


Predicting ...:   0%|          | 0/196 [00:00<?, ?it/s]

val prev epoch Loss: 13.409672406254982
val current epoch Loss: 13.434684768015025, lr = 0.0005063094492826154
Epoch 3/9
----------


Predicting ...:   0%|          | 0/586 [00:00<?, ?it/s]

train prev epoch Loss: 10.911967723442833
train current epoch Loss: 10.762107539909284, lr = 0.0004878942803138293


Predicting ...:   0%|          | 0/196 [00:00<?, ?it/s]

val prev epoch Loss: 13.434684768015025
val current epoch Loss: 13.47486281881527, lr = 0.0004878942803138293
Epoch 4/9
----------


Predicting ...:   0%|          | 0/586 [00:00<?, ?it/s]

train prev epoch Loss: 10.762107539909284
train current epoch Loss: 10.62822029615018, lr = 0.0004713525701262033


Predicting ...:   0%|          | 0/196 [00:00<?, ?it/s]

val prev epoch Loss: 13.47486281881527
val current epoch Loss: 13.499544051228737, lr = 0.0004713525701262033
Epoch 5/9
----------


Predicting ...:   0%|          | 0/586 [00:00<?, ?it/s]

train prev epoch Loss: 10.62822029615018
train current epoch Loss: 10.503854105496977, lr = 0.0004563867859265298


Predicting ...:   0%|          | 0/196 [00:00<?, ?it/s]

val prev epoch Loss: 13.499544051228737
val current epoch Loss: 13.54538107891472, lr = 0.0004563867859265298
Epoch 6/9
----------


Predicting ...:   0%|          | 0/586 [00:00<?, ?it/s]

train prev epoch Loss: 10.503854105496977
train current epoch Loss: 10.386557709234973, lr = 0.00044276160629398497


Predicting ...:   0%|          | 0/196 [00:00<?, ?it/s]

val prev epoch Loss: 13.54538107891472
val current epoch Loss: 13.597115161467572, lr = 0.00044276160629398497
Epoch 7/9
----------


Predicting ...:   0%|          | 0/586 [00:00<?, ?it/s]

train prev epoch Loss: 10.386557709234973
train current epoch Loss: 10.280058279786093, lr = 0.00043028813754943855


Predicting ...:   0%|          | 0/196 [00:00<?, ?it/s]

val prev epoch Loss: 13.597115161467572
val current epoch Loss: 13.592829689687612, lr = 0.00043028813754943855
Epoch 8/9
----------


Predicting ...:   0%|          | 0/586 [00:00<?, ?it/s]

train prev epoch Loss: 10.280058279786093
train current epoch Loss: 10.17767597465385, lr = 0.0004188127637955846


Predicting ...:   0%|          | 0/196 [00:00<?, ?it/s]

val prev epoch Loss: 13.592829689687612
val current epoch Loss: 13.652900549830223, lr = 0.0004188127637955846
Epoch 9/9
----------


Predicting ...:   0%|          | 0/586 [00:00<?, ?it/s]

train prev epoch Loss: 10.17767597465385
train current epoch Loss: 10.082740886219533, lr = 0.0004082091042707


Predicting ...:   0%|          | 0/196 [00:00<?, ?it/s]

val prev epoch Loss: 13.652900549830223
val current epoch Loss: 13.701968183322828, lr = 0.0004082091042707

Training complete in 66m 29s
Best val loss: 13.398887


In [21]:
results = evaluate_model(model, test_dataloader, wmt_preprocessor, device = device)

Predicting ...:   0%|          | 0/47 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/60 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/60 [00:00<?, ?it/s]

 Result : 0.00933783408254385

Evaluating complete in 0m 9s


In [None]:
model.load_state_dict(torch.load(os.path.join("trained",filename)))

<All keys matched successfully>

In [34]:
for i,batch in enumerate(iter(train_dataloader)):

    inputs,targets = batch["input"],batch["target"]
    idx = np.random.randint(0,len(inputs))
    if i < 1:
        print(f"inputs batch shape: {inputs['encodings'].size()}")
        print(f"targets batch shape: {targets['decoder_input_encodings'].size()}")
        sample = (inputs["encodings"][0],targets["decoder_input_encodings"][0])

        print(f"Input English Sentence : {wmt_preprocessor.decode([sample[0].tolist()], unpad = True, idx2word = wmt_preprocessor.idx2word_en)}")
        print(f"Input German Sentence (Translated) : {wmt_preprocessor.decode([sample[1].tolist()], unpad = True, idx2word = wmt_preprocessor.idx2word_de)}")

        out = model(inputs, targets)
        print(out.shape)

        print(f"Predicted German Sentence (Translated) : {wmt_preprocessor.decode(torch.argmax(out[idx].unsqueeze(0),2).tolist(), unpad = True, idx2word = wmt_preprocessor.idx2word_de)}")

        loss = translationLoss(output = out, target = targets['target_encodings'].to(device), pad_index = wmt_preprocessor.PAD, label_smoothing = 0.2)
        print(loss)

        bleu_score = score(out, targets, wmt_preprocessor, kind = "bleu")
        print(bleu_score)

        break



inputs batch shape: torch.Size([64, 138])
targets batch shape: torch.Size([64, 146])


decoding ...:   0%|          | 0/1 [00:00<?, ?it/s]

Input English Sentence : [['the', 'fontana', 'resort', 'is', 'located', 'in', 'the', 'picturesque', 'town', 'of', '<unk>', 'on', 'the', 'island', 'of', 'hvar', ',', 'which', 'is', 'one', 'of', 'the', '10', 'most', 'beautiful', 'islands', 'in', 'the', 'world', ',', 'voted', 'by', 'the', 'conde', '<unk>', 'group', '&apos;s', 'traveller', 'magazine', '.', '<EOS>']]


decoding ...:   0%|          | 0/1 [00:00<?, ?it/s]

Input German Sentence (Translated) : [['<START>', 'das', 'resort', 'fontana', 'befindet', 'sich', 'im', 'malerischen', 'ort', '<unk>', 'auf', 'der', 'insel', 'hvar', ',', 'welche', 'zu', 'den', '10', 'schönsten', 'inseln', 'der', 'welt', 'zählt', 'laut', 'conde', '<unk>', 'group', '&apos;', 's', 'traveller', 'magazine', '.']]
torch.Size([64, 146, 45362])


decoding ...:   0%|          | 0/1 [00:00<?, ?it/s]

Predicted German Sentence (Translated) : [['diese', 'der', '<unk>', '##at##-##at##', 'apartment', 'bietet', 'gärten', 'und', 'de', '<unk>', 'und', 'sie', 'in', 'apartment', 'in', '.', 'und', '<unk>', ',', '<EOS>']]
tensor(5.1798, device='cuda:0', grad_fn=<AddBackward0>)


decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/64 [00:00<?, ?it/s]

0.10201694809248607


In [36]:
!cp -r './trained/transformer-weights' /content/drive/MyDrive/

In [41]:
infer(model.to(device), "at lease he can run", wmt_preprocessor, eos_idx = wmt_preprocessor.EOS, device = device)

tokenizing ...:   0%|          | 0/1 [00:00<?, ?it/s]

encoding ....:   0%|          | 0/1 [00:00<?, ?it/s]

padding ...:   0%|          | 0/1 [00:00<?, ?it/s]

decoding ...:   0%|          | 0/1 [00:00<?, ?it/s]

Input English Sentence : [['<START>', 'at', 'lease', 'he', 'can', 'run', '<EOS>']]
input_masks : torch.Size([1, 1, 139])
input_masks : tensor([[[ True,  True,  True,  True,  True,  True,  True, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False,

decoding ...:   0%|          | 0/1 [00:00<?, ?it/s]

[['<START>',
  'können',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk