Skip to content
Permalink
Browse files

Enable local gradient accumulation (#546)

  • Loading branch information...
andfoy authored and alsrgv committed Dec 12, 2018
1 parent 7b53b4b commit 9081ba35908819fc08365e88314faaa6a96e7d98
Showing with 70 additions and 18 deletions.
  1. +1 −0 .gitignore
  2. +31 −12 examples/pytorch_imagenet_resnet50.py
  3. +38 −6 horovod/torch/__init__.py
@@ -4,3 +4,4 @@
horovod.egg-info
dist
build
.vscode/
@@ -1,5 +1,6 @@
from __future__ import print_function

import torch
import argparse
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
@@ -9,6 +10,7 @@
import horovod.torch as hvd
import tensorboardX
import os
import math
from tqdm import tqdm

# Training settings
@@ -24,6 +26,10 @@
help='checkpoint file format')
parser.add_argument('--fp16-allreduce', action='store_true', default=False,
help='use fp16 compression during allreduce')
parser.add_argument('--batches-per-allreduce', type=int, default=1,
help='number of batches processed locally before '
'executing allreduce across workers; it multiplies '
'total batch size.')

# Default settings from https://arxiv.org/abs/1706.02677.
parser.add_argument('--batch-size', type=int, default=32,
@@ -49,6 +55,8 @@
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

allreduce_batch_size = args.batch_size * args.batches_per_allreduce

hvd.init()
torch.manual_seed(args.seed)

@@ -93,7 +101,8 @@
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, sampler=train_sampler, **kwargs)
train_dataset, batch_size=allreduce_batch_size,
sampler=train_sampler, **kwargs)

val_dataset = \
datasets.ImageFolder(args.val_dir,
@@ -118,16 +127,20 @@
model.cuda()

# Horovod: scale learning rate by the number of GPUs.
optimizer = optim.SGD(model.parameters(), lr=args.base_lr * hvd.size(),
# Gradient Accumulation: scale learning rate by batches_per_allreduce
optimizer = optim.SGD(model.parameters(),
lr=(args.base_lr *
args.batches_per_allreduce * hvd.size()),
momentum=args.momentum, weight_decay=args.wd)

# Horovod: (optional) compression algorithm.
compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none

# Horovod: wrap optimizer with DistributedOptimizer.
optimizer = hvd.DistributedOptimizer(optimizer,
named_parameters=model.named_parameters(),
compression=compression)
optimizer = hvd.DistributedOptimizer(
optimizer, named_parameters=model.named_parameters(),
compression=compression,
backward_passes_per_step=args.batches_per_allreduce)

# Restore from a previous checkpoint, if initial_epoch is specified.
# Horovod: restore on the first worker which will broadcast weights to other workers.
@@ -156,13 +169,19 @@ def train(epoch):
if args.cuda:
data, target = data.cuda(), target.cuda()
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
loss.backward()
# Split data into sub-batches of size batch_size
for i in range(0, len(data), args.batch_size):
data_batch = data[i:i + args.batch_size]
target_batch = target[i:i + args.batch_size]
output = model(data_batch)
train_accuracy.update(accuracy(output, target_batch))
loss = F.cross_entropy(output, target_batch)
train_loss.update(loss.item())
# Average gradients among sub-batches
loss.div_(math.ceil(float(len(data)) / args.batch_size))
loss.backward()
# Gradient is applied across all ranks
optimizer.step()

train_loss.update(loss)
train_accuracy.update(accuracy(output, target))
t.set_postfix({'loss': train_loss.avg.item(),
'accuracy': 100. * train_accuracy.avg.item()})
t.update(1)
@@ -214,7 +233,7 @@ def adjust_learning_rate(epoch, batch_idx):
else:
lr_adj = 1e-3
for param_group in optimizer.param_groups:
param_group['lr'] = args.base_lr * hvd.size() * lr_adj
param_group['lr'] = args.base_lr * hvd.size() * args.batches_per_allreduce * lr_adj


def accuracy(output, target):
@@ -40,7 +40,8 @@


class _DistributedOptimizer(torch.optim.Optimizer):
def __init__(self, params, named_parameters, compression):
def __init__(self, params, named_parameters, compression,
backward_passes_per_step=1):
super(self.__class__, self).__init__(params)
self._compression = compression

@@ -62,13 +63,20 @@ def __init__(self, params, named_parameters, compression):
self._parameter_names = {v: 'allreduce.noname.%s' % i
for param_group in self.param_groups
for i, v in enumerate(param_group['params'])}

self.backward_passes_per_step = backward_passes_per_step
self._allreduce_delay = {v: self.backward_passes_per_step
for _, v in sorted(named_parameters)}
self._handles = {}
self._grad_accs = []
self._requires_update = set()
if size() > 1:
self._register_hooks()

def set_backward_passes_per_step(self, passes):
self.backward_passes_per_step = passes
for p in self._allreduce_delay:
self._allreduce_delay[p] = self.backward_passes_per_step

def _register_hooks(self):
for param_group in self.param_groups:
for p in param_group['params']:
@@ -84,14 +92,25 @@ def _allreduce_grad_async(self, p):
name = self._parameter_names.get(p)
tensor = p.grad.data
tensor_compressed, ctx = self._compression.compress(tensor)

handle = allreduce_async_(tensor_compressed, average=True, name=name)
return handle, ctx

def _make_hook(self, p):
def hook(*ignore):
assert p not in self._handles
if p in self._handles and self._handles[p][0] is not None:
if self._allreduce_delay[p] <= 0:
raise AssertionError(
"Gradients were computed more than "
"backward_passes_per_step times before call "
"to step(). Increase backward_passes_per_step to "
"accumulate gradients locally.")
assert not p.grad.requires_grad
handle, ctx = self._allreduce_grad_async(p)
assert self._allreduce_delay[p] > 0
handle, ctx = None, None
self._allreduce_delay[p] -= 1
if self._allreduce_delay[p] == 0:
handle, ctx = self._allreduce_grad_async(p)
self._handles[p] = (handle, ctx)
return hook

@@ -103,7 +122,12 @@ def synchronize(self):

for p, value in self._handles.items():
handle, ctx = value
if handle is None:
handle, ctx = self._allreduce_grad_async(p)
self._handles[p] = (handle, ctx)
for p, (handle, _) in self._handles.items():
output = synchronize(handle)
self._allreduce_delay[p] = self.backward_passes_per_step
p.grad.data.set_(self._compression.decompress(output, ctx))
self._handles.clear()

@@ -112,7 +136,9 @@ def step(self, closure=None):
return super(self.__class__, self).step(closure)


def DistributedOptimizer(optimizer, named_parameters=None, compression=Compression.none):
def DistributedOptimizer(optimizer, named_parameters=None,
compression=Compression.none,
backward_passes_per_step=1):
"""
An optimizer that wraps another torch.optim.Optimizer, using an allreduce to
average gradient values before applying gradients to model weights.
@@ -142,12 +168,18 @@ def DistributedOptimizer(optimizer, named_parameters=None, compression=Compressi
compression: Compression algorithm used during allreduce to reduce the amount
of data sent during the each parameter update step. Defaults to
not using compression.
backward_passes_per_step: Number of expected backward passes to perform
before calling step()/synchronize(). This
allows accumulating gradients over multiple
mini-batches before executing averaging and
applying them.
"""
# We dynamically create a new class that inherits from the optimizer that was passed in.
# The goal is to override the `step()` method with an allreduce implementation.
cls = type(optimizer.__class__.__name__, (optimizer.__class__,),
dict(_DistributedOptimizer.__dict__))
return cls(optimizer.param_groups, named_parameters, compression)
return cls(optimizer.param_groups, named_parameters,
compression, backward_passes_per_step)


def broadcast_parameters(params, root_rank):

0 comments on commit 9081ba3

Please sign in to comment.
You can’t perform that action at this time.