diff --git a/codeflash/code_utils/env_utils.py b/codeflash/code_utils/env_utils.py index d74c03135..adca67df8 100644 --- a/codeflash/code_utils/env_utils.py +++ b/codeflash/code_utils/env_utils.py @@ -22,13 +22,12 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = f.flush() tmp_file = Path(f.name) try: - format_code(formatter_cmds, tmp_file, print_status=False) + format_code(formatter_cmds, tmp_file, print_status=False, exit_on_failure=exit_on_failure) except Exception: exit_with_message( "⚠️ Codeflash requires a code formatter to be installed in your environment, but none was found. Please install a supported formatter, verify the formatter-cmds in your codeflash pyproject.toml config and try again.", error_on_exit=True, ) - return return_code diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index b1cb58540..f4f4c563e 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -43,6 +43,7 @@ def apply_formatter_cmds( path: Path, test_dir_str: Optional[str], print_status: bool, # noqa + exit_on_failure: bool = True, # noqa ) -> tuple[Path, str]: # TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution formatter_name = cmds[0].lower() @@ -84,8 +85,8 @@ def apply_formatter_cmds( expand=False, ) console.print(panel) - - raise e from None + if exit_on_failure: + raise e from None return file_path, file_path.read_text(encoding="utf8") @@ -106,6 +107,7 @@ def format_code( optimized_function: str = "", check_diff: bool = False, # noqa print_status: bool = True, # noqa + exit_on_failure: bool = True, # noqa ) -> str: with tempfile.TemporaryDirectory() as test_dir_str: if isinstance(path, str): @@ -138,7 +140,9 @@ def format_code( ) return original_code # TODO : We can avoid formatting the whole file again and only formatting the optimized code standalone and replace in formatted file above. - _, formatted_code = apply_formatter_cmds(formatter_cmds, path, test_dir_str=None, print_status=print_status) + _, formatted_code = apply_formatter_cmds( + formatter_cmds, path, test_dir_str=None, print_status=print_status, exit_on_failure=exit_on_failure + ) logger.debug(f"Formatted {path} with commands: {formatter_cmds}") return formatted_code