From 15d2027bb08a1ce66a77f5ecaf8c2d8d9acb590a Mon Sep 17 00:00:00 2001 From: ali Date: Fri, 21 Nov 2025 20:10:26 +0200 Subject: [PATCH 1/4] keep the refrenced global definitions --- codeflash/context/code_context_extractor.py | 38 +++- .../context/unused_definition_remover.py | 60 ++++-- tests/test_code_context_extractor.py | 174 +++++++++++++++--- 3 files changed, 218 insertions(+), 54 deletions(-) diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index 54fda3e16..14d549633 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -12,7 +12,11 @@ from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects from codeflash.code_utils.code_utils import encoded_tokens_len, get_qualified_name, path_belongs_to_site_packages -from codeflash.context.unused_definition_remover import remove_unused_definitions_by_function_names +from codeflash.context.unused_definition_remover import ( + collect_top_level_defs_with_usages, + extract_names_from_targets, + remove_unused_definitions_by_function_names, +) from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001 from codeflash.models.models import ( CodeContextType, @@ -29,6 +33,8 @@ from jedi.api.classes import Name from libcst import CSTNode + from codeflash.context.unused_definition_remover import UsageInfo + def get_code_optimization_context( function_to_optimize: FunctionToOptimize, @@ -498,8 +504,10 @@ def parse_code_and_prune_cst( ) -> str: """Create a read-only version of the code by parsing and filtering the code to keep only class contextual information, and other module scoped variables.""" module = cst.parse_module(code) + defs_with_usages = collect_top_level_defs_with_usages(module, target_functions | helpers_of_helper_functions) + if code_context_type == CodeContextType.READ_WRITABLE: - filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions) + filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions, defs_with_usages) elif code_context_type == CodeContextType.READ_ONLY: filtered_node, found_target = prune_cst_for_read_only_code( module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings @@ -524,7 +532,7 @@ def parse_code_and_prune_cst( def prune_cst_for_read_writable_code( # noqa: PLR0911 - node: cst.CSTNode, target_functions: set[str], prefix: str = "" + node: cst.CSTNode, target_functions: set[str], defs_with_usages: dict[str, UsageInfo], prefix: str = "" ) -> tuple[cst.CSTNode | None, bool]: """Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions. @@ -569,6 +577,21 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911 return node.with_changes(body=cst.IndentedBlock(body=new_body)), found_target + if isinstance(node, cst.Assign): + for target in node.targets: + names = extract_names_from_targets(target.target) + for name in names: + if name in defs_with_usages and defs_with_usages[name].used_by_qualified_function: + return node, True + return None, False + + if isinstance(node, (cst.AnnAssign, cst.AugAssign)): + names = extract_names_from_targets(node.target) + for name in names: + if name in defs_with_usages and defs_with_usages[name].used_by_qualified_function: + return node, True + return None, False + # For other nodes, we preserve them only if they contain target functions in their children. section_names = get_section_names(node) if not section_names: @@ -583,7 +606,9 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911 new_children = [] section_found_target = False for child in original_content: - filtered, found_target = prune_cst_for_read_writable_code(child, target_functions, prefix) + filtered, found_target = prune_cst_for_read_writable_code( + child, target_functions, defs_with_usages, prefix + ) if filtered: new_children.append(filtered) section_found_target |= found_target @@ -592,7 +617,9 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911 found_any_target = True updates[section] = new_children elif original_content is not None: - filtered, found_target = prune_cst_for_read_writable_code(original_content, target_functions, prefix) + filtered, found_target = prune_cst_for_read_writable_code( + original_content, target_functions, defs_with_usages, prefix + ) if found_target: found_any_target = True if filtered: @@ -600,7 +627,6 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911 if not found_any_target: return None, False - return (node.with_changes(**updates) if updates else node), True diff --git a/codeflash/context/unused_definition_remover.py b/codeflash/context/unused_definition_remover.py index 2a288d861..30c5d0125 100644 --- a/codeflash/context/unused_definition_remover.py +++ b/codeflash/context/unused_definition_remover.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field from itertools import chain from pathlib import Path -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union import libcst as cst @@ -122,6 +122,8 @@ def get_section_names(node: cst.CSTNode) -> list[str]: class DependencyCollector(cst.CSTVisitor): """Collects dependencies between definitions using the visitor pattern with depth tracking.""" + METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,) + def __init__(self, definitions: dict[str, UsageInfo]) -> None: super().__init__() self.definitions = definitions @@ -259,8 +261,12 @@ def visit_Name(self, node: cst.Name) -> None: if self.processing_variable and name in self.current_variable_names: return - # Check if name is a top-level definition we're tracking if name in self.definitions and name != self.current_top_level_name: + # skip if we are refrencing a class attribute and not a top-level definition + if self.class_depth > 0: + parent = self.get_metadata(cst.metadata.ParentNodeProvider, node) + if parent is not None and isinstance(parent, cst.Attribute): + return self.definitions[self.current_top_level_name].dependencies.add(name) @@ -293,13 +299,19 @@ def _expand_qualified_functions(self) -> set[str]: def mark_used_definitions(self) -> None: """Find all qualified functions and mark them and their dependencies as used.""" - # First identify all specified functions (including expanded ones) - functions_to_mark = [name for name in self.expanded_qualified_functions if name in self.definitions] + # Avoid list comprehension for set intersection + expanded_names = self.expanded_qualified_functions + defs = self.definitions + functions_to_mark = ( + expanded_names & defs.keys() + if isinstance(expanded_names, set) + else [name for name in expanded_names if name in defs] + ) # For each specified function, mark it and all its dependencies as used for func_name in functions_to_mark: - self.definitions[func_name].used_by_qualified_function = True - for dep in self.definitions[func_name].dependencies: + defs[func_name].used_by_qualified_function = True + for dep in defs[func_name].dependencies: self.mark_as_used_recursively(dep) def mark_as_used_recursively(self, name: str) -> None: @@ -457,7 +469,28 @@ def remove_unused_definitions_recursively( # noqa: PLR0911 return node, False -def remove_unused_definitions_by_function_names(code: str, qualified_function_names: set[str]) -> str: +def collect_top_level_defs_with_usages( + code: Union[str, cst.Module], qualified_function_names: set[str] +) -> dict[str, UsageInfo]: + """Collect all top level definitions (classes, variables or functions) and their usages.""" + module = code if isinstance(code, cst.Module) else cst.parse_module(code) + # Collect all definitions (top level classes, variables or function) + definitions = collect_top_level_definitions(module) + + # Collect dependencies between definitions using the visitor pattern + wrapper = cst.MetadataWrapper(module) + dependency_collector = DependencyCollector(definitions) + wrapper.visit(dependency_collector) + + # Mark definitions used by specified functions, and their dependencies recursively + usage_marker = QualifiedFunctionUsageMarker(definitions, qualified_function_names) + usage_marker.mark_used_definitions() + return definitions + + +def remove_unused_definitions_by_function_names( + code: str, qualified_function_names: set[str] +) -> tuple[str, dict[str, UsageInfo]]: """Analyze a file and remove top level definitions not used by specified functions. Top level definitions, in this context, are only classes, variables or functions. @@ -476,19 +509,10 @@ def remove_unused_definitions_by_function_names(code: str, qualified_function_na return code try: - # Collect all definitions (top level classes, variables or function) - definitions = collect_top_level_definitions(module) - - # Collect dependencies between definitions using the visitor pattern - dependency_collector = DependencyCollector(definitions) - module.visit(dependency_collector) - - # Mark definitions used by specified functions, and their dependencies recursively - usage_marker = QualifiedFunctionUsageMarker(definitions, qualified_function_names) - usage_marker.mark_used_definitions() + defs_with_usages = collect_top_level_defs_with_usages(module, qualified_function_names) # Apply the recursive removal transformation - modified_module, _ = remove_unused_definitions_recursively(module, definitions) + modified_module, _ = remove_unused_definitions_recursively(module, defs_with_usages) return modified_module.code if modified_module else "" # noqa: TRY300 except Exception as e: diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 5b71e2736..4f4761e58 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -459,6 +459,9 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: hashing_context = code_ctx.hashing_code_context expected_read_write_context = f""" ```python:{file_path.relative_to(opt.args.project_root)} +_P = ParamSpec("_P") +_KEY_T = TypeVar("_KEY_T") +_STORE_T = TypeVar("_STORE_T") class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]): def __init__(self) -> None: ... @@ -517,6 +520,10 @@ def get_cache_or_call( # If encoding fails, we should still return the result. return result +_P = ParamSpec("_P") +_R = TypeVar("_R") +_CacheBackendT = TypeVar("_CacheBackendT", bound=CacheBackend) + class _PersistentCache(Generic[_P, _R, _CacheBackendT]): @@ -752,7 +759,7 @@ def test_example_class_token_limit_1(tmp_path: Path) -> None: ) code = f""" class MyClass: - \"\"\"A class with a helper method. + \"\"\"A class with a helper method. {docstring_filler}\"\"\" def __init__(self): self.x = 1 @@ -910,7 +917,17 @@ def helper_method(self): return self.x ``` """ - expected_read_only_context = "" + expected_read_only_context = f'''```python:{file_path.relative_to(opt.args.project_root)} +class MyClass: + """A class with a helper method. """ + +class HelperClass: + """A helper class for MyClass.""" + def __repr__(self): + """Return a string representation of the HelperClass.""" + return "HelperClass" + str(self.x) +``` +''' expected_hashing_context = f""" ```python:{file_path.relative_to(opt.args.project_root)} class MyClass: @@ -987,7 +1004,8 @@ def test_example_class_token_limit_4(tmp_path: Path) -> None: class MyClass: \"\"\"A class with a helper method. \"\"\" def __init__(self): - self.x = 1 + global x + x = 1 def target_method(self): \"\"\"Docstring for target method\"\"\" y = HelperClass().helper_method() @@ -1026,11 +1044,98 @@ def helper_method(self): ending_line=None, ) - # In this scenario, the testgen code context is too long, so we abort. - with pytest.raises(ValueError, match="Testgen code context has exceeded token limit, cannot proceed"): + # In this scenario, the read-writable code context is too long because the __init_ function is reftencing the global x variable not the class attribute (x), so we abort. + with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"): code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) +def test_example_class_token_limit_5(tmp_path: Path) -> None: + string_filler = " ".join( + ["This is a long string that will be used to fill up the token limit." for _ in range(1000)] + ) + code = f""" +class MyClass: + \"\"\"A class with a helper method. \"\"\" + def __init__(self): + self.x = 1 + def target_method(self): + \"\"\"Docstring for target method\"\"\" + y = HelperClass().helper_method() +x = '{string_filler}' + +class HelperClass: + \"\"\"A helper class for MyClass.\"\"\" + def __init__(self): + \"\"\"Initialize the HelperClass.\"\"\" + self.x = 1 + def __repr__(self): + \"\"\"Return a string representation of the HelperClass.\"\"\" + return "HelperClass" + str(self.x) + def helper_method(self): + return self.x +""" + # Create a temporary Python file using pytest's tmp_path fixture + file_path = tmp_path / "test_code.py" + file_path.write_text(code, encoding="utf-8") + opt = Optimizer( + Namespace( + project_root=file_path.parent.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), + ) + ) + function_to_optimize = FunctionToOptimize( + function_name="target_method", + file_path=file_path, + parents=[FunctionParent(name="MyClass", type="ClassDef")], + starting_line=None, + ending_line=None, + ) + + code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + + # the global x variable shouldn't be included in any context type + assert code_ctx.read_writable_code.flat == '''# file: test_code.py +class MyClass: + def __init__(self): + self.x = 1 + def target_method(self): + """Docstring for target method""" + y = HelperClass().helper_method() + +class HelperClass: + def __init__(self): + """Initialize the HelperClass.""" + self.x = 1 + def helper_method(self): + return self.x +''' + assert code_ctx.testgen_context.flat == '''# file: test_code.py +class MyClass: + """A class with a helper method. """ + def __init__(self): + self.x = 1 + def target_method(self): + """Docstring for target method""" + y = HelperClass().helper_method() + +class HelperClass: + """A helper class for MyClass.""" + def __init__(self): + """Initialize the HelperClass.""" + self.x = 1 + def __repr__(self): + """Return a string representation of the HelperClass.""" + return "HelperClass" + str(self.x) + def helper_method(self): + return self.x +''' + + def test_repo_helper() -> None: project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever" path_to_file = project_root / "main.py" @@ -2070,8 +2175,17 @@ def get_system_details(): relative_path = file_path.relative_to(project_root) expected_read_write_context = f""" ```python:utility_module.py -# Function that will be used in the main code +DEFAULT_PRECISION = "medium" + +# Try-except block with variable definitions +try: + # Used variable in try block + CALCULATION_BACKEND = "numpy" +except ImportError: + # Used variable in except block + CALCULATION_BACKEND = "python" +# Function that will be used in the main code def select_precision(precision, fallback_precision): if precision is None: return fallback_precision or DEFAULT_PRECISION @@ -2466,12 +2580,12 @@ def test_circular_deps(): project_root_path= Path(path_to_root), ) assert "import ApiClient" not in new_code, "Error: Circular dependency found" - - assert "import urllib.parse" in new_code, "Make sure imports for optimization global assignments exist" + + assert "import urllib.parse" in new_code, "Make sure imports for optimization global assignments exist" def test_global_assignment_collector_with_async_function(): """Test GlobalAssignmentCollector correctly identifies global assignments outside async functions.""" import libcst as cst - + source_code = """ # Global assignment GLOBAL_VAR = "global_value" @@ -2486,21 +2600,21 @@ async def async_function(): # Another global assignment ANOTHER_GLOBAL = "another_global" """ - + tree = cst.parse_module(source_code) collector = GlobalAssignmentCollector() tree.visit(collector) - + # Should collect global assignments but not the ones inside async function assert len(collector.assignments) == 3 assert "GLOBAL_VAR" in collector.assignments assert "OTHER_GLOBAL" in collector.assignments assert "ANOTHER_GLOBAL" in collector.assignments - + # Should not collect assignments from inside async function assert "local_var" not in collector.assignments assert "INNER_ASSIGNMENT" not in collector.assignments - + # Verify assignment order expected_order = ["GLOBAL_VAR", "OTHER_GLOBAL", "ANOTHER_GLOBAL"] assert collector.assignment_order == expected_order @@ -2509,7 +2623,7 @@ async def async_function(): def test_global_assignment_collector_nested_async_functions(): """Test GlobalAssignmentCollector handles nested async functions correctly.""" import libcst as cst - + source_code = """ # Global assignment CONFIG = {"key": "value"} @@ -2517,38 +2631,38 @@ def test_global_assignment_collector_nested_async_functions(): def sync_function(): # Inside sync function - should not be collected sync_local = "sync" - + async def nested_async(): # Inside nested async function - should not be collected nested_var = "nested" return nested_var - + return sync_local async def async_function(): # Inside async function - should not be collected async_local = "async" - + def nested_sync(): # Inside nested function - should not be collected deeply_nested = "deep" return deeply_nested - + return async_local # Another global assignment FINAL_GLOBAL = "final" """ - + tree = cst.parse_module(source_code) collector = GlobalAssignmentCollector() tree.visit(collector) - + # Should only collect global-level assignments assert len(collector.assignments) == 2 assert "CONFIG" in collector.assignments assert "FINAL_GLOBAL" in collector.assignments - + # Should not collect any assignments from inside functions assert "sync_local" not in collector.assignments assert "nested_var" not in collector.assignments @@ -2559,20 +2673,20 @@ def nested_sync(): def test_global_assignment_collector_mixed_async_sync_with_classes(): """Test GlobalAssignmentCollector with async functions, sync functions, and classes.""" import libcst as cst - + source_code = """ # Global assignments GLOBAL_CONSTANT = "constant" class TestClass: - # Class-level assignment - should not be collected + # Class-level assignment - should not be collected class_var = "class_value" - + def sync_method(self): # Method assignment - should not be collected method_var = "method" return method_var - + async def async_method(self): # Async method assignment - should not be collected async_method_var = "async_method" @@ -2592,24 +2706,24 @@ async def async_function(): ANOTHER_CONSTANT = 100 FINAL_ASSIGNMENT = {"data": "value"} """ - + tree = cst.parse_module(source_code) collector = GlobalAssignmentCollector() tree.visit(collector) - + # Should only collect global-level assignments assert len(collector.assignments) == 3 - assert "GLOBAL_CONSTANT" in collector.assignments + assert "GLOBAL_CONSTANT" in collector.assignments assert "ANOTHER_CONSTANT" in collector.assignments assert "FINAL_ASSIGNMENT" in collector.assignments - + # Should not collect assignments from inside any scoped blocks assert "class_var" not in collector.assignments assert "method_var" not in collector.assignments assert "async_method_var" not in collector.assignments assert "func_var" not in collector.assignments assert "async_func_var" not in collector.assignments - + # Verify correct order expected_order = ["GLOBAL_CONSTANT", "ANOTHER_CONSTANT", "FINAL_ASSIGNMENT"] assert collector.assignment_order == expected_order From e210a31060684dff5564e51edb72e7e26c19d731 Mon Sep 17 00:00:00 2001 From: ali Date: Fri, 21 Nov 2025 20:23:36 +0200 Subject: [PATCH 2/4] fix typing issue --- codeflash/context/unused_definition_remover.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/codeflash/context/unused_definition_remover.py b/codeflash/context/unused_definition_remover.py index 30c5d0125..2ff484bf8 100644 --- a/codeflash/context/unused_definition_remover.py +++ b/codeflash/context/unused_definition_remover.py @@ -488,9 +488,7 @@ def collect_top_level_defs_with_usages( return definitions -def remove_unused_definitions_by_function_names( - code: str, qualified_function_names: set[str] -) -> tuple[str, dict[str, UsageInfo]]: +def remove_unused_definitions_by_function_names(code: str, qualified_function_names: set[str]) -> str: """Analyze a file and remove top level definitions not used by specified functions. Top level definitions, in this context, are only classes, variables or functions. From b464c4d869e4b0306b41147daadf8fdf51cefc53 Mon Sep 17 00:00:00 2001 From: ali Date: Fri, 21 Nov 2025 20:53:35 +0200 Subject: [PATCH 3/4] typo --- tests/test_code_context_extractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 4f4761e58..aa4e2880f 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -1044,7 +1044,7 @@ def helper_method(self): ending_line=None, ) - # In this scenario, the read-writable code context is too long because the __init_ function is reftencing the global x variable not the class attribute (x), so we abort. + # In this scenario, the read-writable code context becomes too large because the __init__ function is referencing the global x variable instead of the class attribute self.x, so we abort. with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"): code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) From 73e43e264ca822a5e428fae5e1ccd693cbfcd435 Mon Sep 17 00:00:00 2001 From: ali Date: Fri, 21 Nov 2025 20:54:38 +0200 Subject: [PATCH 4/4] codeflash optimization --- .../context/unused_definition_remover.py | 64 +++++++++++++------ 1 file changed, 43 insertions(+), 21 deletions(-) diff --git a/codeflash/context/unused_definition_remover.py b/codeflash/context/unused_definition_remover.py index 2ff484bf8..8e6ea057c 100644 --- a/codeflash/context/unused_definition_remover.py +++ b/codeflash/context/unused_definition_remover.py @@ -52,11 +52,21 @@ def collect_top_level_definitions( node: cst.CSTNode, definitions: Optional[dict[str, UsageInfo]] = None ) -> dict[str, UsageInfo]: """Recursively collect all top-level variable, function, and class definitions.""" + # Locally bind types and helpers for faster lookup + FunctionDef = cst.FunctionDef # noqa: N806 + ClassDef = cst.ClassDef # noqa: N806 + Assign = cst.Assign # noqa: N806 + AnnAssign = cst.AnnAssign # noqa: N806 + AugAssign = cst.AugAssign # noqa: N806 + IndentedBlock = cst.IndentedBlock # noqa: N806 + if definitions is None: definitions = {} - # Handle top-level function definitions - if isinstance(node, cst.FunctionDef): + # Speed: Single isinstance+local var instead of several type calls + node_type = type(node) + # Fast path: function def + if node_type is FunctionDef: name = node.name.value definitions[name] = UsageInfo( name=name, @@ -64,34 +74,42 @@ def collect_top_level_definitions( ) return definitions - # Handle top-level class definitions - if isinstance(node, cst.ClassDef): + # Fast path: class def + if node_type is ClassDef: name = node.name.value definitions[name] = UsageInfo(name=name) - # Also collect method definitions within the class - if hasattr(node, "body") and isinstance(node.body, cst.IndentedBlock): - for statement in node.body.body: - if isinstance(statement, cst.FunctionDef): - method_name = f"{name}.{statement.name.value}" + # Collect class methods + body = getattr(node, "body", None) + if body is not None and type(body) is IndentedBlock: + statements = body.body + # Precompute f-string template for efficiency + prefix = name + "." + for statement in statements: + if type(statement) is FunctionDef: + method_name = prefix + statement.name.value definitions[method_name] = UsageInfo(name=method_name) return definitions - # Handle top-level variable assignments - if isinstance(node, cst.Assign): - for target in node.targets: + # Fast path: assignment + if node_type is Assign: + # Inline extract_names_from_targets for single-target speed + targets = node.targets + append_def = definitions.__setitem__ + for target in targets: names = extract_names_from_targets(target.target) for name in names: - definitions[name] = UsageInfo(name=name) + append_def(name, UsageInfo(name=name)) return definitions - if isinstance(node, (cst.AnnAssign, cst.AugAssign)): - if isinstance(node.target, cst.Name): - name = node.target.value + if node_type is AnnAssign or node_type is AugAssign: + tgt = node.target + if type(tgt) is cst.Name: + name = tgt.value definitions[name] = UsageInfo(name=name) else: - names = extract_names_from_targets(node.target) + names = extract_names_from_targets(tgt) for name in names: definitions[name] = UsageInfo(name=name) return definitions @@ -100,12 +118,15 @@ def collect_top_level_definitions( section_names = get_section_names(node) if section_names: + getattr_ = getattr for section in section_names: - original_content = getattr(node, section, None) + original_content = getattr_(node, section, None) + # Instead of isinstance check for list/tuple, rely on duck-type via iter # If section contains a list of nodes if isinstance(original_content, (list, tuple)): + defs = definitions # Move out for minor speed for child in original_content: - collect_top_level_definitions(child, definitions) + collect_top_level_definitions(child, defs) # If section contains a single node elif original_content is not None: collect_top_level_definitions(original_content, definitions) @@ -302,14 +323,15 @@ def mark_used_definitions(self) -> None: # Avoid list comprehension for set intersection expanded_names = self.expanded_qualified_functions defs = self.definitions - functions_to_mark = ( + # Use set intersection but only if defs.keys is a set (Python 3.12 dict_keys supports it efficiently) + fnames = ( expanded_names & defs.keys() if isinstance(expanded_names, set) else [name for name in expanded_names if name in defs] ) # For each specified function, mark it and all its dependencies as used - for func_name in functions_to_mark: + for func_name in fnames: defs[func_name].used_by_qualified_function = True for dep in defs[func_name].dependencies: self.mark_as_used_recursively(dep)