In [1]:
from data import generate_batches
from data import prepare_data
from data import data_to_index
from data import DEP_LABELS
from data import random_batch

from model.encoder import Encoder
from model.decoder import Decoder_luong
from model.tree_lstm import Tree_lstm

from BLEU import BLEU

from utils import time_since

from evaluator import Evaluator

import torch
import torch.nn as nn
from torch.nn import functional
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F

import numpy as np
import time
import random

#from validation import Evaluator

%load_ext autoreload
%autoreload 2

In [2]:
USE_CUDA = True
MAX_LENGTH = 100
DIR_FILES = 'data/translation/train/'
DIR_RESULTS = 'results/step_1'
SPLIT_TRAIN = 0.7
SPLIT_VALID = 0.15
# The rest is for test

# Reading the data

In [3]:
input_lang, output_lang, input_trees, _, pairs = prepare_data('eng', 'esp', dir=DIR_FILES, return_trees=True)

Reading lines...
Read 115244 sentence pairs
Filtered to 83374 pairs
Creating vocab...
Creating trees...
Indexed 12248 words in input language, 22537 words in output


In [4]:
pairs_train = np.array(pairs[:60000])
pairs_test = np.array(pairs[60000:])

trees_train = np.array(input_trees[:60000])
trees_test = np.array(input_trees[60000:])

# Train

In [5]:
def train(input_batches, target_batches, input_tree,\
          encoder, decoder, tree, criterion, batch_ix, train=True):
    
    if train and (batch_ix % batch_size) == 0:
        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()
        tree_optimizer.zero_grad()
        
    loss = 0
    
    encoder_hidden = encoder.init_hidden(1)
    encoder_outputs, encoder_hidden = encoder(input_batches, encoder_hidden)

    state, tree_hidden = tree(input_tree[0], encoder_outputs)
    encoder_outputs = torch.cat((encoder_outputs, state.unsqueeze(0)))
    #print(encoder_outputs.shape, state.shape)
    
    decoder_context = Variable(torch.zeros(1, decoder.hidden_size))   
    decoder_hidden = torch.cat((encoder_hidden, tree_hidden.unsqueeze(0)))
    #decoder_hidden = encoder_hidden
    # set the start of the sentences of the batch
    decoder_input = torch.LongTensor([input_lang.vocab.stoi['<sos>']] * 1)

    # store the decoder outputs to estimate the loss
    all_decoder_outputs = Variable(torch.zeros(target_batches.size()[0], 1, len(output_lang.vocab.stoi)))
    
    if USE_CUDA:
        decoder_input = decoder_input.cuda()
        all_decoder_outputs = all_decoder_outputs.cuda()
        decoder_context = decoder_context.cuda()  
    
    if train:
        use_teacher_forcing = random.random() < teacher_forcing_ratio
    else:
        use_teacher_forcing = False
    
    if use_teacher_forcing:        
        # Use targets as inputs
        for di in range(target_batches.shape[0]):
            decoder_output, decoder_context, decoder_hidden, decoder_attention = decoder(
                decoder_input.unsqueeze(0), decoder_context, decoder_hidden, encoder_outputs)
            
            all_decoder_outputs[di] = decoder_output
            decoder_input = target_batches[di]
    else:        
        # Use decoder output as inputs
        for di in range(target_batches.shape[0]): 
            decoder_output, decoder_context, decoder_hidden, decoder_attention = decoder(
                decoder_input.unsqueeze(0), decoder_context, decoder_hidden, encoder_outputs) 
            
            all_decoder_outputs[di] = decoder_output
            
            # Greedy approach, take the word with highest probability
            topv, topi = decoder_output.data.topk(1)            
            decoder_input = Variable(torch.LongTensor(topi.cpu()).squeeze(dim=0))
            if USE_CUDA: decoder_input = decoder_input.cuda()
    
    loss = nn.NLLLoss()(all_decoder_outputs.view(-1, decoder.output_size), target_batches.contiguous().view(-1))          
    
    if train and (batch_ix % batch_size) == 0:
        loss.backward()
        torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip)
        torch.nn.utils.clip_grad_norm_(decoder.parameters(), clip)
        torch.nn.utils.clip_grad_norm_(tree.parameters(), clip)
        encoder_optimizer.step()
        decoder_optimizer.step()
        tree_optimizer.step()
    elif train:
        loss.backward()
    
    return loss.item() 

# Model

In [6]:
attn_model = 'general'
hidden_size = 512
emb_size = 300
n_layers = 2
dropout_p = 0.1
seed = 12
teacher_forcing_ratio = 0.5
clip = 5.0

n_epochs = 20
batch_size = 128

In [7]:
torch.manual_seed(seed)
np.random.seed(seed)

In [9]:
encoder = Encoder(len(input_lang.vocab.stoi), hidden_size, emb_size, n_layers, dropout_p, input_lang, USE_CUDA)
decoder = Decoder_luong(attn_model, hidden_size, len(output_lang.vocab.stoi), emb_size, 2 * n_layers + 1, dropout_p, output_lang, USE_CUDA)
tree = Tree_lstm(hidden_size, hidden_size)

if USE_CUDA:
    encoder = encoder.cuda()
    decoder = decoder.cuda()
    tree = tree.cuda()
    
learning_rate = 0.001
encoder_optimizer = optim.Adam(filter(lambda p: p.requires_grad, encoder.parameters()), lr=learning_rate)
decoder_optimizer = optim.Adam(filter(lambda p: p.requires_grad, decoder.parameters()), lr=learning_rate)
tree_optimizer = optim.Adam(tree.parameters(), lr=learning_rate)
criterion = nn.NLLLoss()

In [10]:
# Keep track of time elapsed and running averages
start = time.time()
train_losses = []
validation_losses = []
validation_bleu = []

plot_every = 5
print_every = 5
validate_loss_every = 25
best_bleu = 0
print_loss_total = 0 # Reset every print_every
plot_loss_total = 0 # Reset every plot_every

In [None]:
for epoch in range(1, n_epochs): 
    # Shuffle data
    #id_aux = np.random.permutation(np.arange(len(pairs_train)))
    #pairs_train = pairs_train[id_aux]
    
    # Get the batches for this epoch
    input_batches, input_trees, target_batches = generate_batches(input_lang, output_lang, 1, pairs_train, arr_dep=trees_train, USE_CUDA=USE_CUDA)    
    
    for batch_ix, (input_batch, input_tree, target_var) in enumerate(zip(input_batches, input_trees, target_batches)):
        encoder.train()
        decoder.train()
        tree.train()
        
        #[input_var, _, _, _, _, _, _, _] = input_batch
        input_var = input_batch

        # Run the train function
        loss = train(input_var, target_var, input_tree,\
                 encoder, decoder, tree, criterion, batch_ix, train=True)
        #loss = train_luong(input_var, target_var, input_var.size(1), True)
            
        torch.cuda.empty_cache()

        # Keep track of loss
        print_loss_total += loss
        plot_loss_total += loss

        if batch_ix == 0: continue            

        if batch_ix % (print_every * batch_size) == 0:
            print_loss_avg = print_loss_total / (print_every * batch_size)
            print_loss_total = 0
            print_summary = '%s (%d %d%%) %.4f' % (time_since(start, epoch / n_epochs), epoch, batch_ix / len(input_batches) * 100, print_loss_avg)
            train_losses.append(loss)
            print(print_summary)
    
    input_batches, input_trees, target_batches = generate_batches(input_lang, output_lang, 1, pairs_test, arr_dep=trees_test, USE_CUDA=USE_CUDA)
    print_loss_total = 0
    for batch_ix, (input_batch, input_tree, target_var) in enumerate(zip(input_batches, input_trees, target_batches)):
    
        encoder.eval()
        decoder.eval()
        tree.eval()
    
        with torch.no_grad():
            #[input_var, _, _, _, _, _, _, _] = input_batch
            input_var = input_batch
            # Run the train function
            loss = train(input_var, target_var, input_tree,\
                     encoder, decoder, tree, criterion, batch_ix, train=False)

            print_loss_total += loss
            del loss
            torch.cuda.empty_cache()
    val_loss = print_loss_total / len(input_batches)
    validation_losses.append(val_loss)
    # Evaluating Bleu
    #evaluator = Evaluator(encoder, decoder, input_lang, output_lang, MAX_LENGTH, True)
    #candidates, references = evaluator.get_candidates_and_references(pairs_test, k_beams=1)
    #bleu = BLEU(candidates, [references])
    #if bleu[0] > best_bleu:
    #    best_bleu = bleu[0]
    #    torch.save(encoder.state_dict(), f'{DIR_RESULTS}/encoder.pkl')
    #    torch.save(decoder.state_dict(), f'{DIR_RESULTS}/decoder.pkl')
    #validation_bleu.append(bleu)
    print(f'val_loss: {val_loss:.4f} - bleu: {0}')

    # Prevent overflow gpu memory
    #del evaluator

1m 36s (- 30m 36s) (1 1%) 8.9360
2m 57s (- 56m 18s) (1 2%) 7.9924
4m 16s (- 81m 5s) (1 3%) 7.6602
5m 41s (- 108m 13s) (1 4%) 7.3888
7m 5s (- 134m 44s) (1 5%) 7.2123
8m 30s (- 161m 44s) (1 6%) 7.4825
9m 53s (- 187m 54s) (1 7%) 7.4259
11m 21s (- 215m 49s) (1 8%) 7.1312
12m 45s (- 242m 23s) (1 9%) 6.9250
14m 11s (- 269m 29s) (1 10%) 6.8648
15m 34s (- 295m 49s) (1 11%) 6.9617
16m 58s (- 322m 37s) (1 12%) 6.9180
18m 24s (- 349m 40s) (1 13%) 6.8046
19m 48s (- 376m 19s) (1 14%) 7.1147
21m 9s (- 401m 57s) (1 16%) 8.0596
22m 33s (- 428m 38s) (1 17%) 6.7547
23m 58s (- 455m 38s) (1 18%) 6.7011
25m 23s (- 482m 21s) (1 19%) 6.8869
26m 50s (- 510m 6s) (1 20%) 6.8017
28m 17s (- 537m 38s) (1 21%) 6.7670
29m 42s (- 564m 33s) (1 22%) 6.6528
31m 5s (- 590m 52s) (1 23%) 6.9381
32m 35s (- 619m 16s) (1 24%) 7.6222
34m 0s (- 646m 10s) (1 25%) 7.8320
35m 26s (- 673m 32s) (1 26%) 7.0296
36m 51s (- 700m 24s) (1 27%) 6.9229
38m 21s (- 728m 53s) (1 28%) 6.8373
39m 48s (- 756m 22s) (1 29%) 6.5892
41m 15s (- 784m 2

In [None]:
evaluator = Evaluator(encoder, decoder, input_lang, output_lang, 
                      MAX_LENGTH, USE_CUDA)
candidates, references = evaluator.get_candidates_and_references(pairs_test[:10000], k_beams=2)
len(candidates), len(references)

In [12]:
BLEU(candidates, [references]) 

(0.28063523097173265,
 [0.6573135078342698,
  0.37567686039915427,
  0.22488307382629802,
  0.13494545201862276],
 0.953820572858132)

In [11]:
train_losses

[9.041672706604004,
 7.149496555328369,
 14.810012817382812,
 12.101698875427246,
 9.06682014465332,
 7.260944366455078,
 6.044005393981934,
 6.0010833740234375,
 7.1673264503479,
 6.06639289855957,
 6.229933261871338,
 8.623738288879395,
 5.882908344268799,
 4.927737236022949,
 5.671040058135986,
 7.489520072937012,
 5.501565456390381,
 5.914140224456787,
 6.912267208099365,
 6.224165916442871,
 7.6937079429626465,
 6.377658843994141,
 8.010522842407227,
 8.193879127502441,
 7.284544944763184,
 4.864162445068359,
 6.447579860687256,
 6.805881977081299,
 5.03970193862915,
 6.144567012786865,
 5.522188186645508,
 6.3946533203125,
 7.317024230957031,
 7.084739685058594,
 4.866414546966553,
 4.8789286613464355,
 6.360021114349365,
 5.258521556854248,
 7.594843864440918,
 5.99109411239624,
 6.1218085289001465,
 4.6263885498046875,
 6.505831241607666,
 6.49678897857666,
 6.661844253540039,
 5.703457832336426,
 6.080120086669922,
 5.556210994720459,
 4.3718485832214355,
 7.616245269775391,
 

In [13]:
pairs_test[480]

array(["tom wasn 't convinced it was a good idea .",
       'tom no estaba convencido de que fuera una buena idea .'],
      dtype='<U245')

In [14]:
pairs_train[80000]

array(['just act as if nothing has happened .',
       'haga de cuenta que nada ha ocurrido .'],
      dtype='<U245')