Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Implement SquadQA tensorizer in TorchScript (#1211)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1211

Implement SquadQA tensorizer in TorchScript

Differential Revision: D18801535

fbshipit-source-id: 1cc929424721eda444dcc20cda3e093b06d114e5
  • Loading branch information
chenyangyu1988 authored and facebook-github-bot committed Dec 20, 2019
1 parent 7ac6144 commit 5f5c328
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions pytext/data/squad_for_bert_tensorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,8 @@ def _lookup_tokens(self, text: str, seq_len: int = None):
max_seq_len=seq_len if seq_len else self.max_seq_len,
)

def numberize(self, row):
question_column, doc_column = self.columns
doc_tokens, start_idx, end_idx = self._lookup_tokens(row[doc_column])
question_tokens, _, _ = self._lookup_tokens(row[question_column])
question_tokens = [self.vocab.get_bos_index()] + question_tokens
seq_lens = (len(question_tokens), len(doc_tokens))
segment_labels = ([i] * seq_len for i, seq_len in enumerate(seq_lens))
tokens = list(itertools.chain(question_tokens, doc_tokens))
segment_labels = list(itertools.chain(*segment_labels))
seq_len = len(tokens)
positions = list(range(seq_len))

def _calculate_answer_indices(self, row, offset, start_idx, end_idx):
# now map original answer spans to tokenized spans
offset = len(question_tokens)
start_idx_map = {}
end_idx_map = {}
for tokenized_idx, (raw_start_idx, raw_end_idx) in enumerate(
Expand All @@ -98,6 +86,25 @@ def numberize(self, row):
answer_start_indices = [self.SPAN_PAD_IDX]
answer_end_indices = [self.SPAN_PAD_IDX]

return answer_start_indices, answer_end_indices

def numberize(self, row):
question_column, doc_column = self.columns
doc_tokens, start_idx, end_idx = self._lookup_tokens(row[doc_column])
question_tokens, _, _ = self._lookup_tokens(row[question_column])
question_tokens = [self.vocab.get_bos_index()] + question_tokens
seq_lens = (len(question_tokens), len(doc_tokens))
segment_labels = ([i] * seq_len for i, seq_len in enumerate(seq_lens))
tokens = list(itertools.chain(question_tokens, doc_tokens))
segment_labels = list(itertools.chain(*segment_labels))
seq_len = len(tokens)
positions = list(range(seq_len))

# now map original answer spans to tokenized spans
offset = len(question_tokens)
answer_start_indices, answer_end_indices = self._calculate_answer_indices(
row, offset, start_idx, end_idx
)
return (
tokens,
segment_labels,
Expand Down

0 comments on commit 5f5c328

Please sign in to comment.