Skip to content

Conversation

@codeflash-ai
Copy link
Contributor

@codeflash-ai codeflash-ai bot commented Sep 26, 2025

⚡️ This pull request contains optimizations for PR #769

If you approve this dependent PR, these changes will be merged into the original PR branch clean-async-branch.

This PR will be automatically closed if the original PR is merged.


📄 24% (0.24x) speedup for InjectPerfOnly.find_and_update_line_node in codeflash/code_utils/instrument_existing_tests.py

⏱️ Runtime : 21.6 milliseconds 17.4 milliseconds (best of 49 runs)

📝 Explanation and details

The optimization achieves a 24% speedup by targeting two key performance bottlenecks identified in the line profiler results:

1. Optimized node_in_call_position function (~22% faster):

  • Reduced attribute lookups: Pre-fetches lineno, col_offset, end_lineno, and end_col_offset once using getattr() instead of repeatedly calling hasattr() and accessing attributes in the loop
  • Early exit optimization: Returns False immediately if not an ast.Call node, avoiding unnecessary work
  • Simplified conditional logic: Combines nested checks into a single block to reduce Python opcode jumps

2. Optimized find_and_update_line_node method (~18% faster):

  • Cached attribute access: Stores frequently accessed attributes (self.function_object.function_name, self.mode, etc.) in local variables to avoid repeated object attribute lookups
  • Efficient list construction: Builds the args list incrementally using extend() instead of creating multiple intermediate lists with unpacking operators
  • Early termination: Breaks immediately after finding and modifying the matching call node, avoiding unnecessary continuation of ast.walk()

Performance gains are most significant for:

  • Large-scale test cases with many function calls (up to 38% faster for 500+ calls)
  • Mixed workloads with calls and non-calls (25% faster)
  • Tests with keyword arguments (13% faster)

The optimizations maintain identical behavior while reducing CPU-intensive operations like attribute lookups and list operations that dominate the execution time in AST transformation workflows.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 1685 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 91.4%
🌀 Generated Regression Tests and Runtime
from __future__ import annotations

import ast
from collections.abc import Iterable
from typing import List, Optional

# imports
import pytest  # used for our unit tests
from codeflash.code_utils.instrument_existing_tests import InjectPerfOnly


# Minimal stubs for dependencies
class FunctionToOptimize:
    def __init__(self, function_name, qualified_name=None, is_async=False, parents=None, top_level_parent_name=None):
        self.function_name = function_name
        self.qualified_name = qualified_name or function_name
        self.is_async = is_async
        self.parents = parents or []
        self.top_level_parent_name = top_level_parent_name

class CodePosition:
    def __init__(self, line_no, col_no, end_col_offset=None):
        self.line_no = line_no
        self.col_no = col_no
        self.end_col_offset = end_col_offset

class Parent:
    def __init__(self, type_):
        self.type = type_

class TestingMode:
    BEHAVIOR = "BEHAVIOR"
    PERF = "PERF"
from codeflash.code_utils.instrument_existing_tests import InjectPerfOnly


# Helper for building AST nodes with position info
def parse_with_positions(source: str) -> ast.stmt:
    # Parse source and set missing end_lineno/end_col_offset for compatibility
    node = ast.parse(source).body[0]
    for n in ast.walk(node):
        if hasattr(n, "lineno") and not hasattr(n, "end_lineno"):
            n.end_lineno = n.lineno
        if hasattr(n, "col_offset") and not hasattr(n, "end_col_offset"):
            n.end_col_offset = n.col_offset + 1
    return node

# Helper to extract call node from AST
def get_call_node(stmt: ast.stmt) -> Optional[ast.Call]:
    for n in ast.walk(stmt):
        if isinstance(n, ast.Call):
            return n
    return None

# ========== UNIT TESTS ==========

# ----------- BASIC TEST CASES -----------

def test_basic_name_call_is_updated():
    """Test that a simple function call by name is wrapped if it matches the position."""
    # Setup
    src = "result = myfunc(1, 2)"
    test_node = parse_with_positions(src)
    # myfunc is at line 1, col 9
    call_pos = [CodePosition(line_no=1, col_no=9)]
    function = FunctionToOptimize("myfunc")
    inj = InjectPerfOnly(function, "modpath", "pytest", call_pos)
    # Act
    codeflash_output = inj.find_and_update_line_node(test_node, "testnode", "0"); out = codeflash_output # 13.3μs -> 13.4μs (0.493% slower)
    call = get_call_node(out[0])
    # Should include codeflash_loop_index
    found = any(isinstance(a, ast.Name) and a.id == "codeflash_loop_index" for a in call.args)

def test_basic_attribute_call_is_updated():
    """Test that a method call via attribute is wrapped if it matches the position."""
    src = "result = obj.myfunc(3, 4)"
    test_node = parse_with_positions(src)
    call_pos = [CodePosition(1, 13)]
    function = FunctionToOptimize("myfunc")
    inj = InjectPerfOnly(function, "modpath", "pytest", call_pos)
    codeflash_output = inj.find_and_update_line_node(test_node, "testnode", "1"); out = codeflash_output # 24.3μs -> 24.9μs (2.23% slower)
    call = get_call_node(out[0])

def test_basic_no_matching_call_returns_none():
    """Test that if no call matches the position, returns None."""
    src = "result = myfunc(1, 2)"
    test_node = parse_with_positions(src)
    # Position does not match the call
    call_pos = [CodePosition(2, 0)]
    function = FunctionToOptimize("myfunc")
    inj = InjectPerfOnly(function, "modpath", "pytest", call_pos)
    codeflash_output = inj.find_and_update_line_node(test_node, "testnode", "0"); out = codeflash_output # 8.85μs -> 9.20μs (3.84% slower)

def test_basic_async_function_returns_original():
    """Test that if the function is async, the node is returned unchanged."""
    src = "result = myfunc(1, 2)"
    test_node = parse_with_positions(src)
    call_pos = [CodePosition(1, 9)]
    function = FunctionToOptimize("myfunc", is_async=True)
    inj = InjectPerfOnly(function, "modpath", "pytest", call_pos)
    codeflash_output = inj.find_and_update_line_node(test_node, "testnode", "0"); out = codeflash_output # 7.23μs -> 7.11μs (1.59% faster)
    # Should not be wrapped
    call = get_call_node(out[0])

def test_basic_behavior_mode_includes_cur_and_con():
    """Test that in BEHAVIOR mode, codeflash_cur and codeflash_con are included."""
    src = "result = myfunc(1)"
    test_node = parse_with_positions(src)
    call_pos = [CodePosition(1, 9)]
    function = FunctionToOptimize("myfunc")
    inj = InjectPerfOnly(function, "modpath", "pytest", call_pos, mode=TestingMode.BEHAVIOR)
    codeflash_output = inj.find_and_update_line_node(test_node, "testnode", "0"); out = codeflash_output # 12.0μs -> 11.5μs (4.34% faster)
    call = get_call_node(out[0])
    arg_ids = [a.id for a in call.args if isinstance(a, ast.Name)]

def test_basic_perf_mode_excludes_cur_and_con():
    """Test that in PERF mode, codeflash_cur and codeflash_con are NOT included."""
    src = "result = myfunc(1)"
    test_node = parse_with_positions(src)
    call_pos = [CodePosition(1, 9)]
    function = FunctionToOptimize("myfunc")
    inj = InjectPerfOnly(function, "modpath", "pytest", call_pos, mode=TestingMode.PERF)
    codeflash_output = inj.find_and_update_line_node(test_node, "testnode", "0"); out = codeflash_output # 11.5μs -> 11.4μs (1.24% faster)
    call = get_call_node(out[0])
    arg_ids = [a.id for a in call.args if isinstance(a, ast.Name)]

# ----------- EDGE TEST CASES -----------

def test_edge_call_at_end_of_line():
    """Test call at the end of line with end_col_offset."""
    src = "result = myfunc(1, 2)"
    test_node = parse_with_positions(src)
    # The col_no is at the end of the call
    call_pos = [CodePosition(1, test_node.end_col_offset)]
    function = FunctionToOptimize("myfunc")
    inj = InjectPerfOnly(function, "modpath", "pytest", call_pos)
    codeflash_output = inj.find_and_update_line_node(test_node, "testnode", "0"); out = codeflash_output # 12.3μs -> 12.4μs (1.00% slower)
    call = get_call_node(out[0])

def test_edge_multiple_calls_only_one_updated():
    """Test that only the call at the matching position is updated."""
    src = "a = myfunc(1)\nb = myfunc(2)"
    mod = ast.parse(src)
    # Set positions for both calls
    for n in ast.walk(mod):
        if hasattr(n, "lineno") and not hasattr(n, "end_lineno"):
            n.end_lineno = n.lineno
        if hasattr(n, "col_offset") and not hasattr(n, "end_col_offset"):
            n.end_col_offset = n.col_offset + 1
    test_node = mod.body[1]  # second line
    call_pos = [CodePosition(2, 6)]  # position of myfunc(2)
    function = FunctionToOptimize("myfunc")
    inj = InjectPerfOnly(function, "modpath", "pytest", call_pos)
    codeflash_output = inj.find_and_update_line_node(test_node, "testnode", "1"); out = codeflash_output # 12.0μs -> 12.2μs (1.74% slower)
    call = get_call_node(out[0])
    # First line should not be affected
    test_node1 = mod.body[0]
    codeflash_output = inj.find_and_update_line_node(test_node1, "testnode", "0"); out1 = codeflash_output # 6.67μs -> 7.17μs (7.02% slower)

def test_edge_call_with_keywords():
    """Test that keyword arguments are preserved."""
    src = "result = myfunc(a=1, b=2)"
    test_node = parse_with_positions(src)
    call_pos = [CodePosition(1, 9)]
    function = FunctionToOptimize("myfunc")
    inj = InjectPerfOnly(function, "modpath", "pytest", call_pos)
    codeflash_output = inj.find_and_update_line_node(test_node, "testnode", "0"); out = codeflash_output # 12.2μs -> 12.0μs (1.00% faster)
    call = get_call_node(out[0])

def test_edge_call_with_class_parent():
    """Test that class_name is passed if parent is ClassDef."""
    src = "result = myfunc(1)"
    test_node = parse_with_positions(src)
    call_pos = [CodePosition(1, 9)]
    parents = [Parent("ClassDef")]
    function = FunctionToOptimize("myfunc", parents=parents, top_level_parent_name="TestClass")
    inj = InjectPerfOnly(function, "modpath", "pytest", call_pos)
    codeflash_output = inj.find_and_update_line_node(test_node, "testnode", "0", test_class_name="TestClass"); out = codeflash_output # 12.7μs -> 12.5μs (1.17% faster)
    call = get_call_node(out[0])

def test_edge_call_with_none_class_name():
    """Test that None is passed if no class_name is provided."""
    src = "result = myfunc(1)"
    test_node = parse_with_positions(src)
    call_pos = [CodePosition(1, 9)]
    function = FunctionToOptimize("myfunc")
    inj = InjectPerfOnly(function, "modpath", "pytest", call_pos)
    codeflash_output = inj.find_and_update_line_node(test_node, "testnode", "0"); out = codeflash_output # 11.9μs -> 11.6μs (2.02% faster)
    call = get_call_node(out[0])

def test_edge_call_with_no_args():
    """Test call with no arguments is wrapped correctly."""
    src = "result = myfunc()"
    test_node = parse_with_positions(src)
    call_pos = [CodePosition(1, 9)]
    function = FunctionToOptimize("myfunc")
    inj = InjectPerfOnly(function, "modpath", "pytest", call_pos)
    codeflash_output = inj.find_and_update_line_node(test_node, "testnode", "0"); out = codeflash_output # 11.8μs -> 11.8μs (0.262% slower)
    call = get_call_node(out[0])

def test_edge_call_with_multiple_args():
    """Test call with multiple positional arguments."""
    src = "result = myfunc(1, 2, 3, 4)"
    test_node = parse_with_positions(src)
    call_pos = [CodePosition(1, 9)]
    function = FunctionToOptimize("myfunc")
    inj = InjectPerfOnly(function, "modpath", "pytest", call_pos)
    codeflash_output = inj.find_and_update_line_node(test_node, "testnode", "0"); out = codeflash_output # 12.1μs -> 11.9μs (1.25% faster)
    call = get_call_node(out[0])
    # Last arguments should be 1,2,3,4
    values = [a.value for a in call.args if isinstance(a, ast.Constant) and a.value in [1,2,3,4]]

def test_edge_call_with_nested_calls():
    """Test that only the outer call is wrapped if it matches position."""
    src = "result = myfunc(otherfunc(1))"
    test_node = parse_with_positions(src)
    call_pos = [CodePosition(1, 9)]
    function = FunctionToOptimize("myfunc")
    inj = InjectPerfOnly(function, "modpath", "pytest", call_pos)
    codeflash_output = inj.find_and_update_line_node(test_node, "testnode", "0"); out = codeflash_output # 11.8μs -> 11.8μs (0.363% slower)
    call = get_call_node(out[0])
    # The nested call should remain unchanged
    nested_call = [n for n in ast.walk(call) if isinstance(n, ast.Call) and n is not call]

# ----------- LARGE SCALE TEST CASES -----------

def test_large_scale_many_calls():
    """Test performance and correctness with many calls (up to 1000)."""
    src_lines = [f"result{i} = myfunc({i})" for i in range(1000)]
    src = "\n".join(src_lines)
    mod = ast.parse(src)
    # Set positions for all calls
    for n in ast.walk(mod):
        if hasattr(n, "lineno") and not hasattr(n, "end_lineno"):
            n.end_lineno = n.lineno
        if hasattr(n, "col_offset") and not hasattr(n, "end_col_offset"):
            n.end_col_offset = n.col_offset + 1
    # Test only the last one matches
    test_node = mod.body[-1]
    call_pos = [CodePosition(1000, 9)]
    function = FunctionToOptimize("myfunc")
    inj = InjectPerfOnly(function, "modpath", "pytest", call_pos)
    codeflash_output = inj.find_and_update_line_node(test_node, "testnode", "999"); out = codeflash_output # 20.8μs -> 20.2μs (2.97% faster)
    call = get_call_node(out[0])

def test_large_scale_no_matching_calls():
    """Test with many calls, none matching position."""
    src_lines = [f"result{i} = myfunc({i})" for i in range(1000)]
    src = "\n".join(src_lines)
    mod = ast.parse(src)
    for n in ast.walk(mod):
        if hasattr(n, "lineno") and not hasattr(n, "end_lineno"):
            n.end_lineno = n.lineno
        if hasattr(n, "col_offset") and not hasattr(n, "end_col_offset"):
            n.end_col_offset = n.col_offset + 1
    test_node = mod.body[500]
    # Position does not match any call
    call_pos = [CodePosition(2000, 9)]
    function = FunctionToOptimize("myfunc")
    inj = InjectPerfOnly(function, "modpath", "pytest", call_pos)
    codeflash_output = inj.find_and_update_line_node(test_node, "testnode", "500"); out = codeflash_output # 11.6μs -> 12.5μs (6.68% slower)

def test_large_scale_multiple_positions():
    """Test with many positions, only correct ones are updated."""
    src_lines = [f"result{i} = myfunc({i})" for i in range(1000)]
    src = "\n".join(src_lines)
    mod = ast.parse(src)
    for n in ast.walk(mod):
        if hasattr(n, "lineno") and not hasattr(n, "end_lineno"):
            n.end_lineno = n.lineno
        if hasattr(n, "col_offset") and not hasattr(n, "end_col_offset"):
            n.end_col_offset = n.col_offset + 1
    # Positions for every 100th call
    positions = [CodePosition(i+1, 9) for i in range(0, 1000, 100)]
    function = FunctionToOptimize("myfunc")
    inj = InjectPerfOnly(function, "modpath", "pytest", positions)
    # Only every 100th call should be wrapped
    for i in range(0, 1000, 100):
        test_node = mod.body[i]
        codeflash_output = inj.find_and_update_line_node(test_node, "testnode", str(i)); out = codeflash_output # 106μs -> 102μs (3.22% faster)
        call = get_call_node(out[0])
    # Others should not be wrapped
    for i in range(1, 1000, 100):
        test_node = mod.body[i]
        codeflash_output = inj.find_and_update_line_node(test_node, "testnode", str(i)); out = codeflash_output # 69.6μs -> 69.6μs (0.011% slower)

def test_large_scale_attribute_calls():
    """Test with many attribute calls (obj.myfunc) and correct wrapping."""
    src_lines = [f"result{i} = obj.myfunc({i})" for i in range(1000)]
    src = "\n".join(src_lines)
    mod = ast.parse(src)
    for n in ast.walk(mod):
        if hasattr(n, "lineno") and not hasattr(n, "end_lineno"):
            n.end_lineno = n.lineno
        if hasattr(n, "col_offset") and not hasattr(n, "end_col_offset"):
            n.end_col_offset = n.col_offset + 1
    # Only last call matches
    test_node = mod.body[-1]
    call_pos = [CodePosition(1000, 13)]
    function = FunctionToOptimize("myfunc")
    inj = InjectPerfOnly(function, "modpath", "pytest", call_pos)
    codeflash_output = inj.find_and_update_line_node(test_node, "testnode", "999"); out = codeflash_output # 36.4μs -> 39.1μs (6.95% slower)
    call = get_call_node(out[0])
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
from __future__ import annotations

import ast
from collections.abc import Iterable

# imports
import pytest  # used for our unit tests
from codeflash.code_utils.instrument_existing_tests import InjectPerfOnly


# Dummy classes for FunctionToOptimize, CodePosition, TestingMode
class CodePosition:
    def __init__(self, line_no, col_no, end_col_offset=None):
        self.line_no = line_no
        self.col_no = col_no
        self.end_col_offset = end_col_offset

class FunctionToOptimize:
    def __init__(self, function_name, qualified_name, is_async=False, parents=None, top_level_parent_name=None):
        self.function_name = function_name
        self.qualified_name = qualified_name
        self.is_async = is_async
        self.parents = parents or []
        self.top_level_parent_name = top_level_parent_name

class TestingMode:
    BEHAVIOR = "BEHAVIOR"
    PERF = "PERF"
from codeflash.code_utils.instrument_existing_tests import InjectPerfOnly


# Helper to parse code and attach lineno/col_offset/end_lineno/end_col_offset
def parse_stmt_with_positions(code: str) -> ast.stmt:
    # Parse code to AST
    mod = ast.parse(code)
    ast.increment_lineno(mod, 0)
    for node in ast.walk(mod):
        # Attach end_lineno and end_col_offset for all nodes if missing
        ast.fix_missing_locations(node)
    # Return the first statement
    return mod.body[0]

# Helper to extract the call node from a statement
def get_call_node(stmt):
    for node in ast.walk(stmt):
        if isinstance(node, ast.Call):
            return node
    return None

# ========== BASIC TEST CASES ==========

def test_basic_function_name_call_replacement():
    """Test a simple function call is replaced with codeflash_wrap and arguments are injected."""
    code = "result = foo(1, 2)"
    stmt = parse_stmt_with_positions(code)
    # Set positions to match the call
    call_node = get_call_node(stmt)
    call_node.lineno = stmt.lineno
    call_node.col_offset = 9
    call_node.end_lineno = stmt.lineno
    call_node.end_col_offset = 17
    call_positions = [CodePosition(line_no=stmt.lineno, col_no=10, end_col_offset=17)]
    func = FunctionToOptimize("foo", "foo", is_async=False)
    inj = InjectPerfOnly(func, "module.py", "pytest", call_positions)
    codeflash_output = inj.find_and_update_line_node(stmt, "nodeA", "idx1"); out = codeflash_output # 12.4μs -> 13.1μs (5.48% slower)
    updated_stmt = out[0]
    updated_call = get_call_node(updated_stmt)

def test_basic_attribute_call_replacement():
    """Test a method call (attribute) is replaced with codeflash_wrap and arguments are injected."""
    code = "result = obj.foo(3, 4)"
    stmt = parse_stmt_with_positions(code)
    call_node = get_call_node(stmt)
    call_node.lineno = stmt.lineno
    call_node.col_offset = 9
    call_node.end_lineno = stmt.lineno
    call_node.end_col_offset = 21
    call_positions = [CodePosition(line_no=stmt.lineno, col_no=10, end_col_offset=21)]
    func = FunctionToOptimize("foo", "obj.foo", is_async=False)
    inj = InjectPerfOnly(func, "module.py", "pytest", call_positions)
    codeflash_output = inj.find_and_update_line_node(stmt, "nodeB", "idx2", test_class_name="TestClass"); out = codeflash_output # 22.0μs -> 21.5μs (2.43% faster)
    updated_stmt = out[0]
    updated_call = get_call_node(updated_stmt)

def test_basic_async_function_returns_original():
    """Test that if is_async is True, the original node is returned unchanged."""
    code = "result = foo(5, 6)"
    stmt = parse_stmt_with_positions(code)
    call_node = get_call_node(stmt)
    call_node.lineno = stmt.lineno
    call_node.col_offset = 9
    call_node.end_lineno = stmt.lineno
    call_node.end_col_offset = 17
    call_positions = [CodePosition(line_no=stmt.lineno, col_no=10, end_col_offset=17)]
    func = FunctionToOptimize("foo", "foo", is_async=True)
    inj = InjectPerfOnly(func, "module.py", "pytest", call_positions)
    codeflash_output = inj.find_and_update_line_node(stmt, "nodeC", "idx3"); out = codeflash_output # 6.33μs -> 6.37μs (0.612% slower)
    orig_call = get_call_node(out[0])

def test_basic_behavior_mode_injects_extra_args():
    """Test that in BEHAVIOR mode, codeflash_cur and codeflash_con are injected."""
    code = "result = foo(7)"
    stmt = parse_stmt_with_positions(code)
    call_node = get_call_node(stmt)
    call_node.lineno = stmt.lineno
    call_node.col_offset = 9
    call_node.end_lineno = stmt.lineno
    call_node.end_col_offset = 14
    call_positions = [CodePosition(line_no=stmt.lineno, col_no=10, end_col_offset=14)]
    func = FunctionToOptimize("foo", "foo", is_async=False)
    inj = InjectPerfOnly(func, "module.py", "pytest", call_positions, mode=TestingMode.BEHAVIOR)
    codeflash_output = inj.find_and_update_line_node(stmt, "nodeD", "idx4"); out = codeflash_output # 11.0μs -> 10.7μs (2.11% faster)
    updated_call = get_call_node(out[0])
    # codeflash_cur and codeflash_con should be present
    arg_names = [a.id for a in updated_call.args if isinstance(a, ast.Name)]

def test_basic_perf_mode_does_not_inject_extra_args():
    """Test that in PERF mode, codeflash_cur and codeflash_con are NOT injected."""
    code = "result = foo(8)"
    stmt = parse_stmt_with_positions(code)
    call_node = get_call_node(stmt)
    call_node.lineno = stmt.lineno
    call_node.col_offset = 9
    call_node.end_lineno = stmt.lineno
    call_node.end_col_offset = 14
    call_positions = [CodePosition(line_no=stmt.lineno, col_no=10, end_col_offset=14)]
    func = FunctionToOptimize("foo", "foo", is_async=False)
    inj = InjectPerfOnly(func, "module.py", "pytest", call_positions, mode=TestingMode.PERF)
    codeflash_output = inj.find_and_update_line_node(stmt, "nodeE", "idx5"); out = codeflash_output # 10.6μs -> 10.4μs (1.92% faster)
    updated_call = get_call_node(out[0])
    arg_names = [a.id for a in updated_call.args if isinstance(a, ast.Name)]

# ========== EDGE TEST CASES ==========

def test_edge_no_matching_call_returns_none():
    """Test that if no call matches the call_positions, None is returned."""
    code = "result = foo(9)"
    stmt = parse_stmt_with_positions(code)
    call_node = get_call_node(stmt)
    call_node.lineno = stmt.lineno
    call_node.col_offset = 9
    call_node.end_lineno = stmt.lineno
    call_node.end_col_offset = 14
    # Provide call_positions that do not match
    call_positions = [CodePosition(line_no=stmt.lineno+1, col_no=10, end_col_offset=14)]
    func = FunctionToOptimize("foo", "foo", is_async=False)
    inj = InjectPerfOnly(func, "module.py", "pytest", call_positions)
    codeflash_output = inj.find_and_update_line_node(stmt, "nodeF", "idx6"); out = codeflash_output # 7.73μs -> 7.75μs (0.206% slower)

def test_edge_multiple_calls_only_first_replaced():
    """Test that if multiple calls are present, only the first matching is replaced."""
    code = "result = foo(10)\nother = foo(11)"
    mod = ast.parse(code)
    for node in ast.walk(mod):
        ast.fix_missing_locations(node)
    stmt1 = mod.body[0]
    stmt2 = mod.body[1]
    call1 = get_call_node(stmt1)
    call2 = get_call_node(stmt2)
    call1.lineno = stmt1.lineno
    call1.col_offset = 9
    call1.end_lineno = stmt1.lineno
    call1.end_col_offset = 16
    call2.lineno = stmt2.lineno
    call2.col_offset = 8
    call2.end_lineno = stmt2.lineno
    call2.end_col_offset = 15
    call_positions = [
        CodePosition(line_no=stmt1.lineno, col_no=10, end_col_offset=16),
        CodePosition(line_no=stmt2.lineno, col_no=9, end_col_offset=15),
    ]
    func = FunctionToOptimize("foo", "foo", is_async=False)
    inj = InjectPerfOnly(func, "module.py", "pytest", call_positions)
    codeflash_output = inj.find_and_update_line_node(stmt1, "nodeG", "idx7"); out1 = codeflash_output # 11.9μs -> 11.8μs (0.525% faster)
    codeflash_output = inj.find_and_update_line_node(stmt2, "nodeH", "idx8"); out2 = codeflash_output # 8.86μs -> 8.82μs (0.442% faster)
    updated_call1 = get_call_node(out1[0])
    updated_call2 = get_call_node(out2[0])

def test_edge_call_with_keywords_preserved():
    """Test that keyword arguments are preserved in the transformed call."""
    code = "result = foo(a=1, b=2)"
    stmt = parse_stmt_with_positions(code)
    call_node = get_call_node(stmt)
    call_node.lineno = stmt.lineno
    call_node.col_offset = 9
    call_node.end_lineno = stmt.lineno
    call_node.end_col_offset = 22
    call_positions = [CodePosition(line_no=stmt.lineno, col_no=10, end_col_offset=22)]
    func = FunctionToOptimize("foo", "foo", is_async=False)
    inj = InjectPerfOnly(func, "module.py", "pytest", call_positions)
    codeflash_output = inj.find_and_update_line_node(stmt, "nodeI", "idx9"); out = codeflash_output # 10.9μs -> 10.9μs (0.484% slower)
    updated_call = get_call_node(out[0])

def test_edge_nested_call_only_outer_replaced():
    """Test that only the outer call matching call_positions is replaced."""
    code = "result = foo(bar(1))"
    stmt = parse_stmt_with_positions(code)
    outer_call = get_call_node(stmt)
    outer_call.lineno = stmt.lineno
    outer_call.col_offset = 9
    outer_call.end_lineno = stmt.lineno
    outer_call.end_col_offset = 19
    call_positions = [CodePosition(line_no=stmt.lineno, col_no=10, end_col_offset=19)]
    func = FunctionToOptimize("foo", "foo", is_async=False)
    inj = InjectPerfOnly(func, "module.py", "pytest", call_positions)
    codeflash_output = inj.find_and_update_line_node(stmt, "nodeJ", "idx10"); out = codeflash_output # 11.2μs -> 10.9μs (3.30% faster)
    updated_call = get_call_node(out[0])

def test_edge_call_at_start_of_line():
    """Test a call at the very start of a line is matched and replaced."""
    code = "foo(12)"
    stmt = parse_stmt_with_positions(code)
    call_node = get_call_node(stmt)
    call_node.lineno = stmt.lineno
    call_node.col_offset = 0
    call_node.end_lineno = stmt.lineno
    call_node.end_col_offset = 7
    call_positions = [CodePosition(line_no=stmt.lineno, col_no=0, end_col_offset=7)]
    func = FunctionToOptimize("foo", "foo", is_async=False)
    inj = InjectPerfOnly(func, "module.py", "pytest", call_positions)
    codeflash_output = inj.find_and_update_line_node(stmt, "nodeK", "idx11"); out = codeflash_output # 10.3μs -> 10.2μs (0.962% faster)
    updated_call = get_call_node(out[0])

def test_edge_call_at_end_of_line():
    """Test a call at the very end of a line is matched and replaced."""
    code = "x = foo(13)"
    stmt = parse_stmt_with_positions(code)
    call_node = get_call_node(stmt)
    call_node.lineno = stmt.lineno
    call_node.col_offset = 4
    call_node.end_lineno = stmt.lineno
    call_node.end_col_offset = 12
    call_positions = [CodePosition(line_no=stmt.lineno, col_no=12, end_col_offset=12)]
    func = FunctionToOptimize("foo", "foo", is_async=False)
    inj = InjectPerfOnly(func, "module.py", "pytest", call_positions)
    codeflash_output = inj.find_and_update_line_node(stmt, "nodeL", "idx12"); out = codeflash_output # 11.3μs -> 11.1μs (1.80% faster)
    updated_call = get_call_node(out[0])

def test_edge_call_with_none_class_name():
    """Test that passing None for test_class_name works and is injected as None."""
    code = "result = foo(14)"
    stmt = parse_stmt_with_positions(code)
    call_node = get_call_node(stmt)
    call_node.lineno = stmt.lineno
    call_node.col_offset = 9
    call_node.end_lineno = stmt.lineno
    call_node.end_col_offset = 16
    call_positions = [CodePosition(line_no=stmt.lineno, col_no=10, end_col_offset=16)]
    func = FunctionToOptimize("foo", "foo", is_async=False)
    inj = InjectPerfOnly(func, "module.py", "pytest", call_positions)
    codeflash_output = inj.find_and_update_line_node(stmt, "nodeM", "idx13", test_class_name=None); out = codeflash_output # 11.4μs -> 11.2μs (1.71% faster)
    updated_call = get_call_node(out[0])

# ========== LARGE SCALE TEST CASES ==========

def test_large_scale_many_calls():
    """Test performance and correctness with many calls in a single function."""
    # Create a function with 500 calls to foo
    code_lines = [f"x{i} = foo({i})" for i in range(500)]
    code = "\n".join(code_lines)
    mod = ast.parse(code)
    for node in ast.walk(mod):
        ast.fix_missing_locations(node)
    # Set matching positions for all calls
    call_positions = []
    for i, stmt in enumerate(mod.body):
        call_node = get_call_node(stmt)
        call_node.lineno = stmt.lineno
        call_node.col_offset = len(f"x{i} = ")
        call_node.end_lineno = stmt.lineno
        call_node.end_col_offset = call_node.col_offset + len(f"foo({i})")
        call_positions.append(CodePosition(line_no=stmt.lineno, col_no=call_node.col_offset+1, end_col_offset=call_node.end_col_offset))
    func = FunctionToOptimize("foo", "foo", is_async=False)
    inj = InjectPerfOnly(func, "module.py", "pytest", call_positions)
    # Test each call is replaced
    for i, stmt in enumerate(mod.body):
        codeflash_output = inj.find_and_update_line_node(stmt, f"node{i}", f"idx{i}"); out = codeflash_output # 12.1ms -> 8.73ms (38.2% faster)
        updated_call = get_call_node(out[0])

def test_large_scale_many_non_matching_calls():
    """Test that with 500 calls but no matching positions, None is returned for each."""
    code_lines = [f"x{i} = foo({i})" for i in range(500)]
    code = "\n".join(code_lines)
    mod = ast.parse(code)
    for node in ast.walk(mod):
        ast.fix_missing_locations(node)
    # Provide call_positions that do NOT match any call
    call_positions = [CodePosition(line_no=1000, col_no=0, end_col_offset=10)]
    func = FunctionToOptimize("foo", "foo", is_async=False)
    inj = InjectPerfOnly(func, "module.py", "pytest", call_positions)
    for i, stmt in enumerate(mod.body):
        codeflash_output = inj.find_and_update_line_node(stmt, f"node{i}", f"idx{i}"); out = codeflash_output # 2.92ms -> 2.99ms (2.30% slower)

def test_large_scale_calls_with_keywords():
    """Test 100 calls with keyword arguments are all replaced and keywords are preserved."""
    code_lines = [f"x{i} = foo(a={i}, b={i+1})" for i in range(100)]
    code = "\n".join(code_lines)
    mod = ast.parse(code)
    for node in ast.walk(mod):
        ast.fix_missing_locations(node)
    call_positions = []
    for i, stmt in enumerate(mod.body):
        call_node = get_call_node(stmt)
        call_node.lineno = stmt.lineno
        call_node.col_offset = len(f"x{i} = ")
        call_node.end_lineno = stmt.lineno
        call_node.end_col_offset = call_node.col_offset + len(f"foo(a={i}, b={i+1})")
        call_positions.append(CodePosition(line_no=stmt.lineno, col_no=call_node.col_offset+1, end_col_offset=call_node.end_col_offset))
    func = FunctionToOptimize("foo", "foo", is_async=False)
    inj = InjectPerfOnly(func, "module.py", "pytest", call_positions)
    for i, stmt in enumerate(mod.body):
        codeflash_output = inj.find_and_update_line_node(stmt, f"node{i}", f"idx{i}"); out = codeflash_output # 1.12ms -> 987μs (13.0% faster)
        updated_call = get_call_node(out[0])

def test_large_scale_mixed_calls_and_non_calls():
    """Test a mix of function calls and assignments, only calls are replaced."""
    code_lines = []
    for i in range(250):
        code_lines.append(f"x{i} = foo({i})")
        code_lines.append(f"y{i} = {i}")
    code = "\n".join(code_lines)
    mod = ast.parse(code)
    for node in ast.walk(mod):
        ast.fix_missing_locations(node)
    call_positions = []
    for i in range(250):
        stmt = mod.body[i*2]
        call_node = get_call_node(stmt)
        call_node.lineno = stmt.lineno
        call_node.col_offset = len(f"x{i} = ")
        call_node.end_lineno = stmt.lineno
        call_node.end_col_offset = call_node.col_offset + len(f"foo({i})")
        call_positions.append(CodePosition(line_no=stmt.lineno, col_no=call_node.col_offset+1, end_col_offset=call_node.end_col_offset))
    func = FunctionToOptimize("foo", "foo", is_async=False)
    inj = InjectPerfOnly(func, "module.py", "pytest", call_positions)
    # Only even-numbered stmts should be replaced
    for i in range(250):
        call_stmt = mod.body[i*2]
        codeflash_output = inj.find_and_update_line_node(call_stmt, f"node{i}", f"idx{i}"); out = codeflash_output # 4.04ms -> 3.21ms (25.8% faster)
        updated_call = get_call_node(out[0])
        # Non-call stmts should return None
        non_call_stmt = mod.body[i*2+1]
        codeflash_output = inj.find_and_update_line_node(non_call_stmt, f"node{i}", f"idx{i}"); out2 = codeflash_output # 916μs -> 933μs (1.83% slower)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-pr769-2025-09-26T22.57.31 and push.

Codeflash

The optimization achieves a **24% speedup** by targeting two key performance bottlenecks identified in the line profiler results:

**1. Optimized `node_in_call_position` function (~22% faster):**
- **Reduced attribute lookups**: Pre-fetches `lineno`, `col_offset`, `end_lineno`, and `end_col_offset` once using `getattr()` instead of repeatedly calling `hasattr()` and accessing attributes in the loop
- **Early exit optimization**: Returns `False` immediately if not an `ast.Call` node, avoiding unnecessary work
- **Simplified conditional logic**: Combines nested checks into a single block to reduce Python opcode jumps

**2. Optimized `find_and_update_line_node` method (~18% faster):**
- **Cached attribute access**: Stores frequently accessed attributes (`self.function_object.function_name`, `self.mode`, etc.) in local variables to avoid repeated object attribute lookups
- **Efficient list construction**: Builds the `args` list incrementally using `extend()` instead of creating multiple intermediate lists with unpacking operators
- **Early termination**: Breaks immediately after finding and modifying the matching call node, avoiding unnecessary continuation of `ast.walk()`

**Performance gains are most significant for:**
- Large-scale test cases with many function calls (up to 38% faster for 500+ calls)
- Mixed workloads with calls and non-calls (25% faster)
- Tests with keyword arguments (13% faster)

The optimizations maintain identical behavior while reducing CPU-intensive operations like attribute lookups and list operations that dominate the execution time in AST transformation workflows.
@codeflash-ai codeflash-ai bot added the ⚡️ codeflash Optimization PR opened by Codeflash AI label Sep 26, 2025
@KRRT7 KRRT7 closed this Sep 27, 2025
@codeflash-ai codeflash-ai bot deleted the codeflash/optimize-pr769-2025-09-26T22.57.31 branch September 27, 2025 00:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant