Skip to content

Commit f60a22c

Browse files
committed
UCT/GDA/MLX5: Check UAR is supported when querying resources
1 parent ad46593 commit f60a22c

File tree

5 files changed

+90
-25
lines changed

5 files changed

+90
-25
lines changed

src/ucp/wireup/select.c

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2446,7 +2446,7 @@ ucp_wireup_add_device_lanes(const ucp_wireup_select_params_t *select_params,
24462446
ucp_wireup_select_flags_t iface_rma_flags, peer_rma_flags;
24472447
ucp_wireup_select_bw_info_t bw_info = {};
24482448
ucp_tl_bitmap_t mem_type_tl_bitmap;
2449-
ucp_tl_bitmap_t tl_bitmap;
2449+
int found_lane;
24502450

24512451
if (!context->config.ext.proto_enable ||
24522452
(ep_init_flags &
@@ -2478,15 +2478,15 @@ ucp_wireup_add_device_lanes(const ucp_wireup_select_params_t *select_params,
24782478
*/
24792479
bw_info.max_lanes = ucp_wireup_bw_max_lanes(select_params);
24802480

2481-
UCS_STATIC_BITMAP_RESET_ALL(&tl_bitmap);
24822481
ucp_wireup_memaccess_bitmap(context, UCS_MEMORY_TYPE_CUDA,
24832482
&mem_type_tl_bitmap);
2484-
(void)ucp_wireup_add_bw_lanes(select_params, &bw_info,
2485-
UCP_TL_BITMAP_AND_NOT(mem_type_tl_bitmap,
2486-
tl_bitmap),
2487-
UCP_NULL_LANE, select_ctx, 0);
2488-
2489-
UCS_STATIC_BITMAP_OR_INPLACE(&tl_bitmap, mem_type_tl_bitmap);
2483+
found_lane = ucp_wireup_add_bw_lanes(select_params, &bw_info,
2484+
mem_type_tl_bitmap, UCP_NULL_LANE,
2485+
select_ctx, 0);
2486+
if (!found_lane) {
2487+
ucs_error("could not find device lanes");
2488+
return UCS_ERR_UNREACHABLE;
2489+
}
24902490

24912491
return UCS_OK;
24922492
}

src/uct/ib/mlx5/gdaki/gdaki.c

Lines changed: 74 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@ static UCS_CLASS_INIT_FUNC(uct_rc_gdaki_ep_t, const uct_ep_params_t *params)
104104
if (self->umem == NULL) {
105105
uct_ib_check_memlock_limit_msg(md->super.dev.ibv_context,
106106
UCS_LOG_LEVEL_ERROR,
107-
"mlx5dv_devx_umem_reg(size=%zu)",
108-
dev_ep_size);
107+
"mlx5dv_devx_umem_reg(ptr=%p size=%zu)",
108+
self->ep_gpu, dev_ep_size);
109109
status = UCS_ERR_NO_MEMORY;
110110
goto err_mem;
111111
}
@@ -567,17 +567,60 @@ static UCS_CLASS_DEFINE_NEW_FUNC(uct_rc_gdaki_iface_t, uct_iface_t, uct_md_h,
567567
static UCS_CLASS_DEFINE_DELETE_FUNC(uct_rc_gdaki_iface_t, uct_iface_t);
568568

569569
static ucs_status_t
570-
uct_gdaki_query_tl_devices(uct_md_h md, uct_tl_device_resource_t **tl_devices_p,
570+
uct_gdaki_md_check_uar(uct_ib_mlx5_md_t *md, CUdevice cuda_dev)
571+
{
572+
struct mlx5dv_devx_uar *uar;
573+
ucs_status_t status;
574+
CUcontext cuda_ctx;
575+
unsigned flags;
576+
577+
status = uct_ib_mlx5_devx_alloc_uar(md, 0, &uar);
578+
if (status != UCS_OK) {
579+
goto out;
580+
}
581+
582+
status = UCT_CUDADRV_FUNC_LOG_ERR(
583+
cuDevicePrimaryCtxRetain(&cuda_ctx, cuda_dev));
584+
if (status != UCS_OK) {
585+
goto out_free_uar;
586+
}
587+
588+
status = UCT_CUDADRV_FUNC_LOG_ERR(cuCtxPushCurrent(cuda_ctx));
589+
if (status != UCS_OK) {
590+
goto out_ctx_release;
591+
}
592+
593+
flags = CU_MEMHOSTREGISTER_PORTABLE | CU_MEMHOSTREGISTER_DEVICEMAP |
594+
CU_MEMHOSTREGISTER_IOMEMORY;
595+
status = UCT_CUDADRV_FUNC_LOG_DEBUG(
596+
cuMemHostRegister(uar->reg_addr, UCT_IB_MLX5_BF_REG_SIZE, flags));
597+
if (status == UCS_OK) {
598+
UCT_CUDADRV_FUNC_LOG_DEBUG(cuMemHostUnregister(uar->reg_addr));
599+
}
600+
601+
UCT_CUDADRV_FUNC_LOG_WARN(cuCtxPopCurrent(NULL));
602+
out_ctx_release:
603+
UCT_CUDADRV_FUNC_LOG_WARN(cuDevicePrimaryCtxRelease(cuda_dev));
604+
out_free_uar:
605+
mlx5dv_devx_free_uar(uar);
606+
out:
607+
return status;
608+
}
609+
610+
static ucs_status_t
611+
uct_gdaki_query_tl_devices(uct_md_h tl_md,
612+
uct_tl_device_resource_t **tl_devices_p,
571613
unsigned *num_tl_devices_p)
572614
{
573-
uct_ib_md_t *ib_md = ucs_derived_of(md, uct_ib_md_t);
574-
unsigned num_tl_devices = 0;
615+
static int uar_supported = -1;
616+
uct_ib_mlx5_md_t *md = ucs_derived_of(tl_md, uct_ib_mlx5_md_t);
617+
unsigned num_tl_devices = 0;
575618
uct_tl_device_resource_t *tl_devices;
576619
ucs_status_t status;
577620
CUdevice device;
578621
ucs_sys_device_t dev;
579622
ucs_sys_dev_distance_t dist;
580-
int num_gpus;
623+
int i, num_gpus;
581624

582625
status = UCT_CUDADRV_FUNC_LOG_ERR(cuDeviceGetCount(&num_gpus));
583626
if (status != UCS_OK) {
@@ -589,14 +632,35 @@ uct_gdaki_query_tl_devices(uct_md_h md, uct_tl_device_resource_t **tl_devices_p,
589632
return UCS_ERR_NO_MEMORY;
590633
}
591634

592-
for (int i = 0; i < num_gpus; i++) {
635+
for (i = 0; i < num_gpus; i++) {
593636
status = UCT_CUDADRV_FUNC_LOG_ERR(cuDeviceGet(&device, i));
594637
if (status != UCS_OK) {
595638
goto err;
596639
}
597640

641+
/*
642+
* Save the result of UAR support in a global flag since to avoid the
643+
* overhead of checking UAR support for each GPU and MD. Assume the
644+
* support is the same for all GPUs and MDs in the system.
645+
*/
646+
if (uar_supported == -1) {
647+
status = uct_gdaki_md_check_uar(md, device);
648+
if (status == UCS_OK) {
649+
uar_supported = 1;
650+
} else {
651+
ucs_diag("GDAKI not supported, please add "
652+
"NVreg_RegistryDwords=\"PeerMappingOverride=1;\" "
653+
"option for nvidia kernel driver");
654+
uar_supported = 0;
655+
}
656+
}
657+
if (uar_supported == 0) {
658+
status = UCS_ERR_NO_DEVICE;
659+
goto err;
660+
}
661+
598662
uct_cuda_base_get_sys_dev(device, &dev);
599-
status = ucs_topo_get_distance(dev, ib_md->dev.sys_dev, &dist);
663+
status = ucs_topo_get_distance(dev, md->super.dev.sys_dev, &dist);
600664
if (status != UCS_OK) {
601665
goto err;
602666
}
@@ -608,8 +672,8 @@ uct_gdaki_query_tl_devices(uct_md_h md, uct_tl_device_resource_t **tl_devices_p,
608672

609673
snprintf(tl_devices[num_tl_devices].name,
610674
sizeof(tl_devices[num_tl_devices].name), "%s%d-%s:%d",
611-
UCT_DEVICE_CUDA_NAME, device, uct_ib_device_name(&ib_md->dev),
612-
ib_md->dev.first_port);
675+
UCT_DEVICE_CUDA_NAME, device,
676+
uct_ib_device_name(&md->super.dev), md->super.dev.first_port);
613677
tl_devices[num_tl_devices].type = UCT_DEVICE_TYPE_NET;
614678
tl_devices[num_tl_devices].sys_device = dev;
615679
num_tl_devices++;

src/uct/ib/mlx5/ib_mlx5.c

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -569,9 +569,8 @@ int uct_ib_mlx5_devx_uar_cmp(uct_ib_mlx5_devx_uar_t *uar,
569569
}
570570

571571
#if HAVE_DEVX
572-
static ucs_status_t uct_ib_mlx5_devx_alloc_uar(uct_ib_mlx5_md_t *md,
573-
uint32_t flags,
574-
struct mlx5dv_devx_uar **uar_p)
572+
ucs_status_t uct_ib_mlx5_devx_alloc_uar(uct_ib_mlx5_md_t *md, uint32_t flags,
573+
struct mlx5dv_devx_uar **uar_p)
575574
{
576575
const char *uar_type_str = (flags == UCT_IB_MLX5_UAR_ALLOC_TYPE_WC) ?
577576
"WC" : "NC_DEDICATED";

src/uct/ib/mlx5/ib_mlx5.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ typedef struct uct_ib_mlx5_md {
441441
uint8_t log_max_dci_stream_channels;
442442
uint32_t smkey_index;
443443
struct {
444-
/* Max dp ordering level per transport,
444+
/* Max dp ordering level per transport,
445445
as listed in uct_ib_mlx5_dp_ordering_t */
446446
uint8_t rc;
447447
uint8_t dc;
@@ -931,10 +931,12 @@ void uct_ib_mlx5_verbs_srq_cleanup(uct_ib_mlx5_srq_t *srq, struct ibv_srq *verbs
931931
/**
932932
* DEVX UAR API
933933
*/
934-
int uct_ib_mlx5_devx_uar_cmp(uct_ib_mlx5_devx_uar_t *uar,
935-
uct_ib_mlx5_md_t *md,
934+
int uct_ib_mlx5_devx_uar_cmp(uct_ib_mlx5_devx_uar_t *uar, uct_ib_mlx5_md_t *md,
936935
uct_ib_mlx5_mmio_mode_t mmio_mode);
937936

937+
ucs_status_t uct_ib_mlx5_devx_alloc_uar(uct_ib_mlx5_md_t *md, uint32_t flags,
938+
struct mlx5dv_devx_uar **uar_p);
939+
938940
ucs_status_t uct_ib_mlx5_devx_check_uar(uct_ib_mlx5_md_t *md);
939941

940942
ucs_status_t uct_ib_mlx5_devx_uar_init(uct_ib_mlx5_devx_uar_t *uar,

test/gtest/ucp/test_ucp_device.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,4 +189,4 @@ UCS_TEST_P(test_ucp_device, create_fail)
189189
ucp_device_mem_list_create(sender().ep(), &params1, &handle));
190190
}
191191

192-
UCP_INSTANTIATE_TEST_CASE_TLS_GPU_AWARE(test_ucp_device, gdaki, "rc,rc_gda")
192+
UCP_INSTANTIATE_TEST_CASE_TLS_GPU_AWARE(test_ucp_device, rc_gda, "rc,rc_gda")

0 commit comments

Comments
 (0)