diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index dd2724e06a093..b88ae4ae7a343 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -472,7 +472,7 @@ def trace( self._patch_leaf_functions_for_root(root) - graph = super().trace(root, concrete_args=concrete_args) + self.graph = super().trace(root, concrete_args=concrete_args) self._patch_leaf_functions_for_root(root, restore=True) @@ -482,16 +482,16 @@ def trace( # This is necessary because concrete args are added as input to the traced module since # https://github.com/pytorch/pytorch/pull/55888. # A PR that solves this was posted: https://github.com/pytorch/pytorch/pull/59569 but it was not merged yet. - for node in graph.nodes: + for node in self.graph.nodes: if node.op == "placeholder": # Removing default values for inputs as the forward pass will fail with them. if node.target in input_names: node.args = () # It is a concrete arg so it is not used and should be removed. else: - graph.erase_node(node) + self.graph.erase_node(node) - return graph + return self.graph def _insert_module_as_submodule(self, mod: nn.Module) -> str: """