In [1]:
#!/usr/bin/env python
# coding: utf-8

import torch
import torch.nn as nn
from torch.autograd import Variable

import os
import argparse
from layer import QRNNLayer
from model import QRNNModel

import data.data_utils as data_utils
from data.data_utils import fopen
from data.data_utils import load_inv_dict
from data.data_utils import seq2words

from data.data_iterator import TextIterator
from data.data_iterator import BiTextIterator
from data.data_iterator import prepare_batch
from data.data_iterator import prepare_train_batch

use_cuda = torch.cuda.is_available()

In [2]:
# Data loading parameters
src_vocab='data/train.clean.en.json'
tgt_vocab='data/train.clean.fr.json'
# src_train='data/train.clean.en'
# tgt_train='data/train.clean.fr'
# src_valid='data/train.clean.en'
# tgt_valid='data/train.clean.fr'
src_train='data/test.en'
tgt_train='data/test.fr'
src_valid='data/test.en'
tgt_valid='data/test.fr'

# Network parameters
kernel_size = 2
hidden_size = 1024
num_layers = 2
emb_size = 500
num_enc_symbols = 30000
num_dec_symbols = 30000
dropout_rate = 0.3

# Training parameters
lr = 0.0002
max_grad_norm = 1.0
batch_size = 128
max_epochs = 1000
maxi_batches = 20
max_seq_len = 15
display_freq = 100
save_freq = 100
valid_freq = 100
model_dir = 'model/'
model_name = 'model.pkl'
shuffle = True
sort_by_len = True

# Decoding parameters
model_path = 'model/model.pkl'
src_vocab = 'data/train.clean.en.json'
tgt_vocab = 'data/train.clean.fr.json'
decode_input = 'data/test.en'
decode_output = 'data/test.en.trans'
max_decode_step = 20

In [3]:
def load_model():
    if os.path.exists(model_path):
        print 'Reloading model parameters..'
        checkpoint = torch.load(model_path)
        model = QRNNModel(QRNNLayer, num_layers, kernel_size,
                          hidden_size, emb_size, 
                          num_enc_symbols, num_dec_symbols)
        model.load_state_dict(checkpoint['state_dict'])
    else:
        raise ValueError('No such file:[{}]'.format(model_path))

    return model, checkpoint

In [4]:
model, config = load_model()

# Load source data to decode
test_set = TextIterator(source=decode_input,
                        source_dict=src_vocab,
                        batch_size=batch_size,
                        n_words_source=num_enc_symbols,
                        maxlen=None)

valid_set = BiTextIterator(source=src_valid,
                           target=tgt_valid,
                           source_dict=src_vocab,
                           target_dict=tgt_vocab,
                           batch_size=batch_size,
                           maxlen=None,
                           shuffle_each_epoch=False,
                           n_words_source=num_enc_symbols,
                           n_words_target=num_dec_symbols)

target_inv_dict = load_inv_dict(tgt_vocab)

if use_cuda:
    print 'Using gpu..'
    model = model.cuda()

try:
    print 'Decoding starts..'
    fout = fopen(decode_output, 'w')
    #for idx, source_seq in enumerate(test_set):
    
    for idx, (source_seq, target_seq) in enumerate(valid_set):
        # source, source_len = prepare_batch(source_seq)
        # Get a batch from training parallel data
        source, source_len, dec_input, dec_target, dec_len = \
            prepare_train_batch(source_seq, target_seq, max_seq_len)
        
        print 'source', source
        print 'source_len', source_len
        print 'dec_input', dec_input
        print 'dec_target', dec_target

        preds_prev = torch.zeros(len(source), max_decode_step).long()
        preds_prev[:,0] += data_utils.start_token
        preds = torch.zeros(len(source), max_decode_step).long()

        if use_cuda:
            source = Variable(source.cuda())
            source_len = Variable(source_len.cuda())
            preds_prev = Variable(preds_prev.cuda())
            preds = preds.cuda()
            dec_input = Variable(dec_input.cuda())
        else:
            source = Variable(source)
            source_len = Variable(source_len)
            preds_prev = Variable(preds_prev)

        states, memories = model.encode(source, source_len)
        embedding = model.decoder.embedding
        
        
        
#         for t in xrange(max_decode_step):
#             # logits: [batch_size x max_decode_step, tgt_vocab_size]
#             #print 'preds_prev', preds_prev[:,:t+1]
#             _, logits = model.decode(dec_input[:,:3], states, memories, keep_len=True)
#             # outputs: [batch_size, max_decode_step]
#             outputs = torch.max(logits, dim=1)[1].view(len(source), -1)
#             print 'preds', outputs
#             preds[:,t] = outputs[:,t].data
#             if t < max_decode_step - 1:
#                 preds_prev[:,t+1] = outputs[:,t]

#         for i in xrange(len(preds)):
#             fout.write(str(seq2words(preds[i], target_inv_dict)) + '\n')
#             fout.flush()

#         print '  {}th line decoded'.format(idx * batch_size)
#     print 'Decoding terminated'

except IOError:
    pass
finally:
    fout.close()

Reloading model parameters..
Using gpu..
Decoding starts..
source 
  427   625     4   316  1102   111     3     3     3     3     3
  367    12     4   803  5324  5323    76   649   271     3     3
    4  2022  1476    14     4   349  5049     3     3     3     3
   72   621    13   723   670  1218    32     4  5032     3     3
   93     5  5339  1888    10  1015     7    49     3     3     3
   10  1229     9     4    52   970  1258     3     3     3     3
   49   898    19   230   107   347     7  5454   141     3     3
   65  1076   334     9  5142   154  1911     3     3     3     3
    4  5353   814  2616    34     9    18  5053     9    11  2648
   43    36   130   396   610    19  1463   705     3     3     3
   32   568    58  2272  5099     9     5    27   107   964     3
  714  5430  5428  4832     5   714  4830  1586  5423  5421     3
 5194  5319     5   940  5316  5313     3     3     3     3     3
  198   103    86   147    24     3     3     3     3     3     3
 5288  52

RuntimeError: tensors are on different GPUs