diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py b/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py index 23ffbf5d317c..c01e02c49a60 100644 --- a/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py +++ b/applications/ColossalMoE/colossal_moe/models/mixtral_policy.py @@ -109,8 +109,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli else: module = self.model.model - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} self.append_or_create_method_replacement( description=method_replacement, policy=policy, target_key=model_cls @@ -129,10 +129,10 @@ def get_held_layers(self) -> List[Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) if stage_manager.is_first_stage(): held_layers.append(module.embed_tokens) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.layers[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.norm) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index f51cb060c356..eba7d1c1f8b8 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -26,7 +26,7 @@ from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer from colossalai.shardformer.layer.utils import SeqParallelUtils from colossalai.shardformer.policies.base_policy import Policy from colossalai.tensor.d_tensor.api import is_distributed_tensor @@ -930,6 +930,7 @@ class HybridParallelPlugin(PipelinePluginBase): custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None. pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'. num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1. + gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None. enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. """ @@ -969,6 +970,7 @@ def __init__( custom_policy: Policy = None, pp_style: str = "1f1b", num_model_chunks: int = 1, + gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, enable_metadata_cache: bool = True, ) -> None: super().__init__() @@ -1043,6 +1045,7 @@ def __init__( enable_sequence_parallelism=enable_sequence_parallelism, enable_sequence_overlap=enable_sequence_overlap, parallel_output=parallel_output, + gradient_checkpoint_config=gradient_checkpoint_config, ) self.amp_config = dict( initial_scale=initial_scale, diff --git a/colossalai/inference/engine/policies/bloom.py b/colossalai/inference/engine/policies/bloom.py index f35b50189e82..5bc47c3c1a49 100644 --- a/colossalai/inference/engine/policies/bloom.py +++ b/colossalai/inference/engine/policies/bloom.py @@ -114,12 +114,12 @@ def get_held_layers(self) -> List[Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) if stage_manager.is_first_stage(): held_layers.append(module.word_embeddings) held_layers.append(module.word_embeddings_layernorm) held_layers.append(self.model.lm_head) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.h[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.ln_f) diff --git a/colossalai/inference/engine/policies/chatglm2.py b/colossalai/inference/engine/policies/chatglm2.py index 3e1d94f4785c..c7c6f3b927e1 100644 --- a/colossalai/inference/engine/policies/chatglm2.py +++ b/colossalai/inference/engine/policies/chatglm2.py @@ -69,11 +69,11 @@ def get_held_layers(self) -> List[nn.Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(module.num_layers) if stage_manager.is_first_stage(): held_layers.append(module.embedding) held_layers.append(module.output_layer) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.encoder.layers[start_idx:end_idx]) if stage_manager.is_last_stage(): if module.encoder.post_layer_norm: diff --git a/colossalai/inference/engine/policies/llama.py b/colossalai/inference/engine/policies/llama.py index 11517d7e8a13..a57a4e50cdb9 100644 --- a/colossalai/inference/engine/policies/llama.py +++ b/colossalai/inference/engine/policies/llama.py @@ -194,11 +194,11 @@ def get_held_layers(self) -> List[Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) if stage_manager.is_first_stage(): held_layers.append(module.embed_tokens) held_layers.append(self.model.lm_head) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.layers[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.norm) diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index c8f9042084da..b0556669b2bc 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -1,6 +1,7 @@ import contextlib -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union +import numpy as np import torch.distributed as dist from torch.distributed import ProcessGroup @@ -29,6 +30,8 @@ def __init__( ) -> None: assert enable_interleave or num_model_chunks == 1, "num_model_chunks must be 1 when enable_interleave is False" + self.num_layers_per_stage = None + self.pg_mesh = pg_mesh self.pipeline_axis = pipeline_axis self.prev_rank: Optional[Tuple[int, ...]] = None @@ -69,6 +72,88 @@ def __init__( # for shardformer, hold model chunk id self.model_chunk_id: Optional[int] = None + @property + def control_distribute_layers(self) -> bool: + return self.num_layers_per_stage is not None + + def set_distribution_config(self, num_model_layers: int, num_layers_per_stage: List[int]) -> None: + """Set the distribution configuration. + This allows user to customize the number of layers for each stage. + + Args: + num_model_layers (int): Number of layers in the model. + num_layers_per_stage (List[int]): Number of layers for each stage. + """ + assert all([0 < num_layers < num_model_layers for num_layers in num_layers_per_stage]) + assert sum(num_layers_per_stage) == num_model_layers + assert len(num_layers_per_stage) == self.num_stages * (self.num_model_chunks if self.is_interleave else 1) + self.num_model_layers = num_model_layers + self.num_layers_per_stage = num_layers_per_stage + + def distribute_layers( + self, num_layers: int, num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None + ) -> List[int]: + """Divide layers into stages""" + num_stages = self.num_stages if num_stages is None else num_stages + num_model_chunks = ( + (self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks + ) + + if self.control_distribute_layers: + assert num_layers == self.num_model_layers + return self.num_layers_per_stage + + else: + quotient = num_layers // (num_stages * num_model_chunks) + remainder = num_layers % (num_stages * num_model_chunks) + + # calculate the num_layers per stage + layers_per_stage = [quotient] * num_stages * num_model_chunks + + # deal with the rest layers + if remainder > 0: + start_position = (num_stages * num_model_chunks) // 2 - remainder // 2 + for i in range(start_position, start_position + remainder): + layers_per_stage[i] += 1 + return layers_per_stage + + def get_stage_index( + self, + layers_per_stage: List[int], + stage: Optional[int] = None, + num_model_chunks: Optional[int] = None, + num_stages: Optional[int] = None, + ) -> Union[Tuple[int, int], List[Tuple[int, int]]]: + """ + Get the start index and end index of layers for each stage. + + Args: + layers_per_stage (List[int]): number of layers for each stage + stage (int): the stage index + num_stages (int): number of stages + num_model_chunks (int): number of model chunks + + Returns: + - Tuple[int, int]: the start index and end index of this stage + - List[Tuple[int, int]]: the start index and end index of this stage for each model chunk + + """ + stage = self.stage if stage is None else stage + num_model_chunks = ( + (self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks + ) + num_stages = self.num_stages if num_stages is None else num_stages + + num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) + + stage_indices = [] + for model_chunk in range(num_model_chunks): + start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages] + end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1] + stage_indices.append([start_idx, end_idx]) + + return stage_indices[0] if num_model_chunks == 1 else stage_indices + def is_first_stage(self, ignore_chunk: bool = False) -> bool: """Is the current stage the first stage. diff --git a/colossalai/shardformer/__init__.py b/colossalai/shardformer/__init__.py index 77c2af8d18f7..234e7131728f 100644 --- a/colossalai/shardformer/__init__.py +++ b/colossalai/shardformer/__init__.py @@ -1 +1 @@ -from .shard import ShardConfig, ShardFormer +from .shard import GradientCheckpointConfig, ModelSharder, PipelineGradientCheckpointConfig, ShardConfig, ShardFormer diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 29dc8200f338..eb421c92b82c 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -138,13 +138,25 @@ def llama_model_forward( next_decoder_cache = () if use_cache else None start_idx, end_idx = stage_index[0], stage_index[1] + num_ckpt_layers = 0 + if self.gradient_checkpointing and self.training: + num_ckpt_layers = end_idx - start_idx + # TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer + if shard_config.gradient_checkpoint_config is not None: + num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers( + stage=stage_manager.stage, + num_layers=end_idx - start_idx, + model_chunk_id=stage_manager.model_chunk_id if stage_manager.is_interleave else 0, + ) + assert num_ckpt_layers <= end_idx - start_idx + for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): if output_hidden_states: all_hidden_states += (hidden_states,) past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: + if idx - start_idx < num_ckpt_layers: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index 762e754816bf..d67ab0a3c6bb 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -2,9 +2,8 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Union -import numpy as np import torch.nn as nn from torch import Tensor from torch.nn import Module @@ -196,49 +195,3 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}] """ return [] - - def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: - """Divide layers into stages""" - quotient = num_layers // num_stages - remainder = num_layers % num_stages - - # calculate the num_layers per stage - layers_per_stage = [quotient] * num_stages - - # deal with the rest layers - if remainder > 0: - start_position = num_stages // 2 - remainder // 2 - for i in range(start_position, start_position + remainder): - layers_per_stage[i] += 1 - return layers_per_stage - - def get_stage_index( - self, - layers_per_stage: List[int], - stage: int, - num_model_chunks: int = 1, - num_stages: int = 0, - ) -> Union[Tuple[int, int], List[Tuple[int, int]]]: - """ - Get the start index and end index of layers for each stage. - - Args: - layers_per_stage (List[int]): number of layers for each stage - stage (int): the stage index - num_stages (int): number of stages - num_model_chunks (int): number of model chunks - - Returns: - - Tuple[int, int]: the start index and end index of this stage - - List[Tuple[int, int]]: the start index and end index of this stage for each model chunk - - """ - num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) - - stage_indices = [] - for model_chunk in range(num_model_chunks): - start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages] - end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1] - stage_indices.append([start_idx, end_idx]) - - return stage_indices[0] if num_model_chunks == 1 else stage_indices diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 4d50a3c9920c..cd7bdcdd6fb4 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -279,16 +279,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli module = self.model.bert if stage_manager.is_interleave: - layers_per_stage = self.distribute_layers( - len(module.encoder.layer), - stage_manager.num_stages * stage_manager.num_model_chunks, - ) - stage_manager.stage_indices = self.get_stage_index( - layers_per_stage, - stage_manager.stage, - num_model_chunks=stage_manager.num_model_chunks, - num_stages=stage_manager.num_stages, - ) + layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer)) + stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( new_forward, @@ -298,8 +290,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli } else: - layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer)) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( new_forward, @@ -324,16 +316,8 @@ def get_held_layers(self) -> List[Module]: held_layers = [] if stage_manager.is_interleave: assert stage_manager.num_model_chunks is not None - layers_per_stage = self.distribute_layers( - len(module.encoder.layer), - stage_manager.num_stages * stage_manager.num_model_chunks, - ) - stage_indices = self.get_stage_index( - layers_per_stage, - stage_manager.stage, - num_model_chunks=stage_manager.num_model_chunks, - num_stages=stage_manager.num_stages, - ) + layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer)) + stage_indices = stage_manager.get_stage_index(layers_per_stage) if stage_manager.is_first_stage(ignore_chunk=True): held_layers.append(module.embeddings) for start_idx, end_idx in stage_indices: @@ -342,10 +326,10 @@ def get_held_layers(self) -> List[Module]: held_layers.append(module.pooler) else: - layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer)) if stage_manager.is_first_stage(): held_layers.append(module.embeddings) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.encoder.layer[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.pooler) diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index e4714c8c1b15..55b69d5f0d29 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -203,8 +203,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli else: module = self.model.transformer - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config @@ -226,11 +226,11 @@ def get_held_layers(self) -> List[Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) if stage_manager.is_first_stage(): held_layers.append(module.word_embeddings) held_layers.append(module.word_embeddings_layernorm) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.h[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.ln_f) diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index cbe6254d1561..0830d85f1073 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -179,10 +179,10 @@ def get_held_layers(self) -> List[nn.Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(module.num_layers) if stage_manager.is_first_stage(): held_layers.append(module.embedding) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.encoder.layers[start_idx:end_idx]) if stage_manager.is_last_stage(): if module.encoder.post_layer_norm: @@ -204,8 +204,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli else: module = self.model.transformer - layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(module.num_layers) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index 16bbc3f23f81..fe61c406fae3 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -161,8 +161,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli else: module = self.model.transformer - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config @@ -181,10 +181,10 @@ def get_held_layers(self) -> List[Module]: module = self.model.transformer stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) if stage_manager.is_first_stage(): held_layers.append(module.word_embeddings) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.h[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.ln_f) diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index d1a8c9dce2c7..4bcac3951a6b 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -185,15 +185,8 @@ def get_held_layers(self) -> List[nn.Module]: held_layers = [] if stage_manager.is_interleave: assert stage_manager.num_model_chunks is not None - layers_per_stage = self.distribute_layers( - len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks - ) - stage_indices = self.get_stage_index( - layers_per_stage, - stage_manager.stage, - num_model_chunks=stage_manager.num_model_chunks, - num_stages=stage_manager.num_stages, - ) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) + stage_indices = stage_manager.get_stage_index(layers_per_stage) if stage_manager.is_first_stage(ignore_chunk=True): held_layers.append(module.wte) held_layers.append(module.wpe) @@ -203,12 +196,12 @@ def get_held_layers(self) -> List[nn.Module]: if stage_manager.is_last_stage(ignore_chunk=True): held_layers.append(module.ln_f) else: - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) if stage_manager.is_first_stage(): held_layers.append(module.wte) held_layers.append(module.wpe) held_layers.append(module.drop) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.h[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.ln_f) @@ -226,15 +219,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli module = self.model.transformer if stage_manager.is_interleave: - layers_per_stage = self.distribute_layers( - len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks - ) - stage_manager.stage_indices = self.get_stage_index( - layers_per_stage, - stage_manager.stage, - num_model_chunks=stage_manager.num_model_chunks, - num_stages=stage_manager.num_stages, - ) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) + stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( new_forward, @@ -243,8 +229,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli ) } else: - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( new_forward, diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index b24443298e07..eab4c214a41f 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -179,11 +179,11 @@ def get_held_layers(self) -> List[nn.Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) if stage_manager.is_first_stage(): held_layers.append(module.wte) held_layers.append(module.drop) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.h[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.ln_f) @@ -200,8 +200,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli else: module = self.model.transformer - layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(len(module.h)) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( new_forward, diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index daa7708c8fdf..18d79f84a765 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -164,30 +164,20 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli module = self.model.model if stage_manager.is_interleave: - layers_per_stage = self.distribute_layers( - len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks - ) - stage_manager.stage_indices = self.get_stage_index( - layers_per_stage, - stage_manager.stage, - num_model_chunks=stage_manager.num_model_chunks, - num_stages=stage_manager.num_stages, - ) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) } else: - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config ) } - self.append_or_create_method_replacement( - description=method_replacement, policy=policy, target_key=model_cls - ) self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) @@ -204,15 +194,8 @@ def get_held_layers(self) -> List[Module]: held_layers = [] if stage_manager.is_interleave: assert stage_manager.num_model_chunks is not None - layers_per_stage = self.distribute_layers( - len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks - ) - stage_indices = self.get_stage_index( - layers_per_stage, - stage_manager.stage, - num_model_chunks=stage_manager.num_model_chunks, - num_stages=stage_manager.num_stages, - ) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_indices = stage_manager.get_stage_index(layers_per_stage) if stage_manager.is_first_stage(ignore_chunk=True): held_layers.append(module.embed_tokens) for start_idx, end_idx in stage_indices: @@ -221,10 +204,10 @@ def get_held_layers(self) -> List[Module]: held_layers.append(module.norm) else: - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) if stage_manager.is_first_stage(): held_layers.append(module.embed_tokens) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.layers[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.norm) diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 683f3a9d5a2d..98e584be861b 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -186,12 +186,12 @@ def get_held_layers(self) -> List[nn.Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) if stage_manager.is_first_stage(): held_layers.append(module.embed_tokens) held_layers.append(module.embed_positions) held_layers.append(module.project_in) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.layers[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.final_layer_norm) @@ -208,8 +208,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli else: module = self.model.model.decoder - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( new_forward, diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index f5f701dc0972..0c8ec15fa0a9 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -251,6 +251,8 @@ def distribute_t5_layers( Return the layer distribution as a list and the starting stage of decoder. If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers. """ + stage_manager = self.pipeline_stage_manager + assert stage_manager is not None, "Pipeline stage manager is not set." # number of encoder layers must be a positive integer if num_encoder_layers <= 0: @@ -262,7 +264,7 @@ def distribute_t5_layers( # in the case of T5EncoderModel, set decoder starting stage to num_stages since it doesn't exist if num_decoder_layers == 0: - return self.distribute_layers(num_encoder_layers, num_stages), num_stages + return stage_manager.distribute_layers(num_encoder_layers, num_stages), num_stages # the number of stages distributed between encoder and decoder is optimized in this way: # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages)) @@ -273,21 +275,26 @@ def objective(num_encoder_stages): num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1 num_decoder_stages = num_stages - num_encoder_stages - encoder_distribution = self.distribute_layers(num_encoder_layers, num_encoder_stages) - decoder_distribution = self.distribute_layers(num_decoder_layers, num_decoder_stages) + encoder_distribution = stage_manager.distribute_layers(num_encoder_layers, num_encoder_stages) + decoder_distribution = stage_manager.distribute_layers(num_decoder_layers, num_decoder_stages) return encoder_distribution + decoder_distribution, num_encoder_stages def get_t5_stage_index( self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int - ) -> Tuple[bool, int, int]: + ) -> Tuple[int, int]: """ Input the distribution of layers among stages, the current stage and the first stage of decoder. Return the starting/ending idx of layers in encoder/decoder """ + stage_manager = self.pipeline_stage_manager + assert stage_manager is not None, "Pipeline stage manager is not set." + if stage < decoder_starting_stage: - return self.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) + return stage_manager.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) else: - return self.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage) + return stage_manager.get_stage_index( + layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage + ) def get_held_layers(self) -> List[nn.Module]: """Get pipeline layers for current stage.""" diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index b0f224e22dc9..905398c4d51e 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -134,10 +134,10 @@ def get_held_layers(self) -> List[nn.Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer)) if stage_manager.is_first_stage(): held_layers.append(module.embeddings) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.encoder.layer[start_idx:end_idx]) return held_layers @@ -149,8 +149,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, pipeline_forward: Callable, else: module = self.model.vit - layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer)) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = {"forward": pipeline_forward(stage_manager=stage_manager, stage_index=stage_index)} self.append_or_create_method_replacement( description=method_replacement, policy=policy, target_key=model_cls diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index 480a4beea581..c63f6d1cc549 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -300,6 +300,8 @@ def distribute_whisper_layers( Return the layer distribution as a list and the starting stage of decoder. If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers. """ + stage_manager = self.pipeline_stage_manager + assert stage_manager is not None, "pipeline_stage_manager is None" # number of encoder layers must be a positive integer if num_encoder_layers <= 0: @@ -311,7 +313,7 @@ def distribute_whisper_layers( # in the case of whisperEncoderModel, set decoder starting stage to num_stages since it doesn't exist if num_decoder_layers == 0: - return self.distribute_layers(num_encoder_layers, num_stages), num_stages + return stage_manager.distribute_layers(num_encoder_layers, num_stages), num_stages # the number of stages distributed between encoder and decoder is optimized in this way: # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages)) @@ -322,21 +324,24 @@ def objective(num_encoder_stages): num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1 num_decoder_stages = num_stages - num_encoder_stages - encoder_distribution = self.distribute_layers(num_encoder_layers, num_encoder_stages) - decoder_distribution = self.distribute_layers(num_decoder_layers, num_decoder_stages) + encoder_distribution = stage_manager.distribute_layers(num_encoder_layers, num_encoder_stages) + decoder_distribution = stage_manager.distribute_layers(num_decoder_layers, num_decoder_stages) return encoder_distribution + decoder_distribution, num_encoder_stages def get_whisper_stage_index( self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int - ) -> Tuple[bool, int, int]: + ) -> Tuple[int, int]: """ Input the distribution of layers among stages, the current stage and the first stage of decoder. Return the starting/ending idx of layers in encoder/decoder """ + stage_manager = self.pipeline_stage_manager + assert stage_manager is not None, "pipeline_stage_manager is None" + if stage < decoder_starting_stage: - return self.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) + return stage_manager.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) else: - return self.get_stage_index( + return stage_manager.get_stage_index( layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage, ) diff --git a/colossalai/shardformer/shard/__init__.py b/colossalai/shardformer/shard/__init__.py index acf8a95a41ca..dff2118c1c1a 100644 --- a/colossalai/shardformer/shard/__init__.py +++ b/colossalai/shardformer/shard/__init__.py @@ -1,5 +1,6 @@ +from .grad_ckpt_config import GradientCheckpointConfig, PipelineGradientCheckpointConfig from .shard_config import ShardConfig from .sharder import ModelSharder from .shardformer import ShardFormer -__all__ = ["ShardConfig", "ModelSharder", "ShardFormer"] +__all__ = ["ShardConfig", "ModelSharder", "ShardFormer", "PipelineGradientCheckpointConfig", "GradientCheckpointConfig"] diff --git a/colossalai/shardformer/shard/grad_ckpt_config.py b/colossalai/shardformer/shard/grad_ckpt_config.py new file mode 100644 index 000000000000..9c6c2b54ea39 --- /dev/null +++ b/colossalai/shardformer/shard/grad_ckpt_config.py @@ -0,0 +1,87 @@ +from dataclasses import dataclass +from typing import List, Optional + + +@dataclass +class GradientCheckpointConfig: + gradient_checkpointing_ratio: float = 0.0 + + def get_num_ckpt_layers(self, num_layers: int) -> int: + return int(self.gradient_checkpointing_ratio * num_layers) + + +@dataclass +class PipelineGradientCheckpointConfig(GradientCheckpointConfig): + r""" + The pipeline gradient config is designed to provide more flexibility for users to control gradient checkpoint in pipeline parallelism. + Combined with PipelineStageManager.set_distribution_config, user can fully control the distribution of layers and checkpointed layers in pipeline parallelism. + Refer to https://github.com/hpcaitech/ColossalAI/issues/5509 for more details. + + It provides the following features: + 1. `gradient_checkpointing_ratio`: This is used to control gradient checkpointing more precisely, e.g., set 50% of the layers to use gradient checkpointing. + 2. Customize # ckpt layers assigned to each stage. This takes precedence over `gradient_checkpointing_ratio`. + + """ + """ + Args: + gradient_checkpointing_ratio (Optional[float]): The ratio of gradient checkpointing. It can only be used in pipeline parallelism. Defaults to None. + num_stages (Optional[int]): Number of stages in the pipeline. Defaults to None. For sanity check. + num_model_chunks (Optional[int]): Number of model chunks (1F1B or Interleaved). Defaults to None. For sanity check. + num_model_layers (Optional[int]): Number of model layers. Defaults to None. For sanity check. + num_ckpt_layers_per_stage (Optional[List[int]]): Number of checkpointed layers for each stage. Defaults to None. + + Example 1: + num_stages = 8 + num_layers = 80 + num_model_chunks = 1 + num_layers_per_stage = [9, 9, 9, 10, 11, 10, 11, 11] + num_ckpt_layers_per_stage = [4, 4, 2, 2, 0, 0, 0, 0] + + Example 2: + num_stages = 4 + num_layers = 80 + num_model_chunks = 2 + num_layers_per_stage = [9, 9, 9, 10, 11, 10, 11, 11] + # device 0 holds num_layers_per_stage[0] and num_layers_per_stage[4] layers + ... + + """ + num_stages: Optional[int] = None + num_model_chunks: Optional[int] = None + num_model_layers: Optional[int] = None + num_ckpt_layers_per_stage: Optional[List[int]] = None + + def __post_init__(self): + if self._enable_gradient_checkpointing_ratio: + if not (0 <= self.gradient_checkpointing_ratio <= 1): + raise ValueError("gradient_checkpointing_ratio should be in 0% to 100%") + + if self._enable_customized_ckpt_layers_per_stage: + assert ( + self.num_stages is not None and self.num_model_chunks is not None and self.num_model_layers is not None + ) + assert len(self.num_ckpt_layers_per_stage) == self.num_stages * self.num_model_chunks + assert all( + [0 <= num_ckpt_layers < self.num_model_layers for num_ckpt_layers in self.num_ckpt_layers_per_stage] + ) + self.gradient_checkpointing_ratio = sum(self.num_ckpt_layers_per_stage) / self.num_model_layers + + @property + def _enable_gradient_checkpointing_ratio(self) -> bool: + return self.gradient_checkpointing_ratio is not None + + @property + def _enable_customized_ckpt_layers_per_stage(self) -> bool: + return self.num_ckpt_layers_per_stage is not None + + def get_num_ckpt_layers(self, stage: int, num_layers: int, model_chunk_id: int = 0) -> int: + if not self._enable_gradient_checkpointing_ratio and not self._enable_customized_ckpt_layers_per_stage: + raise RuntimeError("No checkpointed layers information is provided") + + if self._enable_customized_ckpt_layers_per_stage: + assert stage <= self.num_stages and model_chunk_id <= self.num_model_chunks + num_ckpt_layers = self.num_ckpt_layers_per_stage[stage + model_chunk_id * self.num_stages] + assert num_ckpt_layers <= num_layers + return num_ckpt_layers + else: + return int(self.gradient_checkpointing_ratio * num_layers) diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index da27341d9c29..646b611932b7 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -6,6 +6,8 @@ from colossalai.pipeline.stage_manager import PipelineStageManager +from .grad_ckpt_config import GradientCheckpointConfig + __all__ = ["ShardConfig"] @@ -23,6 +25,7 @@ class ShardConfig: enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False. enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False. + gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None. enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. """ tensor_parallel_process_group: Optional[ProcessGroup] = None @@ -35,6 +38,7 @@ class ShardConfig: enable_sequence_parallelism: bool = False enable_sequence_overlap: bool = False parallel_output: bool = True + gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None extra_kwargs: Dict[str, Any] = field(default_factory=dict) # TODO padding vocab # make_vocab_size_divisible_by: int = 128 diff --git a/examples/language/openmoe/model/openmoe_policy.py b/examples/language/openmoe/model/openmoe_policy.py index 66a42e0176e9..8ef07bdb91b5 100644 --- a/examples/language/openmoe/model/openmoe_policy.py +++ b/examples/language/openmoe/model/openmoe_policy.py @@ -1,4 +1,3 @@ -import warnings from functools import partial from typing import Callable, Dict, List, Optional, Union @@ -21,7 +20,6 @@ class OpenMoePolicy(Policy): - def config_sanity_check(self): pass @@ -43,7 +41,8 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if self.shard_config.enable_sequence_parallelism: self.shard_config.enable_sequence_parallelism = False raise NotImplementedError( - "openmoe doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") + "openmoe doesn't support sequence parallelism now, will ignore the sequence parallelism flag." + ) if self.shard_config.enable_tensor_parallelism: raise NotImplementedError("Tensor parallelism is not supported for openmoe model now.") @@ -97,8 +96,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli else: module = self.model.model - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) - stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) + stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} self.append_or_create_method_replacement( description=method_replacement, policy=policy, target_key=model_cls @@ -117,10 +116,10 @@ def get_held_layers(self) -> List[Module]: stage_manager = self.pipeline_stage_manager held_layers = [] - layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) + layers_per_stage = stage_manager.distribute_layers(len(module.layers)) if stage_manager.is_first_stage(): held_layers.append(module.embed_tokens) - start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage) held_layers.extend(module.layers[start_idx:end_idx]) if stage_manager.is_last_stage(): held_layers.append(module.norm) @@ -143,7 +142,6 @@ def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: class OpenMoeModelPolicy(OpenMoePolicy): - def __init__(self) -> None: super().__init__() @@ -169,21 +167,21 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]: class OpenMoeForCausalLMPolicy(OpenMoePolicy): - def module_policy(self): policy = super().module_policy() if self.shard_config.enable_tensor_parallelism: # add a new item for casual lm new_item = { - OpenMoeForCausalLM: - ModulePolicyDescription(sub_module_replacement=[ + OpenMoeForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ SubModuleReplacementDescription( suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True), ) - ]) + ] + ) } policy.update(new_item) @@ -208,13 +206,17 @@ def get_held_layers(self) -> List[Module]: def get_shared_params(self) -> List[Dict[int, Tensor]]: llama_model = self.model.model if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: - if (id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight) - and self.pipeline_stage_manager.num_stages > 1): + if ( + id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight) + and self.pipeline_stage_manager.num_stages > 1 + ): # tie weights - return [{ - 0: llama_model.embed_tokens.weight, - self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, - }] + return [ + { + 0: llama_model.embed_tokens.weight, + self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight, + } + ] return [] @@ -247,12 +249,13 @@ def openmoe_model_forward( logger = logging.get_logger(__name__) - output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions) - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if stage_manager.is_first_stage(): @@ -320,7 +323,8 @@ def openmoe_model_forward( if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False # decoder layers @@ -333,12 +337,11 @@ def openmoe_model_forward( if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = (past_key_values[idx] if past_key_values is not None else None) + past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: def create_custom_forward(module): - def custom_forward(*inputs): # None for past_key_value return module(*inputs, output_attentions, None) @@ -384,14 +387,16 @@ def custom_forward(*inputs): router_z_loss = past_router_z_loss + router_z_loss if stage_manager.is_last_stage(): - return tuple([ - hidden_states, - next_cache, - all_hidden_states, - all_self_attns, - router_aux_loss, - router_z_loss, - ]) + return tuple( + [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + router_aux_loss, + router_z_loss, + ] + ) # always return dict for imediate stage return { "hidden_states": hidden_states, @@ -445,10 +450,11 @@ def llama_for_causal_lm_forward( "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." ```""" logger = logging.get_logger(__name__) - output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions) - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) - return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: @@ -504,7 +510,6 @@ def llama_for_causal_lm_forward( if chunk_head == True: def create_custom_forward(module): - def custom_forward(*inputs): logits = module(inputs[0]) logits = logits.float() @@ -522,8 +527,8 @@ def custom_forward(*inputs): for batch_idx in range(hidden_states.shape[0]): loss = loss + torch.utils.checkpoint.checkpoint( create_custom_forward(self.lm_head), - hidden_states[batch_idx:batch_idx + 1, :], - labels[batch_idx:batch_idx + 1, :], + hidden_states[batch_idx : batch_idx + 1, :], + labels[batch_idx : batch_idx + 1, :], ) logits = None else: diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index 4730642705ff..9f801e0cc732 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -49,9 +49,9 @@ def data_gen_for_casual_lm(): loss_fn_for_seq_classification = lambda output: output["logits"].mean() config = LlamaConfig( - num_hidden_layers=4, - hidden_size=128, - intermediate_size=256, + num_hidden_layers=8, + hidden_size=32, + intermediate_size=64, num_attention_heads=4, max_position_embeddings=128, num_labels=16, diff --git a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py index 4ba67225f271..1b7b0073f62e 100644 --- a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py +++ b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py @@ -1,4 +1,23 @@ +import random + +from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.policies.t5 import T5BasePolicy +from colossalai.shardformer.shard.shard_config import ShardConfig + + +class _ShardConfig(ShardConfig): + def __post_init__(self): + pass + + +class _PipelineStageManager(PipelineStageManager): + def __init__(self): + self.is_interleave = False + self.num_layers_per_stage = None + + @property + def num_stages(self): + return random.randint(5, 10) def test_t5_pipeline_distribution(): @@ -10,7 +29,10 @@ def test_t5_pipeline_distribution(): "decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2], } + stage_manager = _PipelineStageManager() + shard_config = _ShardConfig(pipeline_stage_manager=stage_manager) policy = T5BasePolicy() + policy.set_shard_config(shard_config) for i in range(num_test_cases): _, decoder_starting_stage = policy.distribute_t5_layers( test_dict["num_encoder_layers"][i], @@ -35,7 +57,10 @@ def test_t5_pipeline_layers(): } for i in range(num_test_cases): + stage_manager = _PipelineStageManager() + shard_config = _ShardConfig(pipeline_stage_manager=stage_manager) policy = T5BasePolicy() + policy.set_shard_config(shard_config) layers_per_stage, decoder_starting_stage = policy.distribute_t5_layers( test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], diff --git a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py index 0500e46e890a..9f8c1ad32d23 100644 --- a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py +++ b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py @@ -1,4 +1,23 @@ +import random + +from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.policies.whisper import WhisperPolicy +from colossalai.shardformer.shard.shard_config import ShardConfig + + +class _ShardConfig(ShardConfig): + def __post_init__(self): + pass + + +class _PipelineStageManager(PipelineStageManager): + def __init__(self): + self.is_interleave = False + self.num_layers_per_stage = None + + @property + def num_stages(self): + return random.randint(5, 10) def test_whisper_pipeline_distribution(): @@ -10,7 +29,10 @@ def test_whisper_pipeline_distribution(): "decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2], } + stage_manager = _PipelineStageManager() + shard_config = _ShardConfig(pipeline_stage_manager=stage_manager) policy = WhisperPolicy() + policy.set_shard_config(shard_config) for i in range(num_test_cases): _, decoder_starting_stage = policy.distribute_whisper_layers( test_dict["num_encoder_layers"][i], @@ -34,7 +56,10 @@ def test_whisper_pipeline_layers(): ], } + stage_manager = _PipelineStageManager() + shard_config = _ShardConfig(pipeline_stage_manager=stage_manager) policy = WhisperPolicy() + policy.set_shard_config(shard_config) for i in range(num_test_cases): layers_per_stage, decoder_starting_stage = policy.distribute_whisper_layers( test_dict["num_encoder_layers"][i], diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 126ff23a9f25..55858cbd4960 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -5,6 +5,7 @@ import colossalai from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import PipelineGradientCheckpointConfig from colossalai.shardformer.layer.utils import Randomizer from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn @@ -24,9 +25,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config): + enable_gradient_checkpointing = test_config.pop("enable_gradient_checkpointing", False) org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( model_fn, loss_fn, test_config ) + if enable_gradient_checkpointing: + org_model.gradient_checkpointing_enable() + sharded_model.unwrap().gradient_checkpointing_enable() org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster @@ -101,6 +106,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "use_lazy_init": True, "precision": "fp16", "initial_scale": 1, + "enable_gradient_checkpointing": True, + "gradient_checkpoint_config": PipelineGradientCheckpointConfig(gradient_checkpointing_ratio=0.5), }, { "tp_size": 1, @@ -108,6 +115,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "num_microbatches": 4, "use_lazy_init": False, "precision": "fp32", + "enable_gradient_checkpointing": True, + "gradient_checkpoint_config": PipelineGradientCheckpointConfig( + num_stages=2, num_model_chunks=1, num_model_layers=8, num_ckpt_layers_per_stage=[4, 0] + ), }, { "tp_size": 4, @@ -189,6 +200,13 @@ def run_llama_test(test_config): "precision": "fp16", "zero_stage": 1, "initial_scale": 1, + "enable_gradient_checkpointing": True, + "gradient_checkpoint_config": PipelineGradientCheckpointConfig( + num_stages=2, + num_model_chunks=2, + num_model_layers=8, + num_ckpt_layers_per_stage=[0, 1, 2, 2], + ), }, ], )