From b6ea9e70fb1e4b59591f09fccd5138f5cb8017e0 Mon Sep 17 00:00:00 2001 From: genghaozhe <939857490@qq.com> Date: Wed, 12 Jun 2024 05:18:29 +0000 Subject: [PATCH] [moe refactor] update unit test with the refactored ZeRO and remove useless test --- .../zero/low_level/low_level_strategy.py | 4 +- tests/test_moe/test_moe_zero_fwd_bwd.py | 107 --------------- tests/test_moe/test_moe_zero_fwd_bwd_optim.py | 62 ++++----- tests/test_moe/test_moe_zero_optim.py | 125 ------------------ 4 files changed, 34 insertions(+), 264 deletions(-) delete mode 100644 tests/test_moe/test_moe_zero_fwd_bwd.py delete mode 100644 tests/test_moe/test_moe_zero_optim.py diff --git a/colossalai/zero/low_level/low_level_strategy.py b/colossalai/zero/low_level/low_level_strategy.py index 16effac9c80a..7298ef543eae 100644 --- a/colossalai/zero/low_level/low_level_strategy.py +++ b/colossalai/zero/low_level/low_level_strategy.py @@ -66,7 +66,9 @@ def __init__( # it will not manage the tensors used by mixed precision training self._param_store = ParameterStore(process_group) self._grad_store = GradientStore(process_group, partition_grad=partition_grad) - self._bucket_store = BucketStore(process_group) + self._bucket_store = BucketStore( + process_group, reduce_bucket_size=reduce_bucket_size, overlap_communication=overlap_communication + ) # working and master params for mixed precision training group_params = [] diff --git a/tests/test_moe/test_moe_zero_fwd_bwd.py b/tests/test_moe/test_moe_zero_fwd_bwd.py deleted file mode 100644 index c0722881bfcd..000000000000 --- a/tests/test_moe/test_moe_zero_fwd_bwd.py +++ /dev/null @@ -1,107 +0,0 @@ -import pytest -import torch -import torch.distributed as dist -from torch.nn.parallel import DistributedDataParallel as DDP - -import colossalai -from colossalai.moe.manager import MOE_MANAGER -from colossalai.tensor.moe_tensor.api import is_moe_tensor -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.testing.random import seed_all -from colossalai.zero.low_level.low_level_optim import LowLevelZeroOptimizer -from colossalai.zero.low_level.low_level_strategy import LowLevelOptStrategy, MoeZeroStrategy -from tests.test_moe.moe_utils import MoeModel, delete_moe_info, loose_close, sync_local_from_ep - - -def run_zero_test(local_rank): - dp_size = world_size = dist.get_world_size() - assert world_size >= 4, f"{world_size=}: at least 4 processes are required for this test (ep=2, moe_dp=2)" - criterion = torch.nn.CrossEntropyLoss() - - ep_size = 2 - extra_dp_size = world_size // ep_size - - MOE_MANAGER.__init__() - MOE_MANAGER.setup(parallel="EP", mode="fixed", fixed_dp_size=extra_dp_size, fixed_ep_size=ep_size, fixed_pp_size=1) - - zero_model = MoeModel().bfloat16().cuda() - - dp_group = dist.group.WORLD - ep_group = MOE_MANAGER.parallel_info_dict[ep_size].ep_group - moe_extra_dp_group = MOE_MANAGER.parallel_info_dict[ep_size].dp_group - - zero_params = list(filter(lambda x: not is_moe_tensor(x), zero_model.parameters())) - moe_params = list(filter(lambda x: is_moe_tensor(x), zero_model.parameters())) - print(f"{len(zero_params)=}, {len(moe_params)=}") - lr = 1e-3 - zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=lr) - zero_optimizer.param_groups.clear() - zero_optimizer.add_param_group({"params": zero_params}) - zero_optimizer.add_param_group({"params": moe_params}) - - strategies = [ - LowLevelOptStrategy( - param_group=zero_optimizer.param_groups[0], - process_group=dp_group, - overlap_communication=False, - partition_grad=True, - ), - MoeZeroStrategy( - param_group=zero_optimizer.param_groups[1], - process_group=moe_extra_dp_group, - overlap_communication=True, - partition_grad=False, - ), - ] - zero_optimizer = LowLevelZeroOptimizer( - zero_optimizer, - strategies, - ) - - MOE_MANAGER.__init__() - MOE_MANAGER.setup(parallel=None) - ddp_model = DDP(MoeModel().bfloat16().cuda(), static_graph=True) - delete_moe_info(ddp_model) - torch_optim = torch.optim.SGD(ddp_model.parameters(), lr=lr) - sync_local_from_ep(ddp_model, zero_model) - - seed_all(42 + local_rank) - data = torch.randn(16, 4).bfloat16().cuda() - label = torch.randint(0, 4, (16,)).cuda() - - ddp_model.train() - zero_model.train() - ddp_out = criterion(ddp_model(data), label).float() - zero_out = criterion(zero_model(data), label).float() - assert torch.allclose(ddp_out, zero_out) - print(f"{local_rank=} {ddp_out.mean()=}") - - ddp_out.backward() - zero_optimizer.backward(zero_out) - - for (zero_name, zero_param), (ddp_name, ddp_param) in zip( - zero_model.named_parameters(), ddp_model.named_parameters() - ): - torch_grad = ddp_param.grad - zero_grad = zero_optimizer.get_param_grad(zero_param) - if is_moe_tensor(zero_param): - moe_grad_list = [torch.empty_like(zero_grad) for _ in range(ep_size)] - dist.all_gather(moe_grad_list, zero_grad, group=ep_group) - zero_grad = torch.cat(moe_grad_list, dim=0) - loose_close(torch_grad, zero_grad, dtype=torch_grad.dtype) - - -def run_dist(rank, world_size, port, stage): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_zero_test(rank, stage=stage) - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [4]) -@rerun_if_address_is_in_use() -def test_moe_zero_model(world_size): - spawn(run_dist, world_size) - - -if __name__ == "__main__": - test_moe_zero_model(world_size=4) diff --git a/tests/test_moe/test_moe_zero_fwd_bwd_optim.py b/tests/test_moe/test_moe_zero_fwd_bwd_optim.py index 7dcd3d19a734..126ddc6fea65 100644 --- a/tests/test_moe/test_moe_zero_fwd_bwd_optim.py +++ b/tests/test_moe/test_moe_zero_fwd_bwd_optim.py @@ -14,6 +14,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all from colossalai.zero import LowLevelZeroOptimizer +from colossalai.zero.low_level.low_level_strategy import LowLevelOptStrategy, MoeZeroStrategy from tests.test_moe.moe_utils import loose_close tokens, n_experts = 7, 4 @@ -59,14 +60,30 @@ def run_zero_with_original_model(world_size, master_weights: bool, dtype: torch. zero_model = EPMixtralSparseMoeBlock.from_native_module(zero_model, ep_group=plugin.ep_group) zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1) + zero_params = list(filter(lambda x: not is_moe_tensor(x), zero_model.parameters())) + moe_params = list(filter(lambda x: is_moe_tensor(x), zero_model.parameters())) + zero_optimizer.param_groups.clear() + zero_optimizer.add_param_group({"params": zero_params}) + zero_optimizer.add_param_group({"params": moe_params}) + strategies = [ + LowLevelOptStrategy( + param_group=zero_optimizer.param_groups[0], + process_group=plugin.global_dp_group, + overlap_communication=False, + partition_grad=(stage == 2), + ), + MoeZeroStrategy( + param_group=zero_optimizer.param_groups[1], + process_group=plugin.moe_dp_group, + overlap_communication=True, + partition_grad=(stage == 2), + ), + ] zero_optimizer = LowLevelZeroOptimizer( zero_optimizer, - overlap_communication=True, - initial_scale=1, - reduce_bucket_size=1024 * 1024, + strategies, master_weights=master_weights, - moe_extra_dp_process_group=plugin.moe_dp_group, - partition_grad=(stage == 2), + initial_scale=1, ) ori_optimizer = torch.optim.SGD(ori_model.parameters(), lr=1) @@ -89,34 +106,17 @@ def run_zero_with_original_model(world_size, master_weights: bool, dtype: torch. # check grad name_to_p = {n: p for n, p in ori_model.module.named_parameters()} - for n, p in zero_model.named_parameters(): - if is_moe_tensor(p): # moe param - if p.grad is None: - """ - For fixed input seed, the test input may cause a certain expert not to be routed to, - so its gradient is None instead of a tensor, which may lead to a potential bug. - TODO(haze188) fix later - """ - p.grad = torch.zeros_like(p) - continue - dist.all_reduce( - p.grad, group=plugin.moe_dp_group - ) # TODO(haze188) bug fix: this step should be finished by zero - p.grad = ( - p.grad / plugin.moe_dp_group.size() - ) # moe param scaling amoung the moe dp group, not the WORLD group. - loose_close(p.grad, name_to_p[n].grad, dtype=dtype) + zero_grad = zero_optimizer.get_param_grad(p) + if p.grad is None: + """ + For fixed input seed, the test input may cause a certain expert not to be routed to, + so its gradient is None instead of a tensor, which may lead to a potential bug. + """ + # TODO(haze188) fix later + p.grad = torch.zeros_like(p) continue - else: - zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(p)) - assert len(zero_grad_list) != 0 - ori_grad_list = split_grad(name_to_p[n].grad, world_size) - if stage == 2: - # Zero2 splits the gradient, and each rank holds the corresponding part - ori_grad_list = ori_grad_list[rank : rank + 1] - for zero_grad, torch_grad in zip(zero_grad_list, ori_grad_list): - loose_close(zero_grad, torch_grad, dtype=dtype) + loose_close(zero_grad, name_to_p[n].grad, dtype=dtype) # zero-dp step zero_optimizer.step() diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py deleted file mode 100644 index 3bbd90fd6aac..000000000000 --- a/tests/test_moe/test_moe_zero_optim.py +++ /dev/null @@ -1,125 +0,0 @@ -import pytest -import torch -import torch.distributed as dist -from torch.nn.parallel import DistributedDataParallel as DDP - -import colossalai -from colossalai.moe.manager import MOE_MANAGER -from colossalai.tensor.moe_tensor.api import is_moe_tensor -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.testing.random import seed_all -from colossalai.zero.low_level.low_level_optim import LowLevelZeroOptimizer -from colossalai.zero.low_level.low_level_strategy import LowLevelOptStrategy, MoeZeroStrategy -from tests.test_moe.moe_utils import MoeModel, delete_moe_info, loose_close, sync_local_from_ep - - -def run_zero_test(local_rank): - dp_size = world_size = dist.get_world_size() - assert world_size >= 4, f"{world_size=}: at least 4 processes are required for this test (ep=2, moe_dp=2)" - criterion = torch.nn.CrossEntropyLoss() - - ep_size = 2 - extra_dp_size = world_size // ep_size - - MOE_MANAGER.__init__() - MOE_MANAGER.setup(parallel="EP", mode="fixed", fixed_dp_size=extra_dp_size, fixed_ep_size=ep_size, fixed_pp_size=1) - - zero_model = MoeModel().bfloat16().cuda() - - dp_group = dist.group.WORLD - ep_group = MOE_MANAGER.parallel_info_dict[ep_size].ep_group - moe_extra_dp_group = MOE_MANAGER.parallel_info_dict[ep_size].dp_group - - zero_params = list(filter(lambda x: not is_moe_tensor(x), zero_model.parameters())) - moe_params = list(filter(lambda x: is_moe_tensor(x), zero_model.parameters())) - print(f"{len(zero_params)=}, {len(moe_params)=}") - lr = 1e-3 - zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=lr) - zero_optimizer.param_groups.clear() - zero_optimizer.add_param_group({"params": zero_params}) - zero_optimizer.add_param_group({"params": moe_params}) - - strategies = [ - LowLevelOptStrategy( - param_group=zero_optimizer.param_groups[0], - process_group=dp_group, - overlap_communication=False, - partition_grad=True, - ), - MoeZeroStrategy( - param_group=zero_optimizer.param_groups[1], - process_group=moe_extra_dp_group, - overlap_communication=True, - partition_grad=False, - ), - ] - zero_optimizer = LowLevelZeroOptimizer( - zero_optimizer, - strategies, - ) - - MOE_MANAGER.__init__() - MOE_MANAGER.setup(parallel=None) - ddp_model = DDP(MoeModel().bfloat16().cuda(), static_graph=True) - delete_moe_info(ddp_model) - torch_optim = torch.optim.SGD(ddp_model.parameters(), lr=lr) - sync_local_from_ep(ddp_model, zero_model) - - seed_all(42 + local_rank) - data = torch.randn(16, 4).bfloat16().cuda() - label = torch.randint(0, 4, (16,)).cuda() - - ddp_model.train() - zero_model.train() - ddp_out = criterion(ddp_model(data), label).float() - zero_out = criterion(zero_model(data), label).float() - assert torch.allclose(ddp_out, zero_out) - print(f"{local_rank=} {ddp_out.mean()=}") - - ddp_out.backward() - zero_optimizer.backward(zero_out) - - for (zero_name, zero_param), (ddp_name, ddp_param) in zip( - zero_model.named_parameters(), ddp_model.named_parameters() - ): - torch_grad = ddp_param.grad - zero_grad = zero_optimizer.get_param_grad(zero_param) - if is_moe_tensor(zero_param): - moe_grad_list = [torch.empty_like(zero_grad) for _ in range(ep_size)] - dist.all_gather(moe_grad_list, zero_grad, group=ep_group) - zero_grad = torch.cat(moe_grad_list, dim=0) - loose_close(torch_grad, zero_grad, dtype=torch_grad.dtype) - - torch_optim.step() - zero_optimizer.step() - - for (zero_name, zero_param), (ddp_name, ddp_param) in zip( - zero_model.named_parameters(), ddp_model.named_parameters() - ): - if is_moe_tensor(zero_param): - moe_param_list = [torch.empty_like(zero_param) for _ in range(ep_size)] - dist.all_gather(moe_param_list, zero_param, group=ep_group) - zero_param = torch.cat(moe_param_list, dim=0) - assert ddp_param.dtype == zero_param.dtype - ddp_param.numel() // dp_size - loose_close( - ddp_param, - zero_param, - dtype=ddp_param.dtype, - ) - - -def run_dist(rank, world_size, port, stage): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - run_zero_test(rank, stage=stage) - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [4]) -@rerun_if_address_is_in_use() -def test_moe_zero_model(world_size): - spawn(run_dist, world_size) - - -if __name__ == "__main__": - test_moe_zero_model(world_size=4)