From 274f98b2093708a4a89db881b905427e93c21844 Mon Sep 17 00:00:00 2001 From: Alvin Ryanputra Date: Thu, 13 Mar 2025 18:52:11 -0700 Subject: [PATCH 1/3] changed preexisting objects to be a set. removes duplicates naturally and makes it easier to search for matches when replacing code. --- codeflash/code_utils/code_extractor.py | 20 ++++++++++++- codeflash/code_utils/code_replacer.py | 14 ++++----- codeflash/context/code_context_extractor.py | 6 ++-- tests/test_code_replacement.py | 32 ++++++++++----------- 4 files changed, 45 insertions(+), 27 deletions(-) diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 409551d0a..782e60027 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -235,7 +235,7 @@ def extract_code(functions_to_optimize: list[FunctionToOptimize]) -> tuple[str | return edited_code, contextual_dunder_methods -def find_preexisting_objects(source_code: str) -> list[tuple[str, list[FunctionParent]]]: +def find_preexisting_object_old(source_code: str) -> list[tuple[str, list[FunctionParent]]]: """Find all preexisting functions, classes or class methods in the source code""" preexisting_objects: list[tuple[str, list[FunctionParent]]] = [] try: @@ -252,3 +252,21 @@ def find_preexisting_objects(source_code: str) -> list[tuple[str, list[FunctionP if isinstance(cnode, (ast.FunctionDef, ast.AsyncFunctionDef)): preexisting_objects.append((cnode.name, [FunctionParent(node.name, "ClassDef")])) return preexisting_objects + +def find_preexisting_objects(source_code: str) -> set[tuple[str, tuple[FunctionParent, ...]]]: + """Find all preexisting functions, classes or class methods in the source code""" + preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = set() + try: + module_node: ast.Module = ast.parse(source_code) + except SyntaxError: + logger.exception("find_preexisting_objects - Syntax error while parsing code") + return preexisting_objects + for node in module_node.body: + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + preexisting_objects.add((node.name, ())) + elif isinstance(node, ast.ClassDef): + preexisting_objects.add((node.name, ())) + for cnode in node.body: + if isinstance(cnode, (ast.FunctionDef, ast.AsyncFunctionDef)): + preexisting_objects.add((cnode.name, (FunctionParent(node.name, "ClassDef"),))) + return preexisting_objects diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index 2c169136d..86f9bfb02 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -38,11 +38,11 @@ class OptimFunctionCollector(cst.CSTVisitor): def __init__( self, - preexisting_objects: list[tuple[str, list[FunctionParent]]] | None = None, + preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] | None = None, function_names: set[tuple[str | None, str]] | None = None, ) -> None: super().__init__() - self.preexisting_objects = preexisting_objects if preexisting_objects is not None else [] + self.preexisting_objects = preexisting_objects if preexisting_objects is not None else set() self.function_names = function_names # set of (class_name, function_name) self.modified_functions: dict[ @@ -60,7 +60,7 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: self.modified_init_functions[self.current_class] = node elif ( self.preexisting_objects - and (node.name.value, []) not in self.preexisting_objects + and (node.name.value, ()) not in self.preexisting_objects and self.current_class is None ): self.new_functions.append(node) @@ -71,7 +71,7 @@ def visit_ClassDef(self, node: cst.ClassDef) -> bool: return False # If already in a class, do not recurse deeper self.current_class = node.name.value - parents = [FunctionParent(name=node.name.value, type="ClassDef")] + parents = (FunctionParent(name=node.name.value, type="ClassDef"),) for child_node in node.body.body: if ( self.preexisting_objects @@ -159,7 +159,7 @@ def replace_functions_in_file( source_code: str, original_function_names: list[str], optimized_code: str, - preexisting_objects: list[tuple[str, list[FunctionParent]]], + preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]], ) -> str: parsed_function_names = [] for original_function_name in original_function_names: @@ -195,7 +195,7 @@ def replace_functions_and_add_imports( function_names: list[str], optimized_code: str, module_abspath: Path, - preexisting_objects: list[tuple[str, list[FunctionParent]]], + preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]], project_root_path: Path, ) -> str: return add_needed_imports_from_module( @@ -211,7 +211,7 @@ def replace_function_definitions_in_module( function_names: list[str], optimized_code: str, module_abspath: Path, - preexisting_objects: list[tuple[str, list[FunctionParent]]], + preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]], project_root_path: Path, ) -> bool: source_code: str = module_abspath.read_text(encoding="utf8") diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index 7e7fbdb74..e9ace89e3 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -65,13 +65,13 @@ def get_code_optimization_context( if final_read_writable_tokens > optim_token_limit: raise ValueError("Read-writable code has exceeded token limit, cannot proceed") - # Setup preexisting objects for code replacer TODO: should remove duplicates - preexisting_objects = list( + # Setup preexisting objects for code replacer + preexisting_objects = list(set( chain( find_preexisting_objects(final_read_writable_code), *(find_preexisting_objects(codestring.code) for codestring in read_only_code_markdown.code_strings), ) - ) + )) read_only_context_code = read_only_code_markdown.markdown read_only_code_markdown_tokens = len(tokenizer.encode(read_only_context_code)) diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 59bdbcc23..ea221be78 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -74,7 +74,7 @@ def totally_new_function(value): """ function_name: str = "NewClass.new_function" - preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code) + preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) new_code: str = replace_functions_and_add_imports( source_code=original_code, function_names=[function_name], @@ -135,7 +135,7 @@ def other_function(st): """ function_name: str = "NewClass.new_function" - preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code) + preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) new_code: str = replace_functions_and_add_imports( source_code=original_code, function_names=[function_name], @@ -196,7 +196,7 @@ def totally_new_function(value): """ function_names: list[str] = ["other_function"] - preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code) + preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) new_code: str = replace_functions_and_add_imports( source_code=original_code, function_names=function_names, @@ -260,7 +260,7 @@ def totally_new_function(value): """ function_names: list[str] = ["yet_another_function", "other_function"] - preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code) + preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) new_code: str = replace_functions_and_add_imports( source_code=original_code, function_names=function_names, @@ -313,7 +313,7 @@ def supersort(doink): """ function_names: list[str] = ["sorter_deps"] - preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code) + preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) new_code: str = replace_functions_and_add_imports( source_code=original_code, function_names=function_names, @@ -388,7 +388,7 @@ def blab(st): print("Not cool") """ - preexisting_objects = find_preexisting_objects(original_code_main) + find_preexisting_objects(original_code_helper) + preexisting_objects = find_preexisting_objects(original_code_main) | find_preexisting_objects(original_code_helper) new_main_code: str = replace_functions_and_add_imports( source_code=original_code_main, function_names=["other_function"], @@ -591,7 +591,7 @@ def from_config(config: Optional[dict[str, Any]]): ) """ function_names: list[str] = ["CacheSimilarityEvalConfig.from_config"] - preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code) + preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) new_code: str = replace_functions_and_add_imports( source_code=original_code, @@ -662,7 +662,7 @@ def _hamming_distance(a: np.ndarray, b: np.ndarray) -> np.floating: return np.sum(a != b) / a.size ''' function_names: list[str] = ["_EmbeddingDistanceChainMixin._hamming_distance"] - preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code) + preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) new_code: str = replace_functions_and_add_imports( source_code=original_code, function_names=function_names, @@ -715,7 +715,7 @@ def totally_new_function(value: Optional[str]): print("Hello world") """ function_name: str = "NewClass.__init__" - preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code) + preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) new_code: str = replace_functions_and_add_imports( source_code=original_code, function_names=[function_name], @@ -814,8 +814,8 @@ def real_bar(self) -> int: ''' function_name: str = "Fu.foo" - parents = [FunctionParent("Fu", "ClassDef")] - preexisting_objects: list[tuple[str, list[FunctionParent]]] = [("foo", parents), ("real_bar", parents)] + parents = (FunctionParent("Fu", "ClassDef"),) + preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = {("foo", parents), ("real_bar", parents)} new_code: str = replace_functions_in_file( source_code=original_code, original_function_names=[function_name], @@ -854,7 +854,7 @@ def real_bar(self) -> int: pass ''' - preexisting_objects: list[tuple[str, list[FunctionParent]]] = [] + preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = [] new_code: str = replace_functions_in_file( source_code=original_code, original_function_names=["Fu.real_bar"], @@ -891,7 +891,7 @@ def __call__(self, value): """ function_names: list[str] = ["yet_another_function", "other_function"] - preexisting_objects: list[tuple[str, list[FunctionParent]]] = [] + preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = [] new_code: str = replace_functions_and_add_imports( source_code=original_code, function_names=function_names, @@ -1278,7 +1278,7 @@ def cosine_similarity_top_k( return ret_idxs, scores ''' - preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code) + preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) helper_functions = [ FakeFunctionSource( @@ -1579,7 +1579,7 @@ def nested_function(self): "NewClass.new_function2", "NestedClass.nested_function", ] # Nested classes should be ignored, even if provided as target - preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code) + preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) new_code: str = replace_functions_and_add_imports( source_code=original_code, function_names=function_names, @@ -1615,7 +1615,7 @@ def new_function2(value): """ function_names: list[str] = ["NewClass.__init__", "NewClass.__call__", "NewClass.new_function2"] - preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code) + preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) new_code: str = replace_functions_and_add_imports( source_code=original_code, function_names=function_names, From 5d5e77b5fb7e2a90ac3a52add8ed22e2dbda2891 Mon Sep 17 00:00:00 2001 From: Alvin Ryanputra Date: Thu, 13 Mar 2025 18:53:26 -0700 Subject: [PATCH 2/3] removed old function --- codeflash/code_utils/code_extractor.py | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 782e60027..a06a45c92 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -235,26 +235,8 @@ def extract_code(functions_to_optimize: list[FunctionToOptimize]) -> tuple[str | return edited_code, contextual_dunder_methods -def find_preexisting_object_old(source_code: str) -> list[tuple[str, list[FunctionParent]]]: - """Find all preexisting functions, classes or class methods in the source code""" - preexisting_objects: list[tuple[str, list[FunctionParent]]] = [] - try: - module_node: ast.Module = ast.parse(source_code) - except SyntaxError: - logger.exception("find_preexisting_objects - Syntax error while parsing code") - return preexisting_objects - for node in module_node.body: - if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): - preexisting_objects.append((node.name, [])) - elif isinstance(node, ast.ClassDef): - preexisting_objects.append((node.name, [])) - for cnode in node.body: - if isinstance(cnode, (ast.FunctionDef, ast.AsyncFunctionDef)): - preexisting_objects.append((cnode.name, [FunctionParent(node.name, "ClassDef")])) - return preexisting_objects - def find_preexisting_objects(source_code: str) -> set[tuple[str, tuple[FunctionParent, ...]]]: - """Find all preexisting functions, classes or class methods in the source code""" + """Find all preexisting functions, classes or class methods in the source code.""" preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = set() try: module_node: ast.Module = ast.parse(source_code) From 4aa194e4b13e5092909cbcfdd879b67c126eae9c Mon Sep 17 00:00:00 2001 From: Alvin Ryanputra Date: Thu, 13 Mar 2025 19:05:38 -0700 Subject: [PATCH 3/3] updated types --- codeflash/context/code_context_extractor.py | 4 ++-- codeflash/models/models.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index e9ace89e3..e58b372d6 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -66,12 +66,12 @@ def get_code_optimization_context( raise ValueError("Read-writable code has exceeded token limit, cannot proceed") # Setup preexisting objects for code replacer - preexisting_objects = list(set( + preexisting_objects = set( chain( find_preexisting_objects(final_read_writable_code), *(find_preexisting_objects(codestring.code) for codestring in read_only_code_markdown.code_strings), ) - )) + ) read_only_context_code = read_only_code_markdown.markdown read_only_code_markdown_tokens = len(tokenizer.encode(read_only_context_code)) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 9cf51cf31..bd7fd3e05 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -99,7 +99,7 @@ class CodeOptimizationContext(BaseModel): read_writable_code: str = Field(min_length=1) read_only_context_code: str = "" helper_functions: list[FunctionSource] - preexisting_objects: list[tuple[str, list[FunctionParent]]] + preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] class CodeContextType(str, Enum): READ_WRITABLE = "READ_WRITABLE"