diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index a06a45c92..bcbc0e29d 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -2,7 +2,7 @@ import ast from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Dict, Optional, Set import libcst as cst import libcst.matchers as m @@ -18,6 +18,227 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from typing import List, Union + +class GlobalAssignmentCollector(cst.CSTVisitor): + """Collects all global assignment statements.""" + + def __init__(self): + super().__init__() + self.assignments: Dict[str, cst.Assign] = {} + self.assignment_order: List[str] = [] + # Track scope depth to identify global assignments + self.scope_depth = 0 + self.if_else_depth = 0 + + def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: + self.scope_depth += 1 + return True + + def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: + self.scope_depth -= 1 + + def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: + self.scope_depth += 1 + return True + + def leave_ClassDef(self, original_node: cst.ClassDef) -> None: + self.scope_depth -= 1 + + def visit_If(self, node: cst.If) -> Optional[bool]: + self.if_else_depth += 1 + return True + + def leave_If(self, original_node: cst.If) -> None: + self.if_else_depth -= 1 + + def visit_Else(self, node: cst.Else) -> Optional[bool]: + # Else blocks are already counted as part of the if statement + return True + + def visit_Assign(self, node: cst.Assign) -> Optional[bool]: + # Only process global assignments (not inside functions, classes, etc.) + if self.scope_depth == 0 and self.if_else_depth == 0: # We're at module level + for target in node.targets: + if isinstance(target.target, cst.Name): + name = target.target.value + self.assignments[name] = node + if name not in self.assignment_order: + self.assignment_order.append(name) + return True + + +class GlobalAssignmentTransformer(cst.CSTTransformer): + """Transforms global assignments in the original file with those from the new file.""" + + def __init__(self, new_assignments: Dict[str, cst.Assign], new_assignment_order: List[str]): + super().__init__() + self.new_assignments = new_assignments + self.new_assignment_order = new_assignment_order + self.processed_assignments: Set[str] = set() + self.scope_depth = 0 + self.if_else_depth = 0 + + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: + self.scope_depth += 1 + + def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: + self.scope_depth -= 1 + return updated_node + + def visit_ClassDef(self, node: cst.ClassDef) -> None: + self.scope_depth += 1 + + def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: + self.scope_depth -= 1 + return updated_node + + def visit_If(self, node: cst.If) -> None: + self.if_else_depth += 1 + + def leave_If(self, original_node: cst.If, updated_node: cst.If) -> cst.If: + self.if_else_depth -= 1 + return updated_node + + def visit_Else(self, node: cst.Else) -> None: + # Else blocks are already counted as part of the if statement + pass + + def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> cst.CSTNode: + if self.scope_depth > 0 or self.if_else_depth > 0: + return updated_node + + # Check if this is a global assignment we need to replace + for target in original_node.targets: + if isinstance(target.target, cst.Name): + name = target.target.value + if name in self.new_assignments: + self.processed_assignments.add(name) + return self.new_assignments[name] + + return updated_node + + def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: + # Add any new assignments that weren't in the original file + new_statements = list(updated_node.body) + + # Find assignments to append + assignments_to_append = [] + for name in self.new_assignment_order: + if name not in self.processed_assignments and name in self.new_assignments: + assignments_to_append.append(self.new_assignments[name]) + + if assignments_to_append: + # Add a blank line before appending new assignments if needed + if new_statements and not isinstance(new_statements[-1], cst.EmptyLine): + new_statements.append(cst.SimpleStatementLine([cst.Pass()], leading_lines=[cst.EmptyLine()])) + new_statements.pop() # Remove the Pass statement but keep the empty line + + # Add the new assignments + for assignment in assignments_to_append: + new_statements.append( + cst.SimpleStatementLine( + [assignment], + leading_lines=[cst.EmptyLine()] + ) + ) + + return updated_node.with_changes(body=new_statements) + +class GlobalStatementCollector(cst.CSTVisitor): + """Visitor that collects all global statements (excluding imports and functions/classes).""" + + def __init__(self): + super().__init__() + self.global_statements = [] + self.in_function_or_class = False + + def visit_ClassDef(self, node: cst.ClassDef) -> bool: + # Don't visit inside classes + self.in_function_or_class = True + return False + + def leave_ClassDef(self, original_node: cst.ClassDef) -> None: + self.in_function_or_class = False + + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: + # Don't visit inside functions + self.in_function_or_class = True + return False + + def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: + self.in_function_or_class = False + + def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None: + if not self.in_function_or_class: + for statement in node.body: + # Skip imports + if not isinstance(statement, (cst.Import, cst.ImportFrom, cst.Assign)): + self.global_statements.append(node) + break + + +class LastImportFinder(cst.CSTVisitor): + """Finds the position of the last import statement in the module.""" + + def __init__(self): + super().__init__() + self.last_import_line = 0 + self.current_line = 0 + + def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None: + self.current_line += 1 + for statement in node.body: + if isinstance(statement, (cst.Import, cst.ImportFrom)): + self.last_import_line = self.current_line + + +class ImportInserter(cst.CSTTransformer): + """Transformer that inserts global statements after the last import.""" + + def __init__(self, global_statements: List[cst.SimpleStatementLine], last_import_line: int): + super().__init__() + self.global_statements = global_statements + self.last_import_line = last_import_line + self.current_line = 0 + self.inserted = False + + def leave_SimpleStatementLine( + self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine + ) -> cst.Module: + self.current_line += 1 + + # If we're right after the last import and haven't inserted yet + if self.current_line == self.last_import_line and not self.inserted: + self.inserted = True + return cst.Module(body=[updated_node] + self.global_statements) + + return cst.Module(body=[updated_node]) + + def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: + # If there were no imports, add at the beginning of the module + if self.last_import_line == 0 and not self.inserted: + updated_body = list(updated_node.body) + for stmt in reversed(self.global_statements): + updated_body.insert(0, stmt) + return updated_node.with_changes(body=updated_body) + return updated_node + + +def extract_global_statements(source_code: str) -> List[cst.SimpleStatementLine]: + """Extract global statements from source code.""" + module = cst.parse_module(source_code) + collector = GlobalStatementCollector() + module.visit(collector) + return collector.global_statements + + +def find_last_import_line(target_code: str) -> int: + """Find the line number of the last import statement.""" + module = cst.parse_module(target_code) + finder = LastImportFinder() + module.visit(finder) + return finder.last_import_line class FutureAliasedImportTransformer(cst.CSTTransformer): def leave_ImportFrom( @@ -38,6 +259,38 @@ def delete___future___aliased_imports(module_code: str) -> str: return cst.parse_module(module_code).visit(FutureAliasedImportTransformer()).code +def add_global_assignments(src_module_code: str, dst_module_code: str) -> str: + non_assignment_global_statements = extract_global_statements(src_module_code) + + # Find the last import line in target + last_import_line = find_last_import_line(dst_module_code) + + # Parse the target code + target_module = cst.parse_module(dst_module_code) + + # Create transformer to insert non_assignment_global_statements + transformer = ImportInserter(non_assignment_global_statements, last_import_line) + # + # # Apply transformation + modified_module = target_module.visit(transformer) + dst_module_code = modified_module.code + + # Parse the code + original_module = cst.parse_module(dst_module_code) + new_module = cst.parse_module(src_module_code) + + # Collect assignments from the new file + new_collector = GlobalAssignmentCollector() + new_module.visit(new_collector) + + # Transform the original file + transformer = GlobalAssignmentTransformer(new_collector.assignments, new_collector.assignment_order) + transformed_module = original_module.visit(transformer) + + dst_module_code = transformed_module.code + return dst_module_code + + def add_needed_imports_from_module( src_module_code: str, dst_module_code: str, diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index ad37bfbd2..ccb935f42 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -8,7 +8,7 @@ import libcst as cst from codeflash.cli_cmds.console import logger -from codeflash.code_utils.code_extractor import add_needed_imports_from_module +from codeflash.code_utils.code_extractor import add_needed_imports_from_module, add_global_assignments from codeflash.models.models import FunctionParent if TYPE_CHECKING: @@ -220,7 +220,8 @@ def replace_function_definitions_in_module( ) if is_zero_diff(source_code, new_code): return False - module_abspath.write_text(new_code, encoding="utf8") + code_with_global_assignments = add_global_assignments(optimized_code, new_code) + module_abspath.write_text(code_with_global_assignments, encoding="utf8") return True diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index 792a76885..36b1613e4 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -360,23 +360,26 @@ def get_function_to_optimize_as_function_source( # Find the name that matches our function for name in names: - if ( - name.type == "function" - and name.full_name - and name.name == function_to_optimize.function_name - and name.full_name.startswith(name.module_name) - and get_qualified_name(name.module_name, name.full_name) == function_to_optimize.qualified_name - ): - function_source = FunctionSource( - file_path=function_to_optimize.file_path, - qualified_name=function_to_optimize.qualified_name, - fully_qualified_name=name.full_name, - only_function_name=name.name, - source_code=name.get_line_code(), - jedi_definition=name, - ) - return function_source - + try: + if ( + name.type == "function" + and name.full_name + and name.name == function_to_optimize.function_name + and name.full_name.startswith(name.module_name) + and get_qualified_name(name.module_name, name.full_name) == function_to_optimize.qualified_name + ): + function_source = FunctionSource( + file_path=function_to_optimize.file_path, + qualified_name=function_to_optimize.qualified_name, + fully_qualified_name=name.full_name, + only_function_name=name.name, + source_code=name.get_line_code(), + jedi_definition=name, + ) + return function_source + except Exception as e: + logger.exception(f"Error while getting function source: {e}") + continue raise ValueError( f"Could not find function {function_to_optimize.function_name} in {function_to_optimize.file_path}" ) diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index d3c4d941a..2e8c2f6fd 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -12,7 +12,7 @@ replace_functions_in_file, ) from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import FunctionParent +from codeflash.models.models import CodeOptimizationContext, FunctionParent from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.verification.verification_utils import TestConfig @@ -34,6 +34,48 @@ class FakeFunctionSource: jedi_definition: JediDefinition +class Args: + disable_imports_sorting = True + formatter_cmds = ["disabled"] + + +def test_code_replacement_global_statements(): + optimized_code = """import numpy as np +inconsequential_var = '123' +def sorter(arr): + return arr.sort()""" + code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_optimized.py").resolve() + original_code_str = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort.py").read_text( + encoding="utf-8" + ) + code_path.write_text(original_code_str, encoding="utf-8") + tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/") + project_root_path = (Path(__file__).parent / "..").resolve() + func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_path) + test_config = TestConfig( + tests_root=tests_root, + tests_project_rootdir=project_root_path, + project_root_path=project_root_path, + test_framework="pytest", + pytest_cmd="pytest", + ) + func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config) + code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap() + original_helper_code: dict[Path, str] = {} + helper_function_paths = {hf.file_path for hf in code_context.helper_functions} + for helper_function_path in helper_function_paths: + with helper_function_path.open(encoding="utf8") as f: + helper_code = f.read() + original_helper_code[helper_function_path] = helper_code + func_optimizer.args = Args() + func_optimizer.replace_function_and_helpers_with_optimized_code( + code_context=code_context, optimized_code=optimized_code + ) + final_output = code_path.read_text(encoding="utf-8") + assert "inconsequential_var = '123'" in final_output + code_path.unlink(missing_ok=True) + + def test_test_libcst_code_replacement() -> None: optim_code = """import libcst as cst from typing import Optional @@ -74,7 +116,7 @@ def totally_new_function(value): """ function_name: str = "NewClass.new_function" - preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) + preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code) new_code: str = replace_functions_and_add_imports( source_code=original_code, function_names=[function_name], @@ -135,7 +177,7 @@ def other_function(st): """ function_name: str = "NewClass.new_function" - preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) + preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code) new_code: str = replace_functions_and_add_imports( source_code=original_code, function_names=[function_name], @@ -196,7 +238,7 @@ def totally_new_function(value): """ function_names: list[str] = ["other_function"] - preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) + preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code) new_code: str = replace_functions_and_add_imports( source_code=original_code, function_names=function_names, @@ -260,7 +302,7 @@ def totally_new_function(value): """ function_names: list[str] = ["yet_another_function", "other_function"] - preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) + preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code) new_code: str = replace_functions_and_add_imports( source_code=original_code, function_names=function_names, @@ -313,7 +355,7 @@ def supersort(doink): """ function_names: list[str] = ["sorter_deps"] - preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) + preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code) new_code: str = replace_functions_and_add_imports( source_code=original_code, function_names=function_names, @@ -591,7 +633,7 @@ def from_config(config: Optional[dict[str, Any]]): ) """ function_names: list[str] = ["CacheSimilarityEvalConfig.from_config"] - preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) + preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code) new_code: str = replace_functions_and_add_imports( source_code=original_code, @@ -662,7 +704,7 @@ def _hamming_distance(a: np.ndarray, b: np.ndarray) -> np.floating: return np.sum(a != b) / a.size ''' function_names: list[str] = ["_EmbeddingDistanceChainMixin._hamming_distance"] - preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) + preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code) new_code: str = replace_functions_and_add_imports( source_code=original_code, function_names=function_names, @@ -715,7 +757,7 @@ def totally_new_function(value: Optional[str]): print("Hello world") """ function_name: str = "NewClass.__init__" - preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) + preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code) new_code: str = replace_functions_and_add_imports( source_code=original_code, function_names=[function_name], @@ -762,8 +804,7 @@ def __init__(self, name): self.name = name def main_method(self): - return HelperClass(self.name).helper_method() -""" + return HelperClass(self.name).helper_method()""" file_path = Path(__file__).resolve() func_top_optimize = FunctionToOptimize( function_name="main_method", file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")] @@ -811,7 +852,7 @@ def real_bar(self) -> int: function_name: str = "Fu.foo" parents = (FunctionParent("Fu", "ClassDef"),) - preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = {("foo", parents), ("real_bar", parents)} + preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = {("foo", parents), ("real_bar", parents)} new_code: str = replace_functions_in_file( source_code=original_code, original_function_names=[function_name], @@ -850,7 +891,7 @@ def real_bar(self) -> int: pass ''' - preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = [] + preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = [] new_code: str = replace_functions_in_file( source_code=original_code, original_function_names=["Fu.real_bar"], @@ -887,7 +928,7 @@ def __call__(self, value): """ function_names: list[str] = ["yet_another_function", "other_function"] - preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = [] + preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = [] new_code: str = replace_functions_and_add_imports( source_code=original_code, function_names=function_names, @@ -1098,8 +1139,8 @@ def get_test_pass_fail_report_by_type(self) -> dict[TestType, dict[str, int]]: ) assert ( - new_code - == """from __future__ import annotations + new_code + == """from __future__ import annotations import sys from codeflash.verification.comparator import comparator from enum import Enum @@ -1274,7 +1315,7 @@ def cosine_similarity_top_k( return ret_idxs, scores ''' - preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) + preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code) helper_functions = [ FakeFunctionSource( @@ -1304,8 +1345,8 @@ def cosine_similarity_top_k( project_root_path=Path(__file__).parent.parent.resolve(), ) assert ( - new_code - == '''import numpy as np + new_code + == '''import numpy as np from pydantic.dataclasses import dataclass from typing import List, Optional, Tuple, Union @dataclass(config=dict(arbitrary_types_allowed=True)) @@ -1363,8 +1404,8 @@ def cosine_similarity_top_k( ) assert ( - new_helper_code - == '''import numpy as np + new_helper_code + == '''import numpy as np from pydantic.dataclasses import dataclass from typing import List, Optional, Tuple, Union @dataclass(config=dict(arbitrary_types_allowed=True)) @@ -1575,7 +1616,7 @@ def nested_function(self): "NewClass.new_function2", "NestedClass.nested_function", ] # Nested classes should be ignored, even if provided as target - preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) + preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code) new_code: str = replace_functions_and_add_imports( source_code=original_code, function_names=function_names, @@ -1609,9 +1650,8 @@ def new_function2(value): print("Hello world") """ - function_names: list[str] = ["NewClass.__init__", "NewClass.__call__", "NewClass.new_function2"] - preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]] = find_preexisting_objects(original_code) + preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code) new_code: str = replace_functions_and_add_imports( source_code=original_code, function_names=function_names, @@ -1621,3 +1661,474 @@ def new_function2(value): project_root_path=Path(__file__).resolve().parent.resolve(), ) assert new_code == original_code + +def test_global_reassignment() -> None: + original_code = """a=1 +print("Hello world") +def some_fn(): + print("did noting") +class NewClass: + def __init__(self, name): + self.name = name + def __call__(self, value): + return "I am still old" + def new_function2(value): + return cst.ensure_type(value, str) + """ + optimized_code = """import numpy as np +def some_fn(): + a=np.zeros(10) + print("did something") +class NewClass: + def __init__(self, name): + self.name = name + def __call__(self, value): + return "I am still old" + def new_function2(value): + return cst.ensure_type(value, str) +a=2 +print("Hello world") + """ + expected_code = """import numpy as np +print("Hello world") + +a=2 +print("Hello world") +def some_fn(): + a=np.zeros(10) + print("did something") +class NewClass: + def __init__(self, name): + self.name = name + def __call__(self, value): + return "I am still old" + def new_function2(value): + return cst.ensure_type(value, str) + def __init__(self, name): + self.name = name + def __call__(self, value): + return "I am still old" + def new_function2(value): + return cst.ensure_type(value, str) +""" + code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve() + code_path.write_text(original_code, encoding="utf-8") + tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/") + project_root_path = (Path(__file__).parent / "..").resolve() + func = FunctionToOptimize(function_name="some_fn", parents=[], file_path=code_path) + test_config = TestConfig( + tests_root=tests_root, + tests_project_rootdir=project_root_path, + project_root_path=project_root_path, + test_framework="pytest", + pytest_cmd="pytest", + ) + func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config) + code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap() + original_helper_code: dict[Path, str] = {} + helper_function_paths = {hf.file_path for hf in code_context.helper_functions} + for helper_function_path in helper_function_paths: + with helper_function_path.open(encoding="utf8") as f: + helper_code = f.read() + original_helper_code[helper_function_path] = helper_code + func_optimizer.args = Args() + func_optimizer.replace_function_and_helpers_with_optimized_code( + code_context=code_context, optimized_code=optimized_code + ) + new_code = code_path.read_text(encoding="utf-8") + code_path.unlink(missing_ok=True) + assert new_code.rstrip() == expected_code.rstrip() + + original_code = """print("Hello world") +def some_fn(): + print("did noting") +class NewClass: + def __init__(self, name): + self.name = name + def __call__(self, value): + return "I am still old" + def new_function2(value): + return cst.ensure_type(value, str) +a=1 +""" + optimized_code = """a=2 +import numpy as np +def some_fn(): + a=np.zeros(10) + print("did something") +class NewClass: + def __init__(self, name): + self.name = name + def __call__(self, value): + return "I am still old" + def new_function2(value): + return cst.ensure_type(value, str) +print("Hello world") + """ + expected_code = """import numpy as np +print("Hello world") + +print("Hello world") +def some_fn(): + a=np.zeros(10) + print("did something") +class NewClass: + def __init__(self, name): + self.name = name + def __call__(self, value): + return "I am still old" + def new_function2(value): + return cst.ensure_type(value, str) + def __init__(self, name): + self.name = name + def __call__(self, value): + return "I am still old" + def new_function2(value): + return cst.ensure_type(value, str) +a=2 +""" + code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve() + code_path.write_text(original_code, encoding="utf-8") + tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/") + project_root_path = (Path(__file__).parent / "..").resolve() + func = FunctionToOptimize(function_name="some_fn", parents=[], file_path=code_path) + test_config = TestConfig( + tests_root=tests_root, + tests_project_rootdir=project_root_path, + project_root_path=project_root_path, + test_framework="pytest", + pytest_cmd="pytest", + ) + func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config) + code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap() + original_helper_code: dict[Path, str] = {} + helper_function_paths = {hf.file_path for hf in code_context.helper_functions} + for helper_function_path in helper_function_paths: + with helper_function_path.open(encoding="utf8") as f: + helper_code = f.read() + original_helper_code[helper_function_path] = helper_code + func_optimizer.args = Args() + func_optimizer.replace_function_and_helpers_with_optimized_code( + code_context=code_context, optimized_code=optimized_code + ) + new_code = code_path.read_text(encoding="utf-8") + code_path.unlink(missing_ok=True) + assert new_code.rstrip() == expected_code.rstrip() + + original_code = """a=1 +print("Hello world") +def some_fn(): + print("did noting") +class NewClass: + def __init__(self, name): + self.name = name + def __call__(self, value): + return "I am still old" + def new_function2(value): + return cst.ensure_type(value, str) +""" + optimized_code = """import numpy as np +a=2 +def some_fn(): + a=np.zeros(10) + print("did something") +class NewClass: + def __init__(self, name): + self.name = name + def __call__(self, value): + return "I am still old" + def new_function2(value): + return cst.ensure_type(value, str) +a=3 +print("Hello world") + """ + expected_code = """import numpy as np +print("Hello world") + +a=3 +print("Hello world") +def some_fn(): + a=np.zeros(10) + print("did something") +class NewClass: + def __init__(self, name): + self.name = name + def __call__(self, value): + return "I am still old" + def new_function2(value): + return cst.ensure_type(value, str) + def __init__(self, name): + self.name = name + def __call__(self, value): + return "I am still old" + def new_function2(value): + return cst.ensure_type(value, str) +""" + code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve() + code_path.write_text(original_code, encoding="utf-8") + tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/") + project_root_path = (Path(__file__).parent / "..").resolve() + func = FunctionToOptimize(function_name="some_fn", parents=[], file_path=code_path) + test_config = TestConfig( + tests_root=tests_root, + tests_project_rootdir=project_root_path, + project_root_path=project_root_path, + test_framework="pytest", + pytest_cmd="pytest", + ) + func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config) + code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap() + original_helper_code: dict[Path, str] = {} + helper_function_paths = {hf.file_path for hf in code_context.helper_functions} + for helper_function_path in helper_function_paths: + with helper_function_path.open(encoding="utf8") as f: + helper_code = f.read() + original_helper_code[helper_function_path] = helper_code + func_optimizer.args = Args() + func_optimizer.replace_function_and_helpers_with_optimized_code( + code_context=code_context, optimized_code=optimized_code + ) + new_code = code_path.read_text(encoding="utf-8") + code_path.unlink(missing_ok=True) + assert new_code.rstrip() == expected_code.rstrip() + + original_code = """a=1 +print("Hello world") +def some_fn(): + print("did noting") +class NewClass: + def __init__(self, name): + self.name = name + def __call__(self, value): + return "I am still old" + def new_function2(value): + return cst.ensure_type(value, str) +""" + optimized_code = """a=2 +import numpy as np +def some_fn(): + a=np.zeros(10) + print("did something") +class NewClass: + def __init__(self, name): + self.name = name + def __call__(self, value): + return "I am still old" + def new_function2(value): + return cst.ensure_type(value, str) +print("Hello world") + """ + expected_code = """import numpy as np +print("Hello world") + +a=2 +print("Hello world") +def some_fn(): + a=np.zeros(10) + print("did something") +class NewClass: + def __init__(self, name): + self.name = name + def __call__(self, value): + return "I am still old" + def new_function2(value): + return cst.ensure_type(value, str) + def __init__(self, name): + self.name = name + def __call__(self, value): + return "I am still old" + def new_function2(value): + return cst.ensure_type(value, str) +""" + code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve() + code_path.write_text(original_code, encoding="utf-8") + tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/") + project_root_path = (Path(__file__).parent / "..").resolve() + func = FunctionToOptimize(function_name="some_fn", parents=[], file_path=code_path) + test_config = TestConfig( + tests_root=tests_root, + tests_project_rootdir=project_root_path, + project_root_path=project_root_path, + test_framework="pytest", + pytest_cmd="pytest", + ) + func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config) + code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap() + original_helper_code: dict[Path, str] = {} + helper_function_paths = {hf.file_path for hf in code_context.helper_functions} + for helper_function_path in helper_function_paths: + with helper_function_path.open(encoding="utf8") as f: + helper_code = f.read() + original_helper_code[helper_function_path] = helper_code + func_optimizer.args = Args() + func_optimizer.replace_function_and_helpers_with_optimized_code( + code_context=code_context, optimized_code=optimized_code + ) + new_code = code_path.read_text(encoding="utf-8") + code_path.unlink(missing_ok=True) + assert new_code.rstrip() == expected_code.rstrip() + + original_code = """a=1 +print("Hello world") +def some_fn(): + print("did noting") +class NewClass: + def __init__(self, name): + self.name = name + def __call__(self, value): + return "I am still old" + def new_function2(value): + return cst.ensure_type(value, str) +""" + optimized_code = """import numpy as np +a=2 +def some_fn(): + a=np.zeros(10) + print("did something") +class NewClass: + def __init__(self, name): + self.name = name + def __call__(self, value): + return "I am still old" + def new_function2(value): + return cst.ensure_type(value, str) +a=3 +print("Hello world") + """ + expected_code = """import numpy as np +print("Hello world") + +a=3 +print("Hello world") +def some_fn(): + a=np.zeros(10) + print("did something") +class NewClass: + def __init__(self, name): + self.name = name + def __call__(self, value): + return "I am still old" + def new_function2(value): + return cst.ensure_type(value, str) + def __init__(self, name): + self.name = name + def __call__(self, value): + return "I am still old" + def new_function2(value): + return cst.ensure_type(value, str) +""" + code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve() + code_path.write_text(original_code, encoding="utf-8") + tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/") + project_root_path = (Path(__file__).parent / "..").resolve() + func = FunctionToOptimize(function_name="some_fn", parents=[], file_path=code_path) + test_config = TestConfig( + tests_root=tests_root, + tests_project_rootdir=project_root_path, + project_root_path=project_root_path, + test_framework="pytest", + pytest_cmd="pytest", + ) + func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config) + code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap() + original_helper_code: dict[Path, str] = {} + helper_function_paths = {hf.file_path for hf in code_context.helper_functions} + for helper_function_path in helper_function_paths: + with helper_function_path.open(encoding="utf8") as f: + helper_code = f.read() + original_helper_code[helper_function_path] = helper_code + func_optimizer.args = Args() + func_optimizer.replace_function_and_helpers_with_optimized_code( + code_context=code_context, optimized_code=optimized_code + ) + new_code = code_path.read_text(encoding="utf-8") + code_path.unlink(missing_ok=True) + assert new_code.rstrip() == expected_code.rstrip() + + original_code = """if 2<3: + a=4 +else: + a=5 +print("Hello world") +def some_fn(): + print("did noting") +class NewClass: + def __init__(self, name): + self.name = name + def __call__(self, value): + return "I am still old" + def new_function2(value): + return cst.ensure_type(value, str) +""" + optimized_code = """import numpy as np +if 1<2: + a=2 +else: + a=3 +a = 6 +def some_fn(): + a=np.zeros(10) + print("did something") +class NewClass: + def __init__(self, name): + self.name = name + def __call__(self, value): + return "I am still old" + def new_function2(value): + return cst.ensure_type(value, str) +print("Hello world") +""" + expected_code = """import numpy as np +print("Hello world") + +if 2<3: + a=4 +else: + a=5 +print("Hello world") +def some_fn(): + a=np.zeros(10) + print("did something") +class NewClass: + def __init__(self, name): + self.name = name + def __call__(self, value): + return "I am still old" + def new_function2(value): + return cst.ensure_type(value, str) + def __init__(self, name): + self.name = name + def __call__(self, value): + return "I am still old" + def new_function2(value): + return cst.ensure_type(value, str) + +a = 6 +""" + code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve() + code_path.write_text(original_code, encoding="utf-8") + tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/") + project_root_path = (Path(__file__).parent / "..").resolve() + func = FunctionToOptimize(function_name="some_fn", parents=[], file_path=code_path) + test_config = TestConfig( + tests_root=tests_root, + tests_project_rootdir=project_root_path, + project_root_path=project_root_path, + test_framework="pytest", + pytest_cmd="pytest", + ) + func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config) + code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap() + original_helper_code: dict[Path, str] = {} + helper_function_paths = {hf.file_path for hf in code_context.helper_functions} + for helper_function_path in helper_function_paths: + with helper_function_path.open(encoding="utf8") as f: + helper_code = f.read() + original_helper_code[helper_function_path] = helper_code + func_optimizer.args = Args() + func_optimizer.replace_function_and_helpers_with_optimized_code( + code_context=code_context, optimized_code=optimized_code + ) + new_code = code_path.read_text(encoding="utf-8") + code_path.unlink(missing_ok=True) + assert new_code.rstrip() == expected_code.rstrip() \ No newline at end of file