Skip to content

Commit

Permalink
[shardformer] add custom policy in hybrid parallel plugin (#4718)
Browse files Browse the repository at this point in the history
* add custom policy

* update assert
  • Loading branch information
oahzxl committed Sep 15, 2023
1 parent 451c346 commit ac27979
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.zero.low_level import LowLevelZeroOptimizer

from .pp_plugin_base import PipelinePluginBase
Expand All @@ -38,13 +39,15 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
class HybridParallelModule(ModelWrapper):

def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool,
ddp_config: dict) -> None:
ddp_config: dict, custom_policy: Policy) -> None:

self.stage_manager = shard_config.pipeline_stage_manager
self.dp_group = dp_group

shardformer = ShardFormer(shard_config)
module, self.shared_params = shardformer.optimize(module)
if custom_policy is not None:
assert isinstance(custom_policy, object)
module, self.shared_params = shardformer.optimize(module, policy=custom_policy)

# setting process groups for shared parameters
self.shared_param_process_groups = []
Expand Down Expand Up @@ -270,6 +273,7 @@ class HybridParallelPlugin(PipelinePluginBase):
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
"""

def __init__(self,
Expand Down Expand Up @@ -302,7 +306,8 @@ def __init__(self,
zero_bucket_size_in_m: int = 12,
cpu_offload: bool = False,
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True) -> None:
overlap_communication: bool = True,
custom_policy: Policy = None) -> None:

super().__init__()
assert dist.get_world_size() % (
Expand All @@ -326,6 +331,7 @@ def __init__(self,
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
self.stage_manager = None
self.schedule = None
self.custom_policy = custom_policy
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism'
Expand Down Expand Up @@ -405,7 +411,7 @@ def configure(
if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp,
self.ddp_config)
self.ddp_config, self.custom_policy)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0:
if self.precision in ['fp16', 'bf16']:
Expand Down

0 comments on commit ac27979

Please sign in to comment.