-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[moe] initialize MoE groups by ProcessGroup (#1640)
- Loading branch information
Showing
2 changed files
with
67 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from functools import partial | ||
|
||
import colossalai | ||
import pytest | ||
import torch | ||
import torch.multiprocessing as mp | ||
import torch.distributed as dist | ||
from colossalai.testing import parameterize | ||
from colossalai.utils import free_port | ||
from colossalai.context import MOE_CONTEXT | ||
from colossalai.tensor import ColoParameter | ||
from colossalai.utils.model.colo_init_context import ColoInitContext | ||
|
||
from colossalai.testing import rerun_if_address_is_in_use | ||
from colossalai.utils import get_current_device | ||
|
||
from tests.test_zero.common import CONFIG | ||
from tests.test_moe.test_moe_zero_init import MoeModel | ||
from tests.test_tensor.common_utils import debug_print | ||
|
||
|
||
@parameterize("init_device_type", ['cpu', 'cuda']) | ||
def exam_moe_colo_init(init_device_type): | ||
world_size = dist.get_world_size() | ||
|
||
if init_device_type == 'cuda': | ||
init_device = get_current_device() | ||
elif init_device_type == 'cpu': | ||
init_device = torch.device("cpu") | ||
else: | ||
raise NotImplementedError("Unknown device found.") | ||
|
||
with ColoInitContext(device=init_device): | ||
model = MoeModel(checkpoint=True) | ||
|
||
for name, param in model.named_parameters(): | ||
assert isinstance(param, ColoParameter), "parameter `{}` has an init problem".format(name) | ||
|
||
if hasattr(param, "moe_info"): | ||
param.set_process_group(param.moe_info.pg) | ||
|
||
if hasattr(param, "moe_info"): | ||
assert param.process_group.dp_world_size() == param.moe_info.dp_size | ||
else: | ||
assert param.process_group.dp_world_size() == world_size | ||
|
||
|
||
def _run_dist(rank, world_size, port): | ||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') | ||
MOE_CONTEXT.setup(seed=42) | ||
exam_moe_colo_init() | ||
|
||
|
||
@pytest.mark.dist | ||
@pytest.mark.parametrize("world_size", [4]) | ||
@rerun_if_address_is_in_use() | ||
def test_moe_colo_init(world_size): | ||
run_func = partial(_run_dist, world_size=world_size, port=free_port()) | ||
mp.spawn(run_func, nprocs=world_size) | ||
|
||
|
||
if __name__ == '__main__': | ||
test_moe_colo_init(world_size=4) |