Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Implement SquadQA tensorizer in TorchScript
Browse files Browse the repository at this point in the history
Summary: Implement SquadQA tensorizer in TorchScript

Differential Revision: D18801535

fbshipit-source-id: 0951edff93f22bc51e8d46240db3c1e43a9f9b6c
  • Loading branch information
chenyangyu1988 authored and facebook-github-bot committed Dec 19, 2019
1 parent 6cab6b8 commit b489794
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions pytext/data/squad_for_bert_tensorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,8 @@ def _lookup_tokens(self, text: str, seq_len: int = None):
max_seq_len=seq_len if seq_len else self.max_seq_len,
)

def numberize(self, row):
question_column, doc_column = self.columns
doc_tokens, start_idx, end_idx = self._lookup_tokens(row[doc_column])
question_tokens, _, _ = self._lookup_tokens(row[question_column])
question_tokens = [self.vocab.get_bos_index()] + question_tokens
seq_lens = (len(question_tokens), len(doc_tokens))
segment_labels = ([i] * seq_len for i, seq_len in enumerate(seq_lens))
tokens = list(itertools.chain(question_tokens, doc_tokens))
segment_labels = list(itertools.chain(*segment_labels))
seq_len = len(tokens)
positions = list(range(seq_len))

def _calculate_answer_indices(self, row, offset, start_idx, end_idx):
# now map original answer spans to tokenized spans
offset = len(question_tokens)
start_idx_map = {}
end_idx_map = {}
for tokenized_idx, (raw_start_idx, raw_end_idx) in enumerate(
Expand All @@ -98,6 +86,25 @@ def numberize(self, row):
answer_start_indices = [self.SPAN_PAD_IDX]
answer_end_indices = [self.SPAN_PAD_IDX]

return answer_start_indices, answer_end_indices

def numberize(self, row):
question_column, doc_column = self.columns
doc_tokens, start_idx, end_idx = self._lookup_tokens(row[doc_column])
question_tokens, _, _ = self._lookup_tokens(row[question_column])
question_tokens = [self.vocab.get_bos_index()] + question_tokens
seq_lens = (len(question_tokens), len(doc_tokens))
segment_labels = ([i] * seq_len for i, seq_len in enumerate(seq_lens))
tokens = list(itertools.chain(question_tokens, doc_tokens))
segment_labels = list(itertools.chain(*segment_labels))
seq_len = len(tokens)
positions = list(range(seq_len))

# now map original answer spans to tokenized spans
offset = len(question_tokens)
answer_start_indices, answer_end_indices = self._calculate_answer_indices(
row, offset, start_idx, end_idx
)
return (
tokens,
segment_labels,
Expand Down

0 comments on commit b489794

Please sign in to comment.