Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 40 additions & 17 deletions codeflash/code_utils/edit_generated_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,29 +61,52 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctio

def _process_function_def_common(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None:
self.context_stack.append(node.name)
i = len(node.body) - 1
node_body = node.body
node_body_len = len(node_body)
test_qualified_name = ".".join(self.context_stack)
key = test_qualified_name + "#" + str(self.abs_path)

# Precompute key prefix
key_prefix = f"{test_qualified_name}#{self.abs_path}"

i = node_body_len - 1
orig_rt = self.original_runtimes
opt_rt = self.optimized_runtimes
get_comment = self.get_comment

# Hoist isinstance tuple constants out of loop
compound_types = (ast.With, ast.For, ast.While, ast.If)
valid_types = (ast.stmt, ast.Assign)

while i >= 0:
line_node = node.body[i]
if isinstance(line_node, (ast.With, ast.For, ast.While, ast.If)):
j = len(line_node.body) - 1
line_node = node_body[i]
if isinstance(line_node, compound_types):
line_node_body = line_node.body
j = len(line_node_body) - 1
nodes_to_check_append = nodes_to_check_extend = None # Avoids local lookups
while j >= 0:
compound_line_node: ast.stmt = line_node.body[j]
nodes_to_check = [compound_line_node]
nodes_to_check.extend(getattr(compound_line_node, "body", []))
for internal_node in nodes_to_check:
if isinstance(internal_node, (ast.stmt, ast.Assign)):
inv_id = str(i) + "_" + str(j)
match_key = key + "#" + inv_id
if match_key in self.original_runtimes and match_key in self.optimized_runtimes:
self.results[internal_node.lineno] = self.get_comment(match_key)
compound_line_node: ast.stmt = line_node_body[j]
# Pre-extend only if there's a .body attribute, avoid repeated getattr cost
compound_line_node_body = getattr(compound_line_node, "body", None)
if compound_line_node_body:
nodes_to_check = [compound_line_node, *compound_line_node_body]
else:
nodes_to_check = [compound_line_node]

inv_id = f"{i}_{j}"
match_key = f"{key_prefix}#{inv_id}"

if match_key in orig_rt and match_key in opt_rt:
comment = get_comment(match_key)
# Avoid repeated isinstance - enumerate only actual stmt/Assign
for internal_node in nodes_to_check:
if isinstance(internal_node, valid_types):
self.results[internal_node.lineno] = comment
j -= 1
else:
inv_id = str(i)
match_key = key + "#" + inv_id
if match_key in self.original_runtimes and match_key in self.optimized_runtimes:
self.results[line_node.lineno] = self.get_comment(match_key)
match_key = f"{key_prefix}#{inv_id}"
if match_key in orig_rt and match_key in opt_rt:
self.results[line_node.lineno] = get_comment(match_key)
i -= 1
self.context_stack.pop()

Expand Down
Loading