In [1]:
#!/usr/bin/env python
# coding: utf-8

import torch
import torch.nn as nn
from torch.autograd import Variable

import os
import argparse
from layer import QRNNLayer
from model import QRNNModel

import data.data_utils as data_utils
from data.data_utils import fopen
from data.data_utils import load_inv_dict
from data.data_utils import seq2words

from data.data_iterator import TextIterator
from data.data_iterator import BiTextIterator
from data.data_iterator import prepare_batch
from data.data_iterator import prepare_train_batch

use_cuda = torch.cuda.is_available()

In [2]:
# Data loading parameters
src_vocab='data/train.en.json'
tgt_vocab='data/train.de.json'
# src_train='data/train.clean.en'
# tgt_train='data/train.clean.fr'
# src_valid='data/train.clean.en'
# tgt_valid='data/train.clean.fr'
src_train='data/test.en'
tgt_train='data/test.de'
src_valid='data/test.en'
tgt_valid='data/test.de'

# Network parameters
kernel_size = 2
hidden_size = 10
num_layers = 2
emb_size = 500
num_enc_symbols = 30000
num_dec_symbols = 30000
dropout_rate = 0.3

# Training parameters
lr = 0.0002
max_grad_norm = 1.0
batch_size = 128
max_epochs = 1000
maxi_batches = 20
max_seq_len = 18
display_freq = 100
save_freq = 100
valid_freq = 100
model_dir = 'model/'
model_name = 'model.pkl'
shuffle = True
sort_by_len = True

# Decoding parameters
model_path = 'model/model.pkl'
decode_input = 'data/test.en'
decode_output = 'data/test.en.trans'
max_decode_step = 20

In [3]:
def load_model():
    if os.path.exists(model_path):
        print 'Reloading model parameters..'
        checkpoint = torch.load(model_path)
        model = QRNNModel(QRNNLayer, num_layers, kernel_size,
                          hidden_size, emb_size, 
                          num_enc_symbols, num_dec_symbols)
        model.load_state_dict(checkpoint['state_dict'])
    else:
        raise ValueError('No such file:[{}]'.format(model_path))

    return model, checkpoint

In [4]:
model, config = load_model()

# Load source data to decode
test_set = TextIterator(source=decode_input,
                        source_dict=src_vocab,
                        batch_size=batch_size,
                        n_words_source=num_enc_symbols,
                        maxlen=None)

valid_set = BiTextIterator(source=src_valid,
                           target=tgt_valid,
                           source_dict=src_vocab,
                           target_dict=tgt_vocab,
                           batch_size=batch_size,
                           maxlen=None,
                           shuffle_each_epoch=False,
                           n_words_source=num_enc_symbols,
                           n_words_target=num_dec_symbols)

target_inv_dict = load_inv_dict(tgt_vocab)

if use_cuda:
    print 'Using gpu..'
    model = model.cuda()

try:
    print 'Decoding starts..'
    fout = fopen(decode_output, 'w')
    #for idx, source_seq in enumerate(test_set):
    
    for idx, (source_seq, target_seq) in enumerate(valid_set):
        # source, source_len = prepare_batch(source_seq)
        # Get a batch from training parallel data
        source, source_len, dec_input, dec_target, dec_len = \
            prepare_train_batch(source_seq, target_seq, max_seq_len)
        
        print 'source', source
        print 'source_len', source_len
        print 'dec_input', dec_input
        print 'dec_target', dec_target

        preds_prev = torch.zeros(len(source), max_decode_step).long()
        preds_prev[:,0] += data_utils.start_token
        preds = torch.zeros(len(source), max_decode_step).long()

        if use_cuda:
            source = Variable(source.cuda())
            source_len = Variable(source_len.cuda())
            preds_prev = Variable(preds_prev.cuda())
            preds = preds.cuda()
            dec_input = Variable(dec_input.cuda())
        else:
            source = Variable(source)
            source_len = Variable(source_len)
            preds_prev = Variable(preds_prev)

        states, memories = model.encode(source, source_len)
        
        print 'dec_input'
        _, logits_ = model.decode(dec_input, states, memories)
        print logits_
        
        print 'dec_input[:,:1]'
        _, logits__ = model.decode(dec_input[:,:1], states, memories)
        print logits__
        
        
        
#        print 'states', states
#        print 'memories', memories
        
#         for t in xrange(max_decode_step):
#             # logits: [batch_size x max_decode_step, tgt_vocab_size]
#             #print 'preds_prev', preds_prev[:,:t+1]
#             _, logits = model.decode(dec_input, states, memories)
#             # outputs: [batch_size, max_decode_step]
#             outputs = torch.max(logits, dim=1)[1].view(len(source), -1)
#             print 'preds', outputs
#             preds[:,t] = outputs[:,t].data
#             if t < max_decode_step - 1:
#                 preds_prev[:,t+1] = outputs[:,t]

#         for i in xrange(len(preds)):
#             fout.write(str(seq2words(preds[i], target_inv_dict)) + '\n')
#             fout.flush()

#        print '  {}th line decoded'.format(idx * batch_size)
#    print 'Decoding terminated'

except IOError:
    pass
finally:
    fout.close()

Reloading model parameters..
Using gpu..
Decoding starts..
source 
   31   170  1809     9   898    12   862    40    45  1112    98     3     3
   14    30    38   138  2080   120  1619  2079   389  2078     3     3     3
 1592     5    29   647     5    14   367    43   743     3     3     3     3
 1961    25    18   326    15    42     5   146    36     3     3     3     3
 1280     6     4   148  1942    68   943    10   931  1170     3     3     3
  519    34    18   894    70     4   725     9  1555     3     3     3     3
 1298  1184     5   171     5    15    16   862    40     3     3     3     3
   14  1915  1925     4  1662     6     4    33    44  1920    17  1398  1387
  519    21   892    12    96    17    16   353    10     4     3     3     3
 1374     5    22    42    34    21   531     5     4     3     3     3     3
  294     4   170  1626     5    14   139    69    22    29     3     3     3
   83     4  2254     5    14    39    38     8  2249     3     3     3    