Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx] PoC of runtime shape consistency application #1607

Merged

Conversation

YuliangLiu0306
Copy link
Contributor

No description provided.

return shape_consistency_manager.apply(*args, **kwargs)


def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is solution a list of int?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The solution of solver is a list of int, the value of each element stands for the best strategy of the node.



def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh):
mod_graph = gm.graph
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is a mod_graph?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model graph

origin_sharding_spec = ShardingSpec(device_mesh, target_module.weight.shape, {})
setattr(target_module.weight, 'sharding_spec', origin_sharding_spec)
target_weight_sharding_spec = node.best_strategy.input_shardings[1]
target_module.weight.data = target_module.weight.data.permute((1, 0, 2, 3))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why permute?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is because that conv/linear weight is in the desired shape, I can accept it now but we should handle this in NodeHandler.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I just find this problem during test, I will fix it in future PR.

with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function', apply, args=(node, sharding_spec_node))

gm.recompile()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recompile should be only called when all passes finish.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

sharding_spec_dict, origin_spec_dict = solution_annotatation_pass(gm, solution, device_mesh)
shape_consistency_pass(gm)
nodes = [node for node in gm.graph.nodes]
output = gm(input, sharding_spec_dict, origin_spec_dict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Such usage is kind of not intuitive, I would recommend to stick to gm(input) in the future but I can let it pass for now. We can annotate with a TODO tag.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@FrankLeeeee FrankLeeeee merged commit 7d1bb71 into hpcaitech:main Sep 20, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants