From 6b8791844103296ff29099cc5e2b9b306d5e37cc Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 30 Sep 2021 11:32:40 -0400 Subject: [PATCH] Fix gather for TPU (#13813) --- src/transformers/trainer_pt_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 76d64d1e6c390..164f5f2f950d5 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -152,6 +152,8 @@ def nested_xla_mesh_reduce(tensors, name): if isinstance(tensors, (list, tuple)): return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors)) + if tensors.ndim == 0: + tensors = tensors[None] return xm.mesh_reduce(name, tensors, torch.cat) else: raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")