Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tensor] add module handler for linear #1021

Merged
merged 7 commits into from
May 26, 2022

Conversation

Wesley-Jzy
Copy link
Contributor

  1. add module spec function
  2. add ColoLinear as a demo
  3. add unit test

colossalai/tensor/module_utils.py Outdated Show resolved Hide resolved
for param_name in param_names:
param = module.get_parameter(param_name)
if not isinstance(param, ColoParameter):
print(f'Invalid ColoParameter spec: {param} in {module} is not a ColoParameter .')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move the error message to Exception

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

colossalai/tensor/module_utils.py Outdated Show resolved Hide resolved
colossalai/tensor/module_utils.py Outdated Show resolved Hide resolved
colossalai/tensor/module_utils.py Outdated Show resolved Hide resolved
def _register_shard_params(self, params: List[str]):
self._shard_params = params

def _register_allowed_patterns(self, compute_pattern: ComputePattern, dist_specs: Dict[str, _DistSpec], label='default'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use enum class for label... str will bring typos.

tests/test_tensor/test_module_spec.py Show resolved Hide resolved
@feifeibear
Copy link
Contributor

What is the purpose to add module spec function?
I see you still need to set weight and bias spec? Is that make the param dist spec setting simple?

@Wesley-Jzy
Copy link
Contributor Author

What is the purpose to add module spec function? I see you still need to set weight and bias spec? Is that make the param dist spec setting simple?

It sets all params in a module and saves the allowed spec rules to avoid the unmatched spec in every param. Also, it can set the whole model with init_colo_module() after all ops are registered in init context.

for param_name, dist_spec in colo_module.get_dist_specs_with_label(compute_pattern, label=label).items():
if dist_spec is None:
continue
param = module.get_parameter(param_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Matching two names between colo_module and module via a string. Did it robust enough to fit the nesting name? like attention.linear.weight
Do you consider the name nesting? A more complicated model may prove it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a simple_net test.


register_colo_module(torch.nn.Linear, ColoLinear())
parallel_action = ParallelAction(ComputePattern.TP1D)
init_colo_module(model, parallel_action, recursive=True, label=label)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recursive True is meanless in a module with only Linear.
More complicated test case will test you proposed features.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a simple_net test.

@Wesley-Jzy Wesley-Jzy merged commit 32291dd into hpcaitech:main May 26, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants