Skip to content

Commit

Permalink
Introducing Adasum algorithm to do allreduction. (#1485)
Browse files Browse the repository at this point in the history
* Introducing Adasum algorithm to do allreduction.
 1. Adasum operations for both CPU and NCCL build of Horovod
 2. Framework support in Tensorflow and Pytorch to enable Adasum
 3. A new optimizer added for Tensorflow and Pytorch to deliver more accurate estimation when using Adasum

Main contributors:
Olli Saarikivi (olsaarik)
Vadim Eksarevskiy (vaeksare)
Jaliya Ekanayake (jaliyae)
Todd Mytkowicz (klipto)
Saeed Maleki(saeedmaleki)
Sergii Dymchenko(kit1980)

Signed-off-by: Tix <tix@microsoft.com>
  • Loading branch information
Tixxx authored and tgaddair committed Nov 25, 2019
1 parent 3237ccc commit 5fa1d7a
Show file tree
Hide file tree
Showing 42 changed files with 2,472 additions and 157 deletions.
15 changes: 12 additions & 3 deletions examples/pytorch_imagenet_resnet50.py
Expand Up @@ -30,6 +30,8 @@
help='number of batches processed locally before '
'executing allreduce across workers; it multiplies '
'total batch size.')
parser.add_argument('--use-adasum', action='store_true', default=False,
help='use adasum algorithm to do reduction')

# Default settings from https://arxiv.org/abs/1706.02677.
parser.add_argument('--batch-size', type=int, default=32,
Expand Down Expand Up @@ -125,15 +127,21 @@
# Set up standard ResNet-50 model.
model = models.resnet50()

# By default, Adasum doesn't need scaling up learning rate.
# For sum/average with gradient Accumulation: scale learning rate by batches_per_allreduce
lr_scaler = args.batches_per_allreduce * hvd.size() if not args.use_adasum else 1

if args.cuda:
# Move model to GPU.
model.cuda()
# If using GPU Adasum allreduce, scale learning rate by local_size.
if args.use_adasum and hvd.nccl_built():
lr_scaler = args.batches_per_allreduce * hvd.local_size()

# Horovod: scale learning rate by the number of GPUs.
# Gradient Accumulation: scale learning rate by batches_per_allreduce
optimizer = optim.SGD(model.parameters(),
lr=(args.base_lr *
args.batches_per_allreduce * hvd.size()),
lr_scaler),
momentum=args.momentum, weight_decay=args.wd)

# Horovod: (optional) compression algorithm.
Expand All @@ -143,7 +151,8 @@
optimizer = hvd.DistributedOptimizer(
optimizer, named_parameters=model.named_parameters(),
compression=compression,
backward_passes_per_step=args.batches_per_allreduce)
backward_passes_per_step=args.batches_per_allreduce,
op=hvd.Adasum if args.use_adasum else hvd.Average)

# Restore from a previous checkpoint, if initial_epoch is specified.
# Horovod: restore on the first worker which will broadcast weights to other workers.
Expand Down
16 changes: 13 additions & 3 deletions examples/pytorch_mnist.py
Expand Up @@ -27,6 +27,9 @@
help='how many batches to wait before logging training status')
parser.add_argument('--fp16-allreduce', action='store_true', default=False,
help='use fp16 compression during allreduce')
parser.add_argument('--use-adasum', action='store_true', default=False,
help='use adasum algorithm to do reduction')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

Expand Down Expand Up @@ -89,12 +92,18 @@ def forward(self, x):

model = Net()

# By default, Adasum doesn't need scaling up learning rate.
lr_scaler = hvd.size() if not args.use_adasum else 1

if args.cuda:
# Move model to GPU.
model.cuda()
# If using GPU Adasum allreduce, scale learning rate by local_size.
if args.use_adasum and hvd.nccl_built():
lr_scaler = hvd.local_size()

# Horovod: scale learning rate by the number of GPUs.
optimizer = optim.SGD(model.parameters(), lr=args.lr * hvd.size(),
# Horovod: scale learning rate by lr_scaler.
optimizer = optim.SGD(model.parameters(), lr=args.lr * lr_scaler,
momentum=args.momentum)

# Horovod: broadcast parameters & optimizer state.
Expand All @@ -107,7 +116,8 @@ def forward(self, x):
# Horovod: wrap optimizer with DistributedOptimizer.
optimizer = hvd.DistributedOptimizer(optimizer,
named_parameters=model.named_parameters(),
compression=compression)
compression=compression,
op=hvd.Adasum if args.use_adasum else hvd.Average)


def train(epoch):
Expand Down
14 changes: 12 additions & 2 deletions examples/pytorch_synthetic_benchmark.py
Expand Up @@ -31,6 +31,9 @@
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')

parser.add_argument('--use-adasum', action='store_true', default=False,
help='use adasum algorithm to do reduction')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

Expand All @@ -45,19 +48,26 @@
# Set up standard model.
model = getattr(models, args.model)()

# By default, Adasum doesn't need scaling up learning rate.
lr_scaler = hvd.size() if not args.use_adasum else 1

if args.cuda:
# Move model to GPU.
model.cuda()
# If using GPU Adasum allreduce, scale learning rate by local_size.
if args.use_adasum and hvd.nccl_built():
lr_scaler = hvd.local_size()

optimizer = optim.SGD(model.parameters(), lr=0.01)
optimizer = optim.SGD(model.parameters(), lr=0.01 * lr_scaler)

# Horovod: (optional) compression algorithm.
compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none

# Horovod: wrap optimizer with DistributedOptimizer.
optimizer = hvd.DistributedOptimizer(optimizer,
named_parameters=model.named_parameters(),
compression=compression)
compression=compression,
op=hvd.Adasum if args.use_adasum else hvd.Average)

# Horovod: broadcast parameters & optimizer state.
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
Expand Down
18 changes: 15 additions & 3 deletions examples/tensorflow_mnist.py
Expand Up @@ -18,13 +18,19 @@
import tensorflow as tf
import horovod.tensorflow as hvd
import numpy as np
import argparse

from tensorflow import keras

layers = tf.layers

tf.logging.set_verbosity(tf.logging.INFO)

# Training settings
parser = argparse.ArgumentParser(description='Tensorflow MNIST Example')
parser.add_argument('--use-adasum', action='store_true', default=False,
help='use adasum algorithm to do reduction')
args = parser.parse_args()

def conv_model(feature, target, mode):
"""2-layer convolution model."""
Expand Down Expand Up @@ -111,11 +117,17 @@ def main(_):
label = tf.placeholder(tf.float32, [None], name='label')
predict, loss = conv_model(image, label, tf.estimator.ModeKeys.TRAIN)

# Horovod: adjust learning rate based on number of GPUs.
opt = tf.train.AdamOptimizer(0.001 * hvd.size())
lr_scaler = hvd.size()
# By default, Adasum doesn't need scaling when increasing batch size. If used with NCCL,
# scale lr by local_size
if args.use_adasum:
lr_scaler = hvd.local_size() if hvd.nccl_built() else 1

# Horovod: adjust learning rate based on lr_scaler.
opt = tf.train.AdamOptimizer(0.001 * lr_scaler)

# Horovod: add Horovod Distributed Optimizer.
opt = hvd.DistributedOptimizer(opt)
opt = hvd.DistributedOptimizer(opt, op=hvd.Adasum if args.use_adasum else hvd.Average)

global_step = tf.train.get_or_create_global_step()
train_op = opt.minimize(loss, global_step=global_step)
Expand Down
12 changes: 10 additions & 2 deletions examples/tensorflow_synthetic_benchmark.py
Expand Up @@ -31,6 +31,8 @@
help='enables eager execution')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--use-adasum', action='store_true', default=False,
help='use adasum algorithm to do reduction')

args = parser.parse_args()
args.cuda = not args.no_cuda
Expand All @@ -53,13 +55,19 @@
# Set up standard model.
model = getattr(applications, args.model)(weights=None)

opt = tf.train.GradientDescentOptimizer(0.01)
lr_scaler = hvd.size()
# By default, Adasum doesn't need scaling when increasing batch size. If used with NCCL,
# scale lr by local_size
if args.use_adasum:
lr_scaler = hvd.local_size() if args.cuda and hvd.nccl_built() else 1

opt = tf.train.GradientDescentOptimizer(0.01 * lr_scaler)

# Horovod: (optional) compression algorithm.
compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none

# Horovod: wrap optimizer with DistributedOptimizer.
opt = hvd.DistributedOptimizer(opt, compression=compression)
opt = hvd.DistributedOptimizer(opt, compression=compression, op=hvd.Adasum if args.use_adasum else hvd.Average)

init = tf.global_variables_initializer()
bcast_op = hvd.broadcast_global_variables(0)
Expand Down
15 changes: 14 additions & 1 deletion horovod/common/basics.py
@@ -1,5 +1,5 @@
# Copyright (C) 2019 Uber Technologies, Inc.
#
# Modifications copyright Microsoft
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -26,6 +26,10 @@ def __init__(self, pkg_path, *args):
full_path = util.get_extension_full_path(pkg_path, *args)
self.MPI_LIB_CTYPES = ctypes.CDLL(full_path, mode=ctypes.RTLD_GLOBAL)

self.Average = self.MPI_LIB_CTYPES.horovod_reduce_op_average()
self.Sum = self.MPI_LIB_CTYPES.horovod_reduce_op_sum()
self.Adasum = self.MPI_LIB_CTYPES.horovod_reduce_op_adasum()

def init(self, comm=None):
"""A function that initializes Horovod.
Expand Down Expand Up @@ -115,6 +119,15 @@ def local_rank(self):
'Horovod has not been initialized; use hvd.init().')
return local_rank

def is_homogeneous(self):
"""Returns True if the cluster is homogeneous.
Returns:
A boolean value indicating whether every node in the cluster has same number of ranks.
"""
is_homogeneous = self.MPI_LIB_CTYPES.horovod_is_homogeneous()
return bool(is_homogeneous)

def mpi_threads_supported(self):
"""A function that returns a flag indicating whether MPI multi-threading is supported.
Expand Down
2 changes: 2 additions & 0 deletions horovod/common/common.h
Expand Up @@ -40,6 +40,7 @@ namespace common {
#define MEMCPY_IN_HOST_BUFFER "MEMCPY_IN_HOST_BUFFER"
#define MEMCPY_IN_SHARED_BUFFER "MEMCPY_IN_SHARED_BUFFER"
#define MPI_ALLREDUCE "MPI_ALLREDUCE"
#define MPI_ADASUM_ALLREDUCE "MPI_ADASUM_ALLREDUCE"
#define MEMCPY_OUT_HOST_BUFFER "MEMCPY_OUT_HOST_BUFFER"
#define NCCL_ALLREDUCE "NCCL_ALLREDUCE"
#define MEMCPY_OUT_FUSION_BUFFER "MEMCPY_OUT_FUSION_BUFFER"
Expand Down Expand Up @@ -83,6 +84,7 @@ namespace common {
#define HOROVOD_MPI "MPI"
#define HOROVOD_MLSL "MLSL"
#define HOROVOD_GLOO "GLOO"
#define HOROVOD_ADASUM_MPI_CHUNK_SIZE "HOROVOD_ADASUM_MPI_CHUNK_SIZE"

// String constant for gloo interface.
#define GLOO_DEFAULT_IFACE "eth0"
Expand Down
25 changes: 18 additions & 7 deletions horovod/common/controller.cc
@@ -1,4 +1,5 @@
// Copyright 2019 Uber Technologies, Inc. All Rights Reserved.
// Modifications copyright Microsoft
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -321,7 +322,8 @@ ResponseList Controller::ComputeResponseList(std::atomic_bool& shut_down,
// All workers add supported responses to cache. This updates the cache
// order consistently across workers.
for (auto& response : response_list.responses()) {
if (response.response_type() == Response::ResponseType::ALLREDUCE &&
if ((response.response_type() == Response::ResponseType::ALLREDUCE ||
response.response_type() == Response::ResponseType::ADASUM) &&
(int)response.devices().size() == size_) {
response_cache_.put(response, tensor_queue_);
}
Expand All @@ -338,10 +340,10 @@ int64_t Controller::TensorFusionThresholdBytes() {
int64_t proposed_fusion_threshold =
parameter_manager_.TensorFusionThresholdBytes();

// If the cluster is homogeneous and hierarchical allreduce is enabled,
// If the cluster is homogeneous,
// adjust buffer size to make sure it is divisible by local_size to improve
// performance.
if (is_homogeneous_ && parameter_manager_.HierarchicalAllreduce()) {
// performance for operations that perform local reductions by default such as Adasum.
if (is_homogeneous_) {
// Assume the worst-case data type float64, since if it is divisible with
// float64, it will be divisible for other types too.

Expand Down Expand Up @@ -400,6 +402,7 @@ Response Controller::ConstructResponse(std::string& name, int joined_size) {
// If we are doing an allreduce or broadcast, check that all tensor shapes are
// identical.
if (message_type == Request::ALLREDUCE ||
message_type == Request::ADASUM ||
message_type == Request::BROADCAST) {
TensorShape tensor_shape;
for (auto dim : requests[0].tensor_shape()) {
Expand Down Expand Up @@ -497,7 +500,7 @@ Response Controller::ConstructResponse(std::string& name, int joined_size) {

// If there is at least one rank that requested Join, communicate tensor sizes
// in the response, because joined ranks don't have this info.
if (joined_size > 0 && message_type == Request::ALLREDUCE) {
if (joined_size > 0 && (message_type == Request::ALLREDUCE || message_type == Request::ADASUM)) {
TensorShape tensor_shape;
for (auto dim : requests[0].tensor_shape()) {
tensor_shape.AddDim(dim);
Expand Down Expand Up @@ -575,6 +578,14 @@ Response Controller::ConstructResponse(std::string& name, int joined_size) {
}
} else if (message_type == Request::BROADCAST) {
response.set_response_type(Response::BROADCAST);
} else if (message_type == Request::ADASUM) {
response.set_response_type(Response::ADASUM);
if (joined_size > 0) {
for (auto dim : tensor_sizes) {
response.add_tensor_size(dim);
}
response.set_tensor_type(data_type);
}
}
response.set_devices(devices);

Expand Down Expand Up @@ -622,8 +633,8 @@ ResponseList Controller::FuseResponses(std::deque<Response>& responses) {
responses.pop_front();
int64_t tensor_size = 0;
DataType dtype;

if (response.response_type() == Response::ResponseType::ALLREDUCE) {
if (response.response_type() == Response::ResponseType::ALLREDUCE ||
response.response_type() == Response::ResponseType::ADASUM) {
// Attempt to add more responses to this fused response.

// found_tensor can be false for ranks that did Join.
Expand Down
6 changes: 6 additions & 0 deletions horovod/common/global_state.h
Expand Up @@ -102,11 +102,17 @@ struct HorovodGlobalState {

// Number of ranks that did Join()
int joined_size = 0;

// If a rank is Joined, AllReduce uses temporary 0 tensors for it.
bool joined = false;

// ID of the device to create temporary tensors while Joined
int join_device = CPU_DEVICE_ID;

// Chunk size for MPI send/recv in Adasum allreduce. Some versions of Intel MPI
// benefit from a smaller chunk size.
int64_t adasum_mpi_chunk_size = 1<<30;

~HorovodGlobalState() {
// Make sure that the destructor of the background thread is safe to
// call. If a thread is still joinable (not detached or complete) its
Expand Down
7 changes: 7 additions & 0 deletions horovod/common/message.cc
@@ -1,5 +1,6 @@
// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
// Modifications copyright (C) 2019 Uber Technologies, Inc.
// Modifications copyright Microsoft
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -75,6 +76,9 @@ const std::string& Request::RequestType_Name(RequestType value) {
case RequestType::JOIN:
static const std::string join("JOIN");
return join;
case RequestType::ADASUM:
static const std::string adasum("ADASUM");
return adasum;
default:
static const std::string unknown("<unknown>");
return unknown;
Expand Down Expand Up @@ -242,6 +246,9 @@ const std::string& Response::ResponseType_Name(ResponseType value) {
case ResponseType::JOIN:
static const std::string join("JOIN");
return join;
case ResponseType::ADASUM:
static const std::string adasum("ADASUM");
return adasum;
case ResponseType::ERROR:
static const std::string error("ERROR");
return error;
Expand Down
5 changes: 3 additions & 2 deletions horovod/common/message.h
@@ -1,5 +1,6 @@
// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
// Modifications copyright (C) 2019 Uber Technologies, Inc.
// Modifications copyright Microsoft
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -45,7 +46,7 @@ const std::string& DataType_Name(DataType value);
class Request {
public:
enum RequestType {
ALLREDUCE = 0, ALLGATHER = 1, BROADCAST = 2, JOIN = 3
ALLREDUCE = 0, ALLGATHER = 1, BROADCAST = 2, JOIN = 3, ADASUM = 4
};

static const std::string& RequestType_Name(RequestType value);
Expand Down Expand Up @@ -130,7 +131,7 @@ class RequestList {
class Response {
public:
enum ResponseType {
ALLREDUCE = 0, ALLGATHER = 1, BROADCAST = 2, JOIN = 3, ERROR = 4
ALLREDUCE = 0, ALLGATHER = 1, BROADCAST = 2, JOIN = 3, ADASUM = 4, ERROR = 5
};

static const std::string& ResponseType_Name(ResponseType value);
Expand Down

0 comments on commit 5fa1d7a

Please sign in to comment.