diff --git a/pytext/models/output_layers/squad_output_layer.py b/pytext/models/output_layers/squad_output_layer.py index 864b26c6d..20fd48dc7 100644 --- a/pytext/models/output_layers/squad_output_layer.py +++ b/pytext/models/output_layers/squad_output_layer.py @@ -64,7 +64,7 @@ def get_position_preds( max_span_length: int, ): # the following is to enforce end_pos > start_pos. We create a matrix - # of start_positions X end_positions, fill it with the sum logits, + # of start_position X end_position, fill it with the sum logits, # then mask it to be upper-triangular # e.g. start_pos_logits = [1, 3, 0, 5, 2] # end_pos_logits = [2, 4, 6, 3, 5] @@ -94,14 +94,10 @@ def get_position_preds( for i in range(logit_sum_matrix.size()[1]): logit_sum_matrix[:, i, i + max_span_length :] = 0 vals, ids = logit_sum_matrix.max(-1) - _, start_positions = vals.max(-1) - end_positions = ids.gather(-1, start_positions.unsqueeze(-1)).squeeze(-1) + _, start_position = vals.max(-1) + end_position = ids.gather(-1, start_position.unsqueeze(-1)).squeeze(-1) - return ( - start_positions, - end_positions, - logit_sum_matrix[0, start_positions, end_positions], - ) + return start_position, end_position def get_pred( self, @@ -110,7 +106,7 @@ def get_pred( contexts: Dict[str, List[Any]], ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: start_pos_logits, end_pos_logits, has_answer_logits = logits - start_pos_preds, end_pos_preds, _ = self.get_position_preds( + start_pos_preds, end_pos_preds = self.get_position_preds( start_pos_logits, end_pos_logits, self.max_answer_len ) has_answer_preds = has_answer_logits.argmax(-1)