In [1]:
import tensorflow as tf
import numpy as np
tf.enable_eager_execution()
import importlib
import os

from transliteration import data, train, model_one, script

In [2]:
importlib.reload(data)
batch_size = 128
train_dataset = data.make_dataset('../data/tfrecord/train.tfrecord',
                                  from_script='en',
                                  to_script='ja',
                                  batch_size=batch_size)
valid_dataset = data.make_dataset('../data/tfrecord/valid.tfrecord',
                                  from_script='en',
                                  to_script='ja',
                                  batch_size=batch_size)
test_dataset = data.make_dataset('../data/tfrecord/test.tfrecord',
                                 from_script='en',
                                 to_script='ja',
                                 batch_size=batch_size)

In [3]:
optimizer = tf.train.AdamOptimizer()


def loss_function(real, pred):
    mask = 1 - np.equal(real, 0)
    loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred)
    return tf.reduce_mean(loss_ * mask)

In [4]:
importlib.reload(model_one)
encoder_config = model_one.Config(lstm_size=60,
                                  embedding_size=60,
                                  attention_size=30,
                                  vocab_size=script.SCRIPTS['en'].vocab_size)
decoder_config = model_one.Config(lstm_size=60,
                                  embedding_size=30,
                                  attention_size=30,
                                  vocab_size=script.SCRIPTS['ja'].vocab_size)
encoder = model_one.Encoder(encoder_config)
decoder = model_one.Decoder(decoder_config)

In [5]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 encoder=encoder,
                                 decoder=decoder)

In [6]:
importlib.reload(train)
best_val_loss = None
for e in range(200):
    loss = train.run_one_epoch(train_dataset,
                               True,
                               from_script='en',
                               to_script='ja',
                               encoder=encoder,
                               decoder=decoder,
                               optimizer=optimizer,
                               loss_function=loss_function)
    valid_loss = train.run_one_epoch(valid_dataset,
                                     False,
                                     from_script='en',
                                     to_script='ja',
                                     encoder=encoder,
                                     decoder=decoder,
                                     loss_function=loss_function)
    if best_val_loss is None or valid_loss < best_val_loss:
        best_val_loss = valid_loss
        checkpoint.save(file_prefix=checkpoint_prefix)
    print("Epoch {}: Train Loss {:.3f}, Valid Loss {:.3f}".format(e, loss, valid_loss))
    print(train.transliterate(input_strs=['derick'],
                              from_script='en',
                              to_script='ja',
                              encoder=encoder,
                              decoder=decoder))

Instructions for updating:
Colocations handled automatically by placer.


Epoch 0: Train Loss 20.510, Valid Loss 19.166
['リーー']


Epoch 1: Train Loss 18.329, Valid Loss 17.981
['アリンルル']


Epoch 2: Train Loss 16.788, Valid Loss 16.224
['カラントラン']


Epoch 3: Train Loss 15.123, Valid Loss 14.806
['ディリート']


Epoch 4: Train Loss 13.907, Valid Loss 13.705
['ディルリークル']


Epoch 5: Train Loss 12.930, Valid Loss 12.694
['ディルクルク']


Epoch 6: Train Loss 11.978, Valid Loss 11.763
['ディルドリック']


Epoch 7: Train Loss 10.839, Valid Loss 10.789
['ディルドリック']


Epoch 8: Train Loss 9.815, Valid Loss 9.768
['ディリック']


Epoch 9: Train Loss 8.908, Valid Loss 8.883
['デリーク']


Epoch 10: Train Loss 8.141, Valid Loss 8.299
['デリーク']


Epoch 11: Train Loss 7.579, Valid Loss 7.862
['デリーク']


Epoch 12: Train Loss 7.109, Valid Loss 7.453
['デリーク']


Epoch 13: Train Loss 6.720, Valid Loss 7.386
['デリーク']


Epoch 14: Train Loss 6.398, Valid Loss 6.859
['デリック']


Epoch 15: Train Loss 6.105, Valid Loss 6.779
['デリック']


Epoch 16: Train Loss 5.914, Valid Loss 6.502
['デリック']


Epoch 17: Train Loss 5.682, Valid Loss 6.313
['デリック']


Epoch 18: Train Loss 5.552, Valid Loss 6.228
['デリック']


Epoch 19: Train Loss 5.382, Valid Loss 6.119
['デリック']


Epoch 20: Train Loss 5.189, Valid Loss 6.039
['デリック']


Epoch 21: Train Loss 5.073, Valid Loss 5.874
['デリック']


Epoch 22: Train Loss 4.928, Valid Loss 5.937
['デリーククク']


Epoch 23: Train Loss 4.850, Valid Loss 5.861
['デリック']


Epoch 24: Train Loss 4.693, Valid Loss 5.717
['デリック']


Epoch 25: Train Loss 4.606, Valid Loss 5.701
['デリック']


Epoch 26: Train Loss 4.508, Valid Loss 5.664
['デリック']


Epoch 27: Train Loss 4.435, Valid Loss 5.581
['デリック']


Epoch 28: Train Loss 4.332, Valid Loss 5.499
['デリック']


Epoch 29: Train Loss 4.241, Valid Loss 5.799
['デリック']


Epoch 30: Train Loss 4.188, Valid Loss 5.434
['デリック']


Epoch 31: Train Loss 4.082, Valid Loss 5.500
['デリック']


Epoch 32: Train Loss 4.024, Valid Loss 5.395
['デリック']


Epoch 33: Train Loss 3.970, Valid Loss 5.450
['デリック']


Epoch 34: Train Loss 3.899, Valid Loss 5.410
['デリック']


Epoch 35: Train Loss 3.821, Valid Loss 5.398
['デリック']


Epoch 36: Train Loss 3.746, Valid Loss 5.313
['デリック']


Epoch 37: Train Loss 3.684, Valid Loss 5.245
['デリック']


Epoch 38: Train Loss 3.643, Valid Loss 5.364
['デリック']


Epoch 39: Train Loss 3.601, Valid Loss 5.393
['デリック']


Epoch 40: Train Loss 3.538, Valid Loss 5.336
['デリック']


Epoch 41: Train Loss 3.455, Valid Loss 5.292
['デリック']


Epoch 42: Train Loss 3.411, Valid Loss 5.375
['デリック']


Epoch 43: Train Loss 3.369, Valid Loss 5.302
['デリック']


Epoch 44: Train Loss 3.360, Valid Loss 5.408
['デリック']


Epoch 45: Train Loss 3.333, Valid Loss 5.374
['デリック']


Epoch 46: Train Loss 3.270, Valid Loss 5.358
['デリック']


Epoch 47: Train Loss 3.216, Valid Loss 5.477
['デリック']


Epoch 48: Train Loss 3.218, Valid Loss 5.508
['デリック']


Epoch 49: Train Loss 3.113, Valid Loss 5.401
['デリック']


Epoch 50: Train Loss 3.095, Valid Loss 5.378
['デリック']


Epoch 51: Train Loss 3.073, Valid Loss 5.396
['デリック']


Epoch 52: Train Loss 2.996, Valid Loss 5.477
['デリック']


Epoch 53: Train Loss 2.958, Valid Loss 5.345
['デリック']


Epoch 54: Train Loss 2.900, Valid Loss 5.432
['デリック']


Epoch 55: Train Loss 2.862, Valid Loss 5.526
['デリック']


Epoch 56: Train Loss 2.816, Valid Loss 5.471
['デリック']


Epoch 57: Train Loss 2.848, Valid Loss 5.729
['デリック']


Epoch 58: Train Loss 2.813, Valid Loss 5.516
['デリック']


KeyboardInterrupt: 

In [23]:
importlib.reload(train)
importlib.reload(data)
train.transliterate(input_strs=['bob'],
                    from_script='en',
                    to_script='ja',
                    encoder=encoder,
                    decoder=decoder)

['ボボボ']

In [42]:
data.intern_katakana_char('<end>')

98