diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index bf05563e9b36..45d546a3997e 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2630,7 +2630,7 @@ def _inner_training_loop( # Since we perform prefetching, we need to manually set sync_gradients self.accelerator.gradient_state._set_sync_gradients(do_sync_step) - if self.args.include_num_input_tokens_seen: + if self.args.include_num_input_tokens_seen not in ["no", False]: main_input_name = getattr(self.model, "main_input_name", "input_ids") if main_input_name not in inputs: logger.warning( @@ -2639,7 +2639,25 @@ def _inner_training_loop( "a `main_input_name` attribute to the model class you are using." ) else: - input_tokens = inputs[main_input_name].numel() + if self.args.include_num_input_tokens_seen == "non_padding": + if "attention_mask" in inputs: + input_tokens = inputs["attention_mask"].sum() + elif ( + self.processing_class is not None + and hasattr(self.processing_class, "pad_token_id") + and self.processing_class.pad_token_id is not None + ): + input_tokens = ( + inputs[main_input_name] != self.processing_class.pad_token_id + ).sum() + else: + logger.warning( + "Could not determine method to count non-padding tokens, falling back to counting all tokens." + ) + input_tokens = inputs[main_input_name].numel() + else: + input_tokens = inputs[main_input_name].numel() + input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64) self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().item() if rng_to_sync: diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 2337edc93b33..a2bdba93d1a6 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1507,10 +1507,14 @@ class TrainingArguments: metadata={"help": "If set to `True`, the speed metrics will include `tgs` (tokens per second per device)."}, ) - include_num_input_tokens_seen: Optional[bool] = field( + include_num_input_tokens_seen: Optional[Union[str, bool]] = field( default=False, metadata={ - "help": "If set to `True`, will track the number of input tokens seen throughout training. (May be slower in distributed training)" + "help": ( + "Whether to track the number of input tokens seen. " + "Can be `'all'` to count all tokens, `'non_padding'` to count only non-padding tokens, " + "or a boolean (`True` maps to `'all'`, `False` to `'no'`)." + ) }, ) @@ -2143,6 +2147,11 @@ def __post_init__(self): ) self.include_for_metrics.append("inputs") + if self.include_num_input_tokens_seen is True: + self.include_num_input_tokens_seen = "all" + elif self.include_num_input_tokens_seen is False: + self.include_num_input_tokens_seen = "no" + def __str__(self): self_as_dict = asdict(self) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 683c76032dd0..e2534687512f 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1297,6 +1297,104 @@ def test_tf32(self): trainer.train() self.check_trained_model(trainer.model) + def test_include_num_input_tokens_seen(self): + model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2) + tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") + tokenizer.pad_token = "[PAD]" + model.config.pad_token_id = tokenizer.pad_token_id + + sentences = ["This is a short sentence.", "This is a much longer sentence that will require padding."] + labels = torch.tensor([0, 1]) + + # 1. Test with attention_mask + tokenized_dataset_with_mask = tokenizer(sentences, truncation=True, padding="longest", return_tensors="pt") + tokenized_dataset_with_mask["labels"] = labels + dataset_with_mask = datasets.Dataset.from_dict(tokenized_dataset_with_mask) + + # 2. Test without attention_mask + tokenized_dataset_no_mask = {k: v for k, v in tokenized_dataset_with_mask.items() if k != "attention_mask"} + dataset_no_mask = datasets.Dataset.from_dict(tokenized_dataset_no_mask) + + # 3. Test with no padding information + tokenizer_no_pad = AutoTokenizer.from_pretrained("bert-base-cased") + tokenizer_no_pad.pad_token = None + + data_collator = default_data_collator + + with tempfile.TemporaryDirectory() as tmp_dir: + # Test case 1: "non_padding" with attention_mask + args = TrainingArguments( + output_dir=tmp_dir, + include_num_input_tokens_seen="non_padding", + per_device_train_batch_size=2, + max_steps=1, + report_to="none", + ) + trainer = Trainer( + model=model, + args=args, + train_dataset=dataset_with_mask, + data_collator=data_collator, + processing_class=tokenizer, + ) + trainer.train() + attention_mask = tokenized_dataset_with_mask["attention_mask"] + non_padded_tokens_with_mask = attention_mask.sum().item() + self.assertEqual(trainer.state.num_input_tokens_seen, non_padded_tokens_with_mask) + + # Test case 2: "non_padding" without attention_mask (fallback to pad_token_id) + trainer = Trainer( + model=model, + args=args, + train_dataset=dataset_no_mask, + data_collator=data_collator, + processing_class=tokenizer, + ) + trainer.train() + input_ids = tokenized_dataset_with_mask["input_ids"] # use original to compute expected + non_padded_tokens_no_mask = (input_ids != tokenizer.pad_token_id).sum().item() + self.assertEqual(trainer.state.num_input_tokens_seen, non_padded_tokens_no_mask) + + # Test case 3: "non_padding" with no padding info (fallback to numel) + with self.assertLogs("transformers.trainer", level="WARNING") as cm: + trainer = Trainer( + model=model, + args=args, + train_dataset=dataset_no_mask, # still has input_ids + data_collator=data_collator, + processing_class=tokenizer_no_pad, # tokenizer without pad token + ) + trainer.train() + self.assertTrue( + any("Could not determine method to count non-padding tokens" in log for log in cm.output) + ) + total_tokens = input_ids.numel() + self.assertEqual(trainer.state.num_input_tokens_seen, total_tokens) + + # Test case 4: "all" + args.include_num_input_tokens_seen = "all" + trainer = Trainer( + model=model, + args=args, + train_dataset=dataset_with_mask, + data_collator=data_collator, + processing_class=tokenizer, + ) + trainer.train() + self.assertEqual(trainer.state.num_input_tokens_seen, total_tokens) + + # Test case 5: True (backward compatibility) + args.include_num_input_tokens_seen = True + trainer = Trainer( + model=model, + args=args, + train_dataset=dataset_with_mask, + data_collator=data_collator, + processing_class=tokenizer, + ) + trainer.train() + self.assertEqual(trainer.state.num_input_tokens_seen, total_tokens) + @require_torch @require_sentencepiece