Skip to content

Commit

Permalink
[shardformer, pipeline] add gradient_checkpointing_ratio and hetero…
Browse files Browse the repository at this point in the history
…genous shard policy for llama (#5508)

* feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig`

* feat: apply `GradientCheckpointConfig` to policy and llama_forward

* feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager

* fix: add optional args for `distribute_layer` and `get_stage_index`

* fix: fix changed API calls

* test: update llama tests

* style: polish `GradientCheckpointConfig`

* fix: fix pipeline utils tests
  • Loading branch information
CWHer committed Apr 1, 2024
1 parent df5e9c5 commit e614aa3
Show file tree
Hide file tree
Showing 28 changed files with 396 additions and 213 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
else:
module = self.model.model

layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
Expand All @@ -129,10 +129,10 @@ def get_held_layers(self) -> List[Module]:
stage_manager = self.pipeline_stage_manager

held_layers = []
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
Expand Down
5 changes: 4 additions & 1 deletion colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.d_tensor.api import is_distributed_tensor
Expand Down Expand Up @@ -930,6 +930,7 @@ class HybridParallelPlugin(PipelinePluginBase):
custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'.
num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1.
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
"""

Expand Down Expand Up @@ -969,6 +970,7 @@ def __init__(
custom_policy: Policy = None,
pp_style: str = "1f1b",
num_model_chunks: int = 1,
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
enable_metadata_cache: bool = True,
) -> None:
super().__init__()
Expand Down Expand Up @@ -1043,6 +1045,7 @@ def __init__(
enable_sequence_parallelism=enable_sequence_parallelism,
enable_sequence_overlap=enable_sequence_overlap,
parallel_output=parallel_output,
gradient_checkpoint_config=gradient_checkpoint_config,
)
self.amp_config = dict(
initial_scale=initial_scale,
Expand Down
4 changes: 2 additions & 2 deletions colossalai/inference/engine/policies/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,12 @@ def get_held_layers(self) -> List[Module]:
stage_manager = self.pipeline_stage_manager

held_layers = []
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
layers_per_stage = stage_manager.distribute_layers(len(module.h))
if stage_manager.is_first_stage():
held_layers.append(module.word_embeddings)
held_layers.append(module.word_embeddings_layernorm)
held_layers.append(self.model.lm_head)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)
Expand Down
4 changes: 2 additions & 2 deletions colossalai/inference/engine/policies/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ def get_held_layers(self) -> List[nn.Module]:
stage_manager = self.pipeline_stage_manager

held_layers = []
layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages)
layers_per_stage = stage_manager.distribute_layers(module.num_layers)
if stage_manager.is_first_stage():
held_layers.append(module.embedding)
held_layers.append(module.output_layer)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.encoder.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
if module.encoder.post_layer_norm:
Expand Down
4 changes: 2 additions & 2 deletions colossalai/inference/engine/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,11 @@ def get_held_layers(self) -> List[Module]:
stage_manager = self.pipeline_stage_manager

held_layers = []
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens)
held_layers.append(self.model.lm_head)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.norm)
Expand Down
87 changes: 86 additions & 1 deletion colossalai/pipeline/stage_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import contextlib
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch.distributed as dist
from torch.distributed import ProcessGroup

Expand Down Expand Up @@ -29,6 +30,8 @@ def __init__(
) -> None:
assert enable_interleave or num_model_chunks == 1, "num_model_chunks must be 1 when enable_interleave is False"

self.num_layers_per_stage = None

self.pg_mesh = pg_mesh
self.pipeline_axis = pipeline_axis
self.prev_rank: Optional[Tuple[int, ...]] = None
Expand Down Expand Up @@ -69,6 +72,88 @@ def __init__(
# for shardformer, hold model chunk id
self.model_chunk_id: Optional[int] = None

@property
def control_distribute_layers(self) -> bool:
return self.num_layers_per_stage is not None

def set_distribution_config(self, num_model_layers: int, num_layers_per_stage: List[int]) -> None:
"""Set the distribution configuration.
This allows user to customize the number of layers for each stage.
Args:
num_model_layers (int): Number of layers in the model.
num_layers_per_stage (List[int]): Number of layers for each stage.
"""
assert all([0 < num_layers < num_model_layers for num_layers in num_layers_per_stage])
assert sum(num_layers_per_stage) == num_model_layers
assert len(num_layers_per_stage) == self.num_stages * (self.num_model_chunks if self.is_interleave else 1)
self.num_model_layers = num_model_layers
self.num_layers_per_stage = num_layers_per_stage

def distribute_layers(
self, num_layers: int, num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None
) -> List[int]:
"""Divide layers into stages"""
num_stages = self.num_stages if num_stages is None else num_stages
num_model_chunks = (
(self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
)

if self.control_distribute_layers:
assert num_layers == self.num_model_layers
return self.num_layers_per_stage

else:
quotient = num_layers // (num_stages * num_model_chunks)
remainder = num_layers % (num_stages * num_model_chunks)

# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages * num_model_chunks

# deal with the rest layers
if remainder > 0:
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
for i in range(start_position, start_position + remainder):
layers_per_stage[i] += 1
return layers_per_stage

def get_stage_index(
self,
layers_per_stage: List[int],
stage: Optional[int] = None,
num_model_chunks: Optional[int] = None,
num_stages: Optional[int] = None,
) -> Union[Tuple[int, int], List[Tuple[int, int]]]:
"""
Get the start index and end index of layers for each stage.
Args:
layers_per_stage (List[int]): number of layers for each stage
stage (int): the stage index
num_stages (int): number of stages
num_model_chunks (int): number of model chunks
Returns:
- Tuple[int, int]: the start index and end index of this stage
- List[Tuple[int, int]]: the start index and end index of this stage for each model chunk
"""
stage = self.stage if stage is None else stage
num_model_chunks = (
(self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
)
num_stages = self.num_stages if num_stages is None else num_stages

num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)

stage_indices = []
for model_chunk in range(num_model_chunks):
start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages]
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
stage_indices.append([start_idx, end_idx])

return stage_indices[0] if num_model_chunks == 1 else stage_indices

def is_first_stage(self, ignore_chunk: bool = False) -> bool:
"""Is the current stage the first stage.
Expand Down
2 changes: 1 addition & 1 deletion colossalai/shardformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .shard import ShardConfig, ShardFormer
from .shard import GradientCheckpointConfig, ModelSharder, PipelineGradientCheckpointConfig, ShardConfig, ShardFormer
14 changes: 13 additions & 1 deletion colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,25 @@ def llama_model_forward(
next_decoder_cache = () if use_cache else None

start_idx, end_idx = stage_index[0], stage_index[1]
num_ckpt_layers = 0
if self.gradient_checkpointing and self.training:
num_ckpt_layers = end_idx - start_idx
# TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer
if shard_config.gradient_checkpoint_config is not None:
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
stage=stage_manager.stage,
num_layers=end_idx - start_idx,
model_chunk_id=stage_manager.model_chunk_id if stage_manager.is_interleave else 0,
)
assert num_ckpt_layers <= end_idx - start_idx

for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
if output_hidden_states:
all_hidden_states += (hidden_states,)

past_key_value = past_key_values[idx] if past_key_values is not None else None

if self.gradient_checkpointing and self.training:
if idx - start_idx < num_ckpt_layers:

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down
49 changes: 1 addition & 48 deletions colossalai/shardformer/policies/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
import torch.nn as nn
from torch import Tensor
from torch.nn import Module
Expand Down Expand Up @@ -196,49 +195,3 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}]
"""
return []

def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]:
"""Divide layers into stages"""
quotient = num_layers // num_stages
remainder = num_layers % num_stages

# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages

# deal with the rest layers
if remainder > 0:
start_position = num_stages // 2 - remainder // 2
for i in range(start_position, start_position + remainder):
layers_per_stage[i] += 1
return layers_per_stage

def get_stage_index(
self,
layers_per_stage: List[int],
stage: int,
num_model_chunks: int = 1,
num_stages: int = 0,
) -> Union[Tuple[int, int], List[Tuple[int, int]]]:
"""
Get the start index and end index of layers for each stage.
Args:
layers_per_stage (List[int]): number of layers for each stage
stage (int): the stage index
num_stages (int): number of stages
num_model_chunks (int): number of model chunks
Returns:
- Tuple[int, int]: the start index and end index of this stage
- List[Tuple[int, int]]: the start index and end index of this stage for each model chunk
"""
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)

stage_indices = []
for model_chunk in range(num_model_chunks):
start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages]
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
stage_indices.append([start_idx, end_idx])

return stage_indices[0] if num_model_chunks == 1 else stage_indices
32 changes: 8 additions & 24 deletions colossalai/shardformer/policies/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,16 +279,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
module = self.model.bert

if stage_manager.is_interleave:
layers_per_stage = self.distribute_layers(
len(module.encoder.layer),
stage_manager.num_stages * stage_manager.num_model_chunks,
)
stage_manager.stage_indices = self.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(
new_forward,
Expand All @@ -298,8 +290,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
}

else:
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(
new_forward,
Expand All @@ -324,16 +316,8 @@ def get_held_layers(self) -> List[Module]:
held_layers = []
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = self.distribute_layers(
len(module.encoder.layer),
stage_manager.num_stages * stage_manager.num_model_chunks,
)
stage_indices = self.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
stage_indices = stage_manager.get_stage_index(layers_per_stage)
if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.embeddings)
for start_idx, end_idx in stage_indices:
Expand All @@ -342,10 +326,10 @@ def get_held_layers(self) -> List[Module]:
held_layers.append(module.pooler)

else:
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
if stage_manager.is_first_stage():
held_layers.append(module.embeddings)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.encoder.layer[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.pooler)
Expand Down
8 changes: 4 additions & 4 deletions colossalai/shardformer/policies/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
else:
module = self.model.transformer

layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
layers_per_stage = stage_manager.distribute_layers(len(module.h))
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
Expand All @@ -226,11 +226,11 @@ def get_held_layers(self) -> List[Module]:
stage_manager = self.pipeline_stage_manager

held_layers = []
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
layers_per_stage = stage_manager.distribute_layers(len(module.h))
if stage_manager.is_first_stage():
held_layers.append(module.word_embeddings)
held_layers.append(module.word_embeddings_layernorm)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.ln_f)
Expand Down

0 comments on commit e614aa3

Please sign in to comment.