-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[autoparallel] Add metainfo support for F.linear #1987
[autoparallel] Add metainfo support for F.linear #1987
Conversation
Merge ColossalAI
Daily merge
…r30/ColossalAI into feature/metainfo_for_auto_parallel
…r30/ColossalAI into feature/metainfo_for_auto_parallel
@@ -65,7 +68,7 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L | |||
has_bias: bool = False | |||
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data | |||
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data | |||
weight_tensor = next(filter(lambda x: x.name == 'weight', args)).data | |||
weight_tensors = [x.data for x in args if x.type == OperationDataType.PARAM] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modify this part for more robust code
assert meta_register.has(self._target.__class__), f'{self._target.__class__} not found in the meta registry' | ||
meta_func = meta_register.get(self._target.__class__) | ||
try: | ||
# module |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Support the case that node.op == “call_function”
@@ -104,8 +106,12 @@ def mem_test_for_node_strategy(rank: int, | |||
) | |||
|
|||
# estimated memory | |||
metainfo = MetaInfo(target_node.strategies_vector[strategy_index], | |||
target_node.graph.owning_module.get_submodule(target_node.target)) | |||
if target_node.op == "call_module": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modify this part to support node.op == “call_function”
What’s New?
In this PR, I done some work to support
torch.nn.functional.linear
in our metainfo generation, the memory estimation results are aligned withtorch.nn.Linear
(without bias). And though the biased linear is now separated by us into matmul and bias add, I still retain the part that generate metainfo for biased linear for future use.