In [1]:
import pickle
import torch
import torchtext.transforms as T
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

def loadFiles(file):
    with open(file, "rb") as file:
        data = pickle.load(file)
    print("The size of the dataset is:", len(data))
    return data


def separateData(data):
    X = data[:, 0]
    y = data[:, 1]
    return X, y

In [2]:
data = loadFiles(r'./english-german-both.pkl')

The size of the dataset is: 10000


In [3]:
eng, germ = separateData(data)

In [4]:
def findLongestSequence(sentList):
    return max(len(seq.split()) for seq in sentList)


In [5]:
enTokenizer = get_tokenizer('spacy', language='en')
deTokenizer = get_tokenizer('spacy', language='de')
def yieldTokensEn(data):
    for text in data:
        yield enTokenizer(text[:-1])

def yieldTokensDe(data):
    
    for text in data:
        yield deTokenizer(text[:-1])



In [6]:
vocabEn = build_vocab_from_iterator(yieldTokensEn(eng), specials=["<pad>", "<sos>", "<eos>", "<unk>"], special_first=True)
vocabDe = build_vocab_from_iterator(yieldTokensDe(germ), specials=["<pad>", "<sos>", "<eos>", "<unk>"], special_first=True)

In [7]:
textPipelineEn = lambda x: vocabEn(enTokenizer(x))
textPipelineDe = lambda x: vocabDe(deTokenizer(x))

In [8]:
from sklearn.model_selection import train_test_split

SEED = 42
trainEn, testEn, trainDe, testDe = train_test_split(eng, germ, test_size=0.1, random_state=SEED)
trainEn, valEn, trainDe, valDe = train_test_split(eng, germ, test_size=0.1, random_state=SEED)

In [69]:
BATCH_SIZE = 64
PAD_IDX = vocabEn(['<pad>'])[0]
SOS_IDX = vocabEn(['<sos>'])
EOS_IDX = vocabEn(['<eos>'])
print(PAD_IDX)
print(SOS_IDX)
print(EOS_IDX)

0
[1]
[2]


In [70]:
from torch.nn.utils.rnn import pad_sequence

def generateData(eng, deu):
    data = []
    for en, de in zip(eng, deu):
        enTensor = torch.tensor(textPipelineEn(en[:-1]), dtype=torch.long)
        deTensor = torch.tensor(textPipelineDe(de[:-1]), dtype=torch.long)
        data.append((enTensor, deTensor))
    return data


trainData = generateData(trainEn, trainDe)
valData = generateData(valEn, valDe)         
testData = generateData(testEn, testDe)        

In [71]:
from torch.utils.data import DataLoader

    
def generateBatch(data_batch):
  de_batch, en_batch = [], []
  # print(data_batch)
  for (de_item, en_item) in data_batch:
    de_batch.append(torch.cat([torch.tensor(SOS_IDX), de_item, torch.tensor(EOS_IDX)], dim=0))
    en_batch.append(torch.cat([torch.tensor(SOS_IDX), en_item, torch.tensor(EOS_IDX)], dim=0))
  deLength = len(de_batch)
  batch = pad_sequence(en_batch + de_batch, padding_value=PAD_IDX, batch_first=True)
  en_batch, de_batch = batch[:deLength], batch[deLength:]
  return de_batch, en_batch


trainIter = DataLoader(trainData, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generateBatch)
valIter =  DataLoader(valData, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generateBatch)

In [72]:
from model.transformer import Transformer
from model.positenc import PositionalEncodingTorch
from torch.nn import Embedding, Module, Linear

EMB_DIM = 128
HEADS = 8
LINEAR_DIM = 512
DROPOUT = 0.1
LAYERS = 2
BETA_1 = 0.9
BETA_2 = 0.98
EPSILON = 10**-9
ENG_VOCAB_LEN = vocabEn.__len__()
DE_VOCAB_LEN = vocabDe.__len__()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"The english has {ENG_VOCAB_LEN} words.")
print(f"The german has {DE_VOCAB_LEN} words.")

EN_MAX_SEQ_LEN, DE_MAX_SEQ_LEN = findLongestSequence(eng) + 2 , findLongestSequence(germ) + 2

#mask for the decoder
def createMask(x):
    batch, seq_length, _ = x.size()
    mask = torch.ones((batch, seq_length, seq_length)).to(device)
    mask = torch.tril(mask, diagonal=0)
    return mask

def createMaskGen(x):
    x = x.unsqueeze(1)
    batch, _, seq_length = x.size()
    mask = torch.ones((batch, 1, seq_length)).to(device)
    maskR = torch.logical_and(mask, x)
    maskC = torch.logical_and(x.transpose(-1, 1), mask.transpose(-1, 1))
    mask = torch.multiply(maskR, maskC)
    return mask
    


class TransformerModel(Module):
    def __init__(self):
        super().__init__()
        self.embEn = Embedding(ENG_VOCAB_LEN, EMB_DIM)
        self.embDe = Embedding(DE_VOCAB_LEN, EMB_DIM)
        self.positEn = PositionalEncodingTorch(EN_MAX_SEQ_LEN, EMB_DIM)
        self.positDe = PositionalEncodingTorch(DE_MAX_SEQ_LEN, EMB_DIM)
        self.transformer = Transformer(LAYERS, EMB_DIM, EMB_DIM, HEADS, LINEAR_DIM, DROPOUT)
        self.linear = Linear(EMB_DIM, DE_VOCAB_LEN)
        
    def forward(self, eng, de, device=device, ret_att=False):
        mask1 = createMaskGen(eng)
        eng = self.embEn(eng)
        eng = self.positEn(eng, device=device)
        mask = createMaskGen(de)
        de = self.embDe(de)
        de = self.positDe(de, device=device)
        mask2 = createMask(de)
        mask2 = torch.logical_and(mask, mask2)
        softmax, dec = self.transformer(eng, de, mask1, mask2, return_att=ret_att)
        lin = self.linear(dec) 
        if ret_att:
            return softmax, lin   
        return lin
        






The english has 2594 words.
The german has 4167 words.


In [73]:
class TransformerLRScheduler(object):
    def __init__(self,  warmup_steps:int=10, d_model:int=512):
        self.warmup_steps = warmup_steps
        self.d_model = d_model
    
    
    def __call__(self, epoch):
        epoch = epoch + 1
        minimum = min(epoch**-0.5, epoch * ((self.warmup_steps) ** (-1.5)))
        return (self.d_model**-0.5 ) * minimum
    

In [74]:
def calculateAccuracy(prediction, target):
    padding_mask = torch.logical_not(torch.eq(target, torch.tensor(0)))
    accuracy = torch.eq(target, torch.argmax(prediction, axis=2))
    accuracy = torch.logical_and(padding_mask, accuracy)
    accuracy = accuracy.type(torch.float32)
    padding_mask = padding_mask.type(torch.float32)
    return torch.sum(accuracy) / torch.sum(padding_mask)
    

In [75]:
def validateModel(model, testIter, loss, device):
    lossPerBatch = []
    accPerBatch = []
    with torch.no_grad():
        model.eval()
        for i, (X, y) in enumerate(testIter):
            X, y = X.to(device), y.to(device)
            out = model(X[:, 1:], y[:, :-1])
            l = loss(out.contiguous().view(-1, 4167), y[:, 1:].contiguous().view(-1))
            a = calculateAccuracy(out, y[:, 1:])
            lossPerBatch.append(l.item())
            accPerBatch.append(a.item())
        meanLoss = sum(lossPerBatch)/len(lossPerBatch)
    return lossPerBatch, meanLoss, accPerBatch

In [76]:
from torch import optim
import torch.nn as nn
from torch.optim import lr_scheduler
model = TransformerModel()
model.to(device)
optimizer = optim.Adam(model.parameters(), betas=(BETA_1, BETA_2), eps=EPSILON)
loss = nn.CrossEntropyLoss(ignore_index=0)
scheduler = TransformerLRScheduler(20)

In [77]:
from tqdm import tqdm



def train(model, trainIter, 
          testIter=None, 
          epochs=None, 
          loss=None, 
          optimizer=None, 
          device=device, 
          scheduler=None):
    # model.to(device)
    # pbar = trange(epochs, desc="Epochs ", unit="batches")
    # with tqdm(trainIter, unit="epochs") as tepoch:
    logs_dic = {
        "valildationLoss": [],
        "trainingLoss" : [],
        "validationAccuracy": [],
        "trainingAccuracy": []
    }
    for epoch in range(epochs):
        trainLossPerBatch = []
        trainAccuracyPerBatch = []
        with tqdm(trainIter, unit="batches") as tepoch:
            for i, (X,y) in enumerate(tepoch):
                model.train()
                optimizer.zero_grad()
                X, y = X.to(device), y.to(device)
                out = model(X[:, 1:], y[:, :-1])
                l = loss(out.contiguous().view(-1, 4167), y[:, 1:].contiguous().view(-1))
                acc = calculateAccuracy(out, y[:, 1:])
                trainLossPerBatch.append(l.item())
                trainAccuracyPerBatch.append(acc.item())
                tepoch.set_description(f"Epoch {epoch + 1}")            
                tepoch.set_postfix(loss=l.item(), accuracy=acc.item())
                l.backward()
                optimizer.step()
            valLoss, meanValLoss, valAcc = validateModel(model, testIter, loss=loss, device=device)
            print(f"The validation loss is: {meanValLoss}")
            logs_dic["valildationLoss"].append(valLoss)
            logs_dic["trainingLoss"].append(trainLossPerBatch)
            logs_dic["trainingAccuracy"].append(trainAccuracyPerBatch)
            logs_dic["validationAccuracy"].append(valAcc)
            # print(f"Epoch: {epoch+1}     loss: {l}")
            if scheduler:
                if scheduler.__module__ == lr_scheduler.__name__:
                    scheduler.step()
                else:
                    for param_group in optimizer.param_groups:
                        lr = scheduler(epoch)
                        param_group['lr'] = lr
                
    return logs_dic
history = train(model, trainIter, testIter=valIter, epochs=100, loss=loss, optimizer=optimizer, device=device, scheduler=scheduler)

Epoch 1:   0%|          | 0/141 [00:00<?, ?batches/s, accuracy=0, loss=8.29]

Epoch 1: 100%|██████████| 141/141 [00:09<00:00, 15.45batches/s, accuracy=0.5, loss=3.43]  


The validation loss is: 3.6708921790122986


Epoch 2: 100%|██████████| 141/141 [00:09<00:00, 14.88batches/s, accuracy=0.574, loss=2.82]


The validation loss is: 3.1927057802677155


Epoch 3: 100%|██████████| 141/141 [00:09<00:00, 15.21batches/s, accuracy=0.579, loss=2.74]


The validation loss is: 2.8250343948602676


Epoch 4: 100%|██████████| 141/141 [00:08<00:00, 16.01batches/s, accuracy=0.649, loss=2.19]


The validation loss is: 2.5116057693958282


Epoch 5: 100%|██████████| 141/141 [00:08<00:00, 15.70batches/s, accuracy=0.577, loss=2.43]


The validation loss is: 2.318987302482128


Epoch 6: 100%|██████████| 141/141 [00:09<00:00, 14.83batches/s, accuracy=0.716, loss=1.45]


The validation loss is: 2.1574486270546913


Epoch 7: 100%|██████████| 141/141 [00:09<00:00, 14.91batches/s, accuracy=0.661, loss=1.58]


The validation loss is: 2.108462393283844


Epoch 8: 100%|██████████| 141/141 [00:09<00:00, 15.25batches/s, accuracy=0.739, loss=1.23] 


The validation loss is: 2.0209746211767197


Epoch 9: 100%|██████████| 141/141 [00:09<00:00, 15.17batches/s, accuracy=0.788, loss=0.941]


The validation loss is: 2.2718184366822243


Epoch 10: 100%|██████████| 141/141 [00:09<00:00, 15.51batches/s, accuracy=0.781, loss=0.963]


The validation loss is: 2.263396233320236


Epoch 11: 100%|██████████| 141/141 [00:09<00:00, 15.39batches/s, accuracy=0.696, loss=1.52] 


The validation loss is: 2.3549006655812263


Epoch 12: 100%|██████████| 141/141 [00:09<00:00, 15.37batches/s, accuracy=0.706, loss=1.25] 


The validation loss is: 2.405008926987648


Epoch 13: 100%|██████████| 141/141 [00:09<00:00, 15.54batches/s, accuracy=0.696, loss=1.27] 


The validation loss is: 2.4056278243660927


Epoch 14: 100%|██████████| 141/141 [00:09<00:00, 15.65batches/s, accuracy=0.731, loss=1.38] 


The validation loss is: 2.597766488790512


Epoch 15: 100%|██████████| 141/141 [00:09<00:00, 15.64batches/s, accuracy=0.827, loss=0.826]


The validation loss is: 2.4417013376951218


Epoch 16: 100%|██████████| 141/141 [00:09<00:00, 15.34batches/s, accuracy=0.77, loss=1.09]  


The validation loss is: 2.4305384978652


Epoch 17: 100%|██████████| 141/141 [00:09<00:00, 15.19batches/s, accuracy=0.772, loss=0.881]


The validation loss is: 2.513243466615677


Epoch 18: 100%|██████████| 141/141 [00:08<00:00, 15.74batches/s, accuracy=0.753, loss=1.17] 


The validation loss is: 2.4118078127503395


Epoch 19: 100%|██████████| 141/141 [00:09<00:00, 15.57batches/s, accuracy=0.805, loss=0.757]


The validation loss is: 2.423986181616783


Epoch 20: 100%|██████████| 141/141 [00:09<00:00, 15.23batches/s, accuracy=0.742, loss=1.1]  


The validation loss is: 2.39555025100708


Epoch 21: 100%|██████████| 141/141 [00:09<00:00, 15.22batches/s, accuracy=0.799, loss=0.977]


The validation loss is: 2.43848218023777


Epoch 22: 100%|██████████| 141/141 [00:08<00:00, 15.82batches/s, accuracy=0.744, loss=1.07] 


The validation loss is: 2.4493724033236504


Epoch 23: 100%|██████████| 141/141 [00:09<00:00, 15.34batches/s, accuracy=0.855, loss=0.776]


The validation loss is: 2.3991695418953896


Epoch 24: 100%|██████████| 141/141 [00:09<00:00, 15.00batches/s, accuracy=0.893, loss=0.43] 


The validation loss is: 2.3736437633633614


Epoch 25: 100%|██████████| 141/141 [00:09<00:00, 15.62batches/s, accuracy=0.855, loss=0.622]


The validation loss is: 2.363986909389496


Epoch 26: 100%|██████████| 141/141 [00:09<00:00, 15.66batches/s, accuracy=0.897, loss=0.405]


The validation loss is: 2.454198196530342


Epoch 27: 100%|██████████| 141/141 [00:09<00:00, 15.44batches/s, accuracy=0.886, loss=0.387]


The validation loss is: 2.43957556784153


Epoch 28: 100%|██████████| 141/141 [00:09<00:00, 15.13batches/s, accuracy=0.864, loss=0.48] 


The validation loss is: 2.475130155682564


Epoch 29: 100%|██████████| 141/141 [00:09<00:00, 15.54batches/s, accuracy=0.853, loss=0.599]


The validation loss is: 2.3790031522512436


Epoch 30: 100%|██████████| 141/141 [00:09<00:00, 15.07batches/s, accuracy=0.89, loss=0.334] 


The validation loss is: 2.438953571021557


Epoch 31: 100%|██████████| 141/141 [00:09<00:00, 15.51batches/s, accuracy=0.924, loss=0.291]


The validation loss is: 2.435166008770466


Epoch 32: 100%|██████████| 141/141 [00:09<00:00, 15.26batches/s, accuracy=0.895, loss=0.447]


The validation loss is: 2.408282719552517


Epoch 33: 100%|██████████| 141/141 [00:09<00:00, 15.45batches/s, accuracy=0.916, loss=0.268]


The validation loss is: 2.3945429623126984


Epoch 34: 100%|██████████| 141/141 [00:09<00:00, 14.92batches/s, accuracy=0.882, loss=0.358]


The validation loss is: 2.440194569528103


Epoch 35: 100%|██████████| 141/141 [00:09<00:00, 15.42batches/s, accuracy=0.852, loss=0.445]


The validation loss is: 2.4954653829336166


Epoch 36: 100%|██████████| 141/141 [00:09<00:00, 15.25batches/s, accuracy=0.932, loss=0.179]


The validation loss is: 2.4449894055724144


Epoch 37: 100%|██████████| 141/141 [00:08<00:00, 15.75batches/s, accuracy=0.91, loss=0.256] 


The validation loss is: 2.4414227083325386


Epoch 38: 100%|██████████| 141/141 [00:09<00:00, 15.04batches/s, accuracy=0.905, loss=0.277]


The validation loss is: 2.4296866357326508


Epoch 39: 100%|██████████| 141/141 [00:09<00:00, 15.41batches/s, accuracy=0.895, loss=0.408]


The validation loss is: 2.4206850305199623


Epoch 40: 100%|██████████| 141/141 [00:08<00:00, 15.81batches/s, accuracy=0.915, loss=0.312] 


The validation loss is: 2.4508455470204353


Epoch 41: 100%|██████████| 141/141 [00:09<00:00, 15.33batches/s, accuracy=0.893, loss=0.28] 


The validation loss is: 2.446577861905098


Epoch 42: 100%|██████████| 141/141 [00:09<00:00, 15.22batches/s, accuracy=0.948, loss=0.166]


The validation loss is: 2.484200395643711


Epoch 43: 100%|██████████| 141/141 [00:09<00:00, 15.46batches/s, accuracy=0.909, loss=0.242] 


The validation loss is: 2.4314692839980125


Epoch 44: 100%|██████████| 141/141 [00:09<00:00, 15.19batches/s, accuracy=0.918, loss=0.297]


The validation loss is: 2.51539858430624


Epoch 45: 100%|██████████| 141/141 [00:09<00:00, 15.34batches/s, accuracy=0.897, loss=0.434] 


The validation loss is: 2.4533243849873543


Epoch 46: 100%|██████████| 141/141 [00:09<00:00, 15.08batches/s, accuracy=0.956, loss=0.0966]


The validation loss is: 2.3826469853520393


Epoch 47: 100%|██████████| 141/141 [00:09<00:00, 15.34batches/s, accuracy=0.938, loss=0.177] 


The validation loss is: 2.429791674017906


Epoch 48: 100%|██████████| 141/141 [00:09<00:00, 15.34batches/s, accuracy=0.967, loss=0.107] 


The validation loss is: 2.3964041993021965


Epoch 49: 100%|██████████| 141/141 [00:09<00:00, 15.08batches/s, accuracy=0.951, loss=0.161] 


The validation loss is: 2.441951848566532


Epoch 50: 100%|██████████| 141/141 [00:09<00:00, 15.01batches/s, accuracy=0.941, loss=0.235] 


The validation loss is: 2.3985691517591476


Epoch 51: 100%|██████████| 141/141 [00:09<00:00, 14.52batches/s, accuracy=0.95, loss=0.151]  


The validation loss is: 2.416132867336273


Epoch 52: 100%|██████████| 141/141 [00:09<00:00, 14.14batches/s, accuracy=0.96, loss=0.11]   


The validation loss is: 2.428842380642891


Epoch 53: 100%|██████████| 141/141 [00:09<00:00, 14.47batches/s, accuracy=0.972, loss=0.146] 


The validation loss is: 2.4510183706879616


Epoch 54: 100%|██████████| 141/141 [00:09<00:00, 14.71batches/s, accuracy=0.954, loss=0.185] 


The validation loss is: 2.4140170514583588


Epoch 55: 100%|██████████| 141/141 [00:09<00:00, 14.73batches/s, accuracy=0.95, loss=0.149]  


The validation loss is: 2.432015120983124


Epoch 56: 100%|██████████| 141/141 [00:09<00:00, 14.94batches/s, accuracy=0.972, loss=0.103] 


The validation loss is: 2.4406226724386215


Epoch 57: 100%|██████████| 141/141 [00:09<00:00, 15.31batches/s, accuracy=0.949, loss=0.176]


The validation loss is: 2.4900977090001106


Epoch 58: 100%|██████████| 141/141 [00:09<00:00, 14.75batches/s, accuracy=0.955, loss=0.106]


The validation loss is: 2.4943379536271095


Epoch 59: 100%|██████████| 141/141 [00:09<00:00, 14.56batches/s, accuracy=0.972, loss=0.105] 


The validation loss is: 2.5057672634720802


Epoch 60: 100%|██████████| 141/141 [00:09<00:00, 15.28batches/s, accuracy=0.939, loss=0.183] 


The validation loss is: 2.4684274941682816


Epoch 61: 100%|██████████| 141/141 [00:09<00:00, 15.07batches/s, accuracy=0.706, loss=1.35]  


The validation loss is: 2.442855179309845


Epoch 62: 100%|██████████| 141/141 [00:09<00:00, 14.83batches/s, accuracy=0.988, loss=0.0666]


The validation loss is: 2.4079596176743507


Epoch 63: 100%|██████████| 141/141 [00:09<00:00, 14.50batches/s, accuracy=0.962, loss=0.108] 


The validation loss is: 2.4858641177415848


Epoch 64: 100%|██████████| 141/141 [00:09<00:00, 14.58batches/s, accuracy=0.941, loss=0.279] 


The validation loss is: 2.4442346319556236


Epoch 65: 100%|██████████| 141/141 [00:09<00:00, 14.92batches/s, accuracy=0.989, loss=0.0735]


The validation loss is: 2.458151251077652


Epoch 66: 100%|██████████| 141/141 [00:09<00:00, 14.38batches/s, accuracy=0.914, loss=0.252] 


The validation loss is: 2.4854013845324516


Epoch 67: 100%|██████████| 141/141 [00:09<00:00, 14.72batches/s, accuracy=0.955, loss=0.141] 


The validation loss is: 2.4065169245004654


Epoch 68: 100%|██████████| 141/141 [00:09<00:00, 14.51batches/s, accuracy=0.96, loss=0.166]  


The validation loss is: 2.3798323422670364


Epoch 69: 100%|██████████| 141/141 [00:09<00:00, 14.62batches/s, accuracy=0.983, loss=0.0701]


The validation loss is: 2.5471760779619217


Epoch 70: 100%|██████████| 141/141 [00:09<00:00, 15.28batches/s, accuracy=0.942, loss=0.149] 


The validation loss is: 2.4613746106624603


Epoch 71: 100%|██████████| 141/141 [00:09<00:00, 14.73batches/s, accuracy=0.947, loss=0.163] 


The validation loss is: 2.432933948934078


Epoch 72: 100%|██████████| 141/141 [00:09<00:00, 15.04batches/s, accuracy=0.953, loss=0.187] 


The validation loss is: 2.4253192394971848


Epoch 73: 100%|██████████| 141/141 [00:09<00:00, 15.20batches/s, accuracy=0.989, loss=0.0484]


The validation loss is: 2.4780069068074226


Epoch 74: 100%|██████████| 141/141 [00:09<00:00, 15.40batches/s, accuracy=0.989, loss=0.0313]


The validation loss is: 2.3607506453990936


Epoch 75: 100%|██████████| 141/141 [00:09<00:00, 14.96batches/s, accuracy=0.949, loss=0.16]  


The validation loss is: 2.388797625899315


Epoch 76: 100%|██████████| 141/141 [00:09<00:00, 14.96batches/s, accuracy=0.83, loss=0.749]  


The validation loss is: 2.3817024007439613


Epoch 77: 100%|██████████| 141/141 [00:09<00:00, 14.80batches/s, accuracy=0.955, loss=0.13]  


The validation loss is: 2.372066468000412


Epoch 78: 100%|██████████| 141/141 [00:09<00:00, 15.15batches/s, accuracy=0.973, loss=0.117] 


The validation loss is: 2.3921579718589783


Epoch 79: 100%|██████████| 141/141 [00:09<00:00, 14.58batches/s, accuracy=0.966, loss=0.0929]


The validation loss is: 2.4209439381957054


Epoch 80: 100%|██████████| 141/141 [00:09<00:00, 15.07batches/s, accuracy=0.983, loss=0.0302]


The validation loss is: 2.402333326637745


Epoch 81: 100%|██████████| 141/141 [00:09<00:00, 14.85batches/s, accuracy=1, loss=0.0161]    


The validation loss is: 2.4163745045661926


Epoch 82: 100%|██████████| 141/141 [00:09<00:00, 15.01batches/s, accuracy=0.768, loss=1.06]  


The validation loss is: 2.4674023166298866


Epoch 83: 100%|██████████| 141/141 [00:09<00:00, 14.81batches/s, accuracy=0.989, loss=0.0358]


The validation loss is: 2.420039303600788


Epoch 84: 100%|██████████| 141/141 [00:09<00:00, 14.80batches/s, accuracy=0.968, loss=0.0736]


The validation loss is: 2.4660540744662285


Epoch 85: 100%|██████████| 141/141 [00:09<00:00, 15.05batches/s, accuracy=0.96, loss=0.12]   


The validation loss is: 2.4337507858872414


Epoch 86: 100%|██████████| 141/141 [00:09<00:00, 15.17batches/s, accuracy=0.994, loss=0.0168]


The validation loss is: 2.488387629389763


Epoch 87: 100%|██████████| 141/141 [00:09<00:00, 14.91batches/s, accuracy=0.978, loss=0.0491]


The validation loss is: 2.4742741510272026


Epoch 88: 100%|██████████| 141/141 [00:09<00:00, 14.84batches/s, accuracy=0.969, loss=0.114] 


The validation loss is: 2.503173589706421


Epoch 89: 100%|██████████| 141/141 [00:09<00:00, 14.80batches/s, accuracy=0.968, loss=0.0843]


The validation loss is: 2.487916871905327


Epoch 90: 100%|██████████| 141/141 [00:09<00:00, 15.23batches/s, accuracy=0.984, loss=0.0782]


The validation loss is: 2.47575431317091


Epoch 91: 100%|██████████| 141/141 [00:09<00:00, 15.26batches/s, accuracy=0.984, loss=0.0614]


The validation loss is: 2.4763273373246193


Epoch 92: 100%|██████████| 141/141 [00:09<00:00, 14.97batches/s, accuracy=0.984, loss=0.0456]


The validation loss is: 2.4875194802880287


Epoch 93: 100%|██████████| 141/141 [00:09<00:00, 14.98batches/s, accuracy=0.994, loss=0.0225]


The validation loss is: 2.504108637571335


Epoch 94: 100%|██████████| 141/141 [00:09<00:00, 15.14batches/s, accuracy=0.995, loss=0.0184]


The validation loss is: 2.5152123793959618


Epoch 95: 100%|██████████| 141/141 [00:09<00:00, 14.92batches/s, accuracy=0.978, loss=0.0806]


The validation loss is: 2.4345881417393684


Epoch 96: 100%|██████████| 141/141 [00:09<00:00, 14.59batches/s, accuracy=0.99, loss=0.0265] 


The validation loss is: 2.560623422265053


Epoch 97: 100%|██████████| 141/141 [00:09<00:00, 14.57batches/s, accuracy=0.994, loss=0.0282]


The validation loss is: 2.468105398118496


Epoch 98: 100%|██████████| 141/141 [00:09<00:00, 14.52batches/s, accuracy=0.977, loss=0.0937]


The validation loss is: 2.514083608984947


Epoch 99: 100%|██████████| 141/141 [00:09<00:00, 14.57batches/s, accuracy=1, loss=0.0164]    


The validation loss is: 2.528106354176998


Epoch 100: 100%|██████████| 141/141 [00:09<00:00, 14.71batches/s, accuracy=0.971, loss=0.0748]


The validation loss is: 2.4915659576654434


In [78]:
def saveHistory(history, filename):
    print("pickling history.")
    with open(filename, 'wb') as fp:
        pickle.dump(history, fp)
    print("successfully pickled") 
    
saveHistory(history, "./history2")   

pickling history.
successfully pickled


In [18]:
# %matplotlib inline
# import matplotlib.pyplot as plt

# lr = TransformerLRScheduler(1000)
# plt.plot([i for i in range(5000)], [lr(i) for i in range(5000)])
# plt.grid()
# plt.xlabel('Epochs')
# plt.ylabel("Learning Rate")

In [79]:
torch.save(model.state_dict(), '2.pth')


In [81]:
model.load_state_dict(torch.load('2.pth', map_location='cpu'))

def calculateAccuracyTest(prediction, target):
    paddingMask = torch.logical_not(torch.eq(target, torch.tensor(0)))
    # print(target, prediction)
    targLen = len(target[0, :])
    accuracy = torch.eq(target,prediction[:, :targLen])
    accuracy = torch.logical_and(paddingMask, accuracy)
    accuracy = accuracy.type(torch.float32)
    paddingMask = paddingMask.type(torch.float32)
    return torch.sum(accuracy) / torch.sum(paddingMask)



def test(model, encInput):
    model.eval()
    batch, seq_length = encInput.size()
    out = [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    decOutput =  torch.tensor(out)
    decOutput = torch.tile(decOutput, (batch, 1))
    for i in range(DE_MAX_SEQ_LEN-1):
        prediction = model(encInput.to(device)[:, 1:], decOutput.to(device))
        
        prediction = prediction[:, i, :]
        predictedId = torch.argmax(prediction, dim=-1)
        decOutput[:, i+1] = predictedId[ :]
        if predictedId.tolist()[0] == EOS_IDX[0]:
          break
    return decOutput
testIter = DataLoader(testData, batch_size=1, shuffle=True, collate_fn=generateBatch)



for i, (X, Y) in enumerate(testIter):
    pred = test(model, X)
    for x, y, z in zip(pred.tolist(), X.tolist(), Y.tolist()):
      print(' '.join(vocabDe.lookup_tokens(x)))
      print(' '.join(vocabEn.lookup_tokens(y)))
      print(' '.join(vocabDe.lookup_tokens(z)))
      print("The accuracy for this prediction is:", round(calculateAccuracyTest(pred, Y).item(), 4))
      print("\n")
      
    if i == 5:
      break

<sos> ich hab <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
<sos> i m so excite <eos> <pad>
<sos> ich bin ja so aufgereg <eos>
The accuracy for this prediction is: 0.2857


<sos> sollen wir beginne a <eos> <pad> <pad> <pad> <pad> <pad> <pad>
<sos> shall we star <eos>
<sos> sollen wir anfange <eos>
The accuracy for this prediction is: 0.6


<sos> wie er si <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
<sos> how beautifu <eos>
<sos> wie scho <eos>
The accuracy for this prediction is: 0.5


<sos> ich kenne <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
<sos> i know your so <eos>
<sos> ich kenne ihren soh <eos>
The accuracy for this prediction is: 0.5


<sos> du hast <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
<sos> you should res <eos> <pad>
<sos> sie sollten sich ausruhe <eos>
The accuracy for this prediction is: 0.1667


<sos> auto schmeckt <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
<sos> is this car ne <eos>
<sos> ist dieser wagen ne <eos>
The accu

with open("history64h4l1", 'rb') as fp:
    logs = pickle.load(fp)
    


In [20]:
import matplotlib.pyplot as plt

plt.style.available

['Solarize_Light2',
 '_classic_test_patch',
 '_mpl-gallery',
 '_mpl-gallery-nogrid',
 'bmh',
 'classic',
 'dark_background',
 'fast',
 'fivethirtyeight',
 'ggplot',
 'grayscale',
 'seaborn-v0_8',
 'seaborn-v0_8-bright',
 'seaborn-v0_8-colorblind',
 'seaborn-v0_8-dark',
 'seaborn-v0_8-dark-palette',
 'seaborn-v0_8-darkgrid',
 'seaborn-v0_8-deep',
 'seaborn-v0_8-muted',
 'seaborn-v0_8-notebook',
 'seaborn-v0_8-paper',
 'seaborn-v0_8-pastel',
 'seaborn-v0_8-poster',
 'seaborn-v0_8-talk',
 'seaborn-v0_8-ticks',
 'seaborn-v0_8-white',
 'seaborn-v0_8-whitegrid',
 'tableau-colorblind10']