In [1]:
import os
import math
import time
import random
import numpy as np
import tensorflow as tf

In [2]:
from data.data_utils import prepare_data
from data.data_utils import TextIterator
from seq2seq_model import Seq2SeqModel

In [3]:
# Data loading parameters
tf.app.flags.DEFINE_string('source_vocabulary', 'data/en-fr/wmt15_fr-en.train.en.json', 'Path to source vocabulary')
tf.app.flags.DEFINE_string('target_vocabulary', 'data/en-fr/wmt15_fr-en.train.fr.json', 'Path to target vocabulary')
tf.app.flags.DEFINE_string('source_train_data', 'data/en-fr/wmt15_fr-en.train.en', 'Path to source training data')
tf.app.flags.DEFINE_string('target_train_data', 'data/en-fr/wmt15_fr-en.train.fr', 'Path to target training data')
tf.app.flags.DEFINE_string('source_valid_data', 'data/en-fr/newstest2013.tok.en', 'Path to source validation data')
tf.app.flags.DEFINE_string('target_valid_data', 'data/en-fr/newstest2013.tok.fr', 'Path to target validation data')

# Network parameters
tf.app.flags.DEFINE_string('cell_type', 'lstm', 'RNN cell to use for encoder and decoder')
tf.app.flags.DEFINE_string('attention_type', 'bahdanau', 'Attention mechanism: (bahdanau, luong)')
tf.app.flags.DEFINE_integer('hidden_units', 1024, 'Number of hidden units for each layer in the model')
tf.app.flags.DEFINE_integer('depth', 4, 'Number of layers for each encoder and decoder')
tf.app.flags.DEFINE_integer('embedding_size', 500, 'Embedding dimensions of encoder and decoder inputs')
tf.app.flags.DEFINE_integer('num_encoder_symbols', 30000, 'Source vocabulary size')
tf.app.flags.DEFINE_integer('num_decoder_symbols', 30000, 'Target vocabulary size')

tf.app.flags.DEFINE_boolean('use_residual', True, 'Use residual connection between layers')
tf.app.flags.DEFINE_boolean('input_feeding', True, 'Use input feeding method in attentional decoder')
tf.app.flags.DEFINE_boolean('use_dropout', True, 'Use dropout in each rnn cell')
tf.app.flags.DEFINE_float('dropout_keep_prob', 0.3, 'Dropout keep probability for input/output/state units (1.0: no dropout)')

# Training parameters
tf.app.flags.DEFINE_float('learning_rate', 0.0002, 'Learning rate')
tf.app.flags.DEFINE_float('max_gradient_norm', 1.0, 'Clip gradients to this norm')
tf.app.flags.DEFINE_integer('batch_size', 128, 'Batch size to use during training')
tf.app.flags.DEFINE_integer('max_epochs', 10, 'Maximum # of training epochs')
tf.app.flags.DEFINE_integer('max_load_batches', 20, 'Maximum # of batches to load at one time')
tf.app.flags.DEFINE_integer('max_seq_length', 50, 'Maximum sequence length')
tf.app.flags.DEFINE_integer('display_freq', 100, 'Display training status every this iteration')
tf.app.flags.DEFINE_integer('save_freq', 1000, 'Save model checkpoint every this iteration')
tf.app.flags.DEFINE_integer('valid_freq', 1000, 'Evaluate model every this iteration: valid_data needed')
tf.app.flags.DEFINE_string('optimizer', 'adam', 'Optimizer for training: (adadelta, adam, rmsprop)')
tf.app.flags.DEFINE_string('model_dir', 'model/', 'Path to save model checkpoints')
tf.app.flags.DEFINE_string('summary_dir', 'model/summary', 'Path to save model summary')
tf.app.flags.DEFINE_string('model_name', 'translate.ckpt', 'File name used for model checkpoints')
tf.app.flags.DEFINE_boolean('shuffle_each_epoch', True, 'Shuffle training dataset for each epoch')                            

# Runtime parameters
tf.app.flags.DEFINE_boolean('use_fp16', False, 'Use half precision float16 instead of float32 as dtype')
tf.app.flags.DEFINE_boolean('allow_soft_placement', True, 'Allow device soft placement')
tf.app.flags.DEFINE_boolean('log_device_placement', False, 'Log placement of ops on devices')

FLAGS = tf.app.flags.FLAGS

In [4]:
def load_model(session, FLAGS):
    model = Seq2SeqModel(FLAGS, 'decoding')
    
    ckpt = tf.train.get_checkpoint_state(FLAGS.model_dir)
    if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
        print 'Reloading model parameters..'
        model.restore(session, ckpt.model_checkpoint_path)
        
    else:
        print 'Tensorflow checkpoints do not exist'
        
    return model