diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index 3d5b587c6..f76f6c0ad 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -4,6 +4,7 @@ import shlex import subprocess from typing import TYPE_CHECKING, Optional + import isort from codeflash.cli_cmds.console import console, logger @@ -11,12 +12,14 @@ if TYPE_CHECKING: from pathlib import Path + def get_nth_line(text: str, n: int) -> str | None: for i, line in enumerate(text.splitlines(), start=1): if i == n: return line return None + def get_diff_output(cmd: list[str]) -> Optional[str]: try: result = subprocess.run(cmd, capture_output=True, text=True, check=True) @@ -27,7 +30,7 @@ def get_diff_output(cmd: list[str]) -> Optional[str]: is_ruff = cmd[0] == "ruff" if e.returncode == 0 and is_ruff: return "" - elif e.returncode == 1 and is_ruff: + if e.returncode == 1 and is_ruff: return e.stdout.strip() or None return None @@ -35,25 +38,30 @@ def get_diff_output(cmd: list[str]) -> Optional[str]: def get_diff_lines_output_by_black(filepath: str) -> Optional[str]: try: import black # type: ignore - return get_diff_output(['black', '--diff', filepath]) + + return get_diff_output(["black", "--diff", filepath]) except ImportError: return None + def get_diff_lines_output_by_ruff(filepath: str) -> Optional[str]: try: import ruff # type: ignore - return get_diff_output(['ruff', 'format', '--diff', filepath]) + + return get_diff_output(["ruff", "format", "--diff", filepath]) except ImportError: print("can't import ruff") return None def get_diff_lines_count(diff_output: str) -> int: - lines = diff_output.split('\n') - def is_diff_line(line: str) -> bool: - return line.startswith(('+', '-')) and not line.startswith(('+++', '---')) - diff_lines = [line for line in lines if is_diff_line(line)] - return len(diff_lines) + # Count lines that are diff changes (start with '+' or '-', but not '+++' or '---') + count = 0 + for line in diff_output.split("\n"): + if line and (line[0] in ("+", "-")) and not (line[:3] == "+++" or line[:3] == "---"): + count += 1 + return count + def is_safe_to_format(filepath: str, max_diff_lines: int = 100) -> bool: diff_changes_stdout = None @@ -61,20 +69,19 @@ def is_safe_to_format(filepath: str, max_diff_lines: int = 100) -> bool: diff_changes_stdout = get_diff_lines_output_by_black(filepath) if diff_changes_stdout is None: - logger.warning(f"black formatter not found, trying ruff instead...") + logger.warning("black formatter not found, trying ruff instead...") diff_changes_stdout = get_diff_lines_output_by_ruff(filepath) if diff_changes_stdout is None: - logger.warning(f"Both ruff, black formatters not found, skipping formatting diff check.") + logger.warning("Both ruff, black formatters not found, skipping formatting diff check.") return False - + diff_lines_count = get_diff_lines_count(diff_changes_stdout) - + if diff_lines_count > max_diff_lines: logger.debug(f"Skipping {filepath}: {diff_lines_count} lines would change (max: {max_diff_lines})") return False - else: - return True - + return True + def format_code(formatter_cmds: list[str], path: Path, print_status: bool = True) -> str: # noqa # TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution