diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 1fd86acce..7a5bf6f9d 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -1,6 +1,7 @@ # ruff: noqa: SLF001 from __future__ import annotations +import ast import hashlib import os import pickle @@ -12,6 +13,9 @@ from pathlib import Path from typing import TYPE_CHECKING, Callable, Optional +if TYPE_CHECKING: + from codeflash.discovery.functions_to_optimize import FunctionToOptimize + import pytest from pydantic.dataclasses import dataclass from rich.panel import Panel @@ -86,7 +90,6 @@ def insert_test( line_number: int, col_number: int, ) -> None: - self.cur.execute("DELETE FROM discovered_tests WHERE file_path = ?", (file_path,)) test_type_value = test_type.value if hasattr(test_type, "value") else test_type self.cur.execute( "INSERT INTO discovered_tests VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", @@ -137,21 +140,172 @@ def close(self) -> None: self.connection.close() +class ImportAnalyzer(ast.NodeVisitor): + """AST-based analyzer to find all imports in a test file.""" + + def __init__(self, function_names_to_find: set[str]) -> None: + self.function_names_to_find = function_names_to_find + self.imported_names: set[str] = set() + self.imported_modules: set[str] = set() + self.found_target_functions: set[str] = set() + self.qualified_names_called: set[str] = set() + + def visit_Import(self, node: ast.Import) -> None: + """Handle 'import module' statements.""" + for alias in node.names: + module_name = alias.asname if alias.asname else alias.name + self.imported_modules.add(module_name) + self.imported_names.add(module_name) + self.generic_visit(node) + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + """Handle 'from module import name' statements.""" + if node.module: + self.imported_modules.add(node.module) + + for alias in node.names: + if alias.name == "*": + continue + imported_name = alias.asname if alias.asname else alias.name + self.imported_names.add(imported_name) + if alias.name in self.function_names_to_find: + self.found_target_functions.add(alias.name) + # Check for qualified name matches + if node.module: + qualified_name = f"{node.module}.{alias.name}" + if qualified_name in self.function_names_to_find: + self.found_target_functions.add(qualified_name) + self.generic_visit(node) + + def visit_Call(self, node: ast.Call) -> None: + """Handle dynamic imports like importlib.import_module() or __import__().""" + if ( + isinstance(node.func, ast.Name) + and node.func.id == "__import__" + and node.args + and isinstance(node.args[0], ast.Constant) + and isinstance(node.args[0].value, str) + ): + # __import__("module_name") + self.imported_modules.add(node.args[0].value) + elif ( + isinstance(node.func, ast.Attribute) + and isinstance(node.func.value, ast.Name) + and node.func.value.id == "importlib" + and node.func.attr == "import_module" + and node.args + and isinstance(node.args[0], ast.Constant) + and isinstance(node.args[0].value, str) + ): + # importlib.import_module("module_name") + self.imported_modules.add(node.args[0].value) + self.generic_visit(node) + + def visit_Name(self, node: ast.Name) -> None: + """Check if any name usage matches our target functions.""" + if node.id in self.function_names_to_find: + self.found_target_functions.add(node.id) + self.generic_visit(node) + + def visit_Attribute(self, node: ast.Attribute) -> None: + """Handle module.function_name patterns.""" + if node.attr in self.function_names_to_find: + self.found_target_functions.add(node.attr) + if isinstance(node.value, ast.Name): + qualified_name = f"{node.value.id}.{node.attr}" + self.qualified_names_called.add(qualified_name) + self.generic_visit(node) + + +def analyze_imports_in_test_file(test_file_path: Path | str, target_functions: set[str]) -> tuple[bool, set[str]]: + """Analyze imports in a test file to determine if it might test any target functions. + + Args: + test_file_path: Path to the test file + target_functions: Set of function names we're looking for + + Returns: + Tuple of (should_process_with_jedi, found_function_names) + + """ + if isinstance(test_file_path, str): + test_file_path = Path(test_file_path) + + try: + with test_file_path.open("r", encoding="utf-8") as f: + content = f.read() + + tree = ast.parse(content, filename=str(test_file_path)) + analyzer = ImportAnalyzer(target_functions) + analyzer.visit(tree) + + if analyzer.found_target_functions: + return True, analyzer.found_target_functions + + return False, set() # noqa: TRY300 + + except (SyntaxError, UnicodeDecodeError, OSError) as e: + logger.debug(f"Failed to analyze imports in {test_file_path}: {e}") + return True, set() + + +def filter_test_files_by_imports( + file_to_test_map: dict[Path, list[TestsInFile]], target_functions: set[str] +) -> tuple[dict[Path, list[TestsInFile]], dict[Path, set[str]]]: + """Filter test files based on import analysis to reduce Jedi processing. + + Args: + file_to_test_map: Original mapping of test files to test functions + target_functions: Set of function names we're optimizing + + Returns: + Tuple of (filtered_file_map, import_analysis_results) + + """ + if not target_functions: + return file_to_test_map, {} + + filtered_map = {} + import_results = {} + + for test_file, test_functions in file_to_test_map.items(): + should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + import_results[test_file] = found_functions + + if should_process: + filtered_map[test_file] = test_functions + else: + logger.debug(f"Skipping {test_file} - no relevant imports found") + + logger.debug(f"Import filter: Processing {len(filtered_map)}/{len(file_to_test_map)} test files") + return filtered_map, import_results + + def discover_unit_tests( - cfg: TestConfig, discover_only_these_tests: list[Path] | None = None -) -> dict[str, list[FunctionCalledInTest]]: + cfg: TestConfig, + discover_only_these_tests: list[Path] | None = None, + file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]] | None = None, +) -> tuple[dict[str, set[FunctionCalledInTest]], int]: framework_strategies: dict[str, Callable] = {"pytest": discover_tests_pytest, "unittest": discover_tests_unittest} strategy = framework_strategies.get(cfg.test_framework, None) if not strategy: error_message = f"Unsupported test framework: {cfg.test_framework}" raise ValueError(error_message) - return strategy(cfg, discover_only_these_tests) + # Extract all functions to optimize for import filtering + functions_to_optimize = None + if file_to_funcs_to_optimize: + functions_to_optimize = [func for funcs_list in file_to_funcs_to_optimize.values() for func in funcs_list] + + function_to_tests, num_discovered_tests = strategy(cfg, discover_only_these_tests, functions_to_optimize) + return function_to_tests, num_discovered_tests def discover_tests_pytest( - cfg: TestConfig, discover_only_these_tests: list[Path] | None = None -) -> dict[Path, list[FunctionCalledInTest]]: + cfg: TestConfig, + discover_only_these_tests: list[Path] | None = None, + functions_to_optimize: list[FunctionToOptimize] | None = None, +) -> tuple[dict[str, set[FunctionCalledInTest]], int]: tests_root = cfg.tests_root project_root = cfg.project_root_path @@ -220,12 +374,14 @@ def discover_tests_pytest( continue file_to_test_map[test_obj.test_file].append(test_obj) # Within these test files, find the project functions they are referring to and return their names/locations - return process_test_files(file_to_test_map, cfg) + return process_test_files(file_to_test_map, cfg, functions_to_optimize) def discover_tests_unittest( - cfg: TestConfig, discover_only_these_tests: list[str] | None = None -) -> dict[Path, list[FunctionCalledInTest]]: + cfg: TestConfig, + discover_only_these_tests: list[str] | None = None, + functions_to_optimize: list[FunctionToOptimize] | None = None, +) -> tuple[dict[str, set[FunctionCalledInTest]], int]: tests_root: Path = cfg.tests_root loader: unittest.TestLoader = unittest.TestLoader() tests: unittest.TestSuite = loader.discover(str(tests_root)) @@ -277,7 +433,7 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None: details = get_test_details(test) if details is not None: file_to_test_map[str(details.test_file)].append(details) - return process_test_files(file_to_test_map, cfg) + return process_test_files(file_to_test_map, cfg, functions_to_optimize) def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | None]: @@ -289,47 +445,50 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N def process_test_files( - file_to_test_map: dict[Path, list[TestsInFile]], cfg: TestConfig -) -> dict[str, list[FunctionCalledInTest]]: + file_to_test_map: dict[Path, list[TestsInFile]], + cfg: TestConfig, + functions_to_optimize: list[FunctionToOptimize] | None = None, +) -> tuple[dict[str, set[FunctionCalledInTest]], int]: import jedi project_root_path = cfg.project_root_path test_framework = cfg.test_framework + if functions_to_optimize: + target_function_names = set() + for func in functions_to_optimize: + target_function_names.add(func.qualified_name) + logger.debug(f"Target functions for import filtering: {target_function_names}") + file_to_test_map, import_results = filter_test_files_by_imports(file_to_test_map, target_function_names) + logger.debug(f"Import analysis results: {len(import_results)} files analyzed") + function_to_test_map = defaultdict(set) + num_discovered_tests = 0 jedi_project = jedi.Project(path=project_root_path) - goto_cache = {} - tests_cache = TestsCache() with test_files_progress_bar(total=len(file_to_test_map), description="Processing test files") as ( progress, task_id, ): for test_file, functions in file_to_test_map.items(): - file_hash = TestsCache.compute_file_hash(test_file) - cached_tests = tests_cache.get_tests_for_file(str(test_file), file_hash) - if cached_tests: - self_cur = tests_cache.cur - self_cur.execute( - "SELECT qualified_name_with_modules_from_root FROM discovered_tests WHERE file_path = ? AND file_hash = ?", - (str(test_file), file_hash), - ) - qualified_names = [row[0] for row in self_cur.fetchall()] - for cached, qualified_name in zip(cached_tests, qualified_names): - function_to_test_map[qualified_name].add(cached) - progress.advance(task_id) - continue - try: script = jedi.Script(path=test_file, project=jedi_project) test_functions = set() - all_names = script.get_names(all_scopes=True, references=True) - all_defs = script.get_names(all_scopes=True, definitions=True) - all_names_top = script.get_names(all_scopes=True) + # Single call to get all names with references and definitions + all_names = script.get_names(all_scopes=True, references=True, definitions=True) + + # Filter once and create lookup dictionaries + top_level_functions = {} + top_level_classes = {} + all_defs = [] - top_level_functions = {name.name: name for name in all_names_top if name.type == "function"} - top_level_classes = {name.name: name for name in all_names_top if name.type == "class"} + for name in all_names: + if name.type == "function": + top_level_functions[name.name] = name + all_defs.append(name) + elif name.type == "class": + top_level_classes[name.name] = name except Exception as e: logger.debug(f"Failed to get jedi script for {test_file}: {e}") progress.advance(task_id) @@ -394,31 +553,23 @@ def process_test_files( ) ) - test_functions_list = list(test_functions) - test_functions_raw = [elem.function_name for elem in test_functions_list] - test_functions_by_name = defaultdict(list) - for i, func_name in enumerate(test_functions_raw): - test_functions_by_name[func_name].append(i) + for func in test_functions: + test_functions_by_name[func.function_name].append(func) - for name in all_names: - if name.full_name is None: - continue - m = FUNCTION_NAME_REGEX.search(name.full_name) - if not m: - continue + test_function_names_set = set(test_functions_by_name.keys()) + relevant_names = [] - scope = m.group(1) - if scope not in test_functions_by_name: - continue + names_with_full_name = [name for name in all_names if name.full_name is not None] + + for name in names_with_full_name: + match = FUNCTION_NAME_REGEX.search(name.full_name) + if match and match.group(1) in test_function_names_set: + relevant_names.append((name, match.group(1))) - cache_key = (name.full_name, name.module_name) + for name, scope in relevant_names: try: - if cache_key in goto_cache: - definition = goto_cache[cache_key] - else: - definition = name.goto(follow_imports=True, follow_builtin_imports=False) - goto_cache[cache_key] = definition + definition = name.goto(follow_imports=True, follow_builtin_imports=False) except Exception as e: logger.debug(str(e)) continue @@ -426,54 +577,42 @@ def process_test_files( if not definition or definition[0].type != "function": continue - definition_path = str(definition[0].module_path) + definition_obj = definition[0] + definition_path = str(definition_obj.module_path) + + project_root_str = str(project_root_path) if ( - definition_path.startswith(str(project_root_path) + os.sep) - and definition[0].module_name != name.module_name - and definition[0].full_name is not None + definition_path.startswith(project_root_str + os.sep) + and definition_obj.module_name != name.module_name + and definition_obj.full_name is not None ): - for index in test_functions_by_name[scope]: - scope_test_function = test_functions_list[index].function_name - scope_test_class = test_functions_list[index].test_class - scope_parameters = test_functions_list[index].parameters - test_type = test_functions_list[index].test_type + # Pre-compute common values outside the inner loop + module_prefix = definition_obj.module_name + "." + full_name_without_module_prefix = definition_obj.full_name.replace(module_prefix, "", 1) + qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition_obj.module_path, project_root_path)}.{full_name_without_module_prefix}" - if scope_parameters is not None: + for test_func in test_functions_by_name[scope]: + if test_func.parameters is not None: if test_framework == "pytest": - scope_test_function += "[" + scope_parameters + "]" - if test_framework == "unittest": - scope_test_function += "_" + scope_parameters - - full_name_without_module_prefix = definition[0].full_name.replace( - definition[0].module_name + ".", "", 1 - ) - qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition[0].module_path, project_root_path)}.{full_name_without_module_prefix}" - - tests_cache.insert_test( - file_path=str(test_file), - file_hash=file_hash, - qualified_name_with_modules_from_root=qualified_name_with_modules_from_root, - function_name=scope, - test_class=scope_test_class, - test_function=scope_test_function, - test_type=test_type, - line_number=name.line, - col_number=name.column, - ) + scope_test_function = f"{test_func.function_name}[{test_func.parameters}]" + else: # unittest + scope_test_function = f"{test_func.function_name}_{test_func.parameters}" + else: + scope_test_function = test_func.function_name function_to_test_map[qualified_name_with_modules_from_root].add( FunctionCalledInTest( tests_in_file=TestsInFile( test_file=test_file, - test_class=scope_test_class, + test_class=test_func.test_class, test_function=scope_test_function, - test_type=test_type, + test_type=test_func.test_type, ), position=CodePosition(line_no=name.line, col_no=name.column), ) ) + num_discovered_tests += 1 progress.advance(task_id) - tests_cache.close() - return {function: list(tests) for function, tests in function_to_test_map.items()} + return dict(function_to_test_map), num_discovered_tests diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 41d99ec2c..931b3a05a 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -268,7 +268,7 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt def get_all_replay_test_functions( replay_test: Path, test_cfg: TestConfig, project_root_path: Path ) -> dict[Path, list[FunctionToOptimize]]: - function_tests = discover_unit_tests(test_cfg, discover_only_these_tests=[replay_test]) + function_tests, _ = discover_unit_tests(test_cfg, discover_only_these_tests=[replay_test]) # Get the absolute file paths for each function, excluding class name if present filtered_valid_functions = defaultdict(list) file_to_functions_map = defaultdict(list) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 5922d6c1c..9f5781697 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -57,7 +57,6 @@ from codeflash.models.models import ( BestOptimization, CodeOptimizationContext, - FunctionCalledInTest, GeneratedTests, GeneratedTestsList, OptimizationSet, @@ -87,7 +86,13 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.either import Result - from codeflash.models.models import BenchmarkKey, CoverageData, FunctionSource, OptimizedCandidate + from codeflash.models.models import ( + BenchmarkKey, + CoverageData, + FunctionCalledInTest, + FunctionSource, + OptimizedCandidate, + ) from codeflash.verification.verification_utils import TestConfig @@ -97,7 +102,7 @@ def __init__( function_to_optimize: FunctionToOptimize, test_cfg: TestConfig, function_to_optimize_source_code: str = "", - function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None, + function_to_tests: dict[str, set[FunctionCalledInTest]] | None = None, function_to_optimize_ast: ast.FunctionDef | None = None, aiservice_client: AiServiceClient | None = None, function_benchmark_timings: dict[BenchmarkKey, int] | None = None, @@ -213,7 +218,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 function_to_optimize_qualified_name = self.function_to_optimize.qualified_name function_to_all_tests = { - key: self.function_to_tests.get(key, []) + function_to_concolic_tests.get(key, []) + key: self.function_to_tests.get(key, set()) | function_to_concolic_tests.get(key, set()) for key in set(self.function_to_tests) | set(function_to_concolic_tests) } instrumented_unittests_created_for_function = self.instrument_existing_tests(function_to_all_tests) @@ -690,7 +695,7 @@ def cleanup_leftover_test_return_values() -> None: get_run_tmp_file(Path("test_return_values_0.bin")).unlink(missing_ok=True) get_run_tmp_file(Path("test_return_values_0.sqlite")).unlink(missing_ok=True) - def instrument_existing_tests(self, function_to_all_tests: dict[str, list[FunctionCalledInTest]]) -> set[Path]: + def instrument_existing_tests(self, function_to_all_tests: dict[str, set[FunctionCalledInTest]]) -> set[Path]: existing_test_files_count = 0 replay_test_files_count = 0 concolic_coverage_test_files_count = 0 @@ -701,7 +706,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, list[Functi logger.info(f"Did not find any pre-existing tests for '{func_qualname}', will only use generated tests.") console.rule() else: - test_file_invocation_positions = defaultdict(list[FunctionCalledInTest]) + test_file_invocation_positions = defaultdict(list) for tests_in_file in function_to_all_tests.get(func_qualname): test_file_invocation_positions[ (tests_in_file.tests_in_file.test_file, tests_in_file.tests_in_file.test_type) @@ -787,7 +792,7 @@ def generate_tests_and_optimizations( generated_test_paths: list[Path], generated_perf_test_paths: list[Path], run_experiment: bool = False, # noqa: FBT001, FBT002 - ) -> Result[tuple[GeneratedTestsList, dict[str, list[FunctionCalledInTest]], OptimizationSet], str]: + ) -> Result[tuple[GeneratedTestsList, dict[str, set[FunctionCalledInTest]], OptimizationSet], str]: assert len(generated_test_paths) == N_TESTS_TO_GENERATE max_workers = N_TESTS_TO_GENERATE + 2 if not run_experiment else N_TESTS_TO_GENERATE + 3 console.rule() diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 9e5715e2a..55ab14c35 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -48,7 +48,7 @@ def create_function_optimizer( self, function_to_optimize: FunctionToOptimize, function_to_optimize_ast: ast.FunctionDef | None = None, - function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None, + function_to_tests: dict[str, set[FunctionCalledInTest]] | None = None, function_to_optimize_source_code: str | None = "", function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] | None = None, total_benchmark_timings: dict[BenchmarkKey, float] | None = None, @@ -162,8 +162,9 @@ def run(self) -> None: console.rule() start_time = time.time() - function_to_tests: dict[str, list[FunctionCalledInTest]] = discover_unit_tests(self.test_cfg) - num_discovered_tests: int = sum([len(value) for value in function_to_tests.values()]) + function_to_tests, num_discovered_tests = discover_unit_tests( + self.test_cfg, file_to_funcs_to_optimize=file_to_funcs_to_optimize + ) console.rule() logger.info( f"Discovered {num_discovered_tests} existing unit tests in {(time.time() - start_time):.1f}s at {self.test_cfg.tests_root}" diff --git a/codeflash/result/create_pr.py b/codeflash/result/create_pr.py index 8524d397e..b9e05e660 100644 --- a/codeflash/result/create_pr.py +++ b/codeflash/result/create_pr.py @@ -25,7 +25,7 @@ def existing_tests_source_for( function_qualified_name_with_modules_from_root: str, - function_to_tests: dict[str, list[FunctionCalledInTest]], + function_to_tests: dict[str, set[FunctionCalledInTest]], tests_root: Path, ) -> str: test_files = function_to_tests.get(function_qualified_name_with_modules_from_root) diff --git a/codeflash/verification/concolic_testing.py b/codeflash/verification/concolic_testing.py index 5792a289d..014620f28 100644 --- a/codeflash/verification/concolic_testing.py +++ b/codeflash/verification/concolic_testing.py @@ -24,7 +24,7 @@ def generate_concolic_tests( test_cfg: TestConfig, args: Namespace, function_to_optimize: FunctionToOptimize, function_to_optimize_ast: ast.AST -) -> tuple[dict[str, list[FunctionCalledInTest]], str]: +) -> tuple[dict[str, set[FunctionCalledInTest]], str]: start_time = time.perf_counter() function_to_concolic_tests = {} concolic_test_suite_code = "" @@ -78,8 +78,7 @@ def generate_concolic_tests( test_framework=args.test_framework, pytest_cmd=args.pytest_cmd, ) - function_to_concolic_tests = discover_unit_tests(concolic_test_cfg) - num_discovered_concolic_tests: int = sum([len(value) for value in function_to_concolic_tests.values()]) + function_to_concolic_tests, num_discovered_concolic_tests = discover_unit_tests(concolic_test_cfg) logger.info( f"Created {num_discovered_concolic_tests} " f"concolic unit test case{'s' if num_discovered_concolic_tests != 1 else ''} " diff --git a/tests/test_static_analysis.py b/tests/test_static_analysis.py index c4da29c03..b997edeab 100644 --- a/tests/test_static_analysis.py +++ b/tests/test_static_analysis.py @@ -1,4 +1,4 @@ -import ast +import ast from pathlib import Path from codeflash.code_utils.static_analysis import ( diff --git a/tests/test_unit_test_discovery.py b/tests/test_unit_test_discovery.py index 8c3bc35c8..4465acf79 100644 --- a/tests/test_unit_test_discovery.py +++ b/tests/test_unit_test_discovery.py @@ -2,8 +2,14 @@ import tempfile from pathlib import Path -from codeflash.discovery.discover_unit_tests import discover_unit_tests +from codeflash.discovery.discover_unit_tests import ( + analyze_imports_in_test_file, + discover_unit_tests, + filter_test_files_by_imports, +) +from codeflash.models.models import TestsInFile, TestType from codeflash.verification.verification_utils import TestConfig +from codeflash.discovery.functions_to_optimize import FunctionToOptimize def test_unit_test_discovery_pytest(): @@ -15,7 +21,7 @@ def test_unit_test_discovery_pytest(): test_framework="pytest", tests_project_rootdir=tests_path.parent, ) - tests = discover_unit_tests(test_config) + tests, _ = discover_unit_tests(test_config) assert len(tests) > 0 @@ -28,7 +34,7 @@ def test_benchmark_test_discovery_pytest(): test_framework="pytest", tests_project_rootdir=tests_path.parent, ) - tests = discover_unit_tests(test_config) + tests, _ = discover_unit_tests(test_config) assert len(tests) == 1 # Should not discover benchmark tests @@ -42,7 +48,7 @@ def test_unit_test_discovery_unittest(): tests_project_rootdir=project_path.parent, ) os.chdir(project_path) - tests = discover_unit_tests(test_config) + tests, _ = discover_unit_tests(test_config) # assert len(tests) > 0 # Unittest discovery within a pytest environment does not work @@ -80,7 +86,7 @@ def sorter(arr): ) # Discover tests - tests = discover_unit_tests(test_config) + tests, _ = discover_unit_tests(test_config) assert len(tests) == 1 assert 'bubble_sort.sorter' in tests assert len(tests['bubble_sort.sorter']) == 2 @@ -119,17 +125,14 @@ def test_discover_tests_pytest_with_temp_dir_root(): ) # Discover tests - discovered_tests = discover_unit_tests(test_config) + discovered_tests, _ = discover_unit_tests(test_config) # Check if the dummy test file is discovered assert len(discovered_tests) == 1 assert len(discovered_tests["dummy_code.dummy_function"]) == 2 - assert discovered_tests["dummy_code.dummy_function"][0].tests_in_file.test_file == test_file_path - assert discovered_tests["dummy_code.dummy_function"][1].tests_in_file.test_file == test_file_path - assert { - discovered_tests["dummy_code.dummy_function"][0].tests_in_file.test_function, - discovered_tests["dummy_code.dummy_function"][1].tests_in_file.test_function, - } == {"test_dummy_parametrized_function[True]", "test_dummy_function"} + dummy_tests = discovered_tests["dummy_code.dummy_function"] + assert all(test.tests_in_file.test_file == test_file_path for test in dummy_tests) + assert {test.tests_in_file.test_function for test in dummy_tests} == {"test_dummy_parametrized_function[True]", "test_dummy_function"} def test_discover_tests_pytest_with_multi_level_dirs(): @@ -192,17 +195,17 @@ def test_discover_tests_pytest_with_multi_level_dirs(): ) # Discover tests - discovered_tests = discover_unit_tests(test_config) + discovered_tests, _ = discover_unit_tests(test_config) # Check if the test files at all levels are discovered assert len(discovered_tests) == 3 - assert discovered_tests["root_code.root_function"][0].tests_in_file.test_file == root_test_file_path + assert next(iter(discovered_tests["root_code.root_function"])).tests_in_file.test_file == root_test_file_path assert ( - discovered_tests["level1.level1_code.level1_function"][0].tests_in_file.test_file == level1_test_file_path + next(iter(discovered_tests["level1.level1_code.level1_function"])).tests_in_file.test_file == level1_test_file_path ) assert ( - discovered_tests["level1.level2.level2_code.level2_function"][0].tests_in_file.test_file + next(iter(discovered_tests["level1.level2.level2_code.level2_function"])).tests_in_file.test_file == level2_test_file_path ) @@ -282,21 +285,21 @@ def test_discover_tests_pytest_dirs(): ) # Discover tests - discovered_tests = discover_unit_tests(test_config) + discovered_tests, _ = discover_unit_tests(test_config) # Check if the test files at all levels are discovered assert len(discovered_tests) == 4 - assert discovered_tests["root_code.root_function"][0].tests_in_file.test_file == root_test_file_path + assert next(iter(discovered_tests["root_code.root_function"])).tests_in_file.test_file == root_test_file_path assert ( - discovered_tests["level1.level1_code.level1_function"][0].tests_in_file.test_file == level1_test_file_path + next(iter(discovered_tests["level1.level1_code.level1_function"])).tests_in_file.test_file == level1_test_file_path ) assert ( - discovered_tests["level1.level2.level2_code.level2_function"][0].tests_in_file.test_file + next(iter(discovered_tests["level1.level2.level2_code.level2_function"])).tests_in_file.test_file == level2_test_file_path ) assert ( - discovered_tests["level1.level3.level3_code.level3_function"][0].tests_in_file.test_file + next(iter(discovered_tests["level1.level3.level3_code.level3_function"])).tests_in_file.test_file == level3_test_file_path ) @@ -328,11 +331,11 @@ def test_discover_tests_pytest_with_class(): ) # Discover tests - discovered_tests = discover_unit_tests(test_config) + discovered_tests, _ = discover_unit_tests(test_config) # Check if the test class and method are discovered assert len(discovered_tests) == 1 - assert discovered_tests["some_class_code.SomeClass.some_method"][0].tests_in_file.test_file == test_file_path + assert next(iter(discovered_tests["some_class_code.SomeClass.some_method"])).tests_in_file.test_file == test_file_path def test_discover_tests_pytest_with_double_nested_directories(): @@ -366,14 +369,12 @@ def test_discover_tests_pytest_with_double_nested_directories(): ) # Discover tests - discovered_tests = discover_unit_tests(test_config) + discovered_tests, _ = discover_unit_tests(test_config) # Check if the test class and method are discovered assert len(discovered_tests) == 1 assert ( - discovered_tests["nested.more_nested.nested_class_code.NestedClass.nested_method"][ - 0 - ].tests_in_file.test_file + next(iter(discovered_tests["nested.more_nested.nested_class_code.NestedClass.nested_method"])).tests_in_file.test_file == test_file_path ) @@ -416,11 +417,11 @@ def test_discover_tests_with_code_in_dir_and_test_in_subdir(): ) # Discover tests - discovered_tests = discover_unit_tests(test_config) + discovered_tests, _ = discover_unit_tests(test_config) # Check if the test file is discovered and associated with the code file assert len(discovered_tests) == 1 - assert discovered_tests["code.some_code.some_function"][0].tests_in_file.test_file == test_file_path + assert next(iter(discovered_tests["code.some_code.some_function"])).tests_in_file.test_file == test_file_path def test_discover_tests_pytest_with_nested_class(): @@ -455,12 +456,12 @@ def test_discover_tests_pytest_with_nested_class(): ) # Discover tests - discovered_tests = discover_unit_tests(test_config) + discovered_tests, _ = discover_unit_tests(test_config) # Check if the test for the nested class method is discovered assert len(discovered_tests) == 1 assert ( - discovered_tests["nested_class_code.OuterClass.InnerClass.inner_method"][0].tests_in_file.test_file + next(iter(discovered_tests["nested_class_code.OuterClass.InnerClass.inner_method"])).tests_in_file.test_file == test_file_path ) @@ -495,11 +496,11 @@ def test_discover_tests_pytest_separate_moduledir(): ) # Discover tests - discovered_tests = discover_unit_tests(test_config) + discovered_tests, _ = discover_unit_tests(test_config) # Check if the test for the nested class method is discovered assert len(discovered_tests) == 1 - assert discovered_tests["mypackage.code.find_common_tags"][0].tests_in_file.test_file == test_file_path + assert next(iter(discovered_tests["mypackage.code.find_common_tags"])).tests_in_file.test_file == test_file_path def test_unittest_discovery_with_pytest(): @@ -537,14 +538,15 @@ def test_add(self): ) # Discover tests - discovered_tests = discover_unit_tests(test_config) + discovered_tests, _ = discover_unit_tests(test_config) # Verify the unittest was discovered assert len(discovered_tests) == 1 assert "calculator.Calculator.add" in discovered_tests assert len(discovered_tests["calculator.Calculator.add"]) == 1 - assert discovered_tests["calculator.Calculator.add"][0].tests_in_file.test_file == test_file_path - assert discovered_tests["calculator.Calculator.add"][0].tests_in_file.test_function == "test_add" + calculator_test = next(iter(discovered_tests["calculator.Calculator.add"])) + assert calculator_test.tests_in_file.test_file == test_file_path + assert calculator_test.tests_in_file.test_function == "test_add" def test_unittest_discovery_with_pytest_parent_class(): @@ -604,14 +606,15 @@ def test_add(self): ) # Discover tests - discovered_tests = discover_unit_tests(test_config) + discovered_tests, _ = discover_unit_tests(test_config) # Verify the unittest was discovered assert len(discovered_tests) == 2 assert "calculator.Calculator.add" in discovered_tests assert len(discovered_tests["calculator.Calculator.add"]) == 1 - assert discovered_tests["calculator.Calculator.add"][0].tests_in_file.test_file == test_file_path - assert discovered_tests["calculator.Calculator.add"][0].tests_in_file.test_function == "test_add" + calculator_test = next(iter(discovered_tests["calculator.Calculator.add"])) + assert calculator_test.tests_in_file.test_file == test_file_path + assert calculator_test.tests_in_file.test_function == "test_add" def test_unittest_discovery_with_pytest_private(): @@ -649,7 +652,7 @@ def _test_add(self): # Private test method should not be discovered ) # Discover tests - discovered_tests = discover_unit_tests(test_config) + discovered_tests, _ = discover_unit_tests(test_config) # Verify no tests were discovered assert len(discovered_tests) == 0 @@ -701,15 +704,16 @@ def test_add_with_parameters(self): ) # Discover tests - discovered_tests = discover_unit_tests(test_config) + discovered_tests, _ = discover_unit_tests(test_config) # Verify the unittest was discovered assert len(discovered_tests) == 1 assert "calculator.Calculator.add" in discovered_tests assert len(discovered_tests["calculator.Calculator.add"]) == 1 - assert discovered_tests["calculator.Calculator.add"][0].tests_in_file.test_file == test_file_path + calculator_test = next(iter(discovered_tests["calculator.Calculator.add"])) + assert calculator_test.tests_in_file.test_file == test_file_path assert ( - discovered_tests["calculator.Calculator.add"][0].tests_in_file.test_function == "test_add_with_parameters" + calculator_test.tests_in_file.test_function == "test_add_with_parameters" ) @@ -783,9 +787,432 @@ def test_add_mixed(self, name, a, b, expected): ) # Discover tests - discovered_tests = discover_unit_tests(test_config) + discovered_tests, _ = discover_unit_tests(test_config) # Verify the basic structure assert len(discovered_tests) == 2 # Should have tests for both add and multiply assert "calculator.Calculator.add" in discovered_tests assert "calculator.Calculator.multiply" in discovered_tests + + +# Import Filtering Tests + + +def test_analyze_imports_direct_function_import(): + """Test that direct function imports are detected.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from mymodule import target_function, other_function + +def test_target(): + assert target_function() is True +""" + test_file.write_text(test_content) + + target_functions = {"target_function", "missing_function"} + should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is True + assert "target_function" in found_functions + assert "missing_function" not in found_functions + + +def test_analyze_imports_star_import(): + """Test that star imports trigger conservative processing.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from mymodule import * + +def test_something(): + assert something() is True +""" + test_file.write_text(test_content) + + target_functions = {"target_function"} + should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is False + assert found_functions == set() + + +def test_analyze_imports_module_import(): + """Test module imports with function access patterns.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +import mymodule + +def test_target(): + assert mymodule.target_function() is True +""" + test_file.write_text(test_content) + + target_functions = {"target_function"} + should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is True + assert "target_function" in found_functions + + +def test_analyze_imports_dynamic_import(): + """Test detection of dynamic imports.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +import importlib + +def test_dynamic(): + module = importlib.import_module("mymodule") + assert module.target_function() is True +""" + test_file.write_text(test_content) + + target_functions = {"target_function"} + should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is True + assert "target_function" in found_functions + + +def test_analyze_imports_builtin_import(): + """Test detection of __import__ calls.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +def test_builtin_import(): + module = __import__("mymodule") + assert module.target_function() is True +""" + test_file.write_text(test_content) + + target_functions = {"target_function"} + should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is True + assert "target_function" in found_functions + + +def test_analyze_imports_no_matching_imports(): + """Test that files with no matching imports are filtered out.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from unrelated_module import unrelated_function + +def test_unrelated(): + assert unrelated_function() is True +""" + test_file.write_text(test_content) + + target_functions = {"target_function", "another_function"} + should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + assert should_process is False + assert found_functions == set() + + +def test_analyze_qualified_names(): + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from target_module import some_function + +def test_target(): + assert some_function() is True +""" + test_file.write_text(test_content) + + target_functions = {"target_module.some_function"} + should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + assert should_process is True + assert "target_module.some_function" in found_functions + + + +def test_analyze_imports_syntax_error(): + """Test handling of files with syntax errors.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from mymodule import target_function +def test_target( + # Syntax error - missing closing parenthesis + assert target_function() is True +""" + test_file.write_text(test_content) + + target_functions = {"target_function"} + should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + + # Should be conservative with unparseable files + assert should_process is True + assert found_functions == set() + + +def test_filter_test_files_by_imports(): + with tempfile.TemporaryDirectory() as tmpdirname: + tmpdir = Path(tmpdirname) + + # Create test file that imports target function + relevant_test = tmpdir / "test_relevant.py" + relevant_test.write_text(""" +from mymodule import target_function + +def test_target(): + assert target_function() is True +""") + + # Create test file that doesn't import target function + irrelevant_test = tmpdir / "test_irrelevant.py" + irrelevant_test.write_text(""" +from othermodule import other_function + +def test_other(): + assert other_function() is True +""") + + # Create test file with star import (should not be processed) + star_test = tmpdir / "test_star.py" + star_test.write_text(""" +from mymodule import * + +def test_star(): + assert something() is True +""") + + file_to_test_map = { + relevant_test: [TestsInFile(test_file=relevant_test, test_function="test_target", test_class=None, test_type=TestType.EXISTING_UNIT_TEST)], + irrelevant_test: [TestsInFile(test_file=irrelevant_test, test_function="test_other", test_class=None, test_type=TestType.EXISTING_UNIT_TEST)], + star_test: [TestsInFile(test_file=star_test, test_function="test_star", test_class=None, test_type=TestType.EXISTING_UNIT_TEST)], + } + + target_functions = {"target_function"} + filtered_map, import_results = filter_test_files_by_imports(file_to_test_map, target_functions) + + # Should filter out irrelevant_test + assert len(filtered_map) == 1 + assert relevant_test in filtered_map + assert irrelevant_test not in filtered_map + + # Check import analysis results + assert "target_function" in import_results[relevant_test] + assert len(import_results[irrelevant_test]) == 0 + assert len(import_results[star_test]) == 0 + + +def test_filter_test_files_no_target_functions(): + """Test that filtering is skipped when no target functions are provided.""" + with tempfile.TemporaryDirectory() as tmpdirname: + tmpdir = Path(tmpdirname) + + test_file = tmpdir / "test_example.py" + test_file.write_text("def test_something(): pass") + + file_to_test_map = { + test_file: [TestsInFile(test_file=test_file, test_function="test_something", test_class=None, test_type=TestType.EXISTING_UNIT_TEST)] + } + + # No target functions provided + filtered_map, import_results = filter_test_files_by_imports(file_to_test_map, set()) + + # Should return original map unchanged + assert filtered_map == file_to_test_map + assert import_results == {} + + +def test_discover_unit_tests_with_import_filtering(): + """Test the full discovery process with import filtering.""" + with tempfile.TemporaryDirectory() as tmpdirname: + tmpdir = Path(tmpdirname) + + # Create a code file + code_file = tmpdir / "mycode.py" + code_file.write_text(""" +def target_function(): + return True + +def other_function(): + return False +""") + + # Create relevant test file + relevant_test = tmpdir / "test_relevant.py" + relevant_test.write_text(""" +from mycode import target_function + +def test_target(): + assert target_function() is True +""") + + # Create irrelevant test file + irrelevant_test = tmpdir / "test_irrelevant.py" + irrelevant_test.write_text(""" +from mycode import other_function + +def test_other(): + assert other_function() is False +""") + + # Configure test discovery + test_config = TestConfig( + tests_root=tmpdir, + project_root_path=tmpdir, + test_framework="pytest", + tests_project_rootdir=tmpdir.parent, + ) + + all_tests, _ = discover_unit_tests(test_config) + assert len(all_tests) == 2 + + + fto = FunctionToOptimize( + function_name="target_function", + file_path=code_file, + parents=[], + ) + + filtered_tests, _ = discover_unit_tests(test_config, file_to_funcs_to_optimize={code_file: [fto]}) + assert len(filtered_tests) >= 1 + assert "mycode.target_function" in filtered_tests + + +def test_analyze_imports_conditional_import(): + """Test detection of conditional imports within functions.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +def test_conditional(): + if some_condition: + from mymodule import target_function + assert target_function() is True +""" + test_file.write_text(test_content) + + target_functions = {"target_function"} + should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is True + assert "target_function" in found_functions + + +def test_analyze_imports_function_name_in_code(): + """Test detection of function names used directly in code.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +import mymodule + +def test_indirect(): + func_name = "target_function" + func = getattr(mymodule, func_name) + # The analyzer should detect target_function usage + result = target_function() + assert result is True +""" + test_file.write_text(test_content) + + target_functions = {"target_function"} + should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is True + assert "target_function" in found_functions + + +def test_analyze_imports_aliased_imports(): + """Test handling of aliased imports.""" + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from mymodule import target_function as tf, other_function as of + +def test_aliased(): + assert tf() is True + assert of() is False +""" + test_file.write_text(test_content) + + target_functions = {"target_function", "missing_function"} + should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is True + assert "target_function" in found_functions + assert "missing_function" not in found_functions + + +def test_analyze_imports_underscore_function_names(): + with tempfile.TemporaryDirectory() as tmpdirname: + test_file = Path(tmpdirname) / "test_example.py" + test_content = """ +from bubble_module import sort_function + +def test_bubble(): + assert sort_function([3,1,2]) == [1,2,3] +""" + test_file.write_text(test_content) + + target_functions = {"bubble_sort"} + should_process, found_functions = analyze_imports_in_test_file(test_file, target_functions) + + assert should_process is False + assert "bubble_sort" not in found_functions + +def test_discover_unit_tests_filtering_different_modules(): + """Test import filtering with test files from completely different modules.""" + with tempfile.TemporaryDirectory() as tmpdirname: + tmpdir = Path(tmpdirname) + + # Create target code file + target_file = tmpdir / "target_module.py" + target_file.write_text(""" +def target_function(): + return True +""") + + # Create unrelated code file + unrelated_file = tmpdir / "unrelated_module.py" + unrelated_file.write_text(""" +def unrelated_function(): + return False +""") + + # Create test file that imports target function + relevant_test = tmpdir / "test_target.py" + relevant_test.write_text(""" +from target_module import target_function + +def test_target(): + assert target_function() is True +""") + + # Create test file that imports unrelated function + irrelevant_test = tmpdir / "test_unrelated.py" + irrelevant_test.write_text(""" +from unrelated_module import unrelated_function + +def test_unrelated(): + assert unrelated_function() is False +""") + + # Configure test discovery + test_config = TestConfig( + tests_root=tmpdir, + project_root_path=tmpdir, + test_framework="pytest", + tests_project_rootdir=tmpdir.parent, + ) + + # Test without filtering + all_tests, _ = discover_unit_tests(test_config) + assert len(all_tests) == 2 # Should find both functions + + fto = FunctionToOptimize( + function_name="target_function", + file_path=target_file, + parents=[], + ) + + filtered_tests, _ = discover_unit_tests(test_config, file_to_funcs_to_optimize={target_file: [fto]}) + assert len(filtered_tests) == 1 + assert "target_module.target_function" in filtered_tests + assert "unrelated_module.unrelated_function" not in filtered_tests