-
Notifications
You must be signed in to change notification settings - Fork 64
/
train.py
46 lines (42 loc) · 2.59 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import tensorflow as tf
from data_helpers import loadDataset, getBatches, sentence2enco
from model import Seq2SeqModel
from tqdm import tqdm
import math
import os
tf.app.flags.DEFINE_integer('rnn_size', 1024, 'Number of hidden units in each layer')
tf.app.flags.DEFINE_integer('num_layers', 2, 'Number of layers in each encoder and decoder')
tf.app.flags.DEFINE_integer('embedding_size', 1024, 'Embedding dimensions of encoder and decoder inputs')
tf.app.flags.DEFINE_float('learning_rate', 0.0001, 'Learning rate')
tf.app.flags.DEFINE_integer('batch_size', 128, 'Batch size')
tf.app.flags.DEFINE_integer('numEpochs', 30, 'Maximum # of training epochs')
tf.app.flags.DEFINE_integer('steps_per_checkpoint', 100, 'Save model checkpoint every this iteration')
tf.app.flags.DEFINE_string('model_dir', 'model/', 'Path to save model checkpoints')
tf.app.flags.DEFINE_string('model_name', 'chatbot.ckpt', 'File name used for model checkpoints')
FLAGS = tf.app.flags.FLAGS
data_path = 'E:\PycharmProjects\seq2seq_chatbot\seq2seq_chatbot_new\data\dataset-cornell-length10-filter1-vocabSize40000.pkl'
word2id, id2word, trainingSamples = loadDataset(data_path)
with tf.Session() as sess:
model = Seq2SeqModel(FLAGS.rnn_size, FLAGS.num_layers, FLAGS.embedding_size, FLAGS.learning_rate, word2id,
mode='train', use_attention=True, beam_search=False, beam_size=5, max_gradient_norm=5.0)
ckpt = tf.train.get_checkpoint_state(FLAGS.model_dir)
if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
print('Reloading model parameters..')
model.restore(sess, ckpt.model_checkpoint_path)
else:
print('Created new model parameters..')
sess.run(tf.global_variables_initializer())
current_step = 0
summary_writer = tf.summary.FileWriter(FLAGS.model_dir, graph=sess.graph)
for e in range(FLAGS.numEpochs):
print("----- Epoch {}/{} -----".format(e + 1, FLAGS.numEpochs))
batches = getBatches(trainingSamples, FLAGS.batch_size)
for nextBatch in tqdm(batches, desc="Training"):
loss, summary = model.train(sess, nextBatch)
current_step += 1
if current_step % FLAGS.steps_per_checkpoint == 0:
perplexity = math.exp(float(loss)) if loss < 300 else float('inf')
tqdm.write("----- Step %d -- Loss %.2f -- Perplexity %.2f" % (current_step, loss, perplexity))
summary_writer.add_summary(summary, current_step)
checkpoint_path = os.path.join(FLAGS.model_dir, FLAGS.model_name)
model.saver.save(sess, checkpoint_path, global_step=current_step)