Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
c75bbf6
check large diffs with black, and skipp formatting in such case (afte…
mohammedahmed18 Jun 3, 2025
5cd13ad
new line
mohammedahmed18 Jun 3, 2025
1522227
better log messages
mohammedahmed18 Jun 3, 2025
d3ca1cb
remove unnecessary check
mohammedahmed18 Jun 3, 2025
dcb084a
new line
mohammedahmed18 Jun 3, 2025
689a2d9
remove unused comment
mohammedahmed18 Jun 3, 2025
44c0f85
the max lines for formatting changes to 100
mohammedahmed18 Jun 3, 2025
73ef518
refactoring
mohammedahmed18 Jun 3, 2025
a5343fd
refactoring and improvements
mohammedahmed18 Jun 3, 2025
395855d
added black as dev dependency
mohammedahmed18 Jun 3, 2025
822d6cc
made some refactor changes that codeflash suggested
mohammedahmed18 Jun 3, 2025
ce15022
remove unused function
mohammedahmed18 Jun 3, 2025
d2a8711
formatting & using internal black dep
mohammedahmed18 Jun 3, 2025
f46b368
fix black import issue
mohammedahmed18 Jun 4, 2025
6504cc4
handle formatting files with no formatting issues
mohammedahmed18 Jun 4, 2025
aed490d
Merge branch 'main' into skip-formatting-for-large-diffs
Saga4 Jun 4, 2025
82a4ee1
use user pre-defined formatting commands, instead of using black
mohammedahmed18 Jun 4, 2025
90014bd
Merge branch 'skip-formatting-for-large-diffs' of github.com:codeflas…
mohammedahmed18 Jun 4, 2025
caeda49
make sure format_code recieves file path as path type not as str
mohammedahmed18 Jun 4, 2025
6967fcb
formatting and linting
mohammedahmed18 Jun 4, 2025
8248c8e
typo
mohammedahmed18 Jun 4, 2025
15aacdb
revert lock file changes
mohammedahmed18 Jun 4, 2025
c24fc90
remove comment
mohammedahmed18 Jun 4, 2025
b48e9e6
pass helper functions source code to the formatter for diff checking
mohammedahmed18 Jun 5, 2025
93070a9
Merge branch 'main' of github.com:codeflash-ai/codeflash into skip-fo…
mohammedahmed18 Jun 6, 2025
64f2dd9
more unit tests
mohammedahmed18 Jun 6, 2025
a1510a3
enhancements
mohammedahmed18 Jun 6, 2025
6f97004
Merge branch 'main' into skip-formatting-for-large-diffs
Saga4 Jun 10, 2025
6cb8469
Update formatter.py
Saga4 Jun 10, 2025
94e64d3
Update formatter.py
Saga4 Jun 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 108 additions & 12 deletions codeflash/code_utils/formatter.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,78 @@
from __future__ import annotations

import difflib
import os
import re
import shlex
import shutil
import subprocess
from typing import TYPE_CHECKING
import tempfile
from pathlib import Path
from typing import Optional, Union

import isort

from codeflash.cli_cmds.console import console, logger

if TYPE_CHECKING:
from pathlib import Path

def generate_unified_diff(original: str, modified: str, from_file: str, to_file: str) -> str:
line_pattern = re.compile(r"(.*?(?:\r\n|\n|\r|$))")

def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True) -> str: # noqa
def split_lines(text: str) -> list[str]:
lines = [match[0] for match in line_pattern.finditer(text)]
if lines and lines[-1] == "":
lines.pop()
return lines

original_lines = split_lines(original)
modified_lines = split_lines(modified)

diff_output = []
for line in difflib.unified_diff(original_lines, modified_lines, fromfile=from_file, tofile=to_file, n=5):
if line.endswith("\n"):
diff_output.append(line)
else:
diff_output.append(line + "\n")
diff_output.append("\\ No newline at end of file\n")

return "".join(diff_output)


def apply_formatter_cmds(
cmds: list[str],
path: Path,
test_dir_str: Optional[str],
print_status: bool, # noqa
) -> tuple[Path, str]:
# TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
formatter_name = formatter_cmds[0].lower()
formatter_name = cmds[0].lower()
should_make_copy = False
file_path = path

if test_dir_str:
should_make_copy = True
file_path = Path(test_dir_str) / "temp.py"

if not cmds or formatter_name == "disabled":
return path, path.read_text(encoding="utf8")

if not path.exists():
msg = f"File {path} does not exist. Cannot format the file."
msg = f"File {path} does not exist. Cannot apply formatter commands."
raise FileNotFoundError(msg)
if formatter_name == "disabled":
return path.read_text(encoding="utf8")

if should_make_copy:
shutil.copy2(path, file_path)

file_token = "$file" # noqa: S105
for command in formatter_cmds:

for command in cmds:
formatter_cmd_list = shlex.split(command, posix=os.name != "nt")
formatter_cmd_list = [path.as_posix() if chunk == file_token else chunk for chunk in formatter_cmd_list]
formatter_cmd_list = [file_path.as_posix() if chunk == file_token else chunk for chunk in formatter_cmd_list]
try:
result = subprocess.run(formatter_cmd_list, capture_output=True, check=False)
if result.returncode == 0:
if print_status:
console.rule(f"Formatted Successfully with: {formatter_name.replace('$file', path.name)}")
console.rule(f"Formatted Successfully with: {command.replace('$file', path.name)}")
else:
logger.error(f"Failed to format code with {' '.join(formatter_cmd_list)}")
except FileNotFoundError as e:
Expand All @@ -44,7 +87,60 @@ def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True

raise e from None

return path.read_text(encoding="utf8")
return file_path, file_path.read_text(encoding="utf8")


def get_diff_lines_count(diff_output: str) -> int:
lines = diff_output.split("\n")

def is_diff_line(line: str) -> bool:
return line.startswith(("+", "-")) and not line.startswith(("+++", "---"))

diff_lines = [line for line in lines if is_diff_line(line)]
return len(diff_lines)


def format_code(
formatter_cmds: list[str],
path: Union[str, Path],
optimized_function: str = "",
check_diff: bool = False, # noqa
print_status: bool = True, # noqa
) -> str:
with tempfile.TemporaryDirectory() as test_dir_str:
if isinstance(path, str):
path = Path(path)

original_code = path.read_text(encoding="utf8")
original_code_lines = len(original_code.split("\n"))

if check_diff and original_code_lines > 50:
# we dont' count the formatting diff for the optimized function as it should be well-formatted
original_code_without_opfunc = original_code.replace(optimized_function, "")

original_temp = Path(test_dir_str) / "original_temp.py"
original_temp.write_text(original_code_without_opfunc, encoding="utf8")

formatted_temp, formatted_code = apply_formatter_cmds(
formatter_cmds, original_temp, test_dir_str, print_status=False
)

diff_output = generate_unified_diff(
original_code_without_opfunc, formatted_code, from_file=str(original_temp), to_file=str(formatted_temp)
)
diff_lines_count = get_diff_lines_count(diff_output)

max_diff_lines = min(int(original_code_lines * 0.3), 50)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we hardcoding this 30% or 50 lines logic?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The overall formatting other than the optimized function could be too small but annoying if its around formatting.
And when we are only looking at code other than optimzied function, there could be changes in helper functions or imports or global variables too, so how are we sure of this count of 50?


if diff_lines_count > max_diff_lines and max_diff_lines != -1:
logger.debug(
f"Skipping formatting {path}: {diff_lines_count} lines would change (max: {max_diff_lines})"
)
return original_code
# TODO : We can avoid formatting the whole file again and only formatting the optimized code standalone and replace in formatted file above.
_, formatted_code = apply_formatter_cmds(formatter_cmds, path, test_dir_str=None, print_status=print_status)
logger.debug(f"Formatted {path} with commands: {formatter_cmds}")
return formatted_code


def sort_imports(code: str) -> str:
Expand Down
18 changes: 12 additions & 6 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,10 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
)

new_code, new_helper_code = self.reformat_code_and_helpers(
code_context.helper_functions, explanation.file_path, self.function_to_optimize_source_code
code_context.helper_functions,
explanation.file_path,
self.function_to_optimize_source_code,
optimized_function=best_optimization.candidate.source_code,
)

existing_tests = existing_tests_source_for(
Expand Down Expand Up @@ -642,20 +645,23 @@ def write_code_and_helpers(original_code: str, original_helper_code: dict[Path,
f.write(helper_code)

def reformat_code_and_helpers(
self, helper_functions: list[FunctionSource], path: Path, original_code: str
self, helper_functions: list[FunctionSource], path: Path, original_code: str, optimized_function: str
) -> tuple[str, dict[Path, str]]:
should_sort_imports = not self.args.disable_imports_sorting
if should_sort_imports and isort.code(original_code) != original_code:
should_sort_imports = False

new_code = format_code(self.args.formatter_cmds, path)
new_code = format_code(self.args.formatter_cmds, path, optimized_function=optimized_function, check_diff=True)
if should_sort_imports:
new_code = sort_imports(new_code)

new_helper_code: dict[Path, str] = {}
helper_functions_paths = {hf.file_path for hf in helper_functions}
for module_abspath in helper_functions_paths:
formatted_helper_code = format_code(self.args.formatter_cmds, module_abspath)
for hp in helper_functions:
module_abspath = hp.file_path
hp_source_code = hp.source_code
formatted_helper_code = format_code(
self.args.formatter_cmds, module_abspath, optimized_function=hp_source_code, check_diff=True
)
if should_sort_imports:
formatted_helper_code = sort_imports(formatted_helper_code)
new_helper_code[module_abspath] = formatted_helper_code
Expand Down
Loading
Loading