In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchtext.datasets import IWSLT
from torchtext.data import Field, BucketIterator
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.nn import TransformerDecoder, TransformerDecoderLayer

import torch.nn.functional as F

import numpy as np
import spacy
import random

In [None]:
import spacy.cli
import en_core_web_sm
import de_core_news_sm


spacy.cli.download("en_core_web_sm")
spacy.cli.download("de_core_news_sm")


spacy_ger = de_core_news_sm.load()
spacy_eng = en_core_web_sm.load()

In [None]:
def tokenizer_de(text):
  return [tok.text for tok in spacy_ger.tokenizer(text)]

def tokenizer_eng(text):
  return [tok.text for tok in spacy_eng.tokenizer(text)]

In [None]:
german = Field(tokenize=tokenizer_de, lower=True, eos_token="<eos>")

english = Field(
    tokenize=tokenizer_eng, lower=True, init_token="<sos>", eos_token="<eos>"
)

In [None]:
train_data, valid_data, test_data = IWSLT.splits(
    exts=(".de", ".en"), fields=(german, english)
)

In [None]:
german.build_vocab(train_data, max_size=10000, min_freq=2)
english.build_vocab(train_data, max_size=10000, min_freq=2)

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.scale = nn.Parameter(torch.ones(1))

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(
            0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.scale * self.pe[:x.size(0), :]
        return self.dropout(x)

In [None]:
class TransformerModel(nn.Module):
    
    def __init__(self, intoken, outtoken ,hidden, enc_layers=2, dec_layers=2, dropout=.1, nheads=2, ff_model=128, ts_unique=70):
        super(TransformerModel, self).__init__()
        
        self.encoder = nn.Embedding(intoken, hidden)
        self.pos_encoder = PositionalEncoding(hidden, dropout)

        self.decoder = nn.Embedding(outtoken, hidden) 
        self.pos_decoder = PositionalEncoding(hidden, dropout)
        
        
        encoder_layers = TransformerEncoderLayer(d_model=hidden, nhead = nheads, dim_feedforward = ff_model, dropout=dropout, activation='relu')
        self.transformer_encoder = TransformerEncoder(encoder_layers, enc_layers)

        encoder_layers = TransformerDecoderLayer(hidden, nheads, ff_model, dropout, activation='relu')
        self.transformer_decoder = TransformerDecoder(encoder_layers, dec_layers)        

        self.fc_out = nn.Linear(hidden, outtoken)

        self.src_mask = None
        self.trg_mask = None
        self.memory_mask = None

        
    def generate_square_subsequent_mask(self, sz, sz1=None):
        
        if sz1 == None:
            mask = torch.triu(torch.ones(sz, sz), 1)
        else:
            mask = torch.triu(torch.ones(sz, sz1), 1)
            
        return mask.masked_fill(mask==1, float('-inf'))

    def make_len_mask_enc(self, inp):
        return (inp == de_pad_idx).transpose(0, 1)   #(batch_size, output_seq_len)
    
    def make_len_mask_dec(self, inp):
        return (inp == en_pad_idx).transpose(0, 1) #(batch_size, input_seq_len)
    


    def forward(self, src, trg): #SRC: (seq_len, batch_size)

        if self.trg_mask is None or self.trg_mask.size(0) != len(trg):
            self.trg_mask = self.generate_square_subsequent_mask(len(trg)).to(trg.device)
            
        if self.src_mask is None or self.src_mask.size(0) != len(src):
            self.src_mask = self.generate_square_subsequent_mask(len(src)).to(trg.device)
            
        if self.memory_mask is None or self.memory_mask.size(0) != len(trg) or self.memory_mask.size(1) != len(src):
            self.memory_mask = self.generate_square_subsequent_mask(len(trg),len(src)).to(trg.device)
            

        #Adding padding mask
        src_pad_mask = self.make_len_mask_enc(src)
        trg_pad_mask = self.make_len_mask_dec(trg)
             

        #Add embeddings Encoder
        src = self.encoder(src)  #Embedding, (seq_len, batch_size, d_model)
        src = self.pos_encoder(src)   #Pos embedding
        
        
        #Add embedding decoder
        trg = self.decoder(trg) #(seq_len, batch_size, d_model)
        trg = self.pos_decoder(trg)

        memory = self.transformer_encoder(src, self.src_mask, src_pad_mask)
        output = self.transformer_decoder(tgt = trg, memory = memory, tgt_mask = self.trg_mask, memory_mask = self.memory_mask, 
                                          tgt_key_padding_mask = trg_pad_mask, memory_key_padding_mask = src_pad_mask)

        output = self.fc_out(output)

        return output

In [None]:
#Training Hyperparameters
num_epochs = 100
lr = 3e-4
batch_size = 16

In [None]:
#Model Hyperparameter
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
input_size_encoder = len(german.vocab)
input_size_decoder = len(english.vocab)
output_size = len(english.vocab)

d_model = 128

import math 

model = TransformerModel(input_size_encoder, input_size_decoder ,d_model, enc_layers=2, dec_layers=2, dropout=.1, nheads=2, ff_model=128, ts_unique=70).to(device)

In [None]:
step = 0

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data),
    batch_size=batch_size,
    sort_within_batch=True,
    sort_key=lambda x: len(x.src),
    device=device,
)

In [None]:
en_pad_idx = english.vocab.stoi['<pad>']
de_pad_idx = german.vocab.stoi['<pad>']

criterion = nn.CrossEntropyLoss(ignore_index=en_pad_idx)
optimizer = optim.Adam(model.parameters(), lr=lr)

In [None]:
def translate_sentence(model, sentence, german, english, device, max_length=50):
    model.eval()
    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    if type(sentence) == str:
        tokens = [token.text.lower() for token in spacy_ger(sentence)]
    else:
        tokens = [token.lower() for token in sentence]

    # Add <SOS> and <EOS> in beginning and end respectively
    tokens.append(german.eos_token)

    # Go through each german token and convert to an index
    text_to_indices = [german.vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)

    preds = [english.vocab.stoi[english.init_token]]

    with torch.no_grad():
        
        emb_src = model.encoder(sentence_tensor)
        emb_src = model.pos_encoder(emb_src)

        memory = model.transformer_encoder(emb_src)

        global aaa 
        aaa = memory
        for i in range(50):

            trg = torch.Tensor(preds).long().unsqueeze(1).to(device)
            trg = model.decoder(trg)
            trg = model.pos_decoder(trg)

            out = model.transformer_decoder(tgt = trg, memory = memory)
            out = model.fc_out(out)
            
            

            new = out.squeeze(1)[-1].argmax().item()
            preds.append(new)
            if new == english.vocab.stoi["<eos>"]:
                break

    
    return [english.vocab.itos[i] for i in preds]

In [None]:
sentence = "ein boot mit mehreren männern darauf wird von einem großen pferdegespann ans ufer gezogen."
total_loss = 0

for epoch in range(num_epochs):
    
    print(f'Epoch [{epoch} / {num_epochs}]')

    #checkpoint = {'state_dict': model.state_dict(), optimizer: optimizer.state_dict()}
    #save_checkpoint(checkpoint)

    model.eval()
    translated_sentence = translate_sentence(model, sentence, german, english, device, max_length=50)
    
    out = beam("ein Mann, der einen Vogel sieht.", 3)   
    
    print(f"Translated example sentence: \n {list(map(convert, out))}")
    print(total_loss/len(train_iterator))
    total_loss = 0
    
    model.train()

    for batch_idx, batch in enumerate(train_iterator):
        
        inp_data = batch.src.to(device)
        target = batch.trg.to(device)

        output = model(inp_data, target[:-1, ])
        output = output.reshape(-1, output.shape[2])
        target = target[1:].reshape(-1)

        optimizer.zero_grad()
        loss = criterion(output, target)
        print(loss)
        total_loss += loss
        
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
        optimizer.step()

        step = step + 1
        
    

In [20]:
translate_sentence(model, "ein Mann, der einen Vogel sieht", german, english, device, max_length=50)

['<sos>', 'a', 'bird', 'is', 'looking', 'at', 'a', 'bird', '.', '<eos>']

In [None]:
def get_out_encoder(src):
    
    model.eval()
    tokens = [token.text.lower() for token in spacy_ger(src)]

    tokens.append(german.eos_token)

    text_to_indices = [german.vocab.stoi[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)    

    with torch.no_grad():
        
        emb_src = model.encoder(sentence_tensor)
        emb_src = model.pos_encoder(emb_src)

        memory = model.transformer_encoder(emb_src)

        return memory

In [None]:
def beam(phrase, k):
    
    model.eval()
    memory = get_out_encoder(phrase)

    sos = english.vocab.stoi["<sos>"]
    tgt = [sos]

    with torch.no_grad():

        trg = torch.Tensor(tgt).long().unsqueeze(1).to(device)
        trg = model.decoder(trg)
        trg = model.pos_decoder(trg)

        out = model.transformer_decoder(tgt = trg, memory = memory)
        out = F.softmax(model.fc_out(out), dim=-1)[-1].squeeze()

        args = out.argsort()[-k:].detach().cpu().numpy()
        probs = out[args].detach().cpu().numpy()

        probs = np.log(probs)
        possible = list(zip([tgt + [args[i]] for i in range(k)], probs))
        
        for i in range(50):

            test=  []
            for j in range(k):

                tmp_tgt, tmp_prob = possible[j]

                if tmp_tgt[-1] == english.vocab.stoi["<eos>"]:
                    test.append(possible[j])

                else:
                    trg = torch.Tensor(tmp_tgt).long().unsqueeze(1).to(device)
                    trg = model.decoder(trg)
                    trg = model.pos_decoder(trg)

                    out = model.transformer_decoder(tgt = trg, memory = memory)
                    out = F.softmax(model.fc_out(out), dim=-1)[-1].squeeze()

                    tmp_args = out.argsort()[-k:].detach().cpu().numpy()
                    tmp_probs = out[tmp_args].detach().cpu().numpy()
                    tmp_probs = (tmp_prob + np.log(tmp_probs))/(len(tmp_tgt)-1)

                    for r in range(k): 
                        test.append((tmp_tgt + [tmp_args[r]], tmp_probs[r]))


            possible = sorted(test, key=lambda x:x[1], reverse=True)[:k]
            
    return possible

In [None]:
def convert(x):
    
    sentence = x[0]
    sentence = [english.vocab.itos[i] for i in sentence]
    
    return (" ".join(sentence), x[1])

In [19]:
%%time
out =  beam("ich habe essen.", 10)  
list(map(convert, out))

Wall time: 457 ms


[("<sos> so , i 've been a lot . i 've got a lot . i was a lot . <eos>",
  -0.0089537585),
 ("<sos> so , i 've been a lot . i 've got a lot . <eos>", -0.014688058),
 ("<sos> so , i 've got a lot . i was a lot . <eos>", -0.015588673),
 ("<sos> now . i 've got a lot . <eos>", -0.04200139),
 ("<sos> so , i 've been a lot . <eos>", -0.051831104),
 ("<sos> so , i 've got a lot . <eos>", -0.053878494),
 ("<sos> so i 've got a lot . <eos>", -0.055535752),
 ("<sos> so i 've been a lot . <eos>", -0.055751204),
 ("<sos> and i 've got a lot . <eos>", -0.060735006),
 ("<sos> so i 've got to be . <eos>", -0.09270295)]

In [42]:
translate_sentence(model, "ich habe Hunger.", german, english, device, max_length=50)

['<sos>', 'the', '<unk>', 'of', '<unk>', '<unk>', '.', '<eos>']

In [21]:
next(iter(train_iterator)).src

tensor([[ 186,   33,    9,    6,    5,    8,  259,    7,  339,   13,    0,    5,
            6,    5,  449,   61],
        [  12,   12, 1482,  627,   60,  271,   32,   40,    7,  379,  364,   52,
            0,   13,    3,   41],
        [  95,   47,  163, 1014,  551,    3,   12,   14,   26,   20,   70,  113,
          348,   87,   14,   19],
        [1479,    8,  694,   28,  116,   30,  143,   10,   61,  228,  433,    0,
         8091,   94,   59,  735],
        [   3,   36,    3,    3,   12, 5466,   46,  527,  140,    3,   23,   50,
          348, 8543,    0,   15],
        [ 427,    0,    5,   21,    6,    6,    0,    3,  148,  156,   56,   12,
         5202,    3,  557,  207],
        [  12, 5424,   36,    6, 1382,    0, 6255,    6,    6,   30,  159,  333,
          839,   21,    7,    3],
        [ 555,  469,    9,  250, 7118,    3,    3,  306, 1225,   14,   57,   16,
           80, 7850,   46,    6],
        [   0,   49,   22, 1222,    3,    5,   10,  344,   39,   56,    6,  165,