Skip to content

Conversation

@codeflash-ai
Copy link
Contributor

@codeflash-ai codeflash-ai bot commented Sep 23, 2025

⚡️ This pull request contains optimizations for PR #739

If you approve this dependent PR, these changes will be merged into the original PR branch get-throughput-from-output.

This PR will be automatically closed if the original PR is merged.


📄 50% (0.50x) speedup for AsyncCallInstrumenter.visit_AsyncFunctionDef in codeflash/code_utils/instrument_existing_tests.py

⏱️ Runtime : 6.85 milliseconds 4.56 milliseconds (best of 103 runs)

📝 Explanation and details

The optimized code achieves a 50% speedup by replacing the expensive ast.walk() traversal with a targeted stack-based search in the new _instrument_statement_fast() method.

Key optimizations:

  1. Custom AST traversal replaces ast.walk(): The original code used ast.walk(stmt) which visits every node in the AST subtree. The optimized version uses a manual stack-based traversal that only looks for ast.Await nodes, significantly reducing the number of nodes examined.

  2. Early termination: Once an ast.Await node matching the target criteria is found, the search immediately breaks and returns, avoiding unnecessary traversal of remaining nodes.

  3. Optimized decorator checking: The any() generator expression is replaced with a simple for-loop that can exit early when a timeout decorator is found, though this provides minimal gains compared to the AST optimization.

Why this works so well:

  • ast.walk() performs a breadth-first traversal of all nodes in the AST subtree, which can be hundreds of nodes for complex statements
  • The optimized version only examines nodes that could potentially contain ast.Await expressions, dramatically reducing the search space
  • For large test functions with many statements (as shown in the annotated tests), this optimization scales particularly well - the 500+ await call test cases show 50-53% speedup

The optimization is most effective for test cases with:

  • Large numbers of async function calls (50%+ improvement)
  • Complex nested structures with few actual target calls (40%+ improvement)
  • Mixed await patterns where only some calls need instrumentation (35%+ improvement)

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 48 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
from __future__ import annotations

import ast
import os
from types import ModuleType

# imports
import pytest  # used for our unit tests
from codeflash.code_utils.instrument_existing_tests import \
    AsyncCallInstrumenter


# Mocks for FunctionToOptimize, CodePosition, TestingMode
class CodePosition:
    def __init__(self, lineno, col_offset):
        self.lineno = lineno
        self.col_offset = col_offset

class Parent:
    def __init__(self, type_, top_level_parent_name):
        self.type = type_
        self.top_level_parent_name = top_level_parent_name

class FunctionToOptimize:
    def __init__(self, function_name, parents):
        self.function_name = function_name
        self.parents = parents
        self.top_level_parent_name = parents[0].top_level_parent_name if parents else None

# Helper to parse and instrument code, then return the modified AST
def instrument_code(src, instrumenter):
    tree = ast.parse(src)
    new_tree = instrumenter.visit(tree)
    ast.fix_missing_locations(new_tree)
    return new_tree

# -------------------- UNIT TESTS --------------------

# 1. Basic Test Cases

def test_non_test_function_is_unchanged():
    """Non-test async functions should not be instrumented or decorated."""
    src = "async def foo():\n    await target_func()"
    function = FunctionToOptimize("foo", [Parent("Module", None)])
    call_pos = [CodePosition(2, 10)]  # line 2, col 10
    instr = AsyncCallInstrumenter(function, "mod.py", "unittest", call_pos)
    tree = instrument_code(src, instr)
    # Should not have env assignment or decorator
    funcdef = next(n for n in ast.walk(tree) if isinstance(n, ast.AsyncFunctionDef))

def test_test_function_instruments_env_assignment():
    """Test function should get env assignment before target call."""
    src = "async def test_bar():\n    await target_func()"
    function = FunctionToOptimize("test_bar", [Parent("Module", None)])
    call_pos = [CodePosition(2, 10)]
    instr = AsyncCallInstrumenter(function, "mod.py", "unittest", call_pos)
    tree = instrument_code(src, instr)
    funcdef = next(n for n in ast.walk(tree) if isinstance(n, ast.AsyncFunctionDef))
    # Should have env assignment and timeout_decorator
    assigns = [n for n in funcdef.body if isinstance(n, ast.Assign)]

def test_test_function_multiple_target_calls():
    """Multiple target_func calls should get incrementing env assignments."""
    src = (
        "async def test_baz():\n"
        "    await target_func()\n"
        "    await target_func()"
    )
    function = FunctionToOptimize("test_baz", [Parent("Module", None)])
    call_pos = [CodePosition(2, 10), CodePosition(3, 10)]
    instr = AsyncCallInstrumenter(function, "mod.py", "unittest", call_pos)
    tree = instrument_code(src, instr)
    funcdef = next(n for n in ast.walk(tree) if isinstance(n, ast.AsyncFunctionDef))
    assigns = [n for n in funcdef.body if isinstance(n, ast.Assign)]
    # Should have two assignments with values "0" and "1"
    values = [a.value.value for a in assigns]

def test_test_function_non_target_call():
    """Calls to other functions should not be instrumented."""
    src = (
        "async def test_qux():\n"
        "    await not_target_func()\n"
        "    await target_func()"
    )
    function = FunctionToOptimize("test_qux", [Parent("Module", None)])
    call_pos = [CodePosition(3, 10)]
    instr = AsyncCallInstrumenter(function, "mod.py", "unittest", call_pos)
    tree = instrument_code(src, instr)
    funcdef = next(n for n in ast.walk(tree) if isinstance(n, ast.AsyncFunctionDef))
    assigns = [n for n in funcdef.body if isinstance(n, ast.Assign)]

def test_test_function_already_has_timeout_decorator():
    """Should not add timeout_decorator if already present."""
    src = (
        "async def test_decor():\n"
        "    await target_func()"
    )
    # Add timeout_decorator manually
    tree = ast.parse(src)
    funcdef = next(n for n in ast.walk(tree) if isinstance(n, ast.AsyncFunctionDef))
    funcdef.decorator_list.append(
        ast.Call(
            func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
            args=[ast.Constant(value=15)],
            keywords=[],
        )
    )
    function = FunctionToOptimize("test_decor", [Parent("Module", None)])
    call_pos = [CodePosition(2, 10)]
    instr = AsyncCallInstrumenter(function, "mod.py", "unittest", call_pos)
    new_tree = instr.visit(tree)
    funcdef = next(n for n in ast.walk(new_tree) if isinstance(n, ast.AsyncFunctionDef))
    # Should only have one timeout_decorator
    decorators = [
        d for d in funcdef.decorator_list
        if isinstance(d, ast.Call) and isinstance(d.func, ast.Name) and d.func.id == "timeout_decorator.timeout"
    ]

# 2. Edge Test Cases

def test_test_function_empty_body():
    """Test function with no body should not fail or instrument."""
    src = "async def test_empty():\n    pass"
    function = FunctionToOptimize("test_empty", [Parent("Module", None)])
    call_pos = [CodePosition(2, 10)]
    instr = AsyncCallInstrumenter(function, "mod.py", "unittest", call_pos)
    tree = instrument_code(src, instr)
    funcdef = next(n for n in ast.walk(tree) if isinstance(n, ast.AsyncFunctionDef))

def test_test_function_target_call_not_in_positions():
    """Target call not in positions should not be instrumented."""
    src = "async def test_noinst():\n    await target_func()"
    function = FunctionToOptimize("test_noinst", [Parent("Module", None)])
    call_pos = [CodePosition(100, 100)]  # wrong position
    instr = AsyncCallInstrumenter(function, "mod.py", "unittest", call_pos)
    tree = instrument_code(src, instr)
    funcdef = next(n for n in ast.walk(tree) if isinstance(n, ast.AsyncFunctionDef))

def test_test_function_nested_await():
    """Nested await target_func should be instrumented if position matches."""
    src = (
        "async def test_nested():\n"
        "    if True:\n"
        "        await target_func()"
    )
    function = FunctionToOptimize("test_nested", [Parent("Module", None)])
    call_pos = [CodePosition(3, 14)]
    instr = AsyncCallInstrumenter(function, "mod.py", "unittest", call_pos)
    tree = instrument_code(src, instr)
    funcdef = next(n for n in ast.walk(tree) if isinstance(n, ast.AsyncFunctionDef))
    assigns = [n for n in ast.walk(funcdef) if isinstance(n, ast.Assign)]

def test_test_function_with_other_decorators():
    """Should preserve other decorators and add timeout_decorator."""
    src = (
        "@other_decorator\n"
        "async def test_decor2():\n"
        "    await target_func()"
    )
    function = FunctionToOptimize("test_decor2", [Parent("Module", None)])
    call_pos = [CodePosition(3, 10)]
    instr = AsyncCallInstrumenter(function, "mod.py", "unittest", call_pos)
    tree = instrument_code(src, instr)
    funcdef = next(n for n in ast.walk(tree) if isinstance(n, ast.AsyncFunctionDef))
    # Should have both decorators
    decorator_names = [
        d.func.id
        for d in funcdef.decorator_list
        if isinstance(d, ast.Call) and isinstance(d.func, ast.Name)
    ]

def test_test_function_with_different_framework():
    """Should not add timeout_decorator for pytest."""
    src = "async def test_pytest():\n    await target_func()"
    function = FunctionToOptimize("test_pytest", [Parent("Module", None)])
    call_pos = [CodePosition(2, 10)]
    instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_pos)
    tree = instrument_code(src, instr)
    funcdef = next(n for n in ast.walk(tree) if isinstance(n, ast.AsyncFunctionDef))

def test_test_function_with_no_target_calls():
    """Test function with no target_func calls should not be instrumented."""
    src = "async def test_none():\n    await something_else()"
    function = FunctionToOptimize("test_none", [Parent("Module", None)])
    call_pos = [CodePosition(2, 10)]
    instr = AsyncCallInstrumenter(function, "mod.py", "unittest", call_pos)
    tree = instrument_code(src, instr)
    funcdef = next(n for n in ast.walk(tree) if isinstance(n, ast.AsyncFunctionDef))
    assigns = [n for n in funcdef.body if isinstance(n, ast.Assign)]

def test_test_function_with_try_except():
    """Should instrument target_func inside try/except blocks."""
    src = (
        "async def test_try():\n"
        "    try:\n"
        "        await target_func()\n"
        "    except Exception:\n"
        "        pass"
    )
    function = FunctionToOptimize("test_try", [Parent("Module", None)])
    call_pos = [CodePosition(3, 14)]
    instr = AsyncCallInstrumenter(function, "mod.py", "unittest", call_pos)
    tree = instrument_code(src, instr)
    funcdef = next(n for n in ast.walk(tree) if isinstance(n, ast.AsyncFunctionDef))
    assigns = [n for n in ast.walk(funcdef) if isinstance(n, ast.Assign)]

# 3. Large Scale Test Cases

def test_large_number_of_target_calls():
    """Instrument up to 1000 target_func calls and assignments."""
    src = "async def test_large():\n" + "\n".join(
        f"    await target_func()" for _ in range(1000)
    )
    function = FunctionToOptimize("test_large", [Parent("Module", None)])
    call_pos = [CodePosition(i + 2, 10) for i in range(1000)]
    instr = AsyncCallInstrumenter(function, "mod.py", "unittest", call_pos)
    tree = instrument_code(src, instr)
    funcdef = next(n for n in ast.walk(tree) if isinstance(n, ast.AsyncFunctionDef))
    assigns = [n for n in funcdef.body if isinstance(n, ast.Assign)]

def test_large_number_of_non_target_calls():
    """Instrument only target_func calls among many other calls."""
    src = "async def test_sparse():\n" + "\n".join(
        f"    await {'target_func()' if i % 10 == 0 else 'other_func()'}"
        for i in range(1000)
    )
    function = FunctionToOptimize("test_sparse", [Parent("Module", None)])
    call_pos = [CodePosition(i + 2, 10) for i in range(0, 1000, 10)]
    instr = AsyncCallInstrumenter(function, "mod.py", "unittest", call_pos)
    tree = instrument_code(src, instr)
    funcdef = next(n for n in ast.walk(tree) if isinstance(n, ast.AsyncFunctionDef))
    assigns = [n for n in funcdef.body if isinstance(n, ast.Assign)]

def test_large_test_function_no_target_calls():
    """Large function with no target_func calls should not be instrumented."""
    src = "async def test_none_large():\n" + "\n".join(
        f"    await other_func()" for _ in range(1000)
    )
    function = FunctionToOptimize("test_none_large", [Parent("Module", None)])
    call_pos = [CodePosition(i + 2, 10) for i in range(1000)]
    instr = AsyncCallInstrumenter(function, "mod.py", "unittest", call_pos)
    tree = instrument_code(src, instr)
    funcdef = next(n for n in ast.walk(tree) if isinstance(n, ast.AsyncFunctionDef))
    assigns = [n for n in funcdef.body if isinstance(n, ast.Assign)]

def test_large_test_function_mixed_calls_and_positions():
    """Instrument only target_func calls at specified positions."""
    src = "async def test_mixed():\n" + "\n".join(
        f"    await {'target_func()' if i % 7 == 0 else 'other_func()'}"
        for i in range(1000)
    )
    # Only even multiples of 7
    call_pos = [CodePosition(i + 2, 10) for i in range(0, 1000, 14)]
    function = FunctionToOptimize("test_mixed", [Parent("Module", None)])
    instr = AsyncCallInstrumenter(function, "mod.py", "unittest", call_pos)
    tree = instrument_code(src, instr)
    funcdef = next(n for n in ast.walk(tree) if isinstance(n, ast.AsyncFunctionDef))
    assigns = [n for n in funcdef.body if isinstance(n, ast.Assign)]
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import ast
import sys
from types import ModuleType

# imports
import pytest  # used for our unit tests
from codeflash.code_utils.instrument_existing_tests import \
    AsyncCallInstrumenter


class FunctionToOptimize:
    def __init__(self, function_name, parents=None, top_level_parent_name=None):
        self.function_name = function_name
        self.parents = parents or []
        self.top_level_parent_name = top_level_parent_name

class TestingMode:
    BEHAVIOR = "BEHAVIOR"
    COVERAGE = "COVERAGE"

# Helper to parse code and get AsyncFunctionDef node
def get_async_func_node(src_code: str) -> ast.AsyncFunctionDef:
    tree = ast.parse(src_code)
    for node in ast.walk(tree):
        if isinstance(node, ast.AsyncFunctionDef):
            return node
    raise ValueError("No AsyncFunctionDef found")

# Helper to check if an env assignment was inserted before await calls to 'foo'
def has_env_assignment(node: ast.AsyncFunctionDef) -> bool:
    # Look for ast.Assign to os.environ["CODEFLASH_CURRENT_LINE_ID"]
    for stmt in node.body:
        if isinstance(stmt, ast.Assign):
            target = stmt.targets[0]
            if (
                isinstance(target, ast.Subscript)
                and isinstance(target.value, ast.Attribute)
                and target.value.attr == "environ"
                and isinstance(target.slice, ast.Constant)
                and target.slice.value == "CODEFLASH_CURRENT_LINE_ID"
            ):
                return True
    return False

# Helper to count env assignments
def count_env_assignments(node: ast.AsyncFunctionDef) -> int:
    count = 0
    for stmt in node.body:
        if isinstance(stmt, ast.Assign):
            target = stmt.targets[0]
            if (
                isinstance(target, ast.Subscript)
                and isinstance(target.value, ast.Attribute)
                and target.value.attr == "environ"
                and isinstance(target.slice, ast.Constant)
                and target.slice.value == "CODEFLASH_CURRENT_LINE_ID"
            ):
                count += 1
    return count

# Helper to check for timeout_decorator
def has_timeout_decorator(node: ast.AsyncFunctionDef) -> bool:
    for d in node.decorator_list:
        if (
            isinstance(d, ast.Call)
            and isinstance(d.func, ast.Name)
            and d.func.id == "timeout_decorator.timeout"
        ):
            return True
    return False

# Basic Test Cases

def test_non_test_function_is_unchanged():
    # Function name does not start with test_, should be unchanged
    src = "async def not_a_test():\n    await foo()"
    node = get_async_func_node(src)
    func = FunctionToOptimize("not_a_test")
    instr = AsyncCallInstrumenter(func, "mod.py", "pytest", [], TestingMode.BEHAVIOR)
    codeflash_output = instr.visit_AsyncFunctionDef(node); new_node = codeflash_output # 952ns -> 912ns (4.39% faster)

def test_test_function_with_await_foo_inserts_env_assignment():
    # Function named test_*, with await foo(), should insert env assignment
    src = "async def test_example():\n    await foo()"
    node = get_async_func_node(src)
    func = FunctionToOptimize("test_example")
    instr = AsyncCallInstrumenter(func, "mod.py", "pytest", [], TestingMode.BEHAVIOR)
    codeflash_output = instr.visit_AsyncFunctionDef(node); new_node = codeflash_output # 9.42μs -> 7.65μs (23.0% faster)

def test_test_function_with_multiple_await_foo_inserts_multiple_env_assignments():
    # Multiple await foo() calls, should insert env assignment before each
    src = "async def test_example():\n    await foo()\n    await foo()"
    node = get_async_func_node(src)
    func = FunctionToOptimize("test_example")
    instr = AsyncCallInstrumenter(func, "mod.py", "pytest", [], TestingMode.BEHAVIOR)
    codeflash_output = instr.visit_AsyncFunctionDef(node); new_node = codeflash_output # 14.1μs -> 10.3μs (36.7% faster)

def test_test_function_with_other_await_calls_does_not_instrument():
    # Only await foo() gets instrumented, await bar() does not
    src = "async def test_example():\n    await bar()\n    await foo()"
    node = get_async_func_node(src)
    func = FunctionToOptimize("test_example")
    instr = AsyncCallInstrumenter(func, "mod.py", "pytest", [], TestingMode.BEHAVIOR)
    codeflash_output = instr.visit_AsyncFunctionDef(node); new_node = codeflash_output # 13.9μs -> 10.2μs (36.2% faster)

def test_unittest_framework_adds_timeout_decorator():
    # If test_framework is 'unittest', should add timeout_decorator
    src = "async def test_example():\n    await foo()"
    node = get_async_func_node(src)
    func = FunctionToOptimize("test_example")
    instr = AsyncCallInstrumenter(func, "mod.py", "unittest", [], TestingMode.BEHAVIOR)
    codeflash_output = instr.visit_AsyncFunctionDef(node); new_node = codeflash_output # 12.8μs -> 10.1μs (26.7% faster)

def test_unittest_framework_does_not_duplicate_timeout_decorator():
    # If timeout_decorator already present, should not add another
    src = (
        "async def test_example():\n"
        "    await foo()\n"
    )
    node = get_async_func_node(src)
    # Manually add timeout_decorator
    timeout_decorator = ast.Call(
        func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
        args=[ast.Constant(value=15)],
        keywords=[],
    )
    node.decorator_list.append(timeout_decorator)
    func = FunctionToOptimize("test_example")
    instr = AsyncCallInstrumenter(func, "mod.py", "unittest", [], TestingMode.BEHAVIOR)
    codeflash_output = instr.visit_AsyncFunctionDef(node); new_node = codeflash_output # 10.6μs -> 7.36μs (43.9% faster)
    # Should only be one timeout_decorator
    count = sum(
        1 for d in new_node.decorator_list
        if isinstance(d, ast.Call)
        and isinstance(d.func, ast.Name)
        and d.func.id == "timeout_decorator.timeout"
    )

# Edge Test Cases

def test_test_function_with_no_body():
    # Function with empty body should not fail
    src = "async def test_example():\n    pass"
    node = get_async_func_node(src)
    func = FunctionToOptimize("test_example")
    instr = AsyncCallInstrumenter(func, "mod.py", "pytest", [], TestingMode.BEHAVIOR)
    codeflash_output = instr.visit_AsyncFunctionDef(node); new_node = codeflash_output # 4.75μs -> 3.16μs (50.5% faster)

def test_test_function_with_nested_await_foo_instruments_inner():
    # Await foo() inside an if block
    src = "async def test_example():\n    if True:\n        await foo()"
    node = get_async_func_node(src)
    func = FunctionToOptimize("test_example")
    instr = AsyncCallInstrumenter(func, "mod.py", "pytest", [], TestingMode.BEHAVIOR)
    codeflash_output = instr.visit_AsyncFunctionDef(node); new_node = codeflash_output # 11.4μs -> 9.35μs (21.5% faster)
    # Should have env assignment before await foo()
    found = False
    for stmt in new_node.body:
        if isinstance(stmt, ast.If):
            # Look inside the if block
            for inner in stmt.body:
                if isinstance(inner, ast.Assign):
                    found = True

def test_test_function_with_try_except_and_await_foo():
    # Await foo() inside try block
    src = (
        "async def test_example():\n"
        "    try:\n"
        "        await foo()\n"
        "    except Exception:\n"
        "        pass"
    )
    node = get_async_func_node(src)
    func = FunctionToOptimize("test_example")
    instr = AsyncCallInstrumenter(func, "mod.py", "pytest", [], TestingMode.BEHAVIOR)
    codeflash_output = instr.visit_AsyncFunctionDef(node); new_node = codeflash_output # 13.1μs -> 10.9μs (20.6% faster)
    # Should have env assignment inside try block
    found = False
    for stmt in new_node.body:
        if isinstance(stmt, ast.Try):
            for inner in stmt.body:
                if isinstance(inner, ast.Assign):
                    found = True

def test_test_function_with_multiple_env_assignments_has_correct_indices():
    # Each env assignment should have incrementing values
    src = "async def test_example():\n    await foo()\n    await foo()"
    node = get_async_func_node(src)
    func = FunctionToOptimize("test_example")
    instr = AsyncCallInstrumenter(func, "mod.py", "pytest", [], TestingMode.BEHAVIOR)
    codeflash_output = instr.visit_AsyncFunctionDef(node); new_node = codeflash_output # 13.7μs -> 9.99μs (37.2% faster)
    indices = []
    for stmt in new_node.body:
        if isinstance(stmt, ast.Assign):
            if (
                isinstance(stmt.value, ast.Constant)
                and stmt.value.value.isdigit()
            ):
                indices.append(int(stmt.value.value))

def test_test_function_with_non_await_foo_does_not_instrument():
    # 'foo()' not awaited, should not instrument
    src = "async def test_example():\n    foo()"
    node = get_async_func_node(src)
    func = FunctionToOptimize("test_example")
    instr = AsyncCallInstrumenter(func, "mod.py", "pytest", [], TestingMode.BEHAVIOR)
    codeflash_output = instr.visit_AsyncFunctionDef(node); new_node = codeflash_output # 7.47μs -> 5.88μs (27.1% faster)

def test_test_function_with_different_function_name_only_instruments_target():
    # Only target function name should be instrumented
    src = "async def test_example():\n    await foo()\n    await bar()"
    node = get_async_func_node(src)
    func = FunctionToOptimize("test_example")
    instr = AsyncCallInstrumenter(func, "mod.py", "pytest", [], TestingMode.BEHAVIOR)
    codeflash_output = instr.visit_AsyncFunctionDef(node); new_node = codeflash_output # 13.7μs -> 10.2μs (34.4% faster)

# Large Scale Test Cases

def test_large_test_function_with_many_await_foo():
    # Large function with many await foo() calls
    lines = ["async def test_example():"]
    for i in range(500):
        lines.append(f"    await foo()")
    src = "\n".join(lines)
    node = get_async_func_node(src)
    func = FunctionToOptimize("test_example")
    instr = AsyncCallInstrumenter(func, "mod.py", "pytest", [], TestingMode.BEHAVIOR)
    codeflash_output = instr.visit_AsyncFunctionDef(node); new_node = codeflash_output # 2.06ms -> 1.34ms (53.0% faster)

def test_large_test_function_with_mixed_await_calls():
    # Large function with mixed await foo() and await bar()
    lines = ["async def test_example():"]
    for i in range(250):
        lines.append(f"    await foo()")
        lines.append(f"    await bar()")
    src = "\n".join(lines)
    node = get_async_func_node(src)
    func = FunctionToOptimize("test_example")
    instr = AsyncCallInstrumenter(func, "mod.py", "pytest", [], TestingMode.BEHAVIOR)
    codeflash_output = instr.visit_AsyncFunctionDef(node); new_node = codeflash_output # 2.04ms -> 1.34ms (52.0% faster)

def test_large_test_function_with_nested_blocks_and_await_foo():
    # Large function with nested blocks, each with await foo()
    lines = ["async def test_example():"]
    for i in range(100):
        lines.append(f"    if True:")
        lines.append(f"        await foo()")
    src = "\n".join(lines)
    node = get_async_func_node(src)
    func = FunctionToOptimize("test_example")
    instr = AsyncCallInstrumenter(func, "mod.py", "pytest", [], TestingMode.BEHAVIOR)
    codeflash_output = instr.visit_AsyncFunctionDef(node); new_node = codeflash_output # 587μs -> 418μs (40.4% faster)
    # Should have 100 env assignments (one per block)
    count = 0
    for stmt in new_node.body:
        if isinstance(stmt, ast.If):
            for inner in stmt.body:
                if isinstance(inner, ast.Assign):
                    count += 1

def test_large_test_function_with_no_await_foo():
    # Large function, but no await foo(), should not instrument
    lines = ["async def test_example():"]
    for i in range(500):
        lines.append(f"    await bar()")
    src = "\n".join(lines)
    node = get_async_func_node(src)
    func = FunctionToOptimize("test_example")
    instr = AsyncCallInstrumenter(func, "mod.py", "pytest", [], TestingMode.BEHAVIOR)
    codeflash_output = instr.visit_AsyncFunctionDef(node); new_node = codeflash_output # 2.05ms -> 1.36ms (50.6% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-pr739-2025-09-23T04.24.41 and push.

Codeflash

The optimized code achieves a **50% speedup** by replacing the expensive `ast.walk()` traversal with a targeted stack-based search in the new `_instrument_statement_fast()` method.

**Key optimizations:**

1. **Custom AST traversal replaces `ast.walk()`**: The original code used `ast.walk(stmt)` which visits *every* node in the AST subtree. The optimized version uses a manual stack-based traversal that only looks for `ast.Await` nodes, significantly reducing the number of nodes examined.

2. **Early termination**: Once an `ast.Await` node matching the target criteria is found, the search immediately breaks and returns, avoiding unnecessary traversal of remaining nodes.

3. **Optimized decorator checking**: The `any()` generator expression is replaced with a simple for-loop that can exit early when a timeout decorator is found, though this provides minimal gains compared to the AST optimization.

**Why this works so well:**
- `ast.walk()` performs a breadth-first traversal of *all* nodes in the AST subtree, which can be hundreds of nodes for complex statements
- The optimized version only examines nodes that could potentially contain `ast.Await` expressions, dramatically reducing the search space
- For large test functions with many statements (as shown in the annotated tests), this optimization scales particularly well - the 500+ await call test cases show **50-53% speedup**

The optimization is most effective for test cases with:
- Large numbers of async function calls (50%+ improvement)
- Complex nested structures with few actual target calls (40%+ improvement) 
- Mixed await patterns where only some calls need instrumentation (35%+ improvement)
@codeflash-ai codeflash-ai bot added the ⚡️ codeflash Optimization PR opened by Codeflash AI label Sep 23, 2025
@codeflash-ai codeflash-ai bot closed this Sep 24, 2025
@codeflash-ai
Copy link
Contributor Author

codeflash-ai bot commented Sep 24, 2025

This PR has been automatically closed because the original PR #739 by KRRT7 was closed.

@codeflash-ai codeflash-ai bot deleted the codeflash/optimize-pr739-2025-09-23T04.24.41 branch September 24, 2025 17:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants