@@ -288,6 +288,9 @@ class test_mem_alloc_device : public test_switch_cuda_device {
288288 void init () override {
289289 test_switch_cuda_device::init ();
290290
291+ int current_device;
292+ ASSERT_EQ (cudaGetDevice (¤t_device), cudaSuccess);
293+
291294 for (auto device = 0 ; device < m_num_devices; ++device) {
292295 uct_md_mem_attr_t attr = {
293296 .field_mask = UCT_MD_MEM_ATTR_FIELD_SYS_DEV
@@ -305,6 +308,8 @@ class test_mem_alloc_device : public test_switch_cuda_device {
305308 m_sys_dev.push_back (attr.sys_dev );
306309 }
307310
311+ EXPECT_EQ (cudaSetDevice (current_device), cudaSuccess);
312+
308313 ASSERT_EQ (m_num_devices, m_sys_dev.size ());
309314 }
310315
@@ -349,6 +354,9 @@ class test_mem_alloc_device : public test_switch_cuda_device {
349354 {
350355 CUdevice current;
351356
357+ int current_device;
358+ ASSERT_EQ (cudaGetDevice (¤t_device), cudaSuccess);
359+
352360 // Ensure a valid context for each device
353361 for (auto device = 0 ; device < m_num_devices; ++device) {
354362 ASSERT_EQ (cudaSetDevice (device), cudaSuccess);
@@ -365,11 +373,16 @@ class test_mem_alloc_device : public test_switch_cuda_device {
365373 EXPECT_EQ (m_num_devices - 1 , current);
366374 ASSERT_UCS_OK (uct_mem_free (&mem));
367375 }
376+
377+ EXPECT_EQ (cudaSetDevice (current_device), cudaSuccess);
368378 }
369379
370380 void test_same_device_alloc (ucs_memory_type_t mem_type,
371381 bool set_sys_dev = true )
372382 {
383+ int current_device;
384+ ASSERT_EQ (cudaGetDevice (¤t_device), cudaSuccess);
385+
373386 for (auto device = 0 ; device < m_num_devices; ++device) {
374387 ASSERT_EQ (cudaSetDevice (device), cudaSuccess);
375388 ASSERT_UCS_OK (allocate (mem_type,
@@ -379,6 +392,8 @@ class test_mem_alloc_device : public test_switch_cuda_device {
379392 EXPECT_EQ (m_sys_dev[device], sys_device);
380393 ASSERT_UCS_OK (uct_mem_free (&mem));
381394 }
395+
396+ EXPECT_EQ (cudaSetDevice (current_device), cudaSuccess);
382397 }
383398
384399 void skip_if_no_fabric (ucs_memory_type_t mem_type)
@@ -526,7 +541,7 @@ class test_p2p_send_on_diff_device : public uct_p2p_test {
526541 rkey_release (sender (), rkey_dest);
527542
528543 // Set back an original device to avoid test failures
529- ASSERT_EQ ( cudaGetDevice (& current_device), cudaSuccess);
544+ EXPECT_EQ ( cudaSetDevice ( current_device), cudaSuccess);
530545 }
531546
532547 template <typename mem_alloc_t1, typename mem_alloc_t2>
0 commit comments