Skip to content

Commit

Permalink
[NFC] polish colossalai/zero/sharded_model/reduce_scatter.py code sty…
Browse files Browse the repository at this point in the history
…le (#1554)
  • Loading branch information
Fazziekey committed Sep 8, 2022
1 parent 6c9629b commit 649cbc0
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions colossalai/zero/sharded_model/reduce_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@


class Bucket:

def __init__(self, shard_size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup):
self.buffer = torch.zeros((group.size(), shard_size), dtype=dtype, device=device)
self.group = group
Expand All @@ -34,18 +35,18 @@ def flush(self) -> None:
return
# reduce-scatter bucket
if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives:
dist._reduce_scatter_base(
self.output_shard[: self.offset], self.buffer[:, : self.offset].contiguous(), group=self.group
)
dist._reduce_scatter_base(self.output_shard[:self.offset],
self.buffer[:, :self.offset].contiguous(),
group=self.group)
else:
dist.reduce_scatter(
self.output_shard[: self.offset], list(self.buffer[:, : self.offset].unbind(0)), group=self.group
)
dist.reduce_scatter(self.output_shard[:self.offset],
list(self.buffer[:, :self.offset].unbind(0)),
group=self.group)
# execute post-reduction callbacks
for callback_fn in self.callbacks:
callback_fn()
# reuse input bucket but allocate a fresh output shard
self.buffer[:, : self.offset].zero_()
self.buffer[:, :self.offset].zero_()
self.offset = 0
self.callbacks.clear()
self.output_shard = torch.zeros_like(self.buffer[0])
Expand Down Expand Up @@ -73,12 +74,12 @@ def append(self, tensor_list: List[Tensor], callback_fn: Callable):
tensor_size = tensor_list[0].numel()
stacked_input = torch.stack(tensor_list).view(self.group.size(), tensor_size)
offset = self.offset
self.buffer[:, offset: offset + tensor_size].copy_(stacked_input)
self.buffer[:, offset:offset + tensor_size].copy_(stacked_input)
self.offset += tensor_size

# callback will be given the reduced result
if callback_fn is not None:
result_view = self.output_shard[offset: offset + tensor_size].view_as(tensor_list[0])
result_view = self.output_shard[offset:offset + tensor_size].view_as(tensor_list[0])
self.callbacks.append(functools.partial(callback_fn, result_view))


Expand Down Expand Up @@ -141,9 +142,8 @@ def reduce_scatter_async(
"""
world_size = group.size()

assert (
len(input_list) == world_size
), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})"
assert (len(input_list) == world_size
), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})"

first_input = input_list[0]
first_input_size = first_input.numel()
Expand Down Expand Up @@ -183,7 +183,7 @@ def free(self) -> None:

@functools.lru_cache()
def _get_shard_size(self, element_size: int, num_shards: int) -> int:
if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing.
if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing.
return 0
MB = 1024 * 1024
bucket_size = self.bucket_size_mb * MB / element_size
Expand Down

0 comments on commit 649cbc0

Please sign in to comment.