Skip to content

Commit

Permalink
[ShardedDDP] Sync buffers + small cleanup (#112)
Browse files Browse the repository at this point in the history
- adding the buffer broadcast option
- minor cleanup in shardedDDP
  • Loading branch information
blefaudeux committed Sep 29, 2020
1 parent 41819af commit 79ded82
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 42 deletions.
8 changes: 6 additions & 2 deletions benchmarks/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def train(
):
assert not use_sdp or (use_sdp and use_oss), "ShardedDataParallel requires OSS"
# DDP
dist_init(rank, world_size, backend)
dist_init(rank=rank, world_size=world_size, backend=backend)

# Setup
torch.cuda.set_device(rank)
Expand All @@ -81,7 +81,11 @@ def train(

if use_sdp:
ddp = ShardedDataParallel(
module=model, optimizer=OPTIM, optimizer_params={"lr": 1e-4, "momentum": 0.9}, world_size=world_size,
module=model,
optimizer=OPTIM,
optimizer_params={"lr": 1e-4, "momentum": 0.9},
world_size=world_size,
broadcast_buffers=False,
)
ddp.train()
optimizer = ddp.optimizer
Expand Down
74 changes: 40 additions & 34 deletions fairscale/nn/data_parallel/sharded_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@ class ShardedDataParallel(nn.Module):
"""Implements distributed data parallel training with optimizer state sharding.
A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`.
This version uses a c10d process group for communication and does not
This version uses a c10d process group for communication and optionally
broadcast buffers.
Args:
module (~torch.nn.Module): module to be parallelized
optimizer (~torch.optim.Optimizer): optimizer to be used for training
optimizer_params(Dict): extra parameters for the optimizer
world_size (int): number of parallel workers
broadcast_buffers (bool): flag that enables syncing (broadcasting) buffers of
the module at beginning of the forward function. (default: ``True``)
process_group (optional): the c10d process group to be used for
distributed gradient reduction. If None, the default WORLD process group
will be used.
Expand All @@ -47,6 +49,7 @@ def __init__(
optimizer: Type[torch.optim.Optimizer],
optimizer_params: Dict[str, Any],
world_size: int,
broadcast_buffers: bool,
process_group: Any = None,
buffer_size: int = 2 ** 28,
):
Expand All @@ -56,6 +59,8 @@ def __init__(
self.world_size = world_size
self.process_group = process_group if process_group is not None else dist.group.WORLD
self.rank = dist.get_rank(self.process_group)
self.broadcast_buffers = broadcast_buffers
self.authoritative_rank = 0

# Never use a bigger buffer than the number of model params
self.buffer_size = min(buffer_size, sum(p.numel() for p in self.module.parameters()))
Expand All @@ -71,7 +76,7 @@ def __init__(
# Build the sharded optimizer
self.sharded_optimizer = OSS(self.module.parameters(), optim=optimizer, group=process_group, **optimizer_params)

# sanity checks
# Sanity checks
assert len(self.sharded_optimizer.param_to_rank) == len(
list(self.module.parameters())
), "number of params do not match"
Expand Down Expand Up @@ -109,6 +114,9 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Tensor:
raise RuntimeError("OssDdp requires explicit reduction, must call OssDdp.reduce")
if not self.accumulate_grads:
self.need_reduction = True
if self.broadcast_buffers and len(list(self.module.buffers())) > 0:
self._sync_buffers()

return self.module(*inputs, **kwargs)

def reduce(self) -> None:
Expand All @@ -118,52 +126,35 @@ def reduce(self) -> None:
"""
assert self.module.training, "Cannot call reduce in eval"

def reduce_params(params: List[Parameter], params_rank: int) -> None:
""" Helper to reduce a list of params that should fix in the buffer. """
def reduce_grads(params: List[Parameter], params_rank: int) -> None:
""" Helper to reduce a list of params that should fit in the buffer.
NOTE: All param gradients are assumed to exist"""
assert self.buffer is not None

# Fill in the packed IO buffer
buffer: Tensor = cast(Tensor, self.buffer)
nonzero_buffer = False
if len(params) > 1:
offset = 0
for p in params:
sz = p.numel()
if p.grad is not None:
# The type error could have been fixed in later
# version of pytorch. Same elsewhere.
buffer[offset : offset + sz].copy_(p.grad.data.view(-1)) # type: ignore
nonzero_buffer = True
else:
buffer[offset : offset + sz].zero_()
buffer[offset : offset + sz].copy_(p.grad.data.view(-1)) # type: ignore
offset += sz
else:
# we only have a single grad to reduce
p = params[0]
if p.grad is not None:
buffer = p.grad.data
nonzero_buffer = True
elif p.numel() <= self.buffer.numel():
buffer = buffer[: p.numel()]
buffer.zero_()
else:
buffer = torch.zeros_like(p)

if nonzero_buffer:
buffer.div_(self.world_size) # type: ignore
buffer = params[0].grad.data # type: ignore

dist.reduce(buffer, params_rank, group=self.process_group) # type: ignore
# Reduce
buffer.div_(self.world_size) # type: ignore
dist.reduce(tensor=buffer, dst=params_rank, group=self.process_group) # type: ignore

# Copy reduced grads back into their original place, or free corresponding memory
if params_rank == self.rank:
# copy reduced grads back into their original place
offset = 0
for p in params:
sz = p.numel()
if p.grad is not None:
p.grad.data.copy_(buffer[offset : offset + sz].view_as(p)) # type: ignore
else:
p.grad = buffer[offset : offset + sz].view_as(p).clone()
p.grad.data.copy_(buffer[offset : offset + sz].view_as(p)) # type: ignore
offset += sz
else:
# wipe the grads
for p in params:
p.grad = None

Expand Down Expand Up @@ -195,22 +186,37 @@ def reduction_fn() -> None:
if sz > self.buffer.numel():
# reduce big params directly
assert param_rank is not None
reduce_params([param], cast(int, param_rank))
reduce_grads([param], cast(int, param_rank))
else:
# smaller params are packed together from the same device
# and same rank.
if offset + sz > self.buffer.numel() or (
last_param_rank is not None and last_param_rank != param_rank
):
assert last_param_rank is not None
reduce_params(buffered_params, cast(int, last_param_rank))
reduce_grads(buffered_params, cast(int, last_param_rank))
offset = 0
buffered_params.clear()
buffered_params.append(cast(Parameter, param))
offset += sz

if len(buffered_params) > 0:
assert param_rank is not None
reduce_params(buffered_params, cast(int, param_rank))
reduce_grads(buffered_params, cast(int, param_rank))

reduction_fn()

def _sync_buffers(self) -> None:
"""
Sync all the param buffers in between ranks.
TODO: Could be worth bucketing ?
"""
_ = list(
map(
lambda x: x.wait(),
map(
lambda x: dist.broadcast(x, self.authoritative_rank, self.process_group, async_op=True),
self.module.buffers(),
),
)
)
19 changes: 13 additions & 6 deletions tests/nn/data_parallel/test_sharded_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,20 @@ def run_one_step(rank, world_size, backend, device, temp_file_name):
if device == torch.device("cuda"):
torch.cuda.set_device(rank)

# Any model works. Add one different buffer per rank
model = Sequential(Linear(2, 3), Linear(3, 4)).to(device)
model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device)

ddp = ShardedDataParallel(
module=model, optimizer=torch.optim.SGD, optimizer_params={"lr": 0.1, "momentum": 0.99}, world_size=world_size
module=model,
optimizer=torch.optim.SGD,
optimizer_params={"lr": 0.01, "momentum": 0.99},
world_size=world_size,
broadcast_buffers=True,
)
optimizer = ddp.optimizer
model = ddp.module

input_tensor = torch.rand((64, 2)).to(device)
output = ddp(input_tensor).abs().sum() / input_tensor.numel()
Expand All @@ -58,10 +66,9 @@ def run_one_step(rank, world_size, backend, device, temp_file_name):
if param.requires_grad:
assert param.grad.abs().sum().item() > 0.0, "The reduce step should have populated all the gradients"

# Check that the optimization process makes sense (ie. loss goes down for the same data)
optimizer.step()
new_eval = ddp(input_tensor).abs().sum() / input_tensor.numel()
# assert new_eval.item() < output.item()
# Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0)
for b in model.buffers():
assert b.cpu().item() == 0.0


def run_test(backend, device, world_size=2):
Expand All @@ -76,7 +83,7 @@ def run_eval_mode(_unused):
)
model = Sequential(Linear(2, 3), Linear(3, 4))
optimizer_params = {"lr": 0.1, "momentum": 0.99}
ddp = ShardedDataParallel(model, torch.optim.SGD, optimizer_params, 1)
ddp = ShardedDataParallel(model, torch.optim.SGD, optimizer_params, 1, broadcast_buffers=False)
optimizer = ddp.optimizer

ddp.eval()
Expand Down

0 comments on commit 79ded82

Please sign in to comment.