Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions codeflash/context/unused_definition_remover.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from collections import defaultdict
from dataclasses import dataclass, field

import libcst as cst
Expand Down Expand Up @@ -255,24 +256,24 @@ class QualifiedFunctionUsageMarker:
def __init__(self, definitions: dict[str, UsageInfo], qualified_function_names: set[str]) -> None:
self.definitions = definitions
self.qualified_function_names = qualified_function_names
self.class_dunder_methods = self._preprocess_definitions()
self.expanded_qualified_functions = self._expand_qualified_functions()

def _expand_qualified_functions(self) -> set[str]:
"""Expand the qualified function names to include related methods."""
expanded = set(self.qualified_function_names)

# Find class methods and add their containing classes and dunder methods
for qualified_name in list(self.qualified_function_names):
for qualified_name in self.qualified_function_names:
if "." in qualified_name:
class_name, method_name = qualified_name.split(".", 1)

# Add the class itself
expanded.add(class_name)

# Add all dunder methods of the class
for name in self.definitions:
if name.startswith(f"{class_name}.__") and name.endswith("__"):
expanded.add(name)
if class_name in self.class_dunder_methods:
expanded.update(self.class_dunder_methods[class_name])

return expanded

Expand Down Expand Up @@ -301,9 +302,21 @@ def mark_as_used_recursively(self, name: str) -> None:
for dep in self.definitions[name].dependencies:
self.mark_as_used_recursively(dep)

def _preprocess_definitions(self) -> dict[str, set[str]]:
"""Preprocess definitions to find dunder methods for each class."""
class_dunder_methods = defaultdict(set)

for name in self.definitions:
if name.count(".") == 1:
class_name, method_name = name.split(".", 1)
if method_name.startswith("__") and method_name.endswith("__"):
class_dunder_methods[class_name].add(name)

return class_dunder_methods


def remove_unused_definitions_recursively(
node: cst.CSTNode, definitions: dict[str, UsageInfo]
node: cst.CSTNode, definitions: dict[str, UsageInfo]
) -> tuple[cst.CSTNode | None, bool]:
"""Recursively filter the node to remove unused definitions.

Expand Down Expand Up @@ -358,7 +371,10 @@ def remove_unused_definitions_recursively(
names = extract_names_from_targets(target.target)
for name in names:
class_var_name = f"{class_name}.{name}"
if class_var_name in definitions and definitions[class_var_name].used_by_qualified_function:
if (
class_var_name in definitions
and definitions[class_var_name].used_by_qualified_function
):
var_used = True
method_or_var_used = True
break
Expand Down
Loading