Skip to content
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,5 @@ fabric.properties

# Mac
.DS_Store

scratch/
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import sys


def lol():
print( "lol" )









class BubbleSorter:
def __init__(self, x=0):
self.x = x

def lol(self):
print( "lol" )








def sorter(self, arr):
print("codeflash stdout : BubbleSorter.sorter() called")
for i in range(len(arr)):
for j in range(len(arr) - 1):
if arr[j] > arr[j + 1]:
temp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = temp
print("stderr test", file=sys.stderr)
return arr
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
def lol():
print( "lol" )







def sorter(arr):
print("codeflash stdout: Sorting list")
for i in range(len(arr)):
for j in range(len(arr) - 1):
if arr[j] > arr[j + 1]:
temp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = temp
print(f"result: {arr}")
return arr
9 changes: 9 additions & 0 deletions codeflash/code_utils/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,12 @@ def sort_imports(code: str) -> str:
return code # Fall back to original code if isort fails

return sorted_code


def sort_imports_in_place(paths: list[Path]) -> None:
for path in paths:
if path.exists():
content = path.read_text(encoding="utf8")
sorted_content = sort_imports(content)
if sorted_content != content:
path.write_text(sorted_content, encoding="utf8")
7 changes: 4 additions & 3 deletions codeflash/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
solved problem, please reach out to us at careers@codeflash.ai. We're hiring!
"""

import os
from pathlib import Path

from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
Expand All @@ -20,12 +21,12 @@ def main() -> None:
CODEFLASH_LOGO, panel_args={"title": "https://codeflash.ai", "expand": False}, text_args={"style": "bold gold3"}
)
args = parse_args()

if args.command:
if args.config_file and Path.exists(args.config_file):
disable_telemetry = os.environ.get("CODEFLASH_DISABLE_TELEMETRY", "").lower() in {"true", "t", "1", "yes", "y"}
if (not disable_telemetry) and args.config_file and Path.exists(args.config_file):
pyproject_config, _ = parse_config_file(args.config_file)
disable_telemetry = pyproject_config.get("disable_telemetry", False)
else:
disable_telemetry = False
init_sentry(not disable_telemetry, exclude_errors=True)
posthog_cf.initialize_posthog(not disable_telemetry)
args.func()
Expand Down
68 changes: 43 additions & 25 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import ast
import concurrent.futures
import dataclasses
import os
import shutil
import subprocess
import tempfile
import time
import uuid
from collections import defaultdict, deque
Expand Down Expand Up @@ -36,7 +38,7 @@
N_TESTS_TO_GENERATE,
TOTAL_LOOPING_TIME,
)
from codeflash.code_utils.formatter import format_code, sort_imports
from codeflash.code_utils.formatter import format_code, sort_imports_in_place
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
from codeflash.code_utils.line_profile_utils import add_decorator_imports
from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests
Expand Down Expand Up @@ -124,6 +126,7 @@ def __init__(
self.function_benchmark_timings = function_benchmark_timings if function_benchmark_timings else {}
self.total_benchmark_timings = total_benchmark_timings if total_benchmark_timings else {}
self.replay_tests_dir = replay_tests_dir if replay_tests_dir else None
self.optimizer_temp_dir = Path(tempfile.mkdtemp(prefix="codeflash_opt_fmt_"))

def optimize_function(self) -> Result[BestOptimization, str]:
should_run_experiment = self.experiment_id is not None
Expand Down Expand Up @@ -301,9 +304,18 @@ def optimize_function(self) -> Result[BestOptimization, str]:
code_context=code_context, optimized_code=best_optimization.candidate.source_code
)

new_code, new_helper_code = self.reformat_code_and_helpers(
code_context.helper_functions, explanation.file_path, self.function_to_optimize_source_code
)
if not self.args.disable_imports_sorting:
path_to_sort_imports_for = [self.function_to_optimize.file_path] + [hf.file_path for hf in code_context.helper_functions]
sort_imports_in_place(path_to_sort_imports_for)

new_code = self.function_to_optimize.file_path.read_text(encoding="utf8")
new_helper_code: dict[Path, str] = {}
for helper_file_path_key in original_helper_code:
if helper_file_path_key.exists():
new_helper_code[helper_file_path_key] = helper_file_path_key.read_text(encoding="utf8")
else:
logger.warning(f"Helper file {helper_file_path_key} not found after optimization. It will not be included in new_helper_code for PR.")


existing_tests = existing_tests_source_for(
self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root),
Expand Down Expand Up @@ -405,6 +417,33 @@ def determine_best_candidate(
future_line_profile_results = None
candidate_index += 1
candidate = candidates.popleft()

formatted_candidate_code = candidate.source_code
if self.args.formatter_cmds:
temp_code_file_path: Path | None = None
try:
with tempfile.NamedTemporaryFile(
mode="w",
suffix=".py",
delete=False,
encoding="utf8",
dir=self.optimizer_temp_dir
) as tmp_file:
tmp_file.write(candidate.source_code)
temp_code_file_path = Path(tmp_file.name)

formatted_candidate_code = format_code(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

formatting code is depenedent on the cwd of the code, imports are grouped according to what module they belong to, for the project's own module they are grouped together. This determination of what is the module they belong to is determined by the cwd.
So we should not format code in a temp directory, the results may not be the same. This is btw why your unit tests were failing today

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's true, as this is working.

This is only formatting the function snippet and the helper snippets -- I don't think this has to do with why those unit tests are failing.

formatter_cmds=self.args.formatter_cmds,
path=temp_code_file_path
)
except Exception as e:
logger.error(f"Error during formatting candidate code via temp file: {e}. Using original candidate code.")
finally:
if temp_code_file_path and temp_code_file_path.exists():
temp_code_file_path.unlink(missing_ok=True)

candidate = dataclasses.replace(candidate, source_code=formatted_candidate_code)

get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
logger.info(f"Optimization candidate {candidate_index}/{original_len}:")
Expand Down Expand Up @@ -580,27 +619,6 @@ def write_code_and_helpers(original_code: str, original_helper_code: dict[Path,
with Path(module_abspath).open("w", encoding="utf8") as f:
f.write(original_helper_code[module_abspath])

def reformat_code_and_helpers(
self, helper_functions: list[FunctionSource], path: Path, original_code: str
) -> tuple[str, dict[Path, str]]:
should_sort_imports = not self.args.disable_imports_sorting
if should_sort_imports and isort.code(original_code) != original_code:
should_sort_imports = False

new_code = format_code(self.args.formatter_cmds, path)
if should_sort_imports:
new_code = sort_imports(new_code)

new_helper_code: dict[Path, str] = {}
helper_functions_paths = {hf.file_path for hf in helper_functions}
for module_abspath in helper_functions_paths:
formatted_helper_code = format_code(self.args.formatter_cmds, module_abspath)
if should_sort_imports:
formatted_helper_code = sort_imports(formatted_helper_code)
new_helper_code[module_abspath] = formatted_helper_code

return new_code, new_helper_code

def replace_function_and_helpers_with_optimized_code(
self, code_context: CodeOptimizationContext, optimized_code: str
) -> bool:
Expand Down
19 changes: 18 additions & 1 deletion tests/test_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

from codeflash.code_utils.config_parser import parse_config_file
from codeflash.code_utils.formatter import format_code, sort_imports
from codeflash.code_utils.formatter import format_code, sort_imports, sort_imports_in_place


def test_remove_duplicate_imports():
Expand All @@ -30,6 +30,23 @@ def test_sorting_imports():
new_code = sort_imports(original_code)
assert new_code == "import os\nimport sys\nimport unittest\n"

def test_sort_imports_in_place():
"""Test that sorting imports in place in multiple files works."""
original_code = "import sys\nimport unittest\nimport os\n"
expected_code = "import os\nimport sys\nimport unittest\n"

with tempfile.TemporaryDirectory() as tmpdir:
file_paths = []
for i in range(3):
file_path = Path(tmpdir) / f"test_file_{i}.py"
file_path.write_text(original_code, encoding="utf8")
file_paths.append(file_path)

sort_imports_in_place(file_paths)

for file_path in file_paths:
assert file_path.read_text(encoding="utf8") == expected_code


def test_sort_imports_without_formatting():
"""Test that imports are sorted when formatting is disabled and should_sort_imports is True."""
Expand Down
Loading