Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix 2 cuda memory leaks in ggml-cuda.cu #5576

Closed
wants to merge 2 commits into from

Conversation

dhiltgen
Copy link
Contributor

@dhiltgen dhiltgen commented Feb 19, 2024

In our downstream usage of the server, we noticed it wouldn't fully unload the GPU when idle. Using the cuda memory leak detection tool I was able to find where the leaks were located.

compute-sanitizer --tool memcheck --leak-check full ./bin/server ...
========= Leaked 8,388,608 bytes at 0x7faf2c000000
=========     Saved host backtrace up to driver entry point at allocation time
=========     Host Frame: [0x2db39f]
=========                in /lib/x86_64-linux-gnu/libcuda.so.1
=========     Host Frame: [0xc33c3e]
=========                in /usr/local/cuda/lib64/libcublas.so.12
=========     Host Frame: [0xc00373]
=========                in /usr/local/cuda/lib64/libcublas.so.12
=========     Host Frame: [0xc422f5]
=========                in /usr/local/cuda/lib64/libcublas.so.12
=========     Host Frame: [0x8aa9bd]
=========                in /usr/local/cuda/lib64/libcublas.so.12
=========     Host Frame:cublasCreate_v2 [0x7f66f1]
=========                in /usr/local/cuda/lib64/libcublas.so.12
=========     Host Frame:ggml_init_cublas.part.0 in /home/daniel/code/llama.cpp/ggml-cuda.cu:8008 [0x199ee2]
=========                in /home/daniel/code/llama.cpp/build/./bin/server
=========     Host Frame:ggml_init in /home/daniel/code/llama.cpp/ggml.c:2428 [0x159070]
=========                in /home/daniel/code/llama.cpp/build/./bin/server
=========     Host Frame:llama_backend_init in /home/daniel/code/llama.cpp/llama.cpp:11191 [0xf1f8e]
=========                in /home/daniel/code/llama.cpp/build/./bin/server
=========     Host Frame:main in /home/daniel/code/llama.cpp/examples/server/server.cpp:2546 [0x25093]
=========                in /home/daniel/code/llama.cpp/build/./bin/server
=========     Host Frame:__libc_start_call_main in ../sysdeps/nptl/libc_start_call_main.h:58 [0x29d90]
=========                in /lib/x86_64-linux-gnu/libc.so.6
=========     Host Frame:__libc_start_main in ../csu/libc-start.c:379 [0x29e40]
=========                in /lib/x86_64-linux-gnu/libc.so.6
=========     Host Frame:_start [0x2e345]
=========                in /home/daniel/code/llama.cpp/build/./bin/server
========= 
========= Leaked 1,024 bytes at 0x7faf2dc00000
=========     Saved host backtrace up to driver entry point at allocation time
=========     Host Frame: [0x2db39f]
=========                in /lib/x86_64-linux-gnu/libcuda.so.1
=========     Host Frame: [0xc33c3e]
=========                in /usr/local/cuda/lib64/libcublas.so.12
=========     Host Frame: [0xc00373]
=========                in /usr/local/cuda/lib64/libcublas.so.12
=========     Host Frame: [0xc422f5]
=========                in /usr/local/cuda/lib64/libcublas.so.12
=========     Host Frame: [0x8aa9bd]
=========                in /usr/local/cuda/lib64/libcublas.so.12
=========     Host Frame: [0x8aa20b]
=========                in /usr/local/cuda/lib64/libcublas.so.12
=========     Host Frame:cublasCreate_v2 [0x7f66f1]
=========                in /usr/local/cuda/lib64/libcublas.so.12
=========     Host Frame:ggml_init_cublas.part.0 in /home/daniel/code/llama.cpp/ggml-cuda.cu:8008 [0x199ee2]
=========                in /home/daniel/code/llama.cpp/build/./bin/server
=========     Host Frame:ggml_init in /home/daniel/code/llama.cpp/ggml.c:2428 [0x159070]
=========                in /home/daniel/code/llama.cpp/build/./bin/server
=========     Host Frame:llama_backend_init in /home/daniel/code/llama.cpp/llama.cpp:11191 [0xf1f8e]
=========                in /home/daniel/code/llama.cpp/build/./bin/server
=========     Host Frame:main in /home/daniel/code/llama.cpp/examples/server/server.cpp:2546 [0x25093]
=========                in /home/daniel/code/llama.cpp/build/./bin/server
=========     Host Frame:__libc_start_call_main in ../sysdeps/nptl/libc_start_call_main.h:58 [0x29d90]
=========                in /lib/x86_64-linux-gnu/libc.so.6
=========     Host Frame:__libc_start_main in ../csu/libc-start.c:379 [0x29e40]
=========                in /lib/x86_64-linux-gnu/libc.so.6
=========     Host Frame:_start [0x2e345]
=========                in /home/daniel/code/llama.cpp/build/./bin/server
========= 
========= Leaked 131,072 bytes at 0x7faf2dc00400
=========     Saved host backtrace up to driver entry point at allocation time
=========     Host Frame: [0x2db39f]
=========                in /lib/x86_64-linux-gnu/libcuda.so.1
=========     Host Frame: [0xc33c3e]
=========                in /usr/local/cuda/lib64/libcublas.so.12
=========     Host Frame: [0xc00373]
=========                in /usr/local/cuda/lib64/libcublas.so.12
=========     Host Frame: [0xc422f5]
=========                in /usr/local/cuda/lib64/libcublas.so.12
=========     Host Frame: [0x8aa9bd]
=========                in /usr/local/cuda/lib64/libcublas.so.12
=========     Host Frame: [0x8aa22e]
=========                in /usr/local/cuda/lib64/libcublas.so.12
=========     Host Frame:cublasCreate_v2 [0x7f66f1]
=========                in /usr/local/cuda/lib64/libcublas.so.12
=========     Host Frame:ggml_init_cublas.part.0 in /home/daniel/code/llama.cpp/ggml-cuda.cu:8008 [0x199ee2]
=========                in /home/daniel/code/llama.cpp/build/./bin/server
=========     Host Frame:ggml_init in /home/daniel/code/llama.cpp/ggml.c:2428 [0x159070]
=========                in /home/daniel/code/llama.cpp/build/./bin/server
=========     Host Frame:llama_backend_init in /home/daniel/code/llama.cpp/llama.cpp:11191 [0xf1f8e]
=========                in /home/daniel/code/llama.cpp/build/./bin/server
=========     Host Frame:main in /home/daniel/code/llama.cpp/examples/server/server.cpp:2546 [0x25093]
=========                in /home/daniel/code/llama.cpp/build/./bin/server
=========     Host Frame:__libc_start_call_main in ../sysdeps/nptl/libc_start_call_main.h:58 [0x29d90]
=========                in /lib/x86_64-linux-gnu/libc.so.6
=========     Host Frame:__libc_start_main in ../csu/libc-start.c:379 [0x29e40]
=========                in /lib/x86_64-linux-gnu/libc.so.6
=========     Host Frame:_start [0x2e345]
=========                in /home/daniel/code/llama.cpp/build/./bin/server
========= 
========= Leaked 2,097,152 bytes at 0x4ea000000
=========     Saved host backtrace up to driver entry point at allocation time
=========     Host Frame: [0x2e90ad]
=========                in /lib/x86_64-linux-gnu/libcuda.so.1
=========     Host Frame:ggml_cuda_pool_malloc_vmm(int, unsigned long, unsigned long*) in /home/daniel/code/llama.cpp/ggml-cuda.cu:7834 [0x1b2e12]
=========                in /home/daniel/code/llama.cpp/build/./bin/server
=========     Host Frame:ggml_cuda_op_mul_mat(ggml_tensor const*, ggml_tensor const*, ggml_tensor*, void (*)(ggml_tensor const*, ggml_tensor const*, ggml_tensor*, char const*, float const*, char const*, float*, long, long, long, long, CUstream_st*), bool) in /home/daniel/code/llama.cpp/ggml-cuda.cu:9398 [0x1b4004]
=========                in /home/daniel/code/llama.cpp/build/./bin/server
=========     Host Frame:ggml_cuda_compute_forward.part.0 in /home/daniel/code/llama.cpp/ggml-cuda.cu:10632 [0x19a3f5]
=========                in /home/daniel/code/llama.cpp/build/./bin/server
=========     Host Frame:ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) in /home/daniel/code/llama.cpp/ggml-cuda.cu:11323 [0x19a862]
=========                in /home/daniel/code/llama.cpp/build/./bin/server
=========     Host Frame:ggml_backend_sched_graph_compute in /home/daniel/code/llama.cpp/ggml-backend.c:1583 [0x179330]
=========                in /home/daniel/code/llama.cpp/build/./bin/server
=========     Host Frame:llama_decode_internal(llama_context&, llama_batch) in /home/daniel/code/llama.cpp/llama.cpp:7722 [0xf8eed]
=========                in /home/daniel/code/llama.cpp/build/./bin/server
=========     Host Frame:llama_decode in /home/daniel/code/llama.cpp/llama.cpp:12287 [0xf9aa3]
=========                in /home/daniel/code/llama.cpp/build/./bin/server
=========     Host Frame:llama_init_from_gpt_params(gpt_params&) in /home/daniel/code/llama.cpp/common/common.cpp:1361 [0xd8e6d]
=========                in /home/daniel/code/llama.cpp/build/./bin/server
=========     Host Frame:llama_server_context::load_model(gpt_params const&) in /home/daniel/code/llama.cpp/examples/server/server.cpp:383 [0x8024d]
=========                in /home/daniel/code/llama.cpp/build/./bin/server
=========     Host Frame:main in /home/daniel/code/llama.cpp/examples/server/server.cpp:2669 [0x262d4]
=========                in /home/daniel/code/llama.cpp/build/./bin/server
=========     Host Frame:__libc_start_call_main in ../sysdeps/nptl/libc_start_call_main.h:58 [0x29d90]
=========                in /lib/x86_64-linux-gnu/libc.so.6
=========     Host Frame:__libc_start_main in ../csu/libc-start.c:379 [0x29e40]
=========                in /lib/x86_64-linux-gnu/libc.so.6
=========     Host Frame:_start [0x2e345]
=========                in /home/daniel/code/llama.cpp/build/./bin/server
========= 
========= LEAK SUMMARY: 10617856 bytes leaked in 4 allocations
========= ERROR SUMMARY: 4 errors

I'm not sure where the best place is to wire this up, so I'm open to suggestions from reviewers, however as coded it does pass the memcheck tool without leaks, and when we integrate this into our usage of server.cpp we do see the GPU unload when it goes idle.

This fixes 2 memory leaks in ggml-cuda.cu

  • cuMemUnmap called now for the pool allocation
  • cublasDestroy called to release cublas handles

@@ -3217,6 +3225,7 @@ int main(int argc, char **argv)
sigemptyset (&sigint_action.sa_mask);
sigint_action.sa_flags = 0;
sigaction(SIGINT, &sigint_action, NULL);
sigaction(SIGUSR1, &sigint_action, NULL);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be removed before we merge. I was having trouble with some of my analysis tools with INT, so USR1 made it easier to iterate while testing the leak fix.

This fixes 2 memory leaks in ggml-cuda.cu
- cuMemUnmap called now for the pool allocation
- cublasDestroy called to release cublas handles
Copy link
Collaborator

@cebtenzzre cebtenzzre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nits:

  • Extra blank line added at end of server.cpp, but missing newline at end of ggml-cuda.cu
  • ggml_free_cublas is declared three different times?

With the new ggml-backend interface, something like ggml_free_cublas isn't really supposed to be part of the API, but neither is ggml_init_cublas...

ggml-cuda.cu Outdated
extern "C" GGML_CALL void ggml_free_cublas(void);
GGML_CALL void ggml_free_cublas(void) {
for (int id = 0; id < g_device_count; ++id) {
#if !defined(GGML_USE_HIPBLAS)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be

Suggested change
#if !defined(GGML_USE_HIPBLAS)
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))

but then the condition for vmm in ggml-cuda.cu should also be changed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

@slaren
Copy link
Collaborator

slaren commented Feb 19, 2024

The correct way to fix would be:

  • Move the automatic offloading logic from ggml.c to ggml_backend_sched
  • Make the pool private to the ggml_backend instance
  • Free the pool when freeing the ggml_backend instance

I don't really see the point of doing this at this point and in this way, this memory is intentionally never freed, freeing it just before the application ends achieves nothing, the memory would be freed by the OS anyway.

@dhiltgen
Copy link
Contributor Author

dhiltgen commented Feb 19, 2024

I don't really see the point of doing this at this point and in this way, this memory is intentionally never freed, freeing it just before the application ends achieves nothing, the memory would be freed by the OS anyway.

We use this code in a downstream project as a library, in a long running process, and we need the library not to leak when we unload models and the system is idle so that the GPU resources get freed up and the GPU is allowed to return to a low power state.

My intent on wiring this into server.cpp is simply to demonstrate the leak and the fix so folks in this community can repro. Depending on how this ultimately gets wired into the shutdown logic, it should apply to any/all uses I would think.

@dhiltgen
Copy link
Contributor Author

With the new ggml-backend interface, something like ggml_free_cublas isn't really supposed to be part of the API, but neither is ggml_init_cublas...

I agree, I don't think I've wired ggml_free_cublas correctly, but I'm not sure what the right pattern is for this so I'm hoping maintainers can direct me. (perhaps automatically wired up in an existing shutdown/free routine, or maybe as a new API in ggml.h?)

@zsogitbe zsogitbe mentioned this pull request Mar 5, 2024
zsogitbe added a commit to zsogitbe/llama.cpp that referenced this pull request Mar 6, 2024
I have corrected the PR ggerganov#5576 which causes crash and streamlined the code.
Unfortunately, this does not free all occupied GPU memory yet (only 15% of it). We still need to find some objects which are not freed after releasing GPU memory.
@zsogitbe zsogitbe mentioned this pull request Mar 6, 2024
@zsogitbe
Copy link

zsogitbe commented Mar 6, 2024

This PR causes crash in the pooling code (cuMemUnmap). I have corrected the code and streamlined it for llama.cpp: #5898

@dhiltgen
Copy link
Contributor Author

Closing in favor of #5898

@dhiltgen dhiltgen closed this Mar 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants