# Seq2Seq with batch training

This notebook implementation of machine translation using seq2seq  where both encoder and decoder are LSTM. Batch training is enabled by padding input and output sentences. It's based on the [PyTorch tutorial](https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html).

In [7]:
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import random
import numpy as np
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from itertools import chain

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Prepare dataset

This part is taken from [PyTorch tutorial](https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html) with some modification

data has been put under the folder /data.

In [8]:
# PAD is a token used to faciliate batch training when input sentences are not of uniform lengths.
# SOS is start of sentence inserted in the beginning of the output sequence
# EOS is end of sentence inserted at the end of both input and output sequences
PAD_token = 0
SOS_token = 1
EOS_token = 2

# Language class consists of conversion from word to index and its inverse

class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "PAD", 1: "SOS", 2: "EOS"}
        self.n_words = 3  # Count SOS and EOS

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

In [9]:
# Turn a Unicode string to plain ASCII, thanks to
# https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters


def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s

In [10]:
# read in sentence from the data file
def readLangs(lang1, lang2, reverse=False):
    print("Reading lines...")

    # Read the file and split into lines
    lines = open('data/%s-%s.txt' % (lang1, lang2), encoding='utf-8').\
        read().strip().split('\n')

    # Split every line into pairs and normalize
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]

    # Reverse pairs, make Lang instances
    if reverse:
        pairs = [list(reversed(p)) for p in pairs]
        input_lang = Lang(lang2)
        output_lang = Lang(lang1)
    else:
        input_lang = Lang(lang1)
        output_lang = Lang(lang2)

    return input_lang, output_lang, pairs

In [11]:
# Only input sentence of length < 10 is used because this is a toy example
MAX_LENGTH = 10

eng_prefixes = (
    "i am ", "i m ",
    "he is", "he s ",
    "she is", "she s ",
    "you are", "you re ",
    "we are", "we re ",
    "they are", "they re "
)


def filterPair(p):
    return len(p[0].split(' ')) < MAX_LENGTH and \
        len(p[1].split(' ')) < MAX_LENGTH and \
        p[1].startswith(eng_prefixes)


def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

In [12]:
# prepare training data from raw data
def prepareData(lang1, lang2, reverse=False):
    input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)
    print("Read %s sentence pairs" % len(pairs))
    pairs = filterPairs(pairs)
    print("Trimmed to %s sentence pairs" % len(pairs))
    print("Counting words...")
    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])
    print("Counted words:")
    print(input_lang.name, input_lang.n_words)
    print(output_lang.name, output_lang.n_words)
    return input_lang, output_lang, pairs


input_lang, output_lang, pairs = prepareData('eng', 'fra', True)
print(random.choice(pairs))

Reading lines...
Read 135842 sentence pairs
Trimmed to 10599 sentence pairs
Counting words...
Counted words:
fra 4346
eng 2804
['il est gaucher .', 'he s a southpaw .']


# Set up model: encoder and decoder

This part is taken from original PyTorch tutorial with some modification to enable batch processing

In [13]:
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, device):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.device = device
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)

    # input shape: (B, largest seq length)
    # input is padded
    # hidden shape: (B, hidden_size)
    # outputs shape: (B, sequence length, hidden_size)
    def forward(self, inputs, hidden, batch_size=1, input_lengths=None):
        embedded = self.embedding(inputs).view(batch_size, -1, self.hidden_size)
        # in batch training, input sentences are of various lengths 
        if input_lengths:
            outputs = pack_padded_sequence(embedded, input_lengths,
                                            batch_first=True, enforce_sorted=False)
            
            outputs, hidden = self.gru(outputs, hidden)
            outputs, output_lengths = pad_packed_sequence(outputs, batch_first=True)
            #outputs = outputs[np.arange(batch_size), output_lengths-1, :]
           
        # in prediction time, only one sentence is input
        else:
            outputs, hidden = self.gru(embedded, hidden)
        
        return outputs, hidden

    def initHidden(self, batch_size=1):
        return torch.zeros(1, batch_size, self.hidden_size, device=self.device)

In [14]:
class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, device):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.device=device
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.out = nn.Linear(hidden_size, output_size)

    # input shape: (B, largest seq length)
    # input is padded
    # embedded shape: (B, seq length, hidden size)
    # outputs shape: (B, seq length, output size)
    def forward(self, inputs, hidden, batch_size=1, input_lengths=None):
        embedded = self.embedding(inputs).view(batch_size, -1, self.hidden_size)
        embedded = F.relu(embedded)
        # in batch training, input sentences are of various lengths
        if input_lengths:
            outputs = pack_padded_sequence(embedded, input_lengths,
                                           batch_first=True, enforce_sorted=False)
            
            
            outputs, hidden = self.gru(outputs, hidden)
            outputs, output_lengths = pad_packed_sequence(outputs, batch_first=True)
        # in prediction
        else:
            outputs, hidden = self.gru(embedded, hidden)
        
        outputs = self.out(outputs)
        return outputs, hidden

    def initHidden(self, batch_size=1):
        return torch.zeros(1, batch_size, self.hidden_size, device=self.device)

# Data generator

The original PyTorch tutorial does not contain this part for seq2seq. It skipped to seq2seq with attention.

In [15]:
# dataset for translation
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

def pad_sequence_seq2seq(batch):
    # output batch sentence pairs and length of each sentence in the batch
    input_tokens = [pair[0] for pair in batch]
    output_tokens = [pair[1] for pair in batch]
    input_tokens_padded = pad_sequence(input_tokens, batch_first=True)
    output_tokens_padded = pad_sequence(output_tokens, batch_first=True)
    
    input_lengths = [len(pair[0]) for pair in batch]
    output_lengths = [len(pair[1]) for pair in batch]
    
    return input_tokens_padded, input_lengths, output_tokens_padded, output_lengths

class dataset(Dataset):
    def __init__(self, device):
        super(dataset, self).__init__()
        self.data = pairs
        self.device = device
    
    def __len__(self):
        return len(pairs)
    
    def __getitem__(self, idx):
        input, output = pairs[idx]
        # add EOS and SOS
        input_token = torch.tensor([input_lang.word2index[word] for word in input.split()] +
                                   [EOS_token],
                                   device=self.device)
        output_token = torch.tensor([SOS_token] + 
                                    [output_lang.word2index[word] for word in output.split()] +
                                    [EOS_token],
                                    device=self.device)
        return input_token, output_token
    
data = dataset(device)
dataloader = DataLoader(data, collate_fn=pad_sequence_seq2seq ,batch_size=2)

# Construct model class

In [16]:
class seq2seq():
    def __init__(self, input_size, hidden_size, output_size, device):
        super(seq2seq, self).__init__()
        self.device = device
        self.hidden_size = hidden_size
        self.encoder = EncoderRNN(input_size, hidden_size, self.device).to(self.device)
        self.decoder = DecoderRNN(hidden_size,output_size, self.device).to(self.device)
    def train(self, dataloader, epochs=5, encoder_lr=0.01, decoder_lr=0.01):
        criterion = nn.CrossEntropyLoss()
        encoder_optimizer = torch.optim.SGD(self.encoder.parameters(), lr=0.01, momentum=0.9)
        decoder_optimizer = torch.optim.SGD(self.decoder.parameters(), lr=0.01, momentum=0.9)
        for epoch in range(epochs):
            epoch_loss = 0
            for token_fra, token_fra_len, token_eng, token_eng_len in dataloader:
                batch_size = token_fra.shape[0]
                
                # start running encoder
                encoder_optimizer.zero_grad()
                decoder_optimizer.zero_grad()
                loss = 0
                
                # init encoder hidden
                encoder_hidden = self.encoder.initHidden(batch_size)
                encoder_outputs, encoder_hidden = self.encoder.forward(token_fra,
                                                      encoder_hidden,
                                                    input_lengths=token_fra_len,
                                                    batch_size=batch_size
                                                    )
                
                # encoder ouput is first hidden of decoder
                decoder_hidden = encoder_outputs[np.arange(batch_size), 
                                                 torch.tensor(token_fra_len)-1, :].view(1,batch_size,-1)
                decoder_outputs, decoder_hidden = self.decoder.forward(token_eng,
                                                                       decoder_hidden,
                                                                       batch_size=batch_size,
                                                                      input_lengths=token_eng_len)
                for batch_index in range(batch_size):
                    # decoder_outputs: (B, seq len, hidden size)
                    output_length = token_eng_len[batch_index]
                    loss += criterion(decoder_outputs[batch_index, 0:output_length-1, :],
                                         token_eng[batch_index,1:output_length])
                  
                
                epoch_loss += loss
                loss.backward()
                decoder_optimizer.step()
                encoder_optimizer.step()
            print("epoch:", epoch, epoch_loss/len(dataloader))

    def predict(model, inputs):
        with torch.no_grad():
            
            encoder_hidden = model.encoder.initHidden(1)
            loss = 0
            encoder_outputs, encoder_hidden = model.encoder.forward(inputs,encoder_hidden)
            decoder_hidden = encoder_outputs[:,-1,:].view(1,-1,7)
            decoder_output = None
            eng_output = "SOS"
            decoder_token = torch.tensor([SOS_token], device=device)
    
            while eng_output != "EOS":
            
                decoder_output,decoder_hidden = model.decoder.forward(decoder_token,
                                                                     decoder_hidden)
                decoder_token = torch.argmax(decoder_output)
                if device.type == 'cuda':
                    eng_output = output_lang.index2word[int(decoder_token.cpu().numpy())]
                else:
                    eng_output = output_lang.index2word[int(decoder_token.numpy())]
                print(eng_output)
    def save(self, model_path):
        encoder_path = model_path + "_encoder"
        decoder_path = model_path + "_decoder"
        torch.save(self.encoder.state_dict(), encoder_path)
        torch.save(self.decoder.state_dict(), decoder_path)
    
    def load(self, model_path):
        self.encoder.load_state_dict(torch.load(model_path + "_encoder"))
        self.decoder.load_state_dict(torch.load(model_path + "_decoder"))

In [17]:
model = seq2seq(input_size=input_lang.n_words, hidden_size=7,output_size=output_lang.n_words, device=device)

In [18]:
import time
start = time.time()
model.train(dataloader,epochs=4)
end = time.time()
print(end-start)

epoch: 0 tensor(6.6274, grad_fn=<DivBackward0>)
epoch: 1 tensor(5.7268, grad_fn=<DivBackward0>)
epoch: 2 tensor(5.4315, grad_fn=<DivBackward0>)
epoch: 3 tensor(5.2251, grad_fn=<DivBackward0>)
496.23062896728516


In [33]:
model.save("seq2seq")

# Make prediction

In [19]:
#model = seq2seq(input_size=input_lang.n_words, hidden_size=5,output_size=output_lang.n_words, device=device)
#model.load("seq2seq")
input_sentence = "tu es tres bon ."
input_token = torch.tensor([input_lang.word2index[word] for word in input_sentence.split()] +
                                   [EOS_token],
                                   device=device)
#predict(model, input_token)
model.predict(input_token)

you
re
very
the
the
the
the
the
the
the
.
EOS
