In [None]:
import sys
from utils import *
from opennmt.utils.misc import count_lines

In [None]:
src = './data/unsupervised-nmt-enfr-dev/train.en.10k'
tgt = './data/unsupervised-nmt-enfr-dev/train.fr.10k'
src_trans = './data/unsupervised-nmt-enfr-dev/train.en.10k.m1'
tgt_trans = './data/unsupervised-nmt-enfr-dev/train.fr.10k.m1'
src_vocab_path = './data/unsupervised-nmt-enfr-dev/en-vocab.txt'
tgt_vocab_path = './data/unsupervised-nmt-enfr-dev/fr-vocab.txt'
src_emb = './data/unsupervised-nmt-enfr/wmt14m.en300.vec'
tgt_emb = './data/unsupervised-nmt-enfr/wmt14m.fr300.vec'
model_dir = './trained_model'

In [None]:
src_vocab_size = count_lines(src_vocab_path) + 1 # EOS
tgt_vocab_size = count_lines(tgt_vocab_path) + 1 # EOS

In [None]:
src_vocab = tf.contrib.lookup.index_table_from_file(src_vocab_path, vocab_size=src_vocab_size - 1, num_oov_buckets=1)
tgt_vocab = tf.contrib.lookup.index_table_from_file(tgt_vocab_path, vocab_size=tgt_vocab_size - 1, num_oov_buckets=1)

In [None]:
with tf.device("/cpu:0"):  # Input pipeline should always be place on the CPU.
    src_iterator = load_data(src, src_trans, src_vocab, tgt_vocab)
    tgt_iterator = load_data(tgt, tgt_trans, tgt_vocab, src_vocab)
    src = src_iterator.get_next()
    tgt = tgt_iterator.get_next()

In [None]:
with tf.variable_scope("src"):
    src_emb = load_embeddings(src_emb, src_vocab_path)

with tf.variable_scope("tgt"):
    tgt_emb = load_embeddings(tgt_emb, tgt_vocab_path)

In [None]:
hidden_size = 512
encoder = onmt.encoders.BidirectionalRNNEncoder(2, hidden_size)
decoder = onmt.decoders.AttentionalRNNDecoder(2, hidden_size, bridge=onmt.layers.CopyBridge())

src_encoder_auto = add_noise_and_encode(src["ids"], src["length"], src_emb, encoder, reuse=None)
tgt_encoder_auto = add_noise_and_encode(tgt["ids"], tgt["length"], tgt_emb, encoder, reuse=True)

src_encoder_cross = add_noise_and_encode(tgt["trans_ids"], tgt["trans_length"], src_emb, encoder, reuse=True)
tgt_encoder_cross = add_noise_and_encode(src["trans_ids"], src["trans_length"], tgt_emb, encoder, reuse=True)

In [None]:
with tf.variable_scope("src"):
    src_gen = tf.layers.Dense(src_vocab_size)
    src_gen.build([None, hidden_size])

with tf.variable_scope("tgt"):
    tgt_gen = tf.layers.Dense(tgt_vocab_size)
    tgt_gen.build([None, hidden_size])

l_auto_src = denoise(src, src_emb, src_encoder_auto, src_gen, decoder, reuse=None)
l_auto_tgt = denoise(tgt, tgt_emb, tgt_encoder_auto, tgt_gen, decoder, reuse=True)

l_cd_src = denoise(src, src_emb, tgt_encoder_cross, src_gen, decoder, reuse=True)
l_cd_tgt = denoise(tgt, tgt_emb, src_encoder_cross, tgt_gen, decoder, reuse=True)

In [None]:
batch_size = tf.shape(src["length"])[0]
all_encoder_outputs = [src_encoder_auto, src_encoder_cross, tgt_encoder_auto, tgt_encoder_cross]
lang_ids = tf.concat([tf.fill([batch_size * 2], 0), tf.fill([batch_size * 2], 1)], 0)

max_time = tf.reduce_max([tf.shape(output[0])[1] for output in all_encoder_outputs])

encodings = tf.concat([pad_in_time(output[0], max_time - tf.shape(output[0])[1]) for output in all_encoder_outputs], 0)
sequence_lengths = tf.concat([output[2] for output in all_encoder_outputs], 0)

with tf.variable_scope("discriminator"):
    l_d, l_adv = discriminator(encodings, sequence_lengths, lang_ids)

lambda_auto = 1
lambda_cd = 1
lambda_adv = 1

l_auto = l_auto_src + l_auto_tgt
l_cd = l_cd_src + l_cd_tgt

l_final = (lambda_auto * l_auto + lambda_cd * l_cd + lambda_adv * l_adv)

In [None]:
encdec_variables = []
discri_variables = []
for variable in tf.trainable_variables():
    if variable.name.startswith("discriminator"):
        discri_variables.append(variable)
    else:
        encdec_variables.append(variable)

global_step = tf.train.get_or_create_global_step()
train_op = build_train_op(global_step, encdec_variables, discri_variables, l_final, l_d)

In [None]:
i = 0
with tf.train.MonitoredTrainingSession(checkpoint_dir=model_dir) as sess:
    sess.run([src_iterator.initializer, tgt_iterator.initializer])
    while not sess.should_stop():
        if i % 2 == 0:
            _, step, _l_auto, _l_cd, _l_adv, _l = sess.run([train_op, global_step, l_auto, l_cd, l_adv, l_final])
            print("{} - l_auto = {}; l_cd = {}, l_adv = {}; l = {}".format(step, _l_auto, _l_cd, _l_adv, _l))
        else:
            _, step, _l_d = sess.run([train_op, global_step, l_d])
            print("{} - l_d = {}".format(step, _l_d))
        i += 1
        sys.stdout.flush()