Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 40 additions & 31 deletions codeflash/discovery/discover_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,18 @@ def __init__(self, function_names_to_find: set[str]) -> None:
# For prefix match, store mapping from prefix-root to candidates for O(1) matching
self._exact_names = function_names_to_find
self._prefix_roots: dict[str, list[str]] = {}
# Precompute sets for faster lookup during visit_Attribute()
self._dot_names: set[str] = set()
self._dot_methods: dict[str, set[str]] = {}
self._class_method_to_target: dict[tuple[str, str], str] = {}
for name in function_names_to_find:
if "." in name:
root = name.split(".", 1)[0]
self._prefix_roots.setdefault(root, []).append(name)
root, method = name.rsplit(".", 1)
self._dot_names.add(name)
self._dot_methods.setdefault(method, set()).add(root)
self._class_method_to_target[(root, method)] = name
root_prefix = name.split(".", 1)[0]
self._prefix_roots.setdefault(root_prefix, []).append(name)

def visit_Import(self, node: ast.Import) -> None:
"""Handle 'import module' statements."""
Expand Down Expand Up @@ -321,44 +329,45 @@ def visit_Attribute(self, node: ast.Attribute) -> None:
if self.found_any_target_function:
return

# Check if this is accessing a target function through an imported module

node_value = node.value
node_attr = node.attr

# Check if this is accessing a target function through an imported module
if (
isinstance(node.value, ast.Name)
and node.value.id in self.imported_modules
and node.attr in self.function_names_to_find
isinstance(node_value, ast.Name)
and node_value.id in self.imported_modules
and node_attr in self.function_names_to_find
):
self.found_any_target_function = True
self.found_qualified_name = node.attr
self.found_qualified_name = node_attr
return

if isinstance(node.value, ast.Name) and node.value.id in self.imported_modules:
for target_func in self.function_names_to_find:
if "." in target_func:
class_name, method_name = target_func.rsplit(".", 1)
if node.attr == method_name:
imported_name = node.value.id
original_name = self.alias_mapping.get(imported_name, imported_name)
if original_name == class_name:
self.found_any_target_function = True
self.found_qualified_name = target_func
return
# Check for methods via imported modules using precomputed _dot_methods and _class_method_to_target
if isinstance(node_value, ast.Name) and node_value.id in self.imported_modules:
roots_possible = self._dot_methods.get(node_attr)
if roots_possible:
imported_name = node_value.id
original_name = self.alias_mapping.get(imported_name, imported_name)
if original_name in roots_possible:
self.found_any_target_function = True
self.found_qualified_name = self._class_method_to_target[(original_name, node_attr)]
return

# Check if this is accessing a method on an instance variable
if isinstance(node.value, ast.Name) and node.value.id in self.instance_mapping:
class_name = self.instance_mapping[node.value.id]
for target_func in self.function_names_to_find:
if "." in target_func:
target_class, method_name = target_func.rsplit(".", 1)
if node.attr == method_name and class_name == target_class:
self.found_any_target_function = True
self.found_qualified_name = target_func
return

# Check if this is accessing a target function through a dynamically imported module
# Only if we've detected dynamic imports are being used
if self.has_dynamic_imports and node.attr in self.function_names_to_find:
if isinstance(node_value, ast.Name) and node_value.id in self.instance_mapping:
class_name = self.instance_mapping[node_value.id]
roots_possible = self._dot_methods.get(node_attr)
if roots_possible and class_name in roots_possible:
self.found_any_target_function = True
self.found_qualified_name = self._class_method_to_target[(class_name, node_attr)]
return

# Check for dynamic import match
if self.has_dynamic_imports and node_attr in self.function_names_to_find:
self.found_any_target_function = True
self.found_qualified_name = node.attr
self.found_qualified_name = node_attr
return

self.generic_visit(node)
Expand Down
Loading