Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Jul 29, 2020
1 parent 9623240 commit 995e5d7
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/gluonnlp/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ def incremental_decode(self, F, data, states, mem, mem_valid_length, mem_attn_ma
step_value = F.npx.reshape(step_value, (-2, -2, self._num_heads, -1))
new_key = F.np.concatenate([prev_key, step_key], axis=time_axis)
new_value = F.np.concatenate([prev_value, step_value], axis=time_axis)
out, _ = self.self_attention(step_query, new_key, new_value, None)
out, [_, attn_weight] = self.self_attention(step_query, new_key, new_value, None)
out = self.proj_in(out)
out = self.dropout_layer(out)
out = out + data
Expand Down Expand Up @@ -1209,7 +1209,7 @@ def decode_seq(self, F, tgt_data, tgt_valid_length, mem_data, mem_valid_length):
tgt_data = tgt_data + F.np.expand_dims(self.tgt_pos_embed_layer(
F.npx.arange_like(tgt_data, axis=0)), axis=1)
if self.layernorm_embedding:
tgt_data = self.src_embed_ln(tgt_data)
tgt_data = self.tgt_embed_ln(tgt_data)
dec_out = self.decoder(tgt_data, tgt_valid_length, mem_data, mem_valid_length)
return dec_out

Expand Down

0 comments on commit 995e5d7

Please sign in to comment.