diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index e0fef69f4f944..9b2a8985a1a44 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -723,7 +723,8 @@ def NVVM_MBarrierCompleteTxOp : NVVM_VoidIntrinsicOp<"mbarrier.complete_tx"> { let hasVerifier = 1; } -def NVVM_MBarrierArriveOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive"> { +def NVVM_MBarrierArriveOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive", + [InferTypeOpAdaptorWithIsCompatible]> { let summary = "MBarrier Arrive Operation"; let description = [{ The `nvvm.mbarrier.arrive` operation performs an arrive-on operation on the @@ -771,7 +772,8 @@ def NVVM_MBarrierArriveOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive"> { let hasVerifier = 1; } -def NVVM_MBarrierArriveDropOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive_drop"> { +def NVVM_MBarrierArriveDropOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive_drop", + [InferTypeOpAdaptorWithIsCompatible]> { let summary = "MBarrier Arrive-Drop Operation"; let description = [{ The `nvvm.mbarrier.arrive_drop` operation decrements the expected arrival @@ -847,7 +849,8 @@ def NVVM_MBarrierArriveDropNocompleteOp : NVVM_SingleResultIntrinsicOp<"mbarrier let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands) `->` type($res)"; } -def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx"> { +def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx", + [InferTypeOpAdaptorWithIsCompatible]> { let summary = "MBarrier Arrive with Expected Transaction Count"; let description = [{ The `nvvm.mbarrier.arrive.expect_tx` operation performs an expect-tx operation @@ -913,7 +916,8 @@ def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_t }]; } -def NVVM_MBarrierArriveDropExpectTxOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive_drop.expect_tx"> { +def NVVM_MBarrierArriveDropExpectTxOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive_drop.expect_tx", + [InferTypeOpAdaptorWithIsCompatible]> { let summary = "MBarrier arrive_drop with expected transaction count"; let description = [{ The `nvvm.mbarrier.arrive_drop.expect_tx` operation is similar to the @@ -1126,7 +1130,8 @@ def BarrierReductionAttr let assemblyFormat = "`<` $value `>`"; } -def NVVM_BarrierOp : NVVM_SingleResultIntrinsicOp<"barrier", [AttrSizedOperandSegments]> { +def NVVM_BarrierOp : NVVM_SingleResultIntrinsicOp<"barrier", + [AttrSizedOperandSegments, InferTypeOpAdaptorWithIsCompatible]> { let summary = "CTA Barrier Synchronization Op"; let description = [{ The `nvvm.barrier` operation performs barrier synchronization and communication diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 303dc82a67374..c4175016ab30c 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -852,9 +852,7 @@ struct NVGPUMBarrierArriveLowering Value barrier = getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), adaptor.getMbarId(), rewriter); - Type tokenType = getTypeConverter()->convertType( - nvgpu::MBarrierTokenType::get(op->getContext())); - rewriter.replaceOpWithNewOp(op, tokenType, barrier); + rewriter.replaceOpWithNewOp(op, barrier); return success(); } }; @@ -911,12 +909,12 @@ struct NVGPUMBarrierArriveExpectTxLowering getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(), adaptor.getMbarId(), rewriter); Value txcount = truncToI32(b, adaptor.getTxcount()); - rewriter.replaceOpWithNewOp( - op, Type{}, // return-value is optional and is void by default - barrier, txcount, // barrier and txcount - NVVM::MemScopeKind::CTA, // default scope is CTA - false, // relaxed-semantics is false + NVVM::MBarrierArriveExpectTxOp::create( + rewriter, op->getLoc(), barrier, txcount, // barrier and txcount + NVVM::MemScopeKind::CTA, // default scope is CTA + false, // relaxed-semantics is false adaptor.getPredicate()); + rewriter.eraseOp(op); return success(); } }; diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 13f05b8f40ed8..31e7ff209db5c 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -302,6 +302,82 @@ LogicalResult MBarrierArriveDropExpectTxOp::verify() { getRes()); } +//===----------------------------------------------------------------------===// +// inferReturnTypes for mbarrier arrive-like ops +//===----------------------------------------------------------------------===// + +/// Only shared_cluster (ptr<7>) produces zero results; all other address +/// spaces (including generic) return i64. +static LogicalResult +inferMBarrierArriveResultTypes(MLIRContext *context, Value addr, + SmallVectorImpl &inferredReturnTypes) { + if (!isPtrInSharedClusterSpace(addr)) + inferredReturnTypes.push_back(IntegerType::get(context, 64)); + return success(); +} + +LogicalResult +MBarrierArriveOp::inferReturnTypes(MLIRContext *context, + std::optional location, + MBarrierArriveOp::Adaptor adaptor, + SmallVectorImpl &inferredReturnTypes) { + return inferMBarrierArriveResultTypes(context, adaptor.getAddr(), + inferredReturnTypes); +} + +LogicalResult MBarrierArriveDropOp::inferReturnTypes( + MLIRContext *context, std::optional location, + MBarrierArriveDropOp::Adaptor adaptor, + SmallVectorImpl &inferredReturnTypes) { + return inferMBarrierArriveResultTypes(context, adaptor.getAddr(), + inferredReturnTypes); +} + +LogicalResult MBarrierArriveExpectTxOp::inferReturnTypes( + MLIRContext *context, std::optional location, + MBarrierArriveExpectTxOp::Adaptor adaptor, + SmallVectorImpl &inferredReturnTypes) { + // Predicate forces no return value (inline PTX path). + // Note: predicate + shared_cluster is rejected by the verifier separately. + if (adaptor.getPredicate()) + return success(); + return inferMBarrierArriveResultTypes(context, adaptor.getAddr(), + inferredReturnTypes); +} + +LogicalResult MBarrierArriveDropExpectTxOp::inferReturnTypes( + MLIRContext *context, std::optional location, + MBarrierArriveDropExpectTxOp::Adaptor adaptor, + SmallVectorImpl &inferredReturnTypes) { + return inferMBarrierArriveResultTypes(context, adaptor.getAddr(), + inferredReturnTypes); +} + +/// For ops with optional results, allow the user to omit the result even when +/// inference would produce one. This preserves backward compatibility: the +/// result can be silently discarded (e.g., for fire-and-forget arrive ops). +static bool isCompatibleReturnTypesOptionalResult(TypeRange inferred, + TypeRange actual) { + if (actual.empty()) + return true; + return inferred == actual; +} + +bool MBarrierArriveOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + return isCompatibleReturnTypesOptionalResult(l, r); +} +bool MBarrierArriveDropOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + return isCompatibleReturnTypesOptionalResult(l, r); +} +bool MBarrierArriveExpectTxOp::isCompatibleReturnTypes(TypeRange l, + TypeRange r) { + return isCompatibleReturnTypesOptionalResult(l, r); +} +bool MBarrierArriveDropExpectTxOp::isCompatibleReturnTypes(TypeRange l, + TypeRange r) { + return isCompatibleReturnTypesOptionalResult(l, r); +} + LogicalResult MBarrierExpectTxOp::verify() { return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope()); } @@ -2855,6 +2931,18 @@ LogicalResult NVVM::BarrierOp::verify() { return success(); } +LogicalResult BarrierOp::inferReturnTypes( + MLIRContext *context, std::optional location, + BarrierOp::Adaptor adaptor, SmallVectorImpl &inferredReturnTypes) { + if (adaptor.getReductionOp()) + inferredReturnTypes.push_back(IntegerType::get(context, 32)); + return success(); +} + +bool BarrierOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + return isCompatibleReturnTypesOptionalResult(l, r); +} + LogicalResult NVVM::Tcgen05CpOp::verify() { auto mc = getMulticast(); diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir index b11ba944fe4ac..c039edc6b5de5 100644 --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -457,6 +457,13 @@ llvm.func private @mbarrier_arrive_nocomplete_shared(%barrier: !llvm.ptr<3>) { llvm.return } +// CHECK-LABEL: @mbarrier_arrive_expect_tx_predicate +llvm.func private @mbarrier_arrive_expect_tx_predicate(%barrier: !llvm.ptr<3>, %txcount: i32, %pred: i1) { + // CHECK: nvvm.mbarrier.arrive.expect_tx %{{.*}}, %{{.*}}, predicate = %{{.*}} : !llvm.ptr<3>, i32, i1{{$}} + nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1 + llvm.return +} + // CHECK-LABEL: @wgmma_fence_aligned func.func @wgmma_fence_aligned() { // CHECK: nvvm.wgmma.fence.aligned diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir index f406922ea3873..96c910b193f12 100644 --- a/mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir +++ b/mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir @@ -108,8 +108,10 @@ llvm.func @mbarrier_arrive_ignore_retval(%count : i32, %barrier: !llvm.ptr<3>) { // CHECK-NEXT: %4 = call i64 @llvm.nvvm.mbarrier.arrive.scope.cta.space.cta(ptr addrspace(3) %1, i32 %0) // CHECK-NEXT: ret void // CHECK-NEXT: } + // Result silently discarded (backward compatible form) nvvm.mbarrier.arrive %barrier, %count : !llvm.ptr<3> - nvvm.mbarrier.arrive %barrier, %count : !llvm.ptr<3> + // Result explicitly captured + %0 = nvvm.mbarrier.arrive %barrier, %count : !llvm.ptr<3> -> i64 llvm.return } diff --git a/mlir/test/python/dialects/nvvm.py b/mlir/test/python/dialects/nvvm.py index b969faa088a46..24abf617548b8 100644 --- a/mlir/test/python/dialects/nvvm.py +++ b/mlir/test/python/dialects/nvvm.py @@ -93,6 +93,36 @@ def my_inline_ptx(a, b, c, d): arith.addf(wo0, wo1) +@constructAndPrintInModule +def test_mbarrier_arrive(): + ptr_shared = llvm.PointerType.get(3) + ptr_cluster = llvm.PointerType.get(7) + i32 = T.i32() + + @func.FuncOp.from_py_func(ptr_shared, ptr_cluster, i32) + def mbarrier_arrive_ops(barrier_shared, barrier_cluster, txcount): + token = nvvm.mbarrier_arrive(barrier_shared) + nvvm.mbarrier_arrive(barrier_cluster) + token2 = nvvm.mbarrier_arrive_drop(barrier_shared) + nvvm.mbarrier_arrive_drop(barrier_cluster) + token3 = nvvm.mbarrier_arrive_expect_tx(barrier_shared, txcount) + nvvm.mbarrier_arrive_expect_tx(barrier_cluster, txcount) + token4 = nvvm.mbarrier_arrive_drop_expect_tx(barrier_shared, txcount) + nvvm.mbarrier_arrive_drop_expect_tx(barrier_cluster, txcount) + + +# CHECK-LABEL: func.func @mbarrier_arrive_ops( +# CHECK-SAME: %[[SHARED:.*]]: !llvm.ptr<3>, %[[CLUSTER:.*]]: !llvm.ptr<7>, %[[TXCOUNT:.*]]: i32) +# CHECK: %{{.*}} = nvvm.mbarrier.arrive %[[SHARED]] : !llvm.ptr<3> -> i64 +# CHECK-NEXT: nvvm.mbarrier.arrive %[[CLUSTER]] : !llvm.ptr<7>{{$}} +# CHECK-NEXT: %{{.*}} = nvvm.mbarrier.arrive_drop %[[SHARED]] : !llvm.ptr<3> -> i64 +# CHECK-NEXT: nvvm.mbarrier.arrive_drop %[[CLUSTER]] : !llvm.ptr<7>{{$}} +# CHECK-NEXT: %{{.*}} = nvvm.mbarrier.arrive.expect_tx %[[SHARED]], %[[TXCOUNT]] : !llvm.ptr<3>, i32 -> i64 +# CHECK-NEXT: nvvm.mbarrier.arrive.expect_tx %[[CLUSTER]], %[[TXCOUNT]] : !llvm.ptr<7>, i32{{$}} +# CHECK-NEXT: %{{.*}} = nvvm.mbarrier.arrive_drop.expect_tx %[[SHARED]], %[[TXCOUNT]] : !llvm.ptr<3>, i32 -> i64 +# CHECK-NEXT: nvvm.mbarrier.arrive_drop.expect_tx %[[CLUSTER]], %[[TXCOUNT]] : !llvm.ptr<7>, i32{{$}} + + @constructAndPrintInModule def test_barriers(): i32 = T.i32() @@ -102,21 +132,20 @@ def test_barriers(): def barriers(mask, vi32, vf32): c0 = arith.constant(T.i32(), 0) cffff = arith.constant(T.i32(), 0xFFFF) - res = nvvm.barrier( - res=i32, + nvvm.barrier( barrier_id=c0, number_of_threads=cffff, ) + pred = arith.constant(T.i32(), 1) for reduction in ( nvvm.BarrierReduction.AND, nvvm.BarrierReduction.OR, nvvm.BarrierReduction.POPC, ): - res = nvvm.barrier( - res=i32, + pred = nvvm.barrier( reduction_op=reduction, - reduction_predicate=res, + reduction_predicate=pred, ) nvvm.barrier0() @@ -129,15 +158,16 @@ def barriers(mask, vi32, vf32): nvvm.cluster_wait(aligned=True) nvvm.fence_mbarrier_init() nvvm.bar_warp_sync(mask) - return res + return pred # CHECK-LABEL: func.func @barriers( # CHECK: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: f32) -> i32 { # CHECK: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 # CHECK: %[[CONSTANT_1:.*]] = arith.constant 65535 : i32 -# CHECK: %[[BARRIER_0:.*]] = nvvm.barrier id = %[[CONSTANT_0]] number_of_threads = %[[CONSTANT_1]] -> i32 -# CHECK: %[[BARRIER_1:.*]] = nvvm.barrier #nvvm.reduction %[[BARRIER_0]] -> i32 +# CHECK: nvvm.barrier id = %[[CONSTANT_0]] number_of_threads = %[[CONSTANT_1]] +# CHECK: %[[PRED:.*]] = arith.constant 1 : i32 +# CHECK: %[[BARRIER_1:.*]] = nvvm.barrier #nvvm.reduction %[[PRED]] -> i32 # CHECK: %[[BARRIER_2:.*]] = nvvm.barrier #nvvm.reduction %[[BARRIER_1]] -> i32 # CHECK: %[[BARRIER_3:.*]] = nvvm.barrier #nvvm.reduction %[[BARRIER_2]] -> i32 # CHECK: nvvm.barrier0 @@ -151,7 +181,6 @@ def barriers(mask, vi32, vf32): # CHECK: nvvm.fence.mbarrier.init # CHECK: nvvm.bar.warp.sync %[[ARG0]] : i32 # CHECK: return %[[BARRIER_3]] : i32 -# CHECK: } @constructAndPrintInModule