Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions codeflash/discovery/functions_to_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,6 @@ def visit_FunctionDef(self, node: FunctionDef) -> None:
self.functions.append(
FunctionToOptimize(function_name=node.name, file_path=self.file_path, parents=self.ast_path[:])
)
# Continue visiting the body of the function to find nested functions
self.generic_visit(node)

def generic_visit(self, node: ast.AST) -> None:
if isinstance(node, (FunctionDef, AsyncFunctionDef, ClassDef)):
Expand Down
147 changes: 147 additions & 0 deletions tests/test_function_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,17 @@ def test_function_eligible_for_optimization() -> None:
assert len(functions_found[Path(f.name)]) == 0


# we want to trigger an error in the function discovery
function = """def test_invalid_code():"""
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
f.write(function)
f.flush()
functions_found = find_all_functions_in_file(Path(f.name))
assert functions_found == {}




def test_find_top_level_function_or_method():
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
f.write(
Expand Down Expand Up @@ -82,6 +93,15 @@ def non_classmethod_function(cls, name):
).is_top_level
# needed because this will be traced with a class_name being passed

# we want to write invalid code to ensure that the function discovery does not crash
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
f.write(
"""def functionA():
"""
)
f.flush()
path_obj_name = Path(f.name)
assert not inspect_top_level_functions_or_methods(path_obj_name, "functionA")

def test_class_method_discovery():
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
Expand Down Expand Up @@ -152,6 +172,133 @@ def functionA():
assert functions[file][0].function_name == "functionA"


def test_nested_function():
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
f.write(
"""
import copy

def propagate_attributes(
nodes: dict[str, dict], edges: list[dict], source_node_id: str, attribute: str
) -> dict[str, dict]:
modified_nodes = copy.deepcopy(nodes)

# Build an adjacency list for faster traversal
adjacency = {}
for edge in edges:
src = edge["source"]
tgt = edge["target"]
if src not in adjacency:
adjacency[src] = []
adjacency[src].append(tgt)

# Track visited nodes to avoid cycles
visited = set()

def traverse(node_id):
if node_id in visited:
return
visited.add(node_id)

# Propagate attribute from source node
if (
node_id != source_node_id
and source_node_id in modified_nodes
and attribute in modified_nodes[source_node_id]
):
if node_id in modified_nodes:
modified_nodes[node_id][attribute] = modified_nodes[source_node_id][
attribute
]

# Continue propagation to neighbors
for neighbor in adjacency.get(node_id, []):
traverse(neighbor)

traverse(source_node_id)
return modified_nodes
"""
)
f.flush()
test_config = TestConfig(
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
)
path_obj_name = Path(f.name)
functions, functions_count = get_functions_to_optimize(
optimize_all=None,
replay_test=None,
file=path_obj_name,
test_cfg=test_config,
only_get_this_function=None,
ignore_paths=[Path("/bruh/")],
project_root=path_obj_name.parent,
module_root=path_obj_name.parent,
)

assert len(functions) == 1
assert functions_count == 1

with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
f.write(
"""
def outer_function():
def inner_function():
pass

return inner_function
"""
)
f.flush()
test_config = TestConfig(
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
)
path_obj_name = Path(f.name)
functions, functions_count = get_functions_to_optimize(
optimize_all=None,
replay_test=None,
file=path_obj_name,
test_cfg=test_config,
only_get_this_function=None,
ignore_paths=[Path("/bruh/")],
project_root=path_obj_name.parent,
module_root=path_obj_name.parent,
)

assert len(functions) == 1
assert functions_count == 1

with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
f.write(
"""
def outer_function():
def inner_function():
pass

def another_inner_function():
pass
return inner_function, another_inner_function
"""
)
f.flush()
test_config = TestConfig(
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
)
path_obj_name = Path(f.name)
functions, functions_count = get_functions_to_optimize(
optimize_all=None,
replay_test=None,
file=path_obj_name,
test_cfg=test_config,
only_get_this_function=None,
ignore_paths=[Path("/bruh/")],
project_root=path_obj_name.parent,
module_root=path_obj_name.parent,
)

assert len(functions) == 1
assert functions_count == 1


def test_filter_files_optimized():
tests_root = Path("tests").resolve()
module_root = Path().resolve()
Expand Down
Loading