forked from matpalm/snli_nn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
nn_baseline.py
executable file
·230 lines (209 loc) · 10.3 KB
/
nn_baseline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
#!/usr/bin/env python
import argparse
from concat_with_softmax import ConcatWithSoftmax
from dropout import APPLY_DROPOUT, NO_DROPOUT
from embeddings import Embeddings, TiedEmbeddings
from gru_rnn import GruRnn
import itertools
import json
import numpy as np
import os
import random
from simple_rnn import SimpleRnn
from sklearn.metrics import confusion_matrix
from stats import Stats
import sys
import time
import theano
import theano.tensor as T
import tokenise_parse
import util
from updates import *
from vocab import Vocab
parser = argparse.ArgumentParser()
parser.add_argument("--train-set", default="data/snli_1.0_train.jsonl")
parser.add_argument("--num-from-train", default=-1, type=int,
help='number of egs to read from train. -1 => all')
parser.add_argument("--dev-set", default="data/snli_1.0_dev.jsonl")
parser.add_argument("--num-from-dev", default=-1, type=int,
help='number of egs to read from dev. -1 => all')
parser.add_argument("--dev-run-freq", default=100000, type=int,
help='frequency (in num examples trained) to run against dev set')
parser.add_argument("--num-epochs", default=-1, type=int,
help='number of epoches to run. -1 => forever')
parser.add_argument("--max-run-time-sec", default=-1, type=int,
help='max secs to run before early stopping. -1 => dont early stop')
parser.add_argument('--learning-rate', default=0.01, type=float, help='learning rate')
parser.add_argument('--momentum', default=0., type=float,
help='momentum (when applicable)')
parser.add_argument('--update-fn', default='vanilla',
help='vanilla (sgd), momentum or rmsprop. not applied to embeddings')
parser.add_argument('--hidden-dim', default=50, type=int,
help='hidden node dimensionality')
parser.add_argument('--bidirectional', action='store_true',
help='whether to build bidirectional rnns for s1 & s2')
parser.add_argument('--embedding-dim', default=100, type=int,
help='embedding node dimensionality')
parser.add_argument('--tied-embeddings', action='store_true',
help='whether to tie embeddings for each rnn')
parser.add_argument('--initial-embeddings',
help='initial embeddings npy file. for now only applicable if'
' --tied-embeddings. requires --vocab-file')
parser.add_argument('--vocab-file',
help='vocab (token -> idx) for embeddings,'
' required if using --initial-embeddings')
parser.add_argument('--l2-penalty', default=0.0001, type=float,
help='l2 penalty for params')
parser.add_argument('--rnn-type', default="SimpleRnn",
help='rnn cell type {SimpleRnn,GruRnn}')
parser.add_argument('--gru-initial-bias', default=2, type=int,
help='initial gru bias for r & z. higher => more like SimpleRnn')
parser.add_argument('--swap-symmetric-examples', action='store_true',
help='if set we flip s1/s2 for symmetric labels (contra or neutral')
parser.add_argument('--dump-norms', action='store_true',
help='dump l2 norms of all params with stats')
parser.add_argument('--keep-prob', default=1.0, type=float,
help='post concat, pre MLP, dropout keep probability. 1.0 => noop')
parser.add_argument('--parse-mode', default='BINARY_WITHOUT_PARENTHESIS',
help='what parse type to use; BINARY_WITHOUT_PARENTHESIS'
'| BINARY_WITH_PARENTHESIS | PARSE_WITH_OPEN_CLOSE_TAGS')
opts = parser.parse_args()
print >>sys.stderr, opts
# check that if one of --vocab--file or --initial_embeddings is set, they are both set.
assert not ((opts.vocab_file is None) ^ (opts.initial_embeddings is None)), "must set both"
# furthermore these are only valid if tied embeddings (at least for now that's all
# implemented)
if opts.vocab_file and not opts.tied_embeddings:
raise Exception("must set --tied-embeddings if using pre initialised embeddings")
# sanity check other opts
assert opts.keep_prob >= 0.0 and opts.keep_prob <= 1.0
NUM_LABELS = 3
def log(s):
print >>sys.stderr, util.dts(), s
# slurp training data, including converting of tokens -> ids
# if opts.vocab_file set read from that file, otherwise populate lookups as used
vocab = Vocab(opts.vocab_file)
train_x, train_y, train_stats = util.load_data(opts.train_set, vocab,
update_vocab=True,
max_egs=int(opts.num_from_train),
parse_mode=opts.parse_mode)
log("train_stats %s %s" % (len(train_x), train_stats))
dev_x, dev_y, dev_stats = util.load_data(opts.dev_set, vocab,
update_vocab=False,
max_egs=int(opts.num_from_dev),
parse_mode=opts.parse_mode)
log("dev_stats %s %s" % (len(dev_x), dev_stats))
# input/output example vars
s1_idxs = T.ivector('s1') # sequence for sentence one
s2_idxs = T.ivector('s2') # sequence for sentence two
actual_y = T.ivector('y') # single for sentence pair label; 0, 1 or 2
# dropout keep prob for post concat, pre MLP
apply_dropout = T.bscalar('apply_dropout') # dropout.{APPLY_DROPOUT|NO_DROPOUT}
keep_prob = theano.shared(opts.keep_prob) # recall 1.0 => noop
keep_prob = T.cast(keep_prob, 'float32') # shared weirdity, how to set in init (?)
# keep track of different "layers" that handle their own gradients.
# includes rnns, final concat & softmax and, potentially, special handling for
# tied embeddings
layers = []
# decide set of sequence idxs we'll be processing. there will always the two
# for the forward passes over s1 and s2 and, optionally, two more for the
# reverse pass over s1 & s2 in the bidirectional case.
idxs = [s1_idxs, s2_idxs]
names = ["s1f", "s2f"]
if opts.bidirectional:
idxs.extend([s1_idxs[::-1], s2_idxs[::-1]])
names.extend(["s1b", "s2b"])
# build embedding layers. we know we will build an rnn for each sequence idx but depending
# on whether we are using tied embeddings there will be either 1 global embedding matrix
# (whose gradients are managed by TiedEmbeddings) or there will be 1 embeddings matrix per
# rnn (whose gradients are managed by the rnn itself). we build one embedding obj per
# element in idxs
embeddings = None
def build_embedding(idxs=None, sequence_embeddings=None):
return Embeddings(vocab.size(), opts.embedding_dim, idxs=idxs,
sequence_embeddings=sequence_embeddings)
if opts.tied_embeddings:
# make shared tied embeddings helper
tied_embeddings = TiedEmbeddings(vocab.size(), opts.embedding_dim,
opts.initial_embeddings)
layers.append(tied_embeddings)
# embeddings rnn per idx slices. rnn don't maintain their own embeddings in this case.
slices = tied_embeddings.slices_for_idxs(idxs)
embeddings = [build_embedding(sequence_embeddings=s) for s in slices]
else:
# no tied embeddings; each rnn handles it's own weights
embeddings = [build_embedding(idxs=i) for i in idxs]
layers.extend(embeddings)
# build rnns over these embedded sequences
h0 = theano.shared(np.zeros(opts.hidden_dim, dtype='float32'), name='h0', borrow=True)
rnn_fn = globals().get(opts.rnn_type)
if rnn_fn is None:
raise Exception("unknown rnn type [%s]" % opts.rnn_type)
update_fn = globals().get(opts.update_fn)
if update_fn is None:
raise Exception("unknown update function [%s]" % opts.update_fn)
rnns = [rnn_fn("", opts.embedding_dim, opts.hidden_dim, opts, update_fn, h0,
inputs=e.embeddings()) for e in embeddings]
# concat final states of rnns, do a final linear combo and apply softmax for prediction.
final_rnn_states = [rnn.final_state() for rnn in rnns]
concat_with_softmax = ConcatWithSoftmax(final_rnn_states, NUM_LABELS, opts.hidden_dim,
update_fn, apply_dropout, keep_prob)
layers.append(concat_with_softmax)
prob_y, pred_y = concat_with_softmax.prob_pred()
# calc l2_sum across all params
params = [l.params_for_l2_penalty() for l in layers]
l2_sum = sum([(p**2).sum() for p in itertools.chain(*params)])
# calculate cost ; xent + l2 penalty
cross_entropy_cost = T.mean(T.nnet.categorical_crossentropy(prob_y, actual_y))
l2_cost = opts.l2_penalty * l2_sum
total_cost = cross_entropy_cost + l2_cost
# calculate updates
updates = []
for layer in layers:
updates.extend(layer.updates_wrt_cost(total_cost, opts))
log("compiling")
train_fn = theano.function(inputs=[apply_dropout, s1_idxs, s2_idxs, actual_y],
outputs=[total_cost],
updates=updates)
test_fn = theano.function(inputs=[apply_dropout, s1_idxs, s2_idxs, actual_y],
outputs=[pred_y, total_cost])
def stats_from_dev_set(stats):
actuals = []
predicteds = []
for (s1, s2), y in zip(dev_x, dev_y):
pred_y, cost = test_fn(NO_DROPOUT, s1, s2, [y])
actuals.append(y)
predicteds.append(pred_y)
stats.record_dev_cost(cost)
dev_c = confusion_matrix(actuals, predicteds)
dev_accuracy = util.accuracy(dev_c)
stats.set_dev_accuracy(dev_accuracy)
print "dev confusion\n %s (%s)" % (dev_c, dev_accuracy)
log("training")
epoch = 0
training_early_stop_time = opts.max_run_time_sec + time.time()
stats = Stats(os.path.basename(__file__), opts)
egs = zip(train_x, train_y)
while epoch != opts.num_epochs:
random.shuffle(egs)
for (s1, s2), y in egs:
# we may choose to swap s1/s2 for symmetric examples; i.e. contradictions
# and neutral statements.
flip_s1_s2 = opts.swap_symmetric_examples and util.coin_flip() and \
util.symmetric_example(y)
if flip_s1_s2:
cost, = train_fn(APPLY_DROPOUT, s2, s1, [y])
else:
cost, = train_fn(APPLY_DROPOUT, s1, s2, [y])
stats.record_training_cost(cost)
early_stop = False
if opts.max_run_time_sec != -1 and time.time() > training_early_stop_time:
early_stop = True
if stats.n_egs_trained % opts.dev_run_freq == 0 or early_stop:
stats_from_dev_set(stats)
if opts.dump_norms:
stats.set_param_norms(util.norms(layers))
stats.flush_to_stdout(epoch)
if early_stop:
exit(0)
epoch += 1