Skip to content

Commit

Permalink
[fix][FSDP] Add support for saving optimizer state with expert replic…
Browse files Browse the repository at this point in the history
…ation (#936)

* checkpoint tests

* checkpoint tests

* fix tests

* lint fixes

* remove prints

* lint fixes

* add comments

* add changelog

* more cleanup

* lint fix
  • Loading branch information
anj-s committed Feb 23, 2022
1 parent cb72ae5 commit 40e7450
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 14 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- FSDP: Added skip_params_check_for_root flag to default_auto_wrap_policy which,
if set, wraps the root module regardless of how many unwrapped params there were
left after children were wrapped. [#930]
- FSDP: Add support for saving optimizer state when using expert replicas with FSDP.

## [0.4.5] - 2022-01-14

Expand Down
13 changes: 9 additions & 4 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2238,6 +2238,11 @@ def _gather_optim_state(
# param that has the optimizer state. So we handle it with the correct
# parameter list.
non_shared_params = cast(FullyShardedDataParallel, self._fsdp_instances[k]).non_shared_params()
# This is the world size and process group of the FSDP submodule which can be
# different than the parent module. For example, when FSDP is used with MoE.
non_shared_world_size = self._fsdp_instances[k].world_size
non_shared_process_group = self._fsdp_instances[k].process_group

assert (
len(non_shared_params) == 1
), f"Only flatten param or a single non-shared param is supported: len={len(non_shared_params)}"
Expand All @@ -2250,15 +2255,15 @@ def _gather_optim_state(

if ou.is_singleton_tensor(t):
if singleton_buffer is None:
singleton_buffer = list(t.new_zeros(self.world_size).chunk(self.world_size))
dist.all_gather(singleton_buffer, t, group=self.process_group)
singleton_buffer = list(t.new_zeros(non_shared_world_size).chunk(non_shared_world_size))
dist.all_gather(singleton_buffer, t, group=non_shared_process_group)
if self.rank == 0:
singleton_state[k][buffer_name] = [x.cpu().squeeze() for x in singleton_buffer]
assert ou.is_singleton_tensor(singleton_state[k][buffer_name][0])
elif torch.is_tensor(t):
if buffer is None:
buffer = list(t.new_zeros(*desired_buffer_size).chunk(self.world_size))
dist.all_gather(buffer, t, group=self.process_group)
buffer = list(t.new_zeros(*desired_buffer_size).chunk(non_shared_world_size))
dist.all_gather(buffer, t, group=non_shared_process_group)
if self.rank == 0:
gathered_state[k][buffer_name] = [x.cpu() for x in buffer]
elif self.rank == 0: # Add non tensor state
Expand Down
6 changes: 3 additions & 3 deletions tests/nn/data_parallel/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ def forward(self, *args, **kwargs):


class MixtureOfExperts(NestedWrappedModule):
def __init__(self, group, wrapper_config, checkpoint_act=False, delay_before_free_ms=0):
def __init__(self, group, wrapper_config, checkpoint_act=False, delay_before_free_ms=0, expert_group=None):
super().__init__(group, wrapper_config)
self.group = group
self.delay_before_free_ms = delay_before_free_ms
Expand All @@ -801,9 +801,9 @@ def __init__(self, group, wrapper_config, checkpoint_act=False, delay_before_fre
shared = checkpoint_wrapper(shared)

if wrapper_config is not None:
# we create a process group of size 1 for the expert params
# we create a process group of size >= 1 for the expert params
# we also need to pass that group as the reduce_scatter group.
expert_group = torch.distributed.new_group([group.rank()])
expert_group = expert_group or torch.distributed.new_group([group.rank()])
expert = FullyShardedDataParallel(
expert, process_group=expert_group, process_group_reduce_scatter=expert_group, **wrapper_config
)
Expand Down
94 changes: 87 additions & 7 deletions tests/nn/data_parallel/test_fsdp_optimizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
import functools
from time import time
import unittest

from parameterized import parameterized
import torch
Expand All @@ -12,7 +13,7 @@
from fairscale.nn import FullyShardedDataParallel
from fairscale.nn.data_parallel.fsdp_optim_utils import is_singleton_tensor
from fairscale.utils.params import recursive_copy_to_device
from fairscale.utils.testing import objects_are_equal
from fairscale.utils.testing import dist_init, objects_are_equal, spawn_for_all_world_sizes

from .test_fsdp import (
DistributedTest,
Expand All @@ -37,6 +38,57 @@ def assert_equal(a, b):
assert a == b, f"{a} != {b}"


def spawn_and_init_multiple_groups(fn, args=None, **spawn_kwargs):
if args is None:
args = ()

run_fn = functools.partial(init_and_run, fn, args)
spawn_for_all_world_sizes(run_fn, **spawn_kwargs)


def _find_my_group_index(grouped_ranks):
"""Return the index corresponding to the MoE group of the current process."""
my_rank = torch.distributed.get_rank()
for i, group in enumerate(grouped_ranks):
if my_rank in group:
return i
raise RuntimeError(f"Unable to find process rank {my_rank} in the set of grouped ranks {grouped_ranks}.")


def get_moe_group(moe_expert_count=2):
"""Return a process group for initializing a MoE layer."""
if torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size()

# If you have more experts than the world size.
if world_size <= moe_expert_count:
assert moe_expert_count % world_size == 0
moe_groups = [[i] for i in range(world_size)]

# If you have a larger world size than experts.
else:
assert world_size % moe_expert_count == 0
ranks_per_group = world_size // moe_expert_count
moe_groups = [[i + j * moe_expert_count for j in range(ranks_per_group)] for i in range(moe_expert_count)]

moe_pgs = [torch.distributed.new_group(g) for g in moe_groups]

# Find the index in the set of moe_groups which contains the current rank.
my_group_idx = _find_my_group_index(moe_groups)
return moe_pgs[my_group_idx]
else:
return torch.distributed.new_group([torch.distributed.get_rank()])


def init_and_run(fn, args, rank, world_size, filename, filename_rpc):
"""Initialize and run the unit test for testing replicated MoE groups."""
dist_init(rank, world_size, filename, filename_rpc)
torch.cuda.set_device(rank)
group = torch.distributed.new_group()
# Specify the moe_group used to initialize the MoE layers with.
fn(rank, group, *args, expert_group=get_moe_group())


class TestOptimizerUtils(DistributedTest):
@parameterized.expand(
[[functools.partial(SGD, momentum=0.9), True], [SGD, False], [Adam, False], [Adadelta, True], [Adam, True]],
Expand All @@ -51,17 +103,33 @@ def test_consolidate_optimizer(self, optim_fn, transformer):

spawn_and_init(test_fn, world_sizes=[min(torch.cuda.device_count(), 4)])

@parameterized.expand(
[[SGD, False], [Adam, False]],
name_func=rename_test,
)
def test_consolidate_optimizer_diff_world_size(self, optim_fn, transformer):
if torch.cuda.device_count() < 4:
raise unittest.SkipTest("This test requires at least 4 GPUs.")
config = {"mixed_precision": True, "flatten_parameters": True}
config["compute_dtype"] = torch.float32
test_fn = functools.partial(self._test_consolidated_optimizer, config, optim_fn=Adam, transformer=transformer)

spawn_and_init_multiple_groups(test_fn, world_sizes=[min(torch.cuda.device_count(), 4)])

@classmethod
def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim.SGD, transformer=False):
def _test_consolidated_optimizer(
self, config, rank, group, optim_fn=torch.optim.SGD, transformer=False, expert_group=None
):
"""FSDP.gather_full_optim_state_dict() should return something very similar to optimizer.state_dict()"""
# Establish reference behavior.

if transformer:
unwrapped_model = TransformerWithSharedParams(group, wrapper_config=config).cuda()
fsdp = self.get_wrapped_model(group, config=config).cuda()
else:
unwrapped_model = MixtureOfExperts(group, wrapper_config=None).cuda()
fsdp = FullyShardedDataParallel(MixtureOfExperts(group, wrapper_config=config)).cuda()
unwrapped_model = MixtureOfExperts(group, wrapper_config=None, expert_group=expert_group).cuda()
fsdp = FullyShardedDataParallel(
MixtureOfExperts(group, wrapper_config=config, expert_group=expert_group)
).cuda()

try:
fsdp_optim = optim_fn(
Expand All @@ -88,9 +156,9 @@ def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim
optim_unwrapped.step()
unwrapped_sd = optim_unwrapped.state_dict()

if not transformer:
if not transformer and not expert_group:
no_broadcast_children = [x for x in fsdp._fsdp_instances if x.no_broadcast_optim_state]
assert len(no_broadcast_children) == 1
assert len(no_broadcast_children) == 1, f"Length of non shared params {len(no_broadcast_children)}"
assert fsdp._fsdp_instances[-1].no_broadcast_optim_state
torch.cuda.empty_cache()
cuda_gb_before = torch.cuda.memory_stats(fsdp.rank)["allocated_bytes.all.current"] / 1024 ** 3
Expand All @@ -115,6 +183,18 @@ def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim
msg = f"got device {t.device} for {k}: {buffer_name}. expected CPU"
assert t.device == torch.device("cpu"), msg

if expert_group:
sd_state = recursive_copy_to_device(sd["state"], non_blocking=False, device="cpu")
orig_state = recursive_copy_to_device(unwrapped_sd["state"], non_blocking=False, device="cpu")

assert_equal(len(sd_state.keys()), len(orig_state.keys()))

assert_equal(
sum([all_tensors_numel_except_for_step(v) for k, v in sd_state.items()]),
sum([all_tensors_numel_except_for_step(v) for k, v in orig_state.items()]),
)
return

unflat_state = sd["state"]
assert "uncollected_local_ids" in sd
shard_sd = fsdp.get_shard_from_optim_state_dict(sd)
Expand Down

0 comments on commit 40e7450

Please sign in to comment.