# *Sequence to Sequence Learning with Neural Networks* in PyTorch

Encoder-decoder sequence to sequence RNN implementation based on 2014 publication:

**Sequence to Sequence Learning with Neural Networks** - Ilya Sutskever, Oriol Vinyals, Quoc V. Le

https://arxiv.org/abs/1409.3215

## Dataset

In the project files(PyTorch dataset implementation fra_eng_dataset.py) I use small toy dataset(170K sentences) in which can be found in the project files in fra-eng folder. But for this experiment I will try to use **WMT'14 English-German** dataset (4.5M sentences)

The dataset can be found here:
https://nlp.stanford.edu/projects/nmt/

### PyTorch Dataset

In [1]:
from torch.utils.data import Dataset, DataLoader
import pickle
from nltk.tokenize import word_tokenize
import os
import torch
import numpy as np
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
import os


RNN_LAYERS = 4
RNN_HIDDEN_SIZE = 1024
IN_EMBEDDING_SIZE = 128
OUT_EMBEDDING_SIZE = 128
BATCH_SIZE = 128
EPOCHS = 50
MAXMAX_SENTENCE_LEN = 20


device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

In [2]:
def get_top_dictionary(text_corpus_path, top_n = 50000):

    print(f"Creating top dictionary from: {text_corpus_path}..")
    
    last_token_idx = 0
    token_dict = dict()
    token_counts_list = []
    token_idx_to_token = []
    
    with open(text_corpus_path, "r", encoding='utf-8') as f:
        
        for idx, line in enumerate(f.readlines()):
            
            if (idx+1)% 500000 == 0:
                print(f"Processed {idx+1} lines")

            line = line.replace('##AT##', '')
            token_list = word_tokenize(line)
            for token in token_list:
                token = token.lower()
                if token not in token_dict:
                    token_dict[token] = last_token_idx
                    token_idx_to_token.append(token)
                    token_counts_list.append((0,last_token_idx))
                    last_token_idx += 1

                token_idx = token_dict[token]
                count, _ = token_counts_list[token_idx]
                token_counts_list[token_idx] = (count+1,token_idx)
                
    token_counts_list = sorted(token_counts_list, reverse=True)
    
    top_token_list = []
    
    for idx, (count, token_idx) in enumerate(token_counts_list):
        top_token_list.append(token_idx_to_token[token_idx])
        
        if idx > top_n:
            break
    
    return top_token_list
        


class WMT14_en_de_Dataset(Dataset):
    def __init__(self, data_source_path = 'wmt14_en_de'):
        super().__init__()
        
        processed_data_path = "processed_data.pkl"
        top_en_tokens_path = "top_en_tokens.pkl"
        top_de_tokens_path = "top_de_tokens.pkl"
        
        self.sentence_list = []
        
        self.en_token_dict = dict()
        self.en_token_dict['<PAD>'] = 0
        self.en_token_dict['<EOS>'] = 1
        self.en_token_dict['<UNK>'] = 2
        self.en_last_token_idx = 2
        self.en_token_idx_to_text = ['<PAD>', '<EOS>', '<UNK>']
        
        self.de_token_dict = dict()
        self.de_token_dict['<PAD>'] = 0
        self.de_token_dict['<EOS>'] = 1
        self.de_token_dict['<UNK>'] = 2
        self.de_last_token_idx = 2
        self.de_token_idx_to_text = ['<PAD>', '<EOS>', '<UNK>']
        
        
        if os.path.exists(processed_data_path):
            with open(processed_data_path, 'rb') as f:
                pickle_data = pickle.load(f)
                self.sentence_list = pickle_data['sentence_list']
                self.en_last_token_idx = pickle_data['en_last_token_idx']
                self.de_last_token_idx = pickle_data['de_last_token_idx']
                self.en_token_idx_to_text = pickle_data['en_token_idx_to_text']
                self.de_token_idx_to_text = pickle_data['de_token_idx_to_text']
        else:
        
            en_sentences_path = os.path.join(data_source_path, "train.en")
            de_sentences_path = os.path.join(data_source_path, "train.de")
            
            if os.path.exists(top_en_tokens_path):
                with open(top_en_tokens_path, "rb") as f:
                    top_en_tokens = pickle.load(f)
            else:
                top_en_tokens = get_top_dictionary(en_sentences_path)
                with open(top_en_tokens_path, "wb") as f:
                    pickle.dump(top_en_tokens, f)
            
            
            for token in top_en_tokens:
                self.en_last_token_idx += 1
                self.en_token_dict[token] = self.en_last_token_idx
                self.en_token_idx_to_text.append(token)
                
 
            if os.path.exists(top_de_tokens_path):
                with open(top_de_tokens_path, "rb") as f:
                    top_de_tokens = pickle.load(f)
            else:
                top_de_tokens = get_top_dictionary(de_sentences_path)
                with open(top_de_tokens_path, "wb") as f:
                    pickle.dump(top_de_tokens, f)
            
            for token in top_de_tokens:
                self.de_last_token_idx += 1
                self.de_token_dict[token] = self.de_last_token_idx
                self.de_token_idx_to_text.append(token)         
                    
            
            with open(de_sentences_path, "r", encoding='utf-8') as de_f:
                with open(en_sentences_path, "r", encoding='utf-8') as en_f:
                    
                    print("Creating sentences from {de_sentences_path} and {en_sentences_path} coropuses")
                    
                    for idx, (de_sentence, en_sentence) in enumerate(zip(de_f.readlines(), en_f.readlines())):
                        
                        if (idx+1)%500000 == 0:
                            print(f"Processed {idx+1} lines")
                            
                        de_sentence = de_sentence.replace('##AT##', '')
                        en_sentence = en_sentence.replace('##AT##', '')
                        
                        en_token_sentence = []
                        de_token_sentence = []

                        en_token_list = word_tokenize(en_sentence)
                        for token in en_token_list:
                            token = token.lower()
                            if token in self.en_token_dict:
                                token_idx = self.en_token_dict[token]
                            else:
                                token_idx = self.en_token_dict['<UNK>']
                                
                            en_token_sentence.append(token_idx)

                        en_token_sentence.append(self.en_token_dict['<EOS>'])

                        de_token_list = word_tokenize(de_sentence)
                        for token in de_token_list:
                            token = token.lower()
                            if token in self.de_token_dict:
                                token_idx = self.de_token_dict[token]
                            else:
                                token_idx = self.de_token_dict['<UNK>']
                                    

                            de_token_sentence.append(token_idx)

                        de_token_sentence.append(self.de_token_dict['<EOS>'])

                        self.sentence_list.append(
                            dict(
                                en = en_token_sentence,
                                de = de_token_sentence
                            ))
                        
            with open(processed_data_path, "wb") as f:
                pickle_processed_data = dict(
                    sentence_list = self.sentence_list,
                    en_last_token_idx = self.en_last_token_idx,
                    de_last_token_idx = self.de_last_token_idx,
                    en_token_idx_to_text = self.en_token_idx_to_text,
                    de_token_idx_to_text = self.de_token_idx_to_text
                )
                pickle.dump(pickle_processed_data, f)
            
    def get_en_dict_size(self):
        return self.en_last_token_idx + 1
        
    def get_de_dict_size(self):
        return self.de_last_token_idx + 1
    
    def get_de_eos_code(self):
        return self.de_token_dict['<EOS>']
    
    def get_en_eos_code(self):
        return self.en_token_dict['<EOS>']

    def __len__(self):
        return len(self.sentence_list)

    def __getitem__(self, item):
        ret = dict()
        for key in self.sentence_list[item]:
            ret[key] = torch.tensor(self.sentence_list[item][key])
        return ret


def en_de_dataset_collate(data):

    en_sentences = []
    en_sentence_lens = []
    de_sentences = []
    de_sentence_lens = []
    
    en_sentences_sorted = []
    en_sentence_lens_sorted = []
    de_sentences_sorted = []
    de_sentence_lens_sorted = []
    
    for s in data:
        
        sent = s['en'][0:MAXMAX_SENTENCE_LEN]
        en_sentences.append(sent.unsqueeze(dim=1))
        en_sentence_lens.append(len(sent))
        
        sent = s['de'][0:MAXMAX_SENTENCE_LEN]
        de_sentences.append(sent.unsqueeze(dim=1))
        de_sentence_lens.append(len(sent))

    #Rearrange everything by de sentence lens
    sort_idxes = np.argsort(np.array(de_sentence_lens))[::-1]
    for idx in sort_idxes:
        en_sentences_sorted.append(en_sentences[idx])
        en_sentence_lens_sorted.append(en_sentence_lens[idx])
        de_sentences_sorted.append(de_sentences[idx])
        de_sentence_lens_sorted.append(de_sentence_lens[idx])
    
    return dict(
        en_sentences = en_sentences_sorted,
        en_lens = en_sentence_lens_sorted,
        de_sentences = de_sentences_sorted,
        de_lens = de_sentence_lens_sorted
    )

### Models

In [3]:
class RNN_encoder_model(nn.Module):
    def __init__(self, in_dict_size):
        super().__init__()
        
        self.in_dict_size = in_dict_size

        self.embedding = nn.Linear(
            in_dict_size, 
            IN_EMBEDDING_SIZE)
        
        
        self.hidden = None 
        self.cell = None
        
        self.rnn = nn.LSTM(
            input_size = IN_EMBEDDING_SIZE,
            hidden_size = RNN_HIDDEN_SIZE,
            num_layers = RNN_LAYERS
        )
        
    def init_hidden_and_cell(self):
        self.hidden = torch.randn(RNN_LAYERS, BATCH_SIZE, RNN_HIDDEN_SIZE).to(device)
        self.cell = torch.rand(RNN_LAYERS, BATCH_SIZE, RNN_HIDDEN_SIZE).to(device)
    
    def get_hidden_and_cell(self):
        return self.hidden, self.cell
    
    def forward(self, x):
        padded_sent_one_hot, sent_lens = x
        padded_sent_emb = self.embedding.forward(padded_sent_one_hot)
        packed = pack_padded_sequence(padded_sent_emb, sent_lens)
        packed, (self.hidden, self.cell) = self.rnn.forward(packed, (self.hidden,self.cell))
        padded, sent_lens = pad_packed_sequence(packed)

In [4]:
class RNN_decoder_model(nn.Module):
    def __init__(self, out_dict_size):
        super().__init__()
      
        self.in_embedding = nn.Linear(
            in_features=out_dict_size,
            out_features=IN_EMBEDDING_SIZE
        )

        self.rnn = nn.LSTM(
            input_size = IN_EMBEDDING_SIZE,
            hidden_size = RNN_HIDDEN_SIZE,
            num_layers = RNN_LAYERS
        )

        self.rnn_to_embedding = nn.Linear(
            in_features = RNN_HIDDEN_SIZE,
            out_features = OUT_EMBEDDING_SIZE
        )

        self.embedding_to_logit = nn.Linear(
            in_features = OUT_EMBEDDING_SIZE, 
            out_features = out_dict_size
        )

        self.softmax = nn.Softmax(dim=2)
    
    def init_hidden_and_cell(self, hidden, cell):
        self.hidden = hidden
        self.cell = cell
    
    
    def forward(self, out_eos_code, out_dict_size, max_sentence_len):
        batch_size = self.hidden.shape[1]
        prev_outp = (torch.ones(1, batch_size, 1) * out_eos_code).long()
        prev_outp = prev_outp.to(device)
        
        all_outp_prob = []
        
        for timestep in range(max_sentence_len):
            
            prev_outp_one_hot = torch.zeros(prev_outp.shape[0], prev_outp.shape[1], out_dict_size).to(device)
            prev_outp_one_hot = prev_outp_one_hot.scatter_(2,prev_outp.data,1)
            
            prev_outp_in_emb = self.in_embedding(prev_outp_one_hot)
         
            cur_outp_hid, (self.hidden, self.cell) = self.rnn.forward(prev_outp_in_emb, (self.hidden, self.cell))
            cur_outp_emb = self.rnn_to_embedding.forward(cur_outp_hid)
            cur_outp_logits = self.embedding_to_logit(cur_outp_emb)
            cur_outp_prob = self.softmax(cur_outp_logits)
            all_outp_prob.append(cur_outp_prob)
            
            prev_outp = torch.argmax(cur_outp_prob.data.to(device), dim=2, keepdim=True)
        
        all_outp_prob_tensor = torch.cat(all_outp_prob, dim=0)
    
        return all_outp_prob_tensor
   

In [5]:
dataset = WMT14_en_de_Dataset()
sentences_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, collate_fn=en_de_dataset_collate)


In [6]:
def print_results(in_sentence_list, out_sentence_list, pred_tensor, num_batches):
    
    sentence_prediction_samples_path = 'sentence_predictions.text'
    
    print(f"Printing sample predictions after {num_batches} batches in {sentence_prediction_samples_path}")
    
    with open(sentence_prediction_samples_path, "a") as f:

        for i in range(3):
            f.write('='*50 + '\n')
            
        f.write(f"Sample predictions after {num_batches} batches \n\n")
              
        in_token_to_text = dataset.de_token_idx_to_text
        out_token_to_text = dataset.en_token_idx_to_text

        for s in range(min(len(in_sentence_list),50)):

            in_sent_text = []
            for in_token in in_sentence_list[s].squeeze():
                in_sent_text.append(in_token_to_text[in_token])

            f.write(f"\nGerman sentence is: {' '.join(in_sent_text)} \n")

            out_sent_text = []

            for out_token in out_sentence_list[s].squeeze():
                  out_sent_text.append(out_token_to_text[out_token])
            f.write(f"English sentence is: {' '.join(out_sent_text)}\n")

            pred_sent_text = []
            for ts in range(pred_tensor.shape[0]):
                pred_token = torch.argmax(pred_tensor[ts, s,:]).data
                pred_sent_text.append(out_token_to_text[pred_token])

                if pred_token == dataset.get_en_eos_code():
                    break
            f.write(f"Translated English sentence is: {' '.join(pred_sent_text)}\n")


## Training

In [7]:
rnn_encoder = RNN_encoder_model(dataset.get_de_dict_size()).to(device)
rnn_decoder = RNN_decoder_model(dataset.get_en_dict_size()).to(device)

trained_encoder_path = None
trained_decoder_path = None

trained_encoder_path = 'models/encoder_wmt14_de_en.pt'
trained_decoder_path = 'models/decoder_wmt14_de_en.pt'

if os.path.exists(trained_encoder_path):
    rnn_encoder.load_state_dict(torch.load(trained_encoder_path))
if os.path.exists(trained_decoder_path):
    rnn_decoder.load_state_dict(torch.load(trained_decoder_path))


params = list(rnn_encoder.parameters()) + list(rnn_decoder.parameters())
optimizer = torch.optim.Adam(params, lr = 1e-3)

In [8]:
steps = 0
num_batches = 0
num_loss_prints = 0

for epoch in range(EPOCHS):
    
    print(f"Starting epoch {epoch} =====================")
    
    best_loss = 1e10
    loss_sum = 0
    
    for idx, sentences in enumerate(sentences_loader):

        rnn_encoder.init_hidden_and_cell()
       
   
        in_sentences = sentences['de_sentences']
        in_lens = sentences['de_lens']
        out_sentences = sentences['en_sentences']
        out_lens = sentences['en_lens']
        

        padded_in = pad_sequence(in_sentences, padding_value=0).to(device)
        padded_out = pad_sequence(out_sentences, padding_value=0).to(device)

        padded_in_one_hot = torch.zeros(padded_in.shape[0], padded_in.shape[1], dataset.get_de_dict_size()).to(device)
        padded_in_one_hot = padded_in_one_hot.scatter_(2,padded_in.data,1)
       
        rnn_encoder.forward((padded_in_one_hot, in_lens))
        hidden, cell = rnn_encoder.get_hidden_and_cell()
       
        rnn_decoder.init_hidden_and_cell(hidden,cell)
       
        max_sentence_len = padded_out.shape[0]
            
        y_pred = rnn_decoder.forward(dataset.get_en_eos_code(), dataset.get_en_dict_size(), max_sentence_len)
       
        padded_out = padded_out[0:max_sentence_len]
        padded_out_one_hot = torch.zeros(padded_out.shape[0], padded_out.shape[1], dataset.get_en_dict_size()).to(device)
        padded_out_one_hot = padded_out_one_hot.scatter_(2,padded_out.data,1)
       
        #Make all padded one-hot vectors to all zeros, which will make
        #padded components loss 0 and so they wont affect the loss
        padded_out_one_hot[:,:,0] = torch.zeros(max_sentence_len, padded_out_one_hot.shape[1])
        loss = torch.sum(-torch.log(y_pred + 1e-9) * padded_out_one_hot)
       
        loss_sum += loss.to('cpu').detach().data
       
        #print(loss.to('cpu').detach().data)
       
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        #Every 100 batches print the average loss and store the model weights
        steps += BATCH_SIZE
        num_batches += 1
        loss_print_step = 200
        if num_batches % loss_print_step == 0:
            
            print(f"{num_batches} Average loss in the last {loss_print_step} batches is {loss_sum/float(loss_print_step)}")
            steps = 0
            
            
            num_loss_prints += 1 
            
            if num_loss_prints % 10 == 0:
                print_results(in_sentences, out_sentences, y_pred.to('cpu').detach().data, num_batches)
            
            if best_loss > loss_sum:
                best_loss = loss_sum

                models_path = "models"
                if not os.path.exists(models_path):
                    os.mkdir(models_path)

                torch.save(rnn_encoder.state_dict(), trained_encoder_path)
                torch.save(rnn_decoder.state_dict(), trained_decoder_path)
            
            loss_sum = 0
            steps = 0

    

200 Average loss in the last 200 batches is 15070.2412109375
400 Average loss in the last 200 batches is 14894.48828125
600 Average loss in the last 200 batches is 14763.2783203125
800 Average loss in the last 200 batches is 14697.6083984375
1000 Average loss in the last 200 batches is 14628.3720703125
1200 Average loss in the last 200 batches is 14548.44921875
1400 Average loss in the last 200 batches is 14447.9462890625
1600 Average loss in the last 200 batches is 14417.166015625
1800 Average loss in the last 200 batches is 14295.33203125
2000 Average loss in the last 200 batches is 14311.87109375
Printing sample predictions after 2000 batches in sentence_predictions.text
2200 Average loss in the last 200 batches is 14228.9013671875
2400 Average loss in the last 200 batches is 14185.505859375
2600 Average loss in the last 200 batches is 14131.49609375
2800 Average loss in the last 200 batches is 14131.35546875
3000 Average loss in the last 200 batches is 14102.1298828125
3200 Average

24000 Average loss in the last 200 batches is 11962.40234375
Printing sample predictions after 24000 batches in sentence_predictions.text
24200 Average loss in the last 200 batches is 11963.759765625
24400 Average loss in the last 200 batches is 11932.4873046875
24600 Average loss in the last 200 batches is 11936.9072265625
24800 Average loss in the last 200 batches is 11911.490234375
25000 Average loss in the last 200 batches is 11897.30078125
25200 Average loss in the last 200 batches is 11898.2958984375
25400 Average loss in the last 200 batches is 11848.021484375
25600 Average loss in the last 200 batches is 11850.505859375
25800 Average loss in the last 200 batches is 11899.7265625
26000 Average loss in the last 200 batches is 11879.7822265625
Printing sample predictions after 26000 batches in sentence_predictions.text
26200 Average loss in the last 200 batches is 11845.486328125
26400 Average loss in the last 200 batches is 11820.9140625
26600 Average loss in the last 200 batches

47400 Average loss in the last 200 batches is 11190.5966796875
47600 Average loss in the last 200 batches is 11204.7666015625
47800 Average loss in the last 200 batches is 11151.2060546875
48000 Average loss in the last 200 batches is 11186.5341796875
Printing sample predictions after 48000 batches in sentence_predictions.text
48200 Average loss in the last 200 batches is 11207.580078125
48400 Average loss in the last 200 batches is 11133.84375
48600 Average loss in the last 200 batches is 11160.302734375
48800 Average loss in the last 200 batches is 11178.63671875
49000 Average loss in the last 200 batches is 11151.1533203125
49200 Average loss in the last 200 batches is 11166.1298828125
49400 Average loss in the last 200 batches is 11207.677734375
49600 Average loss in the last 200 batches is 11086.244140625
49800 Average loss in the last 200 batches is 11183.81640625
50000 Average loss in the last 200 batches is 11185.8515625
Printing sample predictions after 50000 batches in senten

KeyboardInterrupt: 

In [9]:
#Continue to train the same model with smaller learning rate of 1e-4 (

optimizer = torch.optim.Adam(params, lr = 1e-4)

steps = 0
num_batches = 0
num_loss_prints = 0

for epoch in range(EPOCHS):
    
    print(f"Starting epoch {epoch} =====================")
    
    best_loss = 1e10
    loss_sum = 0
    
    for idx, sentences in enumerate(sentences_loader):

        rnn_encoder.init_hidden_and_cell()
       
   
        in_sentences = sentences['de_sentences']
        in_lens = sentences['de_lens']
        out_sentences = sentences['en_sentences']
        out_lens = sentences['en_lens']
        

        padded_in = pad_sequence(in_sentences, padding_value=0).to(device)
        padded_out = pad_sequence(out_sentences, padding_value=0).to(device)

        padded_in_one_hot = torch.zeros(padded_in.shape[0], padded_in.shape[1], dataset.get_de_dict_size()).to(device)
        padded_in_one_hot = padded_in_one_hot.scatter_(2,padded_in.data,1)
       
        rnn_encoder.forward((padded_in_one_hot, in_lens))
        hidden, cell = rnn_encoder.get_hidden_and_cell()
       
        rnn_decoder.init_hidden_and_cell(hidden,cell)
       
        max_sentence_len = padded_out.shape[0]
            
        y_pred = rnn_decoder.forward(dataset.get_en_eos_code(), dataset.get_en_dict_size(), max_sentence_len)
       
        padded_out = padded_out[0:max_sentence_len]
        padded_out_one_hot = torch.zeros(padded_out.shape[0], padded_out.shape[1], dataset.get_en_dict_size()).to(device)
        padded_out_one_hot = padded_out_one_hot.scatter_(2,padded_out.data,1)
       
        #Make all padded one-hot vectors to all zeros, which will make
        #padded components loss 0 and so they wont affect the loss
        padded_out_one_hot[:,:,0] = torch.zeros(max_sentence_len, padded_out_one_hot.shape[1])
        loss = torch.sum(-torch.log(y_pred + 1e-9) * padded_out_one_hot)
       
        loss_sum += loss.to('cpu').detach().data
       
        #print(loss.to('cpu').detach().data)
       
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        #Every 200 batches print the average loss and store the model weights
        steps += BATCH_SIZE
        num_batches += 1
        loss_print_step = 200
        if num_batches % loss_print_step == 0:
            
            print(f"{num_batches} Average loss in the last {loss_print_step} batches is {loss_sum/float(loss_print_step)}")
            steps = 0
            
            
            num_loss_prints += 1 
            
            if num_loss_prints % 10 == 0:
                print_results(in_sentences, out_sentences, y_pred.to('cpu').detach().data, num_batches)
            
            if best_loss > loss_sum:
                best_loss = loss_sum

                models_path = "models"
                if not os.path.exists(models_path):
                    os.mkdir(models_path)

                torch.save(rnn_encoder.state_dict(), trained_encoder_path)
                torch.save(rnn_decoder.state_dict(), trained_decoder_path)
            
            loss_sum = 0
            steps = 0

200 Average loss in the last 200 batches is 10807.32421875
400 Average loss in the last 200 batches is 10749.759765625
600 Average loss in the last 200 batches is 10706.5400390625
800 Average loss in the last 200 batches is 10728.8974609375
1000 Average loss in the last 200 batches is 10725.8134765625
1200 Average loss in the last 200 batches is 10670.068359375
1400 Average loss in the last 200 batches is 10720.861328125
1600 Average loss in the last 200 batches is 10716.0791015625
1800 Average loss in the last 200 batches is 10679.7021484375
2000 Average loss in the last 200 batches is 10715.9677734375
Printing sample predictions after 2000 batches in sentence_predictions.text
2200 Average loss in the last 200 batches is 10706.7177734375
2400 Average loss in the last 200 batches is 10709.73828125
2600 Average loss in the last 200 batches is 10705.5029296875
2800 Average loss in the last 200 batches is 10698.0166015625
3000 Average loss in the last 200 batches is 10647.263671875
3200 A

KeyboardInterrupt: 

In [10]:
#Continue to train the same model with even smaller learning rate of 1e-5 (

optimizer = torch.optim.Adam(params, lr = 1e-5)

steps = 0
num_loss_prints = 0

for epoch in range(EPOCHS):
    
    print(f"Starting epoch {epoch} =====================")
    
    best_loss = 1e10
    loss_sum = 0
    
    for idx, sentences in enumerate(sentences_loader):

        rnn_encoder.init_hidden_and_cell()
       
   
        in_sentences = sentences['de_sentences']
        in_lens = sentences['de_lens']
        out_sentences = sentences['en_sentences']
        out_lens = sentences['en_lens']
        

        padded_in = pad_sequence(in_sentences, padding_value=0).to(device)
        padded_out = pad_sequence(out_sentences, padding_value=0).to(device)

        padded_in_one_hot = torch.zeros(padded_in.shape[0], padded_in.shape[1], dataset.get_de_dict_size()).to(device)
        padded_in_one_hot = padded_in_one_hot.scatter_(2,padded_in.data,1)
       
        rnn_encoder.forward((padded_in_one_hot, in_lens))
        hidden, cell = rnn_encoder.get_hidden_and_cell()
       
        rnn_decoder.init_hidden_and_cell(hidden,cell)
       
        max_sentence_len = padded_out.shape[0]
            
        y_pred = rnn_decoder.forward(dataset.get_en_eos_code(), dataset.get_en_dict_size(), max_sentence_len)
       
        padded_out = padded_out[0:max_sentence_len]
        padded_out_one_hot = torch.zeros(padded_out.shape[0], padded_out.shape[1], dataset.get_en_dict_size()).to(device)
        padded_out_one_hot = padded_out_one_hot.scatter_(2,padded_out.data,1)
       
        #Make all padded one-hot vectors to all zeros, which will make
        #padded components loss 0 and so they wont affect the loss
        padded_out_one_hot[:,:,0] = torch.zeros(max_sentence_len, padded_out_one_hot.shape[1])
        loss = torch.sum(-torch.log(y_pred + 1e-9) * padded_out_one_hot)
       
        loss_sum += loss.to('cpu').detach().data
       
        #print(loss.to('cpu').detach().data)
       
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        #Every 200 batches print the average loss and store the model weights
        steps += BATCH_SIZE
        num_batches += 1
        loss_print_step = 200
        if num_batches % loss_print_step == 0:
            
            print(f"{num_batches} Average loss in the last {loss_print_step} batches is {loss_sum/float(loss_print_step)}")
            steps = 0
            
            
            num_loss_prints += 1 
            
            if num_loss_prints % 10 == 0:
                print_results(in_sentences, out_sentences, y_pred.to('cpu').detach().data, num_batches)
            
            if best_loss > loss_sum:
                best_loss = loss_sum

                models_path = "models"
                if not os.path.exists(models_path):
                    os.mkdir(models_path)

                torch.save(rnn_encoder.state_dict(), trained_encoder_path)
                torch.save(rnn_decoder.state_dict(), trained_decoder_path)
            
            loss_sum = 0
            steps = 0

9800 Average loss in the last 200 batches is 7144.40869140625
10000 Average loss in the last 200 batches is 10564.2177734375
10200 Average loss in the last 200 batches is 10587.638671875
10400 Average loss in the last 200 batches is 10588.2275390625
10600 Average loss in the last 200 batches is 10566.5234375
10800 Average loss in the last 200 batches is 10594.4365234375
11000 Average loss in the last 200 batches is 10523.6474609375
11200 Average loss in the last 200 batches is 10557.849609375
11400 Average loss in the last 200 batches is 10607.12890625
11600 Average loss in the last 200 batches is 10611.0126953125
Printing sample predictions after 11600 batches in sentence_predictions.text
11800 Average loss in the last 200 batches is 10565.3251953125
12000 Average loss in the last 200 batches is 10654.5576171875
12200 Average loss in the last 200 batches is 10603.1640625
12400 Average loss in the last 200 batches is 10576.9775390625
12600 Average loss in the last 200 batches is 10605.

33400 Average loss in the last 200 batches is 10601.86328125
33600 Average loss in the last 200 batches is 10552.1201171875
Printing sample predictions after 33600 batches in sentence_predictions.text
33800 Average loss in the last 200 batches is 10566.4296875
34000 Average loss in the last 200 batches is 10544.1611328125
34200 Average loss in the last 200 batches is 10555.4111328125
34400 Average loss in the last 200 batches is 10603.4716796875
34600 Average loss in the last 200 batches is 10574.5322265625
34800 Average loss in the last 200 batches is 10580.744140625
35000 Average loss in the last 200 batches is 10578.3525390625
35200 Average loss in the last 200 batches is 10586.884765625
35400 Average loss in the last 200 batches is 10589.6533203125
35600 Average loss in the last 200 batches is 10606.4423828125
Printing sample predictions after 35600 batches in sentence_predictions.text
35800 Average loss in the last 200 batches is 10574.8154296875
36000 Average loss in the last 200

56800 Average loss in the last 200 batches is 10553.919921875
57000 Average loss in the last 200 batches is 10580.888671875
57200 Average loss in the last 200 batches is 10554.2724609375
57400 Average loss in the last 200 batches is 10559.7822265625
57600 Average loss in the last 200 batches is 10523.25
Printing sample predictions after 57600 batches in sentence_predictions.text
57800 Average loss in the last 200 batches is 10527.16796875
58000 Average loss in the last 200 batches is 10530.142578125
58200 Average loss in the last 200 batches is 10532.5859375
58400 Average loss in the last 200 batches is 10584.2666015625
58600 Average loss in the last 200 batches is 10549.6279296875
58800 Average loss in the last 200 batches is 10562.8486328125
59000 Average loss in the last 200 batches is 10489.9912109375


KeyboardInterrupt: 

In [None]:


optimizer = torch.optim.Adam(params, lr = 1e-4)

steps = 0
num_loss_prints = 0

for epoch in range(EPOCHS):
    
    print(f"Starting epoch {epoch} =====================")
    
    best_loss = 1e10
    loss_sum = 0
    
    for idx, sentences in enumerate(sentences_loader):

        rnn_encoder.init_hidden_and_cell()
       
   
        in_sentences = sentences['de_sentences']
        in_lens = sentences['de_lens']
        out_sentences = sentences['en_sentences']
        out_lens = sentences['en_lens']
        

        padded_in = pad_sequence(in_sentences, padding_value=0).to(device)
        padded_out = pad_sequence(out_sentences, padding_value=0).to(device)

        padded_in_one_hot = torch.zeros(padded_in.shape[0], padded_in.shape[1], dataset.get_de_dict_size()).to(device)
        padded_in_one_hot = padded_in_one_hot.scatter_(2,padded_in.data,1)
       
        rnn_encoder.forward((padded_in_one_hot, in_lens))
        hidden, cell = rnn_encoder.get_hidden_and_cell()
       
        rnn_decoder.init_hidden_and_cell(hidden,cell)
       
        max_sentence_len = padded_out.shape[0]
            
        y_pred = rnn_decoder.forward(dataset.get_en_eos_code(), dataset.get_en_dict_size(), max_sentence_len)
       
        padded_out = padded_out[0:max_sentence_len]
        padded_out_one_hot = torch.zeros(padded_out.shape[0], padded_out.shape[1], dataset.get_en_dict_size()).to(device)
        padded_out_one_hot = padded_out_one_hot.scatter_(2,padded_out.data,1)
       
        #Make all padded one-hot vectors to all zeros, which will make
        #padded components loss 0 and so they wont affect the loss
        padded_out_one_hot[:,:,0] = torch.zeros(max_sentence_len, padded_out_one_hot.shape[1])
        loss = torch.sum(-torch.log(y_pred + 1e-9) * padded_out_one_hot)
       
        loss_sum += loss.to('cpu').detach().data
       
        #print(loss.to('cpu').detach().data)
       
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        #Every 200 batches print the average loss and store the model weights
        steps += BATCH_SIZE
        num_batches += 1
        loss_print_step = 200
        if num_batches % loss_print_step == 0:
            
            print(f"{num_batches} Average loss in the last {loss_print_step} batches is {loss_sum/float(loss_print_step)}")
            steps = 0
            
            
            num_loss_prints += 1 
            
            if num_loss_prints % 10 == 0:
                print_results(in_sentences, out_sentences, y_pred.to('cpu').detach().data, num_batches)
            
            if best_loss > loss_sum:
                best_loss = loss_sum

                models_path = "models"
                if not os.path.exists(models_path):
                    os.mkdir(models_path)

                torch.save(rnn_encoder.state_dict(), trained_encoder_path)
                torch.save(rnn_decoder.state_dict(), trained_decoder_path)
            
            loss_sum = 0
            steps = 0

59200 Average loss in the last 200 batches is 1776.3555908203125
59400 Average loss in the last 200 batches is 10538.2021484375
59600 Average loss in the last 200 batches is 10581.0322265625
59800 Average loss in the last 200 batches is 10599.4912109375
