Skip to content

Commit

Permalink
Hotfix tweak convolution
Browse files Browse the repository at this point in the history
  • Loading branch information
philip-a committed Jul 18, 2018
1 parent 9f43aa6 commit f85ab03
Showing 1 changed file with 6 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,15 @@ def compose(self, composed_words, sample_size, batch_size):
now_idx += 1
# Composing
outputs = [[[None for _ in range(seq_len[i,j])] for j in range(batch_size)] for i in range(sample_size)]
expr_list = []
for batch, batch_map, batch_word in zip(batches, batch_maps, batch_words):
self.set_words(batch_word)
results = self.transduce(dy.concatenate_to_batch(batch))
results.value()
for idx, (sample_num, batch_num, position) in batch_map.items():
outputs[sample_num][batch_num][position] = dy.pick_batch_elem(results, idx)
expr_list.append(dy.pick_batch_elem(results, idx))
outputs[sample_num][batch_num][position] = expr_list[-1]
dy.forward(expr_list)
return outputs

@handle_xnmt_event
Expand Down Expand Up @@ -125,7 +129,7 @@ def transduce(self, encodings):
pad = dy.zeros((self.embed_dim, self.ngram_size-dim[0][1]))
inp = dy.concatenate([inp, pad], d=1)
dim = inp.dim()
inp = dy.reshape(inp, (1, dim[0][1], dim[0][0]))
inp = dy.reshape(inp, (1, dim[0][1], dim[0][0]), batch_size=dim[1])
encodings = dy.rectify(dy.conv2d_bias(inp, dy.parameter(self.filter), dy.parameter(self.bias), stride=(1, 1), is_valid=True))
return dy.max_dim(dy.max_dim(encodings, d=1), d=0)

Expand Down

0 comments on commit f85ab03

Please sign in to comment.