In [1]:
import torch 
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import os
from torch import optim

## Parameter Setup

In [2]:
# parameters
MAX_LENGTH = 25
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SOS_token = 0
EOS_token = 1

hidden_size = 128

## Data Setup

In [3]:
# set data path
data_dir = os.path.join('datasets', 'nmt_data_vi')
train_source = 'train.vi'
train_target = 'train.en'
train_source_dir = os.path.join(data_dir, train_source)
train_target_dir = os.path.join(data_dir, train_target)
vocab_source = 'vocab.vi'
vocab_target = 'vocab.en'
vocab_source_dir = os.path.join(data_dir, vocab_source)
vocab_target_dir = os.path.join(data_dir, vocab_target)

In [4]:
# load training sets
with open(train_source_dir) as f_source:
    sentences_source = f_source.readlines()
with open(train_target_dir) as f_target:
    sentences_target = f_target.readlines()

# check the total number of sentencs in training sets    
print("Total number of sentences in source training set: {}".format(len(sentences_source)))
print("Total number of sentences in target training set: {}".format(len(sentences_target)))

Total number of sentences in source training set: 133317
Total number of sentences in target training set: 133317


In [5]:
# Truncate sentences by maximum length
sentences_source = list(map(lambda src:src.split()[:MAX_LENGTH], sentences_source))
sentences_target = list(map(lambda src:src.split()[:MAX_LENGTH], sentences_target))

# check the longest sentence after sentence truncation
max = 0
for s in sentences_source:
    if len(s) > max:
        max = len(s)
        max_s = s
print("Number of words in the longest sentence in sentences_source: {}".format(max))
print("The longest sentence: \n{}".format(max_s))

Number of words in the longest sentence in sentences_source: 25
The longest sentence: 
['Trong', '4', 'phút', ',', 'chuyên', 'gia', 'hoá', 'học', 'khí', 'quyển', 'Rachel', 'Pike', 'giới', 'thiệu', 'sơ', 'lược', 'về', 'những', 'nỗ', 'lực', 'khoa', 'học', 'miệt', 'mài', 'đằng']


In [6]:
# load vocabularies

# index2word
with open(vocab_source_dir) as f_vocab_source:
    #index2word_source = f_vocab_source.readlines()
    index2word_source = [line.rstrip() for line in f_vocab_source]
with open(vocab_target_dir) as f_vocab_target:
    #index2word_target = f_vocab_target.readlines()
    index2word_target = [line.rstrip() for line in f_vocab_target]

# word2index
word2index_source = {}
for idx, word in enumerate(index2word_source):
    word2index_source[word] = idx
word2index_target = {}
for idx, word in enumerate(index2word_target):
    word2index_target[word] = idx
    
# check vocabularies size    
source_vocab_size = len(index2word_source)
target_vocab_size = len(index2word_target)
print("Total nummber of words in source vocabulary: {}".format(len(index2word_source)))
print("Total nummber of words in target vocabulary: {}".format(len(index2word_target)))    

Total nummber of words in source vocabulary: 7709
Total nummber of words in target vocabulary: 17191


In [7]:
# helper funtions to convert sentence in natural language to list of word indexes
def sen2idx(sentence, word2index):
    return [word2index.get(word, 0) for word in sentence] # assume that 0 is for <unk>

def sen2tensor(sentence, word2index):
    idxes = sen2idx(sentence, word2index)
    idxes.append(EOS_token)
    return torch.tensor(idxes, dtype=torch.long, device=device)

## Define Encoder and Decoder classes

In [8]:
class EncoderLSTM(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EncoderLSTM, self).__init__()
        self.hidden_size = hidden_size
        
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.lstm = nn.LSTMCell(hidden_size, hidden_size)
    def forward(self, input, prev_h, prev_c):
        input_embedded = self.embedding(input)
        h, c = self.lstm(input_embedded, (prev_h, prev_c))
        return h, c
    def initHidden(self):
        return torch.zeros(1, self.hidden_size, device=device)

In [13]:
class DecoderLSTM(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(DecoderLSTM, self).__init__()
        self.hidden_size = hidden_size
        
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.lstm = nn.LSTMCell(hidden_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)
    def forward(self, input, prev_h, prev_c):
        input_embedded = self.embedding(input)
        h, c = self.lstm(input_embedded, (prev_h, prev_c))
        output =self.softmax(self.out(h))
        return output, h, c
    def initHidden(self):
        return torch.zeros(1, self.hidden_size, device=device)

In [9]:
class AttnDecoderLSTM(nn.Module):
    def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGTH):
        super(AttnDecoderLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
#         self.dropout_p = dropout_p
        self.max_length = max_length
        
        self.embedding = nn.Embedding(self.output_size, self.hidden_size)
        self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
        self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
#         self.dropout = nn.Dropout(self.dropout_p)
        self.lstm = nn.LSTMCell(self.hidden_size, self.hidden_size)
        self.out = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, input, prev_h, prev_c, encoder_outputs):
        embedded = self.embedding(input).view(1, 1, -1)

        attn_weights = F.softmax(
            self.attn(torch.cat((embedded[0], prev_h), 1)), dim=1)

        attn_applied = torch.bmm(attn_weights.unsqueeze(0),
                                 encoder_outputs.unsqueeze(0))

        output = torch.cat((embedded[0], attn_applied[0]), 1)
        output = self.attn_combine(output).unsqueeze(0)

        output = F.relu(output)
        h, c = self.lstm(output[0], (prev_h, prev_c))

        output = F.log_softmax(self.out(h))
        return output, h, c, attn_weights

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

## Training

In [14]:
def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, max_length=MAX_LENGTH):
    encoder_hidden_h = encoder.initHidden()
    encoder_hidden_c = encoder.initHidden()
    
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    
    input_length = input_tensor.size(0)
    target_length = target_tensor.size(0)
    
    loss = 0
    criterion = nn.NLLLoss()
    
    for ei in range(input_length):
        encoder_hidden_h, encoder_hidden_c = encoder(input_tensor[ei].view(1), encoder_hidden_h, encoder_hidden_c)
    
    decoder_input = torch.tensor([[SOS_token]], device=device)
    decoder_hidden_c = encoder_hidden_c
    decoder_hidden_h = encoder_hidden_h
    
    for di in range(target_length):
        decoder_output, decoder_hidden_h, decoder_hidden_c = decoder(decoder_input.view(1), decoder_hidden_h, decoder_hidden_c)
        loss += criterion(decoder_output, target_tensor[di].view(1))
        decoder_input = target_tensor[di]
    
    loss.backward()
    
    encoder_optimizer.step()
    decoder_optimizer.step()
    
    return loss.item() / target_length

In [10]:
def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, max_length=MAX_LENGTH):
    encoder_hidden_h = encoder.initHidden()
    encoder_hidden_c = encoder.initHidden()
    
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    
    input_length = input_tensor.size(0)
    target_length = target_tensor.size(0)
    
    encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
    
    loss = 0
    criterion = nn.NLLLoss()
    
#     print(input_length)
    
    for ei in range(min(input_length, max_length)):
        encoder_hidden_h, encoder_hidden_c = encoder(input_tensor[ei].view(1), encoder_hidden_h, encoder_hidden_c)
        encoder_outputs[ei] = encoder_hidden_h[0]
    
    decoder_input = torch.tensor([[SOS_token]], device=device)
    decoder_hidden_c = encoder_hidden_c
    decoder_hidden_h = encoder_hidden_h
    
    for di in range(target_length):
#         decoder_output, decoder_hidden_h, decoder_hidden_c = decoder(decoder_input.view(1), decoder_hidden_h, decoder_hidden_c)
#         loss += criterion(decoder_output, target_tensor[di].view(1))
#         decoder_input = target_tensor[di]
        decoder_output, decoder_hidden_h, decoder_hidden_c, decoder_attention = decoder(
                decoder_input, decoder_hidden_h, decoder_hidden_c, encoder_outputs)
#         print(decoder_output.shape)
#         print(target_tensor[di].view(1).shape)
        loss += criterion(decoder_output, target_tensor[di].view(1))
        decoder_input = target_tensor[di]  # Teacher forcing

    
    loss.backward()
    
    encoder_optimizer.step()
    decoder_optimizer.step()
    
    return loss.item() / target_length

In [11]:
def trainIters(encoder, decoder, n_iters, print_every=1000, plot_every=100, learning_rate=0.01):
    plot_losses = []
    print_loss_total = 0
    plot_loss_total = 0
    
    encoder_optimizer = optim.SGD(encoder.parameters(), learning_rate)
    decoder_optimizer = optim.SGD(decoder.parameters(), learning_rate)
    
    for iter in range(1, n_iters+1):
        input_tensor = sen2tensor(sentences_source[iter-1], word2index_source)
        target_tensor = sen2tensor(sentences_target[iter-1], word2index_target)
        
        loss = train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer)
        print_loss_total += loss
        plot_loss_total += loss
        
        if iter%print_every ==0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print('(%d %d%%) %.4f' % (iter, iter / n_iters * 100, print_loss_avg))

In [15]:
encoder1 = EncoderLSTM(source_vocab_size, hidden_size).to(device)
decoder1 = DecoderLSTM(hidden_size, target_vocab_size).to(device)
trainIters(encoder1, decoder1, 133317, print_every=1)

(1 0%) 9.7689
(2 0%) 9.7480
(3 0%) 9.7672
(4 0%) 9.7140
(5 0%) 9.7676
(6 0%) 9.7430
(7 0%) 9.7461
(8 0%) 9.7596
(9 0%) 9.7405
(10 0%) 9.7522
(11 0%) 9.7221
(12 0%) 9.7295
(13 0%) 9.6937
(14 0%) 9.7025
(15 0%) 9.6719
(16 0%) 9.7168
(17 0%) 9.7136
(18 0%) 9.6315
(19 0%) 9.6501
(20 0%) 9.6224
(21 0%) 9.6988
(22 0%) 9.5729
(23 0%) 9.6763
(24 0%) 9.6138
(25 0%) 9.6221
(26 0%) 9.6716
(27 0%) 9.5627
(28 0%) 9.6106
(29 0%) 9.6049
(30 0%) 9.6160
(31 0%) 9.5528
(32 0%) 9.4708
(33 0%) 9.4915
(34 0%) 9.6058
(35 0%) 9.4918
(36 0%) 9.6004
(37 0%) 9.6531
(38 0%) 9.4746
(39 0%) 9.4198
(40 0%) 9.3809
(41 0%) 9.5493
(42 0%) 9.5169
(43 0%) 9.5385
(44 0%) 9.4711
(45 0%) 9.3027
(46 0%) 9.6065
(47 0%) 9.5685
(48 0%) 9.4791
(49 0%) 9.2409
(50 0%) 9.2098
(51 0%) 9.5125
(52 0%) 9.4247
(53 0%) 9.2147
(54 0%) 9.5244
(55 0%) 9.5537
(56 0%) 9.5278
(57 0%) 9.3684
(58 0%) 9.2818
(59 0%) 9.4141
(60 0%) 9.3359
(61 0%) 9.2290
(62 0%) 9.2474
(63 0%) 9.3009
(64 0%) 9.2622
(65 0%) 9.0566
(66 0%) 9.2232
(67 0%) 9.1051
(68 

KeyboardInterrupt: 

In [12]:
encoder2 = EncoderLSTM(source_vocab_size, hidden_size).to(device)
decoder2 = AttnDecoderLSTM(hidden_size, target_vocab_size).to(device)
trainIters(encoder2, decoder2, 133317, print_every=1)



(1 0%) 9.7440
(2 0%) 9.7407
(3 0%) 9.7158
(4 0%) 9.6977
(5 0%) 9.7569
(6 0%) 9.6949
(7 0%) 9.7189
(8 0%) 9.7101
(9 0%) 9.7171
(10 0%) 9.7166
(11 0%) 9.7111
(12 0%) 9.6984
(13 0%) 9.6643
(14 0%) 9.6864
(15 0%) 9.6610
(16 0%) 9.6248
(17 0%) 9.6389
(18 0%) 9.5807
(19 0%) 9.6297
(20 0%) 9.5947
(21 0%) 9.5822
(22 0%) 9.4879
(23 0%) 9.6561
(24 0%) 9.5249
(25 0%) 9.5871
(26 0%) 9.5548
(27 0%) 9.5324
(28 0%) 9.5565
(29 0%) 9.5513
(30 0%) 9.6077
(31 0%) 9.4717
(32 0%) 9.3608
(33 0%) 9.4252
(34 0%) 9.5487
(35 0%) 9.3952
(36 0%) 9.5239
(37 0%) 9.4832
(38 0%) 9.3302
(39 0%) 9.2509
(40 0%) 9.3108
(41 0%) 9.3851
(42 0%) 9.3637
(43 0%) 9.3892
(44 0%) 9.2423
(45 0%) 8.9826
(46 0%) 9.4136
(47 0%) 9.4137
(48 0%) 9.1472
(49 0%) 8.8242
(50 0%) 8.6504
(51 0%) 9.1220
(52 0%) 9.0072
(53 0%) 8.8228
(54 0%) 9.0984
(55 0%) 9.1414
(56 0%) 8.9019
(57 0%) 8.8805
(58 0%) 8.4655
(59 0%) 8.7115
(60 0%) 8.4812
(61 0%) 8.1885
(62 0%) 8.0434
(63 0%) 8.5263
(64 0%) 8.0180
(65 0%) 7.8279
(66 0%) 8.1766
(67 0%) 7.6406
(68 

KeyboardInterrupt: 