Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Implement SquadQA tensorizer in TorchScript #1211

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions pytext/data/squad_for_bert_tensorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,8 @@ def _lookup_tokens(self, text: str, seq_len: int = None):
max_seq_len=seq_len if seq_len else self.max_seq_len,
)

def numberize(self, row):
question_column, doc_column = self.columns
doc_tokens, start_idx, end_idx = self._lookup_tokens(row[doc_column])
question_tokens, _, _ = self._lookup_tokens(row[question_column])
question_tokens = [self.vocab.get_bos_index()] + question_tokens
seq_lens = (len(question_tokens), len(doc_tokens))
segment_labels = ([i] * seq_len for i, seq_len in enumerate(seq_lens))
tokens = list(itertools.chain(question_tokens, doc_tokens))
segment_labels = list(itertools.chain(*segment_labels))
seq_len = len(tokens)
positions = list(range(seq_len))

def _calculate_answer_indices(self, row, offset, start_idx, end_idx):
# now map original answer spans to tokenized spans
offset = len(question_tokens)
start_idx_map = {}
end_idx_map = {}
for tokenized_idx, (raw_start_idx, raw_end_idx) in enumerate(
Expand All @@ -98,6 +86,25 @@ def numberize(self, row):
answer_start_indices = [self.SPAN_PAD_IDX]
answer_end_indices = [self.SPAN_PAD_IDX]

return answer_start_indices, answer_end_indices

def numberize(self, row):
question_column, doc_column = self.columns
doc_tokens, start_idx, end_idx = self._lookup_tokens(row[doc_column])
question_tokens, _, _ = self._lookup_tokens(row[question_column])
question_tokens = [self.vocab.get_bos_index()] + question_tokens
seq_lens = (len(question_tokens), len(doc_tokens))
segment_labels = ([i] * seq_len for i, seq_len in enumerate(seq_lens))
tokens = list(itertools.chain(question_tokens, doc_tokens))
segment_labels = list(itertools.chain(*segment_labels))
seq_len = len(tokens)
positions = list(range(seq_len))

# now map original answer spans to tokenized spans
offset = len(question_tokens)
answer_start_indices, answer_end_indices = self._calculate_answer_indices(
row, offset, start_idx, end_idx
)
return (
tokens,
segment_labels,
Expand Down