-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
155 lines (138 loc) · 7.63 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import datetime
import getopt
import sys
import os.path
from word2vec.word2vec import *
from random import choice
import tensorflow as tf
from tqdm import tqdm
from baseline import BaselineModel
from data_utility import *
from word2vec.load_embeddings import load_embedding
###
# Graph execution
###
def mainFunc(argv):
def printUsage():
print('main.py -n <num_cores> -x <experiment>')
print('num_cores = Number of cores requested from the cluster. Set to -1 to leave unset')
print('experiment = experiment setup that should be executed. e.g \'baseline\' or \'attention\'')
print('tag = optional tag or name to distinguish the runs, e.g. \'bidirect3layers\' ')
num_cores = -1
experiment = ""
tag = None
# Command line argument handling
try:
opts, args = getopt.getopt(argv,"n:x:t:",["num_cores=", "experiment=", "tag="])
except getopt.GetoptError:
printUsage()
sys.exit(2)
for opt, arg in opts:
if opt == '-h':
printUsage()
sys.exit()
elif opt in ("-n", "--num_cores"):
num_cores = int(arg)
elif opt in ("-x", "--experiment"):
if arg in ("baseline", "attention"):
experiment = arg
else:
printUsage()
sys.exit(2)
elif opt in ("-t", "--tag"):
tag = arg
print("Executing experiment {} with {} CPU cores".format(experiment, num_cores))
if num_cores != -1:
# We set the op_parallelism_threads in the ConfigProto and pass it to the TensorFlow session
configProto = tf.ConfigProto(inter_op_parallelism_threads=num_cores,
intra_op_parallelism_threads=num_cores)
else:
configProto = tf.ConfigProto()
print("Initializing model")
model = None
if experiment == "baseline":
model = BaselineModel(vocab_size=conf.vocabulary_size,
embedding_size=conf.word_embedding_size,
bidirectional=conf.bidirectional_encoder,
attention=False,
dropout=conf.use_dropout,
num_layers=conf.num_layers,
is_training=True)
elif experiment == "attention":
model = BaselineModel(vocab_size=conf.vocabulary_size,
embedding_size=conf.word_embedding_size,
bidirectional=conf.bidirectional_encoder,
attention=True,
dropout=conf.use_dropout,
num_layers=conf.num_layers,
is_training=True)
assert model != None
print("=== GETTING DATA BY TYPE = TRAIN ===")
enc_inputs, dec_inputs, word_2_index, index_2_word = get_data_by_type('train')
print("***********")
print("Encoder inputs length {}".format(len(enc_inputs)))
print("Decoder inputs length {}".format(len(dec_inputs)))
print("***********")
# Materialize validation data
print("=== GETTING DATA BY TYPE = EVAL ===")
validation_enc_inputs, validation_dec_inputs, _, _ = get_data_by_type('eval')
validation_data = list(bucket_by_sequence_length(validation_enc_inputs, validation_dec_inputs, conf.batch_size, filter_long_sent=False))
print("Starting TensorFlow session")
with tf.Session(config=configProto) as sess:
global_step = 1
saver = tf.train.Saver(max_to_keep=3, keep_checkpoint_every_n_hours=4)
# Init Tensorboard summaries. This will save Tensorboard information into a different folder at each run.
timestamp = '{0:%Y-%m-%d_%H-%M-%S}'.format(datetime.datetime.now())
tag_string = ""
if tag is not None:
tag_string= "-" + tag
train_logfolderPath = os.path.join(conf.log_directory, "{}{}-training-{}".format(experiment, tag_string, timestamp))
train_writer = tf.summary.FileWriter(train_logfolderPath, graph=tf.get_default_graph())
validation_writer = tf.summary.FileWriter("{}{}{}-validation-{}".format(conf.log_directory, experiment, tag_string, timestamp), graph=tf.get_default_graph())
copy_config(train_logfolderPath) # Copies the current config.py to the log directory
sess.run(tf.global_variables_initializer())
if conf.use_word2vec:
print("Using word2vec embeddings")
if not os.path.isfile(conf.word2vec_path):
train_sentences = TRAINING_FILEPATH
train_embeddings(save_to_path=conf.word2vec_path,
embedding_size=conf.word_embedding_size,
minimal_frequency=conf.word2vec_min_word_freq,
train_tuples_path=train_sentences,
validation_path=None,
num_workers=conf.word2vec_workers_count)
print("Loading word2vec embeddings")
load_embedding(sess,
get_or_create_vocabulary(),
model.embedding_matrix,
conf.word2vec_path,
conf.word_embedding_size,
conf.vocabulary_size)
sess.graph.finalize()
print("Starting training")
for i in range(conf.num_epochs):
print("Training epoch {}".format(i))
for data_batch, data_sentence_lengths, label_inputs_batch, label_targets_batch, label_sentence_lengths in tqdm(bucket_by_sequence_length(enc_inputs, dec_inputs, conf.batch_size), total = ceil(len(enc_inputs) / conf.batch_size)):
feed_dict = model.make_train_inputs(data_batch, data_sentence_lengths, label_inputs_batch, label_targets_batch, label_sentence_lengths)
run_options = None
run_metadata = None
if global_step % conf.trace_frequency == 0:
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
_, train_summary = sess.run([model.train_op, model.summary_op], feed_dict, options=run_options, run_metadata=run_metadata)
if global_step % conf.trace_frequency == 0:
train_writer.add_run_metadata(run_metadata, "step{}".format(global_step))
train_writer.add_summary(train_summary, global_step)
if global_step % conf.validation_summary_frequency == 0:#
# Randomly choose a batch from the validation dataset and use it for loss calculation
vali_data_batch, vali_data_sentence_lengths, vali_label_inputs_batch, vali_label_targets_batch, vali_label_sentence_lengths = choice(validation_data)
validation_feed_dict = model.make_train_inputs(vali_data_batch, vali_data_sentence_lengths, vali_label_inputs_batch, vali_label_targets_batch, vali_label_sentence_lengths, keep_prob = 1.0)
validation_summary = sess.run(model.validation_summary_op, validation_feed_dict)
validation_writer.add_summary(validation_summary, global_step)
if global_step % conf.checkpoint_frequency == 0 :
saver.save(sess, os.path.join(train_logfolderPath, "{}{}-{}-ep{}.ckpt".format(experiment, tag_string, timestamp, i)), global_step=global_step)
global_step += 1
saver.save(sess, os.path.join(train_logfolderPath, "{}{}-{}-ep{}-final.ckpt".format(experiment, tag_string, timestamp, conf.num_epochs)))
print("Done with training for {} epochs".format(conf.num_epochs))
if __name__ == "__main__":
mainFunc(sys.argv[1:])