Skip to content

Commit

Permalink
Update pre/postscale tests. Deal with HOROVOD_MIXED_INSTALL cases.
Browse files Browse the repository at this point in the history
Signed-off-by: Josh Romero <joshr@nvidia.com>
  • Loading branch information
romerojosh committed Aug 10, 2020
1 parent f701bd4 commit d331573
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 51 deletions.
29 changes: 14 additions & 15 deletions test/test_mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,22 +171,21 @@ def test_horovod_allreduce_prescale(self):
tensor = mx.nd.random.uniform(-100, 100, shape=shapes[dim],
ctx=ctx)
tensor = tensor.astype(dtype)
tensor_np = tensor.asnumpy()
factor = np.random.uniform()
scaled = hvd.allreduce(tensor, average=False, name=str(count),
prescale_factor=factor)

factor = mx.nd.array([factor], dtype='float64', ctx=ctx)
if ctx != mx.cpu():
if ctx != mx.cpu() and not int(os.environ.get('HOROVOD_MIXED_INSTALL', 0)):
# For integer types, scaling done in FP64
factor = factor.astype(dtype if dtype not in int_types else 'float64')
tensor = tensor.astype(dtype if dtype not in int_types else 'float64')
factor = factor.astype('float64' if dtype in int_types else dtype)
tensor = tensor.astype('float64' if dtype in int_types else dtype)
else:
# For integer types, scaling done in FP64, FP32 math for FP16 on CPU
factor = factor.astype(dtype if dtype not in int_types else
'float32' if dtype == 'float16' else 'float64')
tensor = tensor.astype(dtype if dtype not in int_types else
'float32' if dtype == 'float16' else 'float64')
factor = factor.astype('float32' if dtype == 'float16' else
'float64' if dtype in int_types else dtype)
tensor = tensor.astype('float32' if dtype == 'float16' else
'float64' if dtype in int_types else dtype)

expected = factor * tensor
expected = expected.astype(dtype)
Expand Down Expand Up @@ -229,16 +228,16 @@ def test_horovod_allreduce_postscale(self):
postscale_factor=factor)

factor = mx.nd.array([factor], dtype='float64', ctx=ctx)
if ctx != mx.cpu():
if ctx != mx.cpu() and not int(os.environ.get('HOROVOD_MIXED_INSTALL', 0)):
# For integer types, scaling done in FP64
factor = factor.astype(dtype if dtype not in int_types else 'float64')
tensor = tensor.astype(dtype if dtype not in int_types else 'float64')
factor = factor.astype('float64' if dtype in int_types else dtype)
tensor = tensor.astype('float64' if dtype in int_types else dtype)
else:
# For integer types, scaling done in FP64, FP32 math for FP16 on CPU
factor = factor.astype(dtype if dtype not in int_types else
'float32' if dtype == 'float16' else 'float64')
tensor = tensor.astype(dtype if dtype not in int_types else
'float32' if dtype == 'float16' else 'float64')
factor = factor.astype('float32' if dtype == 'float16' else
'float64' if dtype in int_types else dtype)
tensor = tensor.astype('float32' if dtype == 'float16' else
'float64' if dtype in int_types else dtype)

expected = tensor * size
expected *= factor
Expand Down
44 changes: 22 additions & 22 deletions test/test_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def test_horovod_allreduce_cpu_prescale(self):
hvd.init()
size = hvd.size()
dtypes = self.filter_supported_types([tf.int32, tf.int64, tf.float16, tf.float32])
int_types = [tf.int32, tf.int64]
dims = [1, 2, 3]
for dtype, dim in itertools.product(dtypes, dims):
with tf.device("/cpu:0"):
Expand All @@ -267,16 +268,16 @@ def test_horovod_allreduce_cpu_prescale(self):
prescale_factor=factor)

# Scaling done in FP64 math for integer types, FP32 math for FP16 on CPU
tensor = tf.cast(tensor, dtype if dtype not in [tf.int32, tf.int64, tf.float16]
else tf.float32 if dtype == tf.float16 else tf.float64)
factor = tf.convert_to_tensor(factor, dtype if dtype not in [tf.int32, tf.int64, tf.float16]
else tf.float32 if dtype == tf.float16 else tf.float64)
tensor = tf.cast(tensor, tf.float32 if dtype == tf.float16 else
tf.float64 if dtype in int_types else dtype)
factor = tf.convert_to_tensor(factor, tf.float32 if dtype == tf.float16 else
tf.float64 if dtype in int_types else dtype)
multiplied = tf.cast(factor * tensor, dtype) * size
max_difference = tf.reduce_max(tf.abs(summed - multiplied))

# Threshold for floating point equality depends on number of
# ranks, since we're comparing against precise multiplication.
if size <= 3 or dtype in [tf.int32, tf.int64]:
if size <= 3 or dtype in int_types:
threshold = 0
elif size < 10:
threshold = 1e-4
Expand All @@ -295,6 +296,7 @@ def test_horovod_allreduce_cpu_postscale(self):
hvd.init()
size = hvd.size()
dtypes = self.filter_supported_types([tf.int32, tf.int64, tf.float16, tf.float32])
int_types = [tf.int32, tf.int64]
dims = [1, 2, 3]
for dtype, dim in itertools.product(dtypes, dims):
with tf.device("/cpu:0"):
Expand All @@ -307,16 +309,16 @@ def test_horovod_allreduce_cpu_postscale(self):

multiplied = tensor * size
# Scaling done in FP64 math for integer types, FP32 math for FP16 on CPU
multiplied = tf.cast(multiplied, dtype if dtype not in [tf.int32, tf.int64, tf.float16]
else tf.float32 if dtype == tf.float16 else tf.float64)
factor = tf.convert_to_tensor(factor, dtype if dtype not in [tf.int32, tf.int64, tf.float16]
else tf.float32 if dtype == tf.float16 else tf.float64)
multiplied = tf.cast(multiplied, tf.float32 if dtype == tf.float16 else
tf.float64 if dtype in int_types else dtype)
factor = tf.convert_to_tensor(factor, tf.float32 if dtype == tf.float16 else
tf.float64 if dtype in int_types else dtype)
multiplied = tf.cast(factor * multiplied, dtype)
max_difference = tf.reduce_max(tf.abs(summed - multiplied))

# Threshold for floating point equality depends on number of
# ranks, since we're comparing against precise multiplication.
if size <= 3 or dtype in [tf.int32, tf.int64]:
if size <= 3 or dtype in int_types:
threshold = 0
elif size < 10:
threshold = 1e-4
Expand Down Expand Up @@ -507,14 +509,15 @@ def test_horovod_allreduce_gpu_prescale(self):
if not tf.test.is_gpu_available(cuda_only=True):
return

if os.environ.get('HOROVOD_MIXED_INSTALL'):
if int(os.environ.get('HOROVOD_MIXED_INSTALL', 0)):
# Skip if compiled with CUDA but without HOROVOD_GPU_ALLREDUCE.
return

hvd.init()
size = hvd.size()
local_rank = hvd.local_rank()
dtypes = self.filter_supported_types([tf.int32, tf.int64, tf.float16, tf.float32])
int_types = [tf.int32, tf.int64]
dims = [1, 2, 3]
for dtype, dim in itertools.product(dtypes, dims):
with tf.device("/gpu:%s" % local_rank):
Expand All @@ -526,16 +529,14 @@ def test_horovod_allreduce_gpu_prescale(self):
prescale_factor=factor)

# Scaling done in FP64 math for integer types.
tensor = tf.cast(tensor, dtype if dtype not in [tf.int32, tf.int64]
else tf.float64)
factor = tf.convert_to_tensor(factor, dtype if dtype not in [tf.int32, tf.int64]
else tf.float64)
tensor = tf.cast(tensor, tf.float64 if dtype in int_types else dtype)
factor = tf.convert_to_tensor(factor, tf.float64 if dtype in int_types else dtype)
multiplied = tf.cast(factor * tensor, dtype) * size
max_difference = tf.reduce_max(tf.abs(summed - multiplied))

# Threshold for floating point equality depends on number of
# ranks, since we're comparing against precise multiplication.
if size <= 3 or dtype in [tf.int32, tf.int64]:
if size <= 3 or dtype in int_types:
threshold = 0
elif size < 10:
threshold = 1e-4
Expand All @@ -556,14 +557,15 @@ def test_horovod_allreduce_gpu_postscale(self):
if not tf.test.is_gpu_available(cuda_only=True):
return

if os.environ.get('HOROVOD_MIXED_INSTALL'):
if int(os.environ.get('HOROVOD_MIXED_INSTALL', 0)):
# Skip if compiled with CUDA but without HOROVOD_GPU_ALLREDUCE.
return

hvd.init()
size = hvd.size()
local_rank = hvd.local_rank()
dtypes = self.filter_supported_types([tf.int32, tf.int64, tf.float16, tf.float32])
int_types = [tf.int32, tf.int64]
dims = [1, 2, 3]
for dtype, dim in itertools.product(dtypes, dims):
with tf.device("/gpu:%s" % local_rank):
Expand All @@ -576,16 +578,14 @@ def test_horovod_allreduce_gpu_postscale(self):

multiplied = tensor * size
# Scaling done in FP64 math for integer types.
multiplied = tf.cast(multiplied, dtype if dtype not in [tf.int32, tf.int64]
else tf.float64)
factor = tf.convert_to_tensor(factor, dtype if dtype not in [tf.int32, tf.int64]
else tf.float64)
multiplied = tf.cast(multiplied, tf.float64 if dtype in int_types else dtype)
factor = tf.convert_to_tensor(factor, tf.float64 if dtype in int_types else dtype)
multiplied = tf.cast(factor * multiplied, dtype)
max_difference = tf.reduce_max(tf.abs(summed - multiplied))

# Threshold for floating point equality depends on number of
# ranks, since we're comparing against precise multiplication.
if size <= 3 or dtype in [tf.int32, tf.int64]:
if size <= 3 or dtype in int_types:
threshold = 0
elif size < 10:
threshold = 1e-4
Expand Down
28 changes: 14 additions & 14 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,17 +348,17 @@ def test_horovod_allreduce_prescale(self):
prescale_factor=factor)

factor = torch.tensor(factor, dtype=torch.float64)
if dtype.is_cuda:
if dtype.is_cuda and not int(os.environ.get('HOROVOD_MIXED_INSTALL', 0)):
factor = factor.cuda(hvd.local_rank())
# For integer types, scaling done in FP64
factor = factor.type(dtype if dtype not in int_types else torch.float64)
tensor = tensor.type(dtype if dtype not in int_types else torch.float64)
factor = factor.type(torch.float64 if dtype in int_types else dtype)
tensor = tensor.type(torch.float64 if dtype in int_types else dtype)
else:
# For integer types, scaling done in FP64, FP32 math for FP16 on CPU
factor = factor.type(dtype if dtype not in int_types + half_types else
torch.float32 if dtype in half_types else torch.float64)
tensor = tensor.type(dtype if dtype not in int_types + half_types else
torch.float32 if dtype in half_types else torch.float64)
factor = factor.type(torch.float32 if dtype in half_types else
torch.float64 if dtype in int_types else dtype)
tensor = tensor.type(torch.float32 if dtype in half_types else
torch.float64 if dtype in int_types else dtype)
multiplied = factor * tensor
multiplied = multiplied.type(dtype)
summed, multiplied = self.convert_cpu_fp16_to_fp32(summed, multiplied)
Expand Down Expand Up @@ -402,17 +402,17 @@ def test_horovod_allreduce_postscale(self):
postscale_factor=factor)

factor = torch.tensor(factor, dtype=torch.float64)
if dtype.is_cuda:
if dtype.is_cuda and not int(os.environ.get('HOROVOD_MIXED_INSTALL', 0)):
factor = factor.cuda(hvd.local_rank())
# For integer types, scaling done in FP64
factor.type(dtype if dtype not in int_types else torch.float64)
tensor.type(dtype if dtype not in int_types else torch.float64)
factor = factor.type(torch.float64 if dtype in int_types else dtype)
tensor = tensor.type(torch.float64 if dtype in int_types else dtype)
else:
# For integer types, scaling done in FP64, FP32 math for FP16 on CPU
factor = factor.type(dtype if dtype not in int_types + half_types else
torch.float32 if dtype in half_types else torch.float64)
tensor = tensor.type(dtype if dtype not in int_types + half_types else
torch.float32 if dtype in half_types else torch.float64)
factor = factor.type(torch.float32 if dtype in half_types else
torch.float64 if dtype in int_types else dtype)
tensor = tensor.type(torch.float32 if dtype in half_types else
torch.float64 if dtype in int_types else dtype)
multiplied = size * tensor
multiplied = multiplied * factor
multiplied = multiplied.type(dtype)
Expand Down

0 comments on commit d331573

Please sign in to comment.