diff --git a/pytext/data/squad_for_bert_tensorizer.py b/pytext/data/squad_for_bert_tensorizer.py index 44eda3330..c271be021 100644 --- a/pytext/data/squad_for_bert_tensorizer.py +++ b/pytext/data/squad_for_bert_tensorizer.py @@ -62,20 +62,8 @@ def _lookup_tokens(self, text: str, seq_len: int = None): max_seq_len=seq_len if seq_len else self.max_seq_len, ) - def numberize(self, row): - question_column, doc_column = self.columns - doc_tokens, start_idx, end_idx = self._lookup_tokens(row[doc_column]) - question_tokens, _, _ = self._lookup_tokens(row[question_column]) - question_tokens = [self.vocab.get_bos_index()] + question_tokens - seq_lens = (len(question_tokens), len(doc_tokens)) - segment_labels = ([i] * seq_len for i, seq_len in enumerate(seq_lens)) - tokens = list(itertools.chain(question_tokens, doc_tokens)) - segment_labels = list(itertools.chain(*segment_labels)) - seq_len = len(tokens) - positions = list(range(seq_len)) - + def _calculate_answer_indices(self, row, offset, start_idx, end_idx): # now map original answer spans to tokenized spans - offset = len(question_tokens) start_idx_map = {} end_idx_map = {} for tokenized_idx, (raw_start_idx, raw_end_idx) in enumerate( @@ -98,6 +86,25 @@ def numberize(self, row): answer_start_indices = [self.SPAN_PAD_IDX] answer_end_indices = [self.SPAN_PAD_IDX] + return answer_start_indices, answer_end_indices + + def numberize(self, row): + question_column, doc_column = self.columns + doc_tokens, start_idx, end_idx = self._lookup_tokens(row[doc_column]) + question_tokens, _, _ = self._lookup_tokens(row[question_column]) + question_tokens = [self.vocab.get_bos_index()] + question_tokens + seq_lens = (len(question_tokens), len(doc_tokens)) + segment_labels = ([i] * seq_len for i, seq_len in enumerate(seq_lens)) + tokens = list(itertools.chain(question_tokens, doc_tokens)) + segment_labels = list(itertools.chain(*segment_labels)) + seq_len = len(tokens) + positions = list(range(seq_len)) + + # now map original answer spans to tokenized spans + offset = len(question_tokens) + answer_start_indices, answer_end_indices = self._calculate_answer_indices( + row, offset, start_idx, end_idx + ) return ( tokens, segment_labels,