In [1]:
import tensorflow as tf
import numpy as np
import pandas as pd
tf.enable_eager_execution()
import importlib
import os

from matplotlib import pyplot as plt
%matplotlib notebook

from transliteration import data, train, model_one, script, decode, evaluate

In [2]:
batch_size = 128
our_train_dataset = data.make_dataset('../data/tfrecord/our_2_train.tfrecord',
                                      from_script='en',
                                      to_script='ja',
                                      batch_size=batch_size)
our_valid_dataset = data.make_dataset('../data/tfrecord/our_2_valid.tfrecord',
                                      from_script='en',
                                      to_script='ja',
                                      batch_size=batch_size)
eob_valid_dataset = data.make_dataset('../data/tfrecord/eob_valid.tfrecord',
                                       from_script='en',
                                       to_script='ja',
                                      batch_size=batch_size)
muse_valid_dataset = data.make_dataset('../data/tfrecord/muse_valid.tfrecord',
                                       from_script='en',
                                       to_script='ja',
                                       batch_size=batch_size)

In [3]:
optimizer = tf.train.AdamOptimizer()

def loss_function(real, pred):
    mask = 1 - np.equal(real, 0)
    loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred)
    return tf.reduce_mean(loss_ * mask)

ja_encoder_config = model_one.Config(lstm_size=120,
                                     embedding_size=30,
                                     attention_size=None,
                                     vocab_size=script.SCRIPTS['en'].vocab_size)
ja_decoder_config = model_one.Config(lstm_size=120,
                                     embedding_size=30,
                                     attention_size=60,
                                     vocab_size=script.SCRIPTS['ja'].vocab_size)
ja_encoder = model_one.Encoder(ja_encoder_config)
ja_decoder = model_one.Decoder(ja_decoder_config)

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = None

def make_checkpoint_obj():
    return tf.train.Checkpoint(optimizer=optimizer,
                               ja_encoder=ja_encoder,
                               ja_decoder=ja_decoder)

In [4]:
best_val_loss = None
ja_checkpoint = None

In [5]:
for e in range(15):
    loss = train.run_one_epoch(our_train_dataset,
                               True,
                               from_script='en',
                               to_script='ja',
                               encoder=ja_encoder,
                               decoder=ja_decoder,
                               optimizer=optimizer,
                               loss_function=loss_function)
    valid_loss = train.run_one_epoch(our_valid_dataset,
                                     False,
                                     from_script='en',
                                     to_script='ja',
                                     encoder=ja_encoder,
                                     decoder=ja_decoder,
                                     loss_function=loss_function)
    if best_val_loss is None or valid_loss < best_val_loss:
        best_val_loss = valid_loss
        if checkpoint is None:
            checkpoint = make_checkpoint_obj()
        ja_checkpoint = checkpoint.save(file_prefix=checkpoint_prefix)
    print("Epoch {}: Train Loss {:.3f}, Valid Loss {:.3f}".format(e, loss, valid_loss))
    print(decode.transliterate(input_strs=['derick'],
                               from_script='en',
                               to_script='ja',
                               encoder=ja_encoder,
                               decoder=ja_decoder,
                               k_best=2,
                               decoding_method=decode.beam_search_decode))

Instructions for updating:
Colocations handled automatically by placer.


Epoch 0: Train Loss 18.893, Valid Loss 17.194
([['', 'ル']], array([[-4.09070444, -8.1306653 ]]))


Epoch 1: Train Loss 15.474, Valid Loss 14.672
([['', 'ー']], array([[-3.96696663, -7.25935888]]))


Epoch 2: Train Loss 14.238, Valid Loss 13.996
([['ー', 'ン']], array([[-6.95424581, -7.10179305]]))


Epoch 3: Train Loss 13.651, Valid Loss 13.575
([['ーー', 'ーン']], array([[-9.0966785 , -9.29305768]]))


Epoch 4: Train Loss 13.300, Valid Loss 13.389
([['ーー', 'ーン']], array([[-8.46067882, -8.61680281]]))


Epoch 5: Train Loss 13.164, Valid Loss 13.219
([['ーー', 'リー']], array([[-8.97501707, -9.21134734]]))


Epoch 6: Train Loss 13.005, Valid Loss 13.053
([['ーー', 'ラー']], array([[-9.38374758, -9.49531746]]))


Epoch 7: Train Loss 12.904, Valid Loss 13.117
([['ーー', 'ラーン']], array([[ -9.63600779, -10.23621213]]))


Epoch 8: Train Loss 12.845, Valid Loss 12.919
([['ーー', 'ラーー']], array([[ -9.61942315, -10.25515229]]))


Epoch 9: Train Loss 12.758, Valid Loss 12.997
([['ラー', 'リーン']], array([[-9.43848062, -9.77764022]]))


Epoch 10: Train Loss 12.644, Valid Loss 13.121
([['リーン', 'リーー']], array([[-10.08886331, -10.18286264]]))


Epoch 11: Train Loss 12.535, Valid Loss 12.360
([['リー', 'ラー']], array([[-9.50044107, -9.51245451]]))


Epoch 12: Train Loss 12.444, Valid Loss 12.439
([['リー', 'リーン']], array([[-9.25696707, -9.55111781]]))


Epoch 13: Train Loss 12.323, Valid Loss 12.569
([['リー', 'リール']], array([[-8.90347123, -9.25400236]]))


Epoch 14: Train Loss 12.290, Valid Loss 12.488
([['リー', 'リン']], array([[-8.347754  , -8.82998943]]))


In [6]:
for e in range(15, 100):
    loss = train.run_one_epoch(our_train_dataset,
                               True,
                               from_script='en',
                               to_script='ja',
                               encoder=ja_encoder,
                               decoder=ja_decoder,
                               optimizer=optimizer,
                               loss_function=loss_function)
    valid_loss = train.run_one_epoch(our_valid_dataset,
                                     False,
                                     from_script='en',
                                     to_script='ja',
                                     encoder=ja_encoder,
                                     decoder=ja_decoder,
                                     loss_function=loss_function)
    if best_val_loss is None or valid_loss < best_val_loss:
        best_val_loss = valid_loss
        if checkpoint is None:
            checkpoint = make_checkpoint_obj()
        ja_checkpoint = checkpoint.save(file_prefix=checkpoint_prefix)
    print("Epoch {}: Train Loss {:.3f}, Valid Loss {:.3f}".format(e, loss, valid_loss))
    print(decode.transliterate(input_strs=['derick'],
                               from_script='en',
                               to_script='ja',
                               encoder=ja_encoder,
                               decoder=ja_decoder,
                               k_best=2,
                               decoding_method=decode.beam_search_decode))

Epoch 15: Train Loss 12.236, Valid Loss 12.349
([['リー', 'リン']], array([[-7.69474316, -8.27618408]]))


Epoch 16: Train Loss 12.216, Valid Loss 12.339
([['リー', 'リン']], array([[-7.87511623, -8.41008997]]))


Epoch 17: Train Loss 12.089, Valid Loss 11.833
([['リー', 'リン']], array([[-7.43134153, -8.0247581 ]]))


Epoch 18: Train Loss 12.037, Valid Loss 11.693
([['リー', 'リン']], array([[-7.06350589, -7.84574306]]))


Epoch 19: Train Loss 11.987, Valid Loss 11.993
([['ロー', 'ロン']], array([[-7.59558558, -8.03289866]]))


Epoch 20: Train Loss 11.935, Valid Loss 11.604
([['リー', 'ロー']], array([[-6.95484495, -7.14413989]]))


Epoch 21: Train Loss 11.896, Valid Loss 12.450
([['リー', 'リール']], array([[-7.45183396, -7.86078775]]))


Epoch 22: Train Loss 11.857, Valid Loss 12.397
([['リー', 'コー']], array([[-6.94243491, -7.55897105]]))


Epoch 23: Train Loss 11.787, Valid Loss 12.353
([['コー', 'ラン']], array([[-7.59818316, -7.82518053]]))


Epoch 24: Train Loss 11.714, Valid Loss 12.020
([['コー', 'ラン']], array([[-7.22297704, -7.36407614]]))


Epoch 25: Train Loss 11.682, Valid Loss 11.570
([['ロー', 'コー']], array([[-7.09749436, -7.50068533]]))


Epoch 26: Train Loss 11.602, Valid Loss 12.327
([['ラント', 'コー']], array([[-7.45920898, -7.59539366]]))


Epoch 27: Train Loss 11.548, Valid Loss 12.112
([['コー', 'ラン']], array([[-7.15042198, -7.29488742]]))


Epoch 28: Train Loss 11.476, Valid Loss 11.318
([['ロー', 'ラント']], array([[-7.07994783, -7.37113078]]))


Epoch 29: Train Loss 11.409, Valid Loss 11.255
([['ロー', 'ロール']], array([[-7.0559473, -7.292904 ]]))


Epoch 30: Train Loss 11.353, Valid Loss 12.420
([['スラン', 'ラント']], array([[-6.94576689, -7.36596085]]))


Epoch 31: Train Loss 11.289, Valid Loss 11.955
([['スラン', 'ロー']], array([[-6.91233492, -7.01830816]]))


Epoch 32: Train Loss 11.195, Valid Loss 11.864
([['スラン', 'ランド']], array([[-6.90497187, -7.326225  ]]))


Epoch 33: Train Loss 11.139, Valid Loss 12.182
([['スラン', 'コール']], array([[-6.70632673, -6.96020754]]))


Epoch 34: Train Loss 11.055, Valid Loss 11.437
([['スラン', 'コール']], array([[-6.73074523, -7.11060132]]))


Epoch 35: Train Loss 11.032, Valid Loss 11.360
([['サン', 'スラン']], array([[-6.37348509, -6.62938705]]))


Epoch 36: Train Loss 10.992, Valid Loss 11.132
([['ディス', 'ディル']], array([[-6.80133558, -6.95877632]]))


Epoch 37: Train Loss 10.860, Valid Loss 11.922
([['サン', 'サンド']], array([[-6.30212533, -6.6556958 ]]))


Epoch 38: Train Loss 10.776, Valid Loss 11.936
([['サン', 'サンド']], array([[-6.59302294, -6.63984786]]))


Epoch 39: Train Loss 10.670, Valid Loss 11.887
([['サン', 'ディス']], array([[-6.54803514, -6.57300754]]))


Epoch 40: Train Loss 10.591, Valid Loss 11.461
([['サンド', 'サント']], array([[-6.56792417, -6.91832024]]))


Epoch 41: Train Loss 10.515, Valid Loss 11.958
([['ディス', 'ディア']], array([[-6.22172285, -6.25650353]]))


Epoch 42: Train Loss 10.428, Valid Loss 12.225
([['ディス', 'ディア']], array([[-6.26754711, -6.39900935]]))


Epoch 43: Train Loss 10.369, Valid Loss 10.922
([['ディス', 'ディア']], array([[-6.5732631, -6.6288981]]))


Epoch 44: Train Loss 10.277, Valid Loss 11.579
([['ディル', 'ディア']], array([[-6.11214685, -6.19505339]]))


Epoch 45: Train Loss 10.225, Valid Loss 11.822
([['デック', 'ディア']], array([[-5.81743953, -6.35966944]]))


Epoch 46: Train Loss 10.101, Valid Loss 11.460
([['デック', 'ディル']], array([[-5.71197397, -6.09957897]]))


Epoch 47: Train Loss 10.022, Valid Loss 11.722
([['デック', 'ドック']], array([[-5.66418384, -6.09656483]]))


Epoch 48: Train Loss 9.916, Valid Loss 11.694
([['デック', 'ディル']], array([[-5.48117416, -5.68262013]]))


Epoch 49: Train Loss 9.839, Valid Loss 11.546
([['デック', 'ディア']], array([[-5.24662854, -5.82407024]]))


Epoch 50: Train Loss 9.735, Valid Loss 11.614
([['ドック', 'ディル']], array([[-5.76182419, -5.82110551]]))


Epoch 51: Train Loss 9.617, Valid Loss 11.379
([['デック', 'ディル']], array([[-5.2168895 , -5.89932529]]))


Epoch 52: Train Loss 9.531, Valid Loss 11.698
([['デック', 'ディル']], array([[-5.38141933, -5.80943595]]))


Epoch 53: Train Loss 9.409, Valid Loss 11.089
([['デック', 'ディル']], array([[-5.01805422, -5.53385383]]))


Epoch 54: Train Loss 9.327, Valid Loss 11.112
([['デック', 'ディル']], array([[-4.8826919 , -5.41494825]]))


Epoch 55: Train Loss 9.223, Valid Loss 11.190
([['デック', 'ダール']], array([[-4.91599151, -5.53230841]]))


Epoch 56: Train Loss 9.087, Valid Loss 11.695
([['デック', 'ディル']], array([[-4.67645852, -5.13162861]]))


Epoch 57: Train Loss 8.947, Valid Loss 10.988
([['デック', 'ディル']], array([[-4.74903919, -5.56818378]]))


Epoch 58: Train Loss 8.814, Valid Loss 12.980
([['デック', 'ディル']], array([[-4.64177781, -5.32941012]]))


Epoch 59: Train Loss 8.686, Valid Loss 12.048
([['デック', 'ダール']], array([[-4.53887695, -5.42842088]]))


Epoch 60: Train Loss 8.541, Valid Loss 12.003
([['デック', 'ディル']], array([[-4.35450761, -5.16484491]]))


Epoch 61: Train Loss 8.393, Valid Loss 11.623
([['デック', 'ダール']], array([[-4.23731735, -5.31813067]]))


Epoch 62: Train Loss 8.231, Valid Loss 12.106
([['デック', 'ダール']], array([[-4.24920981, -5.42189418]]))


Epoch 63: Train Loss 8.065, Valid Loss 12.194
([['デック', 'テック']], array([[-4.04712906, -4.85415128]]))


Epoch 64: Train Loss 7.899, Valid Loss 12.226
([['デック', 'ダール']], array([[-3.92812695, -5.35622302]]))


Epoch 65: Train Loss 7.736, Valid Loss 11.757
([['デック', 'デックス']], array([[-3.95589676, -5.59148896]]))


Epoch 66: Train Loss 7.611, Valid Loss 11.413
([['デック', 'ダーク']], array([[-3.73388763, -4.75318233]]))


Epoch 67: Train Loss 7.399, Valid Loss 11.808
([['デック', 'ダーク']], array([[-3.72477344, -5.30872452]]))


Epoch 68: Train Loss 7.211, Valid Loss 13.555
([['デック', 'ダーク']], array([[-3.77137035, -4.78121322]]))


Epoch 69: Train Loss 7.043, Valid Loss 11.803
([['デック', 'ダーク']], array([[-3.32412944, -4.3202281 ]]))


Epoch 70: Train Loss 6.875, Valid Loss 11.032
([['デック', 'デーク']], array([[-3.46466473, -4.98483496]]))


Epoch 71: Train Loss 6.695, Valid Loss 12.002
([['デック', 'ダール']], array([[-3.81044066, -4.40340318]]))


Epoch 72: Train Loss 6.482, Valid Loss 12.620
([['デック', 'デーク']], array([[-3.5689491 , -5.17268448]]))


Epoch 73: Train Loss 6.280, Valid Loss 12.274
([['デック', 'ダール']], array([[-3.29405502, -5.06248887]]))


Epoch 74: Train Loss 6.095, Valid Loss 13.538
([['デック', 'ダーク']], array([[-3.53195053, -5.05824715]]))


Epoch 75: Train Loss 5.868, Valid Loss 12.090
([['デック', 'ダーク']], array([[-3.29450636, -4.17080583]]))


Epoch 76: Train Loss 5.746, Valid Loss 13.225
([['デック', 'シーク']], array([[-3.54097803, -4.31663749]]))


Epoch 77: Train Loss 5.573, Valid Loss 13.365
([['デック', 'シーク']], array([[-3.63117349, -4.10899925]]))


Epoch 78: Train Loss 5.364, Valid Loss 13.209
([['デック', 'ダーク']], array([[-3.02241894, -4.75569341]]))


Epoch 79: Train Loss 5.145, Valid Loss 13.291
([['デック', 'タール']], array([[-3.48343081, -3.88947146]]))


Epoch 80: Train Loss 4.934, Valid Loss 12.724
([['デック', 'テック']], array([[-3.3690614 , -4.24862789]]))


Epoch 81: Train Loss 4.763, Valid Loss 15.124
([['デック', 'シーク']], array([[-3.19129921, -4.3011933 ]]))


Epoch 82: Train Loss 4.597, Valid Loss 13.368
([['デック', 'シーク']], array([[-2.96625765, -4.00626237]]))


Epoch 83: Train Loss 4.421, Valid Loss 13.584
([['デック', 'シーク']], array([[-3.01486422, -4.13360364]]))


Epoch 84: Train Loss 4.319, Valid Loss 15.529
([['デック', 'ダール']], array([[-3.48407577, -4.14816164]]))


Epoch 85: Train Loss 4.137, Valid Loss 15.065
([['デック', 'ダール']], array([[-3.0029078, -3.6989931]]))


Epoch 86: Train Loss 3.926, Valid Loss 13.110
([['デック', 'シール']], array([[-3.36067244, -3.91971962]]))


Epoch 87: Train Loss 3.824, Valid Loss 15.008
([['デック', 'デルク']], array([[-3.10455038, -3.79842658]]))


Epoch 88: Train Loss 3.655, Valid Loss 15.695
([['デック', 'シーク']], array([[-3.2915145 , -4.12747113]]))


Epoch 89: Train Loss 3.469, Valid Loss 14.852
([['デック', 'デルク']], array([[-3.04959111, -3.30842804]]))


Epoch 90: Train Loss 3.291, Valid Loss 15.658
([['デック', 'シール']], array([[-3.48205354, -3.67319681]]))


Epoch 91: Train Loss 3.124, Valid Loss 15.851
([['デック', 'シール']], array([[-2.86882764, -3.59103976]]))


Epoch 92: Train Loss 3.015, Valid Loss 15.113
([['デック', 'シール']], array([[-3.41172934, -3.72962239]]))


Epoch 93: Train Loss 2.870, Valid Loss 14.266
([['デック', 'シルク']], array([[-3.17465048, -3.62027015]]))


Epoch 94: Train Loss 2.767, Valid Loss 15.368
([['デック', 'シール']], array([[-3.05993757, -3.12989264]]))


Epoch 95: Train Loss 2.603, Valid Loss 15.524
([['デック', 'シック']], array([[-3.530673  , -3.74382568]]))


Epoch 96: Train Loss 2.481, Valid Loss 15.711
([['デック', 'シール']], array([[-3.01422212, -3.3645761 ]]))


Epoch 97: Train Loss 2.376, Valid Loss 14.509
([['シール', 'シック']], array([[-3.18444835, -3.29695437]]))


Epoch 98: Train Loss 2.288, Valid Loss 17.458
([['シール', 'シルク']], array([[-3.2040583 , -3.24740525]]))


Epoch 99: Train Loss 2.151, Valid Loss 16.061
([['デック', 'シック']], array([[-2.91614201, -3.34366792]]))


In [7]:
checkpoint.restore(ja_checkpoint).assert_consumed()
train.run_one_epoch(our_valid_dataset,
                    False,
                    from_script='en',
                    to_script='ja',
                    encoder=ja_encoder,
                    decoder=ja_decoder,
                    loss_function=loss_function)

<tf.Tensor: id=13781081, shape=(), dtype=float32, numpy=11.939638>

In [8]:
train.run_one_epoch(eob_valid_dataset,
                    False,
                    from_script='en',
                    to_script='ja',
                    encoder=ja_encoder,
                    decoder=ja_decoder,
                    loss_function=loss_function)

<tf.Tensor: id=13927716, shape=(), dtype=float32, numpy=41.722157>

In [9]:
train.run_one_epoch(muse_valid_dataset,
                    False,
                    from_script='en',
                    to_script='ja',
                    encoder=ja_encoder,
                    decoder=ja_decoder,
                    loss_function=loss_function)

<tf.Tensor: id=13995484, shape=(), dtype=float32, numpy=21.196993>

In [10]:
valid_df = pd.read_csv('../data/split/muse_pairs_valid.csv',
                       keep_default_na=False)

In [11]:
tr = decode.transliterate(input_strs=valid_df['en'].values,
                          from_script='en',
                          to_script='ja',
                          encoder=ja_encoder,
                          decoder=ja_decoder,
                          k_best=20,
                          num_beams=40,
                          decoding_method=decode.beam_search_decode)

In [12]:
evaluate.top_k_accuracy(valid_df['ja'].values, tr, k=1)

0.004070556309362279

In [13]:
evaluate.top_k_accuracy(valid_df['ja'].values, tr, k=10)

0.012211668928086838