Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2630,7 +2630,7 @@ def _inner_training_loop(
# Since we perform prefetching, we need to manually set sync_gradients
self.accelerator.gradient_state._set_sync_gradients(do_sync_step)

if self.args.include_num_input_tokens_seen:
if self.args.include_num_input_tokens_seen not in ["no", False]:
main_input_name = getattr(self.model, "main_input_name", "input_ids")
if main_input_name not in inputs:
logger.warning(
Expand All @@ -2639,7 +2639,25 @@ def _inner_training_loop(
"a `main_input_name` attribute to the model class you are using."
)
else:
input_tokens = inputs[main_input_name].numel()
if self.args.include_num_input_tokens_seen == "non_padding":
if "attention_mask" in inputs:
input_tokens = inputs["attention_mask"].sum()
elif (
self.processing_class is not None
and hasattr(self.processing_class, "pad_token_id")
and self.processing_class.pad_token_id is not None
):
input_tokens = (
inputs[main_input_name] != self.processing_class.pad_token_id
).sum()
else:
logger.warning(
"Could not determine method to count non-padding tokens, falling back to counting all tokens."
)
input_tokens = inputs[main_input_name].numel()
else:
input_tokens = inputs[main_input_name].numel()

input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64)
self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().item()
if rng_to_sync:
Expand Down
13 changes: 11 additions & 2 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1507,10 +1507,14 @@ class TrainingArguments:
metadata={"help": "If set to `True`, the speed metrics will include `tgs` (tokens per second per device)."},
)

include_num_input_tokens_seen: Optional[bool] = field(
include_num_input_tokens_seen: Optional[Union[str, bool]] = field(
default=False,
metadata={
"help": "If set to `True`, will track the number of input tokens seen throughout training. (May be slower in distributed training)"
"help": (
"Whether to track the number of input tokens seen. "
"Can be `'all'` to count all tokens, `'non_padding'` to count only non-padding tokens, "
"or a boolean (`True` maps to `'all'`, `False` to `'no'`)."
)
},
)

Expand Down Expand Up @@ -2143,6 +2147,11 @@ def __post_init__(self):
)
self.include_for_metrics.append("inputs")

if self.include_num_input_tokens_seen is True:
self.include_num_input_tokens_seen = "all"
elif self.include_num_input_tokens_seen is False:
self.include_num_input_tokens_seen = "no"

def __str__(self):
self_as_dict = asdict(self)

Expand Down
98 changes: 98 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,6 +1297,104 @@ def test_tf32(self):
trainer.train()
self.check_trained_model(trainer.model)

def test_include_num_input_tokens_seen(self):
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2)
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
tokenizer.pad_token = "[PAD]"
model.config.pad_token_id = tokenizer.pad_token_id

sentences = ["This is a short sentence.", "This is a much longer sentence that will require padding."]
labels = torch.tensor([0, 1])

# 1. Test with attention_mask
tokenized_dataset_with_mask = tokenizer(sentences, truncation=True, padding="longest", return_tensors="pt")
tokenized_dataset_with_mask["labels"] = labels
dataset_with_mask = datasets.Dataset.from_dict(tokenized_dataset_with_mask)

# 2. Test without attention_mask
tokenized_dataset_no_mask = {k: v for k, v in tokenized_dataset_with_mask.items() if k != "attention_mask"}
dataset_no_mask = datasets.Dataset.from_dict(tokenized_dataset_no_mask)

# 3. Test with no padding information
tokenizer_no_pad = AutoTokenizer.from_pretrained("bert-base-cased")
tokenizer_no_pad.pad_token = None

data_collator = default_data_collator

with tempfile.TemporaryDirectory() as tmp_dir:
# Test case 1: "non_padding" with attention_mask
args = TrainingArguments(
output_dir=tmp_dir,
include_num_input_tokens_seen="non_padding",
per_device_train_batch_size=2,
max_steps=1,
report_to="none",
)
trainer = Trainer(
model=model,
args=args,
train_dataset=dataset_with_mask,
data_collator=data_collator,
processing_class=tokenizer,
)
trainer.train()
attention_mask = tokenized_dataset_with_mask["attention_mask"]
non_padded_tokens_with_mask = attention_mask.sum().item()
self.assertEqual(trainer.state.num_input_tokens_seen, non_padded_tokens_with_mask)

# Test case 2: "non_padding" without attention_mask (fallback to pad_token_id)
trainer = Trainer(
model=model,
args=args,
train_dataset=dataset_no_mask,
data_collator=data_collator,
processing_class=tokenizer,
)
trainer.train()
input_ids = tokenized_dataset_with_mask["input_ids"] # use original to compute expected
non_padded_tokens_no_mask = (input_ids != tokenizer.pad_token_id).sum().item()
self.assertEqual(trainer.state.num_input_tokens_seen, non_padded_tokens_no_mask)

# Test case 3: "non_padding" with no padding info (fallback to numel)
with self.assertLogs("transformers.trainer", level="WARNING") as cm:
trainer = Trainer(
model=model,
args=args,
train_dataset=dataset_no_mask, # still has input_ids
data_collator=data_collator,
processing_class=tokenizer_no_pad, # tokenizer without pad token
)
trainer.train()
self.assertTrue(
any("Could not determine method to count non-padding tokens" in log for log in cm.output)
)
total_tokens = input_ids.numel()
self.assertEqual(trainer.state.num_input_tokens_seen, total_tokens)

# Test case 4: "all"
args.include_num_input_tokens_seen = "all"
trainer = Trainer(
model=model,
args=args,
train_dataset=dataset_with_mask,
data_collator=data_collator,
processing_class=tokenizer,
)
trainer.train()
self.assertEqual(trainer.state.num_input_tokens_seen, total_tokens)

# Test case 5: True (backward compatibility)
args.include_num_input_tokens_seen = True
trainer = Trainer(
model=model,
args=args,
train_dataset=dataset_with_mask,
data_collator=data_collator,
processing_class=tokenizer,
)
trainer.train()
self.assertEqual(trainer.state.num_input_tokens_seen, total_tokens)


@require_torch
@require_sentencepiece
Expand Down