Skip to content

Commit

Permalink
[fix] OSS unit test to check data group (#129)
Browse files Browse the repository at this point in the history
* new unit test to catch rank issues in OSS
  • Loading branch information
blefaudeux committed Oct 8, 2020
1 parent 22ff665 commit 81ac5b2
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 11 deletions.
10 changes: 9 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,12 @@ run_oss_benchmark: &run_oss_benchmark
name: Run OSS Benchmark
command: |
python benchmarks/oss.py --check_regression --world_size 4 --reference_speed 21.2 --reference_memory 4220 --reference_loss 0.63
python benchmarks/oss.py --gloo --optim_type oss
run_oss_gloo: &run_oss_gloo
- run:
name: Run OSS with Gloo
command: |
python benchmarks/oss.py --gloo --optim_type oss
# -------------------------------------------------------------------------------------
Expand Down Expand Up @@ -254,6 +259,9 @@ jobs:

- <<: *run_oss_benchmark

- <<: *run_oss_gloo




workflows:
Expand Down
17 changes: 7 additions & 10 deletions fairscale/optim/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) ->
device,
device_params,
) in self.per_device_params.items(): # all the params on this device (inc all ranks)
self._broadcast_params(self._broadcast_buffers[device], device_params, self.group, self.global_rank)
self._broadcast_params(self._broadcast_buffers[device], device_params)

return loss

Expand Down Expand Up @@ -408,10 +408,7 @@ def get_global_rank(group: Any, rank: int) -> int:
global_rank = dist.distributed_c10d._get_global_rank(group, rank) # type: ignore
return global_rank

@staticmethod
def _broadcast_params(
buffers: List[torch.Tensor], per_rank_params: List[List[Parameter]], group: Any, self_rank: int
) -> None:
def _broadcast_params(self, buffers: List[torch.Tensor], per_rank_params: List[List[Parameter]]) -> None:
"""Helper function to broadcast all the parameters from a given device
"""
buffer_size = buffers[0].numel()
Expand All @@ -425,7 +422,7 @@ def _broadcast_params(
if len(params) == 0:
continue

global_rank = OSS.get_global_rank(group, rank)
global_rank = OSS.get_global_rank(self.group, rank)

# Copy small parameters into per-GPU buffers
i_bucketed = 0 # the number of tensors packed in the buffer
Expand All @@ -434,14 +431,14 @@ def _broadcast_params(
# Since all the parameters are already sorted per increasing size, we only need to consider the first ones.
while i_bucketed < len(params) and offset + params[i_bucketed].numel() < buffer_size:
end = offset + params[i_bucketed].numel()
if global_rank == self_rank:
if global_rank == self.global_rank:
buffer[offset:end].copy_(params[i_bucketed].data.view(-1)) # type: ignore
offset = end
i_bucketed += 1

if i_bucketed > 0:
future = dist.broadcast(tensor=buffer, src=global_rank, group=group, async_op=True)
if global_rank != self_rank:
future = dist.broadcast(tensor=buffer, src=global_rank, group=self.group, async_op=True)
if global_rank != self.global_rank:
# This request will need to be unrolled
bucket_requests.append((future, rank))

Expand All @@ -455,7 +452,7 @@ def _broadcast_params(
restore_require_grad.append(param)
param.requires_grad = False

requests.append(dist.broadcast(tensor=param, src=global_rank, group=group, async_op=True))
requests.append(dist.broadcast(tensor=param, src=global_rank, group=self.group, async_op=True))

# Unroll the initial packed small parameters
for gate, rank in bucket_requests:
Expand Down
76 changes: 76 additions & 0 deletions tests/optim/test_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import os

import numpy as np
import pytest
import torch
import torch.distributed as dist
Expand Down Expand Up @@ -334,3 +335,78 @@ def test_collect_shards():
mp.spawn(
run_test_collect_shards, args=(world_size, reference_rank), nprocs=world_size, join=True,
)


def run_test_multiple_groups(rank, world_size):
# Only work with the even ranks, to check that the global_rank indexing is properly used
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
dist.init_process_group(backend="gloo", rank=rank, world_size=world_size)
sub_group_ranks = [0, 2, 4]
process_group = torch.distributed.new_group(ranks=sub_group_ranks, backend="gloo")

# Make sure that all the ranks get different training data
# So that the sync check in between their models is meaningful
torch.manual_seed(rank)
np.random.seed(rank)

# Standard deep learning setup
device = "cpu"
epochs, batch, input_width, hidden, target_width = 5, 3, 20, 10, 5
loss_fn = torch.nn.L1Loss().to(device)

def check(optimizer):
# Just run a couple of epochs, check that the model is properly updated
for _ in range(epochs):
target = torch.rand((batch, target_width), device=device)
inputs = torch.rand((batch, input_width), device=device)

def closure():
optimizer.zero_grad()
output = model(inputs)
loss = loss_fn(output, target)
loss /= world_size
loss.backward()
dist.all_reduce(loss, group=process_group) # Not strictly needed for the test below

return loss

_ = optimizer.step(closure=closure)

# Check that all the params are the same on all ranks
for pg in optimizer.param_groups:
for p in pg["params"]:
receptacle = [p.clone() for _ in sub_group_ranks] if rank == 0 else []
dist.gather(p, receptacle, dst=0, group=process_group)
if rank == 0:
for sync_p in receptacle[1:]:
assert torch.all(torch.eq(receptacle[0], sync_p)), "Models differ in between ranks"

if rank in sub_group_ranks:
# Model fitting in the broadcast bucket
model = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, target_width)).to(
device
)

# With SGD, Momentum is required to get a state to shard
optimizer = optim.OSS(
model.parameters(), lr=0.1, momentum=0.99, group=process_group, broadcast_buffer_size=2 ** 20
)
check(optimizer)

# Model not-fitting in the broadcast bucket
model = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, target_width)).to(
device
)

# With SGD, Momentum is required to get a state to shard
optimizer = optim.OSS(model.parameters(), lr=0.1, momentum=0.99, group=process_group, broadcast_buffer_size=0)
check(optimizer)


def test_multiple_groups():
world_size = 6

mp.spawn(
run_test_multiple_groups, args=(world_size,), nprocs=world_size, join=True,
)

0 comments on commit 81ac5b2

Please sign in to comment.