Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 22 additions & 15 deletions codeflash/code_utils/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,22 @@
import shlex
import subprocess
from typing import TYPE_CHECKING, Optional

import isort

from codeflash.cli_cmds.console import console, logger

if TYPE_CHECKING:
from pathlib import Path


def get_nth_line(text: str, n: int) -> str | None:
for i, line in enumerate(text.splitlines(), start=1):
if i == n:
return line
return None


def get_diff_output(cmd: list[str]) -> Optional[str]:
try:
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
Expand All @@ -27,54 +30,58 @@ def get_diff_output(cmd: list[str]) -> Optional[str]:
is_ruff = cmd[0] == "ruff"
if e.returncode == 0 and is_ruff:
return ""
elif e.returncode == 1 and is_ruff:
if e.returncode == 1 and is_ruff:
return e.stdout.strip() or None
return None


def get_diff_lines_output_by_black(filepath: str) -> Optional[str]:
try:
import black # type: ignore
return get_diff_output(['black', '--diff', filepath])

return get_diff_output(["black", "--diff", filepath])
except ImportError:
return None


def get_diff_lines_output_by_ruff(filepath: str) -> Optional[str]:
try:
import ruff # type: ignore
return get_diff_output(['ruff', 'format', '--diff', filepath])

return get_diff_output(["ruff", "format", "--diff", filepath])
except ImportError:
print("can't import ruff")
return None


def get_diff_lines_count(diff_output: str) -> int:
lines = diff_output.split('\n')
def is_diff_line(line: str) -> bool:
return line.startswith(('+', '-')) and not line.startswith(('+++', '---'))
diff_lines = [line for line in lines if is_diff_line(line)]
return len(diff_lines)
# Count lines that are diff changes (start with '+' or '-', but not '+++' or '---')
count = 0
for line in diff_output.split("\n"):
if line and (line[0] in ("+", "-")) and not (line[:3] == "+++" or line[:3] == "---"):
count += 1
return count


def is_safe_to_format(filepath: str, max_diff_lines: int = 100) -> bool:
diff_changes_stdout = None

diff_changes_stdout = get_diff_lines_output_by_black(filepath)

if diff_changes_stdout is None:
logger.warning(f"black formatter not found, trying ruff instead...")
logger.warning("black formatter not found, trying ruff instead...")
diff_changes_stdout = get_diff_lines_output_by_ruff(filepath)
if diff_changes_stdout is None:
logger.warning(f"Both ruff, black formatters not found, skipping formatting diff check.")
logger.warning("Both ruff, black formatters not found, skipping formatting diff check.")
return False

diff_lines_count = get_diff_lines_count(diff_changes_stdout)

if diff_lines_count > max_diff_lines:
logger.debug(f"Skipping {filepath}: {diff_lines_count} lines would change (max: {max_diff_lines})")
return False
else:
return True

return True


def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True) -> str: # noqa
# TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
Expand Down
Loading