Skip to content

Commit

Permalink
[fx][split] Copy node metadata for placeholders
Browse files Browse the repository at this point in the history
- Follow-up for pytorch#107248 which copies metadata for placeholder nodes in
the top-level FX graph
  • Loading branch information
gs-olive committed Aug 25, 2023
1 parent 7ef13b1 commit 47b4b5e
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions torch/fx/passes/split_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def flatten(x: torch.fx.node.Argument) -> NodeList:
# Placeholders in the original graph get copied to main graph.
if node.op == "placeholder":
main_remapping[node] = main_g.placeholder(node.name, type_expr=node.type)
main_remapping[node].meta = copy.copy(node.meta)
continue

# Get_attr nodes are ignored because we are not tagging them.
Expand Down

0 comments on commit 47b4b5e

Please sign in to comment.