diff --git a/seq2seq/tasks/decode_text.py b/seq2seq/tasks/decode_text.py index eb8d710e..25a0ce5d 100644 --- a/seq2seq/tasks/decode_text.py +++ b/seq2seq/tasks/decode_text.py @@ -146,6 +146,8 @@ def before_run(self, _run_context): if "attention_scores" in self._predictions: fetches["attention_scores"] = self._predictions["attention_scores"] + elif "beam_search_output.original_outputs.attention_scores" in self._predictions: + fetches["beam_search_output.original_outputs.attention_scores"] = self._predictions["beam_search_output.original_outputs.attention_scores"] return tf.train.SessionRunArgs(fetches) @@ -169,7 +171,10 @@ def after_run(self, _run_context, run_values): if self._unk_replace_fn is not None: # We slice the attention scores so that we do not # accidentially replace UNK with a SEQUENCE_END token - attention_scores = fetches["attention_scores"] + if "beam_search_output.original_outputs.attention_scores" in fetches: + attention_scores = fetches["beam_search_output.original_outputs.attention_scores"][:,0,:] + else: + attention_scores = fetches["attention_scores"] attention_scores = attention_scores[:, :source_len - 1] predicted_tokens = self._unk_replace_fn( source_tokens=source_tokens,