Skip to content

Commit 935ad58

Browse files
committed
TEST/CUDA: Restore original cuda device
1 parent f3366d6 commit 935ad58

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

test/gtest/uct/cuda/test_switch_cuda_device.cc

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,9 @@ class test_mem_alloc_device : public test_switch_cuda_device {
288288
void init() override {
289289
test_switch_cuda_device::init();
290290

291+
int current_device;
292+
ASSERT_EQ(cudaGetDevice(&current_device), cudaSuccess);
293+
291294
for (auto device = 0; device < m_num_devices; ++device) {
292295
uct_md_mem_attr_t attr = {
293296
.field_mask = UCT_MD_MEM_ATTR_FIELD_SYS_DEV
@@ -305,6 +308,8 @@ class test_mem_alloc_device : public test_switch_cuda_device {
305308
m_sys_dev.push_back(attr.sys_dev);
306309
}
307310

311+
EXPECT_EQ(cudaSetDevice(current_device), cudaSuccess);
312+
308313
ASSERT_EQ(m_num_devices, m_sys_dev.size());
309314
}
310315

@@ -349,6 +354,9 @@ class test_mem_alloc_device : public test_switch_cuda_device {
349354
{
350355
CUdevice current;
351356

357+
int current_device;
358+
ASSERT_EQ(cudaGetDevice(&current_device), cudaSuccess);
359+
352360
// Ensure a valid context for each device
353361
for (auto device = 0; device < m_num_devices; ++device) {
354362
ASSERT_EQ(cudaSetDevice(device), cudaSuccess);
@@ -365,11 +373,16 @@ class test_mem_alloc_device : public test_switch_cuda_device {
365373
EXPECT_EQ(m_num_devices - 1, current);
366374
ASSERT_UCS_OK(uct_mem_free(&mem));
367375
}
376+
377+
EXPECT_EQ(cudaSetDevice(current_device), cudaSuccess);
368378
}
369379

370380
void test_same_device_alloc(ucs_memory_type_t mem_type,
371381
bool set_sys_dev = true)
372382
{
383+
int current_device;
384+
ASSERT_EQ(cudaGetDevice(&current_device), cudaSuccess);
385+
373386
for (auto device = 0; device < m_num_devices; ++device) {
374387
ASSERT_EQ(cudaSetDevice(device), cudaSuccess);
375388
ASSERT_UCS_OK(allocate(mem_type,
@@ -379,6 +392,8 @@ class test_mem_alloc_device : public test_switch_cuda_device {
379392
EXPECT_EQ(m_sys_dev[device], sys_device);
380393
ASSERT_UCS_OK(uct_mem_free(&mem));
381394
}
395+
396+
EXPECT_EQ(cudaSetDevice(current_device), cudaSuccess);
382397
}
383398

384399
void skip_if_no_fabric(ucs_memory_type_t mem_type)
@@ -526,7 +541,7 @@ class test_p2p_send_on_diff_device : public uct_p2p_test {
526541
rkey_release(sender(), rkey_dest);
527542

528543
// Set back an original device to avoid test failures
529-
ASSERT_EQ(cudaGetDevice(&current_device), cudaSuccess);
544+
EXPECT_EQ(cudaSetDevice(current_device), cudaSuccess);
530545
}
531546

532547
template<typename mem_alloc_t1, typename mem_alloc_t2>

0 commit comments

Comments
 (0)