diff --git a/codeflash/code_utils/code_utils.py b/codeflash/code_utils/code_utils.py index 4c167fb69..5fc9bd9e9 100644 --- a/codeflash/code_utils/code_utils.py +++ b/codeflash/code_utils/code_utils.py @@ -15,6 +15,8 @@ from codeflash.cli_cmds.console import logger from codeflash.code_utils.config_parser import find_pyproject_toml +ImportErrorPattern = re.compile(r"ModuleNotFoundError.*$", re.MULTILINE) + @contextmanager def custom_addopts() -> None: diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index b76e63a91..1fd86acce 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -14,9 +14,16 @@ import pytest from pydantic.dataclasses import dataclass +from rich.panel import Panel +from rich.text import Text from codeflash.cli_cmds.console import console, logger, test_files_progress_bar -from codeflash.code_utils.code_utils import custom_addopts, get_run_tmp_file, module_name_from_file_path +from codeflash.code_utils.code_utils import ( + ImportErrorPattern, + custom_addopts, + get_run_tmp_file, + module_name_from_file_path, +) from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, codeflash_cache_db from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType @@ -180,6 +187,10 @@ def discover_tests_pytest( logger.warning( f"Failed to collect tests. Pytest Exit code: {exitcode}={pytest.ExitCode(exitcode).name}\n {error_section}" ) + if "ModuleNotFoundError" in result.stdout: + match = ImportErrorPattern.search(result.stdout).group() + panel = Panel(Text.from_markup(f"⚠️ {match} ", style="bold red"), expand=False) + console.print(panel) elif 0 <= exitcode <= 5: logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}={pytest.ExitCode(exitcode).name}") diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index fe4357839..12aeff3fa 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -23,6 +23,7 @@ from codeflash.code_utils import env_utils from codeflash.code_utils.code_replacer import replace_function_definitions_in_module from codeflash.code_utils.code_utils import ( + ImportErrorPattern, cleanup_paths, file_name_from_test_module_name, get_run_tmp_file, @@ -1192,6 +1193,12 @@ def run_and_parse_tests( f"stdout: {run_result.stdout}\n" f"stderr: {run_result.stderr}\n" ) + if "ModuleNotFoundError" in run_result.stdout: + from rich.text import Text + + match = ImportErrorPattern.search(run_result.stdout).group() + panel = Panel(Text.from_markup(f"⚠️ {match} ", style="bold red"), expand=False) + console.print(panel) if testing_type in {TestingMode.BEHAVIOR, TestingMode.PERFORMANCE}: results, coverage_results = parse_test_results( test_xml_path=result_file_path, diff --git a/tests/test_test_runner.py b/tests/test_test_runner.py index 2c67a644c..0e80b76e0 100644 --- a/tests/test_test_runner.py +++ b/tests/test_test_runner.py @@ -1,7 +1,10 @@ +import re + import os import tempfile from pathlib import Path +from codeflash.code_utils.code_utils import ImportErrorPattern from codeflash.models.models import TestFile, TestFiles, TestType from codeflash.verification.parse_test_output import parse_test_xml from codeflash.verification.test_runner import run_behavioral_tests @@ -96,3 +99,51 @@ def test_sort(): ) assert results[0].did_pass, "Test did not pass as expected" result_file.unlink(missing_ok=True) + + code = """import torch +def sorter(arr): + print(torch.ones(1)) + arr.sort() + return arr + +def test_sort(): + arr = [5, 4, 3, 2, 1, 0] + output = sorter(arr) + assert output == [0, 1, 2, 3, 4, 5] + """ + cur_dir_path = Path(__file__).resolve().parent + config = TestConfig( + tests_root=cur_dir_path, + project_root_path=cur_dir_path, + test_framework="pytest", + tests_project_rootdir=cur_dir_path.parent, + ) + + test_env = os.environ.copy() + test_env["CODEFLASH_TEST_ITERATION"] = "0" + test_env["CODEFLASH_TRACER_DISABLE"] = "1" + if "PYTHONPATH" not in test_env: + test_env["PYTHONPATH"] = str(config.project_root_path) + else: + test_env["PYTHONPATH"] += os.pathsep + str(config.project_root_path) + + with tempfile.NamedTemporaryFile(prefix="test_xx", suffix=".py", dir=cur_dir_path) as fp: + test_files = TestFiles( + test_files=[TestFile(instrumented_behavior_file_path=Path(fp.name), test_type=TestType.EXISTING_UNIT_TEST)] + ) + fp.write(code.encode("utf-8")) + fp.flush() + result_file, process, _, _ = run_behavioral_tests( + test_files, + test_framework=config.test_framework, + cwd=Path(config.project_root_path), + test_env=test_env, + pytest_timeout=1, + pytest_target_runtime_seconds=1, + ) + results = parse_test_xml( + test_xml_file_path=result_file, test_files=test_files, test_config=config, run_result=process + ) + match = ImportErrorPattern.search(process.stdout).group() + assert match=="ModuleNotFoundError: No module named 'torch'" + result_file.unlink(missing_ok=True)