Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 28 additions & 21 deletions codeflash/discovery/pytest_new_process_discovery.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# ruff: noqa
import sys
from pathlib import Path
from typing import Any
import pickle


# This script should not have any relation to the codeflash package, be careful with imports
cwd = sys.argv[1]
Expand All @@ -11,44 +14,48 @@
sys.path.insert(1, str(cwd))


def parse_pytest_collection_results(pytest_tests: list[Any]) -> list[dict[str, str]]:
test_results = []
for test in pytest_tests:
test_class = None
if test.cls:
test_class = test.parent.name
test_results.append({"test_file": str(test.path), "test_class": test_class, "test_function": test.name})
return test_results


class PytestCollectionPlugin:
def pytest_collection_finish(self, session) -> None:
global pytest_rootdir
global pytest_rootdir, collected_tests

collected_tests.extend(session.items)
pytest_rootdir = session.config.rootdir

# Write results immediately since pytest.main() will exit after this callback, not always with a success code
tests = parse_pytest_collection_results(collected_tests)
exit_code = getattr(session.config, "exitstatus", 0)
with Path(pickle_path).open("wb") as f:
pickle.dump((exit_code, tests, pytest_rootdir), f, protocol=pickle.HIGHEST_PROTOCOL)

def pytest_collection_modifyitems(self, items) -> None:
skip_benchmark = pytest.mark.skip(reason="Skipping benchmark tests")
for item in items:
if "benchmark" in item.fixturenames:
item.add_marker(skip_benchmark)


def parse_pytest_collection_results(pytest_tests: list[Any]) -> list[dict[str, str]]:
test_results = []
for test in pytest_tests:
test_class = None
if test.cls:
test_class = test.parent.name
test_results.append({"test_file": str(test.path), "test_class": test_class, "test_function": test.name})
return test_results


if __name__ == "__main__":
from pathlib import Path

import pytest

try:
exitcode = pytest.main(
[tests_root, "-p no:logging", "--collect-only", "-m", "not skip", "-p", "no:codeflash-benchmark"],
pytest.main(
[tests_root, "-p", "no:logging", "--collect-only", "-m", "not skip", "-p", "no:codeflash-benchmark"],
plugins=[PytestCollectionPlugin()],
)
except Exception as e:
print(f"Failed to collect tests: {e!s}")
exitcode = -1
tests = parse_pytest_collection_results(collected_tests)
import pickle

with Path(pickle_path).open("wb") as f:
pickle.dump((exitcode, tests, pytest_rootdir), f, protocol=pickle.HIGHEST_PROTOCOL)
try:
with Path(pickle_path).open("wb") as f:
pickle.dump((-1, [], None), f, protocol=pickle.HIGHEST_PROTOCOL)
except Exception as pickle_error:
print(f"Failed to write failure pickle: {pickle_error!s}", file=sys.stderr)
Loading