Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing Adasum algorithm to do allreduction. #1485

Merged
merged 20 commits into from Nov 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 12 additions & 3 deletions examples/pytorch_imagenet_resnet50.py
Expand Up @@ -30,6 +30,8 @@
help='number of batches processed locally before '
'executing allreduce across workers; it multiplies '
'total batch size.')
parser.add_argument('--use-adasum', action='store_true', default=False,
help='use adasum algorithm to do reduction')

# Default settings from https://arxiv.org/abs/1706.02677.
parser.add_argument('--batch-size', type=int, default=32,
Expand Down Expand Up @@ -125,15 +127,21 @@
# Set up standard ResNet-50 model.
model = models.resnet50()

# By default, Adasum doesn't need scaling up learning rate.
# For sum/average with gradient Accumulation: scale learning rate by batches_per_allreduce
lr_scaler = args.batches_per_allreduce * hvd.size() if not args.use_adasum else 1

if args.cuda:
# Move model to GPU.
model.cuda()
# If using GPU Adasum allreduce, scale learning rate by local_size.
if args.use_adasum and hvd.nccl_built():
lr_scaler = args.batches_per_allreduce * hvd.local_size()

# Horovod: scale learning rate by the number of GPUs.
# Gradient Accumulation: scale learning rate by batches_per_allreduce
optimizer = optim.SGD(model.parameters(),
lr=(args.base_lr *
args.batches_per_allreduce * hvd.size()),
lr_scaler),
momentum=args.momentum, weight_decay=args.wd)

# Horovod: (optional) compression algorithm.
Expand All @@ -143,7 +151,8 @@
optimizer = hvd.DistributedOptimizer(
optimizer, named_parameters=model.named_parameters(),
compression=compression,
backward_passes_per_step=args.batches_per_allreduce)
backward_passes_per_step=args.batches_per_allreduce,
op=hvd.Adasum if args.use_adasum else hvd.Average)

# Restore from a previous checkpoint, if initial_epoch is specified.
# Horovod: restore on the first worker which will broadcast weights to other workers.
Expand Down
16 changes: 13 additions & 3 deletions examples/pytorch_mnist.py
Expand Up @@ -27,6 +27,9 @@
help='how many batches to wait before logging training status')
parser.add_argument('--fp16-allreduce', action='store_true', default=False,
help='use fp16 compression during allreduce')
parser.add_argument('--use-adasum', action='store_true', default=False,
help='use adasum algorithm to do reduction')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

Expand Down Expand Up @@ -89,12 +92,18 @@ def forward(self, x):

model = Net()

# By default, Adasum doesn't need scaling up learning rate.
lr_scaler = hvd.size() if not args.use_adasum else 1

if args.cuda:
# Move model to GPU.
model.cuda()
# If using GPU Adasum allreduce, scale learning rate by local_size.
if args.use_adasum and hvd.nccl_built():
lr_scaler = hvd.local_size()

# Horovod: scale learning rate by the number of GPUs.
optimizer = optim.SGD(model.parameters(), lr=args.lr * hvd.size(),
# Horovod: scale learning rate by lr_scaler.
optimizer = optim.SGD(model.parameters(), lr=args.lr * lr_scaler,
momentum=args.momentum)

# Horovod: broadcast parameters & optimizer state.
Expand All @@ -107,7 +116,8 @@ def forward(self, x):
# Horovod: wrap optimizer with DistributedOptimizer.
optimizer = hvd.DistributedOptimizer(optimizer,
named_parameters=model.named_parameters(),
compression=compression)
compression=compression,
op=hvd.Adasum if args.use_adasum else hvd.Average)
Tixxx marked this conversation as resolved.
Show resolved Hide resolved


def train(epoch):
Expand Down
14 changes: 12 additions & 2 deletions examples/pytorch_synthetic_benchmark.py
Expand Up @@ -31,6 +31,9 @@
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')

parser.add_argument('--use-adasum', action='store_true', default=False,
help='use adasum algorithm to do reduction')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

Expand All @@ -45,19 +48,26 @@
# Set up standard model.
model = getattr(models, args.model)()

# By default, Adasum doesn't need scaling up learning rate.
lr_scaler = hvd.size() if not args.use_adasum else 1

if args.cuda:
# Move model to GPU.
model.cuda()
# If using GPU Adasum allreduce, scale learning rate by local_size.
if args.use_adasum and hvd.nccl_built():
lr_scaler = hvd.local_size()

optimizer = optim.SGD(model.parameters(), lr=0.01)
optimizer = optim.SGD(model.parameters(), lr=0.01 * lr_scaler)

# Horovod: (optional) compression algorithm.
compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none

# Horovod: wrap optimizer with DistributedOptimizer.
optimizer = hvd.DistributedOptimizer(optimizer,
named_parameters=model.named_parameters(),
compression=compression)
compression=compression,
op=hvd.Adasum if args.use_adasum else hvd.Average)

# Horovod: broadcast parameters & optimizer state.
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
Expand Down
18 changes: 15 additions & 3 deletions examples/tensorflow_mnist.py
Expand Up @@ -18,13 +18,19 @@
import tensorflow as tf
import horovod.tensorflow as hvd
import numpy as np
import argparse

from tensorflow import keras

layers = tf.layers

tf.logging.set_verbosity(tf.logging.INFO)

# Training settings
parser = argparse.ArgumentParser(description='Tensorflow MNIST Example')
parser.add_argument('--use-adasum', action='store_true', default=False,
help='use adasum algorithm to do reduction')
args = parser.parse_args()

def conv_model(feature, target, mode):
"""2-layer convolution model."""
Expand Down Expand Up @@ -111,11 +117,17 @@ def main(_):
label = tf.placeholder(tf.float32, [None], name='label')
predict, loss = conv_model(image, label, tf.estimator.ModeKeys.TRAIN)

# Horovod: adjust learning rate based on number of GPUs.
opt = tf.train.AdamOptimizer(0.001 * hvd.size())
lr_scaler = hvd.size()
# By default, Adasum doesn't need scaling when increasing batch size. If used with NCCL,
# scale lr by local_size
if args.use_adasum:
lr_scaler = hvd.local_size() if hvd.nccl_built() else 1

# Horovod: adjust learning rate based on lr_scaler.
opt = tf.train.AdamOptimizer(0.001 * lr_scaler)

# Horovod: add Horovod Distributed Optimizer.
opt = hvd.DistributedOptimizer(opt)
opt = hvd.DistributedOptimizer(opt, op=hvd.Adasum if args.use_adasum else hvd.Average)

global_step = tf.train.get_or_create_global_step()
train_op = opt.minimize(loss, global_step=global_step)
Expand Down
12 changes: 10 additions & 2 deletions examples/tensorflow_synthetic_benchmark.py
Expand Up @@ -31,6 +31,8 @@
help='enables eager execution')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--use-adasum', action='store_true', default=False,
help='use adasum algorithm to do reduction')

args = parser.parse_args()
args.cuda = not args.no_cuda
Expand All @@ -53,13 +55,19 @@
# Set up standard model.
model = getattr(applications, args.model)(weights=None)

opt = tf.train.GradientDescentOptimizer(0.01)
lr_scaler = hvd.size()
# By default, Adasum doesn't need scaling when increasing batch size. If used with NCCL,
# scale lr by local_size
if args.use_adasum:
lr_scaler = hvd.local_size() if args.cuda and hvd.nccl_built() else 1

opt = tf.train.GradientDescentOptimizer(0.01 * lr_scaler)

# Horovod: (optional) compression algorithm.
compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none

# Horovod: wrap optimizer with DistributedOptimizer.
opt = hvd.DistributedOptimizer(opt, compression=compression)
opt = hvd.DistributedOptimizer(opt, compression=compression, op=hvd.Adasum if args.use_adasum else hvd.Average)

init = tf.global_variables_initializer()
bcast_op = hvd.broadcast_global_variables(0)
Expand Down
15 changes: 14 additions & 1 deletion horovod/common/basics.py
@@ -1,5 +1,5 @@
# Copyright (C) 2019 Uber Technologies, Inc.
#
# Modifications copyright Microsoft
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -26,6 +26,10 @@ def __init__(self, pkg_path, *args):
full_path = util.get_extension_full_path(pkg_path, *args)
self.MPI_LIB_CTYPES = ctypes.CDLL(full_path, mode=ctypes.RTLD_GLOBAL)

self.Average = self.MPI_LIB_CTYPES.horovod_reduce_op_average()
self.Sum = self.MPI_LIB_CTYPES.horovod_reduce_op_sum()
self.Adasum = self.MPI_LIB_CTYPES.horovod_reduce_op_adasum()

def init(self, comm=None):
"""A function that initializes Horovod.

Expand Down Expand Up @@ -115,6 +119,15 @@ def local_rank(self):
'Horovod has not been initialized; use hvd.init().')
return local_rank

def is_homogeneous(self):
"""Returns True if the cluster is homogeneous.

Returns:
A boolean value indicating whether every node in the cluster has same number of ranks.
"""
is_homogeneous = self.MPI_LIB_CTYPES.horovod_is_homogeneous()
return bool(is_homogeneous)

def mpi_threads_supported(self):
"""A function that returns a flag indicating whether MPI multi-threading is supported.

Expand Down
2 changes: 2 additions & 0 deletions horovod/common/common.h
Expand Up @@ -40,6 +40,7 @@ namespace common {
#define MEMCPY_IN_HOST_BUFFER "MEMCPY_IN_HOST_BUFFER"
#define MEMCPY_IN_SHARED_BUFFER "MEMCPY_IN_SHARED_BUFFER"
#define MPI_ALLREDUCE "MPI_ALLREDUCE"
#define MPI_ADASUM_ALLREDUCE "MPI_ADASUM_ALLREDUCE"
#define MEMCPY_OUT_HOST_BUFFER "MEMCPY_OUT_HOST_BUFFER"
#define NCCL_ALLREDUCE "NCCL_ALLREDUCE"
#define MEMCPY_OUT_FUSION_BUFFER "MEMCPY_OUT_FUSION_BUFFER"
Expand Down Expand Up @@ -83,6 +84,7 @@ namespace common {
#define HOROVOD_MPI "MPI"
#define HOROVOD_MLSL "MLSL"
#define HOROVOD_GLOO "GLOO"
#define HOROVOD_ADASUM_MPI_CHUNK_SIZE "HOROVOD_ADASUM_MPI_CHUNK_SIZE"

// String constant for gloo interface.
#define GLOO_DEFAULT_IFACE "eth0"
Expand Down
25 changes: 18 additions & 7 deletions horovod/common/controller.cc
@@ -1,4 +1,5 @@
// Copyright 2019 Uber Technologies, Inc. All Rights Reserved.
// Modifications copyright Microsoft
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -321,7 +322,8 @@ ResponseList Controller::ComputeResponseList(std::atomic_bool& shut_down,
// All workers add supported responses to cache. This updates the cache
// order consistently across workers.
for (auto& response : response_list.responses()) {
if (response.response_type() == Response::ResponseType::ALLREDUCE &&
if ((response.response_type() == Response::ResponseType::ALLREDUCE ||
response.response_type() == Response::ResponseType::ADASUM) &&
(int)response.devices().size() == size_) {
response_cache_.put(response, tensor_queue_);
}
Expand All @@ -338,10 +340,10 @@ int64_t Controller::TensorFusionThresholdBytes() {
int64_t proposed_fusion_threshold =
parameter_manager_.TensorFusionThresholdBytes();

// If the cluster is homogeneous and hierarchical allreduce is enabled,
// If the cluster is homogeneous,
// adjust buffer size to make sure it is divisible by local_size to improve
// performance.
if (is_homogeneous_ && parameter_manager_.HierarchicalAllreduce()) {
// performance for operations that perform local reductions by default such as Adasum.
if (is_homogeneous_) {
// Assume the worst-case data type float64, since if it is divisible with
// float64, it will be divisible for other types too.

Expand Down Expand Up @@ -400,6 +402,7 @@ Response Controller::ConstructResponse(std::string& name, int joined_size) {
// If we are doing an allreduce or broadcast, check that all tensor shapes are
// identical.
if (message_type == Request::ALLREDUCE ||
message_type == Request::ADASUM ||
message_type == Request::BROADCAST) {
TensorShape tensor_shape;
for (auto dim : requests[0].tensor_shape()) {
Expand Down Expand Up @@ -497,7 +500,7 @@ Response Controller::ConstructResponse(std::string& name, int joined_size) {

// If there is at least one rank that requested Join, communicate tensor sizes
// in the response, because joined ranks don't have this info.
if (joined_size > 0 && message_type == Request::ALLREDUCE) {
if (joined_size > 0 && (message_type == Request::ALLREDUCE || message_type == Request::ADASUM)) {
TensorShape tensor_shape;
for (auto dim : requests[0].tensor_shape()) {
tensor_shape.AddDim(dim);
Expand Down Expand Up @@ -575,6 +578,14 @@ Response Controller::ConstructResponse(std::string& name, int joined_size) {
}
} else if (message_type == Request::BROADCAST) {
response.set_response_type(Response::BROADCAST);
} else if (message_type == Request::ADASUM) {
response.set_response_type(Response::ADASUM);
if (joined_size > 0) {
for (auto dim : tensor_sizes) {
response.add_tensor_size(dim);
}
response.set_tensor_type(data_type);
}
}
response.set_devices(devices);

Expand Down Expand Up @@ -622,8 +633,8 @@ ResponseList Controller::FuseResponses(std::deque<Response>& responses) {
responses.pop_front();
int64_t tensor_size = 0;
DataType dtype;

if (response.response_type() == Response::ResponseType::ALLREDUCE) {
if (response.response_type() == Response::ResponseType::ALLREDUCE ||
response.response_type() == Response::ResponseType::ADASUM) {
// Attempt to add more responses to this fused response.

// found_tensor can be false for ranks that did Join.
Expand Down
6 changes: 6 additions & 0 deletions horovod/common/global_state.h
Expand Up @@ -102,11 +102,17 @@ struct HorovodGlobalState {

// Number of ranks that did Join()
int joined_size = 0;

// If a rank is Joined, AllReduce uses temporary 0 tensors for it.
bool joined = false;

// ID of the device to create temporary tensors while Joined
int join_device = CPU_DEVICE_ID;

// Chunk size for MPI send/recv in Adasum allreduce. Some versions of Intel MPI
// benefit from a smaller chunk size.
int64_t adasum_mpi_chunk_size = 1<<30;

~HorovodGlobalState() {
// Make sure that the destructor of the background thread is safe to
// call. If a thread is still joinable (not detached or complete) its
Expand Down
7 changes: 7 additions & 0 deletions horovod/common/message.cc
@@ -1,5 +1,6 @@
// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
// Modifications copyright (C) 2019 Uber Technologies, Inc.
// Modifications copyright Microsoft
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -75,6 +76,9 @@ const std::string& Request::RequestType_Name(RequestType value) {
case RequestType::JOIN:
static const std::string join("JOIN");
return join;
case RequestType::ADASUM:
static const std::string adasum("ADASUM");
return adasum;
default:
static const std::string unknown("<unknown>");
return unknown;
Expand Down Expand Up @@ -242,6 +246,9 @@ const std::string& Response::ResponseType_Name(ResponseType value) {
case ResponseType::JOIN:
static const std::string join("JOIN");
return join;
case ResponseType::ADASUM:
static const std::string adasum("ADASUM");
return adasum;
case ResponseType::ERROR:
static const std::string error("ERROR");
return error;
Expand Down
5 changes: 3 additions & 2 deletions horovod/common/message.h
@@ -1,5 +1,6 @@
// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
// Modifications copyright (C) 2019 Uber Technologies, Inc.
// Modifications copyright Microsoft
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -45,7 +46,7 @@ const std::string& DataType_Name(DataType value);
class Request {
public:
enum RequestType {
ALLREDUCE = 0, ALLGATHER = 1, BROADCAST = 2, JOIN = 3
ALLREDUCE = 0, ALLGATHER = 1, BROADCAST = 2, JOIN = 3, ADASUM = 4
};

static const std::string& RequestType_Name(RequestType value);
Expand Down Expand Up @@ -130,7 +131,7 @@ class RequestList {
class Response {
public:
enum ResponseType {
ALLREDUCE = 0, ALLGATHER = 1, BROADCAST = 2, JOIN = 3, ERROR = 4
ALLREDUCE = 0, ALLGATHER = 1, BROADCAST = 2, JOIN = 3, ADASUM = 4, ERROR = 5
};

static const std::string& ResponseType_Name(ResponseType value);
Expand Down