Skip to content

Commit

Permalink
[feat] OSS/SDP : bucketing (#122)
Browse files Browse the repository at this point in the history
Same bucketing strategy for OSS and SDP:
sort everything ahead of time, per rank and per size, smaller tensors first. Bucket the smallest elements in a fixed buffer, send async, then send all the others async, and get back to the bucket. Once done then scatter the contents if needed
  • Loading branch information
blefaudeux committed Oct 6, 2020
1 parent 6e7ad79 commit 341d8b2
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 144 deletions.
7 changes: 6 additions & 1 deletion benchmarks/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ def closure():

print("[Regression Test] VALID")

dist.destroy_process_group() # type: ignore


if __name__ == "__main__":

Expand Down Expand Up @@ -246,7 +248,10 @@ class OptimType(str, Enum):
backend,
True, # OSS
True, # SDP
False, # no regression check
args.check_regression,
-1, # Not checking SDP for speed regression for now, still slower than OSS
args.reference_memory,
args.reference_loss,
),
nprocs=args.world_size,
join=True,
Expand Down
183 changes: 99 additions & 84 deletions fairscale/nn/data_parallel/sharded_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from contextlib import contextmanager
import copy
from typing import Any, Dict, Generator, List, Optional, Type, cast
from typing import Any, Dict, Generator, List, Type, cast

import torch
from torch import Tensor, nn
Expand Down Expand Up @@ -39,7 +39,7 @@ class ShardedDataParallel(nn.Module):
distributed gradient reduction. If None, the default WORLD process group
will be used.
buffer_size (int, optional): number of elements to buffer before
performing reduce (default: 256M). Used to reduce multiple small
performing reduce (default: 512k). Used to reduce multiple small
params to avoid communication overhead.
"""

Expand All @@ -51,7 +51,7 @@ def __init__(
world_size: int,
broadcast_buffers: bool,
process_group: Any = None,
buffer_size: int = 2 ** 28,
buffer_size: int = 2 ** 19,
):
super().__init__()

Expand All @@ -62,10 +62,6 @@ def __init__(
self.broadcast_buffers = broadcast_buffers
self.authoritative_rank = 0

# Never use a bigger buffer than the number of model params
self.buffer_size = min(buffer_size, sum(p.numel() for p in self.module.parameters()))
self.buffer: Optional[Tensor] = None

# Flag used to make sure we only reduce gradients one time in the execution engine
self.need_reduction = False

Expand All @@ -76,6 +72,18 @@ def __init__(
# Build the sharded optimizer
self.sharded_optimizer = OSS(self.module.parameters(), optim=optimizer, group=process_group, **optimizer_params)

# Allocate reduce buffers
# - Never use a bigger buffer than the number of model params
buffer_size = min(buffer_size, sum(p.numel() for p in self.module.parameters()))
self._reduce_buffers: Dict[torch.device, List[torch.Tensor]] = {}

# - One buffer per rank per device
for device, per_device in self.sharded_optimizer.per_device_params.items():
buffer_dtype = per_device[0][0].dtype
self._reduce_buffers[device] = [
torch.zeros(buffer_size, dtype=buffer_dtype, device=device) for _ in range(len(per_device))
]

# Sanity checks
assert len(self.sharded_optimizer.param_to_rank) == len(
list(self.module.parameters())
Expand Down Expand Up @@ -126,85 +134,92 @@ def reduce(self) -> None:
"""
assert self.module.training, "Cannot call reduce in eval"

def reduce_grads(params: List[Parameter], params_rank: int) -> None:
""" Helper to reduce a list of params that should fit in the buffer.
NOTE: All param gradients are assumed to exist"""
assert self.buffer is not None
if not self.need_reduction or self.accumulate_grads:
return

# Fill in the packed IO buffer
buffer: Tensor = cast(Tensor, self.buffer)
if len(params) > 1:
offset = 0
for p in params:
sz = p.numel()
buffer[offset : offset + sz].copy_(p.grad.data.view(-1)) # type: ignore
offset += sz
else:
# we only have a single grad to reduce
buffer = params[0].grad.data # type: ignore

# Reduce
buffer.div_(self.world_size) # type: ignore
dist.reduce(tensor=buffer, dst=params_rank, group=self.process_group) # type: ignore

# Copy reduced grads back into their original place, or free corresponding memory
if params_rank == self.rank:
offset = 0
for p in params:
sz = p.numel()
p.grad.data.copy_(buffer[offset : offset + sz].view_as(p)) # type: ignore
offset += sz
else:
for p in params:
p.grad = None

def reduction_fn() -> None:
# This function only needs to be called once
if not self.need_reduction or self.accumulate_grads:
return
self.need_reduction = False

if self.buffer is None:
self.buffer = next(self.module.parameters()).new(self.buffer_size) # type: ignore

for params in self.sharded_optimizer.per_device_params:
# Reduce the gradients in buckets
self.need_reduction = False

with torch.no_grad():
for device, per_device in self.sharded_optimizer.per_device_params.items():
self._reduce_grads_task(
self._reduce_buffers[device],
per_device,
group=self.process_group,
self_rank=self.rank,
world_size=self.world_size,
)

@staticmethod
def _reduce_grads_task(
buffers: List[torch.Tensor], per_rank_params: List[List[Parameter]], group: Any, self_rank: int, world_size: int
) -> None:
"""Helper to reduce a list of params. The params are sorted by size, smallest first, which allows for
an opportunistic bucketing.
NOTE: All param gradients are assumed to exist"""

buffer_size = buffers[0].numel()
bucket_requests = []
requests = []

for (rank, params), buffer in zip(enumerate(per_rank_params), buffers):
# All the params are sorted per rank and per increasing size
if len(params) == 0:
continue

for p in params:
if p.grad is None:
p.grad = torch.zeros_like(p)

global_rank = OSS.get_global_rank(group, rank)

# Copy small gradients into per-GPU buffers and then async reduce
i_bucketed = 0 # the number of tensors packed in the buffer
offset = 0

# Since all the parameters are already sorted per increasing size, we only need to consider the first ones.
while i_bucketed < len(params) and offset + params[i_bucketed].numel() < buffer_size:
end = offset + params[i_bucketed].numel()
buffer[offset:end].copy_(params[i_bucketed].grad.data.view(-1)) # type: ignore
offset = end
i_bucketed += 1

if i_bucketed > 0:
buffer.div_(world_size) # type: ignore
bucket_requests.append(
(
dist.reduce(tensor=buffer, dst=global_rank, group=group, async_op=True), # type: ignore
rank,
)
)

# Directly reduce the other grads
for p in params[i_bucketed:]:
p.grad = cast(Tensor, p.grad)
if p.grad.requires_grad:
raise RuntimeError("DistributedDataParallel only works with gradients that don't require grad")

p.grad.div_(world_size) # type: ignore
requests.append(dist.reduce(tensor=p.grad, dst=global_rank, group=group, async_op=True)) # type: ignore

# Unroll the initial packed small gradients, as soon as possible
for future, rank in bucket_requests:
future.wait()

if rank == self_rank:
i_bucketed = 0 # the number of tensors packed in the buffer
offset = 0
buffered_params: List[Parameter] = []
param_rank: Optional[int] = None
for param in params:
last_param_rank: Optional[int] = param_rank
param_rank = self.sharded_optimizer.param_to_rank[param]
if not param.requires_grad:
continue

if param.grad is None:
param.grad = torch.zeros_like(param)
if param.grad.requires_grad:
raise RuntimeError("DistributedDataParallel only works with gradients that don't require grad")
sz = param.numel()
if sz > self.buffer.numel():
# reduce big params directly
assert param_rank is not None
reduce_grads([param], cast(int, param_rank))
else:
# smaller params are packed together from the same device
# and same rank.
if offset + sz > self.buffer.numel() or (
last_param_rank is not None and last_param_rank != param_rank
):
assert last_param_rank is not None
reduce_grads(buffered_params, cast(int, last_param_rank))
offset = 0
buffered_params.clear()
buffered_params.append(cast(Parameter, param))
offset += sz

if len(buffered_params) > 0:
assert param_rank is not None
reduce_grads(buffered_params, cast(int, param_rank))

reduction_fn()
params = per_rank_params[rank]
buffer = buffers[rank]

while i_bucketed < len(params) and offset + params[i_bucketed].numel() < buffer_size:
end = offset + params[i_bucketed].numel()
params[i_bucketed].grad.data.copy_(buffer[offset:end].view_as(params[i_bucketed])) # type: ignore
offset = end
i_bucketed += 1

# Make sure that we're done with this device before moving on and cleaning the unused params
_ = list(map(lambda x: x.wait(), requests))

def _sync_buffers(self) -> None:
"""
Expand Down

0 comments on commit 341d8b2

Please sign in to comment.