Skip to content

Commit

Permalink
[cuBLAS] Add cublas_gemm_batched and use cublasSetStream to set strea…
Browse files Browse the repository at this point in the history
…m to the current stream in all cublas API calls (#423)

Co-authored-by: Yudi Sun <yudi@eco-12.syslab.sandbox>
  • Loading branch information
yudi0201 and Yudi Sun committed Feb 12, 2024
1 parent a7446f0 commit 5f76caf
Show file tree
Hide file tree
Showing 11 changed files with 342 additions and 3 deletions.
5 changes: 5 additions & 0 deletions include/hidet/runtime/cuda/cublas.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,8 @@ DLL void hidet_cublas_strided_gemm(
int64_t sa, int64_t sb, int64_t sc,
bool trans_a, bool trans_b, int compute_type
);

DLL void hidet_cublas_batched_gemm(
int b, int m, int n, int k, int ta, int tb, int tc, void **ptr_a, void **ptr_b, void **ptr_c,
bool trans_a, bool trans_b, int compute_type
);
16 changes: 16 additions & 0 deletions include/hidet/runtime/cuda/cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,23 @@
#pragma once
#include <hidet/runtime/common.h>

typedef enum {
cudaMemcpyHostToHost = 0,
cudaMemcpyHostToDevice = 1,
cudaMemcpyDeviceToHost = 2,
cudaMemcpyDeviceToDevice = 3,
cudaMemcpyDefault = 4
} cudaMemcpyKind;

typedef void* cudaStream_t;

DLL int hidet_cuda_device_count();
DLL int hidet_cuda_get_device();
DLL void hidet_cuda_set_device(int device);
DLL void* hidet_cuda_malloc(size_t size);
DLL void* hidet_cuda_malloc_async(size_t size, cudaStream_t stream);
DLL void hidet_cuda_free(void *devPtr);
DLL void hidet_cuda_free_async(void *devPtr, cudaStream_t stream);
DLL void hidet_cuda_memcpy(void* dst, const void* src, size_t count, cudaMemcpyKind kind);
DLL void hidet_cuda_memcpy_async(void* dst, const void* src, size_t count, cudaMemcpyKind kind, cudaStream_t stream);

2 changes: 1 addition & 1 deletion python/hidet/cuda/cublas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .ffi import cublasComputeType, cudaDataType
from .kernels import gemm, strided_gemm
from .kernels import gemm, strided_gemm, batched_gemm
21 changes: 21 additions & 0 deletions python/hidet/cuda/cublas/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import glob
from enum import IntEnum
from ctypes import c_int32, c_int64, c_void_p, c_bool, c_char_p
from hidet.ffi.utils import c_pointer_compatible
from hidet.ffi.ffi import get_func
from hidet.utils.py import initialize

Expand Down Expand Up @@ -117,6 +118,26 @@ class cublasComputeType(IntEnum):
restype=None,
)

batched_gemm = get_func(
func_name='hidet_cublas_batched_gemm',
arg_types=[
c_int32, # batch size
c_int32, # m
c_int32, # n
c_int32, # k
c_int32, # type a
c_int32, # type b
c_int32, # type c
c_pointer_compatible, # a array
c_pointer_compatible, # b array
c_pointer_compatible, # c array
c_bool, # trans a
c_bool, # trans b
c_int32, # compute type
],
restype=None,
)


@initialize()
def set_cublas_library_path():
Expand Down
89 changes: 88 additions & 1 deletion python/hidet/cuda/cublas/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
from typing import Union, List
from hidet.ir.dtypes import DataType
from hidet.ir.type import void_p
from hidet.ffi.utils import Array
from .utils import as_pointer, as_type_code
from .ffi import cublasComputeType, cudaDataType
from . import ffi
Expand Down Expand Up @@ -164,3 +166,88 @@ def strided_gemm(
trans_b,
compute_type,
)


def batched_gemm(
bs: int,
m: int,
n: int,
k: int,
type_a: Union[int, cudaDataType, DataType],
type_b: Union[int, cudaDataType, DataType],
type_c: Union[int, cudaDataType, DataType],
a: Union[Array, List],
b: Union[Array, List],
c: Union[Array, List],
trans_a: bool,
trans_b: bool,
compute_type: Union[int, cublasComputeType],
):
"""
Batch matrix multiplication of two matrices using cublas in row major order by default.
The matrix of A, B, and C are stored as arrays where each array element is one matrix in
row-major order (if not transposed), and the length of the array is the batch size.
A: bs x m x k
B: bs x k x n
C: bs x m x n
Parameters
----------
bs: int
Batch size.
m: int
Number of rows of matrix op(A) and of matrix C.
n: int
Number of columns of matrix op(B) and of matrix C.
k: int
Number of columns of matrix op(A) and of rows of matrix op(B).
type_a: Union[int, DataType]
Type of elements in matrix A.
type_b: Union[int, DataType]
Type of elements in matrix B.
type_c: Union[int, DataType]
Type of elements in matrix C.
a: hidet.ffi.utils.Array or List[Tensor]
Matrix A, can be either a list of Tensors or an Array object constructed from a list of Tensors.
b: hidet.ffi.utils.Array or List[Tensor]
Matrix B, can be either a list of Tensors or an Array object constructed from a list of Tensors.
c: hidet.ffi.utils.Array or List[Tensor]
Matrix C, can be either a list of Tensors or an Array object constructed from a list of Tensors.
trans_a: bool
Whether matrix A is transposed.
trans_b: bool
Whether matrix B is transposed.
compute_type: Union[int, cublasComputeType]
The compute type of the operation.
"""

def convert_list_to_array(l):
ret = Array(void_p, len(l))
for i in range(len(l)):
ret[i] = l[i].storage.addr
return ret

if isinstance(a, List):
a = convert_list_to_array(a)
if isinstance(b, List):
b = convert_list_to_array(b)
if isinstance(c, List):
c = convert_list_to_array(c)

ffi.batched_gemm(
bs,
m,
n,
k,
as_type_code(type_a),
as_type_code(type_b),
as_type_code(type_c),
a,
b,
c,
trans_a,
trans_b,
compute_type,
)
2 changes: 1 addition & 1 deletion python/hidet/ir/library/cuda/cublas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
# limitations under the License.
from hidet.cuda.cublas.utils import as_type_code
from hidet.cuda.cublas.kernels import cublasComputeType, cudaDataType
from .kernels import gemm, strided_gemm
from .kernels import gemm, strided_gemm, batched_gemm
from . import regs as _regs # register functions
22 changes: 22 additions & 0 deletions python/hidet/ir/library/cuda/cublas/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,25 @@ def strided_gemm(
compute_type,
],
)


def batched_gemm(
bs: Union[Expr, int],
m: Union[Expr, int],
n: Union[Expr, int],
k: Union[Expr, int],
type_a: Union[Expr, DataType, int],
type_b: Union[Expr, DataType, int],
type_c: Union[Expr, DataType, int],
a: Expr,
b: Expr,
c: Expr,
trans_a: Union[Expr, bool],
trans_b: Union[Expr, bool],
compute_type: Union[Expr, int],
):
type_a, type_b, type_c = [as_type_code(t) if isinstance(t, DataType) else t for t in [type_a, type_b, type_c]]
return call_primitive_func(
func_name='cublas.batched_gemm',
args=[bs, m, n, k, type_a, type_b, type_c, a, b, c, trans_a, trans_b, compute_type],
)
22 changes: 22 additions & 0 deletions python/hidet/ir/library/cuda/cublas/regs.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,25 @@ def register_cublas_kernels():
),
codegen_name='hidet_cublas_strided_gemm',
)
register_primitive_function(
name='cublas.batched_gemm',
func_or_type=FuncType(
param_types=[
int32, # bs
int32, # m
int32, # n
int32, # k
int32, # type_a (cudaDataType)
int32, # type_b (cudaDataType)
int32, # type_c (cudaDataType)
void_p, # a
void_p, # b
void_p, # c
boolean, # trans_a
boolean, # trans_b
int32, # compute_type (cublasComputeType)
],
ret_type=void,
),
codegen_name='hidet_cublas_batched_gemm',
)
92 changes: 92 additions & 0 deletions src/hidet/runtime/cuda/cublas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ typedef enum {
typedef const char* (*cublasGetStatusName_t)(cublasStatus_t status);
typedef const char* (*cublasGetStatusString_t)(cublasStatus_t status);
typedef cublasStatus_t (*cublasCreate_t)(cublasHandle_t *handle);
typedef cublasStatus_t (*cublasSetStream_t)(cublasHandle_t handle, cudaStream_t streamId);
typedef cublasStatus_t (*cublasGemmEx_t)(
cublasHandle_t handle,
cublasOperation_t transa, cublasOperation_t transb,
Expand All @@ -129,14 +130,29 @@ typedef cublasStatus_t (*cublasGemmStridedBatchedEx_t)(
cublasComputeType_t computeType,
cublasGemmAlgo_t algo
);
typedef cublasStatus_t (*cublasGemmBatchedEx_t)(
cublasHandle_t handle,
cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k,
const void *alpha,
const void *const Aarray[], cudaDataType_t Atype, int lda,
const void *const Barray[], cudaDataType_t Btype, int ldb,
const void *beta,
void *const Carray[], cudaDataType_t Ctype, int ldc,
int batchCount,
cublasComputeType_t computeType,
cublasGemmAlgo_t algo
);


// cublas api functions
static cublasCreate_t cublasCreate;
static cublasSetStream_t cublasSetStream;
static cublasGetStatusName_t cublasGetStatusName;
static cublasGetStatusString_t cublasGetStatusString;
static cublasGemmEx_t cublasGemmEx;
static cublasGemmStridedBatchedEx_t cublasGemmStridedBatchedEx;
static cublasGemmBatchedEx_t cublasGemmBatchedEx;

static std::string library_path;
static void* libcublas = nullptr;
Expand Down Expand Up @@ -199,12 +215,14 @@ static void lazy_load_cublas() {

// load api functions
cublasCreate = get_symbol<cublasCreate_t>(libcublas, "cublasCreate_v2");
cublasSetStream = get_symbol<cublasSetStream_t>(libcublas, "cublasSetStream_v2");
cublasGetStatusName = get_symbol<cublasGetStatusName_t>(libcublas, "cublasGetStatusName");
cublasGetStatusString = get_symbol<cublasGetStatusString_t>(libcublas, "cublasGetStatusString");
cublasGemmEx = get_symbol<cublasGemmEx_t>(libcublas, "cublasGemmEx");
cublasGemmStridedBatchedEx = get_symbol<cublasGemmStridedBatchedEx_t>(
libcublas, "cublasGemmStridedBatchedEx"
);
cublasGemmBatchedEx = get_symbol<cublasGemmBatchedEx_t>(libcublas, "cublasGemmBatchedEx");
}
}

Expand Down Expand Up @@ -248,6 +266,10 @@ DLL void hidet_cublas_gemm(
) {
lazy_load_cublas();

// Set the stream to the current stream
cudaStream_t cur_stream = get_cuda_stream();
CHECK_CUBLAS(cublasSetStream(CublasContext::current_handle(), cur_stream));

const void *p_alpha = nullptr;
const void *p_beta = nullptr;

Expand Down Expand Up @@ -284,6 +306,10 @@ DLL void hidet_cublas_strided_gemm(
) {
lazy_load_cublas();

// Set the stream to the current stream
cudaStream_t cur_stream = get_cuda_stream();
CHECK_CUBLAS(cublasSetStream(CublasContext::current_handle(), cur_stream));

const void *p_alpha = nullptr;
const void *p_beta = nullptr;

Expand Down Expand Up @@ -319,4 +345,70 @@ DLL void hidet_cublas_strided_gemm(
));
}

DLL void hidet_cublas_batched_gemm(
int b, int m, int n, int k, int ta, int tb, int tc, void **ptr_a, void **ptr_b, void **ptr_c,
bool trans_a, bool trans_b, int compute_type
) {
lazy_load_cublas();

// Set the stream to the current stream
cudaStream_t cur_stream = get_cuda_stream();
CHECK_CUBLAS(cublasSetStream(CublasContext::current_handle(), cur_stream));

const void *p_alpha = nullptr;
const void *p_beta = nullptr;

set_alpha_beta(&p_alpha, &p_beta, cublasComputeType_t(compute_type), cudaDataType_t(tc));

static void **ptr_a_device, **ptr_b_device, **ptr_c_device;
static int cur_device_ptr_size; // Size of device memory currently allocated for each of the three a,b,c arrays.

// Allocate device memory
// first use synchronous versions of malloc and memcpy, later switch to async versions
if (b > cur_device_ptr_size) {
if (cur_device_ptr_size > 0) {
hidet_cuda_free_async((void *)ptr_a_device, cur_stream);
hidet_cuda_free_async((void *)ptr_b_device, cur_stream);
hidet_cuda_free_async((void *)ptr_c_device, cur_stream);
}
ptr_a_device = (void **) hidet_cuda_malloc_async(b * sizeof(void*), cur_stream);
ptr_b_device = (void **) hidet_cuda_malloc_async(b * sizeof(void*), cur_stream);
ptr_c_device = (void **) hidet_cuda_malloc_async(b * sizeof(void*), cur_stream);

cur_device_ptr_size = b;
}

// Copy input arrays (A and B) from host to device
hidet_cuda_memcpy_async((void *)ptr_a_device, (void *)ptr_a, b * sizeof(void*), cudaMemcpyHostToDevice, cur_stream);
hidet_cuda_memcpy_async((void *)ptr_b_device, (void *)ptr_b, b * sizeof(void*), cudaMemcpyHostToDevice, cur_stream);
hidet_cuda_memcpy_async((void *)ptr_c_device, (void *)ptr_c, b * sizeof(void*), cudaMemcpyHostToDevice, cur_stream);

CHECK_CUBLAS(cublasGemmBatchedEx(
CublasContext::current_handle(),
trans_a ? cublasOperation_t::CUBLAS_OP_T : cublasOperation_t::CUBLAS_OP_N,
trans_b ? cublasOperation_t::CUBLAS_OP_T : cublasOperation_t::CUBLAS_OP_N,
n,
m,
k,
p_alpha,
// b^t
ptr_b_device,
cudaDataType(tb),
n, // ldb
// a^t
ptr_a_device,
cudaDataType(ta),
k, // lda
p_beta,
// c^t
ptr_c_device,
cudaDataType(tc),
n, // ldc
b, // batchCount
cublasComputeType_t(compute_type),
cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT
));

}


0 comments on commit 5f76caf

Please sign in to comment.