Skip to content

Commit

Permalink
Added functions for resetting devices
Browse files Browse the repository at this point in the history
  • Loading branch information
crashr committed Apr 28, 2024
1 parent 4dba7e8 commit ea271b6
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 0 deletions.
13 changes: 13 additions & 0 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2701,6 +2701,19 @@ GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t * free, si
CUDA_CHECK(cudaMemGetInfo(free, total));
}

GGML_CALL void ggml_backend_cuda_reset_device(int device) {
ggml_cuda_set_device(device);
CUDA_CHECK(cudaDeviceReset());
}

GGML_CALL void ggml_backend_cuda_reset_all_devices(void) {
size_t device_count = ggml_backend_cuda_get_device_count();
for (size_t i = 0; i < device_count; i++) {
ggml_cuda_set_device(i);
CUDA_CHECK(cudaDeviceReset());
}
}

GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
if (getenv("GGML_CUDA_REGISTER_HOST") == nullptr) {
return false;
Expand Down
3 changes: 3 additions & 0 deletions ggml-cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ GGML_API GGML_CALL int ggml_backend_cuda_get_device_count(void);
GGML_API GGML_CALL void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size);
GGML_API GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total);

GGML_API GGML_CALL void ggml_backend_cuda_reset_device(int device);
GGML_API GGML_CALL void ggml_backend_cuda_reset_all_devices(void);

GGML_API GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size);
GGML_API GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer);

Expand Down
24 changes: 24 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1748,6 +1748,30 @@ static size_t llama_get_device_memory(int device) {
#endif
}

void llama_reset_device(int device) {
#if defined(GGML_USE_CUDA)
ggml_backend_cuda_reset_device(device);
#endif
#if defined(GGML_USE_SYCL)
// TODO
#endif
#if defined(GGML_USE_VULKAN)
// TODO
#endif
}

void llama_reset_all_devices(void) {
#if defined(GGML_USE_CUDA)
ggml_backend_cuda_reset_all_devices();
#endif
#if defined(GGML_USE_SYCL)
// TODO
#endif
#if defined(GGML_USE_VULKAN)
// TODO
#endif
}

//
// globals
//
Expand Down
3 changes: 3 additions & 0 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1079,6 +1079,9 @@ extern "C" {
}
#endif

void llama_reset_device(int device);
void llama_reset_all_devices(void);

// Internal API to be implemented by llama.cpp and used by tests/benchmarks only
#ifdef LLAMA_API_INTERNAL

Expand Down

0 comments on commit ea271b6

Please sign in to comment.