diff --git a/.gitignore b/.gitignore index 535acfb3e..b4a99e8c2 100644 --- a/.gitignore +++ b/.gitignore @@ -254,3 +254,5 @@ fabric.properties # Mac .DS_Store + +scratch/ diff --git a/code_to_optimize/bubble_sort_method_preserve_bad_formatting_for_nonoptimized_code.py b/code_to_optimize/bubble_sort_method_preserve_bad_formatting_for_nonoptimized_code.py new file mode 100644 index 000000000..29a00a922 --- /dev/null +++ b/code_to_optimize/bubble_sort_method_preserve_bad_formatting_for_nonoptimized_code.py @@ -0,0 +1,38 @@ +import sys + + +def lol(): + print( "lol" ) + + + + + + + + + +class BubbleSorter: + def __init__(self, x=0): + self.x = x + + def lol(self): + print( "lol" ) + + + + + + + + + def sorter(self, arr): + print("codeflash stdout : BubbleSorter.sorter() called") + for i in range(len(arr)): + for j in range(len(arr) - 1): + if arr[j] > arr[j + 1]: + temp = arr[j] + arr[j] = arr[j + 1] + arr[j + 1] = temp + print("stderr test", file=sys.stderr) + return arr diff --git a/code_to_optimize/bubble_sort_preserve_bad_formatting_for_nonoptimized_code.py b/code_to_optimize/bubble_sort_preserve_bad_formatting_for_nonoptimized_code.py new file mode 100644 index 000000000..b506ddfbb --- /dev/null +++ b/code_to_optimize/bubble_sort_preserve_bad_formatting_for_nonoptimized_code.py @@ -0,0 +1,19 @@ +def lol(): + print( "lol" ) + + + + + + + +def sorter(arr): + print("codeflash stdout: Sorting list") + for i in range(len(arr)): + for j in range(len(arr) - 1): + if arr[j] > arr[j + 1]: + temp = arr[j] + arr[j] = arr[j + 1] + arr[j + 1] = temp + print(f"result: {arr}") + return arr diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index 875fd0a1f..8f426ad8a 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -55,3 +55,12 @@ def sort_imports(code: str) -> str: return code # Fall back to original code if isort fails return sorted_code + + +def sort_imports_in_place(paths: list[Path]) -> None: + for path in paths: + if path.exists(): + content = path.read_text(encoding="utf8") + sorted_content = sort_imports(content) + if sorted_content != content: + path.write_text(sorted_content, encoding="utf8") diff --git a/codeflash/main.py b/codeflash/main.py index 02b13d5aa..78a66acf1 100644 --- a/codeflash/main.py +++ b/codeflash/main.py @@ -2,6 +2,7 @@ solved problem, please reach out to us at careers@codeflash.ai. We're hiring! """ +import os from pathlib import Path from codeflash.cli_cmds.cli import parse_args, process_pyproject_config @@ -20,12 +21,12 @@ def main() -> None: CODEFLASH_LOGO, panel_args={"title": "https://codeflash.ai", "expand": False}, text_args={"style": "bold gold3"} ) args = parse_args() + if args.command: - if args.config_file and Path.exists(args.config_file): + disable_telemetry = os.environ.get("CODEFLASH_DISABLE_TELEMETRY", "").lower() in {"true", "t", "1", "yes", "y"} + if (not disable_telemetry) and args.config_file and Path.exists(args.config_file): pyproject_config, _ = parse_config_file(args.config_file) disable_telemetry = pyproject_config.get("disable_telemetry", False) - else: - disable_telemetry = False init_sentry(not disable_telemetry, exclude_errors=True) posthog_cf.initialize_posthog(not disable_telemetry) args.func() diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 56124a9cb..369b081fd 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -2,9 +2,11 @@ import ast import concurrent.futures +import dataclasses import os import shutil import subprocess +import tempfile import time import uuid from collections import defaultdict, deque @@ -36,7 +38,7 @@ N_TESTS_TO_GENERATE, TOTAL_LOOPING_TIME, ) -from codeflash.code_utils.formatter import format_code, sort_imports +from codeflash.code_utils.formatter import format_code, sort_imports_in_place from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test from codeflash.code_utils.line_profile_utils import add_decorator_imports from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests @@ -124,6 +126,7 @@ def __init__( self.function_benchmark_timings = function_benchmark_timings if function_benchmark_timings else {} self.total_benchmark_timings = total_benchmark_timings if total_benchmark_timings else {} self.replay_tests_dir = replay_tests_dir if replay_tests_dir else None + self.optimizer_temp_dir = Path(tempfile.mkdtemp(prefix="codeflash_opt_fmt_")) def optimize_function(self) -> Result[BestOptimization, str]: should_run_experiment = self.experiment_id is not None @@ -301,9 +304,18 @@ def optimize_function(self) -> Result[BestOptimization, str]: code_context=code_context, optimized_code=best_optimization.candidate.source_code ) - new_code, new_helper_code = self.reformat_code_and_helpers( - code_context.helper_functions, explanation.file_path, self.function_to_optimize_source_code - ) + if not self.args.disable_imports_sorting: + path_to_sort_imports_for = [self.function_to_optimize.file_path] + [hf.file_path for hf in code_context.helper_functions] + sort_imports_in_place(path_to_sort_imports_for) + + new_code = self.function_to_optimize.file_path.read_text(encoding="utf8") + new_helper_code: dict[Path, str] = {} + for helper_file_path_key in original_helper_code: + if helper_file_path_key.exists(): + new_helper_code[helper_file_path_key] = helper_file_path_key.read_text(encoding="utf8") + else: + logger.warning(f"Helper file {helper_file_path_key} not found after optimization. It will not be included in new_helper_code for PR.") + existing_tests = existing_tests_source_for( self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root), @@ -405,6 +417,33 @@ def determine_best_candidate( future_line_profile_results = None candidate_index += 1 candidate = candidates.popleft() + + formatted_candidate_code = candidate.source_code + if self.args.formatter_cmds: + temp_code_file_path: Path | None = None + try: + with tempfile.NamedTemporaryFile( + mode="w", + suffix=".py", + delete=False, + encoding="utf8", + dir=self.optimizer_temp_dir + ) as tmp_file: + tmp_file.write(candidate.source_code) + temp_code_file_path = Path(tmp_file.name) + + formatted_candidate_code = format_code( + formatter_cmds=self.args.formatter_cmds, + path=temp_code_file_path + ) + except Exception as e: + logger.error(f"Error during formatting candidate code via temp file: {e}. Using original candidate code.") + finally: + if temp_code_file_path and temp_code_file_path.exists(): + temp_code_file_path.unlink(missing_ok=True) + + candidate = dataclasses.replace(candidate, source_code=formatted_candidate_code) + get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True) get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True) logger.info(f"Optimization candidate {candidate_index}/{original_len}:") @@ -580,27 +619,6 @@ def write_code_and_helpers(original_code: str, original_helper_code: dict[Path, with Path(module_abspath).open("w", encoding="utf8") as f: f.write(original_helper_code[module_abspath]) - def reformat_code_and_helpers( - self, helper_functions: list[FunctionSource], path: Path, original_code: str - ) -> tuple[str, dict[Path, str]]: - should_sort_imports = not self.args.disable_imports_sorting - if should_sort_imports and isort.code(original_code) != original_code: - should_sort_imports = False - - new_code = format_code(self.args.formatter_cmds, path) - if should_sort_imports: - new_code = sort_imports(new_code) - - new_helper_code: dict[Path, str] = {} - helper_functions_paths = {hf.file_path for hf in helper_functions} - for module_abspath in helper_functions_paths: - formatted_helper_code = format_code(self.args.formatter_cmds, module_abspath) - if should_sort_imports: - formatted_helper_code = sort_imports(formatted_helper_code) - new_helper_code[module_abspath] = formatted_helper_code - - return new_code, new_helper_code - def replace_function_and_helpers_with_optimized_code( self, code_context: CodeOptimizationContext, optimized_code: str ) -> bool: diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 5c0a91c38..4f2ac6d9d 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -5,7 +5,7 @@ import pytest from codeflash.code_utils.config_parser import parse_config_file -from codeflash.code_utils.formatter import format_code, sort_imports +from codeflash.code_utils.formatter import format_code, sort_imports, sort_imports_in_place def test_remove_duplicate_imports(): @@ -30,6 +30,23 @@ def test_sorting_imports(): new_code = sort_imports(original_code) assert new_code == "import os\nimport sys\nimport unittest\n" +def test_sort_imports_in_place(): + """Test that sorting imports in place in multiple files works.""" + original_code = "import sys\nimport unittest\nimport os\n" + expected_code = "import os\nimport sys\nimport unittest\n" + + with tempfile.TemporaryDirectory() as tmpdir: + file_paths = [] + for i in range(3): + file_path = Path(tmpdir) / f"test_file_{i}.py" + file_path.write_text(original_code, encoding="utf8") + file_paths.append(file_path) + + sort_imports_in_place(file_paths) + + for file_path in file_paths: + assert file_path.read_text(encoding="utf8") == expected_code + def test_sort_imports_without_formatting(): """Test that imports are sorted when formatting is disabled and should_sort_imports is True."""