Skip to content

⚡️ Speed up method JavaAssertTransformer._infer_return_type by 44% in PR #1655 (feat/add/void/func)#1757

Merged
claude[bot] merged 1 commit intofeat/add/void/funcfrom
codeflash/optimize-pr1655-2026-03-04T03.21.59
Mar 4, 2026
Merged

⚡️ Speed up method JavaAssertTransformer._infer_return_type by 44% in PR #1655 (feat/add/void/func)#1757
claude[bot] merged 1 commit intofeat/add/void/funcfrom
codeflash/optimize-pr1655-2026-03-04T03.21.59

Conversation

@codeflash-ai
Copy link
Contributor

@codeflash-ai codeflash-ai bot commented Mar 4, 2026

⚡️ This pull request contains optimizations for PR #1655

If you approve this dependent PR, these changes will be merged into the original PR branch feat/add/void/func.

This PR will be automatically closed if the original PR is merged.


📄 44% (0.44x) speedup for JavaAssertTransformer._infer_return_type in codeflash/languages/java/remove_asserts.py

⏱️ Runtime : 4.73 milliseconds 3.30 milliseconds (best of 151 runs)

📝 Explanation and details

The optimization adds a fast-path check in _infer_type_from_assertion_args that looks for the first comma and extracts the substring before it directly when no special delimiter characters (quotes, parentheses, braces) precede that comma, bypassing the expensive full _extract_first_arg parser for simple literals like 42, 100L, or true. Line profiler shows the original _extract_first_arg call consumed 49% of method runtime (~11.1 ms); the optimized version reduces this to 6.7% (~0.98 ms) by handling 1518 of 1632 cases via the cheap substring path, cutting per-call overhead from ~6808 ns to ~140–280 ns for common assertions. Runtime improves 43% (4.73 ms → 3.30 ms) with no correctness regressions; the fallback preserves exact behavior for nested or quoted arguments.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 2027 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Click to see Generated Regression Tests
from collections import \
    namedtuple  # used to create simple real tuple-based records

# imports
import pytest  # used for our unit tests
from codeflash.languages.java.remove_asserts import (JUNIT5_VALUE_ASSERTIONS,
                                                     JavaAssertTransformer)

# Create a lightweight real class-like tuple to represent AssertionMatch records.
# Using collections.namedtuple produces a real class type (not a mock), and instances
# behave like normal objects with named attributes (assertion_method, original_text)
AssertionMatch = namedtuple("AssertionMatch", ["assertion_method", "original_text"])

def test_basic_assert_true_false_and_null_return_object():
    # Create a transformer instance (real class constructor as required).
    transformer = JavaAssertTransformer(function_name="f")

    # assertTrue should always fall back to "Object"
    am_true = AssertionMatch(assertion_method="assertTrue", original_text="assertTrue(f())")
    codeflash_output = transformer._infer_return_type(am_true) # 772ns -> 631ns (22.3% faster)

    # assertFalse should also fall back to "Object"
    am_false = AssertionMatch(assertion_method="assertFalse", original_text="assertFalse(x)")
    codeflash_output = transformer._infer_return_type(am_false) # 361ns -> 270ns (33.7% faster)

    # assertNull should keep "Object" (reference type)
    am_null = AssertionMatch(assertion_method="assertNull", original_text="assertNull(obj)")
    codeflash_output = transformer._infer_return_type(am_null) # 331ns -> 330ns (0.303% faster)

    # assertNotNull should also keep "Object"
    am_notnull = AssertionMatch(assertion_method="assertNotNull", original_text="assertNotNull(obj)")
    codeflash_output = transformer._infer_return_type(am_notnull) # 190ns -> 310ns (38.7% slower)

def test_infer_common_literal_types_from_assert_equals():
    # transformer instance used across checks
    transformer = JavaAssertTransformer(function_name="func")

    # Basic integer literal
    am_int = AssertionMatch("assertEquals", "assertEquals(42, actual)")
    codeflash_output = transformer._infer_return_type(am_int) # 8.53μs -> 7.03μs (21.2% faster)

    # Negative integer literal
    am_neg = AssertionMatch("assertEquals", "assertEquals(-7, actual)")
    codeflash_output = transformer._infer_return_type(am_neg) # 4.60μs -> 3.69μs (24.7% faster)

    # Long literal with L suffix (case-insensitive)
    am_long = AssertionMatch("assertEquals", "assertEquals(123L, actual)")
    codeflash_output = transformer._infer_return_type(am_long) # 4.26μs -> 3.27μs (30.4% faster)

    am_long_lower = AssertionMatch("assertEquals", "assertEquals(-5l, actual)")
    codeflash_output = transformer._infer_return_type(am_long_lower) # 3.47μs -> 2.43μs (42.4% faster)

    # Float literal with f suffix
    am_float = AssertionMatch("assertEquals", "assertEquals(2.0f, actual)")
    codeflash_output = transformer._infer_return_type(am_float) # 2.88μs -> 1.82μs (57.8% faster)

    # Double literal (plain decimal) and with d suffix
    am_double = AssertionMatch("assertEquals", "assertEquals(3.1415, actual)")
    codeflash_output = transformer._infer_return_type(am_double) # 4.00μs -> 2.31μs (72.7% faster)

    am_double_d = AssertionMatch("assertEquals", "assertEquals(2d, actual)")
    codeflash_output = transformer._infer_return_type(am_double_d) # 2.73μs -> 1.95μs (40.0% faster)

    # Char literal
    am_char = AssertionMatch("assertEquals", "assertEquals('x', actual)")
    codeflash_output = transformer._infer_return_type(am_char) # 3.27μs -> 5.26μs (37.9% slower)

    # String literal
    am_string = AssertionMatch("assertEquals", 'assertEquals("hello", actual)')
    codeflash_output = transformer._infer_return_type(am_string) # 8.77μs -> 9.40μs (6.72% slower)

    # Boolean literal
    am_bool = AssertionMatch("assertEquals", "assertEquals(true, actual)")
    codeflash_output = transformer._infer_return_type(am_bool) # 2.71μs -> 1.66μs (62.7% faster)

def test_null_and_casts_and_nonvalue_cases():
    transformer = JavaAssertTransformer(function_name="f")

    # explicit null expected should map to Object (reference)
    am_null = AssertionMatch("assertEquals", "assertEquals(null, actual)")
    codeflash_output = transformer._infer_return_type(am_null) # 5.49μs -> 3.98μs (38.0% faster)

    # cast expressions like (byte)0 or (short)-1 should return the cast type
    am_byte = AssertionMatch("assertEquals", "assertEquals((byte)0, actual)")
    codeflash_output = transformer._infer_return_type(am_byte) # 7.43μs -> 9.07μs (18.0% slower)

    am_short = AssertionMatch("assertEquals", "assertEquals((short)-1, actual)")
    codeflash_output = transformer._infer_return_type(am_short) # 5.61μs -> 6.33μs (11.4% slower)

    # When expected is an expression / method call, we cannot infer -> fallback to Object
    am_expr = AssertionMatch("assertEquals", "assertEquals(Collections.singletonList(1), actual)")
    codeflash_output = transformer._infer_return_type(am_expr) # 9.07μs -> 9.97μs (9.04% slower)

def test_junit4_message_signature_skips_message_string():
    transformer = JavaAssertTransformer(function_name="f")

    # JUnit4 style: assertEquals(String message, expected, actual) -> skip the first string argument
    am_with_msg = AssertionMatch(
        "assertEquals", 'assertEquals("should match", 55, someCall())'
    )
    # The expected value is the second arg (55) so type should be int
    codeflash_output = transformer._infer_return_type(am_with_msg) # 17.6μs -> 17.8μs (1.46% slower)

    # Also works without semicolon and with extra whitespace
    am_with_msg_ws = AssertionMatch(
        "assertEquals", 'assertEquals ( "x" , 123L , foo() )'
    )
    codeflash_output = transformer._infer_return_type(am_with_msg_ws) # 11.4μs -> 11.8μs (2.73% slower)

def test_malformed_or_empty_arguments_return_object():
    transformer = JavaAssertTransformer(function_name="f")

    # Missing opening parenthesis -> cannot extract args
    am_malformed = AssertionMatch("assertEquals", "assertEquals")
    codeflash_output = transformer._infer_return_type(am_malformed) # 1.45μs -> 1.41μs (2.90% faster)

    # Empty parentheses -> nothing to extract
    am_empty = AssertionMatch("assertEquals", "assertEquals()")
    codeflash_output = transformer._infer_return_type(am_empty) # 2.00μs -> 1.64μs (22.0% faster)

    # Parentheses with only whitespace -> no argument
    am_whitespace = AssertionMatch("assertEquals", "assertEquals(   )")
    codeflash_output = transformer._infer_return_type(am_whitespace) # 1.89μs -> 3.89μs (51.3% slower)

def test_string_with_comma_and_nested_delimiters_handling():
    transformer = JavaAssertTransformer(function_name="f")

    # First argument is a string containing a comma. Extracted first argument should keep the quotes,
    # and since it's a string literal, the result should be String.
    am_str_comma = AssertionMatch("assertEquals", 'assertEquals("a,b", other)')
    codeflash_output = transformer._infer_return_type(am_str_comma) # 12.7μs -> 13.2μs (3.20% slower)

    # If first argument is a message string but there are only two args, message should not be skipped
    am_message_and_actual = AssertionMatch("assertEquals", 'assertEquals("msg, with, commas", actual)')
    # Since there are only 2 args, the first arg is treated as expected -> String
    codeflash_output = transformer._infer_return_type(am_message_and_actual) # 11.5μs -> 11.9μs (3.20% slower)

def test_large_scale_alternating_int_and_long_inference():
    transformer = JavaAssertTransformer(function_name="f")

    # Large-scale test: 1000 iterations, alternate between int and long suffixes.
    # This checks both correctness across many inputs and reasonable performance.
    n = 1000
    for i in range(n):
        if i % 2 == 0:
            original = f"assertEquals({i}L, something{i})"
            expected = "long"
        else:
            original = f"assertEquals({i}, something{i})"
            expected = "int"
        am = AssertionMatch("assertEquals", original)
        # Each iteration deterministically yields the expected type
        codeflash_output = transformer._infer_return_type(am) # 3.08ms -> 2.16ms (42.3% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
import pytest
from codeflash.languages.java.parser import JavaAnalyzer, get_java_analyzer
from codeflash.languages.java.remove_asserts import (JUNIT5_VALUE_ASSERTIONS,
                                                     JavaAssertTransformer)

# Helper class to create AssertionMatch objects for testing
class AssertionMatch:
    """Represents a matched assertion for testing purposes."""
    def __init__(self, assertion_method: str, original_text: str):
        self.assertion_method = assertion_method
        self.original_text = original_text

def test_infer_return_type_assert_true():
    """Test that assertTrue returns Object type."""
    transformer = JavaAssertTransformer("testMethod")
    assertion = AssertionMatch("assertTrue", "assertTrue(value > 0);")
    codeflash_output = transformer._infer_return_type(assertion); result = codeflash_output # 732ns -> 581ns (26.0% faster)

def test_infer_return_type_assert_false():
    """Test that assertFalse returns Object type."""
    transformer = JavaAssertTransformer("testMethod")
    assertion = AssertionMatch("assertFalse", "assertFalse(value < 0);")
    codeflash_output = transformer._infer_return_type(assertion); result = codeflash_output # 641ns -> 551ns (16.3% faster)

def test_infer_return_type_assert_null():
    """Test that assertNull returns Object type."""
    transformer = JavaAssertTransformer("testMethod")
    assertion = AssertionMatch("assertNull", "assertNull(result);")
    codeflash_output = transformer._infer_return_type(assertion); result = codeflash_output # 702ns -> 691ns (1.59% faster)

def test_infer_return_type_assert_not_null():
    """Test that assertNotNull returns Object type."""
    transformer = JavaAssertTransformer("testMethod")
    assertion = AssertionMatch("assertNotNull", "assertNotNull(result);")
    codeflash_output = transformer._infer_return_type(assertion); result = codeflash_output # 671ns -> 672ns (0.149% slower)

def test_infer_return_type_assert_equals_int_literal():
    """Test assertEquals with int literal infers int type."""
    transformer = JavaAssertTransformer("testMethod")
    assertion = AssertionMatch("assertEquals", "assertEquals(42, obj.getValue());")
    codeflash_output = transformer._infer_return_type(assertion); result = codeflash_output # 8.72μs -> 6.96μs (25.2% faster)

def test_infer_return_type_assert_equals_long_literal():
    """Test assertEquals with long literal infers long type."""
    transformer = JavaAssertTransformer("testMethod")
    assertion = AssertionMatch("assertEquals", "assertEquals(100L, obj.getValue());")
    codeflash_output = transformer._infer_return_type(assertion); result = codeflash_output # 8.37μs -> 6.66μs (25.6% faster)

def test_infer_return_type_assert_equals_float_literal():
    """Test assertEquals with float literal infers float type."""
    transformer = JavaAssertTransformer("testMethod")
    assertion = AssertionMatch("assertEquals", "assertEquals(3.14f, obj.getValue());")
    codeflash_output = transformer._infer_return_type(assertion); result = codeflash_output # 7.10μs -> 5.11μs (39.0% faster)

def test_infer_return_type_assert_equals_double_literal():
    """Test assertEquals with double literal infers double type."""
    transformer = JavaAssertTransformer("testMethod")
    assertion = AssertionMatch("assertEquals", "assertEquals(3.14, obj.getValue());")
    codeflash_output = transformer._infer_return_type(assertion); result = codeflash_output # 7.74μs -> 5.94μs (30.3% faster)

def test_infer_return_type_assert_equals_string_literal():
    """Test assertEquals with string literal infers String type."""
    transformer = JavaAssertTransformer("testMethod")
    assertion = AssertionMatch("assertEquals", 'assertEquals("hello", obj.getValue());')
    codeflash_output = transformer._infer_return_type(assertion); result = codeflash_output # 15.3μs -> 16.1μs (4.87% slower)

def test_infer_return_type_assert_equals_boolean_true():
    """Test assertEquals with boolean true literal infers boolean type."""
    transformer = JavaAssertTransformer("testMethod")
    assertion = AssertionMatch("assertEquals", "assertEquals(true, obj.isValid());")
    codeflash_output = transformer._infer_return_type(assertion); result = codeflash_output # 5.27μs -> 3.76μs (40.3% faster)

def test_infer_return_type_assert_equals_boolean_false():
    """Test assertEquals with boolean false literal infers boolean type."""
    transformer = JavaAssertTransformer("testMethod")
    assertion = AssertionMatch("assertEquals", "assertEquals(false, obj.isValid());")
    codeflash_output = transformer._infer_return_type(assertion); result = codeflash_output # 5.45μs -> 3.67μs (48.6% faster)

def test_infer_return_type_assert_equals_char_literal():
    """Test assertEquals with char literal infers char type."""
    transformer = JavaAssertTransformer("testMethod")
    assertion = AssertionMatch("assertEquals", "assertEquals('a', obj.getChar());")
    codeflash_output = transformer._infer_return_type(assertion); result = codeflash_output # 7.68μs -> 8.67μs (11.3% slower)

def test_infer_return_type_assert_equals_null():
    """Test assertEquals with null literal infers Object type."""
    transformer = JavaAssertTransformer("testMethod")
    assertion = AssertionMatch("assertEquals", "assertEquals(null, obj.getValue());")
    codeflash_output = transformer._infer_return_type(assertion); result = codeflash_output # 5.30μs -> 3.78μs (40.3% faster)

def test_infer_return_type_assert_not_equals():
    """Test assertNotEquals infers type from first argument."""
    transformer = JavaAssertTransformer("testMethod")
    assertion = AssertionMatch("assertNotEquals", "assertNotEquals(10, obj.getValue());")
    codeflash_output = transformer._infer_return_type(assertion); result = codeflash_output # 7.70μs -> 6.52μs (18.1% faster)

def test_infer_return_type_assert_same():
    """Test assertSame infers from first argument."""
    transformer = JavaAssertTransformer("testMethod")
    assertion = AssertionMatch("assertSame", "assertSame(expected, actual);")
    codeflash_output = transformer._infer_return_type(assertion); result = codeflash_output # 8.74μs -> 6.19μs (41.1% faster)

def test_infer_return_type_assert_that_fluent():
    """Test assertThat (fluent assertion) returns Object."""
    transformer = JavaAssertTransformer("testMethod")
    assertion = AssertionMatch("assertThat", "assertThat(obj).isNotNull();")
    codeflash_output = transformer._infer_return_type(assertion); result = codeflash_output # 721ns -> 822ns (12.3% slower)

def test_infer_return_type_junit4_with_message():
    """Test assertEquals with JUnit 4 message parameter (3 args: message, expected, actual)."""
    transformer = JavaAssertTransformer("testMethod")
    # JUnit 4 format: assertEquals(String message, expected, actual)
    assertion = AssertionMatch("assertEquals", 'assertEquals("Values should match", 42, obj.getValue());')
    codeflash_output = transformer._infer_return_type(assertion); result = codeflash_output # 20.1μs -> 20.6μs (2.48% slower)

def test_type_from_literal_boolean_true():
    """Test _type_from_literal with boolean true."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._type_from_literal("true")

def test_type_from_literal_boolean_false():
    """Test _type_from_literal with boolean false."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._type_from_literal("false")

def test_type_from_literal_int():
    """Test _type_from_literal with int literal."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._type_from_literal("42")

def test_type_from_literal_long():
    """Test _type_from_literal with long literal."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._type_from_literal("100L")

def test_type_from_literal_float():
    """Test _type_from_literal with float literal."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._type_from_literal("3.14f")

def test_type_from_literal_double():
    """Test _type_from_literal with double literal."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._type_from_literal("3.14")

def test_type_from_literal_string():
    """Test _type_from_literal with string literal."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._type_from_literal('"hello"')

def test_type_from_literal_char():
    """Test _type_from_literal with char literal."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._type_from_literal("'a'")

def test_type_from_literal_null():
    """Test _type_from_literal with null."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._type_from_literal("null")

def test_type_from_literal_cast_byte():
    """Test _type_from_literal with cast expression (byte)."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._type_from_literal("(byte)0")

def test_type_from_literal_cast_short():
    """Test _type_from_literal with cast expression (short)."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._type_from_literal("(short)100")

def test_extract_first_arg_simple():
    """Test _extract_first_arg with simple single argument."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._extract_first_arg("42")

def test_extract_first_arg_multiple():
    """Test _extract_first_arg stops at first comma."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._extract_first_arg("42, obj.getValue()")

def test_extract_first_arg_with_parens():
    """Test _extract_first_arg respects nested parentheses."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._extract_first_arg("obj.method(1, 2), other")

def test_extract_first_arg_with_string():
    """Test _extract_first_arg respects string delimiters."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._extract_first_arg('"hello, world", 42')

def test_extract_first_arg_with_generics():
    """Test _extract_first_arg respects generic angle brackets."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._extract_first_arg("List<String>, value")

def test_infer_return_type_empty_assertion_text():
    """Test with empty assertion text."""
    transformer = JavaAssertTransformer("testMethod")
    assertion = AssertionMatch("assertEquals", "")
    codeflash_output = transformer._infer_return_type(assertion); result = codeflash_output # 1.50μs -> 1.46μs (2.67% faster)

def test_infer_return_type_no_parentheses():
    """Test with assertion text that has no opening parenthesis."""
    transformer = JavaAssertTransformer("testMethod")
    assertion = AssertionMatch("assertEquals", "assertEquals")
    codeflash_output = transformer._infer_return_type(assertion); result = codeflash_output # 1.37μs -> 1.20μs (14.1% faster)

def test_extract_first_arg_empty_string():
    """Test _extract_first_arg with empty string."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._extract_first_arg("")

def test_extract_first_arg_only_whitespace():
    """Test _extract_first_arg with only whitespace."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._extract_first_arg("   ")

def test_extract_first_arg_leading_whitespace():
    """Test _extract_first_arg with leading whitespace."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._extract_first_arg("   42, other")

def test_extract_first_arg_trailing_whitespace():
    """Test _extract_first_arg trims trailing whitespace."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._extract_first_arg("42   , other")

def test_type_from_literal_negative_int():
    """Test _type_from_literal with negative int."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._type_from_literal("-42")

def test_type_from_literal_negative_long():
    """Test _type_from_literal with negative long."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._type_from_literal("-100L")

def test_type_from_literal_negative_float():
    """Test _type_from_literal with negative float."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._type_from_literal("-3.14f")

def test_type_from_literal_negative_double():
    """Test _type_from_literal with negative double."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._type_from_literal("-3.14")

def test_type_from_literal_zero_int():
    """Test _type_from_literal with zero."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._type_from_literal("0")

def test_type_from_literal_char_escaped():
    """Test _type_from_literal with escaped char."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._type_from_literal("'\\n'")

def test_type_from_literal_unknown_reference():
    """Test _type_from_literal with unknown variable reference."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._type_from_literal("someVariable")

def test_type_from_literal_method_call():
    """Test _type_from_literal with method call expression."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._type_from_literal("obj.getValue()")

def test_split_top_level_args_empty():
    """Test _split_top_level_args with empty string."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._split_top_level_args("")

def test_split_top_level_args_single_arg():
    """Test _split_top_level_args with single argument."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._split_top_level_args("42")

def test_split_top_level_args_multiple_simple():
    """Test _split_top_level_args with multiple simple arguments."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._split_top_level_args("42, 'hello', true")

def test_split_top_level_args_nested_parens():
    """Test _split_top_level_args with nested parentheses."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._split_top_level_args("obj.method(1, 2), value")

def test_split_top_level_args_string_with_comma():
    """Test _split_top_level_args with comma inside string."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._split_top_level_args('"hello, world", 42')

def test_split_top_level_args_generic_type():
    """Test _split_top_level_args with generic type."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._split_top_level_args("List<String>, value")

def test_split_top_level_args_nested_generics():
    """Test _split_top_level_args with nested generics."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._split_top_level_args("Map<String, Integer>, value")

def test_split_top_level_args_escaped_string():
    """Test _split_top_level_args with escaped quotes in string."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._split_top_level_args(r'"hello\"world", 42')

def test_extract_first_arg_nested_braces():
    """Test _extract_first_arg with nested braces."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._extract_first_arg("{a=1, b=2}, value")

def test_extract_first_arg_nested_brackets():
    """Test _extract_first_arg with nested brackets."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._extract_first_arg("new int[]{1, 2, 3}, value")

def test_infer_type_from_assertion_args_junit4_message_string_actual():
    """Test _infer_type_from_assertion_args with JUnit4 format where message is string."""
    transformer = JavaAssertTransformer("testMethod")
    original_text = 'assertEquals("msg", 99, obj.getValue());'
    result = transformer._infer_type_from_assertion_args(original_text, "assertEquals")

def test_infer_type_from_assertion_args_no_paren():
    """Test _infer_type_from_assertion_args when no opening paren found."""
    transformer = JavaAssertTransformer("testMethod")
    original_text = "assertEquals"
    result = transformer._infer_type_from_assertion_args(original_text, "assertEquals")

def test_infer_type_from_assertion_args_empty_args():
    """Test _infer_type_from_assertion_args with empty arguments."""
    transformer = JavaAssertTransformer("testMethod")
    original_text = "assertEquals();"
    result = transformer._infer_type_from_assertion_args(original_text, "assertEquals")

def test_type_from_literal_large_int():
    """Test _type_from_literal with large integer."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._type_from_literal("2147483647")

def test_type_from_literal_large_long():
    """Test _type_from_literal with large long."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._type_from_literal("9223372036854775807L")

def test_type_from_literal_float_no_decimal():
    """Test _type_from_literal with float literal without decimal point."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._type_from_literal("42f")

def test_type_from_literal_double_uppercase_d():
    """Test _type_from_literal with double using uppercase D."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._type_from_literal("3.14D")

def test_type_from_literal_long_lowercase_l():
    """Test _type_from_literal with long using lowercase l."""
    transformer = JavaAssertTransformer("testMethod")
    result = transformer._type_from_literal("100l")

def test_split_top_level_args_many_simple_args():
    """Test _split_top_level_args with many simple arguments (100 args)."""
    transformer = JavaAssertTransformer("testMethod")
    # Create a string with 100 simple int arguments
    args_list = [str(i) for i in range(100)]
    args_str = ", ".join(args_list)
    result = transformer._split_top_level_args(args_str)

def test_split_top_level_args_deeply_nested_parens():
    """Test _split_top_level_args with deeply nested parentheses."""
    transformer = JavaAssertTransformer("testMethod")
    # Create deeply nested method calls: method(method(method(...)))
    nested = "42"
    for _ in range(50):
        nested = f"method({nested})"
    args_str = f"{nested}, value"
    result = transformer._split_top_level_args(args_str)

def test_split_top_level_args_many_generics():
    """Test _split_top_level_args with many generic type arguments."""
    transformer = JavaAssertTransformer("testMethod")
    # Create complex generics with many nested angle brackets
    complex_type = "Map<String, List<Map<Integer, List<String>>>>"
    args_str = f"{complex_type}, value"
    result = transformer._split_top_level_args(args_str)

def test_split_top_level_args_many_strings():
    """Test _split_top_level_args with many string arguments."""
    transformer = JavaAssertTransformer("testMethod")
    # Create 100 string arguments
    args_list = [f'"string{i}"' for i in range(100)]
    args_str = ", ".join(args_list)
    result = transformer._split_top_level_args(args_str)

def test_extract_first_arg_from_many_args():
    """Test _extract_first_arg extracts only first from 1000 arguments."""
    transformer = JavaAssertTransformer("testMethod")
    # Create string with 1000 simple arguments
    args_list = [str(i) for i in range(1000)]
    args_str = ", ".join(args_list)
    result = transformer._extract_first_arg(args_str)

def test_type_from_literal_various_types_in_sequence():
    """Test _type_from_literal with many different type patterns."""
    transformer = JavaAssertTransformer("testMethod")
    # Test various literal patterns in sequence
    test_cases = [
        ("42", "int"),
        ("100L", "long"),
        ("3.14f", "float"),
        ("3.14", "double"),
        ("true", "boolean"),
        ("'a'", "char"),
        ('"str"', "String"),
        ("null", "Object"),
        ("(byte)5", "byte"),
        ("-999", "int"),
    ]
    for literal, expected_type in test_cases * 100:
        result = transformer._type_from_literal(literal)

def test_infer_return_type_many_assertions():
    """Test _infer_return_type with many different assertion patterns."""
    transformer = JavaAssertTransformer("testMethod")
    # Test many different assertion methods
    assertions = [
        (AssertionMatch("assertEquals", "assertEquals(42, value);"), "int"),
        (AssertionMatch("assertEquals", "assertEquals(100L, value);"), "long"),
        (AssertionMatch("assertEquals", "assertEquals(3.14f, value);"), "float"),
        (AssertionMatch("assertEquals", "assertEquals(3.14, value);"), "double"),
        (AssertionMatch("assertNotEquals", "assertNotEquals(true, value);"), "boolean"),
        (AssertionMatch("assertTrue", "assertTrue(condition);"), "Object"),
        (AssertionMatch("assertFalse", "assertFalse(condition);"), "Object"),
        (AssertionMatch("assertNull", "assertNull(value);"), "Object"),
        (AssertionMatch("assertNotNull", "assertNotNull(value);"), "Object"),
    ]
    # Run each assertion 100 times
    for assertion, expected_type in assertions * 100:
        codeflash_output = transformer._infer_return_type(assertion); result = codeflash_output # 1.41ms -> 902μs (55.8% faster)

def test_extract_first_arg_very_long_arg():
    """Test _extract_first_arg with a single very long argument (10000 chars)."""
    transformer = JavaAssertTransformer("testMethod")
    # Create a very long string argument
    long_string = "x" * 10000
    args_str = f'"{long_string}", other'
    result = transformer._extract_first_arg(args_str)

def test_split_top_level_args_complex_mixed_scenario():
    """Test _split_top_level_args with complex mixed nesting (100 mixed args)."""
    transformer = JavaAssertTransformer("testMethod")
    # Create 100 mixed arguments with various complexity levels
    args_list = []
    for i in range(100):
        if i % 4 == 0:
            args_list.append(str(i))
        elif i % 4 == 1:
            args_list.append(f'"string{i}"')
        elif i % 4 == 2:
            args_list.append(f'method{i}({i})')
        else:
            args_list.append(f'List<Map<String, Integer>>')
    args_str = ", ".join(args_list)
    result = transformer._split_top_level_args(args_str)

def test_infer_type_from_assertion_args_junit4_100_times():
    """Test _infer_type_from_assertion_args with JUnit4 format 100 times."""
    transformer = JavaAssertTransformer("testMethod")
    # Test JUnit4 format with message parameter multiple times
    for i in range(100):
        original_text = f'assertEquals("msg{i}", {i * 10}, obj.getValue{i}());'
        result = transformer._infer_type_from_assertion_args(original_text, "assertEquals")

def test_type_from_literal_regex_patterns_scale():
    """Test regex pattern matching in _type_from_literal with 1000 iterations."""
    transformer = JavaAssertTransformer("testMethod")
    # Test that the regex patterns work correctly at scale
    for i in range(1000):
        pass
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-pr1655-2026-03-04T03.21.59 and push.

Codeflash Static Badge

The optimization adds a fast-path check in `_infer_type_from_assertion_args` that looks for the first comma and extracts the substring before it directly when no special delimiter characters (quotes, parentheses, braces) precede that comma, bypassing the expensive full `_extract_first_arg` parser for simple literals like `42`, `100L`, or `true`. Line profiler shows the original `_extract_first_arg` call consumed 49% of method runtime (~11.1 ms); the optimized version reduces this to 6.7% (~0.98 ms) by handling 1518 of 1632 cases via the cheap substring path, cutting per-call overhead from ~6808 ns to ~140–280 ns for common assertions. Runtime improves 43% (4.73 ms → 3.30 ms) with no correctness regressions; the fallback preserves exact behavior for nested or quoted arguments.
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Mar 4, 2026
@codeflash-ai codeflash-ai bot mentioned this pull request Mar 4, 2026
5 tasks
@claude claude bot merged commit 953ef50 into feat/add/void/func Mar 4, 2026
18 of 30 checks passed
@claude claude bot deleted the codeflash/optimize-pr1655-2026-03-04T03.21.59 branch March 4, 2026 05:23
@claude claude bot mentioned this pull request Mar 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants