diff --git a/include/hidet/runtime/cuda/cublas.h b/include/hidet/runtime/cuda/cublas.h index ed1108957..16119cd72 100644 --- a/include/hidet/runtime/cuda/cublas.h +++ b/include/hidet/runtime/cuda/cublas.h @@ -36,3 +36,8 @@ DLL void hidet_cublas_strided_gemm( int64_t sa, int64_t sb, int64_t sc, bool trans_a, bool trans_b, int compute_type ); + +DLL void hidet_cublas_batched_gemm( + int b, int m, int n, int k, int ta, int tb, int tc, void **ptr_a, void **ptr_b, void **ptr_c, + bool trans_a, bool trans_b, int compute_type +); diff --git a/include/hidet/runtime/cuda/cuda.h b/include/hidet/runtime/cuda/cuda.h index d386ae533..b906806fd 100644 --- a/include/hidet/runtime/cuda/cuda.h +++ b/include/hidet/runtime/cuda/cuda.h @@ -12,7 +12,23 @@ #pragma once #include +typedef enum { + cudaMemcpyHostToHost = 0, + cudaMemcpyHostToDevice = 1, + cudaMemcpyDeviceToHost = 2, + cudaMemcpyDeviceToDevice = 3, + cudaMemcpyDefault = 4 +} cudaMemcpyKind; + +typedef void* cudaStream_t; + DLL int hidet_cuda_device_count(); DLL int hidet_cuda_get_device(); DLL void hidet_cuda_set_device(int device); +DLL void* hidet_cuda_malloc(size_t size); +DLL void* hidet_cuda_malloc_async(size_t size, cudaStream_t stream); +DLL void hidet_cuda_free(void *devPtr); +DLL void hidet_cuda_free_async(void *devPtr, cudaStream_t stream); +DLL void hidet_cuda_memcpy(void* dst, const void* src, size_t count, cudaMemcpyKind kind); +DLL void hidet_cuda_memcpy_async(void* dst, const void* src, size_t count, cudaMemcpyKind kind, cudaStream_t stream); diff --git a/python/hidet/cuda/cublas/__init__.py b/python/hidet/cuda/cublas/__init__.py index e08d49ee0..fb32e996c 100644 --- a/python/hidet/cuda/cublas/__init__.py +++ b/python/hidet/cuda/cublas/__init__.py @@ -10,4 +10,4 @@ # See the License for the specific language governing permissions and # limitations under the License. from .ffi import cublasComputeType, cudaDataType -from .kernels import gemm, strided_gemm +from .kernels import gemm, strided_gemm, batched_gemm diff --git a/python/hidet/cuda/cublas/ffi.py b/python/hidet/cuda/cublas/ffi.py index 96b948632..fce3a494c 100644 --- a/python/hidet/cuda/cublas/ffi.py +++ b/python/hidet/cuda/cublas/ffi.py @@ -14,6 +14,7 @@ import glob from enum import IntEnum from ctypes import c_int32, c_int64, c_void_p, c_bool, c_char_p +from hidet.ffi.utils import c_pointer_compatible from hidet.ffi.ffi import get_func from hidet.utils.py import initialize @@ -117,6 +118,26 @@ class cublasComputeType(IntEnum): restype=None, ) +batched_gemm = get_func( + func_name='hidet_cublas_batched_gemm', + arg_types=[ + c_int32, # batch size + c_int32, # m + c_int32, # n + c_int32, # k + c_int32, # type a + c_int32, # type b + c_int32, # type c + c_pointer_compatible, # a array + c_pointer_compatible, # b array + c_pointer_compatible, # c array + c_bool, # trans a + c_bool, # trans b + c_int32, # compute type + ], + restype=None, +) + @initialize() def set_cublas_library_path(): diff --git a/python/hidet/cuda/cublas/kernels.py b/python/hidet/cuda/cublas/kernels.py index 47a9c13a4..159b868ad 100644 --- a/python/hidet/cuda/cublas/kernels.py +++ b/python/hidet/cuda/cublas/kernels.py @@ -9,8 +9,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from typing import Union, List from hidet.ir.dtypes import DataType +from hidet.ir.type import void_p +from hidet.ffi.utils import Array from .utils import as_pointer, as_type_code from .ffi import cublasComputeType, cudaDataType from . import ffi @@ -164,3 +166,88 @@ def strided_gemm( trans_b, compute_type, ) + + +def batched_gemm( + bs: int, + m: int, + n: int, + k: int, + type_a: Union[int, cudaDataType, DataType], + type_b: Union[int, cudaDataType, DataType], + type_c: Union[int, cudaDataType, DataType], + a: Union[Array, List], + b: Union[Array, List], + c: Union[Array, List], + trans_a: bool, + trans_b: bool, + compute_type: Union[int, cublasComputeType], +): + """ + Batch matrix multiplication of two matrices using cublas in row major order by default. + + The matrix of A, B, and C are stored as arrays where each array element is one matrix in + row-major order (if not transposed), and the length of the array is the batch size. + + A: bs x m x k + B: bs x k x n + C: bs x m x n + + Parameters + ---------- + bs: int + Batch size. + m: int + Number of rows of matrix op(A) and of matrix C. + n: int + Number of columns of matrix op(B) and of matrix C. + k: int + Number of columns of matrix op(A) and of rows of matrix op(B). + type_a: Union[int, DataType] + Type of elements in matrix A. + type_b: Union[int, DataType] + Type of elements in matrix B. + type_c: Union[int, DataType] + Type of elements in matrix C. + a: hidet.ffi.utils.Array or List[Tensor] + Matrix A, can be either a list of Tensors or an Array object constructed from a list of Tensors. + b: hidet.ffi.utils.Array or List[Tensor] + Matrix B, can be either a list of Tensors or an Array object constructed from a list of Tensors. + c: hidet.ffi.utils.Array or List[Tensor] + Matrix C, can be either a list of Tensors or an Array object constructed from a list of Tensors. + trans_a: bool + Whether matrix A is transposed. + trans_b: bool + Whether matrix B is transposed. + compute_type: Union[int, cublasComputeType] + The compute type of the operation. + """ + + def convert_list_to_array(l): + ret = Array(void_p, len(l)) + for i in range(len(l)): + ret[i] = l[i].storage.addr + return ret + + if isinstance(a, List): + a = convert_list_to_array(a) + if isinstance(b, List): + b = convert_list_to_array(b) + if isinstance(c, List): + c = convert_list_to_array(c) + + ffi.batched_gemm( + bs, + m, + n, + k, + as_type_code(type_a), + as_type_code(type_b), + as_type_code(type_c), + a, + b, + c, + trans_a, + trans_b, + compute_type, + ) diff --git a/python/hidet/ir/library/cuda/cublas/__init__.py b/python/hidet/ir/library/cuda/cublas/__init__.py index abf4eece9..4cc421749 100644 --- a/python/hidet/ir/library/cuda/cublas/__init__.py +++ b/python/hidet/ir/library/cuda/cublas/__init__.py @@ -11,5 +11,5 @@ # limitations under the License. from hidet.cuda.cublas.utils import as_type_code from hidet.cuda.cublas.kernels import cublasComputeType, cudaDataType -from .kernels import gemm, strided_gemm +from .kernels import gemm, strided_gemm, batched_gemm from . import regs as _regs # register functions diff --git a/python/hidet/ir/library/cuda/cublas/kernels.py b/python/hidet/ir/library/cuda/cublas/kernels.py index a335ef33b..7dc18b770 100644 --- a/python/hidet/ir/library/cuda/cublas/kernels.py +++ b/python/hidet/ir/library/cuda/cublas/kernels.py @@ -76,3 +76,25 @@ def strided_gemm( compute_type, ], ) + + +def batched_gemm( + bs: Union[Expr, int], + m: Union[Expr, int], + n: Union[Expr, int], + k: Union[Expr, int], + type_a: Union[Expr, DataType, int], + type_b: Union[Expr, DataType, int], + type_c: Union[Expr, DataType, int], + a: Expr, + b: Expr, + c: Expr, + trans_a: Union[Expr, bool], + trans_b: Union[Expr, bool], + compute_type: Union[Expr, int], +): + type_a, type_b, type_c = [as_type_code(t) if isinstance(t, DataType) else t for t in [type_a, type_b, type_c]] + return call_primitive_func( + func_name='cublas.batched_gemm', + args=[bs, m, n, k, type_a, type_b, type_c, a, b, c, trans_a, trans_b, compute_type], + ) diff --git a/python/hidet/ir/library/cuda/cublas/regs.py b/python/hidet/ir/library/cuda/cublas/regs.py index 884f7002d..23182af74 100644 --- a/python/hidet/ir/library/cuda/cublas/regs.py +++ b/python/hidet/ir/library/cuda/cublas/regs.py @@ -63,3 +63,25 @@ def register_cublas_kernels(): ), codegen_name='hidet_cublas_strided_gemm', ) + register_primitive_function( + name='cublas.batched_gemm', + func_or_type=FuncType( + param_types=[ + int32, # bs + int32, # m + int32, # n + int32, # k + int32, # type_a (cudaDataType) + int32, # type_b (cudaDataType) + int32, # type_c (cudaDataType) + void_p, # a + void_p, # b + void_p, # c + boolean, # trans_a + boolean, # trans_b + int32, # compute_type (cublasComputeType) + ], + ret_type=void, + ), + codegen_name='hidet_cublas_batched_gemm', + ) diff --git a/src/hidet/runtime/cuda/cublas.cpp b/src/hidet/runtime/cuda/cublas.cpp index e42bec40c..5f2eb947d 100644 --- a/src/hidet/runtime/cuda/cublas.cpp +++ b/src/hidet/runtime/cuda/cublas.cpp @@ -104,6 +104,7 @@ typedef enum { typedef const char* (*cublasGetStatusName_t)(cublasStatus_t status); typedef const char* (*cublasGetStatusString_t)(cublasStatus_t status); typedef cublasStatus_t (*cublasCreate_t)(cublasHandle_t *handle); +typedef cublasStatus_t (*cublasSetStream_t)(cublasHandle_t handle, cudaStream_t streamId); typedef cublasStatus_t (*cublasGemmEx_t)( cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, @@ -129,14 +130,29 @@ typedef cublasStatus_t (*cublasGemmStridedBatchedEx_t)( cublasComputeType_t computeType, cublasGemmAlgo_t algo ); +typedef cublasStatus_t (*cublasGemmBatchedEx_t)( + cublasHandle_t handle, + cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, + const void *alpha, + const void *const Aarray[], cudaDataType_t Atype, int lda, + const void *const Barray[], cudaDataType_t Btype, int ldb, + const void *beta, + void *const Carray[], cudaDataType_t Ctype, int ldc, + int batchCount, + cublasComputeType_t computeType, + cublasGemmAlgo_t algo +); // cublas api functions static cublasCreate_t cublasCreate; +static cublasSetStream_t cublasSetStream; static cublasGetStatusName_t cublasGetStatusName; static cublasGetStatusString_t cublasGetStatusString; static cublasGemmEx_t cublasGemmEx; static cublasGemmStridedBatchedEx_t cublasGemmStridedBatchedEx; +static cublasGemmBatchedEx_t cublasGemmBatchedEx; static std::string library_path; static void* libcublas = nullptr; @@ -199,12 +215,14 @@ static void lazy_load_cublas() { // load api functions cublasCreate = get_symbol(libcublas, "cublasCreate_v2"); + cublasSetStream = get_symbol(libcublas, "cublasSetStream_v2"); cublasGetStatusName = get_symbol(libcublas, "cublasGetStatusName"); cublasGetStatusString = get_symbol(libcublas, "cublasGetStatusString"); cublasGemmEx = get_symbol(libcublas, "cublasGemmEx"); cublasGemmStridedBatchedEx = get_symbol( libcublas, "cublasGemmStridedBatchedEx" ); + cublasGemmBatchedEx = get_symbol(libcublas, "cublasGemmBatchedEx"); } } @@ -248,6 +266,10 @@ DLL void hidet_cublas_gemm( ) { lazy_load_cublas(); + // Set the stream to the current stream + cudaStream_t cur_stream = get_cuda_stream(); + CHECK_CUBLAS(cublasSetStream(CublasContext::current_handle(), cur_stream)); + const void *p_alpha = nullptr; const void *p_beta = nullptr; @@ -284,6 +306,10 @@ DLL void hidet_cublas_strided_gemm( ) { lazy_load_cublas(); + // Set the stream to the current stream + cudaStream_t cur_stream = get_cuda_stream(); + CHECK_CUBLAS(cublasSetStream(CublasContext::current_handle(), cur_stream)); + const void *p_alpha = nullptr; const void *p_beta = nullptr; @@ -319,4 +345,70 @@ DLL void hidet_cublas_strided_gemm( )); } +DLL void hidet_cublas_batched_gemm( + int b, int m, int n, int k, int ta, int tb, int tc, void **ptr_a, void **ptr_b, void **ptr_c, + bool trans_a, bool trans_b, int compute_type +) { + lazy_load_cublas(); + + // Set the stream to the current stream + cudaStream_t cur_stream = get_cuda_stream(); + CHECK_CUBLAS(cublasSetStream(CublasContext::current_handle(), cur_stream)); + + const void *p_alpha = nullptr; + const void *p_beta = nullptr; + + set_alpha_beta(&p_alpha, &p_beta, cublasComputeType_t(compute_type), cudaDataType_t(tc)); + + static void **ptr_a_device, **ptr_b_device, **ptr_c_device; + static int cur_device_ptr_size; // Size of device memory currently allocated for each of the three a,b,c arrays. + + // Allocate device memory + // first use synchronous versions of malloc and memcpy, later switch to async versions + if (b > cur_device_ptr_size) { + if (cur_device_ptr_size > 0) { + hidet_cuda_free_async((void *)ptr_a_device, cur_stream); + hidet_cuda_free_async((void *)ptr_b_device, cur_stream); + hidet_cuda_free_async((void *)ptr_c_device, cur_stream); + } + ptr_a_device = (void **) hidet_cuda_malloc_async(b * sizeof(void*), cur_stream); + ptr_b_device = (void **) hidet_cuda_malloc_async(b * sizeof(void*), cur_stream); + ptr_c_device = (void **) hidet_cuda_malloc_async(b * sizeof(void*), cur_stream); + + cur_device_ptr_size = b; + } + + // Copy input arrays (A and B) from host to device + hidet_cuda_memcpy_async((void *)ptr_a_device, (void *)ptr_a, b * sizeof(void*), cudaMemcpyHostToDevice, cur_stream); + hidet_cuda_memcpy_async((void *)ptr_b_device, (void *)ptr_b, b * sizeof(void*), cudaMemcpyHostToDevice, cur_stream); + hidet_cuda_memcpy_async((void *)ptr_c_device, (void *)ptr_c, b * sizeof(void*), cudaMemcpyHostToDevice, cur_stream); + + CHECK_CUBLAS(cublasGemmBatchedEx( + CublasContext::current_handle(), + trans_a ? cublasOperation_t::CUBLAS_OP_T : cublasOperation_t::CUBLAS_OP_N, + trans_b ? cublasOperation_t::CUBLAS_OP_T : cublasOperation_t::CUBLAS_OP_N, + n, + m, + k, + p_alpha, + // b^t + ptr_b_device, + cudaDataType(tb), + n, // ldb + // a^t + ptr_a_device, + cudaDataType(ta), + k, // lda + p_beta, + // c^t + ptr_c_device, + cudaDataType(tc), + n, // ldc + b, // batchCount + cublasComputeType_t(compute_type), + cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT + )); + +} + diff --git a/src/hidet/runtime/cuda/cuda.cpp b/src/hidet/runtime/cuda/cuda.cpp index 47e4c6576..67f12aa9b 100644 --- a/src/hidet/runtime/cuda/cuda.cpp +++ b/src/hidet/runtime/cuda/cuda.cpp @@ -18,6 +18,12 @@ typedef int cudaError_t; typedef cudaError_t (*cudaGetDeviceCount_t)(int* count); typedef cudaError_t (*cudaGetDevice_t)(int* device); typedef cudaError_t (*cudaSetDevice_t)(int device); +typedef cudaError_t (*cudaMalloc_t)(void **devPtr, size_t size); +typedef cudaError_t (*cudaMallocAsync_t)(void **devPtr, size_t size, cudaStream_t stream); +typedef cudaError_t (*cudaFree_t)(void *devPtr); +typedef cudaError_t (*cudaFreeAsync_t)(void *devPtr, cudaStream_t stream); +typedef cudaError_t (*cudaMemcpy_t)(void* dst, const void* src, size_t count, cudaMemcpyKind kind); +typedef cudaError_t (*cudaMemcpyAsync_t)(void* dst, const void* src, size_t count, cudaMemcpyKind kind, cudaStream_t stream); typedef const char* (*cudaGetErrorString_t)(cudaError_t error); static std::string library_path; @@ -25,6 +31,12 @@ static void* libcudart = nullptr; static cudaGetDeviceCount_t cudaGetDeviceCount = nullptr; static cudaGetDevice_t cudaGetDevice = nullptr; static cudaSetDevice_t cudaSetDevice = nullptr; +static cudaMalloc_t cudaMalloc = nullptr; +static cudaMallocAsync_t cudaMallocAsync = nullptr; +static cudaFree_t cudaFree = nullptr; +static cudaFreeAsync_t cudaFreeAsync = nullptr; +static cudaMemcpy_t cudaMemcpy = nullptr; +static cudaMemcpyAsync_t cudaMemcpyAsync = nullptr; static cudaGetErrorString_t cudaGetErrorString = nullptr; // load cuda runtime APIs @@ -45,6 +57,12 @@ static inline void lazy_load_cuda_runtime() { cudaGetDeviceCount = get_symbol(libcudart, "cudaGetDeviceCount"); cudaGetDevice = get_symbol(libcudart, "cudaGetDevice"); cudaSetDevice = get_symbol(libcudart, "cudaSetDevice"); + cudaMalloc = get_symbol(libcudart, "cudaMalloc"); + cudaMallocAsync = get_symbol(libcudart, "cudaMallocAsync"); + cudaFree = get_symbol(libcudart, "cudaFree"); + cudaFreeAsync = get_symbol(libcudart, "cudaFreeAsync"); + cudaMemcpy = get_symbol(libcudart, "cudaMemcpy"); + cudaMemcpyAsync = get_symbol(libcudart, "cudaMemcpyAsync"); cudaGetErrorString = get_symbol(libcudart, "cudaGetErrorString"); } } @@ -79,3 +97,37 @@ DLL void hidet_cuda_set_device(int device) { lazy_load_cuda_runtime(); CHECK_CUDA(cudaSetDevice(device)); } + +DLL void* hidet_cuda_malloc(size_t size) { + lazy_load_cuda_runtime(); + void *devPtr; + CHECK_CUDA(cudaMalloc(&devPtr, size)); + return devPtr; +} + +DLL void* hidet_cuda_malloc_async(size_t size, cudaStream_t stream) { + lazy_load_cuda_runtime(); + void *devPtr; + CHECK_CUDA(cudaMallocAsync(&devPtr, size, stream)); + return devPtr; +} + +DLL void hidet_cuda_free(void *devPtr) { + lazy_load_cuda_runtime(); + CHECK_CUDA(cudaFree(devPtr)); +} + +DLL void hidet_cuda_free_async(void *devPtr, cudaStream_t stream) { + lazy_load_cuda_runtime(); + CHECK_CUDA(cudaFreeAsync(devPtr, stream)); +} + +DLL void hidet_cuda_memcpy(void* dst, const void* src, size_t count, cudaMemcpyKind kind) { + lazy_load_cuda_runtime(); + CHECK_CUDA(cudaMemcpy(dst, src, count, kind)); +} + +DLL void hidet_cuda_memcpy_async(void* dst, const void* src, size_t count, cudaMemcpyKind kind, cudaStream_t stream) { + lazy_load_cuda_runtime(); + CHECK_CUDA(cudaMemcpyAsync(dst, src, count, kind, stream)); +} \ No newline at end of file diff --git a/tests/cuda/test_cublas.py b/tests/cuda/test_cublas.py index 6155f79c0..5534f2228 100644 --- a/tests/cuda/test_cublas.py +++ b/tests/cuda/test_cublas.py @@ -51,6 +51,28 @@ def test_cublas_strided_gemm(bs, m, n, k, dtype, compute_type, tol): hidet.utils.assert_close(actual=c, expected=a @ b, rtol=tol, atol=tol) +@pytest.mark.parametrize('bs, m, n, k', [[3, 4, 4, 4], [4, 128, 128, 128], [5, 123, 234, 345]]) +@pytest.mark.parametrize( + 'dtype, compute_type, tol', + [ + (hidet.float16, cublasComputeType.CUBLAS_COMPUTE_16F, 1e-2), + (hidet.float32, cublasComputeType.CUBLAS_COMPUTE_32F, 1e-5), + (hidet.float64, cublasComputeType.CUBLAS_COMPUTE_64F, 1e-8), + ], +) +def test_cublas_batched_gemm(bs, m, n, k, dtype, compute_type, tol): + a, b, c = [], [], [] + for i in range(bs): + a.append(hidet.randn((m, k), device='cuda', dtype=dtype) / math.sqrt(k)) + b.append(hidet.randn((k, n), device='cuda', dtype=dtype) / math.sqrt(k)) + c.append(hidet.empty((m, n), device='cuda', dtype=dtype)) + + hidet.cuda.cublas.batched_gemm(bs, m, n, k, a[0].dtype, b[0].dtype, c[0].dtype, a, b, c, False, False, compute_type) + + for i in range(bs): + hidet.utils.assert_close(actual=c[i], expected=a[i] @ b[i], rtol=tol, atol=tol) + + def test_cublas_library_gemm(): from hidet.lang import attrs from hidet.lang.cuda import cublas