diff --git a/codeflash/context/unused_definition_remover.py b/codeflash/context/unused_definition_remover.py index bfcbbaead..6ba299a01 100644 --- a/codeflash/context/unused_definition_remover.py +++ b/codeflash/context/unused_definition_remover.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections import defaultdict from dataclasses import dataclass, field import libcst as cst @@ -255,6 +256,7 @@ class QualifiedFunctionUsageMarker: def __init__(self, definitions: dict[str, UsageInfo], qualified_function_names: set[str]) -> None: self.definitions = definitions self.qualified_function_names = qualified_function_names + self.class_dunder_methods = self._preprocess_definitions() self.expanded_qualified_functions = self._expand_qualified_functions() def _expand_qualified_functions(self) -> set[str]: @@ -262,7 +264,7 @@ def _expand_qualified_functions(self) -> set[str]: expanded = set(self.qualified_function_names) # Find class methods and add their containing classes and dunder methods - for qualified_name in list(self.qualified_function_names): + for qualified_name in self.qualified_function_names: if "." in qualified_name: class_name, method_name = qualified_name.split(".", 1) @@ -270,9 +272,8 @@ def _expand_qualified_functions(self) -> set[str]: expanded.add(class_name) # Add all dunder methods of the class - for name in self.definitions: - if name.startswith(f"{class_name}.__") and name.endswith("__"): - expanded.add(name) + if class_name in self.class_dunder_methods: + expanded.update(self.class_dunder_methods[class_name]) return expanded @@ -301,9 +302,21 @@ def mark_as_used_recursively(self, name: str) -> None: for dep in self.definitions[name].dependencies: self.mark_as_used_recursively(dep) + def _preprocess_definitions(self) -> dict[str, set[str]]: + """Preprocess definitions to find dunder methods for each class.""" + class_dunder_methods = defaultdict(set) + + for name in self.definitions: + if name.count(".") == 1: + class_name, method_name = name.split(".", 1) + if method_name.startswith("__") and method_name.endswith("__"): + class_dunder_methods[class_name].add(name) + + return class_dunder_methods + def remove_unused_definitions_recursively( - node: cst.CSTNode, definitions: dict[str, UsageInfo] + node: cst.CSTNode, definitions: dict[str, UsageInfo] ) -> tuple[cst.CSTNode | None, bool]: """Recursively filter the node to remove unused definitions. @@ -358,7 +371,10 @@ def remove_unused_definitions_recursively( names = extract_names_from_targets(target.target) for name in names: class_var_name = f"{class_name}.{name}" - if class_var_name in definitions and definitions[class_var_name].used_by_qualified_function: + if ( + class_var_name in definitions + and definitions[class_var_name].used_by_qualified_function + ): var_used = True method_or_var_used = True break