In [1]:
from model import tf_estimator_model_fn
from config import args
from data import DataLoader
from utils import greedy_decode, prepare_params

import json
import numpy as np
import tensorflow as tf

tf.logging.set_verbosity(tf.logging.INFO)


def main():
    dl = DataLoader(
        source_path='../temp/dialog_source.txt',
        target_path='../temp/dialog_target.txt')
    sources, targets = dl.load()
    print('Source Vocab Size:', len(dl.source_word2idx))
    print('Target Vocab Size:', len(dl.target_word2idx))
    
    tf_estimator = tf.estimator.Estimator(
        tf_estimator_model_fn, params=prepare_params(dl))
    
    for epoch in range(2):
        tf_estimator.train(tf.estimator.inputs.numpy_input_fn(
            x = {'source':sources, 'target':targets},
            batch_size = args.batch_size,
            shuffle = True))
        greedy_decode(['你是谁', '你喜欢我吗', '给我唱一首歌', '我帅吗'], tf_estimator, dl)


if __name__ == '__main__':
    print(json.dumps(args, indent=4))
    main()

{
    "source_max_len": 10,
    "target_max_len": 20,
    "min_freq": 50,
    "hidden_units": 128,
    "num_blocks": 2,
    "num_heads": 8,
    "dropout_rate": 0.1,
    "batch_size": 64,
    "position_encoding": "non_param",
    "activation": "relu",
    "tied_proj_weight": true,
    "tied_embedding": false,
    "label_smoothing": false,
    "lr_decay_strategy": "exp"
}
Source Vocab Size: 2022
Target Vocab Size: 2481
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpv4cinv49', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x11917ceb8>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', 

INFO:tensorflow:global_step/sec: 3.94299
INFO:tensorflow:loss = 2.7659223, step = 4301 (25.361 sec)
INFO:tensorflow:lr = 0.0009057327 (25.361 sec)
INFO:tensorflow:global_step/sec: 4.04701
INFO:tensorflow:loss = 2.8845396, step = 4401 (24.710 sec)
INFO:tensorflow:lr = 0.0009036495 (24.710 sec)
INFO:tensorflow:global_step/sec: 4.05306
INFO:tensorflow:loss = 2.4121358, step = 4501 (24.672 sec)
INFO:tensorflow:lr = 0.0009015712 (24.672 sec)
INFO:tensorflow:global_step/sec: 4.07548
INFO:tensorflow:loss = 2.9767103, step = 4601 (24.537 sec)
INFO:tensorflow:lr = 0.0008994976 (24.537 sec)
INFO:tensorflow:global_step/sec: 4.01176
INFO:tensorflow:loss = 3.2390113, step = 4701 (24.927 sec)
INFO:tensorflow:lr = 0.00089742884 (24.927 sec)
INFO:tensorflow:global_step/sec: 3.91639
INFO:tensorflow:loss = 3.0650501, step = 4801 (25.534 sec)
INFO:tensorflow:lr = 0.0008953648 (25.534 sec)
INFO:tensorflow:Saving checkpoints for 4813 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpv4cinv49/model.c

INFO:tensorflow:lr = 0.0008111105 (23.952 sec)
INFO:tensorflow:global_step/sec: 3.96434
INFO:tensorflow:loss = 3.4723723, step = 9193 (25.225 sec)
INFO:tensorflow:lr = 0.00080924504 (25.225 sec)
INFO:tensorflow:global_step/sec: 4.15312
INFO:tensorflow:loss = 3.0881352, step = 9293 (24.078 sec)
INFO:tensorflow:lr = 0.0008073838 (24.078 sec)
INFO:tensorflow:global_step/sec: 4.15924
INFO:tensorflow:loss = 3.28991, step = 9393 (24.043 sec)
INFO:tensorflow:lr = 0.00080552686 (24.043 sec)
INFO:tensorflow:Saving checkpoints for 9447 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpv4cinv49/model.ckpt.
INFO:tensorflow:global_step/sec: 3.48498
INFO:tensorflow:loss = 2.6655982, step = 9493 (28.695 sec)
INFO:tensorflow:lr = 0.0008036742 (28.694 sec)
INFO:tensorflow:global_step/sec: 4.23014
INFO:tensorflow:loss = 2.7164102, step = 9593 (23.640 sec)
INFO:tensorflow:lr = 0.0008018258 (23.640 sec)
INFO:tensorflow:global_step/sec: 4.22662
INFO:tensorflow:loss = 2.833529, step = 9693 (23.659 se