Skip to content

Commit 2fb4921

Browse files
committed
UCM/CUDA: Added hook for cuLibraryGetGlobal.
1 parent 897daff commit 2fb4921

File tree

2 files changed

+38
-23
lines changed

2 files changed

+38
-23
lines changed

src/ucm/cuda/cudamem.c

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,29 @@
6161
return ret; \
6262
}
6363

64+
#define UCM_CUDA_GET_GLOBAL_FUNC(_name, _obj_type) \
65+
CUresult ucm_##_name(CUdeviceptr *dptr, size_t *bytes, _obj_type obj, \
66+
const char *name) \
67+
{ \
68+
CUresult ret; \
69+
size_t size; \
70+
size_t *size_ptr; \
71+
if (dptr == NULL) { \
72+
/* Device pointer is not returned */ \
73+
return ucm_orig_##_name(dptr, bytes, obj, name); \
74+
} \
75+
size_ptr = (bytes == NULL) ? &size : bytes; \
76+
ucm_event_enter(); \
77+
ret = ucm_orig_##_name(dptr, size_ptr, obj, name); \
78+
if (ret == CUDA_SUCCESS) { \
79+
ucm_trace("%s(size_ptr=%p, obj=%p, name=%s) returned dptr=%p", \
80+
__func__, bytes, obj, name, (void*)(*dptr)); \
81+
ucm_cuda_dispatch_mem_alloc(*dptr, *size_ptr); \
82+
} \
83+
ucm_event_leave(); \
84+
return ret; \
85+
}
86+
6487
#define UCM_CUDA_FUNC_ENTRY(_func) \
6588
{ \
6689
{#_func, ucm_override_##_func}, (void**)&ucm_orig_##_func \
@@ -95,6 +118,10 @@ UCM_DEFINE_REPLACE_DLSYM_PTR_FUNC(cuMemAllocAsync, CUresult, -1, CUdeviceptr*,
95118
UCM_DEFINE_REPLACE_DLSYM_PTR_FUNC(cuMemAllocFromPoolAsync, CUresult, -1,
96119
CUdeviceptr*, size_t, CUmemoryPool, CUstream)
97120
#endif
121+
#if CUDART_VERSION >= 12000
122+
UCM_DEFINE_REPLACE_DLSYM_PTR_FUNC(cuLibraryGetGlobal, CUresult, -1,
123+
CUdeviceptr*, size_t*, CUlibrary, const char*)
124+
#endif
98125
UCM_DEFINE_REPLACE_DLSYM_PTR_FUNC(cuMemFree, CUresult, -1, CUdeviceptr)
99126
UCM_DEFINE_REPLACE_DLSYM_PTR_FUNC(cuMemFree_v2, CUresult, -1, CUdeviceptr)
100127
UCM_DEFINE_REPLACE_DLSYM_PTR_FUNC(cuMemFreeHost, CUresult, -1, void*)
@@ -210,29 +237,10 @@ UCM_CUDA_FREE_FUNC(cuMemFreeAsync, UCS_MEMORY_TYPE_CUDA, CUresult, arg0, 0,
210237
"ptr=0x%llx, stream=%p", CUdeviceptr, CUstream)
211238
#endif
212239

213-
CUresult ucm_cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t *bytes,
214-
CUmodule hmod, const char *name)
215-
{
216-
CUresult ret;
217-
size_t size;
218-
size_t *size_ptr;
219-
220-
if (dptr == NULL) {
221-
/* Device pointer is not returned */
222-
return ucm_orig_cuModuleGetGlobal_v2(dptr, bytes, hmod, name);
223-
}
224-
225-
size_ptr = (bytes == NULL) ? &size : bytes;
226-
ucm_event_enter();
227-
ret = ucm_orig_cuModuleGetGlobal_v2(dptr, size_ptr, hmod, name);
228-
if (ret == CUDA_SUCCESS) {
229-
ucm_trace("%s(size_ptr=%p, hmod=%p, name=%s) returned dptr=%p",
230-
__func__, bytes, hmod, name, (void*)(*dptr));
231-
ucm_cuda_dispatch_mem_alloc(*dptr, *size_ptr);
232-
}
233-
ucm_event_leave();
234-
return ret;
235-
}
240+
UCM_CUDA_GET_GLOBAL_FUNC(cuModuleGetGlobal_v2, CUmodule)
241+
#if CUDART_VERSION >= 12000
242+
UCM_CUDA_GET_GLOBAL_FUNC(cuLibraryGetGlobal, CUlibrary)
243+
#endif
236244

237245
static ucm_cuda_func_t ucm_cuda_driver_funcs[] = {
238246
UCM_CUDA_FUNC_ENTRY(cuMemAlloc),
@@ -245,6 +253,9 @@ static ucm_cuda_func_t ucm_cuda_driver_funcs[] = {
245253
#if CUDA_VERSION >= 11020
246254
UCM_CUDA_FUNC_ENTRY(cuMemAllocAsync),
247255
UCM_CUDA_FUNC_ENTRY(cuMemAllocFromPoolAsync),
256+
#endif
257+
#if CUDART_VERSION >= 12000
258+
UCM_CUDA_FUNC_ENTRY(cuLibraryGetGlobal),
248259
#endif
249260
UCM_CUDA_FUNC_ENTRY(cuMemFree),
250261
UCM_CUDA_FUNC_ENTRY(cuMemFree_v2),

src/ucm/cuda/cudamem.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ CUresult ucm_cuMemAllocAsync(CUdeviceptr *dptr, size_t size, CUstream hStream);
3030
CUresult ucm_cuMemAllocFromPoolAsync(CUdeviceptr *dptr, size_t size,
3131
CUmemoryPool pool, CUstream hStream);
3232
#endif
33+
#if CUDART_VERSION >= 12000
34+
CUresult ucm_cuLibraryGetGlobal(CUdeviceptr *dptr, size_t *bytes,
35+
CUlibrary library, const char *name);
36+
#endif
3337
CUresult ucm_cuMemFree(CUdeviceptr dptr);
3438
CUresult ucm_cuMemFree_v2(CUdeviceptr dptr);
3539
CUresult ucm_cuMemFreeHost(void *p);

0 commit comments

Comments
 (0)