diff --git a/code_to_optimize/tests/pytest/test_topological_sort.py b/code_to_optimize/tests/pytest/test_topological_sort.py index 038cab4ea..30c709d53 100644 --- a/code_to_optimize/tests/pytest/test_topological_sort.py +++ b/code_to_optimize/tests/pytest/test_topological_sort.py @@ -10,7 +10,7 @@ def test_topological_sort(): g.addEdge(2, 3) g.addEdge(3, 1) - assert g.topologicalSort() == [5, 4, 2, 3, 1, 0] + assert g.topologicalSort()[0] == [5, 4, 2, 3, 1, 0] def test_topological_sort_2(): @@ -20,7 +20,7 @@ def test_topological_sort_2(): for j in range(i + 1, 10): g.addEdge(i, j) - assert g.topologicalSort() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + assert g.topologicalSort()[0] == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] g = Graph(10) @@ -28,7 +28,7 @@ def test_topological_sort_2(): for j in range(i + 1, 10): g.addEdge(i, j) - assert g.topologicalSort() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + assert g.topologicalSort()[0] == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] def test_topological_sort_3(): @@ -38,4 +38,4 @@ def test_topological_sort_3(): for j in range(i + 1, 1000): g.addEdge(j, i) - assert g.topologicalSort() == list(reversed(range(1000))) + assert g.topologicalSort()[0] == list(reversed(range(1000))) diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 398efe461..382849a59 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -212,15 +212,25 @@ def __init__(self, function_names_to_find: set[str]) -> None: self.wildcard_modules: set[str] = set() # Track aliases: alias_name -> original_name self.alias_mapping: dict[str, str] = {} + # Track instances: variable_name -> class_name + self.instance_mapping: dict[str, str] = {} # Precompute function_names for prefix search # For prefix match, store mapping from prefix-root to candidates for O(1) matching self._exact_names = function_names_to_find self._prefix_roots: dict[str, list[str]] = {} + # Precompute sets for faster lookup during visit_Attribute() + self._dot_names: set[str] = set() + self._dot_methods: dict[str, set[str]] = {} + self._class_method_to_target: dict[tuple[str, str], str] = {} for name in function_names_to_find: if "." in name: - root = name.split(".", 1)[0] - self._prefix_roots.setdefault(root, []).append(name) + root, method = name.rsplit(".", 1) + self._dot_names.add(name) + self._dot_methods.setdefault(method, set()).add(root) + self._class_method_to_target[(root, method)] = name + root_prefix = name.split(".", 1)[0] + self._prefix_roots.setdefault(root_prefix, []).append(name) def visit_Import(self, node: ast.Import) -> None: """Handle 'import module' statements.""" @@ -247,6 +257,28 @@ def visit_Import(self, node: ast.Import) -> None: self.found_qualified_name = target_func return + def visit_Assign(self, node: ast.Assign) -> None: + """Track variable assignments, especially class instantiations.""" + if self.found_any_target_function: + return + + # Check if the assignment is a class instantiation + handled_assignment = False + if isinstance(node.value, ast.Call) and type(node.value.func) is ast.Name: + class_name = node.value.func.id + if class_name in self.imported_modules: + # Track all target variables as instances of the imported class + for target in node.targets: + if type(target) is ast.Name: + # Map the variable to the actual class name (handling aliases) + original_class = self.alias_mapping.get(class_name, class_name) + self.instance_mapping[target.id] = original_class + handled_assignment = True + + # Only traverse child nodes if we didn't handle a class instantiation assignment + if not handled_assignment: + self.generic_visit(node) + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: """Handle 'from module import name' statements.""" if self.found_any_target_function: @@ -287,6 +319,18 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None: self.found_qualified_name = qname return + # Check if any target function is a method of the imported class/module + # Be conservative except when an alias is used (which requires exact method matching) + for target_func in fnames: + if "." in target_func: + class_name, method_name = target_func.split(".", 1) + if aname == class_name and not alias.asname: + # If an alias is used, don't match conservatively + # The actual method usage should be detected in visit_Attribute + self.found_any_target_function = True + self.found_qualified_name = target_func + return + prefix = qname + "." # Only bother if one of the targets startswith the prefix-root candidates = proots.get(qname, ()) @@ -301,33 +345,45 @@ def visit_Attribute(self, node: ast.Attribute) -> None: if self.found_any_target_function: return + # Check if this is accessing a target function through an imported module + + node_value = node.value + node_attr = node.attr + # Check if this is accessing a target function through an imported module if ( - isinstance(node.value, ast.Name) - and node.value.id in self.imported_modules - and node.attr in self.function_names_to_find + isinstance(node_value, ast.Name) + and node_value.id in self.imported_modules + and node_attr in self.function_names_to_find ): self.found_any_target_function = True - self.found_qualified_name = node.attr + self.found_qualified_name = node_attr return - if isinstance(node.value, ast.Name) and node.value.id in self.imported_modules: - for target_func in self.function_names_to_find: - if "." in target_func: - class_name, method_name = target_func.rsplit(".", 1) - if node.attr == method_name: - imported_name = node.value.id - original_name = self.alias_mapping.get(imported_name, imported_name) - if original_name == class_name: - self.found_any_target_function = True - self.found_qualified_name = target_func - return - - # Check if this is accessing a target function through a dynamically imported module - # Only if we've detected dynamic imports are being used - if self.has_dynamic_imports and node.attr in self.function_names_to_find: + # Check for methods via imported modules using precomputed _dot_methods and _class_method_to_target + if isinstance(node_value, ast.Name) and node_value.id in self.imported_modules: + roots_possible = self._dot_methods.get(node_attr) + if roots_possible: + imported_name = node_value.id + original_name = self.alias_mapping.get(imported_name, imported_name) + if original_name in roots_possible: + self.found_any_target_function = True + self.found_qualified_name = self._class_method_to_target[(original_name, node_attr)] + return + + # Check if this is accessing a method on an instance variable + if isinstance(node_value, ast.Name) and node_value.id in self.instance_mapping: + class_name = self.instance_mapping[node_value.id] + roots_possible = self._dot_methods.get(node_attr) + if roots_possible and class_name in roots_possible: + self.found_any_target_function = True + self.found_qualified_name = self._class_method_to_target[(class_name, node_attr)] + return + + # Check for dynamic import match + if self.has_dynamic_imports and node_attr in self.function_names_to_find: self.found_any_target_function = True - self.found_qualified_name = node.attr + self.found_qualified_name = node_attr return self.generic_visit(node) diff --git a/tests/scripts/end_to_end_test_async.py b/tests/scripts/end_to_end_test_async.py index 2de98c6f1..09bcd2bb7 100644 --- a/tests/scripts/end_to_end_test_async.py +++ b/tests/scripts/end_to_end_test_async.py @@ -7,7 +7,6 @@ def run_test(expected_improvement_pct: int) -> bool: config = TestConfig( file_path="main.py", - expected_unit_tests=0, min_improvement_x=0.1, enable_async=True, coverage_expectations=[ @@ -25,4 +24,4 @@ def run_test(expected_improvement_pct: int) -> bool: if __name__ == "__main__": - exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 10)))) \ No newline at end of file + exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 10)))) diff --git a/tests/scripts/end_to_end_test_futurehouse.py b/tests/scripts/end_to_end_test_futurehouse.py index 95271cff1..3c8a3c38c 100644 --- a/tests/scripts/end_to_end_test_futurehouse.py +++ b/tests/scripts/end_to_end_test_futurehouse.py @@ -7,7 +7,7 @@ def run_test(expected_improvement_pct: int) -> bool: config = TestConfig( file_path="src/aviary/common_tags.py", - expected_unit_tests=2, + expected_unit_tests=1, min_improvement_x=0.1, coverage_expectations=[ CoverageExpectation( diff --git a/tests/scripts/end_to_end_test_topological_sort_worktree.py b/tests/scripts/end_to_end_test_topological_sort_worktree.py index 7d5a3fec5..260417a00 100644 --- a/tests/scripts/end_to_end_test_topological_sort_worktree.py +++ b/tests/scripts/end_to_end_test_topological_sort_worktree.py @@ -18,6 +18,7 @@ def run_test(expected_improvement_pct: int) -> bool: expected_lines=[25, 26, 27, 28, 29, 30, 31], ) ], + expected_unit_tests=1, ) cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize").resolve() return_var = run_codeflash_command(cwd, config, expected_improvement_pct) diff --git a/tests/scripts/end_to_end_test_tracer_replay.py b/tests/scripts/end_to_end_test_tracer_replay.py index 118eb1a9c..72d6fe97f 100644 --- a/tests/scripts/end_to_end_test_tracer_replay.py +++ b/tests/scripts/end_to_end_test_tracer_replay.py @@ -8,7 +8,7 @@ def run_test(expected_improvement_pct: int) -> bool: config = TestConfig( trace_mode=True, min_improvement_x=0.1, - expected_unit_tests=8, + expected_unit_tests=0, coverage_expectations=[ CoverageExpectation(function_name="funcA", expected_coverage=100.0, expected_lines=[6, 7, 8, 9, 11, 14]) ], diff --git a/tests/scripts/end_to_end_test_utilities.py b/tests/scripts/end_to_end_test_utilities.py index afd74823d..ede7c3a49 100644 --- a/tests/scripts/end_to_end_test_utilities.py +++ b/tests/scripts/end_to_end_test_utilities.py @@ -170,7 +170,7 @@ def validate_output(stdout: str, return_code: int, expected_improvement_pct: int return False if config.expected_unit_tests is not None: - unit_test_match = re.search(r"Discovered (\d+) existing unit tests", stdout) + unit_test_match = re.search(r"Discovered (\d+) existing unit test file", stdout) if not unit_test_match: logging.error("Could not find unit test count") return False diff --git a/tests/test_unit_test_discovery.py b/tests/test_unit_test_discovery.py index 5af66ebc4..1a0c73ffa 100644 --- a/tests/test_unit_test_discovery.py +++ b/tests/test_unit_test_discovery.py @@ -1310,6 +1310,30 @@ def test_target(): assert should_process is True +def test_analyze_imports_method(): + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from code_to_optimize.topological_sort import Graph + + +def test_topological_sort(): + g = Graph(6) + g.addEdge(5, 2) + g.addEdge(5, 0) + g.addEdge(4, 0) + g.addEdge(4, 1) + g.addEdge(2, 3) + g.addEdge(3, 1) + + assert g.topologicalSort()[0] == [5, 4, 2, 3, 1, 0] +""" + test_file.write_text(test_content) + + target_functions = {"Graph.topologicalSort"} + should_process = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is True def test_analyze_imports_aliased_class_method_negative(): with tempfile.TemporaryDirectory() as tmpdirname: @@ -1333,6 +1357,368 @@ def test_target(): +def test_analyze_imports_class_with_multiple_methods(): + """Test importing a class when looking for multiple methods of that class.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from mymodule import MyClass + +def test_methods(): + obj = MyClass() + assert obj.method1() is True + assert obj.method2() is False + assert obj.method3() == 42 +""" + test_file.write_text(test_content) + + # Looking for multiple methods of the same class + target_functions = {"MyClass.method1", "MyClass.method2", "MyClass.method3"} + should_process = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is True + + +def test_analyze_imports_class_method_with_nested_classes(): + """Test importing nested classes and their methods.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from mymodule import OuterClass + +def test_nested(): + outer = OuterClass() + inner = outer.InnerClass() + assert inner.inner_method() is True +""" + test_file.write_text(test_content) + + # This would require more complex analysis of nested classes + # Currently only direct class.method patterns are supported + target_functions = {"OuterClass.InnerClass.inner_method"} + should_process = analyze_imports_in_test_file(test_file, target_functions) + + # Our fix detects OuterClass from OuterClass.InnerClass.inner_method + # This is overly broad but conservative (better to include than exclude) + assert should_process is True + + +def test_analyze_imports_class_method_partial_match(): + """Test that partial class names don't match incorrectly.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from mymodule import GraphBuilder + +def test_builder(): + builder = GraphBuilder() + assert builder.build() is not None +""" + test_file.write_text(test_content) + + # Looking for Graph.topologicalSort, not GraphBuilder + target_functions = {"Graph.topologicalSort"} + should_process = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is False + + +def test_analyze_imports_class_method_with_inheritance(): + """Test importing a child class when looking for parent class methods.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from mymodule import ChildClass + +def test_inherited(): + child = ChildClass() + # Assuming ChildClass inherits from ParentClass + assert child.parent_method() is True +""" + test_file.write_text(test_content) + + # Looking for parent class method, but only child is imported + target_functions = {"ParentClass.parent_method"} + should_process = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is False + + +def test_analyze_imports_class_static_and_class_methods(): + """Test importing a class and calling static/class methods.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from mymodule import MyClass + +def test_static_and_class_methods(): + # Static method call + assert MyClass.static_method() is True + + # Class method call + result = MyClass.class_method() + assert result == "expected" + + # Instance method call + obj = MyClass() + assert obj.instance_method() is False +""" + test_file.write_text(test_content) + + target_functions = {"MyClass.static_method", "MyClass.class_method", "MyClass.instance_method"} + should_process = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is True + + +def test_analyze_imports_multiple_classes_same_module(): + """Test importing multiple classes from the same module.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from mymodule import ClassA, ClassB, ClassC + +def test_multiple_classes(): + a = ClassA() + b = ClassB() + c = ClassC() + + assert a.methodA() is True + assert b.methodB() is False + assert c.methodC() == 42 +""" + test_file.write_text(test_content) + + # Looking for methods from different classes + target_functions = {"ClassA.methodA", "ClassB.methodB", "ClassD.methodD"} + should_process = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is True # ClassA and ClassB are imported + + +def test_analyze_imports_class_method_case_sensitive(): + """Test that class name matching is case-sensitive.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from mymodule import graph + +def test_lowercase(): + g = graph() + assert g.topologicalSort() is not None +""" + test_file.write_text(test_content) + + # Looking for Graph (capital G), but imported graph (lowercase) + target_functions = {"Graph.topologicalSort"} + should_process = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is False + + +def test_analyze_imports_class_from_submodule(): + """Test importing a class from a submodule.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from package.subpackage.module import MyClass + +def test_submodule_class(): + obj = MyClass() + assert obj.my_method() is True +""" + test_file.write_text(test_content) + + target_functions = {"MyClass.my_method"} + should_process = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is True + + +def test_analyze_imports_aliased_class_with_methods(): + """Test importing a class with an alias and looking for its methods.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from mymodule import Graph as G + +def test_aliased_class(): + graph = G(10) + result = graph.topologicalSort() + assert result is not None +""" + test_file.write_text(test_content) + + target_functions = {"Graph.topologicalSort"} + should_process = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is True + + +def test_analyze_imports_class_property_access(): + """Test importing a class and accessing properties (not methods).""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from mymodule import MyClass + +def test_properties(): + obj = MyClass() + # Accessing properties, not methods + assert obj.size == 10 + assert obj.name == "test" +""" + test_file.write_text(test_content) + + # Looking for methods, but only properties are accessed + # Our fix conservatively includes when class is imported + target_functions = {"MyClass.get_size", "MyClass.get_name"} + should_process = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is True # Conservative approach + + +def test_analyze_imports_class_constructor_params(): + """Test class import when looking for __init__ method.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from mymodule import MyClass + +def test_constructor(): + # Testing the constructor + obj1 = MyClass() + obj2 = MyClass(10) + obj3 = MyClass(size=20, name="test") + + assert obj1 is not None + assert obj2 is not None + assert obj3 is not None +""" + test_file.write_text(test_content) + + # __init__ is a special method that would require additional logic + target_functions = {"MyClass.__init__"} + should_process = analyze_imports_in_test_file(test_file, target_functions) + + # Our fix now detects MyClass from MyClass.__init__ + assert should_process is True + + +def test_analyze_imports_class_method_chaining(): + """Test method chaining on imported classes.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from mymodule import Builder + +def test_chaining(): + result = Builder().add_item("a").add_item("b").build() + assert result is not None +""" + test_file.write_text(test_content) + + # Method chaining requires tracking object types through chained calls + target_functions = {"Builder.add_item", "Builder.build"} + should_process = analyze_imports_in_test_file(test_file, target_functions) + + # Currently detects Builder import and methods + assert should_process is True + + +def test_analyze_imports_mixed_function_and_class_imports(): + """Test mixed imports of functions and classes from the same module.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from mymodule import MyClass, standalone_function, AnotherClass + +def test_mixed(): + # Using class method + obj = MyClass() + assert obj.method() is True + + # Using standalone function + assert standalone_function() is False + + # Using another class + other = AnotherClass() + assert other.other_method() == 42 +""" + test_file.write_text(test_content) + + target_functions = {"MyClass.method", "standalone_function", "YetAnotherClass.method"} + should_process = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is True # MyClass.method and standalone_function are imported + + +def test_analyze_imports_class_with_module_prefix(): + """Test looking for fully qualified class methods.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from code_to_optimize.topological_sort import Graph + +def test_fully_qualified(): + g = Graph(5) + assert g.topologicalSort() == [4, 3, 2, 1, 0] +""" + test_file.write_text(test_content) + + # Looking with full module path would require more complex module resolution + target_functions = {"code_to_optimize.topological_sort.Graph.topologicalSort"} + should_process = analyze_imports_in_test_file(test_file, target_functions) + + # Currently not supported - would need to match module path with imports + assert should_process is False + + +def test_analyze_imports_reimport_in_function(): + """Test class import inside a function.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +def test_local_import(): + from mymodule import MyClass + obj = MyClass() + assert obj.method() is True +""" + test_file.write_text(test_content) + + target_functions = {"MyClass.method"} + should_process = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is True + + +def test_analyze_imports_class_in_type_annotation(): + """Test class used only in type annotations.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from typing import Optional +from mymodule import MyClass + +def helper_function(obj: Optional[MyClass]) -> bool: + if obj: + return obj.method() + return False + +def test_with_type_annotation(): + # MyClass is imported but only used in type annotation + result = helper_function(None) + assert result is False +""" + test_file.write_text(test_content) + + target_functions = {"MyClass.method"} + should_process = analyze_imports_in_test_file(test_file, target_functions) + + # MyClass is imported, so class.method pattern should match + assert should_process is True + + def test_discover_unit_tests_caching(): tests_root = Path(__file__).parent.resolve() / "tests" project_root_path = tests_root.parent.resolve()