Skip to content

Commit

Permalink
[Fix] Fix placement policy in ColossalAIStrategy (#1440)
Browse files Browse the repository at this point in the history
  • Loading branch information
fanqiNO1 committed Dec 23, 2023
1 parent efcd364 commit 671f3bc
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions mmengine/_strategy/colossalai.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,9 @@ def backward(self, loss: torch.Tensor, **kwargs) -> None:
self.optimizer.backward(loss, **kwargs)


@MODEL_WRAPPERS.register_module()
class CollosalAIModelWrapper:
@MODEL_WRAPPERS.register_module(
name=['ColossalAIModelWrapper', 'CollosalAIModelWrapper'])
class ColossalAIModelWrapper:

def __init__(self, model_wrapper: ModelWrapper, model: nn.Module):
self.model_wrapper = model_wrapper
Expand Down Expand Up @@ -238,7 +239,7 @@ class ColossalAIStrategy(BaseStrategy):
OPTIMIZER_DIR = 'optimizer' # directory to save optimizer state.
MODEL_DIR = 'model' # directory to save model
SCHEDULER_DIR = 'scheduler' # directory to save scheduelrs
model: CollosalAIModelWrapper # type: ignore
model: ColossalAIModelWrapper # type: ignore
optim_wrapper: ColossalAIOptimWrapper # type: ignore

def __init__(
Expand Down Expand Up @@ -468,8 +469,14 @@ def save_checkpoint(
def _build_plugin(self, plugin: Union[str, dict]):
if isinstance(plugin, str):
if plugin == 'gemini':
plugin = colo_plugin.GeminiPlugin(
precision='bf16', placement_policy='cuda')
try:
plugin = colo_plugin.GeminiPlugin(
precision='bf16', placement_policy='auto')
except AssertionError:
from colossalai.zero.gemini.placement_policy import \
PlacementPolicyFactory as colo_placement
raise ValueError('placement policy must be one of ' +
f'{list(colo_placement.policies.keys())}')
elif plugin == 'lowlevel-zero':
plugin = colo_plugin.LowLevelZeroPlugin()
else:
Expand Down Expand Up @@ -508,11 +515,11 @@ def _wrap(
self,
model: nn.Module,
optim_wrapper: Optional[OptimWrapper] = None,
) -> Union[Tuple[CollosalAIModelWrapper, ColossalAIOptimWrapper],
CollosalAIModelWrapper]: # type: ignore
) -> Union[Tuple[ColossalAIModelWrapper, ColossalAIOptimWrapper],
ColossalAIModelWrapper]: # type: ignore
"""Wrap model with :class:`ModelWrapper`."""
if self.model_wrapper is None:
self.model_wrapper = {'type': 'CollosalAIModelWrapper'}
self.model_wrapper = {'type': 'ColossalAIModelWrapper'}

# For zero series parallel, move `data_preprocessor` to current device
# is reasonable. We need to `BaseDataPreprocessor.to` manually since
Expand Down

0 comments on commit 671f3bc

Please sign in to comment.