@@ -104,8 +104,8 @@ static UCS_CLASS_INIT_FUNC(uct_rc_gdaki_ep_t, const uct_ep_params_t *params)
104104 if (self -> umem == NULL ) {
105105 uct_ib_check_memlock_limit_msg (md -> super .dev .ibv_context ,
106106 UCS_LOG_LEVEL_ERROR ,
107- "mlx5dv_devx_umem_reg(size=%zu)" ,
108- dev_ep_size );
107+ "mlx5dv_devx_umem_reg(ptr=%p size=%zu)" ,
108+ self -> ep_gpu , dev_ep_size );
109109 status = UCS_ERR_NO_MEMORY ;
110110 goto err_mem ;
111111 }
@@ -567,17 +567,60 @@ static UCS_CLASS_DEFINE_NEW_FUNC(uct_rc_gdaki_iface_t, uct_iface_t, uct_md_h,
567567static UCS_CLASS_DEFINE_DELETE_FUNC (uct_rc_gdaki_iface_t , uct_iface_t ) ;
568568
569569static ucs_status_t
570- uct_gdaki_query_tl_devices (uct_md_h md , uct_tl_device_resource_t * * tl_devices_p ,
570+ uct_gdaki_md_check_uar (uct_ib_mlx5_md_t * md , CUdevice cuda_dev )
571+ {
572+ struct mlx5dv_devx_uar * uar ;
573+ ucs_status_t status ;
574+ CUcontext cuda_ctx ;
575+ unsigned flags ;
576+
577+ status = uct_ib_mlx5_devx_alloc_uar (md , 0 , & uar );
578+ if (status != UCS_OK ) {
579+ goto out ;
580+ }
581+
582+ status = UCT_CUDADRV_FUNC_LOG_ERR (
583+ cuDevicePrimaryCtxRetain (& cuda_ctx , cuda_dev ));
584+ if (status != UCS_OK ) {
585+ goto out_free_uar ;
586+ }
587+
588+ status = UCT_CUDADRV_FUNC_LOG_ERR (cuCtxPushCurrent (cuda_ctx ));
589+ if (status != UCS_OK ) {
590+ goto out_ctx_release ;
591+ }
592+
593+ flags = CU_MEMHOSTREGISTER_PORTABLE | CU_MEMHOSTREGISTER_DEVICEMAP |
594+ CU_MEMHOSTREGISTER_IOMEMORY ;
595+ status = UCT_CUDADRV_FUNC_LOG_DEBUG (
596+ cuMemHostRegister (uar -> reg_addr , UCT_IB_MLX5_BF_REG_SIZE , flags ));
597+ if (status == UCS_OK ) {
598+ UCT_CUDADRV_FUNC_LOG_DEBUG (cuMemHostUnregister (uar -> reg_addr ));
599+ }
600+
601+ UCT_CUDADRV_FUNC_LOG_WARN (cuCtxPopCurrent (NULL ));
602+ out_ctx_release :
603+ UCT_CUDADRV_FUNC_LOG_WARN (cuDevicePrimaryCtxRelease (cuda_dev ));
604+ out_free_uar :
605+ mlx5dv_devx_free_uar (uar );
606+ out :
607+ return status ;
608+ }
609+
610+ static ucs_status_t
611+ uct_gdaki_query_tl_devices (uct_md_h tl_md ,
612+ uct_tl_device_resource_t * * tl_devices_p ,
571613 unsigned * num_tl_devices_p )
572614{
573- uct_ib_md_t * ib_md = ucs_derived_of (md , uct_ib_md_t );
574- unsigned num_tl_devices = 0 ;
615+ static int uar_supported = -1 ;
616+ uct_ib_mlx5_md_t * md = ucs_derived_of (tl_md , uct_ib_mlx5_md_t );
617+ unsigned num_tl_devices = 0 ;
575618 uct_tl_device_resource_t * tl_devices ;
576619 ucs_status_t status ;
577620 CUdevice device ;
578621 ucs_sys_device_t dev ;
579622 ucs_sys_dev_distance_t dist ;
580- int num_gpus ;
623+ int i , num_gpus ;
581624
582625 status = UCT_CUDADRV_FUNC_LOG_ERR (cuDeviceGetCount (& num_gpus ));
583626 if (status != UCS_OK ) {
@@ -589,14 +632,35 @@ uct_gdaki_query_tl_devices(uct_md_h md, uct_tl_device_resource_t **tl_devices_p,
589632 return UCS_ERR_NO_MEMORY ;
590633 }
591634
592- for (int i = 0 ; i < num_gpus ; i ++ ) {
635+ for (i = 0 ; i < num_gpus ; i ++ ) {
593636 status = UCT_CUDADRV_FUNC_LOG_ERR (cuDeviceGet (& device , i ));
594637 if (status != UCS_OK ) {
595638 goto err ;
596639 }
597640
641+ /*
642+ * Save the result of UAR support in a global flag since to avoid the
643+ * overhead of checking UAR support for each GPU and MD. Assume the
644+ * support is the same for all GPUs and MDs in the system.
645+ */
646+ if (uar_supported == -1 ) {
647+ status = uct_gdaki_md_check_uar (md , device );
648+ if (status == UCS_OK ) {
649+ uar_supported = 1 ;
650+ } else {
651+ ucs_diag ("GDAKI not supported, please add "
652+ "NVreg_RegistryDwords=\"PeerMappingOverride=1;\" "
653+ "option for nvidia kernel driver" );
654+ uar_supported = 0 ;
655+ }
656+ }
657+ if (uar_supported == 0 ) {
658+ status = UCS_ERR_NO_DEVICE ;
659+ goto err ;
660+ }
661+
598662 uct_cuda_base_get_sys_dev (device , & dev );
599- status = ucs_topo_get_distance (dev , ib_md -> dev .sys_dev , & dist );
663+ status = ucs_topo_get_distance (dev , md -> super . dev .sys_dev , & dist );
600664 if (status != UCS_OK ) {
601665 goto err ;
602666 }
@@ -608,8 +672,8 @@ uct_gdaki_query_tl_devices(uct_md_h md, uct_tl_device_resource_t **tl_devices_p,
608672
609673 snprintf (tl_devices [num_tl_devices ].name ,
610674 sizeof (tl_devices [num_tl_devices ].name ), "%s%d-%s:%d" ,
611- UCT_DEVICE_CUDA_NAME , device , uct_ib_device_name ( & ib_md -> dev ),
612- ib_md -> dev .first_port );
675+ UCT_DEVICE_CUDA_NAME , device ,
676+ uct_ib_device_name ( & md -> super . dev ), md -> super . dev .first_port );
613677 tl_devices [num_tl_devices ].type = UCT_DEVICE_TYPE_NET ;
614678 tl_devices [num_tl_devices ].sys_device = dev ;
615679 num_tl_devices ++ ;
0 commit comments