Skip to content

Commit

Permalink
fix token_ids
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Jun 28, 2020
1 parent ff7aae8 commit 5c0ca43
Showing 1 changed file with 30 additions and 22 deletions.
52 changes: 30 additions & 22 deletions scripts/question_answering/run_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,32 @@ def parse_args():


class SquadDatasetProcessor:

def __init__(self, tokenizer, doc_stride, max_seq_length, max_query_length):
"""
Parameters
----------
tokenizer
The tokenizer
doc_stride
The stride to chunk the document
max_seq_length
Maximum length of the merged data
max_query_length
Maximum query length
"""
self._tokenizer = tokenizer
self._doc_stride = doc_stride
self._max_seq_length = max_seq_length
self._max_query_length = max_query_length

vocab = tokenizer.vocab
self.pad_id = vocab.pad_id
# For roberta model, taking sepecial token <s> as [CLS] and </s> as [SEP]
self.cls_id = vocab.bos_id if 'cls_token' not in vocab.special_token_keys else vocab.cls_id
self.sep_id = vocab.eos_id if 'sep_token' not in vocab.special_token_keys else vocab.sep_id

# TODO(sxjscience) Consider to combine the NamedTuple and batchify functionality.
ChunkFeature = collections.namedtuple('ChunkFeature',
['qas_id',
Expand All @@ -139,7 +165,7 @@ class SquadDatasetProcessor:
'chunk_length'])
BatchifyFunction = bf.NamedTuple(ChunkFeature,
{'qas_id': bf.List(),
'data': bf.Pad(),
'data': bf.Pad(val=self.pad_id),
'valid_length': bf.Stack(),
'segment_ids': bf.Pad(),
'masks': bf.Pad(val=1),
Expand All @@ -150,24 +176,6 @@ class SquadDatasetProcessor:
'chunk_start': bf.Stack(),
'chunk_length': bf.Stack()})

def __init__(self, tokenizer, doc_stride, max_seq_length, max_query_length):
"""
Parameters
----------
tokenizer
The tokenizer
doc_stride
The stride to chunk the document
max_seq_length
Maximum length of the merged data
max_query_length
Maximum query length
"""
self._tokenizer = tokenizer
self._doc_stride = doc_stride
self._max_seq_length = max_seq_length
self._max_query_length = max_query_length

def process_sample(self, feature: SquadFeature):
"""Process the data to the following format.
Expand Down Expand Up @@ -218,10 +226,10 @@ def process_sample(self, feature: SquadFeature):
doc_stride=self._doc_stride,
max_chunk_length=self._max_seq_length - len(truncated_query_ids) - 3)
for chunk in chunks:
data = np.array([self._tokenizer.vocab.cls_id] + truncated_query_ids +
[self._tokenizer.vocab.sep_id] +
data = np.array([self.cls_id] + truncated_query_ids +
[self.sep_id] +
feature.context_token_ids[chunk.start:(chunk.start + chunk.length)] +
[self._tokenizer.vocab.sep_id], dtype=np.int32)
[self.sep_id], dtype=np.int32)
valid_length = len(data)
segment_ids = np.array([0] + [0] * len(truncated_query_ids) +
[0] + [1] * chunk.length + [1], dtype=np.int32)
Expand Down

0 comments on commit 5c0ca43

Please sign in to comment.