Skip to content

Commit

Permalink
Fix hvd.barrier() tensor queue management and torch test failures fro…
Browse files Browse the repository at this point in the history
…m op name mismatches (#3300)

Signed-off-by: Max H. Gerlach <git@maxgerlach.de>
  • Loading branch information
maxhgerlach committed Dec 9, 2021
1 parent 5af1e22 commit be3b72d
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 57 deletions.
6 changes: 0 additions & 6 deletions horovod/common/controller.cc
Expand Up @@ -273,12 +273,6 @@ ResponseList Controller::ComputeResponseList(bool this_process_requested_shutdow

bool reduce = IncrementTensorCount(message, process_set.joined_size);

// For barrier request, if not ready to reduce, we add it back to tensor queue
// to process in the next cycle.
if(!reduce && message.request_type() == Request::BARRIER) {
tensor_queue_.PushMessageToQueue(message);
}

stall_inspector_.RecordUncachedTensorStart(
message.tensor_name(), message.request_rank(), size_);
if (reduce) {
Expand Down
141 changes: 90 additions & 51 deletions test/parallel/test_torch.py
Expand Up @@ -607,6 +607,7 @@ def test_horovod_allreduce_duplicate_name_error(self):
two concurrent operations with the same name."""
hvd.init()
size = hvd.size()
rank = hvd.rank()

# This test does not apply if there is only one worker.
if size == 1:
Expand All @@ -615,13 +616,22 @@ def test_horovod_allreduce_duplicate_name_error(self):
dims = [17] * 3
tensor = torch.FloatTensor(*dims)

hvd.allreduce_async(tensor, name='duplicate_name')
try:
for i in range(10):
if rank == 0:
hvd.allreduce_async(tensor, name='duplicate_name')
try:
hvd.allreduce_async(tensor, name='duplicate_name')
assert False, 'hvd.allreduce_async did not throw error'
except (torch.FatalError, ValueError):
pass
assert False, 'hvd.allreduce_async did not throw error'
except (torch.FatalError, ValueError):
pass
hvd.allreduce(torch.FloatTensor([1]), name="synch1")
if rank > 0:
hvd.allreduce_async(tensor, name='duplicate_name')
try:
hvd.allreduce_async(tensor, name='duplicate_name')
assert False, 'hvd.allreduce_async did not throw error'
except (torch.FatalError, ValueError):
pass
hvd.allreduce(torch.FloatTensor([2]), name="synch2")

def test_horovod_allreduce_grad(self):
"""Test the correctness of the allreduce gradient."""
Expand Down Expand Up @@ -1213,6 +1223,7 @@ def test_horovod_allgather_duplicate_name_error(self):
two concurrent operations with the same name."""
hvd.init()
size = hvd.size()
rank = hvd.rank()

# This test does not apply if there is only one worker.
if size == 1:
Expand All @@ -1221,13 +1232,22 @@ def test_horovod_allgather_duplicate_name_error(self):
dims = [17] * 3
tensor = torch.FloatTensor(*dims)

hvd.allgather_async(tensor, name='duplicate_name')
try:
for i in range(10):
if rank == 0:
hvd.allgather_async(tensor, name='duplicate_name')
try:
hvd.allgather_async(tensor, name='duplicate_name')
assert False, 'hvd.allgather_async did not throw error'
except (torch.FatalError, ValueError):
pass
assert False, 'hvd.allgather_async did not throw error'
except (torch.FatalError, ValueError):
pass
hvd.allreduce(torch.FloatTensor([1]), name="synch1")
if rank > 0:
hvd.allgather_async(tensor, name='duplicate_name')
try:
hvd.allgather_async(tensor, name='duplicate_name')
assert False, 'hvd.allgather_async did not throw error'
except (torch.FatalError, ValueError):
pass
hvd.allreduce(torch.FloatTensor([2]), name="synch2")

def test_horovod_allgather_grad(self):
"""Test the correctness of the allgather gradient."""
Expand Down Expand Up @@ -1523,6 +1543,7 @@ def test_horovod_broadcast_duplicate_name_error(self):
two concurrent operations with the same name."""
hvd.init()
size = hvd.size()
rank = hvd.rank()

# This test does not apply if there is only one worker.
if size == 1:
Expand All @@ -1531,13 +1552,22 @@ def test_horovod_broadcast_duplicate_name_error(self):
dims = [17] * 3
tensor = torch.FloatTensor(*dims)

hvd.broadcast_async(tensor, root_rank=0, name='duplicate_name')
try:
for i in range(10):
hvd.broadcast_async(tensor, root_rank=0, name='duplicate_name')
assert False, 'hvd.broadcast_async did not throw error'
except (torch.FatalError, ValueError):
pass
if rank == 0:
hvd.broadcast_async(tensor, name='duplicate_name', root_rank=0)
try:
hvd.broadcast_async(tensor, name='duplicate_name', root_rank=0)
assert False, 'hvd.broadcast_async did not throw error'
except (torch.FatalError, ValueError):
pass
hvd.allreduce(torch.FloatTensor([1]), name="synch1")
if rank > 0:
hvd.broadcast_async(tensor, name='duplicate_name', root_rank=0)
try:
hvd.broadcast_async(tensor, name='duplicate_name', root_rank=0)
assert False, 'hvd.broadcast_async did not throw error'
except (torch.FatalError, ValueError):
pass
hvd.allreduce(torch.FloatTensor([2]), name="synch2")

def test_horovod_broadcast_grad(self):
"""Test the correctness of the broadcast gradient."""
Expand Down Expand Up @@ -2743,7 +2773,7 @@ def test_horovod_join_allreduce(self):
integral_types = [torch.IntTensor, torch.LongTensor, torch.cuda.IntTensor, torch.cuda.LongTensor]

dims = [1, 2, 3]
first_join_ranks = [0, 1]
first_join_ranks = list(range(size))
cachings = [False, True]
for dtype, dim, first_join_rank, caching in itertools.product(dtypes, dims, first_join_ranks, cachings):
torch.manual_seed(1234)
Expand Down Expand Up @@ -2814,25 +2844,31 @@ def test_horovod_join_allgather(self):
dims = [17] * 3
tensor = torch.FloatTensor(*dims)

if rank == 0:
if torch.cuda.is_available():
ret = hvd.join(hvd.local_rank())
first_join_ranks = list(range(size))

for first_join_rank in first_join_ranks:
if rank == first_join_rank:
if torch.cuda.is_available():
ret = hvd.join(hvd.local_rank())
else:
ret = hvd.join()
else:
ret = hvd.join()
else:
try:
hvd.allgather(tensor)
assert False, 'hvd.allgather did not throw error'
except (torch.FatalError, RuntimeError):
pass
try:
hvd.allgather(tensor)
assert False, 'hvd.allgather did not throw error'
except (torch.FatalError, RuntimeError):
pass

ret = hvd.join(hvd.local_rank())
if torch.cuda.is_available():
ret = hvd.join(hvd.local_rank())
else:
ret = hvd.join()

self.assertNotEqual(ret, 0,
msg="The return value of hvd.join() may not be equal to 0 because that would be the first rank to join")
ret_values = hvd.allgather_object(ret)
self.assertSequenceEqual(ret_values, [ret] * size,
msg="hvd.join() did not return the same value on each rank")
self.assertNotEqual(ret, first_join_rank,
msg="The return value of hvd.join() may not be equal to first_join_rank")
ret_values = hvd.allgather_object(ret)
self.assertSequenceEqual(ret_values, [ret] * size,
msg="hvd.join() did not return the same value on each rank")

def test_horovod_join_broadcast(self):
"""Test Join op with broadcast."""
Expand All @@ -2847,25 +2883,28 @@ def test_horovod_join_broadcast(self):
dims = [17] * 3
tensor = torch.FloatTensor(*dims)

if rank == 0:
ret = hvd.join(hvd.local_rank())
else:
try:
broadcasted_tensor = hvd.broadcast(tensor, 1, name="test_horovod_join_broadcast")
assert False, 'hvd.broadcast did not throw error'
except (torch.FatalError, RuntimeError):
pass
first_join_ranks = list(range(size))

if torch.cuda.is_available():
for first_join_rank in first_join_ranks:
if rank == first_join_rank:
ret = hvd.join(hvd.local_rank())
else:
ret = hvd.join()
try:
broadcasted_tensor = hvd.broadcast(tensor, rank, name="test_horovod_join_broadcast")
assert False, 'hvd.broadcast did not throw error'
except (torch.FatalError, RuntimeError):
pass

self.assertNotEqual(ret, 0,
msg="The return value of hvd.join() may not be equal to 0 because that would be the first rank to join")
ret_values = hvd.allgather_object(ret)
self.assertSequenceEqual(ret_values, [ret] * size,
msg="hvd.join() did not return the same value on each rank")
if torch.cuda.is_available():
ret = hvd.join(hvd.local_rank())
else:
ret = hvd.join()

self.assertNotEqual(ret, first_join_rank,
msg="The return value of hvd.join() may not be equal to first_join_rank")
ret_values = hvd.allgather_object(ret)
self.assertSequenceEqual(ret_values, [ret] * size,
msg="hvd.join() did not return the same value on each rank")

def test_horovod_sync_batch_norm(self):
"""Tests Horovod version of SyncBatchNorm."""
Expand Down

0 comments on commit be3b72d

Please sign in to comment.