In [1]:
import tensorflow as tf
import numpy as np
tf.enable_eager_execution()
import importlib
import os

from transliteration import data, train, model_one, script, decode

In [2]:
importlib.reload(data)
batch_size = 256
train_dataset = data.make_dataset('../data/tfrecord/train.tfrecord',
                                  from_script='en',
                                  to_script='ja',
                                  batch_size=batch_size)
valid_dataset = data.make_dataset('../data/tfrecord/valid.tfrecord',
                                  from_script='en',
                                  to_script='ja',
                                  batch_size=batch_size)
test_dataset = data.make_dataset('../data/tfrecord/test.tfrecord',
                                 from_script='en',
                                 to_script='ja',
                                 batch_size=batch_size)

In [2]:
batch_size = 128
cmu_train_dataset = data.make_dataset('../data/tfrecord/cmu_train.tfrecord',
                                  from_script='en',
                                  to_script='cmu',
                                  batch_size=batch_size)
cmu_valid_dataset = data.make_dataset('../data/tfrecord/cmu_valid.tfrecord',
                                  from_script='en',
                                  to_script='cmu',
                                  batch_size=batch_size)
cmu_test_dataset = data.make_dataset('../data/tfrecord/cmu_test.tfrecord',
                                 from_script='en',
                                 to_script='cmu',
                                 batch_size=batch_size)

In [15]:
importlib.reload(model_one)
optimizer = tf.train.AdamOptimizer()


def loss_function(real, pred):
    mask = 1 - np.equal(real, 0)
    loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred)
    return tf.reduce_mean(loss_ * mask)

encoder_config = model_one.Config(lstm_size=120,
                                  embedding_size=60,
                                  attention_size=None,
                                  vocab_size=script.SCRIPTS['en'].vocab_size)
decoder_config = model_one.Config(lstm_size=80,
                                  embedding_size=60,
                                  attention_size=60,
                                  attention='monotonic_bahdanau',
                                  vocab_size=script.SCRIPTS['ja'].vocab_size)
cmu_decoder_config = model_one.Config(lstm_size=80,
                                  embedding_size=60,
                                  attention_size=60,
                                  vocab_size=script.SCRIPTS['cmu'].vocab_size)
encoder = model_one.Encoder(encoder_config)
decoder = model_one.Decoder(decoder_config)
cmu_decoder = model_one.Decoder(cmu_decoder_config)

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 encoder=encoder,
                                 decoder=decoder,
                                 cmu_decoder=cmu_decoder)

In [None]:
importlib.reload(train)
best_val_loss = None
for e in range(200):
    loss = train.run_one_epoch(cmu_train_dataset,
                               True,
                               from_script='en',
                               to_script='cmu',
                               encoder=encoder,
                               decoder=cmu_decoder,
                               optimizer=optimizer,
                               loss_function=loss_function)
    valid_loss = train.run_one_epoch(cmu_valid_dataset,
                                     False,
                                     from_script='en',
                                     to_script='cmu',
                                     encoder=encoder,
                                     decoder=cmu_decoder,
                                     loss_function=loss_function)
    if best_val_loss is None or valid_loss < best_val_loss:
        best_val_loss = valid_loss
        checkpoint.save(file_prefix=checkpoint_prefix)
    print("Epoch {}: Train Loss {:.3f}, Valid Loss {:.3f}".format(e, loss, valid_loss))
    print(decode.transliterate(input_strs=['derick'],
                               from_script='en',
                               to_script='cmu',
                               encoder=encoder,
                               decoder=cmu_decoder,
                               k_best=2,
                               decoding_method=decode.beam_search_decode))

In [16]:
importlib.reload(train)
best_val_loss = None
for e in range(200):
    loss = train.run_one_epoch(train_dataset,
                               True,
                               from_script='en',
                               to_script='ja',
                               encoder=encoder,
                               decoder=decoder,
                               train_encoder=True,
                               optimizer=optimizer,
                               loss_function=loss_function)
    valid_loss = train.run_one_epoch(valid_dataset,
                                     False,
                                     from_script='en',
                                     to_script='ja',
                                     encoder=encoder,
                                     decoder=decoder,
                                     loss_function=loss_function)
    if best_val_loss is None or valid_loss < best_val_loss:
        best_val_loss = valid_loss
        checkpoint.save(file_prefix=checkpoint_prefix)
    print("Epoch {}: Train Loss {:.3f}, Valid Loss {:.3f}".format(e, loss, valid_loss))
    print(decode.transliterate(input_strs=['derick'],
                               from_script='en',
                               to_script='ja',
                               encoder=encoder,
                               decoder=decoder,
                               k_best=2,
                               decoding_method=decode.beam_search_decode))

Epoch 0: Train Loss 21.233, Valid Loss 19.356
([['ーー', 'ンー']], array([[-10.15840602, -10.16695213]]))


Epoch 1: Train Loss 18.729, Valid Loss 18.756
([['アルーー', 'アルーーー']], array([[-17.87236977, -20.87293935]]))


Epoch 2: Train Loss 18.207, Valid Loss 18.049
([['アランラン', 'アリンラン']], array([[-20.51289344, -20.73459411]]))


Epoch 3: Train Loss 17.319, Valid Loss 17.099
([['アラント', 'アリント']], array([[-15.59250736, -15.73684573]]))


Epoch 4: Train Loss 15.924, Valid Loss 15.251
([['ディリータート', 'ディアラータート']], array([[-22.44977331, -25.17119241]]))


Epoch 5: Train Loss 14.076, Valid Loss 13.448
([['ディリックト', 'ディレックト']], array([[-13.81076324, -13.97866535]]))


Epoch 6: Train Loss 12.183, Valid Loss 12.144
([['ディックト', 'ディリックト']], array([[ -9.55809736, -11.36953717]]))


Epoch 7: Train Loss 10.649, Valid Loss 10.246
([['ドリック', 'ディリック']], array([[-7.21610558, -8.3488006 ]]))


Epoch 8: Train Loss 9.551, Valid Loss 9.673
([['ドリック', 'ディック']], array([[-5.58030021, -6.59443706]]))


Epoch 9: Train Loss 8.830, Valid Loss 8.823
([['ディック', 'ドリック']], array([[-5.26165286, -5.88261306]]))


Epoch 10: Train Loss 8.088, Valid Loss 8.036
([['ディック', 'デリック']], array([[-5.06439063, -5.16589776]]))


Epoch 11: Train Loss 7.386, Valid Loss 7.452
([['ディック', 'ディリック']], array([[-4.80800387, -5.33656245]]))


Epoch 12: Train Loss 6.938, Valid Loss 7.324
([['デリック', 'ディック']], array([[-4.12823898, -4.91474569]]))


Epoch 13: Train Loss 6.660, Valid Loss 6.921
([['デリック', 'ディック']], array([[-3.88129631, -3.95454511]]))


Epoch 14: Train Loss 6.299, Valid Loss 6.613
([['デリック', 'ディック']], array([[-3.41025676, -4.54222052]]))


Epoch 15: Train Loss 6.018, Valid Loss 6.360
([['デリック', 'デリク']], array([[-3.21972992, -4.00256772]]))


Epoch 16: Train Loss 5.770, Valid Loss 6.254
([['デリック', 'デリス']], array([[-2.8981768, -4.1935067]]))


Epoch 17: Train Loss 5.559, Valid Loss 6.168
([['デリック', 'デリク']], array([[-2.7408665 , -3.82106943]]))


Epoch 18: Train Loss 5.383, Valid Loss 5.923
([['デリック', 'デリク']], array([[-2.51655251, -3.69024467]]))


Epoch 19: Train Loss 5.235, Valid Loss 5.826
([['デリック', 'デリク']], array([[-2.33607227, -3.59429201]]))


Epoch 20: Train Loss 5.113, Valid Loss 5.792
([['デリック', 'デリク']], array([[-2.31403704, -3.71934621]]))


Epoch 21: Train Loss 4.943, Valid Loss 5.666
([['デリック', 'デリク']], array([[-2.30544494, -3.50285442]]))


Epoch 22: Train Loss 4.805, Valid Loss 5.586
([['デリック', 'デリク']], array([[-2.11921783, -3.3709477 ]]))


Epoch 23: Train Loss 4.667, Valid Loss 5.509
([['デリック', 'デリク']], array([[-2.57041283, -3.54113092]]))


Epoch 24: Train Loss 4.518, Valid Loss 5.397
([['デリック', 'デリク']], array([[-2.03805599, -3.50342301]]))


Epoch 25: Train Loss 4.440, Valid Loss 5.440
([['デリック', 'デリックス']], array([[-1.98899063, -3.46273818]]))


Epoch 26: Train Loss 4.310, Valid Loss 5.323
([['デリック', 'デリックス']], array([[-1.94289263, -3.15559326]]))


Epoch 27: Train Loss 4.176, Valid Loss 5.329
([['デリック', 'デリックス']], array([[-2.01173321, -3.28825078]]))


Epoch 28: Train Loss 4.093, Valid Loss 5.239
([['デリック', 'デリックス']], array([[-1.83368374, -3.17345879]]))


Epoch 29: Train Loss 3.984, Valid Loss 5.249
([['デリック', 'デリックス']], array([[-1.91276402, -3.02227522]]))


KeyboardInterrupt: 

In [37]:
checkpoint.restore(checkpoint_prefix + '-22')

<tensorflow.python.training.checkpointable.util.CheckpointLoadStatus at 0x7f430fa68d30>

In [34]:
importlib.reload(decode)
decode.transliterate(input_strs=['armor'],
                     from_script='en',
                     to_script='ja',
                     encoder=encoder,
                     decoder=decoder,
                     k_best=2,
                     decoding_method=decode.beam_search_decode)

([['アルモール', 'アーモーロー']], array([[-5.08133715, -5.74089668]]))