Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[autoparallel] Add metainfo support for F.linear #1987

Merged

Conversation

Cypher30
Copy link
Contributor

What’s New?

In this PR, I done some work to support torch.nn.functional.linear in our metainfo generation, the memory estimation results are aligned with torch.nn.Linear (without bias). And though the biased linear is now separated by us into matmul and bias add, I still retain the part that generate metainfo for biased linear for future use.

Cypher30 and others added 30 commits July 14, 2022 16:07
@@ -65,7 +68,7 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
has_bias: bool = False
input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
weight_tensor = next(filter(lambda x: x.name == 'weight', args)).data
weight_tensors = [x.data for x in args if x.type == OperationDataType.PARAM]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modify this part for more robust code

assert meta_register.has(self._target.__class__), f'{self._target.__class__} not found in the meta registry'
meta_func = meta_register.get(self._target.__class__)
try:
# module
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Support the case that node.op == “call_function”

@@ -104,8 +106,12 @@ def mem_test_for_node_strategy(rank: int,
)

# estimated memory
metainfo = MetaInfo(target_node.strategies_vector[strategy_index],
target_node.graph.owning_module.get_submodule(target_node.target))
if target_node.op == "call_module":
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modify this part to support node.op == “call_function”

@YuliangLiu0306 YuliangLiu0306 merged commit 6cd784f into hpcaitech:main Nov 23, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants