In [None]:
# カレントディレクトリをリポジトリ直下にするおまじない
import os
while os.getcwd().split('/')[-1] != 'deep_dialog_tutorial': os.chdir('..')
print('current dir:', os.getcwd())

In [None]:
import tensorflow as tf
from deepdialog.transformer.transformer import Transformer
from deepdialog.transformer.preprocess.batch_generator import BatchGenerator

# Create Data

In [None]:
data_path = 'data/natsume.txt'

In [None]:
batch_generator = BatchGenerator()
batch_generator.load(data_path)

In [None]:
vocab_size = batch_generator.vocab_size

# Create Model

In [None]:
graph = tf.Graph()
with graph.as_default():
    transformer = Transformer(
        vocab_size=vocab_size,
        hopping_num=4,
        head_num=8,
        hidden_dim=512,
        dropout_rate=0.1,
        max_length=50,
    )
    transformer.build_graph()

# Create Training Graph

In [None]:
save_dir = 'tmp/learning/transformer/'
log_dir = os.path.join(save_dir, 'log')
ckpt_path = os.path.join(save_dir, 'checkpoints/model.ckpt')

os.makedirs(log_dir, exist_ok=True)

In [None]:
with graph.as_default():
    global_step = tf.train.get_or_create_global_step()
    
    learning_rate = tf.placeholder(dtype=tf.float32, name='learning_rate')
    optimizer = tf.train.AdamOptimizer(
        learning_rate=learning_rate,
        beta2=0.98,
    )
    optimize_op = optimizer.minimize(transformer.loss, global_step=global_step)

    summary_op = tf.summary.merge([
        tf.summary.scalar('train/loss', transformer.loss),
        tf.summary.scalar('train/acc', transformer.acc),
        tf.summary.scalar('train/learning_rate', learning_rate),
    ], name='train_summary')
    summary_writer = tf.summary.FileWriter(log_dir, graph)
    saver = tf.train.Saver()

# Train

In [None]:
max_step = 100000
batch_size = 128
max_learning_rate = 0.0001
warmup_step = 4000

In [None]:
def get_learning_rate(step: int) -> float:
    rate = min(step ** -0.5, step * warmup_step ** -1.5) / warmup_step ** -0.5
    return max_learning_rate * rate

In [None]:
with graph.as_default():
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    step = 0

In [None]:
with graph.as_default():
    for batch in batch_generator.get_batch(batch_size=batch_size):
        feed = {
            **batch,
            learning_rate: get_learning_rate(step + 1),
        }
        _, loss, acc, step, summary = sess.run([optimize_op, transformer.loss, transformer.acc, global_step, summary_op], feed_dict=feed)
        summary_writer.add_summary(summary, step)
        
        if step % 100 == 0:
            print(f'{step}: loss: {loss},\t acc: {acc}')
            saver.save(sess, ckpt_path, global_step=step)