From d429bbee5d0a6b6ff6c214a0ff9545832629c036 Mon Sep 17 00:00:00 2001 From: nem6ishi Date: Thu, 6 Jul 2017 16:33:26 +0900 Subject: [PATCH] fixed that unk_replacement and beam_search work together --- seq2seq/tasks/decode_text.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/seq2seq/tasks/decode_text.py b/seq2seq/tasks/decode_text.py index eb8d710e..25a0ce5d 100644 --- a/seq2seq/tasks/decode_text.py +++ b/seq2seq/tasks/decode_text.py @@ -146,6 +146,8 @@ def before_run(self, _run_context): if "attention_scores" in self._predictions: fetches["attention_scores"] = self._predictions["attention_scores"] + elif "beam_search_output.original_outputs.attention_scores" in self._predictions: + fetches["beam_search_output.original_outputs.attention_scores"] = self._predictions["beam_search_output.original_outputs.attention_scores"] return tf.train.SessionRunArgs(fetches) @@ -169,7 +171,10 @@ def after_run(self, _run_context, run_values): if self._unk_replace_fn is not None: # We slice the attention scores so that we do not # accidentially replace UNK with a SEQUENCE_END token - attention_scores = fetches["attention_scores"] + if "beam_search_output.original_outputs.attention_scores" in fetches: + attention_scores = fetches["beam_search_output.original_outputs.attention_scores"][:,0,:] + else: + attention_scores = fetches["attention_scores"] attention_scores = attention_scores[:, :source_len - 1] predicted_tokens = self._unk_replace_fn( source_tokens=source_tokens,