diff --git a/horovod/_keras/__init__.py b/horovod/_keras/__init__.py index 49e2cadc8d..8a35862a7a 100644 --- a/horovod/_keras/__init__.py +++ b/horovod/_keras/__init__.py @@ -28,7 +28,8 @@ def create_distributed_optimizer(keras, optimizer, name, device_dense, device_sparse, compression, sparse_as_dense, gradient_predivide_factor, op, backward_passes_per_step=1, - average_aggregated_gradients=False): + average_aggregated_gradients=False, + num_groups=0): class _DistributedOptimizer(keras.optimizers.Optimizer): _HAS_AGGREGATE_GRAD = True @@ -43,7 +44,8 @@ def __init__(self, **kwargs): compression, sparse_as_dense, op, - gradient_predivide_factor) + gradient_predivide_factor, + num_groups) self._agg_helper = None if backward_passes_per_step > 1: diff --git a/horovod/keras/__init__.py b/horovod/keras/__init__.py index cdd9210fe7..b88a8fa13d 100644 --- a/horovod/keras/__init__.py +++ b/horovod/keras/__init__.py @@ -38,7 +38,8 @@ def DistributedOptimizer(optimizer, name=None, compression=Compression.none, sparse_as_dense=False, gradient_predivide_factor=1.0, - op=Average): + op=Average, + num_groups=0): """ An optimizer that wraps another keras.optimizers.Optimizer, using an allreduce to average gradient values before applying gradients to model weights. @@ -65,6 +66,8 @@ def DistributedOptimizer(optimizer, name=None, gradient_predivide_factor / size after the sum. op: The reduction operation to use when combining gradients across different ranks. Defaults to Average. + num_groups: Number of groups to assign gradient allreduce ops to for explicit + grouping. Defaults to no explicit groups. """ if gradient_predivide_factor != 1.0 and rocm_built(): raise ValueError('gradient_predivide_factor not supported yet with ROCm') @@ -82,6 +85,7 @@ def DistributedOptimizer(optimizer, name=None, sparse_as_dense=sparse_as_dense, gradient_predivide_factor=gradient_predivide_factor, op=op, + num_groups=num_groups, ) diff --git a/test/parallel/test_torch.py b/test/parallel/test_torch.py index 93d28975e7..5027969fd3 100644 --- a/test/parallel/test_torch.py +++ b/test/parallel/test_torch.py @@ -644,6 +644,7 @@ def test_horovod_grouped_allreduce_average(self): tensors = [torch.FloatTensor(*([17] * dim)).random_(-100, 100) for _ in range(5)] tensors = [self.cast_and_place(tensor, dtype) for tensor in tensors] averaged = hvd.grouped_allreduce(tensors, average=True) + tensors, averaged = zip(*[self.convert_cpu_fp16_to_fp32(t, m) for t, m in zip(tensors, averaged)]) # Threshold for floating point equality depends on number of # ranks, since we're comparing against precise multiplication.