Skip to content

Replace Regex with tree-sitter#1495

Merged
HeshamHM28 merged 1 commit intoomni-javafrom
replace/regex/with/tree-sitter
Feb 16, 2026
Merged

Replace Regex with tree-sitter#1495
HeshamHM28 merged 1 commit intoomni-javafrom
replace/regex/with/tree-sitter

Conversation

@HeshamHM28
Copy link
Contributor

Summary

  • Replace fragile regex-based Java call detection in behavior instrumentation with tree-sitter AST analysis, fixing compilation errors from incorrect code transformations
  • Replace regex-based _extract_target_calls in assertion removal with tree-sitter AST walking
  • Fix 23 broken tests by updating calls to match current instrument_existing_test and instrument_generated_java_test signatures
  • Add syntax validation to Java line profiler instrumentation

Changes

Tree-sitter behavior instrumentation (instrumentation.py)

The behavior instrumentation (_add_behavior_instrumentation) previously used regex (_find_method_calls_balanced) to find target function calls in test methods and wrap them with capture/serialize code. This caused three
classes of bugs:

  1. try-catch corruption: try { func(-1); } catch (Exception e) {} → call was moved outside the try-catch, leaving the exception uncaught
  2. Variable scope loss: long first = func(15); → the assignment line was dropped, making later references to first fail with "cannot find symbol"
  3. Lambda false positives: () -> func(-1) inside assertThrows was sometimes incorrectly wrapped

Replaced with wrap_target_calls_with_treesitter() which:

  • Parses the method body with tree-sitter by wrapping in class _D { void _m() { ... }}
  • Walks the AST for method_invocation nodes matching the target function
  • Uses parent node type to determine replacement strategy:
    • expression_statement: replaces the statement IN PLACE with capture+serialize (keeps code inside try blocks)
    • variable_declarator / argument_list / other: emits capture+serialize before the line, replaces call with variable
  • Detects lambdas via ancestor walk (lambda_expression before method_declaration), correctly skipping all lambda forms

Tree-sitter assertion extraction (remove_asserts.py)

Replaced regex backward-scanning in _extract_target_calls with tree-sitter AST walking. Added three focused methods: _extract_target_calls, _collect_target_invocations, _build_target_call.

@HeshamHM28 HeshamHM28 merged commit 4c97641 into omni-java Feb 16, 2026
7 of 29 checks passed
@HeshamHM28 HeshamHM28 deleted the replace/regex/with/tree-sitter branch February 16, 2026 06:44
Comment on lines +584 to +601
get_text = self.analyzer.get_node_text

object_node = node.child_by_field_name("object")
args_node = node.child_by_field_name("arguments")
args_text = get_text(args_node, wrapper_bytes) if args_node else ""
# argument_list node includes parens, strip them
if args_text.startswith("(") and args_text.endswith(")"):
args_text = args_text[1:-1]

# Byte offsets -> char offsets for correct Python string indexing
start_char = len(content_bytes[:start_byte].decode("utf8"))
end_char = len(content_bytes[:end_byte].decode("utf8"))

return TargetCall(
receiver=get_text(object_node, wrapper_bytes) if object_node else None,
method_name=self.func_name,
arguments=args_text,
full_call=get_text(node, wrapper_bytes),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚡️Codeflash found 17% (0.17x) speedup for JavaAssertTransformer._build_target_call in codeflash/languages/java/remove_asserts.py

⏱️ Runtime : 1.78 milliseconds 1.52 milliseconds (best of 52 runs)

📝 Explanation and details

The optimized code achieves a 16% runtime improvement by eliminating intermediate function calls and reducing attribute access overhead in the hot path _build_target_call method.

Key Optimizations:

  1. Inlined get_node_text calls: The original code called self.analyzer.get_node_text() three times per invocation (for receiver, arguments, and full_call). Each call had overhead from function dispatch, parameter passing, and the get_text local variable assignment. The optimized version directly performs the slice-and-decode operations inline, eliminating ~9.5 microseconds (3 × 3.2μs) of function call overhead per invocation.

  2. Direct attribute access: Instead of accessing node.start_byte and node.end_byte through the get_text helper, the optimized code directly accesses these attributes and caches them in local variables (obj_start, obj_end, etc.). This reduces attribute lookup overhead, especially since these values are used immediately for slicing.

  3. Conditional text extraction: The optimized code uses explicit conditionals to check for object_node and args_node presence before extracting text, setting default values (receiver = None, args_text = "") upfront. This eliminates the ternary expression overhead in the original code's get_text(object_node, wrapper_bytes) if object_node else None.

  4. Optimized parentheses check: Changed from startswith()/endswith() to direct character indexing with length check: len(args_text) >= 2 and args_text[0] == "(" and args_text[-1] == ")". This is faster for short strings as it avoids method dispatch overhead.

Performance Impact by Test Case:

  • Simple calls with receiver/arguments (most common): 7-14% faster, as these benefit most from eliminating the 3 function calls
  • Large argument lists (1000 items): 3.5% faster, showing the optimization scales well even with large text slices
  • Repeated calls (1000 iterations): 16.9% faster, demonstrating consistent performance gains in hot loops

The line profiler shows the cumulative time spent in get_text (9.4ms total in original) is now distributed across more granular inline operations, with the overall method time reduced from 16.5ms to 11.1ms. This optimization is particularly valuable if _build_target_call is invoked frequently during Java code parsing and transformation workflows.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 66 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Click to see Generated Regression Tests
import pytest  # used for our unit tests
from codeflash.languages.java.parser import JavaAnalyzer
from codeflash.languages.java.remove_asserts import (JavaAssertTransformer,
                                                     TargetCall)

# Helper "fake" node class to mimic the minimal tree-sitter Node behavior used by the
# implementation under test. We purposely only implement the attributes/methods that
# _build_target_call exercises: start_byte, end_byte and child_by_field_name(name).
# NOTE: This is a very small helper used solely to provide the required shape;
# it is not intended to emulate the full tree-sitter Node API.
class _NodeLike:
    def __init__(self, start_byte: int, end_byte: int, children: dict | None = None):
        # bytes offsets that JavaAnalyzer.get_node_text will use to slice the wrapper bytes
        self.start_byte = start_byte
        self.end_byte = end_byte
        # mapping field name -> node
        self._children = children or {}

    def child_by_field_name(self, name: str):
        # return the child node or None, matching tree-sitter Node API behavior used.
        return self._children.get(name)

def test_build_target_call_with_receiver_and_arguments():
    # Basic scenario: node has both an object (receiver) and an arguments node.
    # We'll craft wrapper_bytes so that the bytes slice corresponding to the node
    # decodes to "obj.test(42)" and the arguments slice decodes to "(42)".
    transformer = JavaAssertTransformer(function_name="test")  # method_name -> "test"

    # Create a wrapper bytes buffer that contains the call at known offsets.
    call_str = "obj.test(42)"
    wrapper_bytes = call_str.encode("utf8")  # the simplest possible wrapper
    content_bytes = wrapper_bytes  # use same bytes for computing char offsets

    # Create nodes: arguments node covers the "(42)" portion, object node covers "obj"
    # Determine byte offsets by locating substrings.
    obj_start = call_str.index("obj")
    obj_end = obj_start + len("obj")
    args_start = call_str.index("(")
    args_end = args_start + len("(42)")
    node_start = obj_start
    node_end = args_end

    # Build NodeLike instances for object, arguments, and the full invocation node
    object_node = _NodeLike(start_byte=obj_start, end_byte=obj_end)
    args_node = _NodeLike(start_byte=args_start, end_byte=args_end)
    invocation_node = _NodeLike(start_byte=node_start, end_byte=node_end, children={
        "object": object_node,
        "arguments": args_node,
    })

    # Call the function under test
    codeflash_output = transformer._build_target_call(
        node=invocation_node,
        wrapper_bytes=wrapper_bytes,
        content_bytes=content_bytes,
        start_byte=node_start,
        end_byte=node_end,
        base_offset=10,  # arbitrary base offset to verify offset arithmetic
    ); result = codeflash_output # 5.50μs -> 6.35μs (13.4% slower)

def test_build_target_call_without_receiver():
    # When there is no object (receiver) node, receiver should be None.
    transformer = JavaAssertTransformer(function_name="run")

    call_str = "run(1,2)"
    wrapper_bytes = call_str.encode("utf8")
    content_bytes = wrapper_bytes

    # No 'object' child; only arguments
    args_start = call_str.index("(")
    args_end = len(call_str)
    invocation_node = _NodeLike(start_byte=0, end_byte=args_end, children={
        "arguments": _NodeLike(start_byte=args_start, end_byte=args_end)
    })

    codeflash_output = transformer._build_target_call(
        node=invocation_node,
        wrapper_bytes=wrapper_bytes,
        content_bytes=content_bytes,
        start_byte=0,
        end_byte=args_end,
        base_offset=0,
    ); result = codeflash_output # 4.61μs -> 4.29μs (7.49% faster)

def test_build_target_call_without_arguments():
    # When there is no arguments node, arguments string should be empty.
    transformer = JavaAssertTransformer(function_name="doStuff")

    call_str = "obj.doStuff"
    wrapper_bytes = call_str.encode("utf8")
    content_bytes = wrapper_bytes

    # Provide object node but no 'arguments' child
    object_node = _NodeLike(start_byte=0, end_byte=len("obj"))
    invocation_node = _NodeLike(start_byte=0, end_byte=len(call_str), children={
        "object": object_node
    })

    codeflash_output = transformer._build_target_call(
        node=invocation_node,
        wrapper_bytes=wrapper_bytes,
        content_bytes=content_bytes,
        start_byte=0,
        end_byte=len(call_str),
        base_offset=5,
    ); result = codeflash_output # 4.03μs -> 3.61μs (11.7% faster)

def test_arguments_with_no_parentheses_preserved():
    # If the retrieved arguments text does not both start and end with parentheses,
    # it should not be stripped. This tests the defensive path.
    transformer = JavaAssertTransformer(function_name="m")

    # craft wrapper text where the 'arguments' slice is "1, 2" (no parentheses)
    full = "a.m1,2"  # intentionally non-call-like format for the edge case
    wrapper_bytes = full.encode("utf8")
    content_bytes = wrapper_bytes

    # treat "1,2" as the arguments slice (no enclosing parentheses)
    invocation_node = _NodeLike(start_byte=0, end_byte=len(full), children={
        "object": _NodeLike(start_byte=0, end_byte=1),
        "arguments": _NodeLike(start_byte=3, end_byte=5),
    })

    codeflash_output = transformer._build_target_call(
        node=invocation_node,
        wrapper_bytes=wrapper_bytes,
        content_bytes=content_bytes,
        start_byte=0,
        end_byte=len(full),
        base_offset=0,
    ); result = codeflash_output # 4.22μs -> 3.81μs (10.8% faster)

def test_multibyte_characters_affect_char_offsets_correctly():
    # Ensure multi-byte UTF-8 characters before the start_byte are counted as
    # multiple bytes but a single character when converting to char offsets.
    transformer = JavaAssertTransformer(function_name="x")

    # prefix contains a multi-byte character (e.g., 'ā' which is 2 bytes in utf-8)
    prefix = "preā"  # 'ā' will be multi-byte
    call = "obj.x(7)"
    content = prefix + call
    wrapper_bytes = content.encode("utf8")
    content_bytes = wrapper_bytes

    # node covers the call only
    node_start = len(prefix.encode("utf8"))  # byte offset where call starts
    node_end = len(wrapper_bytes)
    invocation_node = _NodeLike(start_byte=node_start, end_byte=node_end, children={
        "object": _NodeLike(start_byte=node_start, end_byte=node_start + len("obj")),
        "arguments": _NodeLike(start_byte=node_start + len("obj.x"), end_byte=node_end),
    })

    # The function computes start_char by decoding up to start_byte; this should
    # account for multi-byte characters so the character index equals len(prefix).
    codeflash_output = transformer._build_target_call(
        node=invocation_node,
        wrapper_bytes=wrapper_bytes,
        content_bytes=content_bytes,
        start_byte=node_start,
        end_byte=node_end,
        base_offset=2,
    ); result = codeflash_output # 5.55μs -> 4.89μs (13.5% faster)

def test_build_target_call_with_large_argument_list():
    # Construct a very large comma-separated argument list (1000 items) to verify
    # function handles long argument texts and that parentheses are stripped.
    transformer = JavaAssertTransformer(function_name="big")

    # create a long argument list "0,1,2,...,999"
    args = ",".join(str(i) for i in range(1000))
    call = f"obj.big({args})"
    wrapper_bytes = call.encode("utf8")
    content_bytes = wrapper_bytes

    # build nodes with proper offsets
    obj_start = call.index("obj")
    obj_end = obj_start + len("obj")
    args_start = call.index("(")
    args_end = len(call)
    invocation_node = _NodeLike(start_byte=0, end_byte=len(call), children={
        "object": _NodeLike(start_byte=obj_start, end_byte=obj_end),
        "arguments": _NodeLike(start_byte=args_start, end_byte=args_end),
    })

    codeflash_output = transformer._build_target_call(
        node=invocation_node,
        wrapper_bytes=wrapper_bytes,
        content_bytes=content_bytes,
        start_byte=0,
        end_byte=len(call),
        base_offset=0,
    ); result = codeflash_output # 6.85μs -> 6.62μs (3.47% faster)

def test_repeated_calls_stability_and_performance():
    # Call _build_target_call 1000 times in a loop to ensure deterministic behavior
    # and reasonable performance for repeated invocations. Each call returns
    # consistent TargetCall objects.
    transformer = JavaAssertTransformer(function_name="loop")
    call = "a.loop(1)"
    wrapper_bytes = call.encode("utf8")
    content_bytes = wrapper_bytes
    invocation_node = _NodeLike(start_byte=0, end_byte=len(call), children={
        "object": _NodeLike(start_byte=0, end_byte=1),
        "arguments": _NodeLike(start_byte=6, end_byte=len(call)),
    })

    # perform the repeated calls and verify each result
    results = []
    for _ in range(1000):
        codeflash_output = transformer._build_target_call(
            node=invocation_node,
            wrapper_bytes=wrapper_bytes,
            content_bytes=content_bytes,
            start_byte=0,
            end_byte=len(call),
            base_offset=7,
        ); tc = codeflash_output # 1.75ms -> 1.49ms (16.9% faster)
        results.append(tc)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
import os
import tempfile

# imports
import pytest
# imports from the modules under test
from codeflash.languages.java.parser import JavaAnalyzer, get_java_analyzer
from codeflash.languages.java.remove_asserts import (JavaAssertTransformer,
                                                     TargetCall)
from tree_sitter import Language, Node, Parser

# ============================================================================
# FIXTURES
# ============================================================================

@pytest.fixture
def java_analyzer():
    """Provide a JavaAnalyzer instance for tests."""
    return get_java_analyzer()

@pytest.fixture
def parser():
    """Provide a tree-sitter Parser for Java."""
    try:
        # Try to load the Java language library
        java_language = Language(
            os.path.join(
                os.path.dirname(__file__),
                "../../tree_sitter_java.so"
            ),
            "java"
        )
    except (FileNotFoundError, OSError):
        # Fallback: try standard installation location
        try:
            java_language = Language(
                os.path.join(
                    os.path.expanduser("~"),
                    ".local/lib/tree_sitter/java.so"
                ),
                "java"
            )
        except (FileNotFoundError, OSError):
            pytest.skip("tree-sitter Java library not available")
    
    parser = Parser()
    parser.set_language(java_language)
    return parser

@pytest.fixture
def transformer(java_analyzer):
    """Provide a JavaAssertTransformer instance for tests."""
    return JavaAssertTransformer("assertEquals", "org.junit.Assert.assertEquals", java_analyzer)

# ============================================================================
# HELPER FUNCTIONS
# ============================================================================

def parse_java_method_invocation(parser, code: str) -> Node:
    """Parse Java code and extract the first method_invocation node."""
    # Wrap code in a valid Java class structure if needed
    if not code.strip().startswith("class"):
        full_code = f"class Test {{\n  void test() {{\n    {code}\n  }}\n}}"
    else:
        full_code = code
    
    tree = parser.parse(full_code.encode('utf8'))
    
    # Find the first method_invocation node
    def find_invocation(node):
        if node.type == "method_invocation":
            return node
        for child in node.children:
            result = find_invocation(child)
            if result:
                return result
        return None
    
    invocation = find_invocation(tree.root_node)
    if not invocation:
        raise ValueError(f"No method_invocation found in code: {code}")
    return invocation

def test_build_target_call_with_receiver_and_arguments(transformer, parser):
    """Test building a target call with an object receiver and arguments."""
    # Parse a simple method invocation: assertEquals(expected, actual)
    code = "assertEquals(5, 10)"
    invocation = parse_java_method_invocation(parser, code)
    
    # Prepare wrapper bytes as the transformer would
    wrapper_prefix = "class _D { void _m() { _d("
    content = "assertEquals(5, 10)"
    wrapper_bytes = (wrapper_prefix + content + "); } }").encode('utf8')
    content_bytes = content.encode('utf8')
    
    # Calculate byte offsets within wrapper
    start_byte = len(wrapper_prefix.encode('utf8'))
    end_byte = start_byte + len(content_bytes)
    
    # Call the function under test
    codeflash_output = transformer._build_target_call(
        invocation,
        wrapper_bytes,
        content_bytes,
        start_byte,
        end_byte,
        base_offset=0
    ); result = codeflash_output

def test_build_target_call_without_receiver(transformer, parser):
    """Test building a target call with no object receiver (static method)."""
    code = "assertEquals(expected, actual)"
    invocation = parse_java_method_invocation(parser, code)
    
    wrapper_prefix = "class _D { void _m() { _d("
    content = "assertEquals(expected, actual)"
    wrapper_bytes = (wrapper_prefix + content + "); } }").encode('utf8')
    content_bytes = content.encode('utf8')
    
    start_byte = len(wrapper_prefix.encode('utf8'))
    end_byte = start_byte + len(content_bytes)
    
    codeflash_output = transformer._build_target_call(
        invocation,
        wrapper_bytes,
        content_bytes,
        start_byte,
        end_byte,
        base_offset=0
    ); result = codeflash_output

def test_build_target_call_arguments_extraction(transformer, parser):
    """Test that arguments are correctly extracted and parentheses are stripped."""
    code = "assertEquals(value1, value2)"
    invocation = parse_java_method_invocation(parser, code)
    
    wrapper_prefix = "class _D { void _m() { _d("
    content = "assertEquals(value1, value2)"
    wrapper_bytes = (wrapper_prefix + content + "); } }").encode('utf8')
    content_bytes = content.encode('utf8')
    
    start_byte = len(wrapper_prefix.encode('utf8'))
    end_byte = start_byte + len(content_bytes)
    
    codeflash_output = transformer._build_target_call(
        invocation,
        wrapper_bytes,
        content_bytes,
        start_byte,
        end_byte,
        base_offset=0
    ); result = codeflash_output

def test_build_target_call_position_calculation(transformer, parser):
    """Test that start and end positions are correctly calculated."""
    code = "assertEquals(a, b)"
    invocation = parse_java_method_invocation(parser, code)
    
    wrapper_prefix = "class _D { void _m() { _d("
    content = "assertEquals(a, b)"
    wrapper_bytes = (wrapper_prefix + content + "); } }").encode('utf8')
    content_bytes = content.encode('utf8')
    
    start_byte = len(wrapper_prefix.encode('utf8'))
    end_byte = start_byte + len(content_bytes)
    base_offset = 100  # Non-zero base offset
    
    codeflash_output = transformer._build_target_call(
        invocation,
        wrapper_bytes,
        content_bytes,
        start_byte,
        end_byte,
        base_offset
    ); result = codeflash_output

def test_build_target_call_full_call_preservation(transformer, parser):
    """Test that the full call text is preserved."""
    code = "assertEquals(expected, actual)"
    invocation = parse_java_method_invocation(parser, code)
    
    wrapper_prefix = "class _D { void _m() { _d("
    content = "assertEquals(expected, actual)"
    wrapper_bytes = (wrapper_prefix + content + "); } }").encode('utf8')
    content_bytes = content.encode('utf8')
    
    start_byte = len(wrapper_prefix.encode('utf8'))
    end_byte = start_byte + len(content_bytes)
    
    codeflash_output = transformer._build_target_call(
        invocation,
        wrapper_bytes,
        content_bytes,
        start_byte,
        end_byte,
        base_offset=0
    ); result = codeflash_output

def test_build_target_call_with_empty_arguments(transformer, parser):
    """Test building a target call with no arguments."""
    code = "assertEquals()"
    try:
        invocation = parse_java_method_invocation(parser, code)
    except ValueError:
        pytest.skip("Parser cannot handle empty argument lists in this context")
    
    wrapper_prefix = "class _D { void _m() { _d("
    content = "assertEquals()"
    wrapper_bytes = (wrapper_prefix + content + "); } }").encode('utf8')
    content_bytes = content.encode('utf8')
    
    start_byte = len(wrapper_prefix.encode('utf8'))
    end_byte = start_byte + len(content_bytes)
    
    codeflash_output = transformer._build_target_call(
        invocation,
        wrapper_bytes,
        content_bytes,
        start_byte,
        end_byte,
        base_offset=0
    ); result = codeflash_output

def test_build_target_call_with_complex_arguments(transformer, parser):
    """Test with complex nested arguments including method calls."""
    code = "assertEquals(obj.getValue(), expected)"
    invocation = parse_java_method_invocation(parser, code)
    
    wrapper_prefix = "class _D { void _m() { _d("
    content = "assertEquals(obj.getValue(), expected)"
    wrapper_bytes = (wrapper_prefix + content + "); } }").encode('utf8')
    content_bytes = content.encode('utf8')
    
    start_byte = len(wrapper_prefix.encode('utf8'))
    end_byte = start_byte + len(content_bytes)
    
    codeflash_output = transformer._build_target_call(
        invocation,
        wrapper_bytes,
        content_bytes,
        start_byte,
        end_byte,
        base_offset=0
    ); result = codeflash_output

def test_build_target_call_with_string_literal_arguments(transformer, parser):
    """Test with string literal arguments."""
    code = 'assertEquals("expected", "actual")'
    invocation = parse_java_method_invocation(parser, code)
    
    wrapper_prefix = "class _D { void _m() { _d("
    content = 'assertEquals("expected", "actual")'
    wrapper_bytes = (wrapper_prefix + content + "); } }").encode('utf8')
    content_bytes = content.encode('utf8')
    
    start_byte = len(wrapper_prefix.encode('utf8'))
    end_byte = start_byte + len(content_bytes)
    
    codeflash_output = transformer._build_target_call(
        invocation,
        wrapper_bytes,
        content_bytes,
        start_byte,
        end_byte,
        base_offset=0
    ); result = codeflash_output

def test_build_target_call_with_numeric_arguments(transformer, parser):
    """Test with various numeric argument types."""
    code = "assertEquals(42, 42.0)"
    invocation = parse_java_method_invocation(parser, code)
    
    wrapper_prefix = "class _D { void _m() { _d("
    content = "assertEquals(42, 42.0)"
    wrapper_bytes = (wrapper_prefix + content + "); } }").encode('utf8')
    content_bytes = content.encode('utf8')
    
    start_byte = len(wrapper_prefix.encode('utf8'))
    end_byte = start_byte + len(content_bytes)
    
    codeflash_output = transformer._build_target_call(
        invocation,
        wrapper_bytes,
        content_bytes,
        start_byte,
        end_byte,
        base_offset=0
    ); result = codeflash_output

def test_build_target_call_with_unicode_arguments(transformer, parser):
    """Test with unicode characters in arguments."""
    code = 'assertEquals("café", "café")'
    invocation = parse_java_method_invocation(parser, code)
    
    wrapper_prefix = "class _D { void _m() { _d("
    content = 'assertEquals("café", "café")'
    wrapper_bytes = (wrapper_prefix + content + "); } }").encode('utf8')
    content_bytes = content.encode('utf8')
    
    start_byte = len(wrapper_prefix.encode('utf8'))
    end_byte = start_byte + len(content_bytes)
    
    codeflash_output = transformer._build_target_call(
        invocation,
        wrapper_bytes,
        content_bytes,
        start_byte,
        end_byte,
        base_offset=0
    ); result = codeflash_output

def test_build_target_call_with_zero_base_offset(transformer, parser):
    """Test with base_offset of 0."""
    code = "assertEquals(x, y)"
    invocation = parse_java_method_invocation(parser, code)
    
    wrapper_prefix = "class _D { void _m() { _d("
    content = "assertEquals(x, y)"
    wrapper_bytes = (wrapper_prefix + content + "); } }").encode('utf8')
    content_bytes = content.encode('utf8')
    
    start_byte = len(wrapper_prefix.encode('utf8'))
    end_byte = start_byte + len(content_bytes)
    
    codeflash_output = transformer._build_target_call(
        invocation,
        wrapper_bytes,
        content_bytes,
        start_byte,
        end_byte,
        base_offset=0
    ); result = codeflash_output

def test_build_target_call_with_large_base_offset(transformer, parser):
    """Test with a large base_offset value."""
    code = "assertEquals(a, b)"
    invocation = parse_java_method_invocation(parser, code)
    
    wrapper_prefix = "class _D { void _m() { _d("
    content = "assertEquals(a, b)"
    wrapper_bytes = (wrapper_prefix + content + "); } }").encode('utf8')
    content_bytes = content.encode('utf8')
    
    start_byte = len(wrapper_prefix.encode('utf8'))
    end_byte = start_byte + len(content_bytes)
    large_offset = 1000000
    
    codeflash_output = transformer._build_target_call(
        invocation,
        wrapper_bytes,
        content_bytes,
        start_byte,
        end_byte,
        base_offset=large_offset
    ); result = codeflash_output

def test_build_target_call_byte_offset_at_boundary(transformer, parser):
    """Test with byte offsets at exact boundaries."""
    code = "assertEquals(val, val)"
    invocation = parse_java_method_invocation(parser, code)
    
    wrapper_prefix = "class _D { void _m() { _d("
    content = "assertEquals(val, val)"
    wrapper_bytes = (wrapper_prefix + content + "); } }").encode('utf8')
    content_bytes = content.encode('utf8')
    
    # Use exact byte boundaries
    start_byte = len(wrapper_prefix.encode('utf8'))
    end_byte = start_byte + len(content_bytes)
    
    codeflash_output = transformer._build_target_call(
        invocation,
        wrapper_bytes,
        content_bytes,
        start_byte,
        end_byte,
        base_offset=0
    ); result = codeflash_output

def test_build_target_call_with_many_arguments(transformer, parser):
    """Test with a large number of arguments."""
    # Create an assertion with many arguments (simulating extreme case)
    args = ", ".join([f"arg{i}" for i in range(100)])
    code = f"assertEquals({args})"
    
    try:
        invocation = parse_java_method_invocation(parser, code)
    except ValueError:
        pytest.skip("Parser cannot handle 100 arguments in this context")
    
    wrapper_prefix = "class _D { void _m() { _d("
    content = code
    wrapper_bytes = (wrapper_prefix + content + "); } }").encode('utf8')
    content_bytes = content.encode('utf8')
    
    start_byte = len(wrapper_prefix.encode('utf8'))
    end_byte = start_byte + len(content_bytes)
    
    codeflash_output = transformer._build_target_call(
        invocation,
        wrapper_bytes,
        content_bytes,
        start_byte,
        end_byte,
        base_offset=0
    ); result = codeflash_output

def test_build_target_call_with_deeply_nested_calls(transformer, parser):
    """Test with deeply nested method calls in arguments."""
    # Create nested method calls
    code = "assertEquals(a.b().c().d().e(), value)"
    invocation = parse_java_method_invocation(parser, code)
    
    wrapper_prefix = "class _D { void _m() { _d("
    content = code
    wrapper_bytes = (wrapper_prefix + content + "); } }").encode('utf8')
    content_bytes = content.encode('utf8')
    
    start_byte = len(wrapper_prefix.encode('utf8'))
    end_byte = start_byte + len(content_bytes)
    
    codeflash_output = transformer._build_target_call(
        invocation,
        wrapper_bytes,
        content_bytes,
        start_byte,
        end_byte,
        base_offset=0
    ); result = codeflash_output

def test_build_target_call_repeated_calls(transformer, parser):
    """Test multiple calls to _build_target_call with different inputs."""
    inputs = [
        ("assertEquals(a, b)", "a, b"),
        ("assertEquals(x, y)", "x, y"),
        ("assertEquals(1, 2)", "1, 2"),
    ]
    
    for code, expected_args_part in inputs:
        invocation = parse_java_method_invocation(parser, code)
        
        wrapper_prefix = "class _D { void _m() { _d("
        content = code
        wrapper_bytes = (wrapper_prefix + content + "); } }").encode('utf8')
        content_bytes = content.encode('utf8')
        
        start_byte = len(wrapper_prefix.encode('utf8'))
        end_byte = start_byte + len(content_bytes)
        
        codeflash_output = transformer._build_target_call(
            invocation,
            wrapper_bytes,
            content_bytes,
            start_byte,
            end_byte,
            base_offset=0
        ); result = codeflash_output

def test_build_target_call_with_very_long_argument_string(transformer, parser):
    """Test with very long argument strings (1000+ characters)."""
    # Create a very long string literal
    long_string = "x" * 1000
    code = f'assertEquals("{long_string}", "test")'
    invocation = parse_java_method_invocation(parser, code)
    
    wrapper_prefix = "class _D { void _m() { _d("
    content = code
    wrapper_bytes = (wrapper_prefix + content + "); } }").encode('utf8')
    content_bytes = content.encode('utf8')
    
    start_byte = len(wrapper_prefix.encode('utf8'))
    end_byte = start_byte + len(content_bytes)
    
    codeflash_output = transformer._build_target_call(
        invocation,
        wrapper_bytes,
        content_bytes,
        start_byte,
        end_byte,
        base_offset=0
    ); result = codeflash_output

def test_build_target_call_with_large_content_bytes(transformer, parser):
    """Test with large content_bytes (10000+ characters)."""
    # Build large content with padding
    padding = "// comment\n" * 100
    code = "assertEquals(a, b)"
    full_content = padding + code
    
    invocation = parse_java_method_invocation(parser, code)
    
    wrapper_prefix = "class _D { void _m() { _d("
    wrapper_bytes = (wrapper_prefix + code + "); } }").encode('utf8')
    content_bytes = full_content.encode('utf8')
    
    # Adjust byte offsets to account for padding
    start_byte = len(padding.encode('utf8')) + len(wrapper_prefix.encode('utf8'))
    end_byte = start_byte + len(code.encode('utf8'))
    
    codeflash_output = transformer._build_target_call(
        invocation,
        wrapper_bytes,
        content_bytes,
        start_byte,
        end_byte,
        base_offset=0
    ); result = codeflash_output

def test_build_target_call_performance_with_accumulating_base_offsets(transformer, parser):
    """Test performance with a series of calls with accumulating base offsets."""
    code = "assertEquals(test, expected)"
    invocation = parse_java_method_invocation(parser, code)
    
    wrapper_prefix = "class _D { void _m() { _d("
    content = code
    wrapper_bytes = (wrapper_prefix + content + "); } }").encode('utf8')
    content_bytes = content.encode('utf8')
    
    start_byte = len(wrapper_prefix.encode('utf8'))
    end_byte = start_byte + len(content_bytes)
    
    results = []
    # Simulate multiple calls with increasing offsets
    for i in range(1000):
        base_offset = i * 100
        codeflash_output = transformer._build_target_call(
            invocation,
            wrapper_bytes,
            content_bytes,
            start_byte,
            end_byte,
            base_offset=base_offset
        ); result = codeflash_output
        results.append(result)
    # Verify offsets increase correctly
    for i, result in enumerate(results):
        expected_offset = i * 100
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To test or edit this optimization locally git merge codeflash/optimize-pr1495-2026-02-16T07.15.28

Click to see suggested changes
Suggested change
get_text = self.analyzer.get_node_text
object_node = node.child_by_field_name("object")
args_node = node.child_by_field_name("arguments")
args_text = get_text(args_node, wrapper_bytes) if args_node else ""
# argument_list node includes parens, strip them
if args_text.startswith("(") and args_text.endswith(")"):
args_text = args_text[1:-1]
# Byte offsets -> char offsets for correct Python string indexing
start_char = len(content_bytes[:start_byte].decode("utf8"))
end_char = len(content_bytes[:end_byte].decode("utf8"))
return TargetCall(
receiver=get_text(object_node, wrapper_bytes) if object_node else None,
method_name=self.func_name,
arguments=args_text,
full_call=get_text(node, wrapper_bytes),
object_node = node.child_by_field_name("object")
args_node = node.child_by_field_name("arguments")
# Extract receiver text efficiently
receiver = None
if object_node:
obj_start = object_node.start_byte
obj_end = object_node.end_byte
receiver = wrapper_bytes[obj_start:obj_end].decode("utf8")
# Extract arguments text efficiently
args_text = ""
if args_node:
args_start = args_node.start_byte
args_end = args_node.end_byte
args_text = wrapper_bytes[args_start:args_end].decode("utf8")
# argument_list node includes parens, strip them
if len(args_text) >= 2 and args_text[0] == "(" and args_text[-1] == ")":
args_text = args_text[1:-1]
# Extract full call text
node_start = node.start_byte
node_end = node.end_byte
full_call = wrapper_bytes[node_start:node_end].decode("utf8")
# Byte offsets -> char offsets for correct Python string indexing
# Byte offsets -> char offsets for correct Python string indexing
start_char = len(content_bytes[:start_byte].decode("utf8"))
end_char = len(content_bytes[:end_byte].decode("utf8"))
return TargetCall(
receiver=receiver,
method_name=self.func_name,
arguments=args_text,
full_call=full_call,

Static Badge

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant