From d05ca89fb09ee5f0ba37477f8099744f13a27218 Mon Sep 17 00:00:00 2001 From: Valentin Clement Date: Mon, 24 Nov 2025 15:27:23 -0800 Subject: [PATCH] [flang][cuda] Add support for cluster_block_index in cooperative groups --- .../Optimizer/Builder/CUDAIntrinsicCall.h | 1 + .../Optimizer/Builder/CUDAIntrinsicCall.cpp | 61 +++++++++++++------ flang/module/cooperative_groups.f90 | 7 +++ flang/test/Lower/CUDA/cuda-cluster.cuf | 21 +++++++ 4 files changed, 73 insertions(+), 17 deletions(-) diff --git a/flang/include/flang/Optimizer/Builder/CUDAIntrinsicCall.h b/flang/include/flang/Optimizer/Builder/CUDAIntrinsicCall.h index cedc7a9437eb5..977bc0f4ee58c 100644 --- a/flang/include/flang/Optimizer/Builder/CUDAIntrinsicCall.h +++ b/flang/include/flang/Optimizer/Builder/CUDAIntrinsicCall.h @@ -47,6 +47,7 @@ struct CUDAIntrinsicLibrary : IntrinsicLibrary { void genBarrierInit(llvm::ArrayRef); mlir::Value genBarrierTryWait(mlir::Type, llvm::ArrayRef); mlir::Value genBarrierTryWaitSleep(mlir::Type, llvm::ArrayRef); + mlir::Value genClusterBlockIndex(mlir::Type, llvm::ArrayRef); mlir::Value genClusterDimBlocks(mlir::Type, llvm::ArrayRef); void genFenceProxyAsync(llvm::ArrayRef); template diff --git a/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp b/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp index a770e2d9cdeff..a0d9678683e44 100644 --- a/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp @@ -368,6 +368,11 @@ static constexpr IntrinsicHandler cudaHandlers[]{ &CI::genNVVMTime), {}, /*isElemental=*/false}, + {"cluster_block_index", + static_cast( + &CI::genClusterBlockIndex), + {}, + /*isElemental=*/false}, {"cluster_dim_blocks", static_cast( &CI::genClusterDimBlocks), @@ -990,6 +995,42 @@ CUDAIntrinsicLibrary::genBarrierTryWaitSleep(mlir::Type resultType, .getResult(0); } +static void insertValueAtPos(fir::FirOpBuilder &builder, mlir::Location loc, + fir::RecordType recTy, mlir::Value base, + mlir::Value dim, unsigned fieldPos) { + auto fieldName = recTy.getTypeList()[fieldPos].first; + mlir::Type fieldTy = recTy.getTypeList()[fieldPos].second; + mlir::Type fieldIndexType = fir::FieldType::get(base.getContext()); + mlir::Value fieldIndex = + fir::FieldIndexOp::create(builder, loc, fieldIndexType, fieldName, recTy, + /*typeParams=*/mlir::ValueRange{}); + mlir::Value coord = fir::CoordinateOp::create( + builder, loc, builder.getRefType(fieldTy), base, fieldIndex); + fir::StoreOp::create(builder, loc, dim, coord); +} + +// CLUSTER_BLOCK_INDEX +mlir::Value +CUDAIntrinsicLibrary::genClusterBlockIndex(mlir::Type resultType, + llvm::ArrayRef args) { + assert(args.size() == 0); + auto recTy = mlir::cast(resultType); + assert(recTy && "RecordType expepected"); + mlir::Value res = fir::AllocaOp::create(builder, loc, resultType); + mlir::Type i32Ty = builder.getI32Type(); + mlir::Value x = mlir::NVVM::BlockInClusterIdXOp::create(builder, loc, i32Ty); + mlir::Value one = builder.createIntegerConstant(loc, i32Ty, 1); + x = mlir::arith::AddIOp::create(builder, loc, x, one); + insertValueAtPos(builder, loc, recTy, res, x, 0); + mlir::Value y = mlir::NVVM::BlockInClusterIdYOp::create(builder, loc, i32Ty); + y = mlir::arith::AddIOp::create(builder, loc, y, one); + insertValueAtPos(builder, loc, recTy, res, y, 1); + mlir::Value z = mlir::NVVM::BlockInClusterIdZOp::create(builder, loc, i32Ty); + z = mlir::arith::AddIOp::create(builder, loc, z, one); + insertValueAtPos(builder, loc, recTy, res, z, 2); + return res; +} + // CLUSTER_DIM_BLOCKS mlir::Value CUDAIntrinsicLibrary::genClusterDimBlocks(mlir::Type resultType, @@ -998,27 +1039,13 @@ CUDAIntrinsicLibrary::genClusterDimBlocks(mlir::Type resultType, auto recTy = mlir::cast(resultType); assert(recTy && "RecordType expepected"); mlir::Value res = fir::AllocaOp::create(builder, loc, resultType); - - auto insertDim = [&](mlir::Value dim, unsigned fieldPos) { - auto fieldName = recTy.getTypeList()[fieldPos].first; - mlir::Type fieldTy = recTy.getTypeList()[fieldPos].second; - mlir::Type fieldIndexType = fir::FieldType::get(resultType.getContext()); - mlir::Value fieldIndex = fir::FieldIndexOp::create( - builder, loc, fieldIndexType, fieldName, recTy, - /*typeParams=*/mlir::ValueRange{}); - mlir::Value coord = fir::CoordinateOp::create( - builder, loc, builder.getRefType(fieldTy), res, fieldIndex); - fir::StoreOp::create(builder, loc, dim, coord); - }; - mlir::Type i32Ty = builder.getI32Type(); mlir::Value x = mlir::NVVM::ClusterDimBlocksXOp::create(builder, loc, i32Ty); - insertDim(x, 0); + insertValueAtPos(builder, loc, recTy, res, x, 0); mlir::Value y = mlir::NVVM::ClusterDimBlocksYOp::create(builder, loc, i32Ty); - insertDim(y, 1); + insertValueAtPos(builder, loc, recTy, res, y, 1); mlir::Value z = mlir::NVVM::ClusterDimBlocksZOp::create(builder, loc, i32Ty); - insertDim(z, 2); - + insertValueAtPos(builder, loc, recTy, res, z, 2); return res; } diff --git a/flang/module/cooperative_groups.f90 b/flang/module/cooperative_groups.f90 index 2631975837a5b..8bb4af3afa791 100644 --- a/flang/module/cooperative_groups.f90 +++ b/flang/module/cooperative_groups.f90 @@ -38,6 +38,13 @@ module cooperative_groups integer(4) :: rank end type thread_group +interface + attributes(device) function cluster_block_index() + import + type(dim3) :: cluster_block_index + end function +end interface + interface attributes(device) function cluster_dim_blocks() import diff --git a/flang/test/Lower/CUDA/cuda-cluster.cuf b/flang/test/Lower/CUDA/cuda-cluster.cuf index 51cc4208a35de..78cca15b11dab 100644 --- a/flang/test/Lower/CUDA/cuda-cluster.cuf +++ b/flang/test/Lower/CUDA/cuda-cluster.cuf @@ -32,3 +32,24 @@ end subroutine ! CHECK: %[[Z:.*]] = nvvm.read.ptx.sreg.cluster.nctaid.z : i32 ! CHECK: %[[COORD_Z:.*]] = fir.coordinate_of %{{.*}}, z : (!fir.ref>) -> !fir.ref ! CHECK: fir.store %[[Z]] to %[[COORD_Z]] : !fir.ref + +attributes(global) subroutine test_cluster_block_index() + use cooperative_groups + type(dim3) :: blockIndex + + blockIndex = cluster_block_index() +end subroutine + +! CHECK-LABEL: func.func @_QPtest_cluster_block_index() attributes {cuf.proc_attr = #cuf.cuda_proc} +! CHECK: %[[X:.*]] = nvvm.read.ptx.sreg.cluster.ctaid.x : i32 +! CHECK: %[[X1:.*]] = arith.addi %[[X]], %c1{{.*}} : i32 +! CHECK: %[[COORD_X:.*]] = fir.coordinate_of %{{.*}}, x : (!fir.ref>) -> !fir.ref +! CHECK: fir.store %[[X1]] to %[[COORD_X]] : !fir.ref +! CHECK: %[[Y:.*]] = nvvm.read.ptx.sreg.cluster.ctaid.y : i32 +! CHECK: %[[Y1:.*]] = arith.addi %[[Y]], %c1{{.*}} : i32 +! CHECK: %[[COORD_Y:.*]] = fir.coordinate_of %{{.*}}, y : (!fir.ref>) -> !fir.ref +! CHECK: fir.store %[[Y1]] to %[[COORD_Y]] : !fir.ref +! CHECK: %[[Z:.*]] = nvvm.read.ptx.sreg.cluster.ctaid.z : i32 +! CHECK: %[[Z1:.*]] = arith.addi %[[Z]], %c1{{.*}} : i32 +! CHECK: %[[COORD_Z:.*]] = fir.coordinate_of %{{.*}}, z : (!fir.ref>) -> !fir.ref +! CHECK: fir.store %[[Z1]] to %[[COORD_Z]] : !fir.ref