-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
115 lines (102 loc) · 4.95 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import argparse
import tensorflow as tf
import utils
from config.config import *
from model_helper import las_model_fn
def parse_args():
parser = argparse.ArgumentParser(
description='Listen, Attend and Spell(LAS) implementation based on Tensorflow. '
'The model utilizes input pipeline and estimator API of Tensorflow, '
'which makes the training procedure truly end-to-end.')
parser.add_argument('--train', type=str,
help='training data in TFRecord format')
parser.add_argument('--valid', type=str,
help='validation data in TFRecord format')
parser.add_argument('--vocab', type=str,
help='vocabulary table, listing vocabulary line by line')
parser.add_argument('--mapping', type=str,
help='additional mapping when evaluation')
parser.add_argument('--model_dir', type=str,
help='path of saving model')
parser.add_argument('--eval_secs', type=int, default=300,
help='evaluation every N seconds, only happening when `valid` is specified')
parser.add_argument('--encoder_units', type=int, default=128,
help='rnn hidden units of encoder')
parser.add_argument('--encoder_layers', type=int, default=3,
help='rnn layers of encoder')
parser.add_argument('--use_pyramidal', action='store_true',
help='whether to use pyramidal rnn')
parser.add_argument('--decoder_units', type=int, default=128,
help='rnn hidden units of decoder')
parser.add_argument('--decoder_layers', type=int, default=2,
help='rnn layers of decoder')
parser.add_argument('--embedding_size', type=int, default=0,
help='embedding size of target vocabulary, if 0, one hot encoding is applied')
parser.add_argument('--sampling_probability', type=float, default=0.2,
help='sampling probabilty of decoder during training')
parser.add_argument('--attention_type', type=str, default='luong', choices=['luong', 'bahdanau', 'custom'],
help='type of attention mechanism')
parser.add_argument('--attention_layer_size', type=int,
help='size of attention layer, see tensorflow.contrib.seq2seq.AttentionWrapper'
'for more details')
parser.add_argument('--bottom_only', action='store_true',
help='apply attention mechanism only at the bottommost rnn cell')
parser.add_argument('--pass_hidden_state', action='store_true',
help='whether to pass encoder state to decoder')
parser.add_argument('--batch_size', type=int, default=32,
help='batch size')
parser.add_argument('--num_channels', type=int, default=39,
help='number of input channels')
parser.add_argument('--num_epochs', type=int, default=150,
help='number of training epochs')
parser.add_argument('--learning_rate', type=float, default=1e-3,
help='learning rate')
parser.add_argument('--dropout', type=float, default=0.2,
help='dropout rate of rnn cell')
return parser.parse_args()
def input_fn(dataset_filename, vocab_filename, num_channels=39, batch_size=8, num_epochs=1):
dataset = utils.read_dataset(dataset_filename, num_channels)
vocab_table = utils.create_vocab_table(vocab_filename)
dataset = utils.process_dataset(
dataset, vocab_table, utils.SOS, utils.EOS, batch_size, num_epochs)
return dataset
#
# def main(args):
#
# vocab_list = utils.load_vocab(args.vocab)
#
# vocab_size = len(vocab_list)
#
# config = tf.estimator.RunConfig(model_dir=args.model_dir)
# hparams = utils.create_hparams(
# args, vocab_size, utils.SOS_ID, utils.EOS_ID)
#
# model = tf.estimator.Estimator(
# model_fn=las_model_fn,
# config=config,
# params=hparams)
#
# if args.valid:
# train_spec = tf.estimator.TrainSpec(
# input_fn=lambda: input_fn(
# args.train, args.vocab, num_channels=args.num_channels, batch_size=args.batch_size, num_epochs=args.num_epochs))
#
# eval_spec = tf.estimator.EvalSpec(
# input_fn=lambda: input_fn(
# args.valid or args.train, args.vocab, num_channels=args.num_channels, batch_size=args.batch_size),
# start_delay_secs=60,
# throttle_secs=args.eval_secs)
#
# tf.estimator.train_and_evaluate(model, train_spec, eval_spec)
# else:
# model.train(
# input_fn=lambda: input_fn(
# args.train, args.vocab, num_channels=args.num_channels, batch_size=args.batch_size, num_epochs=args.num_epochs))
#
#
# if __name__ == '__main__':
# tf.logging.set_verbosity(tf.logging.INFO)
#
# args = parse_args()
#
# main(args)