From b2e6c067b205d2f6e40293dcdbbbc009e2e13cbd Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Fri, 28 Sep 2018 16:57:23 -0700 Subject: [PATCH] FP16 support for GPU tensors in all frameworks (#529) * Initial support for FP16 Bump version to a dev release Cast vars to fp16 before allreduce to compress gradients Abstracted compression algorithm into a class hierarchy and added algorithm flag to optimizer and allreduce signatures Changed compressor to set the dtype on initialization Resolved conflicts Additional conflicts Formatting More formats Updated license Added fp16 compression for Keras Added arguments to keras examples Fixed imports * Added compression to tf.keras * Added PyTorch compression API Added unit tests Whitespace * Added C interfaces and types * Forward declare * Removed Half from older versions of PyTorch * Added error for old version of PyTorch * Removed reference to float16 * Updated examples, added compression to the Keras model load * Cleaned imports * Removed dependency on enums * Updated unit tests * Test compatability fix * Reverted version updates * Fixed message * Removed imports * Added cuda.HalfTensor to all PyTorch tests with CUDA * Only compare versions once * Renamed --fp16 in examples to --fp16-allreduce for clarity * Replaced assignment with set_ * Modified compression algorithms to be stateless with optional context parameters * Removed optional ctx parameter * Replaced 0.4.2 with 1.0.0 * Only run GPU tests with HalfTensors if fp16 is supported --- examples/keras_imagenet_resnet50.py | 10 ++- examples/pytorch_mnist.py | 10 ++- horovod/common/mpi_message.cc | 3 + horovod/common/mpi_message.h | 7 +- horovod/common/operations.cc | 25 ++++++- horovod/common/wire/mpi_message.fbs | 7 +- horovod/common/wire/mpi_message_generated.h | 8 +- horovod/keras/__init__.py | 18 ++++- horovod/keras/impl.py | 17 +++-- horovod/tensorflow/__init__.py | 26 +++++-- horovod/tensorflow/compression.py | 74 ++++++++++++++++++ horovod/tensorflow/keras/__init__.py | 18 ++++- horovod/tensorflow/mpi_ops.cc | 8 +- horovod/torch/__init__.py | 25 +++++-- horovod/torch/adapter.cc | 3 +- horovod/torch/adapter_v2.cc | 2 + horovod/torch/compression.py | 74 ++++++++++++++++++ horovod/torch/mpi_ops.cc | 4 +- horovod/torch/mpi_ops.py | 20 ++++- horovod/torch/mpi_ops_v2.cc | 12 +++ test/test_tensorflow.py | 55 +++++++++++--- test/test_torch.py | 83 ++++++++++++++++++--- 22 files changed, 438 insertions(+), 71 deletions(-) create mode 100644 horovod/tensorflow/compression.py create mode 100644 horovod/torch/compression.py diff --git a/examples/keras_imagenet_resnet50.py b/examples/keras_imagenet_resnet50.py index feceb784f1..ce1919c7d7 100644 --- a/examples/keras_imagenet_resnet50.py +++ b/examples/keras_imagenet_resnet50.py @@ -31,6 +31,8 @@ help='tensorboard log directory') parser.add_argument('--checkpoint-format', default='./checkpoint-{epoch}.h5', help='checkpoint file format') +parser.add_argument('--fp16-allreduce', action='store_true', default=False, + help='use fp16 compression during allreduce') # Default settings from https://arxiv.org/abs/1706.02677. parser.add_argument('--batch-size', type=int, default=32, @@ -91,11 +93,15 @@ # Set up standard ResNet-50 model. model = keras.applications.resnet50.ResNet50(weights=None) +# Horovod: (optional) compression algorithm. +compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none + # Restore from a previous checkpoint, if initial_epoch is specified. # Horovod: restore on the first worker which will broadcast both model and optimizer weights # to other workers. if resume_from_epoch > 0 and hvd.rank() == 0: - model = hvd.load_model(args.checkpoint_format.format(epoch=resume_from_epoch)) + model = hvd.load_model(args.checkpoint_format.format(epoch=resume_from_epoch), + compression=compression) else: # ResNet-50 model that is included with Keras is optimized for inference. # Add L2 weight decay & adjust BN settings. @@ -117,7 +123,7 @@ momentum=args.momentum) # Horovod: add Horovod Distributed Optimizer. - opt = hvd.DistributedOptimizer(opt) + opt = hvd.DistributedOptimizer(opt, compression=compression) model.compile(loss=keras.losses.categorical_crossentropy, optimizer=opt, diff --git a/examples/pytorch_mnist.py b/examples/pytorch_mnist.py index 0831b211e1..9f1342a587 100644 --- a/examples/pytorch_mnist.py +++ b/examples/pytorch_mnist.py @@ -25,6 +25,8 @@ help='random seed (default: 42)') parser.add_argument('--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') +parser.add_argument('--fp16-allreduce', action='store_true', default=False, + help='use fp16 compression during allreduce') args = parser.parse_args() args.cuda = not args.no_cuda and torch.cuda.is_available() @@ -92,9 +94,13 @@ def forward(self, x): optimizer = optim.SGD(model.parameters(), lr=args.lr * hvd.size(), momentum=args.momentum) +# Horovod: (optional) compression algorithm. +compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none + # Horovod: wrap optimizer with DistributedOptimizer. -optimizer = hvd.DistributedOptimizer( - optimizer, named_parameters=model.named_parameters()) +optimizer = hvd.DistributedOptimizer(optimizer, + named_parameters=model.named_parameters(), + compression=compression) def train(epoch): diff --git a/horovod/common/mpi_message.cc b/horovod/common/mpi_message.cc index e8a1a3e851..5f0359df25 100644 --- a/horovod/common/mpi_message.cc +++ b/horovod/common/mpi_message.cc @@ -41,6 +41,9 @@ const std::string& MPIDataType_Name(MPIDataType value) { case HOROVOD_INT64: static const std::string int64("int64"); return int64; + case HOROVOD_FLOAT16: + static const std::string float16("float16"); + return float16; case HOROVOD_FLOAT32: static const std::string float32("float32"); return float32; diff --git a/horovod/common/mpi_message.h b/horovod/common/mpi_message.h index 2c1cb84ac0..005d12effa 100644 --- a/horovod/common/mpi_message.h +++ b/horovod/common/mpi_message.h @@ -30,9 +30,10 @@ enum MPIDataType { HOROVOD_INT16 = 3, HOROVOD_INT32 = 4, HOROVOD_INT64 = 5, - HOROVOD_FLOAT32 = 6, - HOROVOD_FLOAT64 = 7, - HOROVOD_BOOL = 8 + HOROVOD_FLOAT16 = 6, + HOROVOD_FLOAT32 = 7, + HOROVOD_FLOAT64 = 8, + HOROVOD_BOOL = 9 }; const std::string& MPIDataType_Name(MPIDataType value); diff --git a/horovod/common/operations.cc b/horovod/common/operations.cc index f5474916cd..3ea102f377 100644 --- a/horovod/common/operations.cc +++ b/horovod/common/operations.cc @@ -182,6 +182,9 @@ struct HorovodGlobalState { // COMM_WORLD ranks of processes running on this node. std::vector local_comm_ranks; + // MPI custom data type for float16. + MPI_Datatype mpi_float16_t; + // Private MPI communicator for Horovod to ensure no collisions with other // threads using MPI. MPI_Comm mpi_comm; @@ -520,6 +523,8 @@ MPI_Datatype GetMPIDataType(const std::shared_ptr tensor) { return MPI_INT32_T; case HOROVOD_INT64: return MPI_INT64_T; + case HOROVOD_FLOAT16: + return horovod_global.mpi_float16_t; case HOROVOD_FLOAT32: return MPI_FLOAT; case HOROVOD_FLOAT64: @@ -539,6 +544,8 @@ ncclDataType_t GetNCCLDataType(const std::shared_ptr tensor) { return ncclInt32; case HOROVOD_INT64: return ncclInt64; + case HOROVOD_FLOAT16: + return ncclFloat16; case HOROVOD_FLOAT32: return ncclFloat32; case HOROVOD_FLOAT64: @@ -1010,7 +1017,7 @@ void PerformOperation(TensorTable& tensor_table, MPIResponse response) { #else if (horovod_global.hierarchical_allreduce) { int element_size; - MPI_Type_size(GetMPIDataType(first_entry.tensor), &element_size); + MPI_Type_size(GetMPIDataType(first_entry.tensor), &element_size); // If cluster is homogeneous and we are using fusion buffer, include // dummy elements from the buffer (if necessary) to make sure the data @@ -1110,7 +1117,7 @@ void PerformOperation(TensorTable& tensor_table, MPIResponse response) { WAIT_FOR_EVENTS(entries, timeline, event_queue) // According to https://docs.nvidia.com/cuda/cuda-runtime-api/ - // api-sync-behavior.html#api-sync-behavior__memcpy-async, + // api-sync-behavior.html#api-sync-behavior__memcpy-async, // cudaMemcpyAsync is synchronous with respect to the host, so we // memcpy (effectively) synchronously to generate an accurate timeline ACTIVITY_START_ALL(entries, timeline, MEMCPY_IN_HOST_BUFFER) @@ -1508,6 +1515,11 @@ void BackgroundThreadLoop(HorovodGlobalState& state) { MPI_Comm_rank(cross_comm, &cross_rank); MPI_Comm_size(cross_comm, &cross_size); + // Create custom MPI float16 data type. + MPI_Datatype mpi_float16_t; + MPI_Type_contiguous(2, MPI_BYTE, &mpi_float16_t); + MPI_Type_commit(&mpi_float16_t); + state.rank = rank; state.local_rank = local_rank; state.cross_rank = cross_rank; @@ -1516,6 +1528,7 @@ void BackgroundThreadLoop(HorovodGlobalState& state) { state.cross_size = cross_size; state.local_comm = local_comm; state.cross_comm = cross_comm; + state.mpi_float16_t = mpi_float16_t; state.mpi_threads_supported = (provided == MPI_THREAD_MULTIPLE); state.local_comm_ranks = local_comm_ranks; @@ -1558,8 +1571,8 @@ void BackgroundThreadLoop(HorovodGlobalState& state) { } // Override Tensor Fusion threshold, if it's set. - auto horovod_fusion_threshold = std::getenv("HOROVOD_FUSION_THRESHOLD"); - int64_t proposed_fusion_threshold = (horovod_fusion_threshold != nullptr) ? + auto horovod_fusion_threshold = std::getenv("HOROVOD_FUSION_THRESHOLD"); + int64_t proposed_fusion_threshold = (horovod_fusion_threshold != nullptr) ? std::strtol(horovod_fusion_threshold, nullptr, 10) : state.tensor_fusion_threshold; @@ -1924,6 +1937,10 @@ void horovod_shutdown() { MPI_Comm_free(&horovod_global.cross_comm); } + if (horovod_global.mpi_float16_t != MPI_DATATYPE_NULL) { + MPI_Type_free(&horovod_global.mpi_float16_t); + } + if (horovod_global.should_finalize) { #if HAVE_DDL // ddl_finalize calls MPI_Finalize diff --git a/horovod/common/wire/mpi_message.fbs b/horovod/common/wire/mpi_message.fbs index 3dcac3d6d9..d838f00e60 100644 --- a/horovod/common/wire/mpi_message.fbs +++ b/horovod/common/wire/mpi_message.fbs @@ -24,9 +24,10 @@ enum MPIDataType:byte { HOROVOD_INT16 = 3, HOROVOD_INT32 = 4, HOROVOD_INT64 = 5, - HOROVOD_FLOAT32 = 6, - HOROVOD_FLOAT64 = 7, - HOROVOD_BOOL = 8 + HOROVOD_FLOAT16 = 6, + HOROVOD_FLOAT32 = 7, + HOROVOD_FLOAT64 = 8, + HOROVOD_BOOL = 9 } // An MPIRequest is a message sent from a rank greater than zero to the diff --git a/horovod/common/wire/mpi_message_generated.h b/horovod/common/wire/mpi_message_generated.h index 5d3edc4fb3..3b63205f00 100644 --- a/horovod/common/wire/mpi_message_generated.h +++ b/horovod/common/wire/mpi_message_generated.h @@ -40,9 +40,10 @@ enum MPIDataType { MPIDataType_HOROVOD_INT16 = 3, MPIDataType_HOROVOD_INT32 = 4, MPIDataType_HOROVOD_INT64 = 5, - MPIDataType_HOROVOD_FLOAT32 = 6, - MPIDataType_HOROVOD_FLOAT64 = 7, - MPIDataType_HOROVOD_BOOL = 8, + MPIDataType_HOROVOD_FLOAT16 = 6, + MPIDataType_HOROVOD_FLOAT32 = 7, + MPIDataType_HOROVOD_FLOAT64 = 8, + MPIDataType_HOROVOD_BOOL = 9, MPIDataType_MIN = MPIDataType_HOROVOD_UINT8, MPIDataType_MAX = MPIDataType_HOROVOD_BOOL }; @@ -55,6 +56,7 @@ inline const char **EnumNamesMPIDataType() { "HOROVOD_INT16", "HOROVOD_INT32", "HOROVOD_INT64", + "HOROVOD_FLOAT16", "HOROVOD_FLOAT32", "HOROVOD_FLOAT64", "HOROVOD_BOOL", diff --git a/horovod/keras/__init__.py b/horovod/keras/__init__.py index 5443b70ce8..7c926a774c 100644 --- a/horovod/keras/__init__.py +++ b/horovod/keras/__init__.py @@ -23,12 +23,15 @@ from horovod.tensorflow import rank from horovod.tensorflow import local_rank from horovod.tensorflow import mpi_threads_supported +from horovod.tensorflow import Compression from horovod.keras import callbacks from horovod.keras import impl as _impl -def DistributedOptimizer(optimizer, name=None, device_dense='', device_sparse=''): +def DistributedOptimizer(optimizer, name=None, + device_dense='', device_sparse='', + compression=Compression.none): """ An optimizer that wraps another keras.optimizers.Optimizer, using an allreduce to average gradient values before applying gradients to model weights. @@ -42,8 +45,12 @@ def DistributedOptimizer(optimizer, name=None, device_dense='', device_sparse='' if Horovod was build with HOROVOD_GPU_ALLREDUCE. device_sparse: Device to be used for sparse tensors. Uses GPU by default if Horovod was build with HOROVOD_GPU_ALLGATHER. + compression: Compression algorithm used to reduce the amount of data + sent and received by each worker node. Defaults to not + using compression. """ - return _impl.create_distributed_optimizer(keras, optimizer, name, device_dense, device_sparse) + return _impl.create_distributed_optimizer(keras, optimizer, name, + device_dense, device_sparse, compression) def broadcast_global_variables(root_rank): @@ -99,7 +106,7 @@ def broadcast(value, root_rank, name=None): return _impl.broadcast(K, value, root_rank, name) -def load_model(filepath, custom_optimizers=None, custom_objects=None): +def load_model(filepath, custom_optimizers=None, custom_objects=None, compression=Compression.none): """ Loads a saved Keras model with a Horovod DistributedOptimizer. @@ -119,6 +126,9 @@ def load_model(filepath, custom_optimizers=None, custom_objects=None): during loading. custom_objects: Optional dictionary mapping names (strings) to custom classes or functions to be considered during deserialization. + compression: Compression algorithm used to reduce the amount of data + sent and received by each worker node. Defaults to not + using compression. # Returns A Keras model instance. @@ -128,5 +138,5 @@ def load_model(filepath, custom_optimizers=None, custom_objects=None): ValueError: In case of an invalid savefile. """ def wrap_optimizer(cls): - return lambda **kwargs: DistributedOptimizer(cls(**kwargs)) + return lambda **kwargs: DistributedOptimizer(cls(**kwargs), compression=compression) return _impl.load_model(keras, wrap_optimizer, filepath, custom_optimizers, custom_objects) diff --git a/horovod/keras/impl.py b/horovod/keras/impl.py index ddde89ec11..fd2ca784b1 100644 --- a/horovod/keras/impl.py +++ b/horovod/keras/impl.py @@ -17,7 +17,7 @@ import tensorflow as tf -def create_distributed_optimizer(keras, optimizer, name=None, device_dense='', device_sparse=''): +def create_distributed_optimizer(keras, optimizer, name, device_dense, device_sparse, compression): class _DistributedOptimizer(keras.optimizers.Optimizer): def __init__(self, name, device_dense, device_sparse, **kwargs): if name is None: @@ -25,6 +25,7 @@ def __init__(self, name, device_dense, device_sparse, **kwargs): self._name = name self._device_dense = device_dense self._device_sparse = device_sparse + self._compression = compression super(self.__class__, self).__init__(**kwargs) def get_gradients(self, loss, params): @@ -42,8 +43,10 @@ def get_gradients(self, loss, params): with tf.name_scope(self._name + "_Allreduce"): for grad in gradients: if grad is not None: - avg_grad = hvd.allreduce(grad, device_dense=self._device_dense, - device_sparse=self._device_sparse) + avg_grad = hvd.allreduce(grad, + device_dense=self._device_dense, + device_sparse=self._device_sparse, + compression=self._compression) averaged_gradients.append(avg_grad) else: averaged_gradients.append(None) @@ -65,22 +68,22 @@ def broadcast_global_variables(backend, root_rank): return backend.get_session().run(bcast_op) -def allreduce(backend, value, name=None, average=True): +def allreduce(backend, value, name, average): allreduce_op = hvd.allreduce(tf.constant(value, name=name), average=average) return backend.get_session().run(allreduce_op) -def allgather(backend, value, name=None): +def allgather(backend, value, name): allgather_op = hvd.allgather(tf.constant(value, name=name)) return backend.get_session().run(allgather_op) -def broadcast(backend, value, root_rank, name=None): +def broadcast(backend, value, root_rank, name): bcast_op = hvd.broadcast(tf.constant(value, name=name), root_rank) return backend.get_session().run(bcast_op) -def load_model(keras, wrap_optimizer, filepath, custom_optimizers=None, custom_objects=None): +def load_model(keras, wrap_optimizer, filepath, custom_optimizers, custom_objects): horovod_objects = { subclass.__name__.lower(): wrap_optimizer(subclass) for subclass in keras.optimizers.Optimizer.__subclasses__() diff --git a/horovod/tensorflow/__init__.py b/horovod/tensorflow/__init__.py index 7d8561769c..8715d699a0 100644 --- a/horovod/tensorflow/__init__.py +++ b/horovod/tensorflow/__init__.py @@ -33,6 +33,7 @@ check_extension('horovod.tensorflow', 'HOROVOD_WITH_TENSORFLOW', __file__, 'mpi_lib') +from horovod.tensorflow.compression import Compression from horovod.tensorflow.mpi_ops import allgather, broadcast, _allreduce from horovod.tensorflow.mpi_ops import init, shutdown from horovod.tensorflow.mpi_ops import size, local_size, rank, local_rank @@ -41,7 +42,8 @@ import tensorflow as tf -def allreduce(tensor, average=True, device_dense='', device_sparse=''): +def allreduce(tensor, average=True, device_dense='', device_sparse='', + compression=Compression.none): """Perform an allreduce on a tf.Tensor or tf.IndexedSlices. Arguments: @@ -53,6 +55,9 @@ def allreduce(tensor, average=True, device_dense='', device_sparse=''): if Horovod was build with HOROVOD_GPU_ALLREDUCE. device_sparse: Device to be used for sparse tensors. Uses GPU by default if Horovod was build with HOROVOD_GPU_ALLGATHER. + compression: Compression algorithm used to reduce the amount of data + sent and received by each worker node. Defaults to not + using compression. This function performs a bandwidth-optimal ring allreduce on the input tensor. If the input is an tf.IndexedSlices, the function instead does an @@ -73,8 +78,10 @@ def allreduce(tensor, average=True, device_dense='', device_sparse=''): dense_shape=tensor.dense_shape) else: with tf.device(device_dense): - horovod_size = tf.cast(size(), tensor.dtype) - summed_tensor = _allreduce(tensor) + horovod_size = tf.cast(size(), dtype=tensor.dtype) + tensor_compressed, ctx = compression.compress(tensor) + summed_tensor_compressed = _allreduce(tensor_compressed) + summed_tensor = compression.decompress(summed_tensor_compressed, ctx) new_tensor = (tf.div(summed_tensor, horovod_size) if average else summed_tensor) return new_tensor @@ -130,7 +137,7 @@ class DistributedOptimizer(tf.train.Optimizer): average gradient values before applying gradients to model weights.""" def __init__(self, optimizer, name=None, use_locking=False, device_dense='', - device_sparse=''): + device_sparse='', compression=Compression.none): """Construct a new DistributedOptimizer, which uses another optimizer under the hood for computing single-process gradient values and applying gradient updates after the gradient values have been averaged @@ -152,6 +159,10 @@ def __init__(self, optimizer, name=None, use_locking=False, device_dense='', device_sparse: Device to be used for sparse tensors. Uses GPU by default if Horovod was build with HOROVOD_GPU_ALLGATHER. + compression: + Compression algorithm used during allreduce to reduce the amount + of data sent during the each parameter update step. Defaults to + not using compression. """ if name is None: name = "Distributed{}".format(type(optimizer).__name__) @@ -159,6 +170,7 @@ def __init__(self, optimizer, name=None, use_locking=False, device_dense='', self._optimizer = optimizer self._device_dense = device_dense self._device_sparse = device_sparse + self._compression = compression super(DistributedOptimizer, self).__init__( name=name, use_locking=use_locking) @@ -176,8 +188,10 @@ def compute_gradients(self, *args, **kwargs): with tf.name_scope(self._name + "_Allreduce"): for grad, var in gradients: if grad is not None: - avg_grad = allreduce(grad, device_dense=self._device_dense, - device_sparse=self._device_sparse) + avg_grad = allreduce(grad, + device_dense=self._device_dense, + device_sparse=self._device_sparse, + compression=self._compression) averaged_gradients.append((avg_grad, var)) else: averaged_gradients.append((None, var)) diff --git a/horovod/tensorflow/compression.py b/horovod/tensorflow/compression.py new file mode 100644 index 0000000000..931c176818 --- /dev/null +++ b/horovod/tensorflow/compression.py @@ -0,0 +1,74 @@ +# Copyright 2018 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Gradient compression algorithms.""" + +import tensorflow as tf + + +class Compressor(object): + """Interface for compressing and decompressing a given tensor.""" + @staticmethod + def compress(tensor): + """Compresses a tensor and returns it with the context needed to decompress it.""" + pass + + @staticmethod + def decompress(tensor, ctx): + """Decompress the tensor with the given context.""" + pass + + +class NoneCompressor(Compressor): + """Default no-op compression.""" + @staticmethod + def compress(tensor): + """Returns the tensor unmodified.""" + return tensor, None + + @staticmethod + def decompress(tensor, ctx): + """Returns the tensor unmodified.""" + return tensor + + +class FP16Compressor(Compressor): + """Compress all floating point gradients to 16-bit.""" + @staticmethod + def compress(tensor): + """Downcasts the tensor to 16-bit.""" + tensor_compressed = tensor + if tensor.dtype.is_floating: + # Only allow compression from other floating point types + tensor_compressed = tf.cast(tensor, dtype=tf.float16) + return tensor_compressed, tensor.dtype + + @staticmethod + def decompress(tensor, ctx): + """Upcasts the tensor to the initialization dtype.""" + tensor_decompressed = tensor + dtype = ctx + if dtype.is_floating: + tensor_decompressed = tf.cast(tensor, dtype=dtype) + return tensor_decompressed + + +class Compression(object): + """Optional gradient compression algorithm used during allreduce.""" + + """Do not compress the gradients. This is the default.""" + none = NoneCompressor + + """Compress all floating point gradients to 16-bit.""" + fp16 = FP16Compressor diff --git a/horovod/tensorflow/keras/__init__.py b/horovod/tensorflow/keras/__init__.py index b5855c3873..dfeb2fd902 100644 --- a/horovod/tensorflow/keras/__init__.py +++ b/horovod/tensorflow/keras/__init__.py @@ -30,12 +30,15 @@ from horovod.tensorflow import rank from horovod.tensorflow import local_rank from horovod.tensorflow import mpi_threads_supported +from horovod.tensorflow import Compression from horovod.keras import _impl from horovod.tensorflow.keras import callbacks -def DistributedOptimizer(optimizer, name=None, device_dense='', device_sparse=''): +def DistributedOptimizer(optimizer, name=None, + device_dense='', device_sparse='', + compression=Compression.none): """ An optimizer that wraps another keras.optimizers.Optimizer, using an allreduce to average gradient values before applying gradients to model weights. @@ -49,8 +52,12 @@ def DistributedOptimizer(optimizer, name=None, device_dense='', device_sparse='' if Horovod was build with HOROVOD_GPU_ALLREDUCE. device_sparse: Device to be used for sparse tensors. Uses GPU by default if Horovod was build with HOROVOD_GPU_ALLGATHER. + compression: Compression algorithm used to reduce the amount of data + sent and received by each worker node. Defaults to not + using compression. """ - return _impl.create_distributed_optimizer(keras, optimizer, name, device_dense, device_sparse) + return _impl.create_distributed_optimizer(keras, optimizer, name, + device_dense, device_sparse, compression) def broadcast_global_variables(root_rank): @@ -106,7 +113,7 @@ def broadcast(value, root_rank, name=None): return _impl.broadcast(K, value, root_rank, name) -def load_model(filepath, custom_optimizers=None, custom_objects=None): +def load_model(filepath, custom_optimizers=None, custom_objects=None, compression=Compression.none): """ Loads a saved Keras model with a Horovod DistributedOptimizer. @@ -126,6 +133,9 @@ def load_model(filepath, custom_optimizers=None, custom_objects=None): during loading. custom_objects: Optional dictionary mapping names (strings) to custom classes or functions to be considered during deserialization. + compression: Compression algorithm used to reduce the amount of data + sent and received by each worker node. Defaults to not + using compression. # Returns A Keras model instance. @@ -135,6 +145,6 @@ def load_model(filepath, custom_optimizers=None, custom_objects=None): ValueError: In case of an invalid savefile. """ def wrap_optimizer(cls): - return lambda **kwargs: DistributedOptimizer(cls(**kwargs)) + return lambda **kwargs: DistributedOptimizer(cls(**kwargs), compression=compression) return _impl.load_model(keras, wrap_optimizer, filepath, custom_optimizers, custom_objects) diff --git a/horovod/tensorflow/mpi_ops.cc b/horovod/tensorflow/mpi_ops.cc index dfc9a33a86..18cb46c811 100644 --- a/horovod/tensorflow/mpi_ops.cc +++ b/horovod/tensorflow/mpi_ops.cc @@ -180,6 +180,8 @@ const common::MPIDataType TFTensor::dtype() const { return common::HOROVOD_INT32; case DT_INT64: return common::HOROVOD_INT64; + case DT_HALF: + return common::HOROVOD_FLOAT16; case DT_FLOAT: return common::HOROVOD_FLOAT32; case DT_DOUBLE: @@ -305,7 +307,7 @@ REGISTER_KERNEL_BUILDER(Name("HorovodAllreduce").Device(DEVICE_GPU), #endif REGISTER_OP("HorovodAllreduce") - .Attr("T: {int32, int64, float32, float64}") + .Attr("T: {int32, int64, float16, float32, float64}") .Input("tensor: T") .Output("sum: T") .SetShapeFn([](shape_inference::InferenceContext* c) { @@ -362,7 +364,7 @@ REGISTER_KERNEL_BUILDER(Name("HorovodAllgather").Device(DEVICE_GPU), REGISTER_OP("HorovodAllgather") .Attr( - "T: {uint8, int8, uint16, int16, int32, int64, float32, float64, bool}") + "T: {uint8, int8, uint16, int16, int32, int64, float16, float32, float64, bool}") .Input("tensor: T") .Output("output: T") .SetShapeFn([](shape_inference::InferenceContext* c) { @@ -435,7 +437,7 @@ REGISTER_KERNEL_BUILDER(Name("HorovodBroadcast").Device(DEVICE_GPU), REGISTER_OP("HorovodBroadcast") .Attr( - "T: {uint8, int8, uint16, int16, int32, int64, float32, float64, bool}") + "T: {uint8, int8, uint16, int16, int32, int64, float16, float32, float64, bool}") .Attr("root_rank: int") .Input("tensor: T") .Output("output: T") diff --git a/horovod/torch/__init__.py b/horovod/torch/__init__.py index ef12a6f492..3ca9a324f6 100644 --- a/horovod/torch/__init__.py +++ b/horovod/torch/__init__.py @@ -26,6 +26,7 @@ check_extension('horovod.torch', 'HOROVOD_WITH_PYTORCH', __file__, 'mpi_lib', '_mpi_lib') +from horovod.torch.compression import Compression from horovod.torch.mpi_ops import allreduce, allreduce_async, allreduce_, allreduce_async_ from horovod.torch.mpi_ops import allgather, allgather_async from horovod.torch.mpi_ops import broadcast, broadcast_async, broadcast_, broadcast_async_ @@ -39,8 +40,9 @@ class _DistributedOptimizer(torch.optim.Optimizer): - def __init__(self, params, named_parameters=None): + def __init__(self, params, named_parameters, compression): super(self.__class__, self).__init__(params) + self._compression = compression if named_parameters is not None: named_parameters = list(named_parameters) @@ -75,13 +77,19 @@ def hook(*ignore): assert p not in self._handles assert not p.grad.requires_grad name = self._parameter_names.get(p) - handle = allreduce_async_(p.grad.data, average=True, name=name) - self._handles[p] = handle + + tensor = p.grad.data + tensor_compressed, ctx = self._compression.compress(tensor) + + handle = allreduce_async_(tensor_compressed, average=True, name=name) + self._handles[p] = (handle, ctx) return hook def synchronize(self): - for handle in self._handles.values(): - synchronize(handle) + for p, value in self._handles.items(): + handle, ctx = value + output = synchronize(handle) + p.grad.data.set_(self._compression.decompress(output, ctx)) self._handles.clear() def step(self, closure=None): @@ -89,7 +97,7 @@ def step(self, closure=None): return super(self.__class__, self).step(closure) -def DistributedOptimizer(optimizer, named_parameters=None): +def DistributedOptimizer(optimizer, named_parameters=None, compression=Compression.none): """ An optimizer that wraps another torch.optim.Optimizer, using an allreduce to average gradient values before applying gradients to model weights. @@ -116,12 +124,15 @@ def DistributedOptimizer(optimizer, named_parameters=None): optimizer: Optimizer to use for computing gradients and applying updates. named_parameters: A mapping between parameter names and values. Used for naming of allreduce operations. Typically just `model.named_parameters()`. + compression: Compression algorithm used during allreduce to reduce the amount + of data sent during the each parameter update step. Defaults to + not using compression. """ # We dynamically create a new class that inherits from the optimizer that was passed in. # The goal is to override the `step()` method with an allreduce implementation. cls = type(optimizer.__class__.__name__, (optimizer.__class__,), dict(_DistributedOptimizer.__dict__)) - return cls(optimizer.param_groups, named_parameters) + return cls(optimizer.param_groups, named_parameters, compression) def broadcast_parameters(params, root_rank): diff --git a/horovod/torch/adapter.cc b/horovod/torch/adapter.cc index 39647c2350..7eb16d072d 100644 --- a/horovod/torch/adapter.cc +++ b/horovod/torch/adapter.cc @@ -170,7 +170,8 @@ ADAPTER_DEFINE_TYPE(MPIDataType::HOROVOD_INT32, DeviceType::GPU, THCudaIntTensor) ADAPTER_DEFINE_TYPE(MPIDataType::HOROVOD_INT64, DeviceType::GPU, THCudaLongTensor) -ADAPTER_DEFINE_TYPE(MPIDataType::HOROVOD_FLOAT32, DeviceType::GPU, THCudaTensor) +ADAPTER_DEFINE_TYPE(MPIDataType::HOROVOD_FLOAT32, DeviceType::GPU, + THCudaTensor) ADAPTER_DEFINE_TYPE(MPIDataType::HOROVOD_FLOAT64, DeviceType::GPU, THCudaDoubleTensor) #endif diff --git a/horovod/torch/adapter_v2.cc b/horovod/torch/adapter_v2.cc index 4b2b9cf667..ac5ea24d7c 100644 --- a/horovod/torch/adapter_v2.cc +++ b/horovod/torch/adapter_v2.cc @@ -48,6 +48,8 @@ const MPIDataType TorchTensor::dtype() const { return common::HOROVOD_INT32; case at::ScalarType::Long: return common::HOROVOD_INT64; + case at::ScalarType::Half: + return common::HOROVOD_FLOAT16; case at::ScalarType::Float: return common::HOROVOD_FLOAT32; case at::ScalarType::Double: diff --git a/horovod/torch/compression.py b/horovod/torch/compression.py new file mode 100644 index 0000000000..75ce91e071 --- /dev/null +++ b/horovod/torch/compression.py @@ -0,0 +1,74 @@ +# Copyright 2018 Uber Technologies, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Gradient compression algorithms.""" + +import torch + + +class Compressor(object): + """Interface for compressing and decompressing a given tensor.""" + @staticmethod + def compress(tensor): + """Compresses a tensor and returns it with the context needed to decompress it.""" + pass + + @staticmethod + def decompress(tensor, ctx): + """Decompress the tensor with the given context.""" + pass + + +class NoneCompressor(Compressor): + """Default no-op compression.""" + @staticmethod + def compress(tensor): + """Returns the tensor unmodified.""" + return tensor, None + + @staticmethod + def decompress(tensor, ctx): + """Returns the tensor unmodified.""" + return tensor + + +class FP16Compressor(Compressor): + """Compress all floating point gradients to 16-bit.""" + @staticmethod + def compress(tensor): + """Downcasts the tensor to 16-bit.""" + tensor_compressed = tensor + if tensor.dtype.is_floating_point: + # Only allow compression from other floating point types + tensor_compressed = tensor.type(torch.float16) + return tensor_compressed, tensor.dtype + + @staticmethod + def decompress(tensor, ctx): + """Upcasts the tensor to the initialization dtype.""" + tensor_decompressed = tensor + dtype = ctx + if dtype.is_floating_point: + tensor_decompressed = tensor.type(dtype) + return tensor_decompressed + + +class Compression(object): + """Optional gradient compression algorithm used during allreduce.""" + + """Do not compress the gradients. This is the default.""" + none = NoneCompressor + + """Compress all floating point gradients to 16-bit.""" + fp16 = FP16Compressor diff --git a/horovod/torch/mpi_ops.cc b/horovod/torch/mpi_ops.cc index 28b0656c2e..451f599e3d 100644 --- a/horovod/torch/mpi_ops.cc +++ b/horovod/torch/mpi_ops.cc @@ -358,8 +358,8 @@ BROADCAST(torch_cuda_LongTensor, MPIDataType::HOROVOD_INT64, DeviceType::GPU, THCudaLongTensor) BROADCAST(torch_cuda_FloatTensor, MPIDataType::HOROVOD_FLOAT32, DeviceType::GPU, THCudaTensor) -BROADCAST(torch_cuda_DoubleTensor, MPIDataType::HOROVOD_FLOAT64, - DeviceType::GPU, THCudaDoubleTensor) +BROADCAST(torch_cuda_DoubleTensor, MPIDataType::HOROVOD_FLOAT64, DeviceType::GPU, + THCudaDoubleTensor) #endif #define BROADCAST_CUDA_ON_CPU(torch_Tensor, HorovodType, THCTensor, THTensor) \ diff --git a/horovod/torch/mpi_ops.py b/horovod/torch/mpi_ops.py index 7a09e12328..9bf6edebfc 100644 --- a/horovod/torch/mpi_ops.py +++ b/horovod/torch/mpi_ops.py @@ -17,9 +17,12 @@ from __future__ import division from __future__ import print_function +from distutils.version import LooseVersion + # Load all the necessary PyTorch C types. import torch +from horovod.torch.compression import Compression try: from horovod.torch import mpi_lib_v2 as mpi_lib from horovod.common import HorovodBasics as _HorovodBasics @@ -47,6 +50,9 @@ # before the operation is finished. _handle_map = {} +# Only support fp16 allreduce for PyTorch versions using v2 API. +_fp16_supported = LooseVersion(torch.__version__) >= LooseVersion('1.0.0') + def _check_function(function_factory, tensor): function = function_factory(tensor) @@ -62,6 +68,11 @@ def _allreduce_function_factory(tensor): def _allreduce_async(tensor, output, average, name): + if tensor.dtype == torch.float16 and not _fp16_supported: + raise NotImplementedError( + 'float16 allreduce is not supported for PyTorch version {} < 1.0.0' + .format(torch.__version__)) + function = _check_function(_allreduce_function_factory, tensor) handle = getattr(mpi_lib, function)(tensor, output, average, name.encode() if name is not None else _NULL) @@ -107,7 +118,7 @@ def backward(ctx, grad_output): return allreduce(grad_output, ctx.average), None, None -def allreduce(tensor, average=True, name=None): +def allreduce(tensor, average=True, name=None, compression=Compression.none): """ A function that performs averaging or summation of the input tensor over all the Horovod processes. The input tensor is not modified. @@ -126,12 +137,17 @@ def allreduce(tensor, average=True, name=None): average: A flag indicating whether to compute average or summation, defaults to average. name: A name of the reduction operation. + compression: Compression algorithm used during allreduce to reduce the amount + of data sent during the each parameter update step. Defaults to + not using compression. Returns: A tensor of the same shape and type as `tensor`, averaged or summed across all processes. """ - return HorovodAllreduce.apply(tensor, average, name) + tensor_compressed, ctx = compression.compress(tensor) + summed_tensor_compressed = HorovodAllreduce.apply(tensor_compressed, average, name) + return compression.decompress(summed_tensor_compressed, ctx) def allreduce_async_(tensor, average=True, name=None): diff --git a/horovod/torch/mpi_ops_v2.cc b/horovod/torch/mpi_ops_v2.cc index 0efce26937..45ccd97fe6 100644 --- a/horovod/torch/mpi_ops_v2.cc +++ b/horovod/torch/mpi_ops_v2.cc @@ -237,11 +237,13 @@ PYBIND11_MODULE(mpi_lib_v2, m) { // allreduce m.def("horovod_torch_allreduce_async_torch_IntTensor", &DoAllreduce); m.def("horovod_torch_allreduce_async_torch_LongTensor", &DoAllreduce); + m.def("horovod_torch_allreduce_async_torch_HalfTensor", &DoAllreduce); m.def("horovod_torch_allreduce_async_torch_FloatTensor", &DoAllreduce); m.def("horovod_torch_allreduce_async_torch_DoubleTensor", &DoAllreduce); #if HOROVOD_GPU_ALLREDUCE m.def("horovod_torch_allreduce_async_torch_cuda_IntTensor", &DoAllreduce); m.def("horovod_torch_allreduce_async_torch_cuda_LongTensor", &DoAllreduce); + m.def("horovod_torch_allreduce_async_torch_cuda_HalfTensor", &DoAllreduce); m.def("horovod_torch_allreduce_async_torch_cuda_FloatTensor", &DoAllreduce); m.def("horovod_torch_allreduce_async_torch_cuda_DoubleTensor", &DoAllreduce); #else @@ -249,6 +251,8 @@ PYBIND11_MODULE(mpi_lib_v2, m) { &DoAllreduceCudaOnCPU); m.def("horovod_torch_allreduce_async_torch_cuda_LongTensor", &DoAllreduceCudaOnCPU); + m.def("horovod_torch_allreduce_async_torch_cuda_HalfTensor", + &DoAllreduceCudaOnCPU); m.def("horovod_torch_allreduce_async_torch_cuda_FloatTensor", &DoAllreduceCudaOnCPU); m.def("horovod_torch_allreduce_async_torch_cuda_DoubleTensor", @@ -261,6 +265,7 @@ PYBIND11_MODULE(mpi_lib_v2, m) { m.def("horovod_torch_allgather_async_torch_ShortTensor", &DoAllgather); m.def("horovod_torch_allgather_async_torch_IntTensor", &DoAllgather); m.def("horovod_torch_allgather_async_torch_LongTensor", &DoAllgather); + m.def("horovod_torch_allgather_async_torch_HalfTensor", &DoAllgather); m.def("horovod_torch_allgather_async_torch_FloatTensor", &DoAllgather); m.def("horovod_torch_allgather_async_torch_DoubleTensor", &DoAllgather); #if HOROVOD_GPU_ALLGATHER @@ -269,6 +274,7 @@ PYBIND11_MODULE(mpi_lib_v2, m) { m.def("horovod_torch_allgather_async_torch_cuda_ShortTensor", &DoAllgather); m.def("horovod_torch_allgather_async_torch_cuda_IntTensor", &DoAllgather); m.def("horovod_torch_allgather_async_torch_cuda_LongTensor", &DoAllgather); + m.def("horovod_torch_allgather_async_torch_cuda_HalfTensor", &DoAllgather); m.def("horovod_torch_allgather_async_torch_cuda_FloatTensor", &DoAllgather); m.def("horovod_torch_allgather_async_torch_cuda_DoubleTensor", &DoAllgather); #else @@ -282,6 +288,8 @@ PYBIND11_MODULE(mpi_lib_v2, m) { &DoAllgatherCudaOnCPU); m.def("horovod_torch_allgather_async_torch_cuda_LongTensor", &DoAllgatherCudaOnCPU); + m.def("horovod_torch_allgather_async_torch_cuda_HalfTensor", + &DoAllgatherCudaOnCPU); m.def("horovod_torch_allgather_async_torch_cuda_FloatTensor", &DoAllgatherCudaOnCPU); m.def("horovod_torch_allgather_async_torch_cuda_DoubleTensor", @@ -294,6 +302,7 @@ PYBIND11_MODULE(mpi_lib_v2, m) { m.def("horovod_torch_broadcast_async_torch_ShortTensor", &DoBroadcast); m.def("horovod_torch_broadcast_async_torch_IntTensor", &DoBroadcast); m.def("horovod_torch_broadcast_async_torch_LongTensor", &DoBroadcast); + m.def("horovod_torch_broadcast_async_torch_HalfTensor", &DoBroadcast); m.def("horovod_torch_broadcast_async_torch_FloatTensor", &DoBroadcast); m.def("horovod_torch_broadcast_async_torch_DoubleTensor", &DoBroadcast); #if HOROVOD_GPU_BROADCAST @@ -302,6 +311,7 @@ PYBIND11_MODULE(mpi_lib_v2, m) { m.def("horovod_torch_broadcast_async_torch_cuda_ShortTensor", &DoBroadcast); m.def("horovod_torch_broadcast_async_torch_cuda_IntTensor", &DoBroadcast); m.def("horovod_torch_broadcast_async_torch_cuda_LongTensor", &DoBroadcast); + m.def("horovod_torch_broadcast_async_torch_cuda_HalfTensor", &DoBroadcast); m.def("horovod_torch_broadcast_async_torch_cuda_FloatTensor", &DoBroadcast); m.def("horovod_torch_broadcast_async_torch_cuda_DoubleTensor", &DoBroadcast); #else @@ -315,6 +325,8 @@ PYBIND11_MODULE(mpi_lib_v2, m) { &DoBroadcastCudaOnCPU); m.def("horovod_torch_broadcast_async_torch_cuda_LongTensor", &DoBroadcastCudaOnCPU); + m.def("horovod_torch_broadcast_async_torch_cuda_HalfTensor", + &DoBroadcastCudaOnCPU); m.def("horovod_torch_broadcast_async_torch_cuda_FloatTensor", &DoBroadcastCudaOnCPU); m.def("horovod_torch_broadcast_async_torch_cuda_DoubleTensor", diff --git a/test/test_tensorflow.py b/test/test_tensorflow.py index 80251cd886..fc17fb39ad 100644 --- a/test/test_tensorflow.py +++ b/test/test_tensorflow.py @@ -133,7 +133,7 @@ def test_horovod_allreduce_gpu(self): size = hvd.size() with self.test_session(config=self.config) as session: - dtypes = [tf.int32, tf.int64, tf.float32, tf.float64] + dtypes = [tf.int32, tf.int64, tf.float16, tf.float32, tf.float64] dims = [1, 2, 3] for dtype, dim in itertools.product(dtypes, dims): with tf.device("/gpu:%d" % local_rank): @@ -174,7 +174,7 @@ def test_horovod_allreduce_gpu_fused(self): size = hvd.size() with self.test_session(config=self.config) as session: - dtypes = [tf.int32, tf.int64, tf.float32, tf.float64] + dtypes = [tf.int32, tf.int64, tf.float16, tf.float32, tf.float64] dims = [1, 2, 3] tests = [] for dtype, dim in itertools.product(dtypes, dims): @@ -219,7 +219,7 @@ def test_horovod_allreduce_multi_gpu(self): iter = 0 gpu_ids = [local_rank * 2, local_rank * 2 + 1] with self.test_session(config=self.config) as session: - dtypes = [tf.int32, tf.int64, tf.float32, tf.float64] + dtypes = [tf.int32, tf.int64, tf.float16, tf.float32, tf.float64] dims = [1, 2, 3] for dtype, dim in itertools.product(dtypes, dims): iter += 1 @@ -353,8 +353,8 @@ def test_horovod_allgather(self): with self.test_session(config=self.config) as session: dtypes = [tf.uint8, tf.int8, tf.uint16, tf.int16, - tf.int32, tf.int64, tf.float32, tf.float64, - tf.bool] + tf.int32, tf.int64, tf.float16, tf.float32, + tf.float64, tf.bool] dims = [1, 2, 3] for dtype, dim in itertools.product(dtypes, dims): tensor = tf.ones([17] * dim) * rank @@ -392,8 +392,8 @@ def test_horovod_allgather_variable_size(self): with self.test_session(config=self.config) as session: dtypes = [tf.uint8, tf.int8, tf.uint16, tf.int16, - tf.int32, tf.int64, tf.float32, tf.float64, - tf.bool] + tf.int32, tf.int64, tf.float16, tf.float32, + tf.float64, tf.bool] dims = [1, 2, 3] for dtype, dim in itertools.product(dtypes, dims): # Support tests up to MPI Size of 35 @@ -518,8 +518,8 @@ def test_horovod_broadcast(self): with self.test_session(config=self.config) as session: dtypes = [tf.uint8, tf.int8, tf.uint16, tf.int16, - tf.int32, tf.int64, tf.float32, tf.float64, - tf.bool] + tf.int32, tf.int64, tf.float16, tf.float32, + tf.float64, tf.bool] dims = [1, 2, 3] root_ranks = list(range(size)) for dtype, dim, root_rank in itertools.product(dtypes, dims, root_ranks): @@ -623,6 +623,43 @@ def test_horovod_broadcast_grad(self): "gradient %s differs from expected %s, " "error: %s" % (grad_out, expected, str(err))) + def test_compression_fp16(self): + valid_dtypes = [tf.float16, tf.float32, tf.float64] + invalid_dtypes = [tf.uint8, tf.int8, tf.uint16, tf.int16, + tf.int32, tf.int64, tf.bool] + + tensor_size = [17] * 3 + compression = hvd.Compression.fp16 + + with self.test_session(config=self.config) as session: + for dtype in valid_dtypes: + tensor = tf.ones(tensor_size, dtype=dtype) + + tensor_compressed, ctx = compression.compress(tensor) + self.assertEqual(tensor_compressed.dtype, tf.float16) + + tensor_decompressed = compression.decompress(tensor_compressed, ctx) + self.assertEqual(tensor_decompressed.dtype, dtype) + + actual = session.run(tensor_decompressed) + expected = np.ones(tensor_size) + err = np.linalg.norm(expected - actual) + self.assertLess(err, 0.00000001) + + for dtype in invalid_dtypes: + tensor = tf.ones(tensor_size, dtype=dtype) + + tensor_compressed, ctx = compression.compress(tensor) + self.assertEqual(tensor_compressed.dtype, dtype) + + tensor_decompressed = compression.decompress(tensor_compressed, ctx) + self.assertEqual(tensor_decompressed.dtype, dtype) + + actual = session.run(tensor_decompressed) + expected = np.ones(tensor_size) + err = np.linalg.norm(expected - actual) + self.assertLess(err, 0.00000001) + if __name__ == '__main__': tf.test.main() diff --git a/test/test_torch.py b/test/test_torch.py index dcfdd7f1ea..d2228389a7 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -17,19 +17,22 @@ from __future__ import division from __future__ import print_function +from distutils.version import LooseVersion import inspect import itertools +import numpy as np import os import tempfile import torch import torch.nn.functional as F import unittest -import numpy as np import horovod.torch as hvd from common import mpi_env_rank_and_size +_fp16_supported = LooseVersion(torch.__version__) >= LooseVersion('1.0.0') + class TorchTests(unittest.TestCase): """ @@ -59,6 +62,8 @@ def test_horovod_allreduce(self): if torch.cuda.is_available(): dtypes += [torch.cuda.IntTensor, torch.cuda.LongTensor, torch.cuda.FloatTensor, torch.cuda.DoubleTensor] + if _fp16_supported: + dtypes += [torch.cuda.HalfTensor] dims = [1, 2, 3] for dtype, dim in itertools.product(dtypes, dims): torch.manual_seed(1234) @@ -91,6 +96,8 @@ def test_horovod_allreduce_average(self): if torch.cuda.is_available(): dtypes += [torch.cuda.IntTensor, torch.cuda.LongTensor, torch.cuda.FloatTensor, torch.cuda.DoubleTensor] + if _fp16_supported: + dtypes += [torch.cuda.HalfTensor] dims = [1, 2, 3] for dtype, dim in itertools.product(dtypes, dims): torch.manual_seed(1234) @@ -122,6 +129,8 @@ def test_horovod_allreduce_inplace(self): if torch.cuda.is_available(): dtypes += [torch.cuda.IntTensor, torch.cuda.LongTensor, torch.cuda.FloatTensor, torch.cuda.DoubleTensor] + if _fp16_supported: + dtypes += [torch.cuda.HalfTensor] dims = [1, 2, 3] for dtype, dim in itertools.product(dtypes, dims): torch.manual_seed(1234) @@ -155,6 +164,8 @@ def test_horovod_allreduce_async_fused(self): if torch.cuda.is_available(): dtypes += [torch.cuda.IntTensor, torch.cuda.LongTensor, torch.cuda.FloatTensor, torch.cuda.DoubleTensor] + if _fp16_supported: + dtypes += [torch.cuda.HalfTensor] dims = [1, 2, 3] tests = [] is_hvd_poll_false_once = False @@ -202,6 +213,9 @@ def test_horovod_allreduce_multi_gpu(self): iter = 0 dtypes = [torch.cuda.IntTensor, torch.cuda.LongTensor, torch.cuda.FloatTensor, torch.cuda.DoubleTensor] + if _fp16_supported: + dtypes += [torch.cuda.HalfTensor] + dims = [1, 2, 3] for dtype, dim in itertools.product(dtypes, dims): iter += 1 @@ -320,6 +334,8 @@ def test_horovod_allreduce_grad(self): dtypes = [torch.FloatTensor, torch.DoubleTensor] if torch.cuda.is_available(): dtypes += [torch.cuda.FloatTensor, torch.cuda.DoubleTensor] + if _fp16_supported: + dtypes += [torch.cuda.HalfTensor] dims = [1, 2, 3] for dtype, dim in itertools.product(dtypes, dims): torch.manual_seed(1234) @@ -344,6 +360,8 @@ def test_horovod_allreduce_grad_average(self): dtypes = [torch.FloatTensor, torch.DoubleTensor] if torch.cuda.is_available(): dtypes += [torch.cuda.FloatTensor, torch.cuda.DoubleTensor] + if _fp16_supported: + dtypes += [torch.cuda.HalfTensor] dims = [1, 2, 3] for dtype, dim in itertools.product(dtypes, dims): torch.manual_seed(1234) @@ -371,8 +389,10 @@ def test_horovod_allgather(self): torch.IntTensor, torch.LongTensor, torch.FloatTensor, torch.DoubleTensor] if torch.cuda.is_available(): dtypes += [torch.cuda.ByteTensor, torch.cuda.CharTensor, torch.cuda.ShortTensor, - torch.cuda.IntTensor, torch.cuda.LongTensor, torch.cuda.FloatTensor, - torch.cuda.DoubleTensor] + torch.cuda.IntTensor, torch.cuda.LongTensor, + torch.cuda.FloatTensor, torch.cuda.DoubleTensor] + if _fp16_supported: + dtypes += [torch.cuda.HalfTensor] dims = [1, 2, 3] for dtype, dim in itertools.product(dtypes, dims): tensor = torch.FloatTensor(*([17] * dim)).fill_(1).mul_(rank) @@ -399,8 +419,10 @@ def test_horovod_allgather_variable_size(self): torch.IntTensor, torch.LongTensor, torch.FloatTensor, torch.DoubleTensor] if torch.cuda.is_available(): dtypes += [torch.cuda.ByteTensor, torch.cuda.CharTensor, torch.cuda.ShortTensor, - torch.cuda.IntTensor, torch.cuda.LongTensor, torch.cuda.FloatTensor, - torch.cuda.DoubleTensor] + torch.cuda.IntTensor, torch.cuda.LongTensor, + torch.cuda.FloatTensor, torch.cuda.DoubleTensor] + if _fp16_supported: + dtypes += [torch.cuda.HalfTensor] dims = [1, 2, 3] for dtype, dim in itertools.product(dtypes, dims): # Support tests up to MPI Size of 35 @@ -480,6 +502,8 @@ def test_horovod_allgather_grad(self): dtypes = [torch.FloatTensor, torch.DoubleTensor] if torch.cuda.is_available(): dtypes += [torch.cuda.FloatTensor, torch.cuda.DoubleTensor] + if _fp16_supported: + dtypes += [torch.cuda.HalfTensor] dims = [1, 2, 3] for dtype, dim in itertools.product(dtypes, dims): # Support tests up to MPI Size of 35 @@ -525,8 +549,10 @@ def test_horovod_broadcast(self): torch.IntTensor, torch.LongTensor, torch.FloatTensor, torch.DoubleTensor] if torch.cuda.is_available(): dtypes += [torch.cuda.ByteTensor, torch.cuda.CharTensor, torch.cuda.ShortTensor, - torch.cuda.IntTensor, torch.cuda.LongTensor, torch.cuda.FloatTensor, - torch.cuda.DoubleTensor] + torch.cuda.IntTensor, torch.cuda.LongTensor, + torch.cuda.FloatTensor, torch.cuda.DoubleTensor] + if _fp16_supported: + dtypes += [torch.cuda.HalfTensor] dims = [1, 2, 3] root_ranks = list(range(size)) for dtype, dim, root_rank in itertools.product(dtypes, dims, root_ranks): @@ -555,8 +581,10 @@ def test_horovod_broadcast_inplace(self): torch.IntTensor, torch.LongTensor, torch.FloatTensor, torch.DoubleTensor] if torch.cuda.is_available(): dtypes += [torch.cuda.ByteTensor, torch.cuda.CharTensor, torch.cuda.ShortTensor, - torch.cuda.IntTensor, torch.cuda.LongTensor, torch.cuda.FloatTensor, - torch.cuda.DoubleTensor] + torch.cuda.IntTensor, torch.cuda.LongTensor, + torch.cuda.FloatTensor, torch.cuda.DoubleTensor] + if _fp16_supported: + dtypes += [torch.cuda.HalfTensor] dims = [1, 2, 3] root_ranks = list(range(size)) for dtype, dim, root_rank in itertools.product(dtypes, dims, root_ranks): @@ -647,6 +675,8 @@ def test_horovod_broadcast_grad(self): dtypes = [torch.FloatTensor, torch.DoubleTensor] if torch.cuda.is_available(): dtypes += [torch.cuda.FloatTensor, torch.cuda.DoubleTensor] + if _fp16_supported: + dtypes += [torch.cuda.HalfTensor] dims = [1, 2, 3] root_ranks = list(range(size)) for dtype, dim, root_rank in itertools.product(dtypes, dims, root_ranks): @@ -787,3 +817,38 @@ def new_optimizer(cls): (opt_param_value == opt_param_value_after).all()) else: self.assertEqual(opt_param_value, opt_param_value_after) + + def test_compression_fp16(self): + valid_dtypes = [torch.float32, torch.float64] + invalid_dtypes = [torch.uint8, torch.int8, torch.int16, + torch.int32, torch.int64] + + tensor_size = [5] * 3 + compression = hvd.Compression.fp16 + + for dtype in valid_dtypes: + tensor = torch.ones(tensor_size, dtype=dtype) + + tensor_compressed, ctx = compression.compress(tensor) + self.assertEqual(tensor_compressed.dtype, torch.float16) + + tensor_decompressed = compression.decompress(tensor_compressed, ctx) + self.assertEqual(tensor_decompressed.dtype, dtype) + + expected = np.ones(tensor_size) + err = np.linalg.norm(expected - tensor_decompressed.data.numpy()) + self.assertLess(err, 0.00000001) + + for dtype in invalid_dtypes: + tensor = torch.ones(tensor_size, dtype=dtype) + + tensor_compressed, ctx = compression.compress(tensor) + self.assertEqual(tensor_compressed.dtype, dtype) + + tensor_decompressed = compression.decompress(tensor_compressed, ctx) + self.assertEqual(tensor_decompressed.dtype, dtype) + + if dtype != torch.int8: # Cannot cast to NumPy with a CharTensor + expected = np.ones(tensor_size) + err = np.linalg.norm(expected - tensor_decompressed.data.numpy()) + self.assertLess(err, 0.00000001)