Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] MoE refactor with newest version of ZeRO #5801

Merged
merged 1 commit into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion colossalai/zero/low_level/low_level_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def __init__(
# it will not manage the tensors used by mixed precision training
self._param_store = ParameterStore(process_group)
self._grad_store = GradientStore(process_group, partition_grad=partition_grad)
self._bucket_store = BucketStore(process_group)
self._bucket_store = BucketStore(
process_group, reduce_bucket_size=reduce_bucket_size, overlap_communication=overlap_communication
)

# working and master params for mixed precision training
group_params = []
Expand Down
107 changes: 0 additions & 107 deletions tests/test_moe/test_moe_zero_fwd_bwd.py

This file was deleted.

62 changes: 31 additions & 31 deletions tests/test_moe/test_moe_zero_fwd_bwd_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from colossalai.zero import LowLevelZeroOptimizer
from colossalai.zero.low_level.low_level_strategy import LowLevelOptStrategy, MoeZeroStrategy
from tests.test_moe.moe_utils import loose_close

tokens, n_experts = 7, 4
Expand Down Expand Up @@ -59,14 +60,30 @@ def run_zero_with_original_model(world_size, master_weights: bool, dtype: torch.
zero_model = EPMixtralSparseMoeBlock.from_native_module(zero_model, ep_group=plugin.ep_group)

zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
zero_params = list(filter(lambda x: not is_moe_tensor(x), zero_model.parameters()))
moe_params = list(filter(lambda x: is_moe_tensor(x), zero_model.parameters()))
zero_optimizer.param_groups.clear()
zero_optimizer.add_param_group({"params": zero_params})
zero_optimizer.add_param_group({"params": moe_params})
strategies = [
LowLevelOptStrategy(
param_group=zero_optimizer.param_groups[0],
process_group=plugin.global_dp_group,
overlap_communication=False,
partition_grad=(stage == 2),
),
MoeZeroStrategy(
param_group=zero_optimizer.param_groups[1],
process_group=plugin.moe_dp_group,
overlap_communication=True,
partition_grad=(stage == 2),
),
]
zero_optimizer = LowLevelZeroOptimizer(
zero_optimizer,
overlap_communication=True,
initial_scale=1,
reduce_bucket_size=1024 * 1024,
strategies,
master_weights=master_weights,
moe_extra_dp_process_group=plugin.moe_dp_group,
partition_grad=(stage == 2),
initial_scale=1,
)

ori_optimizer = torch.optim.SGD(ori_model.parameters(), lr=1)
Expand All @@ -89,34 +106,17 @@ def run_zero_with_original_model(world_size, master_weights: bool, dtype: torch.

# check grad
name_to_p = {n: p for n, p in ori_model.module.named_parameters()}

for n, p in zero_model.named_parameters():
if is_moe_tensor(p): # moe param
if p.grad is None:
"""
For fixed input seed, the test input may cause a certain expert not to be routed to,
so its gradient is None instead of a tensor, which may lead to a potential bug.
TODO(haze188) fix later
"""
p.grad = torch.zeros_like(p)
continue
dist.all_reduce(
p.grad, group=plugin.moe_dp_group
) # TODO(haze188) bug fix: this step should be finished by zero
p.grad = (
p.grad / plugin.moe_dp_group.size()
) # moe param scaling amoung the moe dp group, not the WORLD group.
loose_close(p.grad, name_to_p[n].grad, dtype=dtype)
zero_grad = zero_optimizer.get_param_grad(p)
if p.grad is None:
"""
For fixed input seed, the test input may cause a certain expert not to be routed to,
so its gradient is None instead of a tensor, which may lead to a potential bug.
"""
# TODO(haze188) fix later
p.grad = torch.zeros_like(p)
continue
else:
zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(p))
assert len(zero_grad_list) != 0
ori_grad_list = split_grad(name_to_p[n].grad, world_size)
if stage == 2:
# Zero2 splits the gradient, and each rank holds the corresponding part
ori_grad_list = ori_grad_list[rank : rank + 1]
for zero_grad, torch_grad in zip(zero_grad_list, ori_grad_list):
loose_close(zero_grad, torch_grad, dtype=dtype)
loose_close(zero_grad, name_to_p[n].grad, dtype=dtype)

# zero-dp step
zero_optimizer.step()
Expand Down
125 changes: 0 additions & 125 deletions tests/test_moe/test_moe_zero_optim.py

This file was deleted.

Loading