From 06e9969fe21b3f9a4926936eda970098385f4fbd Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Tue, 1 Mar 2022 12:43:16 +0100 Subject: [PATCH] Restore CUDA_VISIBLE_DEVICES after test Signed-off-by: Enrico Minack --- test/single/test_ray.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/test/single/test_ray.py b/test/single/test_ray.py index baa98ea34a..77a0220084 100644 --- a/test/single/test_ray.py +++ b/test/single/test_ray.py @@ -5,7 +5,6 @@ import os import socket import sys -import time import pytest import ray @@ -39,19 +38,27 @@ def ray_start_4_cpus(): @pytest.fixture def ray_start_6_cpus(): address_info = ray.init(num_cpus=6) - yield address_info - # The code after the yield will run as teardown code. - ray.shutdown() + try: + yield address_info + finally: + # The code after the yield will run as teardown code. + ray.shutdown() @pytest.fixture def ray_start_4_cpus_4_gpus(): + orig_devices = os.environ.get("CUDA_VISIBLE_DEVICES") os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" address_info = ray.init(num_cpus=4, num_gpus=4) - yield address_info - # The code after the yield will run as teardown code. - ray.shutdown() - del os.environ["CUDA_VISIBLE_DEVICES"] + try: + yield address_info + # The code after the yield will run as teardown code. + ray.shutdown() + finally: + if orig_devices: + os.environ["CUDA_VISIBLE_DEVICES"] = orig_devices + else: + del os.environ["CUDA_VISIBLE_DEVICES"] @pytest.fixture