Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[moe] initialize MoE groups by ProcessGroup #1640

Merged
merged 1 commit into from
Sep 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
39 changes: 4 additions & 35 deletions colossalai/context/moe_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from colossalai.context.parallel_mode import ParallelMode
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor import ProcessGroup

from typing import Tuple

Expand All @@ -22,41 +23,9 @@ def __init__(self, ep_size: int, dp_size: int):
_check_sanity()
self.ep_size = ep_size
self.dp_size = dp_size
self.ep_group = None
# data parallel group for experts, since ep_group is different
# we may have different dp_group from get_group(ParallelMode.DATA)
self.dp_group = None

# Here we assume tensor parallel size = 1
# Otherwise, MoE can't be used
# Since TENSOR parallel group and DATA parallel group
# have been created, we can use them directly.
if ep_size == 1:
from colossalai.core import global_context as gpc
self.ep_group = gpc.get_group(ParallelMode.TENSOR)
self.dp_group = gpc.get_group(ParallelMode.DATA)
return

if dp_size == 1:
from colossalai.core import global_context as gpc
self.ep_group = gpc.get_group(ParallelMode.DATA)
self.dp_group = gpc.get_group(ParallelMode.TENSOR)
return

rank = dist.get_rank()
# Create expert parallel group
for i in range(dp_size):
ranks = [i * ep_size + j for j in range(ep_size)]
group = dist.new_group(ranks)
if rank in ranks:
self.ep_group = group

# Create data parallel group
for j in range(ep_size):
ranks = [i * ep_size + j for i in range(dp_size)]
group = dist.new_group(ranks)
if rank in ranks:
self.dp_group = group
self.pg = ProcessGroup(tp_degree=ep_size, dp_degree=dp_size)
self.ep_group = self.pg.tp_process_group()
self.dp_group = self.pg.dp_process_group()


class MoeContext(metaclass=SingletonMeta):
Expand Down
63 changes: 63 additions & 0 deletions tests/test_moe/test_moe_colo_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from functools import partial

import colossalai
import pytest
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from colossalai.testing import parameterize
from colossalai.utils import free_port
from colossalai.context import MOE_CONTEXT
from colossalai.tensor import ColoParameter
from colossalai.utils.model.colo_init_context import ColoInitContext

from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import get_current_device

from tests.test_zero.common import CONFIG
from tests.test_moe.test_moe_zero_init import MoeModel
from tests.test_tensor.common_utils import debug_print


@parameterize("init_device_type", ['cpu', 'cuda'])
def exam_moe_colo_init(init_device_type):
world_size = dist.get_world_size()

if init_device_type == 'cuda':
init_device = get_current_device()
elif init_device_type == 'cpu':
init_device = torch.device("cpu")
else:
raise NotImplementedError("Unknown device found.")

with ColoInitContext(device=init_device):
model = MoeModel(checkpoint=True)

for name, param in model.named_parameters():
assert isinstance(param, ColoParameter), "parameter `{}` has an init problem".format(name)

if hasattr(param, "moe_info"):
param.set_process_group(param.moe_info.pg)

if hasattr(param, "moe_info"):
assert param.process_group.dp_world_size() == param.moe_info.dp_size
else:
assert param.process_group.dp_world_size() == world_size


def _run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
MOE_CONTEXT.setup(seed=42)
exam_moe_colo_init()


@pytest.mark.dist
@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
def test_moe_colo_init(world_size):
run_func = partial(_run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)


if __name__ == '__main__':
test_moe_colo_init(world_size=4)