# *Sequence to Sequence Learning with Neural Networks* in PyTorch

Encoder-decoder sequence to sequence RNN implementation based on 2014 publication:

**Sequence to Sequence Learning with Neural Networks** - Ilya Sutskever, Oriol Vinyals, Quoc V. Le

https://arxiv.org/abs/1409.3215

## Dataset

In the project files(PyTorch dataset implementation fra_eng_dataset.py) I use small toy dataset(170K sentences) in which can be found in the project files in fra-eng folder. But for this experiment I will try to use **WMT'14 English-German** dataset (4.5M sentences)

The dataset can be found here:
https://nlp.stanford.edu/projects/nmt/

### PyTorch Dataset

In [1]:
from torch.utils.data import Dataset, DataLoader
import pickle
from nltk.tokenize import word_tokenize
import os
import torch
import numpy as np
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
import os


RNN_LAYERS = 4
RNN_HIDDEN_SIZE = 1024
IN_EMBEDDING_SIZE = 256
OUT_EMBEDDING_SIZE = 256
BATCH_SIZE = 64
EPOCHS = 50
MAXMAX_SENTENCE_LEN = 50


device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

In [2]:
def get_top_dictionary(text_corpus_path, top_n = 50000):

    print(f"Creating top dictionary from: {text_corpus_path}..")
    
    last_token_idx = 0
    token_dict = dict()
    token_counts_list = []
    token_idx_to_token = []
    
    with open(text_corpus_path, "r", encoding='utf-8') as f:
        
        for idx, line in enumerate(f.readlines()):
            
            if (idx+1)% 500000 == 0:
                print(f"Processed {idx+1} lines")

            line = line.replace('##AT##', '')
            token_list = word_tokenize(line)
            for token in token_list:
                token = token.lower()
                if token not in token_dict:
                    token_dict[token] = last_token_idx
                    token_idx_to_token.append(token)
                    token_counts_list.append((0,last_token_idx))
                    last_token_idx += 1

                token_idx = token_dict[token]
                count, _ = token_counts_list[token_idx]
                token_counts_list[token_idx] = (count+1,token_idx)
                
    token_counts_list = sorted(token_counts_list, reverse=True)
    
    top_token_list = []
    
    for idx, (count, token_idx) in enumerate(token_counts_list):
        top_token_list.append(token_idx_to_token[token_idx])
        
        if idx > top_n:
            break
    
    return top_token_list
        


class WMT14_en_de_Dataset(Dataset):
    def __init__(self, data_source_path = 'wmt14_en_de'):
        super().__init__()
        
        processed_data_path = "processed_data.pkl"
        top_en_tokens_path = "top_en_tokens.pkl"
        top_de_tokens_path = "top_de_tokens.pkl"
        
        self.sentence_list = []
        
        self.en_token_dict = dict()
        self.en_token_dict['<PAD>'] = 0
        self.en_token_dict['<EOS>'] = 1
        self.en_token_dict['<UNK>'] = 2
        self.en_last_token_idx = 2
        self.en_token_idx_to_text = ['<PAD>', '<EOS>', '<UNK>']
        
        self.de_token_dict = dict()
        self.de_token_dict['<PAD>'] = 0
        self.de_token_dict['<EOS>'] = 1
        self.de_token_dict['<UNK>'] = 2
        self.de_last_token_idx = 2
        self.de_token_idx_to_text = ['<PAD>', '<EOS>', '<UNK>']
        
        
        if os.path.exists(processed_data_path):
            with open(processed_data_path, 'rb') as f:
                pickle_data = pickle.load(f)
                self.sentence_list = pickle_data['sentence_list']
                self.en_last_token_idx = pickle_data['en_last_token_idx']
                self.de_last_token_idx = pickle_data['de_last_token_idx']
                self.en_token_idx_to_text = pickle_data['en_token_idx_to_text']
                self.de_token_idx_to_text = pickle_data['de_token_idx_to_text']
        else:
        
            en_sentences_path = os.path.join(data_source_path, "train.en")
            de_sentences_path = os.path.join(data_source_path, "train.de")
            
            if os.path.exists(top_en_tokens_path):
                with open(top_en_tokens_path, "rb") as f:
                    top_en_tokens = pickle.load(f)
            else:
                top_en_tokens = get_top_dictionary(en_sentences_path)
                with open(top_en_tokens_path, "wb") as f:
                    pickle.dump(top_en_tokens, f)
            
            
            for token in top_en_tokens:
                self.en_last_token_idx += 1
                self.en_token_dict[token] = self.en_last_token_idx
                self.en_token_idx_to_text.append(token)
                
 
            if os.path.exists(top_de_tokens_path):
                with open(top_de_tokens_path, "rb") as f:
                    top_de_tokens = pickle.load(f)
            else:
                top_de_tokens = get_top_dictionary(de_sentences_path)
                with open(top_de_tokens_path, "wb") as f:
                    pickle.dump(top_de_tokens, f)
            
            for token in top_de_tokens:
                self.de_last_token_idx += 1
                self.de_token_dict[token] = self.de_last_token_idx
                self.de_token_idx_to_text.append(token)         
                    
            
            with open(de_sentences_path, "r", encoding='utf-8') as de_f:
                with open(en_sentences_path, "r", encoding='utf-8') as en_f:
                    
                    print("Creating sentences from {de_sentences_path} and {en_sentences_path} coropuses")
                    
                    for idx, (de_sentence, en_sentence) in enumerate(zip(de_f.readlines(), en_f.readlines())):
                        
                        if (idx+1)%500000 == 0:
                            print(f"Processed {idx+1} lines")
                            
                        de_sentence = de_sentence.replace('##AT##', '')
                        en_sentence = en_sentence.replace('##AT##', '')
                        
                        en_token_sentence = []
                        de_token_sentence = []

                        en_token_list = word_tokenize(en_sentence)
                        for token in en_token_list:
                            token = token.lower()
                            if token in self.en_token_dict:
                                token_idx = self.en_token_dict[token]
                            else:
                                token_idx = self.en_token_dict['<UNK>']
                                
                            en_token_sentence.append(token_idx)

                        en_token_sentence.append(self.en_token_dict['<EOS>'])

                        de_token_list = word_tokenize(de_sentence)
                        for token in de_token_list:
                            token = token.lower()
                            if token in self.de_token_dict:
                                token_idx = self.de_token_dict[token]
                            else:
                                token_idx = self.de_token_dict['<UNK>']
                                    

                            de_token_sentence.append(token_idx)

                        de_token_sentence.append(self.de_token_dict['<EOS>'])

                        self.sentence_list.append(
                            dict(
                                en = en_token_sentence,
                                de = de_token_sentence
                            ))
                        
            with open(processed_data_path, "wb") as f:
                pickle_processed_data = dict(
                    sentence_list = self.sentence_list,
                    en_last_token_idx = self.en_last_token_idx,
                    de_last_token_idx = self.de_last_token_idx,
                    en_token_idx_to_text = self.en_token_idx_to_text,
                    de_token_idx_to_text = self.de_token_idx_to_text
                )
                pickle.dump(pickle_processed_data, f)
            
    def get_en_dict_size(self):
        return self.en_last_token_idx + 1
        
    def get_de_dict_size(self):
        return self.de_last_token_idx + 1
    
    def get_de_eos_code(self):
        return self.de_token_dict['<EOS>']
    
    def get_en_eos_code(self):
        return self.en_token_dict['<EOS>']

    def __len__(self):
        return len(self.sentence_list)

    def __getitem__(self, item):
        ret = dict()
        for key in self.sentence_list[item]:
            ret[key] = torch.tensor(self.sentence_list[item][key])
        return ret


def en_de_dataset_collate(data):

    en_sentences = []
    en_sentence_lens = []
    de_sentences = []
    de_sentence_lens = []
    
    en_sentences_sorted = []
    en_sentence_lens_sorted = []
    de_sentences_sorted = []
    de_sentence_lens_sorted = []
    
    for s in data:
        
        sent = s['en'][0:MAXMAX_SENTENCE_LEN]
        en_sentences.append(sent.unsqueeze(dim=1))
        en_sentence_lens.append(len(sent))
        
        sent = s['de'][0:MAXMAX_SENTENCE_LEN]
        de_sentences.append(sent.unsqueeze(dim=1))
        de_sentence_lens.append(len(sent))

    #Rearrange everything by de sentence lens
    sort_idxes = np.argsort(np.array(de_sentence_lens))[::-1]
    for idx in sort_idxes:
        en_sentences_sorted.append(en_sentences[idx])
        en_sentence_lens_sorted.append(en_sentence_lens[idx])
        de_sentences_sorted.append(de_sentences[idx])
        de_sentence_lens_sorted.append(de_sentence_lens[idx])
    
    return dict(
        en_sentences = en_sentences_sorted,
        en_lens = en_sentence_lens_sorted,
        de_sentences = de_sentences_sorted,
        de_lens = de_sentence_lens_sorted
    )

### Models

In [3]:
class RNN_encoder_model(nn.Module):
    def __init__(self, in_dict_size):
        super().__init__()
        
        self.in_dict_size = in_dict_size

        self.embedding = nn.Linear(
            in_dict_size, 
            IN_EMBEDDING_SIZE)
        
        
        self.hidden = None 
        self.cell = None
        
        self.rnn = nn.LSTM(
            input_size = IN_EMBEDDING_SIZE,
            hidden_size = RNN_HIDDEN_SIZE,
            num_layers = RNN_LAYERS
        )
        
    def init_hidden_and_cell(self):
        self.hidden = torch.zeros(RNN_LAYERS, BATCH_SIZE, RNN_HIDDEN_SIZE).to(device)
        self.cell = torch.zeros(RNN_LAYERS, BATCH_SIZE, RNN_HIDDEN_SIZE).to(device)
    
    def get_hidden_and_cell(self):
        return self.hidden, self.cell
    
    def forward(self, x):
        padded_sent_one_hot, sent_lens = x
        padded_sent_emb = self.embedding.forward(padded_sent_one_hot)
        packed = pack_padded_sequence(padded_sent_emb, sent_lens)
        packed, (self.hidden, self.cell) = self.rnn.forward(packed, (self.hidden,self.cell))
        padded, sent_lens = pad_packed_sequence(packed)

In [4]:
class RNN_decoder_model(nn.Module):
    def __init__(self, out_dict_size):
        super().__init__()
      
        self.in_embedding = nn.Linear(
            in_features=out_dict_size,
            out_features=IN_EMBEDDING_SIZE
        )

        self.rnn = nn.LSTM(
            input_size = IN_EMBEDDING_SIZE,
            hidden_size = RNN_HIDDEN_SIZE,
            num_layers = RNN_LAYERS
        )

        self.rnn_to_embedding = nn.Linear(
            in_features = RNN_HIDDEN_SIZE,
            out_features = OUT_EMBEDDING_SIZE
        )

        self.embedding_to_logit = nn.Linear(
            in_features = OUT_EMBEDDING_SIZE, 
            out_features = out_dict_size
        )

        self.softmax = nn.Softmax(dim=2)
    
    def init_hidden_and_cell(self, hidden, cell):
        self.hidden = hidden
        self.cell = cell
    
    
    def forward(self, out_eos_code, out_dict_size, max_sentence_len):
        batch_size = self.hidden.shape[1]
        prev_outp = (torch.ones(1, batch_size, 1) * out_eos_code).long()
        prev_outp = prev_outp.to(device)
        
        all_outp_prob = []
        
        for timestep in range(max_sentence_len):
            
            prev_outp_one_hot = torch.zeros(prev_outp.shape[0], prev_outp.shape[1], out_dict_size).to(device)
            prev_outp_one_hot = prev_outp_one_hot.scatter_(2,prev_outp.data,1)
            
            prev_outp_in_emb = self.in_embedding(prev_outp_one_hot)
         
            cur_outp_hid, (self.hidden, self.cell) = self.rnn.forward(prev_outp_in_emb, (self.hidden, self.cell))
            cur_outp_emb = self.rnn_to_embedding.forward(cur_outp_hid)
            cur_outp_logits = self.embedding_to_logit(cur_outp_emb)
            cur_outp_prob = self.softmax(cur_outp_logits)
            all_outp_prob.append(cur_outp_prob)
            
            prev_outp = torch.argmax(cur_outp_prob.detach().data.to(device), dim=2, keepdim=True)
        
        all_outp_prob_tensor = torch.cat(all_outp_prob, dim=0)
    
        return all_outp_prob_tensor
   

In [5]:
dataset = WMT14_en_de_Dataset()
sentences_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, collate_fn=en_de_dataset_collate)


In [6]:
def print_results(in_sentence_list, out_sentence_list, pred_tensor, num_batches):
    
    sentence_prediction_samples_path = 'sentence_predictions.text'
    
    print(f"Printing sample predictions after {num_batches} batches in {sentence_prediction_samples_path}")
    
    with open(sentence_prediction_samples_path, "a") as f:

        for i in range(3):
            f.write('='*50 + '\n')
            
        f.write(f"Sample predictions after {num_batches} batches \n\n")
              
        in_token_to_text = dataset.de_token_idx_to_text
        out_token_to_text = dataset.en_token_idx_to_text

        for s in range(min(len(in_sentence_list),50)):

            in_sent_text = []
            for in_token in in_sentence_list[s].squeeze():
                in_sent_text.append(in_token_to_text[in_token])

            f.write(f"\nGerman sentence is: {' '.join(in_sent_text)} \n")

            out_sent_text = []

            for out_token in out_sentence_list[s].squeeze():
                  out_sent_text.append(out_token_to_text[out_token])
            f.write(f"English sentence is: {' '.join(out_sent_text)}\n")

            pred_sent_text = []
            for ts in range(pred_tensor.shape[0]):
                pred_token = torch.argmax(pred_tensor[ts, s,:]).data
                pred_sent_text.append(out_token_to_text[pred_token])

                if pred_token == dataset.get_en_eos_code():
                    break
            f.write(f"Translated English sentence is: {' '.join(pred_sent_text)}\n")


## Training

In [7]:
rnn_encoder = RNN_encoder_model(dataset.get_de_dict_size()).to(device)
rnn_decoder = RNN_decoder_model(dataset.get_en_dict_size()).to(device)

trained_encoder_path = None
trained_decoder_path = None

trained_encoder_path = 'models/encoder_wmt14_de_en_3nd.pt'
trained_decoder_path = 'models/decoder_wmt14_de_en_3nd.pt'

if os.path.exists(trained_encoder_path):
    rnn_encoder.load_state_dict(torch.load(trained_encoder_path))
if os.path.exists(trained_decoder_path):
    rnn_decoder.load_state_dict(torch.load(trained_decoder_path))


params = list(rnn_encoder.parameters()) + list(rnn_decoder.parameters())
optimizer = torch.optim.Adam(params, lr = 1e-5)

In [8]:
steps = 0
def train():
    global steps
    num_batches = 0
    num_loss_prints = 0

    for epoch in range(EPOCHS):

        print(f"Starting epoch {epoch} =====================")

        best_loss = 1e10
        loss_sum = 0

        for idx, sentences in enumerate(sentences_loader):

            rnn_encoder.init_hidden_and_cell()


            in_sentences = sentences['de_sentences']
            in_lens = sentences['de_lens']
            out_sentences = sentences['en_sentences']
            out_lens = sentences['en_lens']


            padded_in = pad_sequence(in_sentences, padding_value=0).to(device)
            padded_out = pad_sequence(out_sentences, padding_value=0).to(device)

            padded_in_one_hot = torch.zeros(padded_in.shape[0], padded_in.shape[1], dataset.get_de_dict_size()).to(device)
            padded_in_one_hot = padded_in_one_hot.scatter_(2,padded_in.data,1)

            rnn_encoder.forward((padded_in_one_hot, in_lens))
            hidden, cell = rnn_encoder.get_hidden_and_cell()

            rnn_decoder.init_hidden_and_cell(hidden,cell)

            max_sentence_len = padded_out.shape[0]

            y_pred = rnn_decoder.forward(dataset.get_en_eos_code(), dataset.get_en_dict_size(), max_sentence_len)

            padded_out = padded_out[0:max_sentence_len]
            padded_out_one_hot = torch.zeros(padded_out.shape[0], padded_out.shape[1], dataset.get_en_dict_size()).to(device)
            padded_out_one_hot = padded_out_one_hot.scatter_(2,padded_out.data,1)

            #Make all padded one-hot vectors to all zeros, which will make
            #padded components loss 0 and so they wont affect the loss
            padded_out_one_hot[:,:,0] = torch.zeros(max_sentence_len, padded_out_one_hot.shape[1])
            loss = torch.sum(-torch.log(y_pred + 1e-9) * padded_out_one_hot)

            loss_sum += loss.to('cpu').detach().data

            #print(loss.to('cpu').detach().data)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            #Every 100 batches print the average loss and store the model weights
            steps += BATCH_SIZE
            num_batches += 1
            loss_print_step = 200
            if num_batches % loss_print_step == 0:

                print(f"{num_batches} Average loss in the last {loss_print_step} batches is {loss_sum/float(loss_print_step)}")
                steps = 0


                num_loss_prints += 1 

                if num_loss_prints % 10 == 0:
                    print_results(in_sentences, out_sentences, y_pred.to('cpu').detach().data, num_batches)

                if best_loss > loss_sum:
                    best_loss = loss_sum

                    models_path = "models"
                    if not os.path.exists(models_path):
                        os.mkdir(models_path)

                    torch.save(rnn_encoder.state_dict(), trained_encoder_path)
                    torch.save(rnn_decoder.state_dict(), trained_decoder_path)

                loss_sum = 0
                steps = 0

    

In [8]:
train()

200 Average loss in the last 200 batches is 12456.51171875
400 Average loss in the last 200 batches is 11314.78125
600 Average loss in the last 200 batches is 11211.7890625
800 Average loss in the last 200 batches is 11208.6103515625
1000 Average loss in the last 200 batches is 11233.7509765625
1200 Average loss in the last 200 batches is 11134.869140625
1400 Average loss in the last 200 batches is 11068.5078125
1600 Average loss in the last 200 batches is 11022.9384765625
1800 Average loss in the last 200 batches is 10998.482421875
2000 Average loss in the last 200 batches is 10920.26953125
Printing sample predictions after 2000 batches in sentence_predictions.text
2200 Average loss in the last 200 batches is 10852.3564453125
2400 Average loss in the last 200 batches is 11001.01171875
2600 Average loss in the last 200 batches is 10825.5224609375
2800 Average loss in the last 200 batches is 10904.9560546875
3000 Average loss in the last 200 batches is 10801.083984375
3200 Average loss 

KeyboardInterrupt: 

In [9]:
train()

200 Average loss in the last 200 batches is 9747.49609375
400 Average loss in the last 200 batches is 9838.896484375
600 Average loss in the last 200 batches is 9838.03125
800 Average loss in the last 200 batches is 9887.501953125
1000 Average loss in the last 200 batches is 9797.1171875
1200 Average loss in the last 200 batches is 9779.291015625
1400 Average loss in the last 200 batches is 9824.8076171875
1600 Average loss in the last 200 batches is 9809.30078125
1800 Average loss in the last 200 batches is 9745.7041015625
2000 Average loss in the last 200 batches is 9747.541015625
Printing sample predictions after 2000 batches in sentence_predictions.text
2200 Average loss in the last 200 batches is 9720.42578125
2400 Average loss in the last 200 batches is 9672.2802734375
2600 Average loss in the last 200 batches is 9758.294921875
2800 Average loss in the last 200 batches is 9755.1083984375
3000 Average loss in the last 200 batches is 9773.0859375
3200 Average loss in the last 200 b

24200 Average loss in the last 200 batches is 9097.5791015625
24400 Average loss in the last 200 batches is 9155.7392578125
24600 Average loss in the last 200 batches is 9137.91796875
24800 Average loss in the last 200 batches is 9061.6142578125
25000 Average loss in the last 200 batches is 9092.25
25200 Average loss in the last 200 batches is 9058.88671875
25400 Average loss in the last 200 batches is 9061.0341796875
25600 Average loss in the last 200 batches is 9101.041015625
25800 Average loss in the last 200 batches is 9056.2060546875
26000 Average loss in the last 200 batches is 8972.74609375
Printing sample predictions after 26000 batches in sentence_predictions.text
26200 Average loss in the last 200 batches is 9111.7197265625
26400 Average loss in the last 200 batches is 9089.14453125
26600 Average loss in the last 200 batches is 9092.103515625
26800 Average loss in the last 200 batches is 8986.541015625
27000 Average loss in the last 200 batches is 9046.9296875
27200 Average l

48200 Average loss in the last 200 batches is 8701.6298828125
48400 Average loss in the last 200 batches is 8666.4658203125
48600 Average loss in the last 200 batches is 8714.48046875
48800 Average loss in the last 200 batches is 8730.14453125
49000 Average loss in the last 200 batches is 8733.6025390625
49200 Average loss in the last 200 batches is 8700.58984375
49400 Average loss in the last 200 batches is 8624.59375
49600 Average loss in the last 200 batches is 8746.4130859375
49800 Average loss in the last 200 batches is 8703.1240234375
50000 Average loss in the last 200 batches is 8704.4404296875
Printing sample predictions after 50000 batches in sentence_predictions.text
50200 Average loss in the last 200 batches is 8705.70703125
50400 Average loss in the last 200 batches is 8644.94140625
50600 Average loss in the last 200 batches is 8657.5615234375
50800 Average loss in the last 200 batches is 8682.419921875
51000 Average loss in the last 200 batches is 8618.341796875
51200 Aver

KeyboardInterrupt: 

In [None]:
train()

200 Average loss in the last 200 batches is 8515.95703125
400 Average loss in the last 200 batches is 8485.75
600 Average loss in the last 200 batches is 8553.5283203125
800 Average loss in the last 200 batches is 8567.0166015625
1000 Average loss in the last 200 batches is 8530.3193359375
1200 Average loss in the last 200 batches is 8580.306640625
1400 Average loss in the last 200 batches is 8576.2333984375
1600 Average loss in the last 200 batches is 8550.1787109375
1800 Average loss in the last 200 batches is 8647.04296875
2000 Average loss in the last 200 batches is 8578.50390625
Printing sample predictions after 2000 batches in sentence_predictions.text
2200 Average loss in the last 200 batches is 8517.2197265625
2400 Average loss in the last 200 batches is 8599.896484375
2600 Average loss in the last 200 batches is 8481.1123046875
2800 Average loss in the last 200 batches is 8500.046875
3000 Average loss in the last 200 batches is 8501.99609375
3200 Average loss in the last 200 b

24200 Average loss in the last 200 batches is 8355.32421875
24400 Average loss in the last 200 batches is 8219.5791015625
24600 Average loss in the last 200 batches is 8335.9970703125
24800 Average loss in the last 200 batches is 8272.494140625
25000 Average loss in the last 200 batches is 8427.666015625
25200 Average loss in the last 200 batches is 8368.103515625
25400 Average loss in the last 200 batches is 8339.986328125
25600 Average loss in the last 200 batches is 8275.7451171875
25800 Average loss in the last 200 batches is 8290.634765625
26000 Average loss in the last 200 batches is 8256.8583984375
Printing sample predictions after 26000 batches in sentence_predictions.text
26200 Average loss in the last 200 batches is 8348.1708984375
26400 Average loss in the last 200 batches is 8243.8935546875
26600 Average loss in the last 200 batches is 8248.8466796875
26800 Average loss in the last 200 batches is 8371.630859375
27000 Average loss in the last 200 batches is 8355.3310546875
2

48200 Average loss in the last 200 batches is 8235.7060546875
48400 Average loss in the last 200 batches is 8270.7548828125
48600 Average loss in the last 200 batches is 8172.2841796875
48800 Average loss in the last 200 batches is 8102.3662109375
49000 Average loss in the last 200 batches is 8226.0166015625
49200 Average loss in the last 200 batches is 8129.18505859375
49400 Average loss in the last 200 batches is 8098.24169921875
49600 Average loss in the last 200 batches is 8113.24267578125
49800 Average loss in the last 200 batches is 8169.158203125
50000 Average loss in the last 200 batches is 8105.26416015625
Printing sample predictions after 50000 batches in sentence_predictions.text
50200 Average loss in the last 200 batches is 8191.80810546875
50400 Average loss in the last 200 batches is 8160.361328125
50600 Average loss in the last 200 batches is 8176.326171875
50800 Average loss in the last 200 batches is 8247.390625
51000 Average loss in the last 200 batches is 8169.933105

72000 Average loss in the last 200 batches is 7997.4599609375
Printing sample predictions after 72000 batches in sentence_predictions.text
72200 Average loss in the last 200 batches is 7994.53564453125
72400 Average loss in the last 200 batches is 7931.41943359375
72600 Average loss in the last 200 batches is 8029.2060546875
72800 Average loss in the last 200 batches is 7958.47119140625
73000 Average loss in the last 200 batches is 7941.45458984375
73200 Average loss in the last 200 batches is 7940.62890625
73400 Average loss in the last 200 batches is 7982.12744140625
73600 Average loss in the last 200 batches is 7947.412109375
73800 Average loss in the last 200 batches is 7990.07421875
74000 Average loss in the last 200 batches is 8023.21630859375
Printing sample predictions after 74000 batches in sentence_predictions.text
74200 Average loss in the last 200 batches is 7894.0263671875
74400 Average loss in the last 200 batches is 8054.64013671875
74600 Average loss in the last 200 bat

95600 Average loss in the last 200 batches is 7967.630859375
95800 Average loss in the last 200 batches is 7866.36767578125
96000 Average loss in the last 200 batches is 7963.55126953125
Printing sample predictions after 96000 batches in sentence_predictions.text
96200 Average loss in the last 200 batches is 7976.87255859375
96400 Average loss in the last 200 batches is 7958.7841796875
96600 Average loss in the last 200 batches is 7919.865234375
96800 Average loss in the last 200 batches is 7945.3486328125
97000 Average loss in the last 200 batches is 7975.392578125
97200 Average loss in the last 200 batches is 7980.22265625
97400 Average loss in the last 200 batches is 7927.1513671875
97600 Average loss in the last 200 batches is 7968.25
97800 Average loss in the last 200 batches is 7982.466796875
98000 Average loss in the last 200 batches is 8093.2158203125
Printing sample predictions after 98000 batches in sentence_predictions.text
98200 Average loss in the last 200 batches is 7909.

118800 Average loss in the last 200 batches is 7893.3935546875
119000 Average loss in the last 200 batches is 7788.71875
119200 Average loss in the last 200 batches is 7960.40576171875
119400 Average loss in the last 200 batches is 7913.21435546875
119600 Average loss in the last 200 batches is 7948.41259765625
119800 Average loss in the last 200 batches is 7859.83544921875
120000 Average loss in the last 200 batches is 7889.13427734375
Printing sample predictions after 120000 batches in sentence_predictions.text
120200 Average loss in the last 200 batches is 7874.4638671875
120400 Average loss in the last 200 batches is 7910.6904296875
120600 Average loss in the last 200 batches is 7818.14453125
120800 Average loss in the last 200 batches is 7867.0419921875
121000 Average loss in the last 200 batches is 7868.86767578125
121200 Average loss in the last 200 batches is 7872.130859375
121400 Average loss in the last 200 batches is 7891.10302734375
121600 Average loss in the last 200 batch

142000 Average loss in the last 200 batches is 7801.388671875
Printing sample predictions after 142000 batches in sentence_predictions.text


In [None]:
#Continue with smaller learning rate. 
train()

200 Average loss in the last 200 batches is 7711.92578125
