In [1]:
#!/usr/bin/env python
# coding: utf-8

import torch
import torch.nn as nn
from torch.autograd import Variable

import os
import math
import time
import argparse
import numpy as np
from layer import QRNNLayer
from model import QRNNModel

import data.data_utils as data_utils
from data.data_iterator import BiTextIterator
from data.data_iterator import prepare_batch
from data.data_iterator import prepare_train_batch

In [2]:
# Data loading parameters
src_vocab='data/train.clean.en.json'
tgt_vocab='data/train.clean.fr.json'
# src_train='data/train.clean.en'
# tgt_train='data/train.clean.fr'
# src_valid='data/train.clean.en'
# tgt_valid='data/train.clean.fr'
src_train='data/test.en'
tgt_train='data/test.fr'
src_valid='data/test.en'
tgt_valid='data/test.fr'

# Network parameters
kernel_size = 2
hidden_size = 1024
num_layers = 2
emb_size = 500
num_enc_symbols = 30000
num_dec_symbols = 30000
dropout_rate = 0.3

# Training parameters
lr = 0.0002
max_grad_norm = 1.0
batch_size = 128
max_epochs = 1000
maxi_batches = 20
max_seq_len = 50
display_freq = 100
save_freq = 100
valid_freq = 100
model_dir = 'model/'
model_name = 'model.pkl'
shuffle = True
sort_by_len = True

In [3]:
use_cuda = torch.cuda.is_available()

def create_model():
    print 'Creating new model parameters..'
    model = QRNNModel(QRNNLayer, num_layers, kernel_size,
    	              hidden_size, emb_size, 
    	              num_enc_symbols, num_dec_symbols)

    # Initialize a model state
#    model_state = vars(config)
    model_state = {}
    model_state['epoch'], model_state['train_steps'] = 0, 0
    model_state['state_dict'] = None
    
    model_path = os.path.join(model_dir, model_name)
    if os.path.exists(model_path):
        print 'Reloading model parameters..'
        checkpoint = torch.load(model_path)

        model_state['epoch'] = checkpoint['epoch']
        model_state['train_steps'] = checkpoint['train_steps']
        model.load_state_dict(checkpoint['state_dict'])

    return model, model_state

In [4]:
# Load parallel data to train
# TODO: using PyTorch DataIterator
print 'Loading training data..'
train_set = BiTextIterator(source=src_train,
                           target=tgt_train,
                           source_dict=src_vocab,
                           target_dict=tgt_vocab,
                           batch_size=batch_size,
                           maxlen=max_seq_len,
                           n_words_source=num_enc_symbols,
                           n_words_target=num_dec_symbols,
                           shuffle_each_epoch=shuffle,
                           sort_by_length=sort_by_len,
                           maxibatch_size=maxi_batches)

if src_valid and tgt_valid:
    print 'Loading validation data..'
    valid_set = BiTextIterator(source=src_valid,
                               target=tgt_valid,
                               source_dict=src_vocab,
                               target_dict=tgt_vocab,
                               batch_size=batch_size,
                               maxlen=None,
                               shuffle_each_epoch=False,
                               n_words_source=num_enc_symbols,
                               n_words_target=num_dec_symbols)
else:
    valid_set = None

# Create a Quasi-RNN model
model, model_state = create_model()
if use_cuda:
    print 'Using gpu..'
    model = model.cuda()

# Loss and Optimizer
criterion = nn.CrossEntropyLoss(ignore_index=data_utils.pad_token)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

loss = 0.0
words_seen, sents_seen = 0, 0
start_time = time.time()

# Training loop
print 'Training..'
for epoch_idx in xrange(max_epochs):
    if model_state['epoch'] >= max_epochs:
        print 'Training is already complete.', \
              'current epoch:{}, max epoch:{}'.format(model_state['epoch'], max_epochs)
        break

    for source_seq, target_seq in train_set:    
        # Get a batch from training parallel data
        enc_input, enc_len, dec_input, dec_target, dec_len = \
            prepare_train_batch(source_seq, target_seq, max_seq_len)

        if use_cuda:
            enc_input = Variable(enc_input.cuda())
            enc_len = Variable(enc_len.cuda())
            dec_input = Variable(dec_input.cuda())
            dec_target = Variable(dec_target.cuda())
            dec_len = Variable(dec_len.cuda())
        else:
            enc_input = Variable(enc_input)
            enc_len = Variable(enc_len)
            dec_input = Variable(dec_input)
            dec_target = Variable(dec_target)
            dec_len = Variable(dec_len)

        if enc_input is None or dec_input is None or dec_target is None:
            print 'No samples under max_seq_length ', max_seq_len
            continue

        # Execute a single training step
        optimizer.zero_grad()
        dec_logits = model(enc_input, enc_len, dec_input)
        step_loss = criterion(dec_logits, dec_target.view(-1))
        step_loss.backward()
        nn.utils.clip_grad_norm(model.parameters(), max_grad_norm)
        optimizer.step()

        loss += float(step_loss.data[0]) / display_freq
        words_seen += torch.sum(enc_len + dec_len).data[0]
        sents_seen += enc_input.size(0)  # batch_size

        model_state['train_steps'] += 1

        # Display training status
        if model_state['train_steps'] % display_freq == 0:

            avg_perplexity = math.exp(float(loss)) if loss < 300 else float("inf")
            time_elapsed = time.time() - start_time
            step_time = time_elapsed / display_freq

            words_per_sec = words_seen / time_elapsed
            sents_per_sec = sents_seen / time_elapsed

            print 'Epoch ', model_state['epoch'], 'Step ', model_state['train_steps'], \
                  'Perplexity {0:.2f}'.format(avg_perplexity), 'Step-time {0:.2f}'.format(step_time), \
                  '{0:.2f} sents/s'.format(sents_per_sec), '{0:.2f} words/s'.format(words_per_sec)

            loss = 0.0
            words_seen, sents_seen = 0, 0
            start_time = time.time()

        # Execute a validation process
        if valid_set and model_state['train_steps'] % valid_freq == 0:
            print 'Validation step'

            valid_steps = 0
            valid_loss = 0.0
            valid_sents_seen = 0
            for source_seq, target_seq in valid_set:
                # Get a batch from validation parallel data
                enc_input, enc_len, dec_input, dec_target, dec_len = \
                    prepare_train_batch(source_seq, target_seq)

                if use_cuda:
                    enc_input = Variable(enc_input.cuda())
                    enc_len = Variable(enc_len.cuda())
                    dec_input = Variable(dec_input.cuda())
                    dec_target = Variable(dec_target.cuda())
                    dec_len = Variable(dec_len.cuda())
                else:
                    enc_input = Variable(enc_input)
                    enc_len = Variable(enc_len)
                    dec_input = Variable(dec_input)
                    dec_target = Variable(dec_target)
                    dec_len = Variable(dec_len)

                dec_logits = model(enc_input, enc_len, dec_input)
                step_loss = criterion(dec_logits, dec_target.view(-1))
                print 'validation step loss', math.exp(step_loss.data[0])
                valid_steps += 1 
                valid_loss += float(step_loss.data[0])
                valid_sents_seen += enc_input.size(0)
                print '  {} samples seen'.format(valid_sents_seen)

            print 'Valid perplexity: {0:.2f}'.format(math.exp(valid_loss / valid_steps))

        # Save the model checkpoint
        if model_state['train_steps'] % save_freq == 0:
            print 'Saving the model..'

            model_state['state_dict'] = model.state_dict()
#                state = dict(list(model_state.items()))
            model_path = os.path.join(model_dir, model_name)
            torch.save(model_state, model_path)

    # Increase the epoch index of the model
    model_state['epoch'] += 1
    print 'Epoch {0:} DONE'.format(model_state['epoch'])

Loading training data..
Loading validation data..
Creating new model parameters..
Using gpu..
Training..
Epoch 1 DONE
Epoch 2 DONE
Epoch 3 DONE
Epoch 4 DONE
Epoch 5 DONE
Epoch 6 DONE
Epoch 7 DONE
Epoch 8 DONE
Epoch 9 DONE
Epoch 10 DONE
Epoch 11 DONE
Epoch 12 DONE
Epoch 13 DONE
Epoch 14 DONE
Epoch 15 DONE
Epoch 16 DONE
Epoch 17 DONE
Epoch 18 DONE
Epoch 19 DONE
Epoch 20 DONE
Epoch 21 DONE
Epoch 22 DONE
Epoch 23 DONE
Epoch 24 DONE
Epoch 25 DONE
Epoch 26 DONE
Epoch 27 DONE
Epoch 28 DONE
Epoch 29 DONE
Epoch 30 DONE
Epoch 31 DONE
Epoch 32 DONE
Epoch 33 DONE
Epoch 34 DONE
Epoch 35 DONE
Epoch 36 DONE
Epoch 37 DONE
Epoch 38 DONE
Epoch 39 DONE
Epoch 40 DONE
Epoch 41 DONE
Epoch 42 DONE
Epoch 43 DONE
Epoch 44 DONE
Epoch 45 DONE
Epoch 46 DONE
Epoch 47 DONE
Epoch 48 DONE
Epoch 49 DONE
Epoch 50 DONE
Epoch 51 DONE
Epoch 52 DONE
Epoch 53 DONE
Epoch 54 DONE
Epoch 55 DONE
Epoch 56 DONE
Epoch 57 DONE
Epoch 58 DONE
Epoch 59 DONE
Epoch 60 DONE
Epoch 61 DONE
Epoch 62 DONE
Epoch 63 DONE
Epoch 64 DONE
Epoch 65

KeyboardInterrupt: 