Skip to content

Commit

Permalink
Fix missing changes
Browse files Browse the repository at this point in the history
  • Loading branch information
insujang committed Jan 6, 2024
1 parent 10786ae commit fb2f2b2
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 19 deletions.
67 changes: 49 additions & 18 deletions colossalai/shardformer/policies/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,26 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="attention.self.query",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
),
SubModuleReplacementDescription(
suffix="attention.self.key",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
),
SubModuleReplacementDescription(
suffix="attention.self.value",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
),
SubModuleReplacementDescription(
suffix="attention.self.dropout",
Expand All @@ -112,7 +121,10 @@ def module_policy(self):
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel, "overlap": overlap},
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
),
SubModuleReplacementDescription(
suffix="output.dense",
Expand Down Expand Up @@ -214,7 +226,9 @@ def add_lm_head_policy(self, base_policy):
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="decoder", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}
suffix="decoder",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True},
),
policy=base_policy,
target_key=BertLMPredictionHead,
Expand All @@ -241,7 +255,9 @@ def add_lm_prediction_policy(self, base_policy):
"_load_from_state_dict": col_nn.ParallelModule._load_from_state_dict,
}
self.append_or_create_method_replacement(
description=method_replacement, policy=base_policy, target_key=BertLMPredictionHead
description=method_replacement,
policy=base_policy,
target_key=BertLMPredictionHead,
)
return base_policy

Expand All @@ -264,24 +280,32 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli

if stage_manager.is_interleave:
layers_per_stage = self.distribute_layers(
len(module.encoder.layer), stage_manager.num_stages * stage_manager.num_model_chunks
len(module.encoder.layer),
stage_manager.num_stages * stage_manager.num_model_chunks,
)
stage_manager.stage_indices = Policy.get_stage_index(
stage_manager.stage_indices = self.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
method_replacement = {
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
"forward": partial(
new_forward,
stage_manager=stage_manager,
shard_config=self.shard_config,
)
}

else:
layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
layers_per_stage = self.distributed_layers(len(module.encoder.layer), stage_manager.num_stages)
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=self.shard_config,
)
}

Expand All @@ -301,9 +325,10 @@ def get_held_layers(self) -> List[Module]:
if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None
layers_per_stage = self.distribute_layers(
len(module.encoder.layer), stage_manager.num_stages * stage_manager.num_model_chunks
len(module.encoder.layer),
stage_manager.num_stages * stage_manager.num_model_chunks,
)
stage_indices = Policy.get_stage_index(
stage_indices = self.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
Expand All @@ -320,7 +345,7 @@ def get_held_layers(self) -> List[Module]:
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.embeddings)
start_idx, end_idx = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.encoder.layer[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.pooler)
Expand All @@ -336,7 +361,9 @@ def module_policy(self):

if self.pipeline_stage_manager:
self.set_pipeline_forward(
model_cls=BertModel, new_forward=BertPipelineForwards.bert_model_forward, policy=policy
model_cls=BertModel,
new_forward=BertPipelineForwards.bert_model_forward,
policy=policy,
)
return policy

Expand Down Expand Up @@ -399,7 +426,9 @@ def module_policy(self):

if self.pipeline_stage_manager:
self.set_pipeline_forward(
model_cls=BertLMHeadModel, new_forward=BertPipelineForwards.bert_lm_head_model_forward, policy=policy
model_cls=BertLMHeadModel,
new_forward=BertPipelineForwards.bert_lm_head_model_forward,
policy=policy,
)
return policy

Expand Down Expand Up @@ -437,7 +466,9 @@ def module_policy(self):

if self.pipeline_stage_manager:
self.set_pipeline_forward(
model_cls=BertForMaskedLM, new_forward=BertPipelineForwards.bert_for_masked_lm_forward, policy=policy
model_cls=BertForMaskedLM,
new_forward=BertPipelineForwards.bert_for_masked_lm_forward,
policy=policy,
)
return policy

Expand Down
2 changes: 1 addition & 1 deletion colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def get_held_layers(self) -> List[Module]:
len(module.layers),
stage_manager.num_stages * stage_manager.num_model_chunks,
)
stage_indices = Policy.get_stage_index(
stage_indices = self.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
Expand Down

0 comments on commit fb2f2b2

Please sign in to comment.