diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index ebea374b8..196b868e4 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -222,7 +222,8 @@ def find_and_update_line_node( def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: # TODO: Ensure that this class inherits from unittest.TestCase. Don't modify non unittest.TestCase classes. - for inner_node in ast.walk(node): + # Iterate only over direct children for efficiency. + for inner_node in node.body: if isinstance(inner_node, ast.FunctionDef): self.visit_FunctionDef(inner_node, node.name) elif isinstance(inner_node, ast.AsyncFunctionDef): @@ -269,20 +270,19 @@ def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = line_node = node.body[i] # TODO: Validate if the functional call actually did not raise any exceptions + # Fast path: operate directly on the node bodies, only calling find_and_update_line_node on stmts if isinstance(line_node, (ast.With, ast.For, ast.While, ast.If)): j = len(line_node.body) - 1 while j >= 0: compound_line_node: ast.stmt = line_node.body[j] - internal_node: ast.AST - for internal_node in ast.walk(compound_line_node): - if isinstance(internal_node, (ast.stmt, ast.Assign)): - updated_node = self.find_and_update_line_node( - internal_node, node.name, str(i) + "_" + str(j), test_class_name - ) - if updated_node is not None: - line_node.body[j : j + 1] = updated_node - did_update = True - break + updated_node = self.find_and_update_line_node( + compound_line_node, node.name, f"{i}_{j}", test_class_name + ) + if updated_node is not None: + line_node.body[j : j + 1] = updated_node + did_update = True + # break out after updating, as in the original logic + break j -= 1 else: updated_node = self.find_and_update_line_node(line_node, node.name, str(i), test_class_name)