New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fx] PoC of runtime shape consistency application #1607
[fx] PoC of runtime shape consistency application #1607
Conversation
return shape_consistency_manager.apply(*args, **kwargs) | ||
|
||
|
||
def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is solution a list of int?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The solution of solver is a list of int, the value of each element stands for the best strategy of the node.
|
||
|
||
def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh): | ||
mod_graph = gm.graph |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is a mod_graph?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model graph
origin_sharding_spec = ShardingSpec(device_mesh, target_module.weight.shape, {}) | ||
setattr(target_module.weight, 'sharding_spec', origin_sharding_spec) | ||
target_weight_sharding_spec = node.best_strategy.input_shardings[1] | ||
target_module.weight.data = target_module.weight.data.permute((1, 0, 2, 3)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why permute?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is because that conv/linear weight is in the desired shape, I can accept it now but we should handle this in NodeHandler
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I just find this problem during test, I will fix it in future PR.
with mod_graph.inserting_before(user_node): | ||
shape_consistency_node = mod_graph.create_node('call_function', apply, args=(node, sharding_spec_node)) | ||
|
||
gm.recompile() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
recompile
should be only called when all passes finish.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
sharding_spec_dict, origin_spec_dict = solution_annotatation_pass(gm, solution, device_mesh) | ||
shape_consistency_pass(gm) | ||
nodes = [node for node in gm.graph.nodes] | ||
output = gm(input, sharding_spec_dict, origin_spec_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Such usage is kind of not intuitive, I would recommend to stick to gm(input)
in the future but I can let it pass for now. We can annotate with a TODO
tag.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
No description provided.