Skip to content

Commit

Permalink
Use cudaMallocAsync and cudaMemcpyAsync for cublas_batched_gemm
Browse files Browse the repository at this point in the history
  • Loading branch information
Yudi Sun committed Feb 6, 2024
1 parent 3d58fbd commit 69abe16
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 9 deletions.
4 changes: 3 additions & 1 deletion include/hidet/runtime/cuda/cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ typedef enum {
DLL int hidet_cuda_device_count();
DLL int hidet_cuda_get_device();
DLL void hidet_cuda_set_device(int device);
DLL void hidet_cuda_malloc(void **devPtr, size_t size);
DLL void* hidet_cuda_malloc(size_t size);
DLL void* hidet_cuda_malloc_async(size_t size, void *stream);
DLL void hidet_cuda_free(void *devPtr);
DLL void hidet_cuda_memcpy(void* dst, const void* src, size_t count, cudaMemcpyKind kind);
DLL void hidet_cuda_memcpy_async(void* dst, const void* src, size_t count, cudaMemcpyKind kind, void *stream);

14 changes: 8 additions & 6 deletions src/hidet/runtime/cuda/cublas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ DLL void hidet_cublas_batched_gemm(
static void **ptr_a_device, **ptr_b_device, **ptr_c_device;
static int cur_device_ptr_size; // Size of device memory currently allocated for each of the three a,b,c arrays.

void *cur_stream = get_cuda_stream();
// Allocate device memory
// first use synchronous versions of malloc and memcpy, later switch to async versions
if (cur_device_ptr_size != 0 && b > cur_device_ptr_size) {
Expand All @@ -356,16 +357,17 @@ DLL void hidet_cublas_batched_gemm(
hidet_cuda_free((void *)ptr_c_device);
}
if (ptr_a_device == NULL || b > cur_device_ptr_size) {
hidet_cuda_malloc((void **) &ptr_a_device, b * sizeof(void*));
hidet_cuda_malloc((void **) &ptr_b_device, b * sizeof(void*));
hidet_cuda_malloc((void **) &ptr_c_device, b * sizeof(void*));
ptr_a_device = (void **) hidet_cuda_malloc_async(b * sizeof(void*), cur_stream);
ptr_b_device = (void **) hidet_cuda_malloc_async(b * sizeof(void*), cur_stream);
ptr_c_device = (void **) hidet_cuda_malloc_async(b * sizeof(void*), cur_stream);

cur_device_ptr_size = b;
}

// Copy input arrays (A and B) from host to device
hidet_cuda_memcpy((void *)ptr_a_device, (void *)ptr_a, b * sizeof(void*), cudaMemcpyHostToDevice);
hidet_cuda_memcpy((void *)ptr_b_device, (void *)ptr_b, b * sizeof(void*), cudaMemcpyHostToDevice);
hidet_cuda_memcpy((void *)ptr_c_device, (void *)ptr_c, b * sizeof(void*), cudaMemcpyHostToDevice);
hidet_cuda_memcpy_async((void *)ptr_a_device, (void *)ptr_a, b * sizeof(void*), cudaMemcpyHostToDevice, cur_stream);
hidet_cuda_memcpy_async((void *)ptr_b_device, (void *)ptr_b, b * sizeof(void*), cudaMemcpyHostToDevice, cur_stream);
hidet_cuda_memcpy_async((void *)ptr_c_device, (void *)ptr_c, b * sizeof(void*), cudaMemcpyHostToDevice, cur_stream);

CHECK_CUBLAS(cublasGemmBatchedEx(
CublasContext::current_handle(),
Expand Down
24 changes: 22 additions & 2 deletions src/hidet/runtime/cuda/cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ typedef cudaError_t (*cudaGetDeviceCount_t)(int* count);
typedef cudaError_t (*cudaGetDevice_t)(int* device);
typedef cudaError_t (*cudaSetDevice_t)(int device);
typedef cudaError_t (*cudaMalloc_t)(void **devPtr, size_t size);
typedef cudaError_t (*cudaMallocAsync_t)(void **devPtr, size_t size, void *stream);
typedef cudaError_t (*cudaFree_t)(void *devPtr);
typedef cudaError_t (*cudaMemcpy_t)(void* dst, const void* src, size_t count, cudaMemcpyKind kind);
typedef cudaError_t (*cudaMemcpyAsync_t)(void* dst, const void* src, size_t count, cudaMemcpyKind kind, void *stream);
typedef const char* (*cudaGetErrorString_t)(cudaError_t error);

static std::string library_path;
Expand All @@ -29,8 +31,10 @@ static cudaGetDeviceCount_t cudaGetDeviceCount = nullptr;
static cudaGetDevice_t cudaGetDevice = nullptr;
static cudaSetDevice_t cudaSetDevice = nullptr;
static cudaMalloc_t cudaMalloc = nullptr;
static cudaMallocAsync_t cudaMallocAsync = nullptr;
static cudaFree_t cudaFree = nullptr;
static cudaMemcpy_t cudaMemcpy = nullptr;
static cudaMemcpyAsync_t cudaMemcpyAsync = nullptr;
static cudaGetErrorString_t cudaGetErrorString = nullptr;

// load cuda runtime APIs
Expand All @@ -52,8 +56,10 @@ static inline void lazy_load_cuda_runtime() {
cudaGetDevice = get_symbol<cudaGetDevice_t>(libcudart, "cudaGetDevice");
cudaSetDevice = get_symbol<cudaSetDevice_t>(libcudart, "cudaSetDevice");
cudaMalloc = get_symbol<cudaMalloc_t>(libcudart, "cudaMalloc");
cudaMallocAsync = get_symbol<cudaMallocAsync_t>(libcudart, "cudaMallocAsync");
cudaFree = get_symbol<cudaFree_t>(libcudart, "cudaFree");
cudaMemcpy = get_symbol<cudaMemcpy_t>(libcudart, "cudaMemcpy");
cudaMemcpyAsync = get_symbol<cudaMemcpyAsync_t>(libcudart, "cudaMemcpyAsync");
cudaGetErrorString = get_symbol<cudaGetErrorString_t>(libcudart, "cudaGetErrorString");
}
}
Expand Down Expand Up @@ -89,9 +95,18 @@ DLL void hidet_cuda_set_device(int device) {
CHECK_CUDA(cudaSetDevice(device));
}

DLL void hidet_cuda_malloc(void **devPtr, size_t size) {
DLL void* hidet_cuda_malloc(size_t size) {
lazy_load_cuda_runtime();
CHECK_CUDA(cudaMalloc(devPtr, size));
void* devPtr = malloc(sizeof(void*));
CHECK_CUDA(cudaMalloc(&devPtr, size));
return devPtr;
}

DLL void* hidet_cuda_malloc_async(size_t size, void *stream) {
lazy_load_cuda_runtime();
void* devPtr = malloc(sizeof(void*));
CHECK_CUDA(cudaMallocAsync(&devPtr, size, stream));
return devPtr;
}

DLL void hidet_cuda_free(void *devPtr) {
Expand All @@ -103,3 +118,8 @@ DLL void hidet_cuda_memcpy(void* dst, const void* src, size_t count, cudaMemcpyK
lazy_load_cuda_runtime();
CHECK_CUDA(cudaMemcpy(dst, src, count, kind));
}

DLL void hidet_cuda_memcpy_async(void* dst, const void* src, size_t count, cudaMemcpyKind kind, void *stream) {
lazy_load_cuda_runtime();
CHECK_CUDA(cudaMemcpyAsync(dst, src, count, kind, stream));
}

0 comments on commit 69abe16

Please sign in to comment.