diff --git a/clang/lib/DPCT/APINames.inc b/clang/lib/DPCT/APINames.inc index 7c4ddee08a15..b288c57f79bb 100644 --- a/clang/lib/DPCT/APINames.inc +++ b/clang/lib/DPCT/APINames.inc @@ -1708,7 +1708,7 @@ ENTRY(cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags, cuOccupancyMaxActive ENTRY(cuOccupancyMaxPotentialBlockSize, cuOccupancyMaxPotentialBlockSize, false, NO_FLAG, P4, "comment") ENTRY(cuOccupancyMaxPotentialBlockSizeWithFlags, cuOccupancyMaxPotentialBlockSizeWithFlags, false, NO_FLAG, P4, "comment") ENTRY(cuCtxDisablePeerAccess, cuCtxDisablePeerAccess, false, NO_FLAG, P4, "comment") -ENTRY(cuCtxEnablePeerAccess, cuCtxEnablePeerAccess, false, NO_FLAG, P4, "comment") +ENTRY(cuCtxEnablePeerAccess, cuCtxEnablePeerAccess, true, NO_FLAG, P0, "DPCT1026/DPCT1027") ENTRY(cuDeviceCanAccessPeer, cuDeviceCanAccessPeer, true, NO_FLAG, P4, "DPCT1031") ENTRY(cuDeviceGetP2PAttribute, cuDeviceGetP2PAttribute, false, NO_FLAG, P4, "comment") ENTRY(cuDevicePrimaryCtxGetState, cuDevicePrimaryCtxGetState, false, NO_FLAG, P4, "comment") diff --git a/clang/lib/DPCT/APINamesMemory.inc b/clang/lib/DPCT/APINamesMemory.inc index e5b111010e5c..f24c1cc0a0bb 100644 --- a/clang/lib/DPCT/APINamesMemory.inc +++ b/clang/lib/DPCT/APINamesMemory.inc @@ -673,6 +673,20 @@ CONDITIONAL_FACTORY_ENTRY( "You can migrate the code with peer access extension by not specifying " "-no-dpcpp-extensions=peer_access.")) +CONDITIONAL_FACTORY_ENTRY( + UsePeerAccess(), + ASSIGNABLE_FACTORY(MEMBER_CALL_FACTORY_ENTRY( + "cuCtxEnablePeerAccess", + CALL(MapNames::getDpctNamespace() + "get_current_device"), false, + "ext_oneapi_enable_peer_access", + MEMBER_CALL(CALL(MapNames::getDpctNamespace() + "dev_mgr::instance"), false, + "get_device", ARG_WC(0)))), + REMOVE_API_FACTORY_ENTRY_WITH_MSG( + "cuCtxEnablePeerAccess", + "SYCL currently does not support memory access across peer devices. " + "You can migrate the code with peer access extension by not specifying " + "-no-dpcpp-extensions=peer_access.")) + CALL_FACTORY_ENTRY( "make_cudaExtent", CALL(DpctGlobalInfo::getCtadClass(MapNames::getClNamespace() + "range", 3), diff --git a/clang/lib/DPCT/ASTTraversal.cpp b/clang/lib/DPCT/ASTTraversal.cpp index 1e1915b7915a..41c8cc4bdd4d 100644 --- a/clang/lib/DPCT/ASTTraversal.cpp +++ b/clang/lib/DPCT/ASTTraversal.cpp @@ -6243,7 +6243,7 @@ void FunctionCallRule::registerMatcher(MatchFinder &MF) { "cudaPointerGetAttributes", "cuCtxSetCacheConfig", "cuCtxSetLimit", "cudaCtxResetPersistingL2Cache", "cuCtxResetPersistingL2Cache", "cudaStreamSetAttribute", "cudaStreamGetAttribute", "cudaProfilerStart", - "cudaProfilerStop", "__trap"); + "cudaProfilerStop", "__trap", "cuCtxEnablePeerAccess"); }; MF.addMatcher( diff --git a/clang/test/dpct/disable-all-extensions.cu b/clang/test/dpct/disable-all-extensions.cu index 3caa9adcd4c0..987c78b3c030 100644 --- a/clang/test/dpct/disable-all-extensions.cu +++ b/clang/test/dpct/disable-all-extensions.cu @@ -175,6 +175,10 @@ int peer_access() { // CHECK: DPCT1026:{{[0-9]+}}: The call to cudaDeviceDisablePeerAccess was removed because SYCL currently does not support memory access across peer devices. You can migrate the code with peer access extension by not specifying -no-dpcpp-extensions=peer_access. // CHECK: */ cudaDeviceDisablePeerAccess(0); + // CHECK: /* + // CHECK: DPCT1026:{{[0-9]+}}: The call to cuCtxEnablePeerAccess was removed because SYCL currently does not support memory access across peer devices. You can migrate the code with peer access extension by not specifying -no-dpcpp-extensions=peer_access. + // CHECK: */ + cuCtxEnablePeerAccess(0, 0); return 0; } diff --git a/clang/test/dpct/driver_context.cu b/clang/test/dpct/driver_context.cu index 4842192aef1b..4b1d679402d0 100644 --- a/clang/test/dpct/driver_context.cu +++ b/clang/test/dpct/driver_context.cu @@ -52,6 +52,9 @@ int main(){ // CHECK: MY_SAFE_CALL(DPCT_CHECK_ERROR(dpct::select_device(ctx))); MY_SAFE_CALL(cuCtxSetCurrent(ctx)); + // CHECK: dpct::get_current_device().ext_oneapi_enable_peer_access(dpct::dev_mgr::instance().get_device(ctx2)); + cuCtxEnablePeerAccess(ctx2, 0); + // CHECK: ctx2 = dpct::dev_mgr::instance().current_device_id(); cuCtxGetCurrent(&ctx2);