In [1]:
#!/usr/bin/env python
# coding: utf-8

import torch
import torch.nn as nn
from torch.autograd import Variable

import os
import math
import time
import argparse
import numpy as np
from layer import QRNNLayer
from model import QRNNModel

import data.data_utils as data_utils
from data.data_iterator import BiTextIterator
from data.data_iterator import prepare_batch
from data.data_iterator import prepare_train_batch

In [2]:
# Data loading parameters
src_vocab='data/train.en.json'
tgt_vocab='data/train.de.json'
# src_train='data/train.clean.en'
# tgt_train='data/train.clean.fr'
# src_valid='data/train.clean.en'
# tgt_valid='data/train.clean.fr'
src_train='data/test.en'
tgt_train='data/test.de'
src_valid='data/test.en'
tgt_valid='data/test.de'

# Network parameters
kernel_size = 2
hidden_size = 10
num_layers = 2
emb_size = 500
num_enc_symbols = 30000
num_dec_symbols = 30000
dropout_rate = 0.3

# Training parameters
lr = 0.0002
max_grad_norm = 1.0
batch_size = 128
max_epochs = 10000
maxi_batches = 20
max_seq_len = 50
display_freq = 100
save_freq = 100
valid_freq = 100
model_dir = 'model/'
model_name = 'model.pkl'
shuffle = True
sort_by_len = True

In [3]:
use_cuda = torch.cuda.is_available()

def create_model():
    print 'Creating new model parameters..'
    model = QRNNModel(QRNNLayer, num_layers, kernel_size,
    	              hidden_size, emb_size, 
    	              num_enc_symbols, num_dec_symbols)

    # Initialize a model state
#    model_state = vars(config)
    model_state = {}
    model_state['epoch'], model_state['train_steps'] = 0, 0
    model_state['state_dict'] = None
    
    model_path = os.path.join(model_dir, model_name)
    if os.path.exists(model_path):
        print 'Reloading model parameters..'
        checkpoint = torch.load(model_path)

        model_state['epoch'] = checkpoint['epoch']
        model_state['train_steps'] = checkpoint['train_steps']
        model.load_state_dict(checkpoint['state_dict'])

    return model, model_state

In [4]:
# Load parallel data to train
# TODO: using PyTorch DataIterator
print 'Loading training data..'
train_set = BiTextIterator(source=src_train,
                           target=tgt_train,
                           source_dict=src_vocab,
                           target_dict=tgt_vocab,
                           batch_size=batch_size,
                           maxlen=max_seq_len,
                           n_words_source=num_enc_symbols,
                           n_words_target=num_dec_symbols,
                           shuffle_each_epoch=shuffle,
                           sort_by_length=sort_by_len,
                           maxibatch_size=maxi_batches)

if src_valid and tgt_valid:
    print 'Loading validation data..'
    valid_set = BiTextIterator(source=src_valid,
                               target=tgt_valid,
                               source_dict=src_vocab,
                               target_dict=tgt_vocab,
                               batch_size=batch_size,
                               maxlen=None,
                               shuffle_each_epoch=False,
                               n_words_source=num_enc_symbols,
                               n_words_target=num_dec_symbols)
else:
    valid_set = None

# Create a Quasi-RNN model
model, model_state = create_model()
if use_cuda:
    print 'Using gpu..'
    model = model.cuda()

# Loss and Optimizer
criterion = nn.CrossEntropyLoss(ignore_index=data_utils.pad_token)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

loss = 0.0
words_seen, sents_seen = 0, 0
start_time = time.time()

# Training loop
print 'Training..'
for epoch_idx in xrange(max_epochs):
    if model_state['epoch'] >= max_epochs:
        print 'Training is already complete.', \
              'current epoch:{}, max epoch:{}'.format(model_state['epoch'], max_epochs)
        break

    for source_seq, target_seq in train_set:    
        # Get a batch from training parallel data
        enc_input, enc_len, dec_input, dec_target, dec_len = \
            prepare_train_batch(source_seq, target_seq, max_seq_len)

        if use_cuda:
            enc_input = Variable(enc_input.cuda())
            enc_len = Variable(enc_len.cuda())
            dec_input = Variable(dec_input.cuda())
            dec_target = Variable(dec_target.cuda())
            dec_len = Variable(dec_len.cuda())
        else:
            enc_input = Variable(enc_input)
            enc_len = Variable(enc_len)
            dec_input = Variable(dec_input)
            dec_target = Variable(dec_target)
            dec_len = Variable(dec_len)

        if enc_input is None or dec_input is None or dec_target is None:
            print 'No samples under max_seq_length ', max_seq_len
            continue

        # Execute a single training step
        optimizer.zero_grad()
        dec_logits = model(enc_input, enc_len, dec_input)
        step_loss = criterion(dec_logits, dec_target.view(-1))
        step_loss.backward()
        nn.utils.clip_grad_norm(model.parameters(), max_grad_norm)
        optimizer.step()

        loss += float(step_loss.data[0]) / display_freq
        words_seen += torch.sum(enc_len + dec_len).data[0]
        sents_seen += enc_input.size(0)  # batch_size

        model_state['train_steps'] += 1

        # Display training status
        if model_state['train_steps'] % display_freq == 0:

            avg_perplexity = math.exp(float(loss)) if loss < 300 else float("inf")
            time_elapsed = time.time() - start_time
            step_time = time_elapsed / display_freq

            words_per_sec = words_seen / time_elapsed
            sents_per_sec = sents_seen / time_elapsed

            print 'Epoch ', model_state['epoch'], 'Step ', model_state['train_steps'], \
                  'Perplexity {0:.2f}'.format(avg_perplexity), 'Step-time {0:.2f}'.format(step_time), \
                  '{0:.2f} sents/s'.format(sents_per_sec), '{0:.2f} words/s'.format(words_per_sec)

            loss = 0.0
            words_seen, sents_seen = 0, 0
            start_time = time.time()

        # Execute a validation process
        if valid_set and model_state['train_steps'] % valid_freq == 0:
            print 'Validation step'

            valid_steps = 0
            valid_loss = 0.0
            valid_sents_seen = 0
            for source_seq, target_seq in valid_set:
                # Get a batch from validation parallel data
                enc_input, enc_len, dec_input, dec_target, dec_len = \
                    prepare_train_batch(source_seq, target_seq)

                if use_cuda:
                    enc_input = Variable(enc_input.cuda())
                    enc_len = Variable(enc_len.cuda())
                    dec_input = Variable(dec_input.cuda())
                    dec_target = Variable(dec_target.cuda())
                    dec_len = Variable(dec_len.cuda())
                else:
                    enc_input = Variable(enc_input)
                    enc_len = Variable(enc_len)
                    dec_input = Variable(dec_input)
                    dec_target = Variable(dec_target)
                    dec_len = Variable(dec_len)

                dec_logits = model(enc_input, enc_len, dec_input)
                step_loss = criterion(dec_logits, dec_target.view(-1))
                print 'validation step loss', math.exp(step_loss.data[0])
                valid_steps += 1 
                valid_loss += float(step_loss.data[0])
                valid_sents_seen += enc_input.size(0)
                print '  {} samples seen'.format(valid_sents_seen)

            print 'Valid perplexity: {0:.2f}'.format(math.exp(valid_loss / valid_steps))

        # Save the model checkpoint
        if model_state['train_steps'] % save_freq == 0:
            print 'Saving the model..'

            model_state['state_dict'] = model.state_dict()
#                state = dict(list(model_state.items()))
            model_path = os.path.join(model_dir, model_name)
            torch.save(model_state, model_path)

    # Increase the epoch index of the model
    model_state['epoch'] += 1
    print 'Epoch {0:} DONE'.format(model_state['epoch'])

Loading training data..
Loading validation data..
Creating new model parameters..
Reloading model parameters..
Using gpu..
Training..
Epoch 800 DONE
Epoch 801 DONE
Epoch 802 DONE
Epoch 803 DONE
Epoch 804 DONE
Epoch 805 DONE
Epoch 806 DONE
Epoch 807 DONE
Epoch 808 DONE
Epoch 809 DONE
Epoch 810 DONE
Epoch 811 DONE
Epoch 812 DONE
Epoch 813 DONE
Epoch 814 DONE
Epoch 815 DONE
Epoch 816 DONE
Epoch 817 DONE
Epoch 818 DONE
Epoch 819 DONE
Epoch 820 DONE
Epoch 821 DONE
Epoch 822 DONE
Epoch 823 DONE
Epoch 824 DONE
Epoch 825 DONE
Epoch 826 DONE
Epoch 827 DONE
Epoch 828 DONE
Epoch 829 DONE
Epoch 830 DONE
Epoch 831 DONE
Epoch 832 DONE
Epoch 833 DONE
Epoch 834 DONE
Epoch 835 DONE
Epoch 836 DONE
Epoch 837 DONE
Epoch 838 DONE
Epoch 839 DONE
Epoch 840 DONE
Epoch 841 DONE
Epoch 842 DONE
Epoch 843 DONE
Epoch 844 DONE
Epoch 845 DONE
Epoch 846 DONE
Epoch 847 DONE
Epoch 848 DONE
Epoch 849 DONE
Epoch 850 DONE
Epoch 851 DONE
Epoch 852 DONE
Epoch 853 DONE
Epoch 854 DONE
Epoch 855 DONE
Epoch 856 DONE
Epoch 857 D

Epoch 1270 DONE
Epoch 1271 DONE
Epoch 1272 DONE
Epoch 1273 DONE
Epoch 1274 DONE
Epoch 1275 DONE
Epoch 1276 DONE
Epoch 1277 DONE
Epoch 1278 DONE
Epoch 1279 DONE
Epoch 1280 DONE
Epoch 1281 DONE
Epoch 1282 DONE
Epoch 1283 DONE
Epoch 1284 DONE
Epoch 1285 DONE
Epoch 1286 DONE
Epoch 1287 DONE
Epoch 1288 DONE
Epoch 1289 DONE
Epoch 1290 DONE
Epoch 1291 DONE
Epoch 1292 DONE
Epoch 1293 DONE
Epoch 1294 DONE
Epoch 1295 DONE
Epoch 1296 DONE
Epoch 1297 DONE
Epoch 1298 DONE
Epoch  1298 Step  1300 Perplexity 5.58 Step-time 0.11 140.63 sents/s 2418.76 words/s
Validation step
validation step loss 4.74394651612
  15 samples seen
Valid perplexity: 4.74
Saving the model..
Epoch 1299 DONE
Epoch 1300 DONE
Epoch 1301 DONE
Epoch 1302 DONE
Epoch 1303 DONE
Epoch 1304 DONE
Epoch 1305 DONE
Epoch 1306 DONE
Epoch 1307 DONE
Epoch 1308 DONE
Epoch 1309 DONE
Epoch 1310 DONE
Epoch 1311 DONE
Epoch 1312 DONE
Epoch 1313 DONE
Epoch 1314 DONE
Epoch 1315 DONE
Epoch 1316 DONE
Epoch 1317 DONE
Epoch 1318 DONE
Epoch 1319 DONE
Epoc

Epoch 1722 DONE
Epoch 1723 DONE
Epoch 1724 DONE
Epoch 1725 DONE
Epoch 1726 DONE
Epoch 1727 DONE
Epoch 1728 DONE
Epoch 1729 DONE
Epoch 1730 DONE
Epoch 1731 DONE
Epoch 1732 DONE
Epoch 1733 DONE
Epoch 1734 DONE
Epoch 1735 DONE
Epoch 1736 DONE
Epoch 1737 DONE
Epoch 1738 DONE
Epoch 1739 DONE
Epoch 1740 DONE
Epoch 1741 DONE
Epoch 1742 DONE
Epoch 1743 DONE
Epoch 1744 DONE
Epoch 1745 DONE
Epoch 1746 DONE
Epoch 1747 DONE
Epoch 1748 DONE
Epoch 1749 DONE
Epoch 1750 DONE
Epoch 1751 DONE
Epoch 1752 DONE
Epoch 1753 DONE
Epoch 1754 DONE
Epoch 1755 DONE
Epoch 1756 DONE
Epoch 1757 DONE
Epoch 1758 DONE
Epoch 1759 DONE
Epoch 1760 DONE
Epoch 1761 DONE
Epoch 1762 DONE
Epoch 1763 DONE
Epoch 1764 DONE
Epoch 1765 DONE
Epoch 1766 DONE
Epoch 1767 DONE
Epoch 1768 DONE
Epoch 1769 DONE
Epoch 1770 DONE
Epoch 1771 DONE
Epoch 1772 DONE
Epoch 1773 DONE
Epoch 1774 DONE
Epoch 1775 DONE
Epoch 1776 DONE
Epoch 1777 DONE
Epoch 1778 DONE
Epoch 1779 DONE
Epoch 1780 DONE
Epoch 1781 DONE
Epoch 1782 DONE
Epoch 1783 DONE
Epoch 17

Epoch 2186 DONE
Epoch 2187 DONE
Epoch 2188 DONE
Epoch 2189 DONE
Epoch 2190 DONE
Epoch 2191 DONE
Epoch 2192 DONE
Epoch 2193 DONE
Epoch 2194 DONE
Epoch 2195 DONE
Epoch 2196 DONE
Epoch 2197 DONE
Epoch 2198 DONE
Epoch  2198 Step  2200 Perplexity 1.21 Step-time 0.10 144.31 sents/s 2482.21 words/s
Validation step
validation step loss 1.18624368234
  15 samples seen
Valid perplexity: 1.19
Saving the model..
Epoch 2199 DONE
Epoch 2200 DONE
Epoch 2201 DONE
Epoch 2202 DONE
Epoch 2203 DONE
Epoch 2204 DONE
Epoch 2205 DONE
Epoch 2206 DONE
Epoch 2207 DONE
Epoch 2208 DONE
Epoch 2209 DONE
Epoch 2210 DONE
Epoch 2211 DONE
Epoch 2212 DONE
Epoch 2213 DONE
Epoch 2214 DONE
Epoch 2215 DONE
Epoch 2216 DONE
Epoch 2217 DONE
Epoch 2218 DONE
Epoch 2219 DONE
Epoch 2220 DONE
Epoch 2221 DONE
Epoch 2222 DONE
Epoch 2223 DONE
Epoch 2224 DONE
Epoch 2225 DONE
Epoch 2226 DONE
Epoch 2227 DONE
Epoch 2228 DONE
Epoch 2229 DONE
Epoch 2230 DONE
Epoch 2231 DONE
Epoch 2232 DONE
Epoch 2233 DONE
Epoch 2234 DONE
Epoch 2235 DONE
Epoc

KeyboardInterrupt: 