Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

minimal fixes to run DataCollatorForWholeWordMask with return_tensors="np" and return_tensors="tf" #13891

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions src/transformers/data/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,14 +871,14 @@ def tf_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict
inputs, labels = self.tf_mask_tokens(batch_input, batch_mask)
return {"input_ids": inputs, "labels": labels}

def np_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
def numpy_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
if isinstance(examples[0], (dict, BatchEncoding)):
input_ids = [e["input_ids"] for e in examples]
else:
input_ids = examples
examples = [{"input_ids": e} for e in examples]

batch_input = _tf_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
batch_input = _numpy_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)

mask_labels = []
for e in examples:
Expand Down Expand Up @@ -996,15 +996,15 @@ def tf_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
)
labels = inputs.clone()
labels = tf.identity(inputs)
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)

masked_indices = tf.cast(mask_labels, tf.bool)

special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels
]
masked_indices = masked_indices & ~tf.convert_to_tensor(special_tokens_mask, dtype=tf.bool)
masked_indices = masked_indices & ~tf.cast(special_tokens_mask, dtype=tf.bool)
Rocketknight1 marked this conversation as resolved.
Show resolved Hide resolved
if self.tokenizer._pad_token is not None:
padding_mask = inputs == self.tokenizer.pad_token_id
masked_indices = masked_indices & ~padding_mask
Expand Down Expand Up @@ -1060,9 +1060,7 @@ def numpy_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
indices_random = (
np.random.binomial(1, 0.5, size=labels.shape).astype(np.bool) & masked_indices & ~indices_replaced
)
random_words = np.random.randint(
low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64
)
random_words = np.random.randint(low=0, high=len(self.tokenizer), size=labels.shape, dtype=np.int64)
inputs[indices_random] = random_words[indices_random]

# The rest of the time (10% of the time) we keep the masked input tokens unchanged
Expand Down
31 changes: 31 additions & 0 deletions tests/test_data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
DataCollatorForLanguageModeling,
DataCollatorForPermutationLanguageModeling,
DataCollatorForTokenClassification,
DataCollatorForWholeWordMask,
DataCollatorWithPadding,
default_data_collator,
is_tf_available,
Expand Down Expand Up @@ -224,6 +225,16 @@ def test_data_collator_for_language_modeling(self):
pad_features = [list(range(5)), list(range(10))]
self._test_no_pad_and_pad(no_pad_features, pad_features)

def test_data_collator_for_whole_word_mask(self):
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]

tokenizer = BertTokenizer(self.vocab_file)
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="pt")
batch = data_collator(features)

self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 10)))

def test_plm(self):
tokenizer = BertTokenizer(self.vocab_file)
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
Expand Down Expand Up @@ -488,6 +499,16 @@ def test_data_collator_for_language_modeling(self):
pad_features = [list(range(5)), list(range(10))]
self._test_no_pad_and_pad(no_pad_features, pad_features)

def test_data_collator_for_whole_word_mask(self):
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]

tokenizer = BertTokenizer(self.vocab_file)
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="tf")
batch = data_collator(features)

self.assertEqual(batch["input_ids"].shape.as_list(), [2, 10])
self.assertEqual(batch["labels"].shape.as_list(), [2, 10])

def test_plm(self):
tokenizer = BertTokenizer(self.vocab_file)
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
Expand Down Expand Up @@ -750,6 +771,16 @@ def test_data_collator_for_language_modeling(self):
pad_features = [list(range(5)), list(range(10))]
self._test_no_pad_and_pad(no_pad_features, pad_features)

def test_data_collator_for_whole_word_mask(self):
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]

tokenizer = BertTokenizer(self.vocab_file)
data_collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="np")
batch = data_collator(features)

self.assertEqual(batch["input_ids"].shape, (2, 10))
self.assertEqual(batch["labels"].shape, (2, 10))

def test_plm(self):
tokenizer = BertTokenizer(self.vocab_file)
no_pad_features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
Expand Down