Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cuBLAS] Add cublas_gemm_batched and use cublasSetStream to set stream to the current stream in all cublas API calls #423

Merged
merged 4 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/hidet/runtime/cuda/cublas.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,8 @@ DLL void hidet_cublas_strided_gemm(
int64_t sa, int64_t sb, int64_t sc,
bool trans_a, bool trans_b, int compute_type
);

DLL void hidet_cublas_batched_gemm(
int b, int m, int n, int k, int ta, int tb, int tc, void **ptr_a, void **ptr_b, void **ptr_c,
bool trans_a, bool trans_b, int compute_type
);
11 changes: 11 additions & 0 deletions include/hidet/runtime/cuda/cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,18 @@
#pragma once
#include <hidet/runtime/common.h>

typedef enum {
cudaMemcpyHostToHost = 0,
cudaMemcpyHostToDevice = 1,
cudaMemcpyDeviceToHost = 2,
cudaMemcpyDeviceToDevice = 3,
cudaMemcpyDefault = 4
} cudaMemcpyKind;

DLL int hidet_cuda_device_count();
DLL int hidet_cuda_get_device();
DLL void hidet_cuda_set_device(int device);
DLL void hidet_cuda_malloc(void **devPtr, size_t size);
DLL void hidet_cuda_free(void *devPtr);
DLL void hidet_cuda_memcpy(void* dst, const void* src, size_t count, cudaMemcpyKind kind);

2 changes: 1 addition & 1 deletion python/hidet/cuda/cublas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .ffi import cublasComputeType, cudaDataType
from .kernels import gemm, strided_gemm
from .kernels import gemm, strided_gemm, batched_gemm
21 changes: 21 additions & 0 deletions python/hidet/cuda/cublas/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import glob
from enum import IntEnum
from ctypes import c_int32, c_int64, c_void_p, c_bool, c_char_p
from hidet.ffi.utils import c_pointer_compatible
from hidet.ffi.ffi import get_func
from hidet.utils.py import initialize

Expand Down Expand Up @@ -117,6 +118,26 @@ class cublasComputeType(IntEnum):
restype=None,
)

batched_gemm = get_func(
func_name='hidet_cublas_batched_gemm',
arg_types=[
c_int32, # batch size
c_int32, # m
c_int32, # n
c_int32, # k
c_int32, # type a
c_int32, # type b
c_int32, # type c
c_pointer_compatible, # a array
c_pointer_compatible, # b array
c_pointer_compatible, # c array
c_bool, # trans a
c_bool, # trans b
c_int32, # compute type
],
restype=None,
)


@initialize()
def set_cublas_library_path():
Expand Down
89 changes: 88 additions & 1 deletion python/hidet/cuda/cublas/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
from typing import Union, List
from hidet.ir.dtypes import DataType
from hidet.ir.type import void_p
from hidet.ffi.utils import Array
from .utils import as_pointer, as_type_code
from .ffi import cublasComputeType, cudaDataType
from . import ffi
Expand Down Expand Up @@ -164,3 +166,88 @@ def strided_gemm(
trans_b,
compute_type,
)


def batched_gemm(
bs: int,
m: int,
n: int,
k: int,
type_a: Union[int, cudaDataType, DataType],
type_b: Union[int, cudaDataType, DataType],
type_c: Union[int, cudaDataType, DataType],
a: Union[Array, List],
b: Union[Array, List],
c: Union[Array, List],
trans_a: bool,
trans_b: bool,
compute_type: Union[int, cublasComputeType],
):
"""
Batch matrix multiplication of two matrices using cublas in row major order by default.

The matrix of A, B, and C are stored as arrays where each array element is one matrix in
row-major order (if not transposed), and the length of the array is the batch size.

A: bs x m x k
B: bs x k x n
C: bs x m x n

Parameters
----------
bs: int
Batch size.
m: int
Number of rows of matrix op(A) and of matrix C.
n: int
Number of columns of matrix op(B) and of matrix C.
k: int
Number of columns of matrix op(A) and of rows of matrix op(B).
type_a: Union[int, DataType]
Type of elements in matrix A.
type_b: Union[int, DataType]
Type of elements in matrix B.
type_c: Union[int, DataType]
Type of elements in matrix C.
a: hidet.ffi.utils.Array or List[Tensor]
Matrix A, can be either a list of Tensors or an Array object constructed from a list of Tensors.
b: hidet.ffi.utils.Array or List[Tensor]
Matrix B, can be either a list of Tensors or an Array object constructed from a list of Tensors.
c: hidet.ffi.utils.Array or List[Tensor]
Matrix C, can be either a list of Tensors or an Array object constructed from a list of Tensors.
trans_a: bool
Whether matrix A is transposed.
trans_b: bool
Whether matrix B is transposed.
compute_type: Union[int, cublasComputeType]
The compute type of the operation.
"""

def convert_list_to_array(l):
ret = Array(void_p, len(l))
for i in range(len(l)):
ret[i] = l[i].storage.addr
return ret

if isinstance(a, List):
a = convert_list_to_array(a)
if isinstance(b, List):
b = convert_list_to_array(b)
if isinstance(c, List):
c = convert_list_to_array(c)

ffi.batched_gemm(
bs,
m,
n,
k,
as_type_code(type_a),
as_type_code(type_b),
as_type_code(type_c),
a,
b,
c,
trans_a,
trans_b,
compute_type,
)
2 changes: 1 addition & 1 deletion python/hidet/ir/library/cuda/cublas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
# limitations under the License.
from hidet.cuda.cublas.utils import as_type_code
from hidet.cuda.cublas.kernels import cublasComputeType, cudaDataType
from .kernels import gemm, strided_gemm
from .kernels import gemm, strided_gemm, batched_gemm
from . import regs as _regs # register functions
22 changes: 22 additions & 0 deletions python/hidet/ir/library/cuda/cublas/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,25 @@ def strided_gemm(
compute_type,
],
)


def batched_gemm(
bs: Union[Expr, int],
m: Union[Expr, int],
n: Union[Expr, int],
k: Union[Expr, int],
type_a: Union[Expr, DataType, int],
type_b: Union[Expr, DataType, int],
type_c: Union[Expr, DataType, int],
a: Expr,
b: Expr,
c: Expr,
trans_a: Union[Expr, bool],
trans_b: Union[Expr, bool],
compute_type: Union[Expr, int],
):
type_a, type_b, type_c = [as_type_code(t) if isinstance(t, DataType) else t for t in [type_a, type_b, type_c]]
return call_primitive_func(
func_name='cublas.batched_gemm',
args=[bs, m, n, k, type_a, type_b, type_c, a, b, c, trans_a, trans_b, compute_type],
)
22 changes: 22 additions & 0 deletions python/hidet/ir/library/cuda/cublas/regs.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,25 @@ def register_cublas_kernels():
),
codegen_name='hidet_cublas_strided_gemm',
)
register_primitive_function(
name='cublas.batched_gemm',
func_or_type=FuncType(
param_types=[
int32, # bs
int32, # m
int32, # n
int32, # k
int32, # type_a (cudaDataType)
int32, # type_b (cudaDataType)
int32, # type_c (cudaDataType)
void_p, # a
void_p, # b
void_p, # c
boolean, # trans_a
boolean, # trans_b
int32, # compute_type (cublasComputeType)
],
ret_type=void,
),
codegen_name='hidet_cublas_batched_gemm',
)
76 changes: 76 additions & 0 deletions src/hidet/runtime/cuda/cublas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,19 @@ typedef cublasStatus_t (*cublasGemmStridedBatchedEx_t)(
cublasComputeType_t computeType,
cublasGemmAlgo_t algo
);
typedef cublasStatus_t (*cublasGemmBatchedEx_t)(
cublasHandle_t handle,
cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k,
const void *alpha,
const void *const Aarray[], cudaDataType_t Atype, int lda,
const void *const Barray[], cudaDataType_t Btype, int ldb,
const void *beta,
void *const Carray[], cudaDataType_t Ctype, int ldc,
int batchCount,
cublasComputeType_t computeType,
cublasGemmAlgo_t algo
);


// cublas api functions
Expand All @@ -137,6 +150,7 @@ static cublasGetStatusName_t cublasGetStatusName;
static cublasGetStatusString_t cublasGetStatusString;
static cublasGemmEx_t cublasGemmEx;
static cublasGemmStridedBatchedEx_t cublasGemmStridedBatchedEx;
static cublasGemmBatchedEx_t cublasGemmBatchedEx;

static std::string library_path;
static void* libcublas = nullptr;
Expand Down Expand Up @@ -205,6 +219,7 @@ static void lazy_load_cublas() {
cublasGemmStridedBatchedEx = get_symbol<cublasGemmStridedBatchedEx_t>(
libcublas, "cublasGemmStridedBatchedEx"
);
cublasGemmBatchedEx = get_symbol<cublasGemmBatchedEx_t>(libcublas, "cublasGemmBatchedEx");
}
}

Expand Down Expand Up @@ -319,4 +334,65 @@ DLL void hidet_cublas_strided_gemm(
));
}

DLL void hidet_cublas_batched_gemm(
int b, int m, int n, int k, int ta, int tb, int tc, void **ptr_a, void **ptr_b, void **ptr_c,
bool trans_a, bool trans_b, int compute_type
) {
lazy_load_cublas();

const void *p_alpha = nullptr;
const void *p_beta = nullptr;

set_alpha_beta(&p_alpha, &p_beta, cublasComputeType_t(compute_type), cudaDataType_t(tc));

static void **ptr_a_device, **ptr_b_device, **ptr_c_device;
static int cur_device_ptr_size; // Size of device memory currently allocated for each of the three a,b,c arrays.

// Allocate device memory
// first use synchronous versions of malloc and memcpy, later switch to async versions
if (cur_device_ptr_size != 0 && b > cur_device_ptr_size) {
hidet_cuda_free((void *)ptr_a_device);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not hidet_cuda_free_async?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following logic is more readable to me, just as a reference.

if(b > cur_device_ptr_size) {
  if(cur_device_ptr_size > 0) {
    free the three ptrs
  }
  alloc three ptrs
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestions! I'll modify these in the next revision.

hidet_cuda_free((void *)ptr_b_device);
hidet_cuda_free((void *)ptr_c_device);
}
if (ptr_a_device == NULL || b > cur_device_ptr_size) {
hidet_cuda_malloc((void **) &ptr_a_device, b * sizeof(void*));
hidet_cuda_malloc((void **) &ptr_b_device, b * sizeof(void*));
hidet_cuda_malloc((void **) &ptr_c_device, b * sizeof(void*));
cur_device_ptr_size = b;
}

// Copy input arrays (A and B) from host to device
hidet_cuda_memcpy((void *)ptr_a_device, (void *)ptr_a, b * sizeof(void*), cudaMemcpyHostToDevice);
hidet_cuda_memcpy((void *)ptr_b_device, (void *)ptr_b, b * sizeof(void*), cudaMemcpyHostToDevice);
hidet_cuda_memcpy((void *)ptr_c_device, (void *)ptr_c, b * sizeof(void*), cudaMemcpyHostToDevice);

CHECK_CUBLAS(cublasGemmBatchedEx(
CublasContext::current_handle(),
trans_a ? cublasOperation_t::CUBLAS_OP_T : cublasOperation_t::CUBLAS_OP_N,
trans_b ? cublasOperation_t::CUBLAS_OP_T : cublasOperation_t::CUBLAS_OP_N,
n,
m,
k,
p_alpha,
// b^t
ptr_b_device,
cudaDataType(tb),
n, // ldb
// a^t
ptr_a_device,
cudaDataType(ta),
k, // lda
p_beta,
// c^t
ptr_c_device,
cudaDataType(tc),
n, // ldc
b, // batchCount
cublasComputeType_t(compute_type),
cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT
));

}


24 changes: 24 additions & 0 deletions src/hidet/runtime/cuda/cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,19 @@ typedef int cudaError_t;
typedef cudaError_t (*cudaGetDeviceCount_t)(int* count);
typedef cudaError_t (*cudaGetDevice_t)(int* device);
typedef cudaError_t (*cudaSetDevice_t)(int device);
typedef cudaError_t (*cudaMalloc_t)(void **devPtr, size_t size);
typedef cudaError_t (*cudaFree_t)(void *devPtr);
typedef cudaError_t (*cudaMemcpy_t)(void* dst, const void* src, size_t count, cudaMemcpyKind kind);
typedef const char* (*cudaGetErrorString_t)(cudaError_t error);

static std::string library_path;
static void* libcudart = nullptr;
static cudaGetDeviceCount_t cudaGetDeviceCount = nullptr;
static cudaGetDevice_t cudaGetDevice = nullptr;
static cudaSetDevice_t cudaSetDevice = nullptr;
static cudaMalloc_t cudaMalloc = nullptr;
static cudaFree_t cudaFree = nullptr;
static cudaMemcpy_t cudaMemcpy = nullptr;
static cudaGetErrorString_t cudaGetErrorString = nullptr;

// load cuda runtime APIs
Expand All @@ -45,6 +51,9 @@ static inline void lazy_load_cuda_runtime() {
cudaGetDeviceCount = get_symbol<cudaGetDeviceCount_t>(libcudart, "cudaGetDeviceCount");
cudaGetDevice = get_symbol<cudaGetDevice_t>(libcudart, "cudaGetDevice");
cudaSetDevice = get_symbol<cudaSetDevice_t>(libcudart, "cudaSetDevice");
cudaMalloc = get_symbol<cudaMalloc_t>(libcudart, "cudaMalloc");
cudaFree = get_symbol<cudaFree_t>(libcudart, "cudaFree");
cudaMemcpy = get_symbol<cudaMemcpy_t>(libcudart, "cudaMemcpy");
cudaGetErrorString = get_symbol<cudaGetErrorString_t>(libcudart, "cudaGetErrorString");
}
}
Expand Down Expand Up @@ -79,3 +88,18 @@ DLL void hidet_cuda_set_device(int device) {
lazy_load_cuda_runtime();
CHECK_CUDA(cudaSetDevice(device));
}

DLL void hidet_cuda_malloc(void **devPtr, size_t size) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider directly return the allocated memory address like

DLL void* hidet_cuda_malloc(size_t size) {
    ...
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like hidet_cuda_get_device(...).

lazy_load_cuda_runtime();
CHECK_CUDA(cudaMalloc(devPtr, size));
}

DLL void hidet_cuda_free(void *devPtr) {
lazy_load_cuda_runtime();
CHECK_CUDA(cudaFree(devPtr));
}

DLL void hidet_cuda_memcpy(void* dst, const void* src, size_t count, cudaMemcpyKind kind) {
lazy_load_cuda_runtime();
CHECK_CUDA(cudaMemcpy(dst, src, count, kind));
}