In [4]:
%load_ext autoreload
%autoreload 2
import tensorflow as tf, re, time, math
from wmt_data import *
from seq2seq_model import *

def create_model(V_en, V_fr, batch_size, buckets):
    state_size = 256; num_layers = 1
    max_gradient_norm = 5.0; batch_size = 128
    learning_rate = 0.5; num_samples = 512
    forward_only = True
    return Seq2SeqModel(V_en, V_fr, buckets, state_size, num_layers, max_gradient_norm, batch_size, learning_rate, num_samples,
      forward_only, dtype=tf.float32)

V_en = 50000; V_fr = 50000; batch_size = 128
max_train_data_size = 0
buckets = [(5,10), (20,30)]
en_train, fr_train, en_dev, fr_dev, en_vocab_path, fr_vocab_path = prepare_wmt_data(V_en, V_fr)

dev_set = read_data(en_dev, fr_dev, buckets)
train_set = read_data(en_train, fr_train, buckets, max_train_data_size)

model = create_model(V_en, V_fr, batch_size, buckets)


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
def train():
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        steps_per_checkpoint = 200

        train_bucket_sizes = [len(train_set[b]) for b in xrange(len(buckets))]
        train_total_size = float(sum(train_bucket_sizes))

        train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size for i in xrange(len(train_bucket_sizes))]

        # This is the training loop.
        step_time, loss = 0.0, 0.0
        current_step = 0
        previous_losses = []
        while True:
            # Choose a bucket according to data distribution. We pick a random number
            # in [0, 1] and use the corresponding interval in train_buckets_scale.
            random_number_01 = np.random.random_sample() # Smart
            bucket_id = min([i for i in xrange(len(train_buckets_scale)) if train_buckets_scale[i] > random_number_01])

            # Get a batch and make a step.
            start_time = time.time()
            encoder_inputs, decoder_inputs, target_weights = get_batch(train_set, batch_size, buckets, bucket_id)
            _, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, False)
            step_time += (time.time() - start_time) / steps_per_checkpoint
            loss += step_loss / steps_per_checkpoint
            current_step += 1

            # Once in a while, we save checkpoint, print statistics, and run evals.
            if current_step % steps_per_checkpoint == 0:
                # Print statistics for the previous epoch.
                perplexity = math.exp(float(loss)) if loss < 300 else float("inf")
                print ("global step %d learning rate %.4f step-time %.2f perplexity %.2f" % (model.global_step.eval(), model.learning_rate, step_time, perplexity))
                # Decrease learning rate if no improvement was seen over last 3 times.
#                 if len(previous_losses) > 2 and loss > max(previous_losses[-3:]):
#                     sess.run(model.learning_rate_decay_op)
                previous_losses.append(loss)
                # Save checkpoint and zero timer and loss.
                checkpoint_path = os.path.join("wmt/translate.ckpt")
                model.saver.save(sess, checkpoint_path, global_step=model.global_step)
                step_time, loss = 0.0, 0.0
                # Run evals on development set and print their perplexity.
                for bucket_id in xrange(len(buckets)):
                    if len(dev_set[bucket_id]) == 0:
                        print("  eval: empty bucket %d" % (bucket_id))
                        continue
                    encoder_inputs, decoder_inputs, target_weights = get_batch(dev_set, batch_size, buckets, bucket_id)
                    _, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True)
                    eval_ppx = math.exp(float(eval_loss)) if eval_loss < 300 else float("inf")
                    print("  eval: bucket %d perplexity %.2f" % (bucket_id, eval_ppx))
                sys.stdout.flush()
train()

In [28]:
# See some translations
saver = tf.train.Saver()

vocab_en, rev_vocab_en = initialize_vocabulary(en_vocab_path)
vocab_fr, rev_vocab_fr = initialize_vocabulary(fr_vocab_path)

def buildSentence(ids, vocab):
    sent = " ".join([vocab[i] for i in ids])
    return sent.replace('_PAD', '').replace('_EOS', '')

with tf.Session() as sess:
    
    saver.restore(sess, "wmt/translate.ckpt-41000")
    print("Model restored.")
    
    for bucket_id in xrange(len(buckets)):
        encoder_inputs, decoder_inputs, target_weights = get_batch(dev_set, batch_size, buckets, bucket_id)
        _, loss, outputs = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True)
        english_sentences = np.array(encoder_inputs).T
        french_sentences = np.argmax(np.array(outputs), axis=2).T
        for i in range(5):
            print buildSentence(reversed(english_sentences[i]), rev_vocab_en)
            print buildSentence(french_sentences[i], rev_vocab_fr)
            print "------------------------------"


Model restored.
_UNK .   
_UNK .  .   .   
------------------------------
Visit unique museums  
_UNK         
------------------------------
_UNK .   
_UNK .  .   .   
------------------------------
A disappointment .  
A .  .  .  .  
------------------------------
My _UNK identity crisis 
_UNK _UNK _UNK       
------------------------------
This is palliative care , given when there is nothing else that can be done .    
Il est donc de façon que les changements de la façon de la façon de façon de façon de façon .  .   .    
------------------------------
With _UNK , pay and sell without banks            
_UNK , _UNK et _UNK , _UNK                       
------------------------------
Many of the foreigners assert that receiving official status in our country is not that easy .   
En 0000 , les pays qui ne sont pas de plus de plus en plus de plus de plus en plus de plus en plus .  .  
------------------------------
The buyer and seller often find each other through friends .         


In [3]:
# We can load data in batches ... :)
vocab_en, rev_vocab_en = initialize_vocabulary(en_vocab_path)
vocab_fr, rev_vocab_fr = initialize_vocabulary(fr_vocab_path)

batch_size = 10
buckets = [(20,30)]
data = read_data(en_train_ids_path, fr_train_ids_path, buckets, 200000) # In buckets

batch_encoder_inputs, batch_decoder_inputs, batch_weights = get_batch(data, batch_size, buckets, 0)

# for j in range(batch_size):
#     sent = [rev_vocab_en[i] for i in reversed([w[j] for w in batch_encoder_inputs])]
#     sent_fr = [rev_vocab_fr[i] for i in [w[j] for w in batch_decoder_inputs]]
#     print " ".join(sent)
#     print " ".join(sent_fr[1:])
#     print "-------------------------------"

  reading data line 100000
  reading data line 200000
