diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index 498a8078b..3ff487517 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -8,7 +8,7 @@ import subprocess import tempfile from pathlib import Path -from typing import Optional, Union +from typing import Any, Optional, Union import isort @@ -163,10 +163,10 @@ def format_code( return formatted_code -def sort_imports(code: str, *, float_to_top: bool = False) -> str: +def sort_imports(code: str, **kwargs: Any) -> str: # noqa : ANN401 try: # Deduplicate and sort imports, modify the code in memory, not on disk - sorted_code = isort.code(code=code, float_to_top=float_to_top) + sorted_code = isort.code(code, **kwargs) except Exception: # this will also catch the FileSkipComment exception, use this fn everywhere logger.exception("Failed to sort imports with isort.") return code # Fall back to original code if isort fails diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index a776c4d45..b16606357 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -6,7 +6,6 @@ from pathlib import Path from typing import TYPE_CHECKING -import isort import libcst as cst from codeflash.cli_cmds.console import logger @@ -741,7 +740,7 @@ def inject_async_profiling_into_existing_test( new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")])) tree.body = [*new_imports, *tree.body] - return True, isort.code(ast.unparse(tree), float_to_top=True) + return True, sort_imports(ast.unparse(tree), float_to_top=True) def inject_profiling_into_existing_test( @@ -789,7 +788,7 @@ def inject_profiling_into_existing_test( additional_functions = [create_wrapper_function(mode)] tree.body = [*new_imports, *additional_functions, *tree.body] - return True, isort.code(ast.unparse(tree), float_to_top=True) + return True, sort_imports(ast.unparse(tree), float_to_top=True) def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.FunctionDef: diff --git a/tests/test_instrumentation_run_results_aiservice.py b/tests/test_instrumentation_run_results_aiservice.py index 78d9973f1..cae2c76f1 100644 --- a/tests/test_instrumentation_run_results_aiservice.py +++ b/tests/test_instrumentation_run_results_aiservice.py @@ -8,6 +8,7 @@ import isort from code_to_optimize.bubble_sort_method import BubbleSorter from codeflash.code_utils.code_utils import get_run_tmp_file +from codeflash.code_utils.formatter import sort_imports from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import FunctionParent, TestFile, TestFiles, TestingMode, TestType, VerificationType from codeflash.optimization.optimizer import Optimizer @@ -115,7 +116,7 @@ def test_single_element_list(): ) """ ) - instrumented_behavior_test_source = isort.code( + instrumented_behavior_test_source = sort_imports( instrumented_behavior_test_source, config=isort.Config(float_to_top=True) ) @@ -257,7 +258,7 @@ def test_single_element_list(): ) """ ) - instrumented_behavior_test_source = isort.code( + instrumented_behavior_test_source = sort_imports( instrumented_behavior_test_source, config=isort.Config(float_to_top=True) )