Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR: Enable local gradient accumulation #546

Merged
merged 30 commits into from Dec 12, 2018
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
4ddf607
Add ignore_gradients function to accumulate gradients locally
andfoy Oct 6, 2018
a998ef9
Reduce gradients by default
andfoy Oct 6, 2018
3bad665
All Reduce only if self._reduce_gradients
andfoy Oct 6, 2018
b0638d3
Checking that handle is not None to raise exception
andfoy Oct 6, 2018
91b3291
Minor typo correction
andfoy Oct 6, 2018
ed5e12a
Update warning message
andfoy Oct 7, 2018
c77a403
Replace ignore_gradients by set_aggregation_delay
andfoy Oct 8, 2018
aeaf351
Add backward_passes_per_step
andfoy Oct 9, 2018
a61788f
Minor error correction
andfoy Oct 9, 2018
da01dfb
Merge remote-tracking branch 'upstream/master' into disable_gradient_…
andfoy Oct 11, 2018
e443509
Expose backward_passes_per_step, update ImageNet example
andfoy Oct 13, 2018
b641b68
Add backward_passes_per_step setter
andfoy Oct 14, 2018
d4b37d4
Update ImageNet example
andfoy Oct 16, 2018
73a35c6
Handle case where len(loader) is not divisble by backward_passes_per_…
andfoy Oct 16, 2018
35f0fbe
Restore loop-like gradient accumulation
andfoy Oct 16, 2018
8833fa0
Add backward_passes_per_step warning comment
andfoy Oct 16, 2018
f8bbf3f
Raise ValueError if backwards step is not an integer divisor
andfoy Oct 16, 2018
b79dd11
Add batches-per-allreduce argument
andfoy Oct 16, 2018
614106f
Split larger batch size into batches of original batch size
andfoy Oct 16, 2018
238edd7
Normalize gradients by allreduce_batch_size
andfoy Oct 16, 2018
db3e101
Remove manual gradient rescaling
andfoy Oct 16, 2018
be0b62c
Variable renaming
andfoy Oct 16, 2018
6150e47
Rescale LR
andfoy Oct 16, 2018
db840f7
Merge remote-tracking branch 'upstream/master' into disable_gradient_…
andfoy Oct 18, 2018
81d3ad6
Add comments
andfoy Oct 18, 2018
f236643
Prevent race conditions
andfoy Oct 19, 2018
32130d8
Add .vscode to gitignore
andfoy Dec 10, 2018
ea8cd22
Rebase with master
andfoy Dec 10, 2018
9f930e5
Scale lr also during warmup
andfoy Dec 10, 2018
bd6baaa
Remove batches_per_allreduce from warmup
andfoy Dec 11, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
41 changes: 30 additions & 11 deletions examples/pytorch_imagenet_resnet50.py
@@ -1,5 +1,6 @@
from __future__ import print_function

import torch
import argparse
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
Expand All @@ -9,6 +10,7 @@
import horovod.torch as hvd
import tensorboardX
import os
import math
from tqdm import tqdm

# Training settings
Expand All @@ -24,6 +26,10 @@
help='checkpoint file format')
parser.add_argument('--fp16-allreduce', action='store_true', default=False,
help='use fp16 compression during allreduce')
parser.add_argument('--batches-per-allreduce', type=int, default=1,
help='number of batches processed locally before '
'executing allreduce across workers; it multiplies '
'total batch size.')

# Default settings from https://arxiv.org/abs/1706.02677.
parser.add_argument('--batch-size', type=int, default=32,
Expand All @@ -49,6 +55,8 @@
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

allreduce_batch_size = args.batch_size * args.batches_per_allreduce

hvd.init()
torch.manual_seed(args.seed)

Expand Down Expand Up @@ -93,7 +101,8 @@
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, sampler=train_sampler, **kwargs)
train_dataset, batch_size=allreduce_batch_size,
sampler=train_sampler, **kwargs)

val_dataset = \
datasets.ImageFolder(args.val_dir,
Expand All @@ -118,16 +127,20 @@
model.cuda()

# Horovod: scale learning rate by the number of GPUs.
optimizer = optim.SGD(model.parameters(), lr=args.base_lr * hvd.size(),
# Gradient Accumulation: scale learning rate by batches_per_allreduce
optimizer = optim.SGD(model.parameters(),
lr=(args.base_lr *
args.batches_per_allreduce * hvd.size()),
momentum=args.momentum, weight_decay=args.wd)

# Horovod: (optional) compression algorithm.
compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none

# Horovod: wrap optimizer with DistributedOptimizer.
optimizer = hvd.DistributedOptimizer(optimizer,
named_parameters=model.named_parameters(),
compression=compression)
optimizer = hvd.DistributedOptimizer(
optimizer, named_parameters=model.named_parameters(),
compression=compression,
backward_passes_per_step=args.batches_per_allreduce)

# Restore from a previous checkpoint, if initial_epoch is specified.
# Horovod: restore on the first worker which will broadcast weights to other workers.
Expand Down Expand Up @@ -156,13 +169,19 @@ def train(epoch):
if args.cuda:
data, target = data.cuda(), target.cuda()
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
loss.backward()
# Split data into sub-batches of size batch_size
for i in range(0, len(data), args.batch_size):
data_batch = data[i:i + args.batch_size]
target_batch = target[i:i + args.batch_size]
output = model(data_batch)
train_accuracy.update(accuracy(output, target_batch))
loss = F.cross_entropy(output, target_batch)
train_loss.update(loss.item())
alsrgv marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to remove .item(): train_loss.update(loss)

# Average gradients among sub-batches
loss.div_(math.ceil(float(len(data)) / args.batch_size))
loss.backward()
# Gradient is applied across all ranks
optimizer.step()

train_loss.update(loss)
train_accuracy.update(accuracy(output, target))
t.set_postfix({'loss': train_loss.avg.item(),
'accuracy': 100. * train_accuracy.avg.item()})
t.update(1)
Expand Down
56 changes: 45 additions & 11 deletions horovod/torch/__init__.py
Expand Up @@ -40,7 +40,8 @@


class _DistributedOptimizer(torch.optim.Optimizer):
def __init__(self, params, named_parameters, compression):
def __init__(self, params, named_parameters, compression,
backward_passes_per_step=1):
alsrgv marked this conversation as resolved.
Show resolved Hide resolved
super(self.__class__, self).__init__(params)
self._compression = compression

Expand All @@ -57,12 +58,19 @@ def __init__(self, params, named_parameters, compression):

self._parameter_names = {v: k for k, v
in sorted(named_parameters)}
self.backward_passes_per_step = backward_passes_per_step
self._allreduce_delay = {v: self.backward_passes_per_step
for _, v in sorted(named_parameters)}
self._handles = {}
self._grad_accs = []

if size() > 1:
self._register_hooks()

def set_backward_passes_per_step(self, passes):
self.backward_passes_per_step = passes
for p in self._allreduce_delay:
self._allreduce_delay[p] = self.backward_passes_per_step

def _register_hooks(self):
for param_group in self.param_groups:
for p in param_group['params']:
Expand All @@ -72,23 +80,41 @@ def _register_hooks(self):
grad_acc.register_hook(self._make_hook(p))
self._grad_accs.append(grad_acc)

def _allreduce_grad(self, p):
name = self._parameter_names.get(p)
tensor = p.grad.data
tensor_compressed, ctx = self._compression.compress(tensor)

handle = allreduce_async_(tensor_compressed, average=True, name=name)
return handle, ctx

def _make_hook(self, p):
def hook(*ignore):
assert p not in self._handles
if p in self._handles and self._handles[p][0] is not None:
if self._allreduce_delay[p] <= 0:
raise AssertionError(
"Gradients were computed more than "
"backward_passes_per_step times before call "
"to step(). Increase backward_passes_per_step to "
"accumulate gradients locally.")
assert not p.grad.requires_grad
name = self._parameter_names.get(p)

tensor = p.grad.data
tensor_compressed, ctx = self._compression.compress(tensor)

handle = allreduce_async_(tensor_compressed, average=True, name=name)
assert self._allreduce_delay[p] > 0
handle, ctx = None, None
self._allreduce_delay[p] -= 1
if self._allreduce_delay[p] == 0:
handle, ctx = self._allreduce_grad(p)
self._handles[p] = (handle, ctx)
return hook

def synchronize(self):
for p, value in self._handles.items():
handle, ctx = value
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can in-line value, like you do below:

for p, (handle, ctx) in self._handles.items():

if handle is None:
handle, ctx = self._allreduce_grad(p)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to do two passes. In a first pass, we check all the handles and backfill missing allreduces. In a second pass, we synchronize all the handles. If we don't do that, we can get a race and hang forever.

self._handles[p] = (handle, ctx)
for p, (handle, _) in self._handles.items():
output = synchronize(handle)
self._allreduce_delay[p] = self.backward_passes_per_step
p.grad.data.set_(self._compression.decompress(output, ctx))
self._handles.clear()

Expand All @@ -97,7 +123,9 @@ def step(self, closure=None):
return super(self.__class__, self).step(closure)


def DistributedOptimizer(optimizer, named_parameters=None, compression=Compression.none):
def DistributedOptimizer(optimizer, named_parameters=None,
compression=Compression.none,
backward_passes_per_step=1):
"""
An optimizer that wraps another torch.optim.Optimizer, using an allreduce to
average gradient values before applying gradients to model weights.
Expand Down Expand Up @@ -127,12 +155,18 @@ def DistributedOptimizer(optimizer, named_parameters=None, compression=Compressi
compression: Compression algorithm used during allreduce to reduce the amount
of data sent during the each parameter update step. Defaults to
not using compression.
backward_passes_per_step: Number of expected backward passes to perform
before calling step()/synchronize(). This
allows accumulating gradients over multiple
mini-batches before executing averaging and
applying them.
"""
# We dynamically create a new class that inherits from the optimizer that was passed in.
# The goal is to override the `step()` method with an allreduce implementation.
cls = type(optimizer.__class__.__name__, (optimizer.__class__,),
dict(_DistributedOptimizer.__dict__))
return cls(optimizer.param_groups, named_parameters, compression)
return cls(optimizer.param_groups, named_parameters,
compression, backward_passes_per_step)


def broadcast_parameters(params, root_rank):
Expand Down