diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 76d64d1e6c390..164f5f2f950d5 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -152,6 +152,8 @@ def nested_xla_mesh_reduce(tensors, name): if isinstance(tensors, (list, tuple)): return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors)) + if tensors.ndim == 0: + tensors = tensors[None] return xm.mesh_reduce(name, tensors, torch.cat) else: raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")