Skip to content

Commit

Permalink
informative error for bad LSTM input size (#502)
Browse files Browse the repository at this point in the history
* informative error for bad LSTM input size

* fix unit test
  • Loading branch information
msperber committed Aug 3, 2018
1 parent e0f1048 commit 2505ec5
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
2 changes: 1 addition & 1 deletion test/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_segmenting(self):
def test_reload_exception(self):
with self.assertRaises(ValueError) as context:
run.main(["test/config/reload_exception.yaml"])
self.assertEqual(str(context.exception), 'VanillaLSTMGates: x_t has inconsistent dimension')
self.assertEqual(str(context.exception), 'VanillaLSTMGates: x_t has inconsistent dimension 20, expecting 40')

def test_report(self):
run.main(["test/config/report.yaml"])
Expand Down
12 changes: 11 additions & 1 deletion xnmt/transducers/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,19 @@ def transduce(self, expr_seq: 'expression_seqs.ExpressionSequence') -> 'expressi
x_t = [x_t]
elif type(x_t) != list:
x_t = list(x_t)
if sum([x_t_i.dim()[0][0] for x_t_i in x_t]) != self.input_dim:
found_dim = sum([x_t_i.dim()[0][0] for x_t_i in x_t])
raise ValueError(f"VanillaLSTMGates: x_t has inconsistent dimension {found_dim}, expecting {self.input_dim}")
if self.dropout_rate > 0.0 and self.train:
# apply dropout according to https://arxiv.org/abs/1512.05287 (tied weights)
gates_t = dy.vanilla_lstm_gates_dropout_concat(x_t, h[-1], self.Wx[layer_i], self.Wh[layer_i], self.b[layer_i], self.dropout_mask_x[layer_i], self.dropout_mask_h[layer_i], self.weightnoise_std if self.train else 0.0)
gates_t = dy.vanilla_lstm_gates_dropout_concat(x_t,
h[-1],
self.Wx[layer_i],
self.Wh[layer_i],
self.b[layer_i],
self.dropout_mask_x[layer_i],
self.dropout_mask_h[layer_i],
self.weightnoise_std if self.train else 0.0)
else:
gates_t = dy.vanilla_lstm_gates_concat(x_t, h[-1], self.Wx[layer_i], self.Wh[layer_i], self.b[layer_i], self.weightnoise_std if self.train else 0.0)
c_t = dy.vanilla_lstm_c(c[-1], gates_t)
Expand Down

0 comments on commit 2505ec5

Please sign in to comment.