In [1]:
import tensorflow as tf
import numpy as np
tf.enable_eager_execution()
import importlib
import os

from transliteration import data, train, model_one, script, decode

In [2]:
importlib.reload(data)
batch_size = 256
train_dataset = data.make_dataset('../data/tfrecord/train.tfrecord',
                                  from_script='en',
                                  to_script='ja',
                                  batch_size=batch_size)
valid_dataset = data.make_dataset('../data/tfrecord/valid.tfrecord',
                                  from_script='en',
                                  to_script='ja',
                                  batch_size=batch_size)
test_dataset = data.make_dataset('../data/tfrecord/test.tfrecord',
                                 from_script='en',
                                 to_script='ja',
                                 batch_size=batch_size)

In [9]:
batch_size = 256
cmu_train_dataset = data.make_dataset('../data/tfrecord/cmu_train.tfrecord',
                                  from_script='en',
                                  to_script='cmu',
                                  batch_size=batch_size)
cmu_valid_dataset = data.make_dataset('../data/tfrecord/cmu_valid.tfrecord',
                                  from_script='en',
                                  to_script='cmu',
                                  batch_size=batch_size)
cmu_test_dataset = data.make_dataset('../data/tfrecord/cmu_test.tfrecord',
                                 from_script='en',
                                 to_script='cmu',
                                 batch_size=batch_size)

In [10]:
importlib.reload(model_one)
optimizer = tf.train.AdamOptimizer()


def loss_function(real, pred):
    mask = 1 - np.equal(real, 0)
    loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred)
    return tf.reduce_mean(loss_ * mask)

encoder_config = model_one.Config(lstm_size=120,
                                  embedding_size=60,
                                  attention_size=None,
                                  vocab_size=script.SCRIPTS['en'].vocab_size)
decoder_config = model_one.Config(lstm_size=80,
                                  embedding_size=60,
                                  attention_size=60,
                                  attention='monotonic_bahdanau',
                                  vocab_size=script.SCRIPTS['ja'].vocab_size)
cmu_decoder_config = model_one.Config(lstm_size=80,
                                  embedding_size=60,
                                  attention_size=60,
                                  attention='monotonic_bahdanau',
                                  vocab_size=script.SCRIPTS['cmu'].vocab_size)
encoder = model_one.Encoder(encoder_config)
decoder = model_one.Decoder(decoder_config)
cmu_decoder = model_one.Decoder(cmu_decoder_config)

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 encoder=encoder,
                                 decoder=decoder,
                                 cmu_decoder=cmu_decoder)

In [11]:
importlib.reload(train)
best_val_loss = None
for e in range(10):
    loss = train.run_one_epoch(cmu_train_dataset,
                               True,
                               from_script='en',
                               to_script='cmu',
                               encoder=encoder,
                               decoder=cmu_decoder,
                               optimizer=optimizer,
                               loss_function=loss_function)
    valid_loss = train.run_one_epoch(cmu_valid_dataset,
                                     False,
                                     from_script='en',
                                     to_script='cmu',
                                     encoder=encoder,
                                     decoder=cmu_decoder,
                                     loss_function=loss_function)
    if best_val_loss is None or valid_loss < best_val_loss:
        best_val_loss = valid_loss
        checkpoint.save(file_prefix=checkpoint_prefix)
    print("Epoch {}: Train Loss {:.3f}, Valid Loss {:.3f}".format(e, loss, valid_loss))
    print(decode.transliterate(input_strs=['derick'],
                               from_script='en',
                               to_script='cmu',
                               encoder=encoder,
                               decoder=cmu_decoder,
                               k_best=2,
                               decoding_method=decode.beam_search_decode))

Epoch 0: Train Loss 16.177, Valid Loss 7.804
([['D ER0 IH0 K K AH0 K', 'D ER0 IH0 K K AH0 K AH0 S']], array([[ -8.2832819 , -11.33341632]]))


Epoch 1: Train Loss 5.812, Valid Loss 4.820
([['D EH1 R IH0 K AH0 S', 'D EH1 R IH0 K AH0 K AH0 N']], array([[-6.2494958 , -8.99595803]]))


Epoch 2: Train Loss 4.102, Valid Loss 3.865
([['D EH1 R IH0 K AH0 S', 'D EH1 R IH0 K AH0 K AH0 N']], array([[-4.97263764, -7.36846613]]))


Epoch 3: Train Loss 3.398, Valid Loss 3.263
([['D EH1 R IH0 K AH0 K ER0', 'D EH1 R IH0 K AH0 K AH0 N']], array([[-6.09805015, -7.05461708]]))


Epoch 4: Train Loss 2.992, Valid Loss 3.024
([['D EH1 R IH0 K AH0 CH ER0', 'D EH1 R IH0 K AH0 K AH0 N']], array([[-5.87977612, -8.12592036]]))


Epoch 5: Train Loss 2.784, Valid Loss 2.739
([['D EH1 R IH0 K AH0 CH ER0', 'D EH1 R IH0 K AH0 L Z']], array([[-5.35534295, -5.89101115]]))


Epoch 6: Train Loss 2.526, Valid Loss 2.623
([['D EH1 R IH0 K AH0 CH ER0', 'D EH1 R IH0 K AH0 K AH0 L Z']], array([[-5.45755468, -9.41242491]]))


Epoch 7: Train Loss 2.376, Valid Loss 2.515
([['D EH1 R IH0 K AH0 CH ER0', 'D EH1 R IH0 K AH0 K AH0 L D ER0']], array([[ -5.4120316 , -11.06813771]]))


Epoch 8: Train Loss 2.272, Valid Loss 2.366
([['D EH1 R IH0 K AH0 CH ER0', 'D EH1 R IH0 K AH0 NG K ER0']], array([[-5.50152449, -6.60748948]]))


Epoch 9: Train Loss 2.169, Valid Loss 2.286
([['D EH1 R IH0 K ER0', 'D EH1 R IH0 K AH0 CH ER0']], array([[-3.5732613, -4.9583194]]))


In [12]:
importlib.reload(train)
best_val_loss = None
for e in range(5):
    loss = train.run_one_epoch(train_dataset,
                               True,
                               from_script='en',
                               to_script='ja',
                               encoder=encoder,
                               decoder=decoder,
                               train_encoder=False,
                               optimizer=optimizer,
                               loss_function=loss_function)
    valid_loss = train.run_one_epoch(valid_dataset,
                                     False,
                                     from_script='en',
                                     to_script='ja',
                                     encoder=encoder,
                                     decoder=decoder,
                                     loss_function=loss_function)
    if best_val_loss is None or valid_loss < best_val_loss:
        best_val_loss = valid_loss
        checkpoint.save(file_prefix=checkpoint_prefix)
    print("Epoch {}: Train Loss {:.3f}, Valid Loss {:.3f}".format(e, loss, valid_loss))
    print(decode.transliterate(input_strs=['derick'],
                               from_script='en',
                               to_script='ja',
                               encoder=encoder,
                               decoder=decoder,
                               k_best=2,
                               decoding_method=decode.beam_search_decode))

Epoch 0: Train Loss 14.830, Valid Loss 10.378
([['デリック', 'デリクク']], array([[-4.42585427, -4.96765587]]))


Epoch 1: Train Loss 8.380, Valid Loss 7.523
([['デリック', 'デリックル']], array([[-3.11422116, -5.02311182]]))


Epoch 2: Train Loss 6.674, Valid Loss 6.537
([['デリック', 'デリックス']], array([[-3.37739228, -4.80727522]]))


Epoch 3: Train Loss 5.952, Valid Loss 6.151
([['デリック', 'デリックス']], array([[-2.88655928, -4.43254995]]))


Epoch 4: Train Loss 5.427, Valid Loss 5.988
([['デリック', 'デリックス']], array([[-2.59951185, -4.37641649]]))


In [35]:
importlib.reload(train)
best_val_loss = None
for e in range(5):
    loss = train.run_one_epoch(train_dataset,
                               True,
                               from_script='en',
                               to_script='ja',
                               encoder=encoder,
                               decoder=decoder,
                               train_encoder=True,
                               optimizer=optimizer,
                               loss_function=loss_function)
    valid_loss = train.run_one_epoch(valid_dataset,
                                     False,
                                     from_script='en',
                                     to_script='ja',
                                     encoder=encoder,
                                     decoder=decoder,
                                     loss_function=loss_function)
    if best_val_loss is None or valid_loss < best_val_loss:
        best_val_loss = valid_loss
        checkpoint.save(file_prefix=checkpoint_prefix)
    print("Epoch {}: Train Loss {:.3f}, Valid Loss {:.3f}".format(e, loss, valid_loss))
    print(decode.transliterate(input_strs=['derick'],
                               from_script='en',
                               to_script='ja',
                               encoder=encoder,
                               decoder=decoder,
                               k_best=2,
                               decoding_method=decode.beam_search_decode))

Epoch 0: Train Loss 3.061, Valid Loss 4.737
([['デリック', 'デリク']], array([[-2.0089123, -3.2298286]]))


Epoch 1: Train Loss 2.940, Valid Loss 4.718
([['デリック', 'デリク']], array([[-1.90779161, -3.51031125]]))


Epoch 2: Train Loss 2.837, Valid Loss 4.761
([['デリック', 'デリク']], array([[-1.91291335, -2.73543338]]))


Epoch 3: Train Loss 2.738, Valid Loss 4.881
([['デリック', 'デリク']], array([[-1.55329115, -2.82706814]]))


Epoch 4: Train Loss 2.653, Valid Loss 4.749
([['デリック', 'デリックス']], array([[-1.62591756, -3.16085919]]))


In [44]:
checkpoint.restore(checkpoint_prefix + '-26').assert_consumed()
train.run_one_epoch(valid_dataset,
                    False,
                    from_script='en',
                    to_script='ja',
                    encoder=encoder,
                    decoder=decoder,
                    loss_function=loss_function)

<tf.Tensor: id=129273772, shape=(), dtype=float32, numpy=4.719208>

In [46]:
importlib.reload(decode)
decode.transliterate(input_strs=['room'],
                     from_script='en',
                     to_script='ja',
                     encoder=encoder,
                     decoder=decoder,
                     k_best=2,
                     decoding_method=decode.beam_search_decode)

([['ルーム', 'ルームズ']], array([[-1.43902729, -3.161433  ]]))