From b8c85bbe5099730867773b480350706eb0b08cef Mon Sep 17 00:00:00 2001 From: ZheyuYe Date: Mon, 6 Jul 2020 12:28:56 +0800 Subject: [PATCH] fix ModelForQABasic --- scripts/question_answering/models.py | 56 +++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/scripts/question_answering/models.py b/scripts/question_answering/models.py index 7458024840..7786c98172 100644 --- a/scripts/question_answering/models.py +++ b/scripts/question_answering/models.py @@ -58,13 +58,67 @@ def hybrid_forward(self, F, tokens, token_types, valid_length, p_mask): contextual_embeddings = self.backbone(tokens, token_types, valid_length) else: contextual_embeddings = self.backbone(tokens, valid_length) - scores = self.qa_outputs(contextual_embedding) + scores = self.qa_outputs(contextual_embeddings) start_scores = scores[:, :, 0] end_scores = scores[:, :, 1] start_logits = masked_logsoftmax(F, start_scores, mask=p_mask, axis=-1) end_logits = masked_logsoftmax(F, end_scores, mask=p_mask, axis=-1) return start_logits, end_logits + def inference(self, tokens, token_types, valid_length, p_mask, + start_top_n: int = 5, end_top_n: int = 5): + """Get the inference result with beam search + + Parameters + ---------- + tokens + The input tokens. Shape (batch_size, sequence_length) + token_types + The input token types. Shape (batch_size, sequence_length) + valid_length + The valid length of the tokens. Shape (batch_size,) + p_mask + The mask which indicates that some tokens won't be used in the calculation. + Shape (batch_size, sequence_length) + start_top_n + The number of candidates to select for the start position. + end_top_n + The number of candidates to select for the end position. + + Returns + ------- + start_top_logits + The top start logits + Shape (batch_size, start_top_n) + start_top_index + Index of the top start logits + Shape (batch_size, start_top_n) + end_top_logits + The top end logits. + Shape (batch_size, end_top_n) + end_top_index + Index of the top end logits + Shape (batch_size, end_top_n) + """ + # Shape (batch_size, sequence_length, C) + if self.use_segmentation: + contextual_embeddings = self.backbone(tokens, token_types, valid_length) + else: + contextual_embeddings = self.backbone(tokens, valid_length) + scores = self.qa_outputs(contextual_embeddings) + start_scores = scores[:, :, 0] + end_scores = scores[:, :, 1] + start_logits = masked_logsoftmax(mx.nd, start_scores, mask=p_mask, axis=-1) + end_logits = masked_logsoftmax(mx.nd, end_scores, mask=p_mask, axis=-1) + # The shape of start_top_index will be (..., start_top_n) + start_top_logits, start_top_index = mx.npx.topk(start_logits, k=start_top_n, axis=-1, + ret_typ='both') + # Note that end_top_index and end_top_log_probs have shape (bsz, start_n_top, end_n_top) + # So that for each start position, there are end_n_top end positions on the third dim. + end_top_logits, end_top_index = mx.npx.topk(end_logits, k=end_top_n, axis=-1, + ret_typ='both') + return start_top_logits, start_top_index, end_top_logits, end_top_index + @use_np class ModelForQAConditionalV1(HybridBlock):