Skip to content

Commit

Permalink
HFTracer.trace should use/return self.graph to be compatible with tor…
Browse files Browse the repository at this point in the history
…ch.fx.Tracer (huggingface#15824)
  • Loading branch information
pbelevich authored and James Reed committed Mar 8, 2022
1 parent 877a4b5 commit 876c079
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/transformers/utils/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def trace(

self._patch_leaf_functions_for_root(root)

graph = super().trace(root, concrete_args=concrete_args)
self.graph = super().trace(root, concrete_args=concrete_args)

self._patch_leaf_functions_for_root(root, restore=True)

Expand All @@ -482,16 +482,16 @@ def trace(
# This is necessary because concrete args are added as input to the traced module since
# https://github.com/pytorch/pytorch/pull/55888.
# A PR that solves this was posted: https://github.com/pytorch/pytorch/pull/59569 but it was not merged yet.
for node in graph.nodes:
for node in self.graph.nodes:
if node.op == "placeholder":
# Removing default values for inputs as the forward pass will fail with them.
if node.target in input_names:
node.args = ()
# It is a concrete arg so it is not used and should be removed.
else:
graph.erase_node(node)
self.graph.erase_node(node)

return graph
return self.graph

def _insert_module_as_submodule(self, mod: nn.Module) -> str:
"""
Expand Down

0 comments on commit 876c079

Please sign in to comment.