diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf index 69fefcf972065..434322ea22265 100644 --- a/flang/test/Lower/CUDA/cuda-device-proc.cuf +++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf @@ -477,7 +477,7 @@ end subroutine ! CHECK: %[[DST_7:.*]] = llvm.addrspacecast %[[DST_PTR]] : !llvm.ptr to !llvm.ptr<7> ! CHECK: %[[SRC_PTR:.*]] = fir.convert %[[SRC]] : (!fir.ref) -> !llvm.ptr ! CHECK: %[[SRC_3:.*]] = llvm.addrspacecast %[[SRC_PTR]] : !llvm.ptr to !llvm.ptr<1> -! CHECK: nvvm.cp.async.bulk.shared.cluster.global %[[DST_7]], %[[SRC_3]], %[[BARRIER_3]], %[[COUNT_LOAD]] : <7>, <1> +! CHECK: nvvm.cp.async.bulk.shared.cluster.global %[[DST_7]], %[[SRC_3]], %[[BARRIER_3]], %[[COUNT_LOAD]] : !llvm.ptr<7>, <1> attributes(global) subroutine test_bulk_s2g(a) real(8), device :: a(*) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 995ade5c9b033..d4ef5104d3c1f 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -3342,16 +3342,17 @@ def NVVM_CpAsyncBulkTensorReduceOp : def NVVM_CpAsyncBulkGlobalToSharedClusterOp : NVVM_Op<"cp.async.bulk.shared.cluster.global", [AttrSizedOperandSegments]> { - let summary = "Async bulk copy from global memory to Shared cluster memory"; + let summary = "Async bulk copy from global to Shared {cta or cluster} memory"; let description = [{ - Initiates an asynchronous copy operation from global memory to cluster's - shared memory. + Initiates an asynchronous copy operation from global memory to shared + memory or shared_cluster memory. - The `multicastMask` operand is optional. When it is present, the Op copies + The `multicastMask` operand is optional and can be used only when the + destination is shared::cluster memory. When it is present, this Op copies data from global memory to shared memory of multiple CTAs in the cluster. Operand `multicastMask` specifies the destination CTAs in the cluster such that each bit position in the 16-bit `multicastMask` operand corresponds to - the `nvvm.read.ptx.sreg.ctaid` of the destination CTA. + the `nvvm.read.ptx.sreg.ctaid` of the destination CTA. The `l2CacheHint` operand is optional, and it is used to specify cache eviction policy that may be used during the memory access. @@ -3360,7 +3361,7 @@ def NVVM_CpAsyncBulkGlobalToSharedClusterOp : }]; let arguments = (ins - LLVM_PointerSharedCluster:$dstMem, + AnyTypeOf<[LLVM_PointerShared, LLVM_PointerSharedCluster]>:$dstMem, LLVM_PointerGlobal:$srcMem, LLVM_PointerShared:$mbar, I32:$size, @@ -3374,6 +3375,8 @@ def NVVM_CpAsyncBulkGlobalToSharedClusterOp : attr-dict `:` type($dstMem) `,` type($srcMem) }]; + let hasVerifier = 1; + let extraClassDeclaration = [{ static mlir::NVVM::IDArgPair getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 0f7b3638fb30d..7ac427dbe3941 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -212,6 +212,14 @@ LogicalResult CpAsyncBulkTensorReduceOp::verify() { return success(); } +LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() { + bool isSharedCTA = isPtrInSharedCTASpace(getDstMem()); + if (isSharedCTA && getMulticastMask()) + return emitError("Multicast is not supported with shared::cta mode."); + + return success(); +} + LogicalResult ConvertFloatToTF32Op::verify() { using RndMode = NVVM::FPRoundingMode; switch (getRnd()) { @@ -1980,11 +1988,15 @@ mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs( args.push_back(mt.lookupValue(thisOp.getSrcMem())); args.push_back(mt.lookupValue(thisOp.getSize())); - // Multicast mask, if available. + // Multicast mask for shared::cluster only, if available. mlir::Value multicastMask = thisOp.getMulticastMask(); const bool hasMulticastMask = static_cast(multicastMask); - llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0); - args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) : i16Unused); + const bool isSharedCTA = isPtrInSharedCTASpace(thisOp.getDstMem()); + if (!isSharedCTA) { + llvm::Value *i16Unused = llvm::ConstantInt::get(builder.getInt16Ty(), 0); + args.push_back(hasMulticastMask ? mt.lookupValue(multicastMask) + : i16Unused); + } // Cache hint, if available. mlir::Value cacheHint = thisOp.getL2CacheHint(); @@ -1993,11 +2005,14 @@ mlir::NVVM::IDArgPair CpAsyncBulkGlobalToSharedClusterOp::getIntrinsicIDAndArgs( args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused); // Flag arguments for multicast and cachehint. - args.push_back(builder.getInt1(hasMulticastMask)); + if (!isSharedCTA) + args.push_back(builder.getInt1(hasMulticastMask)); args.push_back(builder.getInt1(hasCacheHint)); llvm::Intrinsic::ID id = - llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster; + isSharedCTA + ? llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cta + : llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster; return {id, std::move(args)}; } diff --git a/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir index 0daf24536a672..240fab5b63908 100644 --- a/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir +++ b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy.mlir @@ -16,6 +16,17 @@ llvm.func @llvm_nvvm_cp_async_bulk_global_to_shared_cluster(%dst : !llvm.ptr<7>, llvm.return } +// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_global_to_shared_cta +llvm.func @llvm_nvvm_cp_async_bulk_global_to_shared_cta(%dst : !llvm.ptr<3>, %src : !llvm.ptr<1>, %mbar : !llvm.ptr<3>, %size : i32, %ch : i64) { + // CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cta(ptr addrspace(3) %[[DST:.*]], ptr addrspace(3) %[[MBAR:.*]], ptr addrspace(1) %[[SRC:.*]], i32 %[[SIZE:.*]], i64 0, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.global.to.shared.cta(ptr addrspace(3) %[[DST]], ptr addrspace(3) %[[MBAR]], ptr addrspace(1) %[[SRC]], i32 %[[SIZE]], i64 %[[CH:.*]], i1 true) + nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size : !llvm.ptr<3>, !llvm.ptr<1> + + nvvm.cp.async.bulk.shared.cluster.global %dst, %src, %mbar, %size l2_cache_hint = %ch : !llvm.ptr<3>, !llvm.ptr<1> + + llvm.return +} + // CHECK-LABEL: @llvm_nvvm_cp_async_bulk_shared_cta_to_shared_cluster llvm.func @llvm_nvvm_cp_async_bulk_shared_cta_to_shared_cluster(%dst : !llvm.ptr<7>, %src : !llvm.ptr<3>, %mbar : !llvm.ptr<3>, %size : i32) { // CHECK: call void @llvm.nvvm.cp.async.bulk.shared.cta.to.cluster(ptr addrspace(7) %0, ptr addrspace(3) %2, ptr addrspace(3) %1, i32 %3) diff --git a/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy_invalid.mlir new file mode 100644 index 0000000000000..d762ff3ff1e76 --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/tma_bulk_copy_invalid.mlir @@ -0,0 +1,8 @@ +// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s + +llvm.func @tma_bulk_copy_g2s_mc(%src : !llvm.ptr<1>, %dest : !llvm.ptr<3>, %bar : !llvm.ptr<3>, %size : i32, %ctamask : i16) { + // expected-error @below {{Multicast is not supported with shared::cta mode.}} + nvvm.cp.async.bulk.shared.cluster.global %dest, %src, %bar, %size multicast_mask = %ctamask : !llvm.ptr<3>, !llvm.ptr<1> + + llvm.return +}