In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchtext.datasets import Multi30k#IWSLT
from torchtext.data import Field, BucketIterator
import numpy as np
import spacy
import random
import sys 


import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence


import warnings
warnings.filterwarnings("ignore")

In [2]:
CFG = {"IN_LANG":"en", "OUT_LANG": "de"}

In [3]:
import spacy.cli
import en_core_web_sm
import de_core_news_sm


spacy.cli.download("en_core_web_sm")
spacy.cli.download("de_core_news_sm")


if CFG["IN_LANG"] == "en":
    spacy_in_lang = en_core_web_sm.load()
    spacy_out_lang = de_core_news_sm.load()
else:
    spacy_in_lang = de_core_news_sm.load()
    spacy_out_lang = en_core_web_sm.load()
    

✔ Download and installation successful
You can now load the model via spacy.load('en_core_web_sm')
✔ Download and installation successful
You can now load the model via spacy.load('de_core_news_sm')


In [4]:
def tokenizer_in(text):
    return [tok.text for tok in spacy_in_lang.tokenizer(text)]

def tokenizer_out(text):
    return [tok.text for tok in spacy_out_lang.tokenizer(text)]


in_lang = Field(tokenize=tokenizer_in, lower=True, include_lengths=True)
out_lang = Field(tokenize=tokenizer_out, lower=True, init_token="<sos>", eos_token="<eos>", include_lengths=True)

In [5]:
train_data, valid_data, test_data = Multi30k.splits(
        exts=("."+CFG["IN_LANG"], "."+CFG["OUT_LANG"]), fields=(in_lang, out_lang ))

In [6]:
in_lang.build_vocab(train_data, max_size=10000, min_freq=2)
out_lang.build_vocab(train_data, max_size=10000, min_freq=2)

In [7]:
class Encoder(nn.Module): 
    def __init__(self, input_size, embedding_size, hidden_size, num_layers, p):
        
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.embedding = nn.Embedding(input_size, embedding_size)
        self.rnn = nn.GRU(embedding_size, hidden_size, num_layers, bidirectional=True)

        self.fc_hidden = nn.Linear(hidden_size * 2, hidden_size)
        self.dropout = nn.Dropout(p)

    def forward(self, x, inp_length=None):
        
        embedding = self.dropout(self.embedding(x))
        
        if inp_length == None:
            encoder_states, hidden = self.rnn(embedding)
        else:      
            packed = pack_padded_sequence(embedding, inp_length.cpu()) #To speed up training
            encoder_states, hidden = self.rnn(packed)
            encoder_states, _ = pad_packed_sequence(encoder_states)

        hidden = self.fc_hidden(torch.cat((hidden[0:1], hidden[1:2]), dim=2))

        return encoder_states, hidden


In [8]:
class Decoder(nn.Module):
    
    def __init__(self, input_size, embedding_size, hidden_size, output_size, num_layers, p):
        
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.embedding = nn.Embedding(input_size, embedding_size)
        self.rnn = nn.GRU(hidden_size * 2 + embedding_size, hidden_size, num_layers)

        self.energy = nn.Linear(hidden_size, 1)
        self.fc = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(p)
        
        self.fc_key = nn.Linear(hidden_size, hidden_size)
        self.fc_query = nn.Linear(hidden_size*2, hidden_size)

    def forward(self, x, encoder_states, hidden, source):
        
        x = x.unsqueeze(0)
        embedding = self.dropout(self.embedding(x))

        
        key = self.fc_key(hidden)
        query = self.fc_query(encoder_states)
        
        energy = key+query
        energy = self.energy(torch.tanh(energy))
        energy = energy.squeeze(-1).masked_fill_((source == in_pad_idx), -float('inf')).unsqueeze(-1)

        attention = F.softmax(energy, dim=0)
        context_vector = torch.einsum("snk,snl->knl", attention, encoder_states)

        #Concatenate the context vector with the embedding of the previous word, and feed it to the GRU
        rnn_input = torch.cat((context_vector, embedding), dim=2)
        outputs, hidden = self.rnn(rnn_input, hidden)

        predictions = self.fc(outputs).squeeze(0)

        return predictions, hidden

In [9]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, source, target, inp_length):
        
        batch_size = source.shape[1]
        target_len = target.shape[0]
        target_vocab_size = len(out_lang.vocab)

        outputs = torch.zeros(target_len, batch_size, target_vocab_size).to(device)
        encoder_states, hidden = self.encoder(source, inp_length)

        x = target[0] #<SOS>

        for t in range(1, target_len):

            output, hidden = self.decoder(x, encoder_states, hidden, source)

            outputs[t] = output
            best_guess = output.argmax(1)

            x = target[t] #No teacher forcing
            
            
        return outputs

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#Training Hyperparameters
num_epochs = 100
lr = 3e-4
batch_size = 64
d_model = 128

input_size_encoder = len(in_lang.vocab)
input_size_decoder = len(out_lang.vocab)
output_size = len(out_lang.vocab)


encoder_embedding_size = d_model
decoder_embedding_size = d_model
hidden_size = d_model*4

num_layers = 1
dropout = 0.1

In [11]:
train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data),
    batch_size=batch_size,
    sort_within_batch=True,
    sort_key=lambda x: len(x.src),
    device=device,
)

In [12]:
encoder_net = Encoder(input_size_encoder, encoder_embedding_size, 
                      hidden_size, num_layers, dropout).to(device)

decoder_net = Decoder(input_size_decoder, decoder_embedding_size, 
                      hidden_size, output_size, num_layers, dropout).to(device)


model = Seq2Seq(encoder_net, decoder_net).to(device)

In [13]:
in_pad_idx = in_lang.vocab.stoi['<pad>']
out_pad_idx = out_lang.vocab.stoi['<pad>']


optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss(ignore_index=out_pad_idx)

In [14]:
def translate_sentence_bahdanau(model, sentence, max_length=50): #Translate from raw text using the trained model
    
    model.eval()

    tokens = [token.text.lower() for token in spacy_in_lang(sentence)]

    text_to_indices = [in_lang.vocab.stoi[token] for token in tokens]
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)

    preds = [out_lang.vocab.stoi[out_lang.init_token]]

    with torch.no_grad():
        
        encoder_states, hidden = model.encoder(sentence_tensor)
        
        for t in range(max_length):
                    
            trg = torch.Tensor([preds[-1]]).long().to(device)

            output, hidden = model.decoder(trg, encoder_states, hidden, sentence_tensor,print_att=True)
            new = output.argmax(1).item()
            
            preds.append(new)
            
            if new == out_lang.vocab.stoi["<eos>"]:
                break
            
        
    return [out_lang.vocab.itos[i] for i in preds][1:-1]

In [15]:
def beam(phrase, k):  #K: beam width
    
    model.eval()
    
    sos = out_lang.vocab.stoi["<sos>"]
    tgt = [sos]
    
    #Prepare sentence
    tokens = [token.text.lower() for token in spacy_in_lang(phrase)]
    tokens.append(in_lang.eos_token)
    tokens.insert(0, in_lang.init_token)

    text_to_indices = [in_lang.vocab.stoi[token] for token in tokens]
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(1).to(device)    
    

    with torch.no_grad():

        #Get encoder output
        encoder_states, hidden = model.encoder(sentence_tensor)
        
        
        #Get first output from model
        trg = torch.Tensor([tgt[-1]]).long().to(device)

        output, hidden = model.decoder(trg, encoder_states, hidden,sentence_tensor)
        out = F.softmax(output).squeeze()



        args = out.argsort()[-k:]
        probs = out[args].detach().cpu().numpy()
        
        args = args.detach().cpu().numpy()
        
        
        probs = np.log(probs)
        possible = list(zip([tgt + [args[i]] for i in range(k)], probs, [hidden.clone() for j in range(k)]))


        for i in range(50):

            test=  []
            for j in range(k):

                tmp_tgt, tmp_prob, tmp_hidden = possible[j]

                if tmp_tgt[-1] == out_lang.vocab.stoi["<eos>"]:  #If sentence already ended
                    test.append(possible[j])

                else:
                    
                    #Compute output
                    trg = torch.Tensor([tmp_tgt[-1]]).long().to(device)

                    output, hidden = model.decoder(trg, encoder_states, tmp_hidden, sentence_tensor)
                    out = F.softmax(output).squeeze()
                    
                    
                    tmp_args = out.argsort()[-k:]
                    tmp_probs = out[args].detach().cpu().numpy()

                    tmp_args = tmp_args.detach().cpu().numpy()
                    tmp_probs = (tmp_prob + np.log(tmp_probs))/(len(tmp_tgt)-1)


                    for r in range(k): 
                        test.append((tmp_tgt + [tmp_args[r]], tmp_probs[r], hidden))


            possible = sorted(test, key=lambda x:x[1], reverse=True)[:k]


                    
    
    return possible



def convert(x):
    
    sentence = x[0]
    sentence = [out_lang.vocab.itos[i] for i in sentence]
    
    return (" ".join(sentence), x[1])

In [16]:
def run_epoch():
    
    model.train()
    total_loss = 0
    
    for batch_idx, batch in enumerate(train_iterator):
        
        inp_data, inp_length = batch.src
        inp_data = inp_data.to(device)
        
        target, target_length = batch.trg
        target = target.to(device)

        output = model(inp_data, target, inp_length)
        output = output[1:].reshape(-1, output.shape[2])
        target = target[1:].reshape(-1)

        optimizer.zero_grad()
        loss = criterion(output, target)

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
        optimizer.step()
        
        total_loss += loss.item()
        
        sys.stdout.write("\r %d" % (batch_idx))
        sys.stdout.flush()
        
    return total_loss / len(train_iterator)

In [17]:
def run_epoch():
    
    model.train()
    total_loss = 0
    
    for batch_idx, batch in enumerate(train_iterator):
        
        inp_data, inp_length = batch.src
        inp_data = inp_data.to(device)
        
        target, target_length = batch.trg
        target = target.to(device)

        output = model(inp_data, target, inp_length)
        output = output[1:].reshape(-1, output.shape[2])
        target = target[1:].reshape(-1)

        optimizer.zero_grad()
        loss = criterion(output, target)

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
        optimizer.step()
        
        total_loss += loss.item()
        
        sys.stdout.write("\r %d" % (batch_idx))
        sys.stdout.flush()
        
    return total_loss / len(train_iterator)

In [18]:
def run_validation():
    
    model.eval()
    total_loss = 0
    
    for batch_idx, batch in enumerate(valid_iterator):
        
        inp_data, inp_length = batch.src
        inp_data = inp_data.to(device)
        
        target, target_length = batch.trg
        target = target.to(device)

        output = model(inp_data, target, inp_length)
        
        output = output[1:].reshape(-1, output.shape[2])
        target = target[1:].reshape(-1)

        loss = criterion(output, target)
        total_loss += loss.item()
        
    return total_loss / len(valid_iterator)

In [19]:
sentence = 'a man in green holds a guitar while the other man observes his shirt .'

best_loss = 65646

In [None]:
for epoch in range(20):
    
    print(f'Epoch [{epoch} / {num_epochs}]')


    loss =  run_epoch()
    validation_loss = run_validation()
    
    translated_sentence = translate_sentence_bahdanau(model, sentence, max_length=50)
    out = beam(sentence, 3) 
    
    
    print(f"Translated example sentence: \n {list(map(convert, out[:2]))}")
    print(f"Greedy: {translated_sentence}")
    
    print(f"\n Train loss {loss} | Validation loss {validation_loss} \n \n")
    
    if validation_loss < best_loss:
        torch.save(model.state_dict(), "../models/rnn_model")
        best_loss = validation_loss