From 6b03e1c3391720ad5a7dba9faf4000434fade957 Mon Sep 17 00:00:00 2001 From: Sam Sharpe Date: Tue, 9 Sep 2025 13:32:10 -0400 Subject: [PATCH] add device --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f3edd080e464..1f6efe60966e 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -5614,7 +5614,7 @@ def get_batch_samples( if num_items_in_batch is not None: if self.args.average_tokens_across_devices: - num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum() + num_items_in_batch = self.accelerator.gather(num_items_in_batch.to(device)).sum() if torch.is_tensor(num_items_in_batch): num_items_in_batch = num_items_in_batch.to(device)