In [1]:
# カレントディレクトリをリポジトリ直下にするおまじない
import os
while os.getcwd().split('/')[-1] != 'deep_dialog_tutorial': os.chdir('..')
print('current dir:', os.getcwd())

current dir: /home/harumitsu.nobuta/git/deep_dialog_tutorial


In [2]:
import tensorflow as tf
from deepdialog.transformer.transformer import Transformer
from deepdialog.transformer.preprocess.batch_generator import BatchGenerator

# Create Data

In [3]:
data_path = 'data/natsume.txt'

In [4]:
batch_generator = BatchGenerator()
batch_generator.load(data_path)

In [5]:
vocab_size = batch_generator.vocab_size

# Create Model

In [None]:
graph = tf.Graph()
with graph.as_default():
    transformer = Transformer(
        vocab_size=vocab_size,
        hopping_num=4,
        head_num=8,
        hidden_dim=512,
        dropout_rate=0.1,
        max_length=50,
    )
    transformer.build_graph()

# Create Training Graph

In [None]:
save_dir = 'tmp/learning/transformer/'
log_dir = os.path.join(save_dir, 'log')
ckpt_path = os.path.join(save_dir, 'checkpoints/model.ckpt')

os.makedirs(log_dir, exist_ok=True)

In [None]:
with graph.as_default():
    global_step = tf.train.get_or_create_global_step()
    
    learning_rate = tf.placeholder(dtype=tf.float32, name='learning_rate')
    optimizer = tf.train.AdamOptimizer(
        learning_rate=learning_rate,
        beta2=0.98,
    )
    optimize_op = optimizer.minimize(transformer.loss, global_step=global_step)

    summary_op = tf.summary.merge([
        tf.summary.scalar('train/loss', transformer.loss),
        tf.summary.scalar('train/acc', transformer.acc),
        tf.summary.scalar('train/learning_rate', learning_rate),
    ], name='train_summary')
    summary_writer = tf.summary.FileWriter(log_dir, graph)
    saver = tf.train.Saver()

# Train

In [None]:
max_step = 100000
batch_size = 128
max_learning_rate = 0.0001
warmup_step = 4000

In [None]:
def get_learning_rate(step: int) -> float:
    rate = min(step ** -0.5, step * warmup_step ** -1.5) / warmup_step ** -0.5
    return max_learning_rate * rate

In [None]:
with graph.as_default():
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    step = 0

In [None]:
with graph.as_default():
    for batch in batch_generator.get_batch(batch_size=batch_size):
        feed = {
            **batch,
            learning_rate: get_learning_rate(step + 1),
        }
        _, loss, acc, step, summary = sess.run([optimize_op, transformer.loss, transformer.acc, global_step, summary_op], feed_dict=feed)
        summary_writer.add_summary(summary, step)
        
        if step % 100 == 0:
            print(f'{step}: loss: {loss},\t acc: {acc}')
            saver.save(sess, ckpt_path, global_step=step)

0: loss: 8.456110000610352,	 acc: 0.00042753314482979476
100: loss: 8.063234329223633,	 acc: 0.061625875532627106
200: loss: 7.624130725860596,	 acc: 0.08877591043710709
300: loss: 7.2388014793396,	 acc: 0.15279187262058258
400: loss: 6.831792831420898,	 acc: 0.15193502604961395
800: loss: 6.131741523742676,	 acc: 0.16190476715564728
900: loss: 6.099096298217773,	 acc: 0.16284014284610748
1000: loss: 6.00535774230957,	 acc: 0.17789646983146667
1100: loss: 5.965171813964844,	 acc: 0.175105482339859
1200: loss: 6.056082248687744,	 acc: 0.16189385950565338
1300: loss: 5.734684944152832,	 acc: 0.19673576951026917
1400: loss: 5.750892162322998,	 acc: 0.19291338324546814
1500: loss: 5.762808322906494,	 acc: 0.19190600514411926
1600: loss: 5.654571056365967,	 acc: 0.20242369174957275
1700: loss: 5.622186660766602,	 acc: 0.20016610622406006
1800: loss: 5.621791362762451,	 acc: 0.19756199419498444
1900: loss: 5.568434238433838,	 acc: 0.20691144466400146
2000: loss: 5.44687557220459,	 acc: 0.213