In [1]:
import torch 
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import os
from torch import optim

In [22]:
# set data path
data_dir = os.path.join('datasets', 'nmt_data_vi')
train_source = 'train.vi'
train_target = 'train.en'
train_source_dir = os.path.join(data_dir, train_source)
train_target_dir = os.path.join(data_dir, train_target)
vocab_source = 'vocab.vi'
vocab_target = 'vocab.en'
vocab_source_dir = os.path.join(data_dir, vocab_source)
vocab_target_dir = os.path.join(data_dir, vocab_target)

In [3]:
# load training sets
with open(train_source_dir) as f_source:
    sentences_source = f_source.readlines()
with open(train_target_dir) as f_target:
    sentences_target = f_target.readlines()

In [16]:
print("Total number of sentences in source training set: {}".format(len(sentences_source)))
print("Total number of sentences in target training set: {}".format(len(sentences_target)))

Total number of sentences in source training set: 133317
Total number of sentences in target training set: 133317


In [5]:
# load vocabularies
with open(vocab_source_dir) as f_vocab_source:
    #index2word_source = f_vocab_source.readlines()
    index2word_source = [line.rstrip() for line in f_vocab_source]
with open(vocab_target_dir) as f_vocab_target:
    #index2word_target = f_vocab_target.readlines()
    index2word_target = [line.rstrip() for line in f_vocab_target]

In [20]:
source_vocab_size = len(index2word_source)
target_vocab_size = len(index2word_target)
print("Total nummber of words in source vocabulary: {}".format(len(index2word_source)))
print("Total nummber of words in target vocabulary: {}".format(len(index2word_target)))

Total nummber of words in source vocabulary: 7709
Total nummber of words in target vocabulary: 17191


In [7]:
word2index_source = {}
for idx, word in enumerate(index2word_source):
    word2index_source[word] = idx
word2index_target = {}
for idx, word in enumerate(index2word_target):
    word2index_target[word] = idx

In [8]:
# Preparing Data
def sen2idx(sentence, word2index):
    return [word2index.get(word, 0) for word in sentence.split(' ')] # assume that 0 is for <unk>

def sen2tensor(sentence, word2index):
    idxes = sen2idx(sentence, word2index)
    idxes.append(EOS_token)
    return torch.tensor(idxes, dtype=torch.long, device=device)

In [9]:
# Parameters
MAX_LENGTH = 25
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SOS_token = 0
EOS_token = 1

hidden_size = 512

In [10]:
class EncoderLSTM(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EncoderLSTM, self).__init__()
        self.hidden_size = hidden_size
        
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.lstm = nn.LSTMCell(hidden_size, hidden_size)
    def forward(self, input, prev_h, prev_c):
        input_embedded = self.embedding(input)
        h, c = self.lstm(input_embedded, (prev_h, prev_c))
        return h, c
    def initHidden(self):
        return torch.zeros(1, self.hidden_size, device=device)

In [11]:
class DecoderLSTM(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(DecoderLSTM, self).__init__()
        self.hidden_size = hidden_size
        
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.lstm = nn.LSTMCell(hidden_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)
    def forward(self, input, prev_h, prev_c):
        input_embedded = self.embedding(input)
        h, c = self.lstm(input_embedded, (prev_h, prev_c))
        output =self.softmax(self.out(h))
        return output, h, c
    def initHidden(self):
        return torch.zeros(1, self.hidden_size, device=device)

In [12]:
def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, max_length=MAX_LENGTH):
    encoder_hidden_h = encoder.initHidden()
    encoder_hidden_c = encoder.initHidden()
    
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    
    input_length = input_tensor.size(0)
    target_length = target_tensor.size(0)
    
    loss = 0
    criterion = nn.NLLLoss()
    
    for ei in range(input_length):
        encoder_hidden_h, encoder_hidden_c = encoder(input_tensor[ei].view(1), encoder_hidden_h, encoder_hidden_c)
    
    decoder_input = torch.tensor([[SOS_token]], device=device)
    decoder_hidden_c = encoder_hidden_c
    decoder_hidden_h = encoder_hidden_h
    
    for di in range(target_length):
        decoder_output, decoder_hidden_h, decoder_hidden_c = decoder(decoder_input.view(1), decoder_hidden_h, decoder_hidden_c)
        loss += criterion(decoder_output, target_tensor[di].view(1))
        decoder_input = target_tensor[di]
    
    loss.backward()
    
    encoder_optimizer.step()
    decoder_optimizer.step()
    
    return loss.item() / target_length

In [13]:
def trainIters(encoder, decoder, n_iters, print_every=1000, plot_every=100, learning_rate=0.01):
    plot_losses = []
    print_loss_total = 0
    plot_loss_total = 0
    
    encoder_optimizer = optim.SGD(encoder.parameters(), learning_rate)
    decoder_optimizer = optim.SGD(decoder.parameters(), learning_rate)
    
    for iter in range(1, n_iters+1):
        input_tensor = sen2tensor(sentences_source[iter-1], word2index_source)
        target_tensor = sen2tensor(sentences_target[iter-1], word2index_target)
        
        loss = train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer)
        print_loss_total += loss
        plot_loss_total += loss
        
        if iter%print_every ==0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print('(%d %d%%) %.4f' % (iter, iter / n_iters * 100, print_loss_avg))

In [14]:
encoder1 = EncoderLSTM(source_vocab_size, hidden_size).to(device)
decoder1 = DecoderLSTM(hidden_size, target_vocab_size).to(device)
trainIters(encoder1, decoder1, 133317, print_every=1)

(1 0%) 9.7950
(2 0%) 9.7657
(3 0%) 9.7038
(4 0%) 9.7050
(5 0%) 9.6514
(6 0%) 9.6383
(7 0%) 9.6300
(8 0%) 9.5787
(9 0%) 9.5858
(10 0%) 9.6439
(11 0%) 9.5941
(12 0%) 9.6279
(13 0%) 9.6036
(14 0%) 9.6211
(15 0%) 9.4434
(16 0%) 9.3253
(17 0%) 9.4326
(18 0%) 9.1876
(19 0%) 9.2698
(20 0%) 9.2184
(21 0%) 9.0665
(22 0%) 8.0887
(23 0%) 9.1504
(24 0%) 8.5712
(25 0%) 9.0539
(26 0%) 8.9005
(27 0%) 8.0030
(28 0%) 8.7416
(29 0%) 8.8727
(30 0%) 9.0309
(31 0%) 9.0578
(32 0%) 8.3912
(33 0%) 7.9025
(34 0%) 8.8234
(35 0%) 8.3513
(36 0%) 8.9950
(37 0%) 9.1683
(38 0%) 8.3022
(39 0%) 7.4667
(40 0%) 7.9123
(41 0%) 8.6362
(42 0%) 8.5013
(43 0%) 8.7922
(44 0%) 8.4008
(45 0%) 7.7894
(46 0%) 8.6016
(47 0%) 8.5772
(48 0%) 7.8845
(49 0%) 6.6711
(50 0%) 6.7027
(51 0%) 8.1349
(52 0%) 7.7260
(53 0%) 7.0871
(54 0%) 7.5398
(55 0%) 8.6970
(56 0%) 8.2807
(57 0%) 7.6769
(58 0%) 7.0634
(59 0%) 7.4636
(60 0%) 7.1844
(61 0%) 6.8641
(62 0%) 6.8567
(63 0%) 7.8745
(64 0%) 7.1352
(65 0%) 7.1144
(66 0%) 7.5639
(67 0%) 6.4010
(68 

KeyboardInterrupt: 