diff --git a/horovod/torch/mpi_ops.py b/horovod/torch/mpi_ops.py index 8a177472ca..337396a4cb 100644 --- a/horovod/torch/mpi_ops.py +++ b/horovod/torch/mpi_ops.py @@ -518,8 +518,11 @@ def sparse_allreduce_async(tensor, name, op): values_handle = allgather_async(t._values(), name=f'{name}.values') def handle(): - indices = synchronize(indices_handle) + # We need to sync values handle firstly for torch nightly >= 10.0 + # Issue: https://github.com/horovod/horovod/issues/2961 values = synchronize(values_handle) + indices = synchronize(indices_handle) + values = (values / size()) if op == Average else values if indices.dim() == 0 or values.dim() == 0: diff --git a/test/parallel/test_torch.py b/test/parallel/test_torch.py index 799af627ad..6822fc6f82 100644 --- a/test/parallel/test_torch.py +++ b/test/parallel/test_torch.py @@ -2421,8 +2421,6 @@ def forward(self, x): loss.backward() opt.step() - @pytest.mark.skipif(LooseVersion(torch.__version__) > LooseVersion('1.10.0'), - reason='https://github.com/horovod/horovod/issues/2961') def test_async_sparse_allreduce(self): """Test that allgather over indices and values is equivalent to allreduce.""" hvd.init()