Skip to content

Commit

Permalink
[LAPACK][cuSOLVER] Get cuSOLVER backend to build with SYCL2020 changes (
Browse files Browse the repository at this point in the history
  • Loading branch information
sknepper committed Oct 19, 2023
1 parent a566a71 commit 6c5f7ea
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 16 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ if(WIN32 AND ONEMKL_SYCL_IMPLEMENTATION STREQUAL "dpc++")
endif()

# Temporary disable sycl 2020 deprecations warnings for cuSOLVER and rocSOLVER
if(ONEMKL_SYCL_IMPLEMENTATION STREQUAL "dpc++" AND (ENABLE_CUSOLVER_BACKEND OR ENABLE_ROCSOLVER_BACKEND))
if(ONEMKL_SYCL_IMPLEMENTATION STREQUAL "dpc++" AND (ENABLE_ROCSOLVER_BACKEND))
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DSYCL2020_DISABLE_DEPRECATION_WARNINGS")
endif()

Expand Down
4 changes: 2 additions & 2 deletions src/lapack/backends/cusolver/cusolver_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ inline void getri_batch(const char *func_name, Func func, sycl::queue &queue, st
CUresult cuda_result;
cublasHandle_t cublas_handle;
CUBLAS_ERROR_FUNC(cublasCreate, err, &cublas_handle);
CUstream cu_stream = sycl::get_native<sycl::backend::cuda>(queue);
CUstream cu_stream = sycl::get_native<sycl::backend::ext_oneapi_cuda>(queue);
CUBLAS_ERROR_FUNC(cublasSetStream, err, cublas_handle, cu_stream);

auto a_ = sc.get_mem<cuDataType *>(a_acc);
Expand Down Expand Up @@ -838,7 +838,7 @@ sycl::event getri_batch(const char *func_name, Func func, sycl::queue &queue, st
CUresult cuda_result;
cublasHandle_t cublas_handle;
CUBLAS_ERROR_FUNC(cublasCreate, err, &cublas_handle);
CUstream cu_stream = sycl::get_native<sycl::backend::cuda>(queue);
CUstream cu_stream = sycl::get_native<sycl::backend::ext_oneapi_cuda>(queue);
CUBLAS_ERROR_FUNC(cublasSetStream, err, cublas_handle, cu_stream);

CUdeviceptr a_dev;
Expand Down
23 changes: 23 additions & 0 deletions src/lapack/backends/cusolver/cusolver_lapack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2458,6 +2458,7 @@ inline void gebrd_scratchpad_size(const char *func_name, Func func, sycl::queue
CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, scratch_size);
});
});
queue.wait();
}

#define GEBRD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \
Expand Down Expand Up @@ -2509,6 +2510,7 @@ inline void geqrf_scratchpad_size(const char *func_name, Func func, sycl::queue
CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, nullptr, lda, scratch_size);
});
});
queue.wait();
}

#define GEQRF_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \
Expand Down Expand Up @@ -2540,6 +2542,7 @@ inline void gesvd_scratchpad_size(const char *func_name, Func func, sycl::queue
CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, scratch_size);
});
});
queue.wait();
}

#define GESVD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \
Expand Down Expand Up @@ -2571,6 +2574,7 @@ inline void getrf_scratchpad_size(const char *func_name, Func func, sycl::queue
CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, nullptr, lda, scratch_size);
});
});
queue.wait();
}

#define GETRF_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \
Expand Down Expand Up @@ -2633,6 +2637,7 @@ inline void heevd_scratchpad_size(const char *func_name, Func func, sycl::queue
scratch_size);
});
});
queue.wait();
}

#define HEEVD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \
Expand Down Expand Up @@ -2665,6 +2670,7 @@ inline void hegvd_scratchpad_size(const char *func_name, Func func, sycl::queue
lda, nullptr, ldb, nullptr, scratch_size);
});
});
queue.wait();
}

#define HEGVD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \
Expand Down Expand Up @@ -2695,6 +2701,7 @@ inline void hetrd_scratchpad_size(const char *func_name, Func func, sycl::queue
nullptr, lda, nullptr, nullptr, nullptr, scratch_size);
});
});
queue.wait();
}

#define HETRD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \
Expand Down Expand Up @@ -2735,6 +2742,7 @@ inline void orgbr_scratchpad_size(const char *func_name, Func func, sycl::queue
nullptr, lda, nullptr, scratch_size);
});
});
queue.wait();
}

#define ORGBR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \
Expand Down Expand Up @@ -2765,6 +2773,7 @@ inline void orgtr_scratchpad_size(const char *func_name, Func func, sycl::queue
nullptr, lda, nullptr, scratch_size);
});
});
queue.wait();
}

#define ORGTR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \
Expand Down Expand Up @@ -2794,6 +2803,7 @@ inline void orgqr_scratchpad_size(const char *func_name, Func func, sycl::queue
scratch_size);
});
});
queue.wait();
}

#define ORGQR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \
Expand Down Expand Up @@ -2840,6 +2850,7 @@ inline void ormqr_scratchpad_size(const char *func_name, Func func, sycl::queue
nullptr, ldc, scratch_size);
});
});
queue.wait();
}

#define ORMQRF_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \
Expand Down Expand Up @@ -2872,6 +2883,7 @@ inline void ormtr_scratchpad_size(const char *func_name, Func func, sycl::queue
nullptr, lda, nullptr, nullptr, ldc, scratch_size);
});
});
queue.wait();
}

#define ORMTR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \
Expand Down Expand Up @@ -2903,6 +2915,7 @@ inline void potrf_scratchpad_size(const char *func_name, Func func, sycl::queue
nullptr, lda, scratch_size);
});
});
queue.wait();
}

#define POTRF_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \
Expand Down Expand Up @@ -2950,6 +2963,7 @@ inline void potri_scratchpad_size(const char *func_name, Func func, sycl::queue
nullptr, lda, scratch_size);
});
});
queue.wait();
}

#define POTRI_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \
Expand Down Expand Up @@ -2980,6 +2994,7 @@ inline void sytrf_scratchpad_size(const char *func_name, Func func, sycl::queue
CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, n, nullptr, lda, scratch_size);
});
});
queue.wait();
}

#define SYTRF_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \
Expand Down Expand Up @@ -3012,6 +3027,7 @@ inline void syevd_scratchpad_size(const char *func_name, Func func, sycl::queue
scratch_size);
});
});
queue.wait();
}

#define SYEVD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \
Expand Down Expand Up @@ -3044,6 +3060,7 @@ inline void sygvd_scratchpad_size(const char *func_name, Func func, sycl::queue
lda, nullptr, ldb, nullptr, scratch_size);
});
});
queue.wait();
}

#define SYGVD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \
Expand Down Expand Up @@ -3074,6 +3091,7 @@ inline void sytrd_scratchpad_size(const char *func_name, Func func, sycl::queue
nullptr, lda, nullptr, nullptr, nullptr, scratch_size);
});
});
queue.wait();
}

#define SYTRD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \
Expand Down Expand Up @@ -3134,6 +3152,7 @@ inline void ungbr_scratchpad_size(const char *func_name, Func func, sycl::queue
nullptr, lda, nullptr, scratch_size);
});
});
queue.wait();
}

#define UNGBR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \
Expand Down Expand Up @@ -3164,6 +3183,7 @@ inline void ungqr_scratchpad_size(const char *func_name, Func func, sycl::queue
scratch_size);
});
});
queue.wait();
}

#define UNGQR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \
Expand Down Expand Up @@ -3193,6 +3213,7 @@ inline void ungtr_scratchpad_size(const char *func_name, Func func, sycl::queue
nullptr, lda, nullptr, scratch_size);
});
});
queue.wait();
}

#define UNGTR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \
Expand Down Expand Up @@ -3241,6 +3262,7 @@ inline void unmqr_scratchpad_size(const char *func_name, Func func, sycl::queue
nullptr, ldc, scratch_size);
});
});
queue.wait();
}

#define UNMQR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \
Expand Down Expand Up @@ -3273,6 +3295,7 @@ inline void unmtr_scratchpad_size(const char *func_name, Func func, sycl::queue
nullptr, lda, nullptr, nullptr, ldc, scratch_size);
});
});
queue.wait();
}

#define UNMTR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \
Expand Down
15 changes: 8 additions & 7 deletions src/lapack/backends/cusolver/cusolver_scope_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ thread_local cusolver_handle<pi_context> CusolverScopedContextHandler::handle_he
cusolver_handle<pi_context>{};

CusolverScopedContextHandler::CusolverScopedContextHandler(sycl::queue queue,
sycl::interop_handler &ih)
sycl::interop_handle &ih)
: ih(ih),
needToRecover_(false) {
placedContext_ = queue.get_context();
placedContext_ = new sycl::context(queue.get_context());
auto device = queue.get_device();
auto desired = sycl::get_native<sycl::backend::cuda>(placedContext_);
auto desired = sycl::get_native<sycl::backend::ext_oneapi_cuda>(*placedContext_);
CUresult err;
CUDA_ERROR_FUNC(cuCtxGetCurrent, err, &original_);
if (original_ != desired) {
Expand All @@ -65,6 +65,7 @@ CusolverScopedContextHandler::~CusolverScopedContextHandler() noexcept(false) {
CUresult err;
CUDA_ERROR_FUNC(cuCtxSetCurrent, err, original_);
}
delete placedContext_;
}

void ContextCallback(void *userData) {
Expand All @@ -87,8 +88,8 @@ void ContextCallback(void *userData) {
}

cusolverDnHandle_t CusolverScopedContextHandler::get_handle(const sycl::queue &queue) {
auto piPlacedContext_ =
reinterpret_cast<pi_context>(sycl::get_native<sycl::backend::cuda>(placedContext_));
auto piPlacedContext_ = reinterpret_cast<pi_context>(
sycl::get_native<sycl::backend::ext_oneapi_cuda>(*placedContext_));
CUstream streamId = get_stream(queue);
cusolverStatus_t err;
auto it = handle_helper.cusolver_handle_mapper_.find(piPlacedContext_);
Expand Down Expand Up @@ -120,14 +121,14 @@ cusolverDnHandle_t CusolverScopedContextHandler::get_handle(const sycl::queue &q
auto insert_iter = handle_helper.cusolver_handle_mapper_.insert(
std::make_pair(piPlacedContext_, new std::atomic<cusolverDnHandle_t>(handle)));

sycl::detail::pi::contextSetExtendedDeleter(placedContext_, ContextCallback,
sycl::detail::pi::contextSetExtendedDeleter(*placedContext_, ContextCallback,
insert_iter.first->second);

return handle;
}

CUstream CusolverScopedContextHandler::get_stream(const sycl::queue &queue) {
return sycl::get_native<sycl::backend::cuda>(queue);
return sycl::get_native<sycl::backend::ext_oneapi_cuda>(queue);
}
sycl::context CusolverScopedContextHandler::get_context(const sycl::queue &queue) {
return queue.get_context();
Expand Down
16 changes: 11 additions & 5 deletions src/lapack/backends/cusolver/cusolver_scope_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
#else
#include <CL/sycl.hpp>
#endif
#if __has_include(<sycl/backend/cuda.hpp>)
#if __has_include(<sycl/context.hpp>)
#if __SYCL_COMPILER_VERSION <= 20220930
#include <sycl/backend/cuda.hpp>
#endif
#include <sycl/context.hpp>
#include <sycl/detail/pi.hpp>
#else
Expand Down Expand Up @@ -77,15 +79,15 @@ cuSolver handle to the SYCL context.

class CusolverScopedContextHandler {
CUcontext original_;
sycl::context placedContext_;
sycl::context *placedContext_;
bool needToRecover_;
sycl::interop_handler &ih;
sycl::interop_handle &ih;
static thread_local cusolver_handle<pi_context> handle_helper;
CUstream get_stream(const sycl::queue &queue);
sycl::context get_context(const sycl::queue &queue);

public:
CusolverScopedContextHandler(sycl::queue queue, sycl::interop_handler &ih);
CusolverScopedContextHandler(sycl::queue queue, sycl::interop_handle &ih);

~CusolverScopedContextHandler() noexcept(false);
/**
Expand All @@ -100,9 +102,13 @@ class CusolverScopedContextHandler {
// will be fixed when SYCL-2020 has been implemented for Pi backend.
template <typename T, typename U>
inline T get_mem(U acc) {
CUdeviceptr cudaPtr = ih.get_mem<sycl::backend::cuda>(acc);
CUdeviceptr cudaPtr = ih.get_native_mem<sycl::backend::ext_oneapi_cuda>(acc);
return reinterpret_cast<T>(cudaPtr);
}

void wait_stream(const sycl::queue &queue) {
cuStreamSynchronize(get_stream(queue));
}
};

} // namespace cusolver
Expand Down
3 changes: 2 additions & 1 deletion src/lapack/backends/cusolver/cusolver_task.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ namespace cusolver {

template <typename H, typename F>
static inline void host_task_internal(H &cgh, sycl::queue queue, F f) {
cgh.interop_task([f, queue](sycl::interop_handler ih) {
cgh.host_task([f, queue](sycl::interop_handle ih) {
auto sc = CusolverScopedContextHandler(queue, ih);
f(sc);
sc.wait_stream(queue);
});
}

Expand Down

0 comments on commit 6c5f7ea

Please sign in to comment.