Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Modify Return Signature of TorchScript BERT
Browse files Browse the repository at this point in the history
Summary:
* Return actual answer text instead of spans. (Blank means no answer, no need for exception handling in caller.)
* Return answer confidence score.
* Return has_answer score.

Differential Revision: D17983996

fbshipit-source-id: fc3084681e9d1b0453b8e4a28eeb1293dcb4d541
  • Loading branch information
Debojeet Chatterjee authored and facebook-github-bot committed Oct 17, 2019
1 parent 8f93ce1 commit a7cdb3a
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions pytext/models/output_layers/squad_output_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def get_position_preds(
max_span_length: int,
):
# the following is to enforce end_pos > start_pos. We create a matrix
# of start_positions X end_positions, fill it with the sum logits,
# of start_position X end_position, fill it with the sum logits,
# then mask it to be upper-triangular
# e.g. start_pos_logits = [1, 3, 0, 5, 2]
# end_pos_logits = [2, 4, 6, 3, 5]
Expand Down Expand Up @@ -94,14 +94,10 @@ def get_position_preds(
for i in range(logit_sum_matrix.size()[1]):
logit_sum_matrix[:, i, i + max_span_length :] = 0
vals, ids = logit_sum_matrix.max(-1)
_, start_positions = vals.max(-1)
end_positions = ids.gather(-1, start_positions.unsqueeze(-1)).squeeze(-1)
_, start_position = vals.max(-1)
end_position = ids.gather(-1, start_position.unsqueeze(-1)).squeeze(-1)

return (
start_positions,
end_positions,
logit_sum_matrix[0, start_positions, end_positions],
)
return start_position, end_position

def get_pred(
self,
Expand All @@ -110,7 +106,7 @@ def get_pred(
contexts: Dict[str, List[Any]],
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
start_pos_logits, end_pos_logits, has_answer_logits = logits
start_pos_preds, end_pos_preds, _ = self.get_position_preds(
start_pos_preds, end_pos_preds = self.get_position_preds(
start_pos_logits, end_pos_logits, self.max_answer_len
)
has_answer_preds = has_answer_logits.argmax(-1)
Expand Down

0 comments on commit a7cdb3a

Please sign in to comment.