Skip to content

⚡️ Speed up method JavaAssertTransformer._infer_type_from_assertion_args by 32% in PR #1655 (feat/add/void/func)#1765

Closed
codeflash-ai[bot] wants to merge 1 commit intofeat/add/void/funcfrom
codeflash/optimize-pr1655-2026-03-04T05.37.57
Closed

⚡️ Speed up method JavaAssertTransformer._infer_type_from_assertion_args by 32% in PR #1655 (feat/add/void/func)#1765
codeflash-ai[bot] wants to merge 1 commit intofeat/add/void/funcfrom
codeflash/optimize-pr1655-2026-03-04T05.37.57

Conversation

@codeflash-ai
Copy link
Contributor

@codeflash-ai codeflash-ai bot commented Mar 4, 2026

⚡️ This pull request contains optimizations for PR #1655

If you approve this dependent PR, these changes will be merged into the original PR branch feat/add/void/func.

This PR will be automatically closed if the original PR is merged.


📄 32% (0.32x) speedup for JavaAssertTransformer._infer_type_from_assertion_args in codeflash/languages/java/remove_asserts.py

⏱️ Runtime : 6.32 milliseconds 4.78 milliseconds (best of 140 runs)

📝 Explanation and details

The optimization replaced character-by-character list accumulation in _extract_first_arg with direct substring slicing (tracking start and end indices instead of building a list via cur.append()), eliminating repeated list operations and the final "".join(cur) call. For the JUnit4 message-string case (where assertEquals("msg", expected, actual) requires extracting the second argument), it introduced _second_arg_if_message, a lightweight two-comma scanner that stops as soon as it confirms three top-level arguments exist and extracts only the second one, avoiding the original _split_top_level_args which built a complete list of all arguments. Line profiler confirms _split_top_level_args accounted for 36.6% of original runtime and _extract_first_arg's list operations added 10.1%; the new index-based approach cuts total runtime by 32% with no behavioral changes.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 1193 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 92.9%
🌀 Click to see Generated Regression Tests
import pytest  # used for our unit tests
from codeflash.languages.java.remove_asserts import JavaAssertTransformer

def test_infer_int_from_simple_integer_literal():
    # Create a real instance of the transformer (use function name "m")
    transformer = JavaAssertTransformer("m")
    # Typical assertEquals with an integer literal should map to "int"
    codeflash_output = transformer._infer_type_from_assertion_args('assertEquals(42, obj.foo());', 'assertEquals'); res = codeflash_output # 6.37μs -> 6.41μs (0.624% slower)

def test_infer_long_from_literal_with_L_suffix():
    transformer = JavaAssertTransformer("m")
    # Long literals end with L or l -> expect "long"
    codeflash_output = transformer._infer_type_from_assertion_args('assertEquals(1234567890123L, x);', 'assertEquals'); res1 = codeflash_output # 8.42μs -> 8.27μs (1.83% faster)
    codeflash_output = transformer._infer_type_from_assertion_args('assertEquals(-1l, x);', 'assertEquals'); res2 = codeflash_output # 3.07μs -> 3.00μs (2.03% faster)

def test_infer_double_and_float_literals():
    transformer = JavaAssertTransformer("m")
    # Plain decimal with dot (or trailing d) -> double
    codeflash_output = transformer._infer_type_from_assertion_args('assertEquals(3.14, x);', 'assertEquals') # 5.40μs -> 5.51μs (2.00% slower)
    codeflash_output = transformer._infer_type_from_assertion_args('assertEquals(2.0D, x);', 'assertEquals') # 2.73μs -> 2.58μs (5.80% faster)
    # Trailing f/F -> float
    codeflash_output = transformer._infer_type_from_assertion_args('assertEquals(1.0f, x);', 'assertEquals') # 1.74μs -> 1.66μs (4.87% faster)
    codeflash_output = transformer._infer_type_from_assertion_args('assertEquals(5F, x);', 'assertEquals') # 1.75μs -> 1.75μs (0.000% faster)

def test_infer_char_and_string_and_boolean_literals():
    transformer = JavaAssertTransformer("m")
    # Char literal -> 'char'
    codeflash_output = transformer._infer_type_from_assertion_args("assertEquals('c', x);", 'assertEquals') # 8.53μs -> 7.95μs (7.31% faster)
    # String literal -> 'String'
    codeflash_output = transformer._infer_type_from_assertion_args('assertEquals("hello", x);', 'assertEquals') # 8.89μs -> 6.97μs (27.4% faster)
    # Boolean literals -> 'boolean'
    codeflash_output = transformer._infer_type_from_assertion_args('assertEquals(true, x);', 'assertEquals') # 1.63μs -> 1.58μs (3.16% faster)
    codeflash_output = transformer._infer_type_from_assertion_args('assertEquals(false, x);', 'assertEquals') # 1.23μs -> 1.17μs (5.12% faster)

def test_missing_parentheses_or_empty_args_return_object():
    transformer = JavaAssertTransformer("m")
    # No '(' should return Object
    codeflash_output = transformer._infer_type_from_assertion_args('assertEquals', 'assertEquals') # 761ns -> 771ns (1.30% slower)
    # Empty parentheses or only whitespace -> Object
    codeflash_output = transformer._infer_type_from_assertion_args('assertEquals()', 'assertEquals') # 1.41μs -> 1.41μs (0.071% slower)
    codeflash_output = transformer._infer_type_from_assertion_args('assertEquals(   )', 'assertEquals') # 3.93μs -> 4.01μs (1.97% slower)

def test_trailing_closing_paren_without_semicolon():
    transformer = JavaAssertTransformer("m")
    # The function accepts either ) or ); endings
    codeflash_output = transformer._infer_type_from_assertion_args('assertEquals("x")', 'assertEquals') # 8.14μs -> 6.36μs (28.0% faster)
    codeflash_output = transformer._infer_type_from_assertion_args('assertEquals(10)', 'assertEquals') # 3.32μs -> 3.36μs (1.19% slower)

def test_junit_message_string_is_skipped_for_assertEquals_and_assertNotEquals():
    transformer = JavaAssertTransformer("m")
    # JUnit4 style: assertEquals(String message, expected, actual)
    # First arg is a string and there are 3 args => expected should be the second arg
    src = 'assertEquals("should be five", 5, obj.get());'
    codeflash_output = transformer._infer_type_from_assertion_args(src, 'assertEquals') # 16.6μs -> 11.3μs (46.7% faster)
    # Also ensure assertNotEquals respects same behavior
    src2 = 'assertNotEquals("boom", 3L, x);'
    codeflash_output = transformer._infer_type_from_assertion_args(src2, 'assertNotEquals') # 9.23μs -> 6.89μs (33.9% faster)

def test_nested_comma_in_first_arg_uses_extractor_and_defaults_to_object_for_expressions():
    transformer = JavaAssertTransformer("m")
    # A first argument that is a method invocation with its own commas:
    # e.g., Arrays.asList(1, 2) should be treated as a non-literal expression -> "Object"
    src = 'assertEquals(Arrays.asList(1, 2), somethingElse());'
    codeflash_output = transformer._infer_type_from_assertion_args(src, 'assertEquals') # 12.3μs -> 11.1μs (10.3% faster)

def test_casted_expression_infers_cast_type():
    transformer = JavaAssertTransformer("m")
    # If the expected argument begins with a cast like (long) value -> infer 'long'
    codeflash_output = transformer._infer_type_from_assertion_args('assertEquals((long) foo(), x);', 'assertEquals') # 11.0μs -> 10.4μs (5.97% faster)
    # Also test another cast type
    codeflash_output = transformer._infer_type_from_assertion_args('assertEquals((MyType) get(), x);', 'assertEquals') # 7.46μs -> 6.85μs (8.92% faster)

def test_large_scale_inference_many_calls_is_deterministic_and_correct():
    transformer = JavaAssertTransformer("m")
    # Build a set of 1000 assertions cycling through known literal forms.
    literals = [
        ("42", "int"),
        ("-7", "int"),
        ("9999999999L", "long"),
        ("0l", "long"),
        ("3.14159", "double"),
        ("2.0D", "double"),
        ("1.5f", "float"),
        ("7F", "float"),
        ("'z'", "char"),
        ('"big string"', "String"),
        ("true", "boolean"),
        ("false", "boolean"),
        # a non-literal expression -> Object
        ("computeValue()", "Object"),
        # nested expression with commas -> Object
        ("Collections.unmodifiableList(Arrays.asList(1,2,3))", "Object"),
        # casted expressions
        ("(int) something()", "int"),
        ("(CustomType) make()", "CustomType"),
    ]
    # Generate many tests and assert each inference matches expected mapping
    n = 1000
    for i in range(n):
        lit, expected = literals[i % len(literals)]
        # Ensure variations with/without semicolon, and with extra whitespace
        if i % 3 == 0:
            src = f'assertEquals({lit}, result);'
        elif i % 3 == 1:
            src = f'assertEquals(  {lit}  )'
        else:
            # include a message-case occasionally to ensure it doesn't break non-string first args
            src = f'assertEquals("msg", {lit}, other);'
        codeflash_output = transformer._infer_type_from_assertion_args(src, 'assertEquals'); got = codeflash_output # 5.31ms -> 3.99ms (33.3% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
import pytest
from codeflash.languages.java.parser import JavaAnalyzer, get_java_analyzer
from codeflash.languages.java.remove_asserts import JavaAssertTransformer

class TestInferTypeFromAssertionArgs:
    """Unit tests for _infer_type_from_assertion_args method."""

    def setup_method(self):
        """Set up a transformer instance for each test."""
        self.transformer = JavaAssertTransformer(
            function_name="test_method",
            qualified_name="com.example.TestClass.test_method",
            analyzer=get_java_analyzer(),
            is_void=False
        )

    # ========== BASIC TESTS ==========

    def test_basic_int_literal(self):
        """Test inference with a simple integer literal."""
        # When passing a plain integer, type should be inferred as int
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(42);',
            'assertEquals'
        ); result = codeflash_output # 6.04μs -> 6.12μs (1.32% slower)

    def test_basic_long_literal_with_l_suffix(self):
        """Test inference with a long literal using 'l' suffix."""
        # Long literals end with 'l' or 'L'
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(42l);',
            'assertEquals'
        ); result = codeflash_output # 5.68μs -> 5.76μs (1.39% slower)

    def test_basic_long_literal_with_L_suffix(self):
        """Test inference with a long literal using 'L' suffix."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(42L);',
            'assertEquals'
        ); result = codeflash_output # 5.53μs -> 5.58μs (0.914% slower)

    def test_basic_float_literal(self):
        """Test inference with a float literal."""
        # Float literals end with 'f' or 'F'
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(3.14f);',
            'assertEquals'
        ); result = codeflash_output # 4.16μs -> 4.28μs (2.83% slower)

    def test_basic_double_literal(self):
        """Test inference with a double literal."""
        # Double literals can end with 'd', 'D', or have decimal without suffix
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(3.14);',
            'assertEquals'
        ); result = codeflash_output # 5.07μs -> 5.20μs (2.52% slower)

    def test_basic_double_literal_with_d_suffix(self):
        """Test inference with a double literal using 'd' suffix."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(3.14d);',
            'assertEquals'
        ); result = codeflash_output # 5.13μs -> 4.96μs (3.43% faster)

    def test_basic_char_literal_single_quote(self):
        """Test inference with a single character literal."""
        # Character literals are enclosed in single quotes
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            "assertEquals('a');",
            'assertEquals'
        ); result = codeflash_output # 5.34μs -> 5.35μs (0.187% slower)

    def test_basic_string_literal(self):
        """Test inference with a string literal."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals("hello");',
            'assertEquals'
        ); result = codeflash_output # 8.98μs -> 7.03μs (27.6% faster)

    def test_basic_true_boolean_literal(self):
        """Test inference with boolean literal true."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(true);',
            'assertEquals'
        ); result = codeflash_output # 2.45μs -> 2.51μs (2.39% slower)

    def test_basic_false_boolean_literal(self):
        """Test inference with boolean literal false."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(false);',
            'assertEquals'
        ); result = codeflash_output # 2.44μs -> 2.39μs (1.71% faster)

    def test_basic_null_literal(self):
        """Test inference with null literal."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(null);',
            'assertEquals'
        ); result = codeflash_output # 2.50μs -> 2.44μs (2.45% faster)

    def test_method_call_argument(self):
        """Test inference with a method call as argument."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(obj.getValue());',
            'assertEquals'
        ); result = codeflash_output # 5.49μs -> 5.43μs (1.10% faster)

    def test_variable_reference(self):
        """Test inference with a variable reference."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(myVariable);',
            'assertEquals'
        ); result = codeflash_output # 5.40μs -> 5.33μs (1.31% faster)

    # ========== MESSAGE TESTS (JUnit 4 Format) ==========

    def test_junit4_with_message_string_three_args(self):
        """Test assertEquals with message (JUnit 4: message, expected, actual)."""
        # When first arg is a string and there are 3+ args, use second arg
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals("test message", 42, actualValue);',
            'assertEquals'
        ); result = codeflash_output # 19.1μs -> 12.8μs (49.3% faster)

    def test_junit4_with_message_string_three_args_long(self):
        """Test assertEquals with message and long expected value."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals("message", 100L, actual);',
            'assertEquals'
        ); result = codeflash_output # 16.5μs -> 12.1μs (36.8% faster)

    def test_junit4_with_message_string_three_args_float(self):
        """Test assertEquals with message and float expected value."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals("message", 2.5f, actual);',
            'assertEquals'
        ); result = codeflash_output # 15.1μs -> 10.5μs (43.9% faster)

    def test_assertNotEquals_with_message(self):
        """Test assertNotEquals with message string (JUnit 4 format)."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertNotEquals("msg", 10, other);',
            'assertNotEquals'
        ); result = codeflash_output # 14.7μs -> 10.5μs (40.2% faster)

    # ========== EDGE CASES ==========

    def test_empty_args(self):
        """Test with no arguments in parentheses."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals();',
            'assertEquals'
        ); result = codeflash_output # 1.74μs -> 1.78μs (2.19% slower)

    def test_no_opening_paren(self):
        """Test when there is no opening parenthesis."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals',
            'assertEquals'
        ); result = codeflash_output # 761ns -> 821ns (7.31% slower)

    def test_whitespace_around_args(self):
        """Test with significant whitespace around arguments."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(  42  );',
            'assertEquals'
        ); result = codeflash_output # 5.99μs -> 5.93μs (1.01% faster)

    def test_negative_integer(self):
        """Test with a negative integer literal."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(-42);',
            'assertEquals'
        ); result = codeflash_output # 5.86μs -> 5.80μs (1.03% faster)

    def test_negative_long(self):
        """Test with a negative long literal."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(-100L);',
            'assertEquals'
        ); result = codeflash_output # 5.74μs -> 5.86μs (2.05% slower)

    def test_negative_float(self):
        """Test with a negative float literal."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(-3.5f);',
            'assertEquals'
        ); result = codeflash_output # 4.11μs -> 4.15μs (0.964% slower)

    def test_negative_double(self):
        """Test with a negative double literal."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(-3.5);',
            'assertEquals'
        ); result = codeflash_output # 4.98μs -> 5.02μs (0.777% slower)

    def test_char_with_escape_sequence(self):
        """Test character literal with escape sequence."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            "assertEquals('\\n');",
            'assertEquals'
        ); result = codeflash_output # 5.36μs -> 5.50μs (2.55% slower)

    def test_string_with_escaped_quotes(self):
        """Test string literal containing escaped quotes."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals("say \\"hello\\"");',
            'assertEquals'
        ); result = codeflash_output # 10.2μs -> 7.62μs (33.9% faster)

    def test_empty_string(self):
        """Test with empty string literal."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals("");',
            'assertEquals'
        ); result = codeflash_output # 7.72μs -> 6.37μs (21.2% faster)

    def test_nested_method_calls(self):
        """Test with nested method calls in parentheses."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(obj.method1(obj.method2()));',
            'assertEquals'
        ); result = codeflash_output # 5.43μs -> 5.34μs (1.69% faster)

    def test_method_call_with_int_argument(self):
        """Test with method call that takes integer argument."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(getValue(42));',
            'assertEquals'
        ); result = codeflash_output # 5.35μs -> 5.50μs (2.71% slower)

    def test_generic_type_in_argument(self):
        """Test with generics in method call."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(map.get("key"));',
            'assertEquals'
        ); result = codeflash_output # 5.34μs -> 5.30μs (0.755% faster)

    def test_array_access(self):
        """Test with array element access."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(array[0]);',
            'assertEquals'
        ); result = codeflash_output # 5.31μs -> 5.34μs (0.562% slower)

    def test_cast_expression(self):
        """Test with type cast."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals((int)value);',
            'assertEquals'
        ); result = codeflash_output # 6.16μs -> 6.16μs (0.000% faster)

    def test_ternary_operator(self):
        """Test with ternary conditional operator."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(flag ? 1 : 0);',
            'assertEquals'
        ); result = codeflash_output # 5.32μs -> 5.22μs (1.94% faster)

    def test_binary_operation(self):
        """Test with binary arithmetic operation."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(5 + 3);',
            'assertEquals'
        ); result = codeflash_output # 5.67μs -> 5.73μs (1.05% slower)

    def test_single_quoted_empty_char(self):
        """Test that invalid empty char literal defaults to Object."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            "assertEquals('');",
            'assertEquals'
        ); result = codeflash_output # 5.31μs -> 5.12μs (3.69% faster)

    def test_double_without_decimal_part_with_d(self):
        """Test double literal like 5d (integer part only)."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(5d);',
            'assertEquals'
        ); result = codeflash_output # 4.87μs -> 4.97μs (2.03% slower)

    def test_float_without_decimal_part(self):
        """Test float literal like 5f (integer part only)."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(5f);',
            'assertEquals'
        ); result = codeflash_output # 4.16μs -> 4.14μs (0.483% faster)

    def test_zero_int(self):
        """Test with zero as integer."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(0);',
            'assertEquals'
        ); result = codeflash_output # 4.66μs -> 4.72μs (1.29% slower)

    def test_zero_long(self):
        """Test with zero as long."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(0L);',
            'assertEquals'
        ); result = codeflash_output # 5.21μs -> 5.25μs (0.762% slower)

    def test_zero_float(self):
        """Test with zero as float."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(0f);',
            'assertEquals'
        ); result = codeflash_output # 4.07μs -> 4.14μs (1.69% slower)

    def test_zero_double(self):
        """Test with zero as double."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(0.0);',
            'assertEquals'
        ); result = codeflash_output # 4.99μs -> 5.03μs (0.795% slower)

    # ========== MULTIPLE ARGUMENTS TESTS ==========

    def test_two_int_arguments_takes_first(self):
        """Test that with two int args, first is used."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(42, 43);',
            'assertEquals'
        ); result = codeflash_output # 6.32μs -> 6.41μs (1.42% slower)

    def test_two_long_arguments_takes_first(self):
        """Test that with two long args, first is used."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(42L, 43L);',
            'assertEquals'
        ); result = codeflash_output # 5.89μs -> 6.18μs (4.71% slower)

    def test_two_mixed_type_arguments_takes_first(self):
        """Test that with mixed types, first arg type is used."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(42, "string");',
            'assertEquals'
        ); result = codeflash_output # 6.22μs -> 6.15μs (1.14% faster)

    def test_two_arguments_reversed_types(self):
        """Test first argument determines type even if second is different."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals("string", 42);',
            'assertEquals'
        ); result = codeflash_output # 13.6μs -> 10.8μs (26.6% faster)

    def test_three_arguments_non_string_first(self):
        """Test three args where first is not a string (shouldn't trigger JUnit4 logic)."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(42, 43, 44);',
            'assertEquals'
        ); result = codeflash_output # 6.17μs -> 6.13μs (0.636% faster)

    def test_message_with_spaces_and_commas_in_string(self):
        """Test message containing commas doesn't confuse argument parsing."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals("message, with, commas", 100, actual);',
            'assertEquals'
        ); result = codeflash_output # 20.0μs -> 13.8μs (44.7% faster)

    def test_message_with_nested_parens_in_string(self):
        """Test message containing parentheses doesn't break parsing."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals("message(with(parens))", 50L, actual);',
            'assertEquals'
        ); result = codeflash_output # 19.1μs -> 13.2μs (44.9% faster)

    def test_complex_method_call_as_message(self):
        """Test when first argument is a method call (not string literal)."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(getMsg(), 99, other);',
            'assertEquals'
        ); result = codeflash_output # 10.2μs -> 9.53μs (6.83% faster)

    def test_trailing_semicolon_removed(self):
        """Test that trailing semicolon is properly stripped."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(42);',
            'assertEquals'
        ); result = codeflash_output # 5.75μs -> 5.54μs (3.79% faster)

    def test_only_closing_paren_removed(self):
        """Test handling of args_str ending with only closing paren."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(100)',
            'assertEquals'
        ); result = codeflash_output # 5.91μs -> 5.90μs (0.152% faster)

    # ========== EXTRACT_FIRST_ARG FALLBACK TESTS ==========

    def test_extract_first_arg_with_nested_generics(self):
        """Test _extract_first_arg with generics containing commas."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(map.get("key"), actual);',
            'assertEquals'
        ); result = codeflash_output # 11.3μs -> 10.2μs (10.2% faster)

    def test_extract_first_arg_with_nested_angles(self):
        """Test extraction respects angle bracket nesting."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(list.of(1, 2), actual);',
            'assertEquals'
        ); result = codeflash_output # 11.3μs -> 9.90μs (14.3% faster)

    def test_extract_first_arg_with_quoted_comma(self):
        """Test that commas inside quotes don't terminate the argument."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals("a,b", actual);',
            'assertEquals'
        ); result = codeflash_output # 13.6μs -> 11.0μs (23.1% faster)

    def test_extract_first_arg_with_single_quoted_comma(self):
        """Test that commas inside single quotes don't terminate the argument."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            "assertEquals('x', actual);",
            'assertEquals'
        ); result = codeflash_output # 8.36μs -> 7.75μs (7.88% faster)

    def test_extract_first_arg_with_escaped_quote_in_string(self):
        """Test that escaped quotes are handled correctly in string arguments."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals("a\\"b", actual);',
            'assertEquals'
        ); result = codeflash_output # 13.7μs -> 11.0μs (24.6% faster)

    def test_extract_first_arg_with_square_brackets(self):
        """Test that square brackets are respected as delimiters."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(arr[10], actual);',
            'assertEquals'
        ); result = codeflash_output # 6.17μs -> 6.27μs (1.61% slower)

    def test_extract_first_arg_with_curly_braces(self):
        """Test that curly braces are respected as delimiters."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals({1, 2, 3}, actual);',
            'assertEquals'
        ); result = codeflash_output # 10.5μs -> 9.46μs (11.4% faster)

    # ========== LARGE-SCALE TESTS ==========

    def test_very_long_integer(self):
        """Test with a very large integer literal."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(999999999999999);',
            'assertEquals'
        ); result = codeflash_output # 8.76μs -> 8.69μs (0.806% faster)

    def test_very_long_double(self):
        """Test with a very large double literal."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(999999999999.999999999);',
            'assertEquals'
        ); result = codeflash_output # 7.25μs -> 7.27μs (0.289% slower)

    def test_long_string_literal(self):
        """Test with a very long string literal."""
        long_string = '"' + ('x' * 1000) + '"'
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            f'assertEquals({long_string});',
            'assertEquals'
        ); result = codeflash_output # 131μs -> 69.4μs (89.2% faster)

    def test_many_arguments_takes_first(self):
        """Test with many arguments (100+) - should use first."""
        args = ', '.join([str(i) for i in range(100)])
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            f'assertEquals({args});',
            'assertEquals'
        ); result = codeflash_output # 5.88μs -> 5.91μs (0.508% slower)

    def test_deeply_nested_method_calls(self):
        """Test with deeply nested method calls (50+ levels)."""
        nested = 'obj'
        for i in range(50):
            nested = f'{nested}.method()'
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            f'assertEquals({nested});',
            'assertEquals'
        ); result = codeflash_output # 5.51μs -> 5.62μs (1.97% slower)

    def test_method_with_many_arguments_in_first_arg(self):
        """Test method call with many arguments inside first assertion arg."""
        inner_args = ', '.join([str(i) for i in range(50)])
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            f'assertEquals(method({inner_args}), actual);',
            'assertEquals'
        ); result = codeflash_output # 50.8μs -> 44.1μs (15.1% faster)

    def test_large_message_string_with_junit4_format(self):
        """Test JUnit4 format with very large message string."""
        large_msg = 'x' * 500
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            f'assertEquals("{large_msg}", 42, actual);',
            'assertEquals'
        ); result = codeflash_output # 121μs -> 72.5μs (67.7% faster)

    def test_many_commas_in_quotes_and_args(self):
        """Test arguments with many commas in strings and actual args."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals("' + (','.join(['x'] * 100)) + '", 99, actual);',
            'assertEquals'
        ); result = codeflash_output # 57.1μs -> 33.9μs (68.2% faster)

    # ========== METHOD NAME VARIATIONS ==========

    def test_other_assertion_method_assertEquals(self):
        """Test that different assertion methods work with assertEquals."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(42);',
            'assertEquals'
        ); result = codeflash_output # 5.39μs -> 5.50μs (2.00% slower)

    def test_other_assertion_method_assertNotEquals(self):
        """Test that assertNotEquals is handled correctly."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertNotEquals(42);',
            'assertNotEquals'
        ); result = codeflash_output # 5.57μs -> 5.50μs (1.27% faster)

    def test_other_assertion_method_assertTrue(self):
        """Test that assertTrue uses first arg."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertTrue(true);',
            'assertTrue'
        ); result = codeflash_output # 2.56μs -> 2.54μs (0.786% faster)

    def test_other_assertion_method_assertFalse(self):
        """Test that assertFalse uses first arg."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertFalse(false);',
            'assertFalse'
        ); result = codeflash_output # 2.48μs -> 2.40μs (2.95% faster)

    def test_other_assertion_method_assertNull(self):
        """Test that assertNull uses first arg."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertNull(null);',
            'assertNull'
        ); result = codeflash_output # 2.48μs -> 2.50μs (0.441% slower)

    # ========== COMBINATION TESTS ==========

    def test_special_float_notation_with_no_integer_part(self):
        """Test float like .5f (no leading zero)."""
        # This tests edge case where decimal starts without integer part
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals(.5f);',
            'assertEquals'
        ); result = codeflash_output # 5.50μs -> 5.34μs (3.00% faster)

    def test_all_numeric_types_in_sequence(self):
        """Test that each numeric type is correctly identified."""
        tests = [
            ('assertEquals(1);', 'int'),
            ('assertEquals(1L);', 'long'),
            ('assertEquals(1.0f);', 'float'),
            ('assertEquals(1.0);', 'double'),
        ]
        for assertion_text, expected_type in tests:
            codeflash_output = self.transformer._infer_type_from_assertion_args(
                assertion_text, 'assertEquals'
            ); result = codeflash_output # 11.4μs -> 11.3μs (1.04% faster)

    def test_booleans_and_null_in_sequence(self):
        """Test boolean and null literals are identified correctly."""
        tests = [
            ('assertEquals(true);', 'boolean'),
            ('assertEquals(false);', 'boolean'),
            ('assertEquals(null);', 'null'),
        ]
        for assertion_text, expected_type in tests:
            codeflash_output = self.transformer._infer_type_from_assertion_args(
                assertion_text, 'assertEquals'
            ); result = codeflash_output # 4.60μs -> 4.67μs (1.52% slower)

    def test_string_and_char_literals(self):
        """Test string and char distinction."""
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            'assertEquals("a");',
            'assertEquals'
        ); string_result = codeflash_output # 8.75μs -> 6.86μs (27.4% faster)
        codeflash_output = self.transformer._infer_type_from_assertion_args(
            "assertEquals('a');",
            'assertEquals'
        ); char_result = codeflash_output # 3.10μs -> 3.08μs (0.943% faster)

    def test_decimal_formats_all_variations(self):
        """Test all decimal number formats."""
        tests = [
            ('assertEquals(1.0);', 'double'),   # decimal with d default
            ('assertEquals(1.0d);', 'double'),  # decimal with explicit d
            ('assertEquals(1.0D);', 'double'),  # decimal with explicit D
            ('assertEquals(1.0f);', 'float'),   # decimal with f
            ('assertEquals(1.0F);', 'float'),   # decimal with F
        ]
        for assertion_text, expected_type in tests:
            codeflash_output = self.transformer._infer_type_from_assertion_args(
                assertion_text, 'assertEquals'
            ); result = codeflash_output # 12.0μs -> 12.0μs (0.258% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-pr1655-2026-03-04T05.37.57 and push.

Codeflash Static Badge

The optimization replaced character-by-character list accumulation in `_extract_first_arg` with direct substring slicing (tracking `start` and `end` indices instead of building a list via `cur.append()`), eliminating repeated list operations and the final `"".join(cur)` call. For the JUnit4 message-string case (where `assertEquals("msg", expected, actual)` requires extracting the second argument), it introduced `_second_arg_if_message`, a lightweight two-comma scanner that stops as soon as it confirms three top-level arguments exist and extracts only the second one, avoiding the original `_split_top_level_args` which built a complete list of all arguments. Line profiler confirms `_split_top_level_args` accounted for 36.6% of original runtime and `_extract_first_arg`'s list operations added 10.1%; the new index-based approach cuts total runtime by 32% with no behavioral changes.
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Mar 4, 2026
@codeflash-ai codeflash-ai bot mentioned this pull request Mar 4, 2026
5 tasks
@claude
Copy link
Contributor

claude bot commented Mar 4, 2026

Closing stale optimization PR.

@claude claude bot closed this Mar 4, 2026
@claude claude bot deleted the codeflash/optimize-pr1655-2026-03-04T05.37.57 branch March 4, 2026 07:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants