Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix 2 cuda memory leaks in ggml-cuda.cu #5576

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 10 additions & 0 deletions examples/server/server.cpp
Expand Up @@ -309,6 +309,10 @@ struct llama_client_slot
}
};

#ifdef GGML_USE_CUBLAS
extern "C" GGML_CALL void ggml_free_cublas(void);
#endif

struct llama_server_context
{
llama_model *model = nullptr;
Expand Down Expand Up @@ -355,6 +359,10 @@ struct llama_server_context
llama_free_model(model);
model = nullptr;
}
#ifdef GGML_USE_CUBLAS
ggml_free_cublas();
#endif

}

bool load_model(const gpt_params &params_)
Expand Down Expand Up @@ -3217,6 +3225,7 @@ int main(int argc, char **argv)
sigemptyset (&sigint_action.sa_mask);
sigint_action.sa_flags = 0;
sigaction(SIGINT, &sigint_action, NULL);
sigaction(SIGUSR1, &sigint_action, NULL);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be removed before we merge. I was having trouble with some of my analysis tools with INT, so USR1 made it easier to iterate while testing the leak fix.

#elif defined (_WIN32)
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
Expand All @@ -3230,3 +3239,4 @@ int main(int argc, char **argv)
llama_backend_free();
return 0;
}

28 changes: 22 additions & 6 deletions ggml-cuda.cu
Expand Up @@ -39,6 +39,7 @@
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
#define cublasCreate hipblasCreate
#define cublasDestroy hipblasDestroy
#define cublasGemmEx hipblasGemmEx
#define cublasGemmBatchedEx hipblasGemmBatchedEx
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
Expand Down Expand Up @@ -7847,7 +7848,7 @@ static void ggml_cuda_pool_free_leg(int device, void * ptr, size_t size) {
g_cuda_pool_size[device] -= size;
}

#if !defined(GGML_USE_HIPBLAS)
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
// pool with virtual memory
static CUdeviceptr g_cuda_pool_addr[GGML_CUDA_MAX_DEVICES] = {0};
static size_t g_cuda_pool_used[GGML_CUDA_MAX_DEVICES] = {0};
Expand Down Expand Up @@ -7948,7 +7949,7 @@ static void ggml_cuda_pool_free(int device, void * ptr, size_t size) {
#else
#define ggml_cuda_pool_malloc ggml_cuda_pool_malloc_leg
#define ggml_cuda_pool_free ggml_cuda_pool_free_leg
#endif // !defined(GGML_USE_HIPBLAS)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))

template<typename T>
struct cuda_pool_alloc {
Expand Down Expand Up @@ -7991,10 +7992,11 @@ GGML_CALL bool ggml_cublas_loaded(void) {
return g_cublas_loaded;
}

static bool g_cublas_initialized = false;

GGML_CALL void ggml_init_cublas() {
static bool initialized = false;

if (!initialized) {
if (!g_cublas_initialized) {

#ifdef __HIP_PLATFORM_AMD__
// Workaround for a rocBLAS bug when using multiple graphics cards:
Expand All @@ -8004,7 +8006,7 @@ GGML_CALL void ggml_init_cublas() {
#endif

if (cudaGetDeviceCount(&g_device_count) != cudaSuccess) {
initialized = true;
g_cublas_initialized = true;
g_cublas_loaded = false;
fprintf(stderr, "%s: no " GGML_CUDA_NAME " devices found, " GGML_CUDA_NAME " will be disabled\n", __func__);
return;
Expand Down Expand Up @@ -8075,7 +8077,7 @@ GGML_CALL void ggml_init_cublas() {
// configure logging to stdout
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));

initialized = true;
g_cublas_initialized = true;
g_cublas_loaded = true;
}
}
Expand Down Expand Up @@ -11604,3 +11606,17 @@ GGML_CALL int ggml_backend_cuda_reg_devices() {
}
return device_count;
}

extern "C" GGML_CALL void ggml_free_cublas(void);
GGML_CALL void ggml_free_cublas(void) {
for (int id = 0; id < g_device_count; ++id) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
CU_CHECK(cuMemUnmap(g_cuda_pool_addr[id], g_cuda_pool_size[id]));
g_cuda_pool_size[id] = 0;
g_cuda_pool_addr[id] = 0;
#endif
CUBLAS_CHECK(cublasDestroy(g_cublas_handles[id]));
g_cublas_handles[id] = nullptr;
}
g_cublas_initialized = false;
}
3 changes: 3 additions & 0 deletions ggml-cuda.h
Expand Up @@ -20,6 +20,9 @@ extern "C" {
// Always success. To check if CUDA is actually loaded, use `ggml_cublas_loaded`.
GGML_API GGML_CALL void ggml_init_cublas(void);

// Release CUDA resources
GGML_API GGML_CALL void ggml_free_cublas(void);

// Returns `true` if there are available CUDA devices and cublas loads successfully; otherwise, it returns `false`.
GGML_API GGML_CALL bool ggml_cublas_loaded(void);

Expand Down