In [1]:
import tensorflow as tf
import numpy as np
import pandas as pd
tf.enable_eager_execution()
import importlib
import os

from matplotlib import pyplot as plt
%matplotlib notebook

from transliteration import data, train, model_one, script, decode, evaluate

In [2]:
batch_size = 128
our_train_dataset = data.make_dataset('../data/tfrecord/our_train.tfrecord',
                                      from_script='en',
                                      to_script='ja',
                                      batch_size=batch_size)
our_valid_dataset = data.make_dataset('../data/tfrecord/our_valid.tfrecord',
                                      from_script='en',
                                      to_script='ja',
                                      batch_size=batch_size)
eob_valid_dataset = data.make_dataset('../data/tfrecord/eob_valid.tfrecord',
                                       from_script='en',
                                       to_script='ja',
                                      batch_size=batch_size)
muse_valid_dataset = data.make_dataset('../data/tfrecord/muse_valid.tfrecord',
                                       from_script='en',
                                       to_script='ja',
                                       batch_size=batch_size)

In [3]:
optimizer = tf.train.AdamOptimizer()

def loss_function(real, pred):
    mask = 1 - np.equal(real, 0)
    loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred)
    return tf.reduce_mean(loss_ * mask)

ja_encoder_config = model_one.Config(lstm_size=120,
                                     embedding_size=30,
                                     attention_size=None,
                                     vocab_size=script.SCRIPTS['en'].vocab_size)
ja_decoder_config = model_one.Config(lstm_size=120,
                                     embedding_size=30,
                                     attention_size=60,
                                     vocab_size=script.SCRIPTS['ja'].vocab_size)
ja_encoder = model_one.Encoder(ja_encoder_config)
ja_decoder = model_one.Decoder(ja_decoder_config)

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = None

def make_checkpoint_obj():
    return tf.train.Checkpoint(optimizer=optimizer,
                               ja_encoder=ja_encoder,
                               ja_decoder=ja_decoder)

In [4]:
best_val_loss = None
ja_checkpoint = None

In [5]:
for e in range(15):
    loss = train.run_one_epoch(our_train_dataset,
                               True,
                               from_script='en',
                               to_script='ja',
                               encoder=ja_encoder,
                               decoder=ja_decoder,
                               optimizer=optimizer,
                               loss_function=loss_function)
    valid_loss = train.run_one_epoch(our_valid_dataset,
                                     False,
                                     from_script='en',
                                     to_script='ja',
                                     encoder=ja_encoder,
                                     decoder=ja_decoder,
                                     loss_function=loss_function)
    if best_val_loss is None or valid_loss < best_val_loss:
        best_val_loss = valid_loss
        if checkpoint is None:
            checkpoint = make_checkpoint_obj()
        ja_checkpoint = checkpoint.save(file_prefix=checkpoint_prefix)
    print("Epoch {}: Train Loss {:.3f}, Valid Loss {:.3f}".format(e, loss, valid_loss))
    print(decode.transliterate(input_strs=['derick'],
                               from_script='en',
                               to_script='ja',
                               encoder=ja_encoder,
                               decoder=ja_decoder,
                               k_best=2,
                               decoding_method=decode.beam_search_decode))

Instructions for updating:
Colocations handled automatically by placer.


Epoch 0: Train Loss 14.392, Valid Loss 13.576
([['アーー', 'アース']], array([[-10.28815603, -10.39903569]]))


Epoch 1: Train Loss 13.459, Valid Loss 13.335
([['アース', 'マース']], array([[-10.18462709, -10.2174125 ]]))


Epoch 2: Train Loss 13.182, Valid Loss 13.158
([['アース', 'アンタ']], array([[-9.82302047, -9.8355126 ]]))


Epoch 3: Train Loss 13.005, Valid Loss 13.004
([['アルコ', 'アルト']], array([[-9.67963633, -9.71657748]]))


Epoch 4: Train Loss 12.878, Valid Loss 12.934
([['サント', 'コント']], array([[-9.61216854, -9.64887603]]))


Epoch 5: Train Loss 12.802, Valid Loss 12.844
([['フェル', 'フォル']], array([[-8.76831627, -8.8010599 ]]))


Epoch 6: Train Loss 12.743, Valid Loss 12.791
([['アイコ', 'アイト']], array([[-9.49943934, -9.59871354]]))


Epoch 7: Train Loss 12.680, Valid Loss 12.742
([['アリア', 'アリオ']], array([[-8.88664953, -9.02956691]]))


Epoch 8: Train Loss 12.625, Valid Loss 12.704
([['シェル', 'シュル']], array([[-9.00573689, -9.05271047]]))


Epoch 9: Train Loss 12.570, Valid Loss 12.682
([['シュル', 'シェル']], array([[-8.69352408, -8.69719134]]))


Epoch 10: Train Loss 12.533, Valid Loss 12.697
([['スタン', 'シュル']], array([[-8.75985097, -8.78409141]]))


Epoch 11: Train Loss 12.488, Valid Loss 12.631
([['カルト', 'オルト']], array([[-9.67866716, -9.73689082]]))


Epoch 12: Train Loss 12.440, Valid Loss 12.617
([['カルト', 'カルコ']], array([[-9.52784336, -9.55611947]]))


Epoch 13: Train Loss 12.390, Valid Loss 12.587
([['スタン', 'スタル']], array([[-8.73756421, -9.43957877]]))


Epoch 14: Train Loss 12.335, Valid Loss 12.572
([['スタン', 'コック']], array([[-9.01734382, -9.44148189]]))


In [6]:
for e in range(15, 100):
    loss = train.run_one_epoch(our_train_dataset,
                               True,
                               from_script='en',
                               to_script='ja',
                               encoder=ja_encoder,
                               decoder=ja_decoder,
                               optimizer=optimizer,
                               loss_function=loss_function)
    valid_loss = train.run_one_epoch(our_valid_dataset,
                                     False,
                                     from_script='en',
                                     to_script='ja',
                                     encoder=ja_encoder,
                                     decoder=ja_decoder,
                                     loss_function=loss_function)
    if best_val_loss is None or valid_loss < best_val_loss:
        best_val_loss = valid_loss
        if checkpoint is None:
            checkpoint = make_checkpoint_obj()
        ja_checkpoint = checkpoint.save(file_prefix=checkpoint_prefix)
    print("Epoch {}: Train Loss {:.3f}, Valid Loss {:.3f}".format(e, loss, valid_loss))
    print(decode.transliterate(input_strs=['derick'],
                               from_script='en',
                               to_script='ja',
                               encoder=ja_encoder,
                               decoder=ja_decoder,
                               k_best=2,
                               decoding_method=decode.beam_search_decode))

Epoch 15: Train Loss 12.273, Valid Loss 12.532
([['スタン', 'シェル']], array([[-8.81433985, -9.17567664]]))


Epoch 16: Train Loss 12.215, Valid Loss 12.519
([['ファック', 'フォル']], array([[-9.50694835, -9.61435378]]))


Epoch 17: Train Loss 12.141, Valid Loss 12.489
([['フォル', 'フォー']], array([[-8.9566741 , -9.22642171]]))


Epoch 18: Train Loss 12.069, Valid Loss 12.501
([['コック', 'コッコ']], array([[-8.66506797, -9.0536558 ]]))


Epoch 19: Train Loss 11.987, Valid Loss 12.483
([['スタン', 'コック']], array([[-8.37900326, -8.60699941]]))


Epoch 20: Train Loss 11.901, Valid Loss 12.471
([['アルツィ', 'アイチャ']], array([[-10.04967617, -10.07043385]]))


Epoch 21: Train Loss 11.805, Valid Loss 12.493
([['スタン', 'コック']], array([[-7.99442369, -8.3191779 ]]))


Epoch 22: Train Loss 11.704, Valid Loss 12.481
([['スター', 'コック']], array([[-7.8888122 , -7.98038058]]))


Epoch 23: Train Loss 11.593, Valid Loss 12.530
([['コック', 'コップ']], array([[-7.92117533, -7.98976117]]))


Epoch 24: Train Loss 11.475, Valid Loss 12.601
([['ダック', 'コック']], array([[-7.759684, -8.314972]]))


Epoch 25: Train Loss 11.350, Valid Loss 12.611
([['スルム', 'コック']], array([[-7.83733969, -8.10845034]]))


Epoch 26: Train Loss 11.223, Valid Loss 12.615
([['ダック', 'ドック']], array([[-7.14815707, -7.30802326]]))


Epoch 27: Train Loss 11.077, Valid Loss 12.712
([['ダック', 'コック']], array([[-7.29989286, -7.76347581]]))


Epoch 28: Train Loss 10.940, Valid Loss 12.752
([['ドール', 'コール']], array([[-7.37191931, -7.376223  ]]))


Epoch 29: Train Loss 10.791, Valid Loss 12.885
([['ジャル', 'ジュル']], array([[-5.98720983, -6.12350587]]))


Epoch 30: Train Loss 10.636, Valid Loss 12.904
([['ドルム', 'ドーム']], array([[-6.76180489, -6.77989257]]))


Epoch 31: Train Loss 10.475, Valid Loss 13.028
([['ドルム', 'コルム']], array([[-6.34475897, -6.44341401]]))


Epoch 32: Train Loss 10.315, Valid Loss 13.122
([['コルム', 'ドーム']], array([[-6.43668607, -6.52390778]]))


KeyboardInterrupt: 

In [7]:
checkpoint.restore(ja_checkpoint).assert_consumed()
train.run_one_epoch(our_valid_dataset,
                    False,
                    from_script='en',
                    to_script='ja',
                    encoder=ja_encoder,
                    decoder=ja_decoder,
                    loss_function=loss_function)

<tf.Tensor: id=64540676, shape=(), dtype=float32, numpy=12.489432>

In [8]:
train.run_one_epoch(eob_valid_dataset,
                    False,
                    from_script='en',
                    to_script='ja',
                    encoder=ja_encoder,
                    decoder=ja_decoder,
                    loss_function=loss_function)

<tf.Tensor: id=64689840, shape=(), dtype=float32, numpy=33.10506>

In [10]:
train.run_one_epoch(muse_valid_dataset,
                    False,
                    from_script='en',
                    to_script='ja',
                    encoder=ja_encoder,
                    decoder=ja_decoder,
                    loss_function=loss_function)

<tf.Tensor: id=64756464, shape=(), dtype=float32, numpy=20.740007>

In [11]:
valid_df = pd.read_csv('../data/split/muse_pairs_valid.csv',
                       keep_default_na=False)

In [12]:
tr = decode.transliterate(input_strs=valid_df['en'].values,
                          from_script='en',
                          to_script='ja',
                          encoder=ja_encoder,
                          decoder=ja_decoder,
                          k_best=20,
                          num_beams=40,
                          decoding_method=decode.beam_search_decode)

In [13]:
evaluate.top_k_accuracy(valid_df['ja'].values, tr, k=1)

0.0047489823609226595

In [14]:
evaluate.top_k_accuracy(valid_df['ja'].values, tr, k=10)

0.018995929443690638