Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[shardformer] fix pipeline forward error if custom layer distribution is used #5189

Merged
merged 9 commits into from
Mar 27, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
module = self.model.model

layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
Expand Down
7 changes: 3 additions & 4 deletions colossalai/shardformer/policies/base_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,7 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""
return []

@staticmethod
def distribute_layers(num_layers: int, num_stages: int) -> List[int]:
def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]:
"""Divide layers into stages"""
quotient = num_layers // num_stages
remainder = num_layers % num_stages
Expand All @@ -213,8 +212,8 @@ def distribute_layers(num_layers: int, num_stages: int) -> List[int]:
layers_per_stage[i] += 1
return layers_per_stage

@staticmethod
def get_stage_index(
self,
layers_per_stage: List[int],
stage: int,
num_model_chunks: int = 1,
Expand Down Expand Up @@ -242,4 +241,4 @@ def get_stage_index(
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
stage_indices.append([start_idx, end_idx])

return stage_indices[0] if num_model_chunks == 1 else stage_indices
return stage_indices[0] if num_model_chunks == 1 else stage_indices
67 changes: 49 additions & 18 deletions colossalai/shardformer/policies/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,26 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="attention.self.query",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
),
SubModuleReplacementDescription(
suffix="attention.self.key",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
),
SubModuleReplacementDescription(
suffix="attention.self.value",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
),
SubModuleReplacementDescription(
suffix="attention.self.dropout",
Expand All @@ -112,7 +121,10 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
),
SubModuleReplacementDescription(
suffix="output.dense",
Expand Down Expand Up @@ -214,7 +226,9 @@ def add_lm_head_policy(self, base_policy):
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}
suffix="decoder",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True},
),
policy=base_policy,
target_key=BertLMPredictionHead,
Expand All @@ -241,7 +255,9 @@ def add_lm_prediction_policy(self, base_policy):
"_load_from_state_dict": col_nn.ParallelModule._load_from_state_dict,
}
self.append_or_create_method_replacement(
description=method_replacement, policy=base_policy, target_key=BertLMPredictionHead
description=method_replacement,
policy=base_policy,
target_key=BertLMPredictionHead,
)
return base_policy

Expand All @@ -264,24 +280,32 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli

if stage_manager.is_interleave:
layers_per_stage = self.distribute_layers(
len(module.encoder.layer), stage_manager.num_stages * stage_manager.num_model_chunks
len(module.encoder.layer),
stage_manager.num_stages * stage_manager.num_model_chunks,
)
stage_manager.stage_indices = Policy.get_stage_index(
stage_manager.stage_indices = self.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
method_replacement = {
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
"forward": partial(
new_forward,
stage_manager=stage_manager,
shard_config=self.shard_config,
)
}

else:
layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=self.shard_config,
)
}

Expand All @@ -301,9 +325,10 @@ def get_held_layers(self) -> List[Module]:
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = self.distribute_layers(
len(module.encoder.layer), stage_manager.num_stages * stage_manager.num_model_chunks
len(module.encoder.layer),
stage_manager.num_stages * stage_manager.num_model_chunks,
)
stage_indices = Policy.get_stage_index(
stage_indices = self.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
Expand All @@ -320,7 +345,7 @@ def get_held_layers(self) -> List[Module]:
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.embeddings)
start_idx, end_idx = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.encoder.layer[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.pooler)
Expand All @@ -336,7 +361,9 @@ def module_policy(self):

if self.pipeline_stage_manager:
self.set_pipeline_forward(
model_cls=BertModel, new_forward=BertPipelineForwards.bert_model_forward, policy=policy
model_cls=BertModel,
new_forward=BertPipelineForwards.bert_model_forward,
policy=policy,
)
return policy

Expand Down Expand Up @@ -399,7 +426,9 @@ def module_policy(self):

if self.pipeline_stage_manager:
self.set_pipeline_forward(
model_cls=BertLMHeadModel, new_forward=BertPipelineForwards.bert_lm_head_model_forward, policy=policy
model_cls=BertLMHeadModel,
new_forward=BertPipelineForwards.bert_lm_head_model_forward,
policy=policy,
)
return policy

Expand Down Expand Up @@ -437,7 +466,9 @@ def module_policy(self):

if self.pipeline_stage_manager:
self.set_pipeline_forward(
model_cls=BertForMaskedLM, new_forward=BertPipelineForwards.bert_for_masked_lm_forward, policy=policy
model_cls=BertForMaskedLM,
new_forward=BertPipelineForwards.bert_for_masked_lm_forward,
policy=policy,
)
return policy

Expand Down
4 changes: 2 additions & 2 deletions colossalai/shardformer/policies/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
else:
module = self.model.transformer

layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
Expand Down
4 changes: 2 additions & 2 deletions colossalai/shardformer/policies/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
else:
module = self.model.transformer

layers_per_stage = Policy.distribute_layers(module.num_layers, stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages)
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
Expand Down
4 changes: 2 additions & 2 deletions colossalai/shardformer/policies/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
else:
module = self.model.transformer

layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
Expand Down
8 changes: 4 additions & 4 deletions colossalai/shardformer/policies/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def get_held_layers(self) -> List[nn.Module]:
layers_per_stage = self.distribute_layers(
len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks
)
stage_indices = Policy.get_stage_index(
stage_indices = self.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
Expand Down Expand Up @@ -229,7 +229,7 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
layers_per_stage = self.distribute_layers(
len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks
)
stage_manager.stage_indices = Policy.get_stage_index(
stage_manager.stage_indices = self.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
Expand All @@ -243,8 +243,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
)
}
else:
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {
"forward": partial(
new_forward,
Expand Down
4 changes: 2 additions & 2 deletions colossalai/shardformer/policies/gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
else:
module = self.model.transformer

layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {
"forward": partial(
new_forward,
Expand Down
8 changes: 4 additions & 4 deletions colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
layers_per_stage = self.distribute_layers(
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
)
stage_manager.stage_indices = Policy.get_stage_index(
stage_manager.stage_indices = self.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
Expand All @@ -178,8 +178,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
}

else:
layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
Expand Down Expand Up @@ -207,7 +207,7 @@ def get_held_layers(self) -> List[Module]:
layers_per_stage = self.distribute_layers(
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
)
stage_indices = Policy.get_stage_index(
stage_indices = self.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
Expand Down
4 changes: 2 additions & 2 deletions colossalai/shardformer/policies/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
else:
module = self.model.model.decoder

layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {
"forward": partial(
new_forward,
Expand Down
Loading
Loading