diff --git a/code_to_optimize/code_directories/retriever/import_test.py b/code_to_optimize/code_directories/retriever/import_test.py new file mode 100644 index 000000000..7f12f0a89 --- /dev/null +++ b/code_to_optimize/code_directories/retriever/import_test.py @@ -0,0 +1,5 @@ + +import code_to_optimize.code_directories.retriever.main + +def function_to_optimize(): + return code_to_optimize.code_directories.retriever.main.fetch_and_transform_data() diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index 516f3c94e..792a76885 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -14,6 +14,7 @@ from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages +from codeflash.context.unused_definition_remover import remove_unused_definitions_by_function_names from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import ( CodeContextType, @@ -182,14 +183,16 @@ def extract_code_string_context_from_files( helpers_of_helpers_qualified_names = { func.qualified_name for func in helpers_of_helpers.get(file_path, set()) } + code_without_unused_defs = remove_unused_definitions_by_function_names( + original_code, qualified_function_names | helpers_of_helpers_qualified_names + ) code_context = parse_code_and_prune_cst( - original_code, + code_without_unused_defs, code_context_type, qualified_function_names, helpers_of_helpers_qualified_names, remove_docstrings, ) - except ValueError as e: logger.debug(f"Error while getting read-only code: {e}") continue @@ -214,8 +217,9 @@ def extract_code_string_context_from_files( continue try: qualified_helper_function_names = {func.qualified_name for func in helper_function_sources} + code_without_unused_defs = remove_unused_definitions_by_function_names(original_code, qualified_helper_function_names) code_context = parse_code_and_prune_cst( - original_code, code_context_type, set(), qualified_helper_function_names, remove_docstrings + code_without_unused_defs, code_context_type, set(), qualified_helper_function_names, remove_docstrings ) except ValueError as e: logger.debug(f"Error while getting read-only code: {e}") @@ -283,8 +287,11 @@ def extract_code_markdown_context_from_files( helpers_of_helpers_qualified_names = { func.qualified_name for func in helpers_of_helpers.get(file_path, set()) } + code_without_unused_defs = remove_unused_definitions_by_function_names( + original_code, qualified_function_names | helpers_of_helpers_qualified_names + ) code_context = parse_code_and_prune_cst( - original_code, + code_without_unused_defs, code_context_type, qualified_function_names, helpers_of_helpers_qualified_names, @@ -318,8 +325,9 @@ def extract_code_markdown_context_from_files( continue try: qualified_helper_function_names = {func.qualified_name for func in helper_function_sources} + code_without_unused_defs = remove_unused_definitions_by_function_names(original_code, qualified_helper_function_names) code_context = parse_code_and_prune_cst( - original_code, code_context_type, set(), qualified_helper_function_names, remove_docstrings + code_without_unused_defs, code_context_type, set(), qualified_helper_function_names, remove_docstrings ) except ValueError as e: logger.debug(f"Error while getting read-only code: {e}") diff --git a/codeflash/context/unused_definition_remover.py b/codeflash/context/unused_definition_remover.py new file mode 100644 index 000000000..bfcbbaead --- /dev/null +++ b/codeflash/context/unused_definition_remover.py @@ -0,0 +1,476 @@ +from __future__ import annotations + +from dataclasses import dataclass, field + +import libcst as cst + + +@dataclass +class UsageInfo: + """Information about a name and its usage.""" + + name: str + used_by_qualified_function: bool = False + dependencies: set[str] = field(default_factory=set) + + +def extract_names_from_targets(target: cst.CSTNode) -> list[str]: + """Extract all variable names from a target node, including from tuple unpacking.""" + names = [] + + # Handle a simple name + if isinstance(target, cst.Name): + names.append(target.value) + + # Handle any node with a value attribute (StarredElement, etc.) + elif hasattr(target, "value"): + names.extend(extract_names_from_targets(target.value)) + + # Handle any node with elements attribute (tuples, lists, etc.) + elif hasattr(target, "elements"): + for element in target.elements: + # Recursive call for each element + names.extend(extract_names_from_targets(element)) + + return names + + +def collect_top_level_definitions(node: cst.CSTNode, definitions: dict[str, UsageInfo] = None) -> dict[str, UsageInfo]: + """Recursively collect all top-level variable, function, and class definitions.""" + if definitions is None: + definitions = {} + + # Handle top-level function definitions + if isinstance(node, cst.FunctionDef): + name = node.name.value + definitions[name] = UsageInfo( + name=name, + used_by_qualified_function=False, # Will be marked later if in qualified functions + ) + return definitions + + # Handle top-level class definitions + if isinstance(node, cst.ClassDef): + name = node.name.value + definitions[name] = UsageInfo(name=name) + + # Also collect method definitions within the class + if hasattr(node, "body") and isinstance(node.body, cst.IndentedBlock): + for statement in node.body.body: + if isinstance(statement, cst.FunctionDef): + method_name = f"{name}.{statement.name.value}" + definitions[method_name] = UsageInfo(name=method_name) + + return definitions + + # Handle top-level variable assignments + if isinstance(node, cst.Assign): + for target in node.targets: + names = extract_names_from_targets(target.target) + for name in names: + definitions[name] = UsageInfo(name=name) + return definitions + + if isinstance(node, (cst.AnnAssign, cst.AugAssign)): + if isinstance(node.target, cst.Name): + name = node.target.value + definitions[name] = UsageInfo(name=name) + else: + names = extract_names_from_targets(node.target) + for name in names: + definitions[name] = UsageInfo(name=name) + return definitions + + # Recursively process children. Takes care of top level assignments in if/else/while/for blocks + section_names = get_section_names(node) + + if section_names: + for section in section_names: + original_content = getattr(node, section, None) + # If section contains a list of nodes + if isinstance(original_content, (list, tuple)): + for child in original_content: + collect_top_level_definitions(child, definitions) + # If section contains a single node + elif original_content is not None: + collect_top_level_definitions(original_content, definitions) + + return definitions + + +def get_section_names(node: cst.CSTNode) -> list[str]: + """Return the section attribute names (e.g., body, orelse) for a given node if they exist.""" + possible_sections = ["body", "orelse", "finalbody", "handlers"] + return [sec for sec in possible_sections if hasattr(node, sec)] + + +class DependencyCollector(cst.CSTVisitor): + """Collects dependencies between definitions using the visitor pattern with depth tracking.""" + + def __init__(self, definitions: dict[str, UsageInfo]) -> None: + super().__init__() + self.definitions = definitions + # Track function and class depths + self.function_depth = 0 + self.class_depth = 0 + # Track top-level qualified names + self.current_top_level_name = "" + self.current_class = "" + # Track if we're processing a top-level variable + self.processing_variable = False + self.current_variable_names = set() + + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: + function_name = node.name.value + + if self.function_depth == 0: + # This is a top-level function + if self.class_depth > 0: + # If inside a class, we're now tracking dependencies at the class level + self.current_top_level_name = f"{self.current_class}.{function_name}" + else: + # Regular top-level function + self.current_top_level_name = function_name + + # Check parameter type annotations for dependencies + if hasattr(node, "params") and node.params: + for param in node.params.params: + if param.annotation: + # Visit the annotation to extract dependencies + self._collect_annotation_dependencies(param.annotation) + + self.function_depth += 1 + + def _collect_annotation_dependencies(self, annotation: cst.Annotation) -> None: + """Extract dependencies from type annotations""" + if hasattr(annotation, "annotation"): + # Extract names from annotation (could be Name, Attribute, Subscript, etc.) + self._extract_names_from_annotation(annotation.annotation) + + def _extract_names_from_annotation(self, node: cst.CSTNode) -> None: + """Extract names from a type annotation node""" + # Simple name reference like 'int', 'str', or custom type + if isinstance(node, cst.Name): + name = node.value + if name in self.definitions and name != self.current_top_level_name and self.current_top_level_name: + self.definitions[self.current_top_level_name].dependencies.add(name) + + # Handle compound annotations like List[int], Dict[str, CustomType], etc. + elif isinstance(node, cst.Subscript): + if hasattr(node, "value"): + self._extract_names_from_annotation(node.value) + if hasattr(node, "slice"): + for slice_item in node.slice: + if hasattr(slice_item, "slice"): + self._extract_names_from_annotation(slice_item.slice) + + # Handle attribute access like module.Type + elif isinstance(node, cst.Attribute): + if hasattr(node, "value"): + self._extract_names_from_annotation(node.value) + # No need to check the attribute name itself as it's likely not a top-level definition + + def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: + self.function_depth -= 1 + + if self.function_depth == 0 and self.class_depth == 0: + # Exiting top-level function that's not in a class + self.current_top_level_name = "" + + def visit_ClassDef(self, node: cst.ClassDef) -> None: + class_name = node.name.value + + if self.class_depth == 0: + # This is a top-level class + self.current_class = class_name + self.current_top_level_name = class_name + + self.class_depth += 1 + + def leave_ClassDef(self, original_node: cst.ClassDef) -> None: + self.class_depth -= 1 + + if self.class_depth == 0: + # Exiting top-level class + self.current_class = "" + self.current_top_level_name = "" + + def visit_Assign(self, node: cst.Assign) -> None: + # Only handle top-level assignments + if self.function_depth == 0 and self.class_depth == 0: + for target in node.targets: + # Extract all variable names from the target + names = extract_names_from_targets(target.target) + + # Check if any of these names are top-level definitions we're tracking + tracked_names = [name for name in names if name in self.definitions] + if tracked_names: + self.processing_variable = True + self.current_variable_names.update(tracked_names) + # Use the first tracked name as the current top-level name (for dependency tracking) + self.current_top_level_name = tracked_names[0] + + def leave_Assign(self, original_node: cst.Assign) -> None: + if self.processing_variable: + self.processing_variable = False + self.current_variable_names.clear() + self.current_top_level_name = "" + + def visit_AnnAssign(self, node: cst.AnnAssign) -> None: + # Extract names from the variable annotations + if hasattr(node, "annotation") and node.annotation: + # First mark we're processing a variable to avoid recording it as a dependency of itself + self.processing_variable = True + if isinstance(node.target, cst.Name): + self.current_variable_names.add(node.target.value) + else: + self.current_variable_names.update(extract_names_from_targets(node.target)) + + # Process the annotation + self._collect_annotation_dependencies(node.annotation) + + # Reset processing state + self.processing_variable = False + self.current_variable_names.clear() + + def visit_Name(self, node: cst.Name) -> None: + name = node.value + + # Skip if we're not inside a tracked definition + if not self.current_top_level_name or self.current_top_level_name not in self.definitions: + return + + # Skip if we're looking at the variable name itself in an assignment + if self.processing_variable and name in self.current_variable_names: + return + + # Check if name is a top-level definition we're tracking + if name in self.definitions and name != self.current_top_level_name: + self.definitions[self.current_top_level_name].dependencies.add(name) + + +class QualifiedFunctionUsageMarker: + """Marks definitions that are used by specific qualified functions.""" + + def __init__(self, definitions: dict[str, UsageInfo], qualified_function_names: set[str]) -> None: + self.definitions = definitions + self.qualified_function_names = qualified_function_names + self.expanded_qualified_functions = self._expand_qualified_functions() + + def _expand_qualified_functions(self) -> set[str]: + """Expand the qualified function names to include related methods.""" + expanded = set(self.qualified_function_names) + + # Find class methods and add their containing classes and dunder methods + for qualified_name in list(self.qualified_function_names): + if "." in qualified_name: + class_name, method_name = qualified_name.split(".", 1) + + # Add the class itself + expanded.add(class_name) + + # Add all dunder methods of the class + for name in self.definitions: + if name.startswith(f"{class_name}.__") and name.endswith("__"): + expanded.add(name) + + return expanded + + def mark_used_definitions(self) -> None: + """Find all qualified functions and mark them and their dependencies as used.""" + # First identify all specified functions (including expanded ones) + functions_to_mark = [name for name in self.expanded_qualified_functions if name in self.definitions] + + # For each specified function, mark it and all its dependencies as used + for func_name in functions_to_mark: + self.definitions[func_name].used_by_qualified_function = True + for dep in self.definitions[func_name].dependencies: + self.mark_as_used_recursively(dep) + + def mark_as_used_recursively(self, name: str) -> None: + """Mark a name and all its dependencies as used recursively.""" + if name not in self.definitions: + return + + if self.definitions[name].used_by_qualified_function: + return # Already marked + + self.definitions[name].used_by_qualified_function = True + + # Mark all dependencies as used + for dep in self.definitions[name].dependencies: + self.mark_as_used_recursively(dep) + + +def remove_unused_definitions_recursively( + node: cst.CSTNode, definitions: dict[str, UsageInfo] +) -> tuple[cst.CSTNode | None, bool]: + """Recursively filter the node to remove unused definitions. + + Args: + node: The CST node to process + definitions: Dictionary of definition info + + Returns: + (filtered_node, used_by_function): + filtered_node: The modified CST node or None if it should be removed + used_by_function: True if this node or any child is used by qualified functions + + """ + # Skip import statements + if isinstance(node, (cst.Import, cst.ImportFrom)): + return node, True + + # Never remove function definitions + if isinstance(node, cst.FunctionDef): + return node, True + + # Never remove class definitions + if isinstance(node, cst.ClassDef): + class_name = node.name.value + + # Check if any methods or variables in this class are used + method_or_var_used = False + class_has_dependencies = False + + # Check if class itself is marked as used + if class_name in definitions and definitions[class_name].used_by_qualified_function: + class_has_dependencies = True + + if hasattr(node, "body") and isinstance(node.body, cst.IndentedBlock): + updates = {} + new_statements = [] + + for statement in node.body.body: + # Keep all function definitions + if isinstance(statement, cst.FunctionDef): + method_name = f"{class_name}.{statement.name.value}" + if method_name in definitions and definitions[method_name].used_by_qualified_function: + method_or_var_used = True + new_statements.append(statement) + # Only process variable assignments + elif isinstance(statement, (cst.Assign, cst.AnnAssign, cst.AugAssign)): + var_used = False + + # Check if any variable in this assignment is used + if isinstance(statement, cst.Assign): + for target in statement.targets: + names = extract_names_from_targets(target.target) + for name in names: + class_var_name = f"{class_name}.{name}" + if class_var_name in definitions and definitions[class_var_name].used_by_qualified_function: + var_used = True + method_or_var_used = True + break + elif isinstance(statement, (cst.AnnAssign, cst.AugAssign)): + names = extract_names_from_targets(statement.target) + for name in names: + class_var_name = f"{class_name}.{name}" + if class_var_name in definitions and definitions[class_var_name].used_by_qualified_function: + var_used = True + method_or_var_used = True + break + + if var_used or class_has_dependencies: + new_statements.append(statement) + else: + # Keep all other statements in the class + new_statements.append(statement) + + # Update the class body + new_body = node.body.with_changes(body=new_statements) + updates["body"] = new_body + + return node.with_changes(**updates), True + + return node, method_or_var_used or class_has_dependencies + + # Handle assignments (Assign and AnnAssign) + if isinstance(node, cst.Assign): + for target in node.targets: + names = extract_names_from_targets(target.target) + for name in names: + if name in definitions and definitions[name].used_by_qualified_function: + return node, True + return None, False + + if isinstance(node, (cst.AnnAssign, cst.AugAssign)): + names = extract_names_from_targets(node.target) + for name in names: + if name in definitions and definitions[name].used_by_qualified_function: + return node, True + return None, False + + # For other nodes, recursively process children + section_names = get_section_names(node) + if not section_names: + return node, False + + updates = {} + found_used = False + + for section in section_names: + original_content = getattr(node, section, None) + if isinstance(original_content, (list, tuple)): + new_children = [] + section_found_used = False + + for child in original_content: + filtered, used = remove_unused_definitions_recursively(child, definitions) + if filtered: + new_children.append(filtered) + section_found_used |= used + + if new_children or section_found_used: + found_used |= section_found_used + updates[section] = new_children + elif original_content is not None: + filtered, used = remove_unused_definitions_recursively(original_content, definitions) + found_used |= used + if filtered: + updates[section] = filtered + if not found_used: + return None, False + if updates: + return node.with_changes(**updates), found_used + + return node, False + + +def remove_unused_definitions_by_function_names(code: str, qualified_function_names: set[str]) -> str: + """Analyze a file and remove top level definitions not used by specified functions. + + Top level definitions, in this context, are only classes, variables or functions. + If a class is referenced by a qualified function, we keep the entire class. + + Args: + code: The code to process + qualified_function_names: Set of function names to keep. For methods, use format 'classname.methodname' + + """ + module = cst.parse_module(code) + # Collect all definitions (top level classes, variables or function) + definitions = collect_top_level_definitions(module) + + # Collect dependencies between definitions using the visitor pattern + dependency_collector = DependencyCollector(definitions) + module.visit(dependency_collector) + + # Mark definitions used by specified functions, and their dependencies recursively + usage_marker = QualifiedFunctionUsageMarker(definitions, qualified_function_names) + usage_marker.mark_used_definitions() + + # Apply the recursive removal transformation + modified_module, _ = remove_unused_definitions_recursively(module, definitions) + + return modified_module.code if modified_module else "" + + +def print_definitions(definitions: dict[str, UsageInfo]) -> None: + """Print information about each definition without the complex node object, used for debugging.""" + print(f"Found {len(definitions)} definitions:") + for name, info in sorted(definitions.items()): + print(f" - Name: {name}") + print(f" Used by qualified function: {info.used_by_qualified_function}") + print(f" Dependencies: {', '.join(sorted(info.dependencies)) if info.dependencies else 'None'}") + print() diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 069a8eb19..90356ac10 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -929,9 +929,6 @@ def fetch_and_process_data(): """ expected_read_only_context = f""" ```python:{path_to_utils.relative_to(project_root)} -GLOBAL_VAR = 10 - - class DataProcessor: \"\"\"A class for processing data.\"\"\" @@ -941,11 +938,6 @@ def __repr__(self) -> str: \"\"\"Return a string representation of the DataProcessor.\"\"\" return f"DataProcessor(default_prefix={{self.default_prefix!r}})" ``` -```python:{path_to_file.relative_to(project_root)} -if __name__ == "__main__": - result = fetch_and_process_data() - print("Processed data:", result) -``` """ assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() @@ -1006,9 +998,6 @@ def fetch_and_transform_data(): """ expected_read_only_context = f""" ```python:{path_to_utils.relative_to(project_root)} -GLOBAL_VAR = 10 - - class DataProcessor: \"\"\"A class for processing data.\"\"\" @@ -1018,11 +1007,6 @@ def __repr__(self) -> str: \"\"\"Return a string representation of the DataProcessor.\"\"\" return f"DataProcessor(default_prefix={{self.default_prefix!r}})" ``` -```python:{path_to_file.relative_to(project_root)} -if __name__ == "__main__": - result = fetch_and_process_data() - print("Processed data:", result) -``` ```python:{path_to_transform_utils.relative_to(project_root)} class DataTransformer: @@ -1084,9 +1068,6 @@ def transform(self, data): return self.data ``` ```python:{path_to_utils.relative_to(project_root)} -GLOBAL_VAR = 10 - - class DataProcessor: \"\"\"A class for processing data.\"\"\" @@ -1147,9 +1128,6 @@ def update_data(data): return data + " updated" ``` ```python:{path_to_utils.relative_to(project_root)} -GLOBAL_VAR = 10 - - class DataProcessor: \"\"\"A class for processing data.\"\"\" @@ -1252,9 +1230,6 @@ def circular_dependency(self, data): """ expected_read_only_context = f""" ```python:{path_to_utils.relative_to(project_root)} -GLOBAL_VAR = 10 - - class DataProcessor: \"\"\"A class for processing data.\"\"\" @@ -1320,6 +1295,500 @@ def target_method(self): def outside_method(): return 1 ``` +""" + assert read_write_context.strip() == expected_read_write_context.strip() + assert read_only_context.strip() == expected_read_only_context.strip() + +def test_direct_module_import() -> None: + project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever" + path_to_main = project_root / "main.py" + path_to_fto = project_root / "import_test.py" + function_to_optimize = FunctionToOptimize( + function_name="function_to_optimize", + file_path=str(path_to_fto), + parents=[], + starting_line=None, + ending_line=None, + ) + + + code_ctx = get_code_optimization_context(function_to_optimize, project_root) + read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + + expected_read_only_context = """ +```python:utils.py +from transform_utils import DataTransformer + +class DataProcessor: + \"\"\"A class for processing data.\"\"\" + + number = 1 + + def __repr__(self) -> str: + \"\"\"Return a string representation of the DataProcessor.\"\"\" + return f"DataProcessor(default_prefix={self.default_prefix!r})" + + def process_data(self, raw_data: str) -> str: + \"\"\"Process raw data by converting it to uppercase.\"\"\" + return raw_data.upper() + + def transform_data(self, data: str) -> str: + \"\"\"Transform the processed data\"\"\" + return DataTransformer().transform(data) +```""" + expected_read_write_context = """ +import requests +from globals import API_URL +from utils import DataProcessor +import code_to_optimize.code_directories.retriever.main + +def fetch_and_transform_data(): + # Use the global variable for the request + response = requests.get(API_URL) + + raw_data = response.text + + # Use code from another file (utils.py) + processor = DataProcessor() + processed = processor.process_data(raw_data) + transformed = processor.transform_data(processed) + + return transformed + + + +def function_to_optimize(): + return code_to_optimize.code_directories.retriever.main.fetch_and_transform_data() +""" + assert read_write_context.strip() == expected_read_write_context.strip() + assert read_only_context.strip() == expected_read_only_context.strip() + +def test_module_import_optimization() -> None: + main_code = ''' +import utility_module + +class Calculator: + def __init__(self, precision="high", fallback_precision=None, mode="standard"): + # This is where we use the imported module + self.precision = utility_module.select_precision(precision, fallback_precision) + self.mode = mode + + # Using variables from the utility module + self.backend = utility_module.CALCULATION_BACKEND + self.system = utility_module.SYSTEM_TYPE + self.default_precision = utility_module.DEFAULT_PRECISION + + def add(self, a, b): + return a + b + + def subtract(self, a, b): + return a - b + + def calculate(self, operation, x, y): + if operation == "add": + return self.add(x, y) + elif operation == "subtract": + return self.subtract(x, y) + else: + return None +''' + + utility_module_code = ''' +import sys +import platform +import logging + +DEFAULT_PRECISION = "medium" +DEFAULT_MODE = "standard" + +# Try-except block with variable definitions +try: + import numpy as np + # Used variable in try block + CALCULATION_BACKEND = "numpy" + # Unused variable in try block + VECTOR_DIMENSIONS = 3 +except ImportError: + # Used variable in except block + CALCULATION_BACKEND = "python" + # Unused variable in except block + FALLBACK_WARNING = "NumPy not available, using slower Python implementation" + +# Nested if-else with variable definitions +if sys.platform.startswith('win'): + # Used variable in outer if + SYSTEM_TYPE = "windows" + if platform.architecture()[0] == '64bit': + # Unused variable in nested if + MEMORY_MODEL = "x64" + else: + # Unused variable in nested else + MEMORY_MODEL = "x86" +elif sys.platform.startswith('linux'): + # Used variable in outer elif + SYSTEM_TYPE = "linux" + # Unused variable in outer elif + KERNEL_VERSION = platform.release() +else: + # Used variable in outer else + SYSTEM_TYPE = "other" + # Unused variable in outer else + UNKNOWN_SYSTEM_MSG = "Running on an unrecognized platform" + +# Function that will be used in the main code +def select_precision(precision, fallback_precision): + if precision is None: + return fallback_precision or DEFAULT_PRECISION + + # Using the variables defined above + if CALCULATION_BACKEND == "numpy": + # Higher precision available with NumPy + precision_options = ["low", "medium", "high", "ultra"] + else: + # Limited precision without NumPy + precision_options = ["low", "medium", "high"] + + if isinstance(precision, str): + if precision.lower() not in precision_options: + if fallback_precision: + return fallback_precision + else: + return DEFAULT_PRECISION + return precision.lower() + else: + return DEFAULT_PRECISION + +# Function that won't be used +def get_system_details(): + return { + "system": SYSTEM_TYPE, + "backend": CALCULATION_BACKEND, + "default_precision": DEFAULT_PRECISION, + "python_version": sys.version + } +''' + + # Create a temporary directory for the test + with tempfile.TemporaryDirectory() as temp_dir: + # Set up the package structure + package_dir = Path(temp_dir) / "package" + package_dir.mkdir() + + # Create the __init__.py file + with open(package_dir / "__init__.py", "w") as init_file: + init_file.write("") + + # Write the utility_module.py file + with open(package_dir / "utility_module.py", "w") as utility_file: + utility_file.write(utility_module_code) + utility_file.flush() + + # Write the main code file + main_file_path = package_dir / "main_module.py" + with open(main_file_path, "w") as main_file: + main_file.write(main_code) + main_file.flush() + + # Set up the optimizer + file_path = main_file_path.resolve() + opt = Optimizer( + Namespace( + project_root=package_dir.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), + ) + ) + + # Define the function to optimize + function_to_optimize = FunctionToOptimize( + function_name="calculate", + file_path=file_path, + parents=[FunctionParent(name="Calculator", type="ClassDef")], + starting_line=None, + ending_line=None, + ) + + # Get the code optimization context + code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + # The expected contexts + expected_read_write_context = """ +import utility_module + +class Calculator: + def __init__(self, precision="high", fallback_precision=None, mode="standard"): + # This is where we use the imported module + self.precision = utility_module.select_precision(precision, fallback_precision) + self.mode = mode + + # Using variables from the utility module + self.backend = utility_module.CALCULATION_BACKEND + self.system = utility_module.SYSTEM_TYPE + self.default_precision = utility_module.DEFAULT_PRECISION + + def add(self, a, b): + return a + b + + def subtract(self, a, b): + return a - b + + def calculate(self, operation, x, y): + if operation == "add": + return self.add(x, y) + elif operation == "subtract": + return self.subtract(x, y) + else: + return None +""" + expected_read_only_context = """ +```python:utility_module.py +DEFAULT_PRECISION = "medium" + +# Try-except block with variable definitions +try: + # Used variable in try block + CALCULATION_BACKEND = "numpy" +except ImportError: + # Used variable in except block + CALCULATION_BACKEND = "python" + +# Function that will be used in the main code +def select_precision(precision, fallback_precision): + if precision is None: + return fallback_precision or DEFAULT_PRECISION + + # Using the variables defined above + if CALCULATION_BACKEND == "numpy": + # Higher precision available with NumPy + precision_options = ["low", "medium", "high", "ultra"] + else: + # Limited precision without NumPy + precision_options = ["low", "medium", "high"] + + if isinstance(precision, str): + if precision.lower() not in precision_options: + if fallback_precision: + return fallback_precision + else: + return DEFAULT_PRECISION + return precision.lower() + else: + return DEFAULT_PRECISION +``` +""" + # Verify the contexts match the expected values + assert read_write_context.strip() == expected_read_write_context.strip() + assert read_only_context.strip() == expected_read_only_context.strip() + +def test_module_import_init_fto() -> None: + main_code = ''' +import utility_module + +class Calculator: + def __init__(self, precision="high", fallback_precision=None, mode="standard"): + # This is where we use the imported module + self.precision = utility_module.select_precision(precision, fallback_precision) + self.mode = mode + + # Using variables from the utility module + self.backend = utility_module.CALCULATION_BACKEND + self.system = utility_module.SYSTEM_TYPE + self.default_precision = utility_module.DEFAULT_PRECISION + + def add(self, a, b): + return a + b + + def subtract(self, a, b): + return a - b + + def calculate(self, operation, x, y): + if operation == "add": + return self.add(x, y) + elif operation == "subtract": + return self.subtract(x, y) + else: + return None +''' + + utility_module_code = ''' +import sys +import platform +import logging + +DEFAULT_PRECISION = "medium" +DEFAULT_MODE = "standard" + +# Try-except block with variable definitions +try: + import numpy as np + # Used variable in try block + CALCULATION_BACKEND = "numpy" + # Unused variable in try block + VECTOR_DIMENSIONS = 3 +except ImportError: + # Used variable in except block + CALCULATION_BACKEND = "python" + # Unused variable in except block + FALLBACK_WARNING = "NumPy not available, using slower Python implementation" + +# Nested if-else with variable definitions +if sys.platform.startswith('win'): + # Used variable in outer if + SYSTEM_TYPE = "windows" + if platform.architecture()[0] == '64bit': + # Unused variable in nested if + MEMORY_MODEL = "x64" + else: + # Unused variable in nested else + MEMORY_MODEL = "x86" +elif sys.platform.startswith('linux'): + # Used variable in outer elif + SYSTEM_TYPE = "linux" + # Unused variable in outer elif + KERNEL_VERSION = platform.release() +else: + # Used variable in outer else + SYSTEM_TYPE = "other" + # Unused variable in outer else + UNKNOWN_SYSTEM_MSG = "Running on an unrecognized platform" + +# Function that will be used in the main code +def select_precision(precision, fallback_precision): + if precision is None: + return fallback_precision or DEFAULT_PRECISION + + # Using the variables defined above + if CALCULATION_BACKEND == "numpy": + # Higher precision available with NumPy + precision_options = ["low", "medium", "high", "ultra"] + else: + # Limited precision without NumPy + precision_options = ["low", "medium", "high"] + + if isinstance(precision, str): + if precision.lower() not in precision_options: + if fallback_precision: + return fallback_precision + else: + return DEFAULT_PRECISION + return precision.lower() + else: + return DEFAULT_PRECISION + +# Function that won't be used +def get_system_details(): + return { + "system": SYSTEM_TYPE, + "backend": CALCULATION_BACKEND, + "default_precision": DEFAULT_PRECISION, + "python_version": sys.version + } +''' + + # Create a temporary directory for the test + with tempfile.TemporaryDirectory() as temp_dir: + # Set up the package structure + package_dir = Path(temp_dir) / "package" + package_dir.mkdir() + + # Create the __init__.py file + with open(package_dir / "__init__.py", "w") as init_file: + init_file.write("") + + # Write the utility_module.py file + with open(package_dir / "utility_module.py", "w") as utility_file: + utility_file.write(utility_module_code) + utility_file.flush() + + # Write the main code file + main_file_path = package_dir / "main_module.py" + with open(main_file_path, "w") as main_file: + main_file.write(main_code) + main_file.flush() + + # Set up the optimizer + file_path = main_file_path.resolve() + opt = Optimizer( + Namespace( + project_root=package_dir.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), + ) + ) + + # Define the function to optimize + function_to_optimize = FunctionToOptimize( + function_name="__init__", + file_path=file_path, + parents=[FunctionParent(name="Calculator", type="ClassDef")], + starting_line=None, + ending_line=None, + ) + + # Get the code optimization context + code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code + # The expected contexts + expected_read_write_context = """ +# Function that will be used in the main code + +import utility_module + +def select_precision(precision, fallback_precision): + if precision is None: + return fallback_precision or DEFAULT_PRECISION + + # Using the variables defined above + if CALCULATION_BACKEND == "numpy": + # Higher precision available with NumPy + precision_options = ["low", "medium", "high", "ultra"] + else: + # Limited precision without NumPy + precision_options = ["low", "medium", "high"] + + if isinstance(precision, str): + if precision.lower() not in precision_options: + if fallback_precision: + return fallback_precision + else: + return DEFAULT_PRECISION + return precision.lower() + else: + return DEFAULT_PRECISION + + + +class Calculator: + def __init__(self, precision="high", fallback_precision=None, mode="standard"): + # This is where we use the imported module + self.precision = utility_module.select_precision(precision, fallback_precision) + self.mode = mode + + # Using variables from the utility module + self.backend = utility_module.CALCULATION_BACKEND + self.system = utility_module.SYSTEM_TYPE + self.default_precision = utility_module.DEFAULT_PRECISION +""" + expected_read_only_context = """ +```python:utility_module.py +DEFAULT_PRECISION = "medium" + +# Try-except block with variable definitions +try: + # Used variable in try block + CALCULATION_BACKEND = "numpy" +except ImportError: + # Used variable in except block + CALCULATION_BACKEND = "python" +``` """ assert read_write_context.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() \ No newline at end of file diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index ea221be78..d3c4d941a 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -748,10 +748,6 @@ def main_method(self): def test_code_replacement10() -> None: get_code_output = """from __future__ import annotations -import os - -os.environ["CODEFLASH_API_KEY"] = "cf-test-key" - class HelperClass: def __init__(self, name): diff --git a/tests/test_remove_unused_definitions.py b/tests/test_remove_unused_definitions.py new file mode 100644 index 000000000..86a57bb6d --- /dev/null +++ b/tests/test_remove_unused_definitions.py @@ -0,0 +1,424 @@ +import tempfile +from argparse import Namespace +from pathlib import Path + +import libcst as cst + +from codeflash.context.code_context_extractor import get_code_optimization_context +from codeflash.context.unused_definition_remover import remove_unused_definitions_by_function_names +from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.models.models import FunctionParent +from codeflash.optimization.optimizer import Optimizer + + +def test_variable_removal_only() -> None: + """Test that only variables not used by specified functions are removed, not functions.""" + code = """ +def main_function(): + return USED_CONSTANT + 10 + +def helper_function(): + return 42 + +USED_CONSTANT = 42 +UNUSED_CONSTANT = 123 + +def another_function(): + return UNUSED_CONSTANT +""" + + expected = """ +def main_function(): + return USED_CONSTANT + 10 + +def helper_function(): + return 42 + +USED_CONSTANT = 42 + +def another_function(): + return UNUSED_CONSTANT +""" + + qualified_functions = {"main_function"} + result = remove_unused_definitions_by_function_names(code, qualified_functions) + # Normalize whitespace for comparison + assert result.strip() == expected.strip() + + +def test_class_variable_removal() -> None: + """Test that only class variables not used by specified functions are removed, not methods.""" + code = """ +class MyClass: + CLASS_USED = "used value" + CLASS_UNUSED = "unused value" + + def __init__(self): + self.value = self.CLASS_USED + self.other = self.CLASS_UNUSED + + def used_method(self): + return self.value + + def unused_method(self): + return "Not used but not removed" + +GLOBAL_USED = "global used" +GLOBAL_UNUSED = "global unused" + +def helper_function(): + return MyClass().used_method() + GLOBAL_USED +""" + + expected = """ +class MyClass: + CLASS_USED = "used value" + CLASS_UNUSED = "unused value" + + def __init__(self): + self.value = self.CLASS_USED + self.other = self.CLASS_UNUSED + + def used_method(self): + return self.value + + def unused_method(self): + return "Not used but not removed" + +GLOBAL_USED = "global used" + +def helper_function(): + return MyClass().used_method() + GLOBAL_USED +""" + + qualified_functions = {"helper_function"} + result = remove_unused_definitions_by_function_names(code, qualified_functions) + # Normalize whitespace for comparison + assert result.strip() == expected.strip() + + +def test_complex_variable_dependencies() -> None: + """Test that only variables with complex dependencies are properly handled.""" + code = """ +def main_function(): + return DIRECT_DEPENDENCY + +def unused_function(): + return "Not used but not removed" + +DIRECT_DEPENDENCY = INDIRECT_DEPENDENCY + "_suffix" +INDIRECT_DEPENDENCY = "base value" +UNUSED_VARIABLE = "This should be removed" + +TUPLE_USED, TUPLE_UNUSED = ("used", "unused") + +def tuple_user(): + return TUPLE_USED +""" + + expected = """ +def main_function(): + return DIRECT_DEPENDENCY + +def unused_function(): + return "Not used but not removed" + +DIRECT_DEPENDENCY = INDIRECT_DEPENDENCY + "_suffix" +INDIRECT_DEPENDENCY = "base value" + +def tuple_user(): + return TUPLE_USED +""" + + qualified_functions = {"main_function"} + result = remove_unused_definitions_by_function_names(code, qualified_functions) + assert result.strip() == expected.strip() + + +def test_type_annotation_usage() -> None: + """Test that variables used in type annotations are considered used.""" + code = """ +# Type definition +CustomType = int +UnusedType = str + +def main_function(param: CustomType) -> CustomType: + return param + 10 + +def unused_function(param: UnusedType) -> UnusedType: + return param + " suffix" + +UNUSED_CONSTANT = 123 +""" + + expected = """ +# Type definition +CustomType = int + +def main_function(param: CustomType) -> CustomType: + return param + 10 + +def unused_function(param: UnusedType) -> UnusedType: + return param + " suffix" + +""" + + qualified_functions = {"main_function"} + result = remove_unused_definitions_by_function_names(code, qualified_functions) + # Normalize whitespace for comparison + assert result.strip() == expected.strip() + + +def test_class_method_with_dunder_methods() -> None: + """Test that when a class method is used, dunder methods of that class are preserved.""" + code = """ +class MyClass: + CLASS_VAR = "class variable" + UNUSED_VAR = GLOBAL_VAR_2 + + def __init__(self, value): + self.value = GLOBAL_VAR + + def __str__(self): + return f"MyClass({self.value})" + + def target_method(self): + return self.value * 2 + + def unused_method(self): + return "Not used" + +GLOBAL_VAR = "global" +GLOBAL_VAR_2 = "global" +UNUSED_GLOBAL = "unused global" + +def helper_function(): + obj = MyClass(5) + return obj.target_method() +""" + + expected = """ +class MyClass: + CLASS_VAR = "class variable" + UNUSED_VAR = GLOBAL_VAR_2 + + def __init__(self, value): + self.value = GLOBAL_VAR + + def __str__(self): + return f"MyClass({self.value})" + + def target_method(self): + return self.value * 2 + + def unused_method(self): + return "Not used" + +GLOBAL_VAR = "global" +GLOBAL_VAR_2 = "global" + +def helper_function(): + obj = MyClass(5) + return obj.target_method() +""" + + qualified_functions = {"MyClass.target_method"} + result = remove_unused_definitions_by_function_names(code, qualified_functions) + # Normalize whitespace for comparison + assert result.strip() == expected.strip() + + +def test_complex_type_annotations() -> None: + """Test complex type annotations with nested types.""" + code = """ +from typing import List, Dict, Optional + +# Type aliases +ItemType = Dict[str, int] +ResultType = List[ItemType] +UnusedType = Optional[str] + +def process_data(items: ResultType) -> int: + total = 0 + for item in items: + for key, value in item.items(): + total += value + return total + +def unused_function(param: UnusedType) -> None: + pass + +# Variables +SAMPLE_DATA: ResultType = [{"a": 1, "b": 2}] +UNUSED_DATA: UnusedType = None +""" + + expected = """ +from typing import List, Dict, Optional + +# Type aliases +ItemType = Dict[str, int] +ResultType = List[ItemType] + +def process_data(items: ResultType) -> int: + total = 0 + for item in items: + for key, value in item.items(): + total += value + return total + +def unused_function(param: UnusedType) -> None: + pass +""" + + qualified_functions = {"process_data"} + result = remove_unused_definitions_by_function_names(code, qualified_functions) + assert result.strip() == expected.strip() + + +def test_try_except_finally_variables() -> None: + """Test handling of variables defined in try-except-finally blocks.""" + code = """ +import math +import os + +# Top-level try-except that defines variables +try: + MATH_CONSTANT = math.pi + USED_ERROR_MSG = "An error occurred" + UNUSED_CONST = 42 +except ImportError: + MATH_CONSTANT = 3.14 + USED_ERROR_MSG = "Math module not available" + UNUSED_CONST = 0 +finally: + CLEANUP_FLAG = True + UNUSED_CLEANUP = "Not used" + +def use_constants(): + return f"Pi is approximately {MATH_CONSTANT}, message: {USED_ERROR_MSG}" + +def use_cleanup(): + if CLEANUP_FLAG: + return "Cleanup performed" + return "No cleanup" + +def unused_function(): + return UNUSED_CONST +""" + + expected = """ +import math +import os + +# Top-level try-except that defines variables +try: + MATH_CONSTANT = math.pi + USED_ERROR_MSG = "An error occurred" +except ImportError: + MATH_CONSTANT = 3.14 + USED_ERROR_MSG = "Math module not available" +finally: + CLEANUP_FLAG = True + +def use_constants(): + return f"Pi is approximately {MATH_CONSTANT}, message: {USED_ERROR_MSG}" + +def use_cleanup(): + if CLEANUP_FLAG: + return "Cleanup performed" + return "No cleanup" + +def unused_function(): + return UNUSED_CONST +""" + + qualified_functions = {"use_constants", "use_cleanup"} + result = remove_unused_definitions_by_function_names(code, qualified_functions) + assert result.strip() == expected.strip() + +def test_conditional_and_loop_variables() -> None: + """Test handling of variables defined in if-else and while loops.""" + code = """ +import sys +import platform + +# Top-level if-else block defining variables +if sys.platform.startswith('win'): + OS_TYPE = "Windows" + OS_SEP = "" + UNUSED_WIN_VAR = "Unused Windows variable" +elif sys.platform.startswith('linux'): + OS_TYPE = "Linux" + OS_SEP = "/" + UNUSED_LINUX_VAR = "Unused Linux variable" +else: + OS_TYPE = "Other" + OS_SEP = "/" + UNUSED_OTHER_VAR = "Unused other variable" + +# While loop with variable definitions +counter = 0 +while counter < 5: + LOOP_RESULT = "Iteration " + str(counter) + UNUSED_LOOP_VAR = "Unused loop " + str(counter) + counter += 1 + +def get_platform_info(): + return "OS: " + OS_TYPE + ", Separator: " + OS_SEP + +def get_loop_result(): + return LOOP_RESULT + +def unused_function(): + result = "" + if sys.platform.startswith('win'): + result = UNUSED_WIN_VAR + elif sys.platform.startswith('linux'): + result = UNUSED_LINUX_VAR + else: + result = UNUSED_OTHER_VAR + return result +""" + + expected = """ +import sys +import platform + +# Top-level if-else block defining variables +if sys.platform.startswith('win'): + OS_TYPE = "Windows" + OS_SEP = "" +elif sys.platform.startswith('linux'): + OS_TYPE = "Linux" + OS_SEP = "/" +else: + OS_TYPE = "Other" + OS_SEP = "/" + +# While loop with variable definitions +counter = 0 +while counter < 5: + LOOP_RESULT = "Iteration " + str(counter) + counter += 1 + +def get_platform_info(): + return "OS: " + OS_TYPE + ", Separator: " + OS_SEP + +def get_loop_result(): + return LOOP_RESULT + +def unused_function(): + result = "" + if sys.platform.startswith('win'): + result = UNUSED_WIN_VAR + elif sys.platform.startswith('linux'): + result = UNUSED_LINUX_VAR + else: + result = UNUSED_OTHER_VAR + return result +""" + + qualified_functions = {"get_platform_info", "get_loop_result"} + result = remove_unused_definitions_by_function_names(code, qualified_functions) + assert result.strip() == expected.strip()