From 1122cc0e8308977145a0fd9b2c571f1cc9448017 Mon Sep 17 00:00:00 2001 From: Seunghoon Lee Date: Mon, 25 Mar 2024 11:30:23 +0900 Subject: [PATCH] Implement cublasSdot. --- zluda_blas/src/cublas.rs | 18 +++++++++++-- zluda_blas/src/lib.rs | 56 +++++++++++++++++++++++++++++++++++++--- 2 files changed, 68 insertions(+), 6 deletions(-) diff --git a/zluda_blas/src/cublas.rs b/zluda_blas/src/cublas.rs index 39410602..38fb1650 100644 --- a/zluda_blas/src/cublas.rs +++ b/zluda_blas/src/cublas.rs @@ -802,7 +802,15 @@ pub unsafe extern "system" fn cublasSdot_v2( incy: ::std::os::raw::c_int, result: *mut f32, ) -> cublasStatus_t { - crate::unsupported() + crate::sdot_v2( + handle, + n, + x, + incx, + y, + incy, + result, + ) } #[no_mangle] @@ -4920,7 +4928,13 @@ pub unsafe extern "system" fn cublasSdot( y: *const f32, incy: ::std::os::raw::c_int, ) -> f32 { - unimplemented!() + crate::sdot( + n, + x, + incx, + y, + incy, + ) } #[no_mangle] diff --git a/zluda_blas/src/lib.rs b/zluda_blas/src/lib.rs index 7644340f..67b51aa7 100644 --- a/zluda_blas/src/lib.rs +++ b/zluda_blas/src/lib.rs @@ -209,13 +209,13 @@ unsafe fn sgemm( c: *mut f32, ldc: i32, ) -> cublasStatus_t { - let transa = op_from_cuda(cublasOperation_t(transa as _)); - let transb = op_from_cuda(cublasOperation_t(transb as _)); let mut handle = mem::zeroed(); let mut status = to_cuda(rocblas_create_handle(handle)); if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS { return status; } + let transa = op_from_cuda(cublasOperation_t(transa as _)); + let transb = op_from_cuda(cublasOperation_t(transb as _)); status = to_cuda(rocblas_sgemm( handle.cast(), transa, @@ -279,6 +279,34 @@ unsafe fn init() -> cublasStatus_t { cublasStatus_t::CUBLAS_STATUS_SUCCESS } +unsafe fn sdot( + n: i32, + x: *const f32, + incx: i32, + y: *const f32, + incy: i32, +) -> cublasStatus_t { + let mut handle = mem::zeroed(); + let mut status = to_cuda(rocblas_create_handle(handle)); + if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS { + return status; + } + let result = mem::zeroed(); + status = to_cuda(rocblas_sdot( + handle.cast(), + n, + x, + incx, + y, + incy, + result, + )); + if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS { + return status; + } + to_cuda(rocblas_destroy_handle(*handle)) +} + unsafe fn dasum_v2( handle: *mut cublasContext, n: i32, @@ -333,6 +361,26 @@ unsafe fn dnrm_v2( to_cuda(rocblas_dnrm2(handle.cast(), n, x, incx, result)) } +unsafe fn sdot_v2( + handle: cublasHandle_t, + n: i32, + x: *const f32, + incx: i32, + y: *const f32, + incy: i32, + result: *mut f32, +) -> cublasStatus_t { + to_cuda(rocblas_sdot( + handle.cast(), + n, + x, + incx, + y, + incy, + result, + )) +} + unsafe fn idamax_v2( handle: *mut cublasContext, n: i32, @@ -979,13 +1027,13 @@ unsafe fn dgemm( c: *mut f64, ldc: i32, ) -> cublasStatus_t { - let transa = op_from_cuda(cublasOperation_t(transa as _)); - let transb = op_from_cuda(cublasOperation_t(transb as _)); let mut handle = mem::zeroed(); let mut status = to_cuda(rocblas_create_handle(handle)); if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS { return status; } + let transa = op_from_cuda(cublasOperation_t(transa as _)); + let transb = op_from_cuda(cublasOperation_t(transb as _)); status = to_cuda(rocblas_dgemm( handle.cast(), transa,