In [1]:
from model import tf_estimator_model_fn
from config import args
from data import DataLoader
from utils import greedy_decode, prepare_params

import json
import numpy as np
import tensorflow as tf

tf.logging.set_verbosity(tf.logging.INFO)


def main():
    dl = DataLoader(
        source_path='../temp/dialog_source.txt',
        target_path='../temp/dialog_target.txt')
    sources, targets = dl.load()
    print('Source Vocab Size:', len(dl.source_word2idx))
    print('Target Vocab Size:', len(dl.target_word2idx))
    
    tf_estimator = tf.estimator.Estimator(
        tf_estimator_model_fn, params=prepare_params(dl))
    
    for epoch in range(1):
        tf_estimator.train(tf.estimator.inputs.numpy_input_fn(
            x = {'source':sources, 'target':targets},
            batch_size = args.batch_size,
            num_epochs = 1,
            shuffle = True))
        greedy_decode(['你是谁', '你喜欢我吗', '给我唱一首歌', '我帅吗'], tf_estimator, dl)


if __name__ == '__main__':
    print(json.dumps(args, indent=4))
    main()

{
    "source_max_len": 10,
    "target_max_len": 20,
    "min_freq": 50,
    "hidden_units": 128,
    "num_blocks": 2,
    "num_heads": 8,
    "dropout_rate": 0.1,
    "batch_size": 64,
    "position_encoding": "non_param",
    "activation": "relu",
    "tied_proj_weight": true,
    "tied_embedding": false,
    "label_smoothing": false,
    "lr_decay_strategy": "exp"
}
Source Vocab Size: 2022
Target Vocab Size: 2481
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp8hk_ofiw', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x121a68080>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', 

INFO:tensorflow:global_step/sec: 4.25284
INFO:tensorflow:loss = 3.209977, step = 4301 (23.513 sec)
INFO:tensorflow:lr = 0.0009057327 (23.513 sec)
INFO:tensorflow:global_step/sec: 4.2517
INFO:tensorflow:loss = 2.8030608, step = 4401 (23.520 sec)
INFO:tensorflow:lr = 0.0009036495 (23.520 sec)
INFO:tensorflow:global_step/sec: 4.24093
INFO:tensorflow:loss = 3.0776007, step = 4501 (23.580 sec)
INFO:tensorflow:lr = 0.0009015712 (23.580 sec)
INFO:tensorflow:global_step/sec: 4.2648
INFO:tensorflow:loss = 3.4140217, step = 4601 (23.448 sec)
INFO:tensorflow:lr = 0.0008994976 (23.448 sec)
INFO:tensorflow:global_step/sec: 4.24942
INFO:tensorflow:loss = 2.836907, step = 4701 (23.533 sec)
INFO:tensorflow:lr = 0.00089742884 (23.532 sec)
INFO:tensorflow:global_step/sec: 4.25072
INFO:tensorflow:loss = 2.5549767, step = 4801 (23.525 sec)
INFO:tensorflow:lr = 0.0008953648 (23.525 sec)
INFO:tensorflow:Saving checkpoints for 4809 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmp8hk_ofiw/model.ckpt.