From d1ae9329eb4241b8bfa47eba47f2ffbdac8b92a5 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Fri, 30 May 2025 12:41:02 -0700 Subject: [PATCH 1/2] don't count nested functions as it's own func --- codeflash/discovery/functions_to_optimize.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 8aa052ab0..3f0d72bcd 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -94,8 +94,6 @@ def visit_FunctionDef(self, node: FunctionDef) -> None: self.functions.append( FunctionToOptimize(function_name=node.name, file_path=self.file_path, parents=self.ast_path[:]) ) - # Continue visiting the body of the function to find nested functions - self.generic_visit(node) def generic_visit(self, node: ast.AST) -> None: if isinstance(node, (FunctionDef, AsyncFunctionDef, ClassDef)): From d184036466d0c8127620f2cd075948bbf26df191 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Fri, 30 May 2025 15:13:10 -0700 Subject: [PATCH 2/2] add tests --- tests/test_function_discovery.py | 147 +++++++++++++++++++++++++++++++ 1 file changed, 147 insertions(+) diff --git a/tests/test_function_discovery.py b/tests/test_function_discovery.py index a02dec2dd..da40121bc 100644 --- a/tests/test_function_discovery.py +++ b/tests/test_function_discovery.py @@ -35,6 +35,17 @@ def test_function_eligible_for_optimization() -> None: assert len(functions_found[Path(f.name)]) == 0 + # we want to trigger an error in the function discovery + function = """def test_invalid_code():""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f: + f.write(function) + f.flush() + functions_found = find_all_functions_in_file(Path(f.name)) + assert functions_found == {} + + + + def test_find_top_level_function_or_method(): with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f: f.write( @@ -82,6 +93,15 @@ def non_classmethod_function(cls, name): ).is_top_level # needed because this will be traced with a class_name being passed + # we want to write invalid code to ensure that the function discovery does not crash + with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f: + f.write( + """def functionA(): +""" + ) + f.flush() + path_obj_name = Path(f.name) + assert not inspect_top_level_functions_or_methods(path_obj_name, "functionA") def test_class_method_discovery(): with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f: @@ -152,6 +172,133 @@ def functionA(): assert functions[file][0].function_name == "functionA" +def test_nested_function(): + with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f: + f.write( +""" +import copy + +def propagate_attributes( + nodes: dict[str, dict], edges: list[dict], source_node_id: str, attribute: str +) -> dict[str, dict]: + modified_nodes = copy.deepcopy(nodes) + + # Build an adjacency list for faster traversal + adjacency = {} + for edge in edges: + src = edge["source"] + tgt = edge["target"] + if src not in adjacency: + adjacency[src] = [] + adjacency[src].append(tgt) + + # Track visited nodes to avoid cycles + visited = set() + + def traverse(node_id): + if node_id in visited: + return + visited.add(node_id) + + # Propagate attribute from source node + if ( + node_id != source_node_id + and source_node_id in modified_nodes + and attribute in modified_nodes[source_node_id] + ): + if node_id in modified_nodes: + modified_nodes[node_id][attribute] = modified_nodes[source_node_id][ + attribute + ] + + # Continue propagation to neighbors + for neighbor in adjacency.get(node_id, []): + traverse(neighbor) + + traverse(source_node_id) + return modified_nodes +""" + ) + f.flush() + test_config = TestConfig( + tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path() + ) + path_obj_name = Path(f.name) + functions, functions_count = get_functions_to_optimize( + optimize_all=None, + replay_test=None, + file=path_obj_name, + test_cfg=test_config, + only_get_this_function=None, + ignore_paths=[Path("/bruh/")], + project_root=path_obj_name.parent, + module_root=path_obj_name.parent, + ) + + assert len(functions) == 1 + assert functions_count == 1 + + with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f: + f.write( +""" +def outer_function(): + def inner_function(): + pass + + return inner_function +""" + ) + f.flush() + test_config = TestConfig( + tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path() + ) + path_obj_name = Path(f.name) + functions, functions_count = get_functions_to_optimize( + optimize_all=None, + replay_test=None, + file=path_obj_name, + test_cfg=test_config, + only_get_this_function=None, + ignore_paths=[Path("/bruh/")], + project_root=path_obj_name.parent, + module_root=path_obj_name.parent, + ) + + assert len(functions) == 1 + assert functions_count == 1 + + with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f: + f.write( +""" +def outer_function(): + def inner_function(): + pass + + def another_inner_function(): + pass + return inner_function, another_inner_function +""" + ) + f.flush() + test_config = TestConfig( + tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path() + ) + path_obj_name = Path(f.name) + functions, functions_count = get_functions_to_optimize( + optimize_all=None, + replay_test=None, + file=path_obj_name, + test_cfg=test_config, + only_get_this_function=None, + ignore_paths=[Path("/bruh/")], + project_root=path_obj_name.parent, + module_root=path_obj_name.parent, + ) + + assert len(functions) == 1 + assert functions_count == 1 + + def test_filter_files_optimized(): tests_root = Path("tests").resolve() module_root = Path().resolve()