Skip to content

Commit

Permalink
Implement cublasSdot.
Browse files Browse the repository at this point in the history
  • Loading branch information
lshqqytiger committed Mar 25, 2024
1 parent 7c3891e commit 1122cc0
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 6 deletions.
18 changes: 16 additions & 2 deletions zluda_blas/src/cublas.rs
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,15 @@ pub unsafe extern "system" fn cublasSdot_v2(
incy: ::std::os::raw::c_int,
result: *mut f32,
) -> cublasStatus_t {
crate::unsupported()
crate::sdot_v2(
handle,
n,
x,
incx,
y,
incy,
result,
)
}

#[no_mangle]
Expand Down Expand Up @@ -4920,7 +4928,13 @@ pub unsafe extern "system" fn cublasSdot(
y: *const f32,
incy: ::std::os::raw::c_int,
) -> f32 {
unimplemented!()
crate::sdot(
n,
x,
incx,
y,
incy,
)
}

#[no_mangle]
Expand Down
56 changes: 52 additions & 4 deletions zluda_blas/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,13 @@ unsafe fn sgemm(
c: *mut f32,
ldc: i32,
) -> cublasStatus_t {
let transa = op_from_cuda(cublasOperation_t(transa as _));
let transb = op_from_cuda(cublasOperation_t(transb as _));
let mut handle = mem::zeroed();
let mut status = to_cuda(rocblas_create_handle(handle));
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
return status;
}
let transa = op_from_cuda(cublasOperation_t(transa as _));
let transb = op_from_cuda(cublasOperation_t(transb as _));
status = to_cuda(rocblas_sgemm(
handle.cast(),
transa,
Expand Down Expand Up @@ -279,6 +279,34 @@ unsafe fn init() -> cublasStatus_t {
cublasStatus_t::CUBLAS_STATUS_SUCCESS
}

unsafe fn sdot(
n: i32,
x: *const f32,
incx: i32,
y: *const f32,
incy: i32,
) -> cublasStatus_t {
let mut handle = mem::zeroed();
let mut status = to_cuda(rocblas_create_handle(handle));
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
return status;
}
let result = mem::zeroed();
status = to_cuda(rocblas_sdot(
handle.cast(),
n,
x,
incx,
y,
incy,
result,
));
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
return status;
}
to_cuda(rocblas_destroy_handle(*handle))
}

unsafe fn dasum_v2(
handle: *mut cublasContext,
n: i32,
Expand Down Expand Up @@ -333,6 +361,26 @@ unsafe fn dnrm_v2(
to_cuda(rocblas_dnrm2(handle.cast(), n, x, incx, result))
}

unsafe fn sdot_v2(
handle: cublasHandle_t,
n: i32,
x: *const f32,
incx: i32,
y: *const f32,
incy: i32,
result: *mut f32,
) -> cublasStatus_t {
to_cuda(rocblas_sdot(
handle.cast(),
n,
x,
incx,
y,
incy,
result,
))
}

unsafe fn idamax_v2(
handle: *mut cublasContext,
n: i32,
Expand Down Expand Up @@ -979,13 +1027,13 @@ unsafe fn dgemm(
c: *mut f64,
ldc: i32,
) -> cublasStatus_t {
let transa = op_from_cuda(cublasOperation_t(transa as _));
let transb = op_from_cuda(cublasOperation_t(transb as _));
let mut handle = mem::zeroed();
let mut status = to_cuda(rocblas_create_handle(handle));
if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
return status;
}
let transa = op_from_cuda(cublasOperation_t(transa as _));
let transb = op_from_cuda(cublasOperation_t(transb as _));
status = to_cuda(rocblas_dgemm(
handle.cast(),
transa,
Expand Down

0 comments on commit 1122cc0

Please sign in to comment.