Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,52 @@
import os
import random

import pytest
import torch


def pytest_configure(config):
# register an additional marker (see pytest_collection_modifyitems)
config.addinivalue_line(
"markers", "needs_cuda: mark for tests that rely on a CUDA device"
)


def pytest_collection_modifyitems(items):
# This hook is called by pytest after it has collected the tests (google its
# name to check out its doc!). We can ignore some tests as we see fit here,
# or add marks, such as a skip mark.

out_items = []
for item in items:
# The needs_cuda mark will exist if the test was explicitly decorated
# with the @needs_cuda decorator. It will also exist if it was
# parametrized with a parameter that has the mark: for example if a test
# is parametrized with
# @pytest.mark.parametrize('device', cpu_and_cuda())
# the "instances" of the tests where device == 'cuda' will have the
# 'needs_cuda' mark, and the ones with device == 'cpu' won't have the
# mark.
needs_cuda = item.get_closest_marker("needs_cuda") is not None

if (
needs_cuda
and not torch.cuda.is_available()
and os.environ.get("FAIL_WITHOUT_CUDA") is None
):
# We skip CUDA tests on non-CUDA machines, but only if the
# FAIL_WITHOUT_CUDA env var wasn't set. If it's set, the test will
# typically fail with a "Unsupported device: cuda" error. This is
# normal and desirable: this env var is set on CI jobs that are
# supposed to run the CUDA tests, so if CUDA isn't available on
# those for whatever reason, we need to know.
item.add_marker(pytest.mark.skip(reason="CUDA not available."))

out_items.append(item)

items[:] = out_items


@pytest.fixture(autouse=True)
def prevent_leaking_rng():
# Prevent each test from leaking the rng to all other test when they call
Expand Down
Loading
Loading