-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Tensor] add module handler for linear #1021
Conversation
Wesley-Jzy
commented
May 24, 2022
- add module spec function
- add ColoLinear as a demo
- add unit test
colossalai/tensor/module_utils.py
Outdated
for param_name in param_names: | ||
param = module.get_parameter(param_name) | ||
if not isinstance(param, ColoParameter): | ||
print(f'Invalid ColoParameter spec: {param} in {module} is not a ColoParameter .') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move the error message to Exception
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
def _register_shard_params(self, params: List[str]): | ||
self._shard_params = params | ||
|
||
def _register_allowed_patterns(self, compute_pattern: ComputePattern, dist_specs: Dict[str, _DistSpec], label='default'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use enum class for label... str will bring typos.
What is the purpose to add module spec function? |
It sets all params in a module and saves the allowed spec rules to avoid the unmatched spec in every param. Also, it can set the whole model with init_colo_module() after all ops are registered in init context. |
for param_name, dist_spec in colo_module.get_dist_specs_with_label(compute_pattern, label=label).items(): | ||
if dist_spec is None: | ||
continue | ||
param = module.get_parameter(param_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Matching two names between colo_module and module via a string. Did it robust enough to fit the nesting name? like attention.linear.weight
Do you consider the name nesting? A more complicated model may prove it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a simple_net test.
|
||
register_colo_module(torch.nn.Linear, ColoLinear()) | ||
parallel_action = ParallelAction(ComputePattern.TP1D) | ||
init_colo_module(model, parallel_action, recursive=True, label=label) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
recursive True is meanless in a module with only Linear.
More complicated test case will test you proposed features.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a simple_net test.