Skip to content

Commit

Permalink
Refactor using StreamsAndDevices
Browse files Browse the repository at this point in the history
  • Loading branch information
masterleinad committed Feb 15, 2024
1 parent 49a7ea6 commit f8ff523
Showing 1 changed file with 38 additions and 37 deletions.
75 changes: 38 additions & 37 deletions core/unit_test/cuda/TestCuda_InterOp_StreamsMultiGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,47 @@

#include <TestCuda_Category.hpp>
#include <Test_InterOp_Streams.hpp>
#include <memory>

namespace {

std::array<TEST_EXECSPACE, 2> get_execution_spaces(int n_devices) {
KOKKOS_IMPL_CUDA_SAFE_CALL(cudaSetDevice(0));
cudaStream_t stream0;
KOKKOS_IMPL_CUDA_SAFE_CALL(cudaStreamCreate(&stream0));

KOKKOS_IMPL_CUDA_SAFE_CALL(cudaSetDevice(n_devices - 1));
cudaStream_t stream;
KOKKOS_IMPL_CUDA_SAFE_CALL(cudaStreamCreate(&stream));
struct StreamsAndDevices {
std::shared_ptr<std::array<cudaStream_t, 2>> streams;
std::array<int, 2> devices;

StreamsAndDevices() {
int n_devices;
KOKKOS_IMPL_CUDA_SAFE_CALL(cudaGetDeviceCount(&n_devices));

devices = {0, n_devices - 1};

streams = std::shared_ptr<std::array<cudaStream_t, 2>>(
new std::array<cudaStream_t, 2>{},
[this](std::array<cudaStream_t, 2> *s) {
for (int i = 0; i < 2; ++i) {
KOKKOS_IMPL_CUDA_SAFE_CALL(cudaSetDevice(devices[i]));
KOKKOS_IMPL_CUDA_SAFE_CALL(cudaStreamDestroy((*s)[i]));
}
});
for (int i = 0; i < 2; ++i) {
KOKKOS_IMPL_CUDA_SAFE_CALL(cudaSetDevice(devices[i]));
KOKKOS_IMPL_CUDA_SAFE_CALL(cudaStreamCreate(&((*streams)[i])));
}
}
};

TEST_EXECSPACE exec0(stream0);
TEST_EXECSPACE exec(stream);
std::array<TEST_EXECSPACE, 2> get_execution_spaces(
const StreamsAndDevices streams_and_devices) {
TEST_EXECSPACE exec0((*streams_and_devices.streams)[0]);
TEST_EXECSPACE exec1((*streams_and_devices.streams)[1]);

// Must return void to use ASSERT_EQ
[&]() {
ASSERT_EQ(exec0.cuda_device(), 0);
ASSERT_EQ(exec.cuda_device(), n_devices - 1);
ASSERT_EQ(exec0.cuda_device(), streams_and_devices.devices[0]);
ASSERT_EQ(exec1.cuda_device(), streams_and_devices.devices[1]);
}();

return {exec0, exec};
return {exec0, exec1};
}

// Test Interoperability with Cuda Streams
Expand Down Expand Up @@ -114,36 +133,25 @@ void test_policies(TEST_EXECSPACE exec0, Kokkos::View<int *, TEST_EXECSPACE> v0,
}

TEST(cuda_multi_gpu, managed_views) {
cudaStream_t stream0;
cudaStream_t stream;
int n_devices;
KOKKOS_IMPL_CUDA_SAFE_CALL(cudaGetDeviceCount(&n_devices));
StreamsAndDevices streams_and_devices;
{
std::array<TEST_EXECSPACE, 2> execs = get_execution_spaces(n_devices);
std::array<TEST_EXECSPACE, 2> execs =
get_execution_spaces(streams_and_devices);

Kokkos::View<int *, TEST_EXECSPACE> view0(
Kokkos::view_alloc("v0", execs[0]), 100);
Kokkos::View<int *, TEST_EXECSPACE> view(Kokkos::view_alloc("v", execs[1]),
100);

test_policies(execs[0], view0, execs[1], view);
stream0 = execs[0].cuda_stream();
stream = execs[1].cuda_stream();
}
KOKKOS_IMPL_CUDA_SAFE_CALL(cudaSetDevice(0));
KOKKOS_IMPL_CUDA_SAFE_CALL(cudaStreamDestroy(stream0));

KOKKOS_IMPL_CUDA_SAFE_CALL(cudaSetDevice(n_devices - 1));
KOKKOS_IMPL_CUDA_SAFE_CALL(cudaStreamDestroy(stream));
}

TEST(cuda_multi_gpu, unmanaged_views) {
cudaStream_t stream0;
cudaStream_t stream;
int n_devices;
KOKKOS_IMPL_CUDA_SAFE_CALL(cudaGetDeviceCount(&n_devices));
StreamsAndDevices streams_and_devices;
{
std::array<TEST_EXECSPACE, 2> execs = get_execution_spaces(n_devices);
std::array<TEST_EXECSPACE, 2> execs =
get_execution_spaces(streams_and_devices);

KOKKOS_IMPL_CUDA_SAFE_CALL(cudaSetDevice(execs[0].cuda_device()));
int *p0;
Expand All @@ -160,13 +168,6 @@ TEST(cuda_multi_gpu, unmanaged_views) {
test_policies(execs[0], view0, execs[1], view);
KOKKOS_IMPL_CUDA_SAFE_CALL(cudaFree(p0));
KOKKOS_IMPL_CUDA_SAFE_CALL(cudaFree(p));
stream0 = execs[0].cuda_stream();
stream = execs[1].cuda_stream();
}
KOKKOS_IMPL_CUDA_SAFE_CALL(cudaSetDevice(0));
KOKKOS_IMPL_CUDA_SAFE_CALL(cudaStreamDestroy(stream0));

KOKKOS_IMPL_CUDA_SAFE_CALL(cudaSetDevice(n_devices - 1));
KOKKOS_IMPL_CUDA_SAFE_CALL(cudaStreamDestroy(stream));
}
} // namespace

0 comments on commit f8ff523

Please sign in to comment.