Skip to content

Commit

Permalink
style: polish code
Browse files Browse the repository at this point in the history
  • Loading branch information
CWHer committed Mar 25, 2024
1 parent ca5b811 commit 09d6945
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 40 deletions.
77 changes: 41 additions & 36 deletions examples/language/openmoe/model/openmoe_policy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from functools import partial
from typing import Callable, Dict, List, Optional, Union

Expand All @@ -21,7 +20,6 @@


class OpenMoePolicy(Policy):

def config_sanity_check(self):
pass

Expand All @@ -43,7 +41,8 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
raise NotImplementedError(
"openmoe doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
"openmoe doesn't support sequence parallelism now, will ignore the sequence parallelism flag."
)

if self.shard_config.enable_tensor_parallelism:
raise NotImplementedError("Tensor parallelism is not supported for openmoe model now.")
Expand Down Expand Up @@ -143,7 +142,6 @@ def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]:


class OpenMoeModelPolicy(OpenMoePolicy):

def __init__(self) -> None:
super().__init__()

Expand All @@ -169,21 +167,21 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:


class OpenMoeForCausalLMPolicy(OpenMoePolicy):

def module_policy(self):
policy = super().module_policy()

if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
new_item = {
OpenMoeForCausalLM:
ModulePolicyDescription(sub_module_replacement=[
OpenMoeForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True),
)
])
]
)
}
policy.update(new_item)

Expand All @@ -208,13 +206,17 @@ def get_held_layers(self) -> List[Module]:
def get_shared_params(self) -> List[Dict[int, Tensor]]:
llama_model = self.model.model
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
if (id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
and self.pipeline_stage_manager.num_stages > 1):
if (
id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
and self.pipeline_stage_manager.num_stages > 1
):
# tie weights
return [{
0: llama_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
}]
return [
{
0: llama_model.embed_tokens.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
}
]
return []


Expand Down Expand Up @@ -247,12 +249,13 @@ def openmoe_model_forward(

logger = logging.get_logger(__name__)

output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache

return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# retrieve input_ids and inputs_embeds
if stage_manager.is_first_stage():
Expand Down Expand Up @@ -320,7 +323,8 @@ def openmoe_model_forward(
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

# decoder layers
Expand All @@ -333,12 +337,11 @@ def openmoe_model_forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)

past_key_value = (past_key_values[idx] if past_key_values is not None else None)
past_key_value = past_key_values[idx] if past_key_values is not None else None

if self.gradient_checkpointing and self.training:

def create_custom_forward(module):

def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
Expand Down Expand Up @@ -384,14 +387,16 @@ def custom_forward(*inputs):
router_z_loss = past_router_z_loss + router_z_loss

if stage_manager.is_last_stage():
return tuple([
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
router_aux_loss,
router_z_loss,
])
return tuple(
[
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
router_aux_loss,
router_z_loss,
]
)
# always return dict for imediate stage
return {
"hidden_states": hidden_states,
Expand Down Expand Up @@ -445,10 +450,11 @@ def llama_for_causal_lm_forward(
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
```"""
logger = logging.get_logger(__name__)
output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
Expand Down Expand Up @@ -504,7 +510,6 @@ def llama_for_causal_lm_forward(
if chunk_head == True:

def create_custom_forward(module):

def custom_forward(*inputs):
logits = module(inputs[0])
logits = logits.float()
Expand All @@ -522,8 +527,8 @@ def custom_forward(*inputs):
for batch_idx in range(hidden_states.shape[0]):
loss = loss + torch.utils.checkpoint.checkpoint(
create_custom_forward(self.lm_head),
hidden_states[batch_idx:batch_idx + 1, :],
labels[batch_idx:batch_idx + 1, :],
hidden_states[batch_idx : batch_idx + 1, :],
labels[batch_idx : batch_idx + 1, :],
)
logits = None
else:
Expand Down
8 changes: 4 additions & 4 deletions tests/test_booster/test_plugin/test_3d_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _criterion(outputs, inputs):

@parameterize("init_method", ["none", "lazy"])
def check_3d_plugin(init_method: str = "none", early_stop: bool = True):
"""check gemini plugin over model zoo
"""check hybrid plugin over model zoo
Args:
early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True.
Expand Down Expand Up @@ -260,7 +260,7 @@ def run_grad_acc_test(test_args):
origin_model, origin_optimizer, dataloader=dataloader
)
for p1, p2 in zip(model.unwrap().parameters(), origin_model.unwrap().parameters()):
assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2)
assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2)


def run_dist(rank, world_size, port, early_stop: bool = True):
Expand All @@ -271,9 +271,9 @@ def run_dist(rank, world_size, port, early_stop: bool = True):


@rerun_if_address_is_in_use()
def test_gemini_plugin(early_stop: bool = True):
def test_3d_plugin(early_stop: bool = True):
spawn(run_dist, 4, early_stop=early_stop)


if __name__ == "__main__":
test_gemini_plugin(early_stop=False)
test_3d_plugin(early_stop=False)

0 comments on commit 09d6945

Please sign in to comment.