Skip to content

Commit 8d4f43e

Browse files
authored
UCT/CUDA: Removed unsafe usage of cuCtxGetId. (#10852)
1 parent 8897ae4 commit 8d4f43e

File tree

4 files changed

+91
-176
lines changed

4 files changed

+91
-176
lines changed

src/uct/cuda/base/cuda_iface.c

Lines changed: 5 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -236,35 +236,11 @@ ucs_status_t uct_cuda_base_iface_flush(uct_iface_h tl_iface, unsigned flags,
236236
return UCS_OK;
237237
}
238238

239-
static int uct_cuda_base_is_ctx_rsc_valid(const uct_cuda_ctx_rsc_t *ctx_rsc)
239+
void uct_cuda_base_stream_destroy(CUstream *stream)
240240
{
241-
#if CUDA_VERSION >= 12000
242-
unsigned long long ctx_id;
243-
CUresult result;
244-
245-
result = uct_cuda_base_ctx_get_id(ctx_rsc->ctx, &ctx_id);
246-
if (result == CUDA_ERROR_CONTEXT_IS_DESTROYED) {
247-
return 0;
248-
} else if (result != CUDA_SUCCESS) {
249-
UCT_CUDADRV_LOG(cuCtxGetId, UCS_LOG_LEVEL_WARN, result);
250-
return 0;
251-
}
252-
253-
return ctx_id == ctx_rsc->ctx_id;
254-
#else
255-
/* Best effort check on older Cuda versions */
256-
return uct_cuda_base_is_context_valid(ctx_rsc->ctx);
257-
#endif
258-
}
259-
260-
void uct_cuda_base_stream_destroy(const uct_cuda_ctx_rsc_t *ctx_rsc,
261-
CUstream *stream)
262-
{
263-
if ((*stream == NULL) || !uct_cuda_base_is_ctx_rsc_valid(ctx_rsc)) {
264-
return;
241+
if (*stream != NULL) {
242+
(void)UCT_CUDADRV_FUNC_LOG_WARN(cuStreamDestroy(*stream));
265243
}
266-
267-
UCT_CUDADRV_FUNC_LOG_WARN(cuStreamDestroy(*stream));
268244
}
269245

270246
static void
@@ -279,12 +255,8 @@ uct_cuda_base_event_desc_init(ucs_mpool_t *mp, void *obj, void *chunk)
279255
static void uct_cuda_base_event_desc_cleanup(ucs_mpool_t *mp, void *obj)
280256
{
281257
uct_cuda_event_desc_t *event_desc = obj;
282-
uct_cuda_ctx_rsc_t *ctx_rsc = ucs_container_of(mp, uct_cuda_ctx_rsc_t,
283-
event_mp);
284258

285-
if (uct_cuda_base_is_ctx_rsc_valid(ctx_rsc)) {
286-
UCT_CUDADRV_FUNC_LOG_WARN(cuEventDestroy(event_desc->event));
287-
}
259+
(void)UCT_CUDADRV_FUNC_LOG_WARN(cuEventDestroy(event_desc->event));
288260
}
289261

290262
void uct_cuda_base_queue_desc_init(uct_cuda_queue_desc_t *qdesc)
@@ -302,7 +274,7 @@ void uct_cuda_base_queue_desc_destroy(const uct_cuda_ctx_rsc_t *ctx_rsc,
302274
ucs_queue_length(&qdesc->event_queue));
303275
}
304276

305-
uct_cuda_base_stream_destroy(ctx_rsc, &qdesc->stream);
277+
uct_cuda_base_stream_destroy(&qdesc->stream);
306278
}
307279

308280
static ucs_mpool_ops_t uct_cuda_event_desc_mpool_ops = {

src/uct/cuda/base/cuda_iface.h

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -57,39 +57,15 @@ static UCS_F_ALWAYS_INLINE int uct_cuda_base_is_context_active()
5757
}
5858

5959

60-
static UCS_F_ALWAYS_INLINE int uct_cuda_base_is_context_valid(CUcontext ctx)
61-
{
62-
unsigned version;
63-
ucs_status_t status;
64-
65-
/* Check if CUDA context is valid by running a dummy operation on it */
66-
status = UCT_CUDADRV_FUNC_LOG_DEBUG(cuCtxGetApiVersion(ctx, &version));
67-
return (status == UCS_OK);
68-
}
69-
70-
71-
static UCS_F_ALWAYS_INLINE int uct_cuda_base_context_match(CUcontext ctx1,
72-
CUcontext ctx2)
73-
{
74-
return ((ctx1 != NULL) && (ctx1 == ctx2) &&
75-
uct_cuda_base_is_context_valid(ctx1));
76-
}
77-
78-
7960
static UCS_F_ALWAYS_INLINE CUresult
8061
uct_cuda_base_ctx_get_id(CUcontext ctx, unsigned long long *ctx_id_p)
8162
{
82-
unsigned long long ctx_id = 0;
83-
8463
#if CUDA_VERSION >= 12000
85-
CUresult result = cuCtxGetId(ctx, &ctx_id);
86-
if (ucs_unlikely(result != CUDA_SUCCESS)) {
87-
return result;
88-
}
89-
#endif
90-
91-
*ctx_id_p = ctx_id;
64+
return cuCtxGetId(ctx, ctx_id_p);
65+
#else
66+
*ctx_id_p = 0;
9267
return CUDA_SUCCESS;
68+
#endif
9369
}
9470

9571

@@ -192,8 +168,7 @@ void uct_cuda_base_queue_desc_init(uct_cuda_queue_desc_t *qdesc);
192168
void uct_cuda_base_queue_desc_destroy(const uct_cuda_ctx_rsc_t *ctx_rsc,
193169
uct_cuda_queue_desc_t *qdesc);
194170

195-
void uct_cuda_base_stream_destroy(const uct_cuda_ctx_rsc_t *ctx_rsc,
196-
CUstream *stream);
171+
void uct_cuda_base_stream_destroy(CUstream *stream);
197172

198173
#if (__CUDACC_VER_MAJOR__ >= 100000)
199174
void CUDA_CB uct_cuda_base_iface_stream_cb_fxn(void *arg);

src/uct/cuda/cuda_copy/cuda_copy_iface.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ static void uct_cuda_copy_ctx_rsc_destroy(uct_iface_h tl_iface,
313313
}
314314
}
315315

316-
uct_cuda_base_stream_destroy(cuda_ctx_rsc, &ctx_rsc->short_stream);
316+
uct_cuda_base_stream_destroy(&ctx_rsc->short_stream);
317317
ucs_free(ctx_rsc);
318318
}
319319

0 commit comments

Comments
 (0)