diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h index 6137bb087c576..2e8c0d15f098e 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h @@ -55,10 +55,21 @@ enum NVVMMemorySpace { kSharedClusterMemorySpace = 7, }; -/// A pair type of LLVM's Intrinsic ID and args (which are llvm values). -/// This type is returned by the getIntrinsicIDAndArgs() methods. -using IDArgPair = - std::pair>; +/// A struct of LLVM's Intrinsic ID, args (which are llvm values), +/// and args types (which are llvm types). +/// Args types are only required for overloaded intrinsics to provide the +/// correct argument types to the createIntrinsicCall() method. +/// This type is returned by the getIIDAndArgsWithTypes() methods. +struct IIDArgsWithTypes { + IIDArgsWithTypes(llvm::Intrinsic::ID id, + llvm::SmallVector args, + llvm::SmallVector types) + : id(id), args(args), types(types) {} + + llvm::Intrinsic::ID id; + llvm::SmallVector args; + llvm::SmallVector types; +}; /// Return the element type and number of elements associated with a wmma matrix /// of given chracteristics. This matches the logic in IntrinsicsNVVM.td diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 9d93b4efe7a5b..168060aae2c3e 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -410,9 +410,16 @@ def NVVM_ReduxOp : [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-redux-sync) }]; + let extraClassDeclaration = [{ + static NVVM::IIDArgsWithTypes + getIIDAndArgsWithTypes(Operation &op, + LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder); + }]; string llvmBuilder = [{ - auto intId = getReduxIntrinsicId($_resultType, $kind, $abs, $nan); - $res = createIntrinsicCall(builder, intId, {$val, $mask_and_clamp}); + auto [id, args, types] = + NVVM::ReduxOp::getIIDAndArgsWithTypes( + *op, moduleTranslation, builder); + $res = createIntrinsicCall(builder, id, args); }]; let assemblyFormat = [{ $kind $val `,` $mask_and_clamp attr-dict `:` type($val) `->` type($res) @@ -876,11 +883,17 @@ def NVVM_FenceProxyAcquireOp : NVVM_Op<"fence.proxy.acquire">, }]; let assemblyFormat = "$scope $addr `,` $size (`from_proxy` `=` $fromProxy^)? (`to_proxy` `=` $toProxy^)? attr-dict"; + + let extraClassDeclaration = [{ + static NVVM::IIDArgsWithTypes + getIIDAndArgsWithTypes(Operation &op, + LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder); + }]; let llvmBuilder = [{ - createIntrinsicCall( - builder, - getUnidirectionalFenceProxyID($fromProxy, $toProxy, $scope, false), - {$addr, $size}); + auto [intId, args, types] = + NVVM::FenceProxyAcquireOp::getIIDAndArgsWithTypes( + *op, moduleTranslation, builder); + createIntrinsicCall(builder, intId, args); }]; let hasVerifier = 1; @@ -904,9 +917,16 @@ def NVVM_FenceProxyReleaseOp : NVVM_Op<"fence.proxy.release">, }]; let assemblyFormat = "$scope (`from_proxy` `=` $fromProxy^)? (`to_proxy` `=` $toProxy^)? attr-dict"; + + let extraClassDeclaration = [{ + static NVVM::IIDArgsWithTypes + getIIDAndArgsWithTypes(Operation &op, + LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder); + }]; let llvmBuilder = [{ - createIntrinsicCall(builder, getUnidirectionalFenceProxyID( - $fromProxy, $toProxy, $scope, true)); + auto [intId, args, types] = NVVM::FenceProxyReleaseOp::getIIDAndArgsWithTypes( + *op, moduleTranslation, builder); + createIntrinsicCall(builder, intId, args); }]; let hasVerifier = 1; @@ -985,11 +1005,15 @@ def NVVM_ShflOp : [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl-sync) }]; + let extraClassDeclaration = [{ + static NVVM::IIDArgsWithTypes + getIIDAndArgsWithTypes(Operation &op, + LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder); + }]; string llvmBuilder = [{ - auto intId = getShflIntrinsicId( - $_resultType, $kind, static_cast($return_value_and_is_valid)); - $res = createIntrinsicCall(builder, - intId, {$thread_mask, $val, $offset, $mask_and_clamp}); + auto [intId, args, types] = NVVM::ShflOp::getIIDAndArgsWithTypes( + *op, moduleTranslation, builder); + $res = createIntrinsicCall(builder, intId, args); }]; let assemblyFormat = [{ $kind $thread_mask `,` $val `,` $offset `,` $mask_and_clamp attr-dict @@ -1035,9 +1059,16 @@ def NVVM_VoteSyncOp [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-vote-sync) }]; + let extraClassDeclaration = [{ + static NVVM::IIDArgsWithTypes + getIIDAndArgsWithTypes(Operation &op, + LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder); + }]; string llvmBuilder = [{ - auto intId = getVoteSyncIntrinsicId($kind); - $res = createIntrinsicCall(builder, intId, {$mask, $pred}); + auto [intId, args, types] = + NVVM::VoteSyncOp::getIIDAndArgsWithTypes( + *op, moduleTranslation, builder); + $res = createIntrinsicCall(builder, intId, args); }]; let assemblyFormat = "$kind $mask `,` $pred attr-dict `->` type($res)"; let hasVerifier = 1; @@ -1108,15 +1139,14 @@ def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global">, let assemblyFormat = "$dst `,` $src `,` $size `,` `cache` `=` $modifier (`,` $cpSize^)? attr-dict `:` type(operands)"; let hasVerifier = 1; let extraClassDeclaration = [{ - static llvm::Intrinsic::ID - getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, - llvm::SmallVector &args); + static NVVM::IIDArgsWithTypes + getIIDAndArgsWithTypes(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder); }]; string llvmBuilder = [{ - llvm::SmallVector translatedOperands; - auto id = NVVM::CpAsyncOp::getIntrinsicIDAndArgs( - *op, moduleTranslation, translatedOperands); - createIntrinsicCall(builder, id, translatedOperands); + auto [id, args, types] = NVVM::CpAsyncOp::getIIDAndArgsWithTypes( + *op, moduleTranslation, builder); + createIntrinsicCall(builder, id, args); }]; } @@ -2107,10 +2137,16 @@ def NVVM_StMatrixOp: NVVM_Op<"stmatrix">, [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix) }]; + let extraClassDeclaration = [{ + static NVVM::IIDArgsWithTypes + getIIDAndArgsWithTypes(Operation &op, + LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder); + }]; string llvmBuilder = [{ - auto operands = moduleTranslation.lookupValues(opInst.getOperands()); - auto intId = getStMatrixIntrinsicId($layout, $sources.size(), $shape, $eltType); - createIntrinsicCall(builder, intId, operands, operands[0]->getType()); + auto [intId, args, types] = + NVVM::StMatrixOp::getIIDAndArgsWithTypes( + *op, moduleTranslation, builder); + createIntrinsicCall(builder, intId, args, types); }]; let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)"; let hasVerifier = 1; @@ -2125,10 +2161,16 @@ def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">, let summary = "cooperative matrix load"; + let extraClassDeclaration = [{ + static NVVM::IIDArgsWithTypes + getIIDAndArgsWithTypes(Operation &op, + LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder); + }]; string llvmBuilder = [{ - auto operands = moduleTranslation.lookupValues(opInst.getOperands()); - auto intId = getLdMatrixIntrinsicId($layout, $num, $shape, $eltType); - $res = createIntrinsicCall(builder, intId, operands, {operands[0]->getType()}); + auto [intId, args, types] = + NVVM::LdMatrixOp::getIIDAndArgsWithTypes( + *op, moduleTranslation, builder); + $res = createIntrinsicCall(builder, intId, args, types); }]; string baseDescription = [{ @@ -2543,8 +2585,8 @@ def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp : let extraClassDeclaration = [{ bool hasIntrinsic() { return !getPredicate(); } - static mlir::NVVM::IDArgPair - getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + static mlir::NVVM::IIDArgsWithTypes + getIIDAndArgsWithTypes(Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase& builder); }]; @@ -2565,7 +2607,7 @@ def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp : let hasVerifier = 1; string llvmBuilder = [{ - auto [id, args] = NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs( + auto [id, args, types] = NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp::getIIDAndArgsWithTypes( *op, moduleTranslation, builder); createIntrinsicCall(builder, id, args); }]; @@ -2631,8 +2673,8 @@ def NVVM_PrefetchOp : NVVM_Op<"prefetch", let hasVerifier = 1; let extraClassDeclaration = [{ - static NVVM::IDArgPair - getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,LLVM::ModuleTranslation &mt, + static NVVM::IIDArgsWithTypes + getIIDAndArgsWithTypes(Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder); bool hasIntrinsic() { return !getPredicate() || !getTensormap(); } }]; @@ -2643,7 +2685,7 @@ def NVVM_PrefetchOp : NVVM_Op<"prefetch", } }]; let llvmBuilder = [{ - auto [id, args] = NVVM::PrefetchOp::getIntrinsicIDAndArgs(op, + auto [id, args, types] = NVVM::PrefetchOp::getIIDAndArgsWithTypes(*op, moduleTranslation, builder); if(op.getTensormap()) @@ -2685,13 +2727,13 @@ def NVVM_CpAsyncBulkPrefetchOp : NVVM_Op<"cp.async.bulk.prefetch"> { }]; let extraClassDeclaration = [{ - static mlir::NVVM::IDArgPair - getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + static mlir::NVVM::IIDArgsWithTypes + getIIDAndArgsWithTypes(Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase& builder); }]; string llvmBuilder = [{ - auto [id, args] = NVVM::CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs( + auto [id, args, types] = NVVM::CpAsyncBulkPrefetchOp::getIIDAndArgsWithTypes( *op, moduleTranslation, builder); createIntrinsicCall(builder, id, args); }]; @@ -2726,15 +2768,15 @@ def NVVM_CpAsyncBulkTensorPrefetchOp : }]; let extraClassDeclaration = [{ - static mlir::NVVM::IDArgPair - getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + static mlir::NVVM::IIDArgsWithTypes + getIIDAndArgsWithTypes(Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase& builder); }]; let hasVerifier = 1; string llvmBuilder = [{ - auto [id, args] = NVVM::CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs( + auto [id, args, types] = NVVM::CpAsyncBulkTensorPrefetchOp::getIIDAndArgsWithTypes( *op, moduleTranslation, builder); createIntrinsicCall(builder, id, args); }]; @@ -2795,35 +2837,17 @@ def NVVM_CpAsyncBulkTensorReduceOp : }]; let extraClassDeclaration = [{ - static llvm::Intrinsic::ID getIntrinsicID(int tensorDims, - NVVM::TMAReduxKind kind, - bool isIm2Col); + static mlir::NVVM::IIDArgsWithTypes + getIIDAndArgsWithTypes(Operation &op, + LLVM::ModuleTranslation &mt, llvm::IRBuilderBase& builder); }]; let hasVerifier = 1; string llvmBuilder = [{ - // Arguments to the intrinsic: - // shared_mem_ptr, tmaDesc, tensorDims - // cache_hint(if applicable) and flag(boolean) - llvm::SmallVector translatedOperands; - translatedOperands.push_back($srcMem); - translatedOperands.push_back($tmaDescriptor); - - for (auto v : op.getCoordinates()) - translatedOperands.push_back(moduleTranslation.lookupValue(v)); - - llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext(); - auto *i64Undef = llvm::UndefValue::get(llvm::IntegerType::get(ctx, 64)); - - bool isCacheHint = op.getL2CacheHint() ? true : false; - translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Undef); - translatedOperands.push_back(builder.getInt1(isCacheHint)); - - auto intId = NVVM::CpAsyncBulkTensorReduceOp::getIntrinsicID( - op.getCoordinates().size(), $redKind, - (op.getMode() == NVVM::TMAStoreMode::IM2COL)); - createIntrinsicCall(builder, intId, translatedOperands); + auto [id, args, types] = NVVM::CpAsyncBulkTensorReduceOp::getIIDAndArgsWithTypes( + *op, moduleTranslation, builder); + createIntrinsicCall(builder, id, args); }]; } @@ -2860,36 +2884,17 @@ def NVVM_CpAsyncBulkGlobalToSharedClusterOp : (`l2_cache_hint` `=` $l2CacheHint^ )? attr-dict `:` type($dstMem) `,` type($srcMem) }]; + + let extraClassDeclaration = [{ + static mlir::NVVM::IIDArgsWithTypes + getIIDAndArgsWithTypes(Operation &op, + LLVM::ModuleTranslation &mt, llvm::IRBuilderBase& builder); + }]; string llvmBuilder = [{ - // Arguments to the intrinsic: - // dst, mbar, src, size - // multicast_mask, cache_hint, - // flag for multicast_mask, - // flag for cache_hint - llvm::SmallVector translatedOperands; - translatedOperands.push_back($dstMem); - translatedOperands.push_back($mbar); - translatedOperands.push_back($srcMem); - translatedOperands.push_back($size); - - // Multicast, if available - llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext(); - auto *i16Unused = llvm::ConstantInt::get(llvm::Type::getInt16Ty(ctx), 0); - bool isMulticast = op.getMulticastMask() ? true : false; - translatedOperands.push_back(isMulticast ? $multicastMask : i16Unused); - - // Cachehint, if available - auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0); - bool isCacheHint = op.getL2CacheHint() ? true : false; - translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused); - - // Flag arguments for multicast and cachehint - translatedOperands.push_back(builder.getInt1(isMulticast)); - translatedOperands.push_back(builder.getInt1(isCacheHint)); - - createIntrinsicCall(builder, - llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster, translatedOperands); + auto [id, args, types] = NVVM::CpAsyncBulkGlobalToSharedClusterOp::getIIDAndArgsWithTypes( + *op, moduleTranslation, builder); + createIntrinsicCall(builder, id, args); }]; } @@ -2971,12 +2976,12 @@ def NVVM_CpAsyncBulkSharedCTAToGlobalOp : }]; let extraClassDeclaration = [{ - static mlir::NVVM::IDArgPair - getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + static mlir::NVVM::IIDArgsWithTypes + getIIDAndArgsWithTypes(Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase& builder); }]; string llvmBuilder = [{ - auto [id, args] = NVVM::CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs( + auto [id, args, types] = NVVM::CpAsyncBulkSharedCTAToGlobalOp::getIIDAndArgsWithTypes( *op, moduleTranslation, builder); createIntrinsicCall(builder, id, args); }]; @@ -3276,11 +3281,16 @@ def NVVM_MatchSyncOp : NVVM_Op<"match.sync">, [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-match-sync) }]; + let extraClassDeclaration = [{ + static NVVM::IIDArgsWithTypes + getIIDAndArgsWithTypes(Operation &op, + LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder); + }]; string llvmBuilder = [{ - auto intId = getMatchSyncIntrinsicId( - op.getVal().getType(), $kind); - $res = createIntrinsicCall(builder, - intId, {$thread_mask, $val}); + auto [intId, args, types] = + NVVM::MatchSyncOp::getIIDAndArgsWithTypes( + *op, moduleTranslation, builder); + $res = createIntrinsicCall(builder, intId, args); }]; let assemblyFormat = "$kind $thread_mask `,` $val attr-dict `:` type($val) `->` type($res)"; let hasVerifier = 1; @@ -3304,11 +3314,16 @@ def NVVM_BulkStoreOp: NVVM_Op<"st.bulk"> { [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st-bulk) }]; + let extraClassDeclaration = [{ + static NVVM::IIDArgsWithTypes + getIIDAndArgsWithTypes(Operation &op, + LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder); + }]; string llvmBuilder = [{ - auto intId = getStBulkIntrinsicId( - llvm::cast(op.getAddr().getType())); - createIntrinsicCall(builder, intId, - {$addr, $size, builder.getInt64($initVal)}); + auto [intId, args, types] = + NVVM::BulkStoreOp::getIIDAndArgsWithTypes( + *op, moduleTranslation, builder); + createIntrinsicCall(builder, intId, args); }]; let assemblyFormat = "$addr `,` `size` `=` $size (`,` `init` `=` $initVal^)? attr-dict `:` type($addr)"; @@ -3392,14 +3407,13 @@ def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc", [NVVMRequiresSMa<[100, 101]>] let assemblyFormat = "$addr `,` $nCols attr-dict `:` type(operands)"; let extraClassDeclaration = [{ - static llvm::Intrinsic::ID - getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, - llvm::SmallVector &args); + static NVVM::IIDArgsWithTypes + getIIDAndArgsWithTypes(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder); }]; string llvmBuilder = [{ - llvm::SmallVector args; - auto id = NVVM::Tcgen05AllocOp::getIntrinsicIDAndArgs( - *op, moduleTranslation, args); + auto [id, args, types] = NVVM::Tcgen05AllocOp::getIIDAndArgsWithTypes( + *op, moduleTranslation, builder); createIntrinsicCall(builder, id, args); }]; } @@ -3420,14 +3434,13 @@ def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc", [NVVMRequiresSMa<[100, 10 let assemblyFormat = "$taddr `,` $nCols attr-dict `:` type(operands)"; let extraClassDeclaration = [{ - static llvm::Intrinsic::ID - getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, - llvm::SmallVector &args); + static NVVM::IIDArgsWithTypes + getIIDAndArgsWithTypes(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder); }]; string llvmBuilder = [{ - llvm::SmallVector args; - auto id = NVVM::Tcgen05DeallocOp::getIntrinsicIDAndArgs( - *op, moduleTranslation, args); + auto [id, args, types] = NVVM::Tcgen05DeallocOp::getIIDAndArgsWithTypes( + *op, moduleTranslation, builder); createIntrinsicCall(builder, id, args); }]; } @@ -3524,15 +3537,14 @@ def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit", [NVVMRequiresSMa<[100, 101] }]; let extraClassDeclaration = [{ - static llvm::Intrinsic::ID - getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, - llvm::SmallVector &args); + static NVVM::IIDArgsWithTypes + getIIDAndArgsWithTypes(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder); }]; string llvmBuilder = [{ - llvm::SmallVector args; - auto id = NVVM::Tcgen05CommitOp::getIntrinsicIDAndArgs( - *op, moduleTranslation, args); + auto [id, args, types] = NVVM::Tcgen05CommitOp::getIIDAndArgsWithTypes( + *op, moduleTranslation, builder); createIntrinsicCall(builder, id, args); }]; } @@ -3636,12 +3648,14 @@ def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp", [NVVMRequiresSMa<[100, 101]>]> { let hasVerifier = 1; let extraClassDeclaration = [{ - static llvm::Intrinsic::ID getIntrinsicID(Operation &op); + static NVVM::IIDArgsWithTypes getIIDAndArgsWithTypes(Operation &op, + LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder); }]; string llvmBuilder = [{ - auto id = NVVM::Tcgen05CpOp::getIntrinsicID(*op); - createIntrinsicCall(builder, id, {$taddr, $smem_desc}); + auto [id, args, types] = NVVM::Tcgen05CpOp::getIIDAndArgsWithTypes(*op, + moduleTranslation, builder); + createIntrinsicCall(builder, id, args); }]; } @@ -3806,24 +3820,16 @@ def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld", [NVVMRequiresSMa<[100, 101]>]> { let hasVerifier = 1; + let extraClassDeclaration = [{ + static NVVM::IIDArgsWithTypes + getIIDAndArgsWithTypes(Operation &op, + LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder); + }]; string llvmBuilder = [{ - llvm::LLVMContext &Context = moduleTranslation.getLLVMContext(); - auto Pack = llvm::ConstantInt::get(Context, llvm::APInt(1, $pack)); - - unsigned num = $_resultType->isVectorTy() - ? llvm::cast($_resultType) - ->getElementCount() - .getFixedValue() - : 1; - - auto ID = getTcgen05LdIntrinsicID($shape, num); - if (ID == llvm::Intrinsic::not_intrinsic) - llvm::report_fatal_error("unknow intrinsic signature for tcgen05.ld"); - - if ($offset) - $res = createIntrinsicCall(builder, ID, {$tmemAddr, $offset, Pack}); - else - $res = createIntrinsicCall(builder, ID, {$tmemAddr, Pack}); + auto [id, args, types] = + NVVM::Tcgen05LdOp::getIIDAndArgsWithTypes(*op, + moduleTranslation, builder); + $res = createIntrinsicCall(builder, id, args); }]; } @@ -3894,24 +3900,16 @@ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st", [NVVMRequiresSMa<[100, 101]>]> { [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st) }]; + let extraClassDeclaration = [{ + static NVVM::IIDArgsWithTypes + getIIDAndArgsWithTypes(Operation &op, + LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder); + }]; string llvmBuilder = [{ - llvm::LLVMContext &Context = moduleTranslation.getLLVMContext(); - auto Unpack = llvm::ConstantInt::get(Context, llvm::APInt(1, $unpack)); - - auto valTy = $val->getType(); - uint32_t num = valTy->isVectorTy() ? llvm::cast(valTy) - ->getElementCount() - .getFixedValue() - : 1; - - auto ID = getTcgen05StIntrinsicID($shape, num); - if (ID == llvm::Intrinsic::not_intrinsic) - llvm::report_fatal_error("unknow intrinsic signature for tcgen05.st"); - - if ($offset) - createIntrinsicCall(builder, ID, {$tmemAddr, $offset, $val, Unpack}); - else - createIntrinsicCall(builder, ID, {$tmemAddr, $val, Unpack}); + auto [id, args, types] = + NVVM::Tcgen05StOp::getIIDAndArgsWithTypes(*op, + moduleTranslation, builder); + createIntrinsicCall(builder, id, args); }]; let hasVerifier = 1; @@ -3969,13 +3967,13 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> { let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)"; let extraClassDeclaration = [{ - static mlir::NVVM::IDArgPair - getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + static mlir::NVVM::IIDArgsWithTypes + getIIDAndArgsWithTypes(Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder); }]; string llvmBuilder = [{ - auto [id, args] = NVVM::DotAccumulate4WayOp::getIntrinsicIDAndArgs( + auto [id, args, types] = NVVM::DotAccumulate4WayOp::getIIDAndArgsWithTypes( *op, moduleTranslation, builder); $res = createIntrinsicCall(builder, id, args); }]; @@ -4023,13 +4021,13 @@ def NVVM_DotAccumulate2WayOp : NVVM_Op<"dot.accumulate.2way"> { let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)"; let extraClassDeclaration = [{ - static mlir::NVVM::IDArgPair - getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + static mlir::NVVM::IIDArgsWithTypes + getIIDAndArgsWithTypes(Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder); }]; string llvmBuilder = [{ - auto [id, args] = NVVM::DotAccumulate2WayOp::getIntrinsicIDAndArgs( + auto [id, args, types] = NVVM::DotAccumulate2WayOp::getIIDAndArgsWithTypes( *op, moduleTranslation, builder); $res = createIntrinsicCall(builder, id, args); }]; diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 376e3c3e1fcbe..cac5df7f32f58 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -1447,7 +1447,7 @@ void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op, } //===----------------------------------------------------------------------===// -// getIntrinsicID/getIntrinsicIDAndArgs methods +// getIntrinsicID/getIIDAndArgsWithTypes methods //===----------------------------------------------------------------------===// #define CP_ASYNC_ID_IMPL(mod, size, suffix) \ @@ -1456,9 +1456,10 @@ void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op, #define GET_CP_ASYNC_ID(mod, size, has_cpsize) \ has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, ) -llvm::Intrinsic::ID -CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, - llvm::SmallVector &args) { +NVVM::IIDArgsWithTypes +CpAsyncOp::getIIDAndArgsWithTypes(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + llvm::SmallVector args; llvm::Intrinsic::ID id; auto cpAsyncOp = cast(op); @@ -1485,10 +1486,10 @@ CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, if (hasCpSize) args.push_back(mt.lookupValue(cpAsyncOp.getCpSize())); - return id; + return {id, std::move(args), {}}; } -mlir::NVVM::IDArgPair CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs( +mlir::NVVM::IIDArgsWithTypes CpAsyncBulkPrefetchOp::getIIDAndArgsWithTypes( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto thisOp = cast(op); llvm::SmallVector args; @@ -1505,10 +1506,11 @@ mlir::NVVM::IDArgPair CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs( args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused); args.push_back(builder.getInt1(hasCacheHint)); - return {id, std::move(args)}; + return {id, std::move(args), {}}; } -mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs( +mlir::NVVM::IIDArgsWithTypes +CpAsyncBulkSharedCTAToGlobalOp::getIIDAndArgsWithTypes( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto thisOp = cast(op); llvm::SmallVector args; @@ -1533,10 +1535,11 @@ mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs( id = llvm::Intrinsic::nvvm_cp_async_bulk_shared_cta_to_global_bytemask; } - return {id, std::move(args)}; + return {id, std::move(args), {}}; } -mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs( +mlir::NVVM::IIDArgsWithTypes +CpAsyncBulkTensorPrefetchOp::getIIDAndArgsWithTypes( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto thisOp = cast(op); llvm::SmallVector args; @@ -1586,11 +1589,11 @@ mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs( if (id == llvm::Intrinsic::not_intrinsic) llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp."); - return {id, std::move(args)}; + return {id, std::move(args), {}}; } -mlir::NVVM::IDArgPair -CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs( +mlir::NVVM::IIDArgsWithTypes +CpAsyncBulkTensorSharedCTAToGlobalOp::getIIDAndArgsWithTypes( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto thisOp = cast(op); llvm::SmallVector args; @@ -1631,7 +1634,7 @@ CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs( llvm_unreachable( "Invalid intrinsic for CpAsyncBulkTensorSharedCTAToGlobalOp."); - return {id, std::move(args)}; + return {id, std::move(args), {}}; } #define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \ @@ -1641,46 +1644,115 @@ CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs( is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col) \ : CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile) -#define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col) \ - [&]() -> auto { \ - switch (dims) { \ - case 1: \ - return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile); \ - case 2: \ - return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 2, tile); \ - case 3: \ - return CP_ASYNC_BULK_TENSOR_REDUCE(op, 3, is_im2col); \ - case 4: \ - return CP_ASYNC_BULK_TENSOR_REDUCE(op, 4, is_im2col); \ - case 5: \ - return CP_ASYNC_BULK_TENSOR_REDUCE(op, 5, is_im2col); \ - default: \ - llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp."); \ - } \ - }() +#define GET_CP_ASYNC_BULK_TENSOR_ID(iid, op, dims, is_im2col) \ + switch (dims) { \ + case 1: \ + iid = CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile); \ + break; \ + case 2: \ + iid = CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 2, tile); \ + break; \ + case 3: \ + iid = CP_ASYNC_BULK_TENSOR_REDUCE(op, 3, is_im2col); \ + break; \ + case 4: \ + iid = CP_ASYNC_BULK_TENSOR_REDUCE(op, 4, is_im2col); \ + break; \ + case 5: \ + iid = CP_ASYNC_BULK_TENSOR_REDUCE(op, 5, is_im2col); \ + break; \ + default: \ + llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp."); \ + break; \ + } \ + break; + +NVVM::IIDArgsWithTypes CpAsyncBulkTensorReduceOp::getIIDAndArgsWithTypes( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast(op); + llvm::LLVMContext &ctx = mt.getLLVMContext(); + + llvm::SmallVector args; + + // Arguments to the intrinsic: + // shared_mem_ptr, tmaDesc, tensorDims + // cache_hint(if applicable) and flag(boolean) + args.push_back(mt.lookupValue(thisOp.getSrcMem())); + args.push_back(mt.lookupValue(thisOp.getTmaDescriptor())); + + for (auto v : thisOp.getCoordinates()) + args.push_back(mt.lookupValue(v)); + + mlir::Value cacheHint = thisOp.getL2CacheHint(); + const bool hasCacheHint = static_cast(cacheHint); + llvm::Value *i64Poison = + llvm::PoisonValue::get(llvm::IntegerType::get(ctx, 64)); + args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Poison); + args.push_back(builder.getInt1(hasCacheHint)); + + llvm::Intrinsic::ID iid; + int tensorDims = thisOp.getCoordinates().size(); + bool isIm2Col = thisOp.getMode() == NVVM::TMAStoreMode::IM2COL; -llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID( - int tensorDims, NVVM::TMAReduxKind kind, bool isIm2Col) { using RedTy = NVVM::TMAReduxKind; - switch (kind) { + switch (thisOp.getRedKind()) { case RedTy::ADD: - return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_add, tensorDims, isIm2Col); + GET_CP_ASYNC_BULK_TENSOR_ID(iid, reduce_add, tensorDims, isIm2Col); case RedTy::MIN: - return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_min, tensorDims, isIm2Col); + GET_CP_ASYNC_BULK_TENSOR_ID(iid, reduce_min, tensorDims, isIm2Col); case RedTy::MAX: - return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_max, tensorDims, isIm2Col); + GET_CP_ASYNC_BULK_TENSOR_ID(iid, reduce_max, tensorDims, isIm2Col); case RedTy::INC: - return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_inc, tensorDims, isIm2Col); + GET_CP_ASYNC_BULK_TENSOR_ID(iid, reduce_inc, tensorDims, isIm2Col); case RedTy::DEC: - return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_dec, tensorDims, isIm2Col); + GET_CP_ASYNC_BULK_TENSOR_ID(iid, reduce_dec, tensorDims, isIm2Col); case RedTy::AND: - return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_and, tensorDims, isIm2Col); + GET_CP_ASYNC_BULK_TENSOR_ID(iid, reduce_and, tensorDims, isIm2Col); case RedTy::OR: - return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_or, tensorDims, isIm2Col); + GET_CP_ASYNC_BULK_TENSOR_ID(iid, reduce_or, tensorDims, isIm2Col); case RedTy::XOR: - return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_xor, tensorDims, isIm2Col); + GET_CP_ASYNC_BULK_TENSOR_ID(iid, reduce_xor, tensorDims, isIm2Col); } - llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp"); + + return {iid, std::move(args), {}}; +} + +NVVM::IIDArgsWithTypes +CpAsyncBulkGlobalToSharedClusterOp::getIIDAndArgsWithTypes( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast(op); + llvm::SmallVector args; + + // Arguments to the intrinsic: + // dst, mbar, src, size + // multicast_mask, cache_hint, + // flag for multicast_mask, + // flag for cache_hint + args.push_back(mt.lookupValue(thisOp.getDstMem())); + args.push_back(mt.lookupValue(thisOp.getMbar())); + args.push_back(mt.lookupValue(thisOp.getSrcMem())); + args.push_back(mt.lookupValue(thisOp.getSize())); + + // Multicast, if available + llvm::LLVMContext &ctx = mt.getLLVMContext(); + auto *i16Unused = llvm::ConstantInt::get(llvm::Type::getInt16Ty(ctx), 0); + bool isMulticast = thisOp.getMulticastMask() ? true : false; + args.push_back(isMulticast ? mt.lookupValue(thisOp.getMulticastMask()) + : i16Unused); + + // Cachehint, if available + auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0); + bool isCacheHint = thisOp.getL2CacheHint() ? true : false; + args.push_back(isCacheHint ? mt.lookupValue(thisOp.getL2CacheHint()) + : i64Unused); + + // Flag arguments for multicast and cachehint + args.push_back(builder.getInt1(isMulticast)); + args.push_back(builder.getInt1(isCacheHint)); + + return {llvm::Intrinsic::nvvm_cp_async_bulk_global_to_shared_cluster, + std::move(args), + {}}; } #define _none @@ -1789,10 +1861,8 @@ ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd, } } -llvm::Intrinsic::ID -Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op, - LLVM::ModuleTranslation &mt, - llvm::SmallVector &args) { +NVVM::IIDArgsWithTypes Tcgen05AllocOp::getIIDAndArgsWithTypes( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto curOp = cast(op); unsigned as = llvm::cast(curOp.getAddr().getType()) .getAddressSpace(); @@ -1809,25 +1879,26 @@ Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op, } // Fill the Intrinsic Args + llvm::SmallVector args; args.push_back(mt.lookupValue(curOp.getAddr())); args.push_back(mt.lookupValue(curOp.getNCols())); - return id; + return {id, std::move(args), {}}; } -llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs( - Operation &op, LLVM::ModuleTranslation &mt, - llvm::SmallVector &args) { +NVVM::IIDArgsWithTypes Tcgen05DeallocOp::getIIDAndArgsWithTypes( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto curOp = cast(op); auto id = (curOp.getGroup() == CTAGroupKind::CTA_1) ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1 : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2; // Fill the Intrinsic Args + llvm::SmallVector args; args.push_back(mt.lookupValue(curOp.getTaddr())); args.push_back(mt.lookupValue(curOp.getNCols())); - return id; + return {id, std::move(args), {}}; } #define TCGEN05_COMMIT_IMPL(cg, is_shared, mc) \ @@ -1838,10 +1909,8 @@ llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs( has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \ : TCGEN05_COMMIT_IMPL(cta_group, is_shared, ) -llvm::Intrinsic::ID -Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op, - LLVM::ModuleTranslation &mt, - llvm::SmallVector &args) { +NVVM::IIDArgsWithTypes Tcgen05CommitOp::getIIDAndArgsWithTypes( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto curOp = cast(op); unsigned as = llvm::cast(curOp.getAddr().getType()) .getAddressSpace(); @@ -1854,11 +1923,12 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op, : GET_TCGEN05_COMMIT_ID(cg1, isShared, hasMulticast); // Fill the Intrinsic Args + llvm::SmallVector args; args.push_back(mt.lookupValue(curOp.getAddr())); if (hasMulticast) args.push_back(mt.lookupValue(curOp.getMulticastMask())); - return id; + return {id, std::move(args), {}}; } #define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \ @@ -1877,25 +1947,37 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op, return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \ }() -llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) { +NVVM::IIDArgsWithTypes +Tcgen05CpOp::getIIDAndArgsWithTypes(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { auto curOp = cast(op); bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2; auto srcFmt = curOp.getSrcFormat(); auto mc = curOp.getMulticast(); + llvm::SmallVector args; + args.push_back(mt.lookupValue(curOp.getTaddr())); + args.push_back(mt.lookupValue(curOp.getSmemDesc())); + switch (curOp.getShape()) { case Tcgen05CpShape::SHAPE_128x256b: - return GET_TCGEN05_CP_ID(_128x256b, srcFmt, is2CTA); + return {GET_TCGEN05_CP_ID(_128x256b, srcFmt, is2CTA), std::move(args), {}}; case Tcgen05CpShape::SHAPE_128x128b: - return GET_TCGEN05_CP_ID(_128x128b, srcFmt, is2CTA); + return {GET_TCGEN05_CP_ID(_128x128b, srcFmt, is2CTA), std::move(args), {}}; case Tcgen05CpShape::SHAPE_4x256b: - return GET_TCGEN05_CP_ID(_4x256b, srcFmt, is2CTA); + return {GET_TCGEN05_CP_ID(_4x256b, srcFmt, is2CTA), std::move(args), {}}; case Tcgen05CpShape::SHAPE_32x128b: - return GET_TCGEN05_CP_ID(_32x128b_warpx4, srcFmt, is2CTA); + return {GET_TCGEN05_CP_ID(_32x128b_warpx4, srcFmt, is2CTA), + std::move(args), + {}}; case Tcgen05CpShape::SHAPE_64x128b: return (mc == Tcgen05CpMulticast::WARPX2_01_23) - ? GET_TCGEN05_CP_ID(_64x128b_warpx2_01_23, srcFmt, is2CTA) - : GET_TCGEN05_CP_ID(_64x128b_warpx2_02_13, srcFmt, is2CTA); + ? NVVM::IIDArgsWithTypes( + GET_TCGEN05_CP_ID(_64x128b_warpx2_01_23, srcFmt, is2CTA), + std::move(args), {}) + : NVVM::IIDArgsWithTypes( + GET_TCGEN05_CP_ID(_64x128b_warpx2_02_13, srcFmt, is2CTA), + std::move(args), {}); } llvm_unreachable("Invalid shape in tcgen05 cp Op"); } @@ -1962,7 +2044,7 @@ static llvm::Value *getAsPackedI32(llvm::Value *arg, llvm::Type::getInt32Ty(builder.getContext())); } -NVVM::IDArgPair DotAccumulate4WayOp::getIntrinsicIDAndArgs( +NVVM::IIDArgsWithTypes DotAccumulate4WayOp::getIIDAndArgsWithTypes( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto curOp = cast(op); @@ -1980,10 +2062,10 @@ NVVM::IDArgPair DotAccumulate4WayOp::getIntrinsicIDAndArgs( llvm::Intrinsic::nvvm_idp4a_s_u, llvm::Intrinsic::nvvm_idp4a_s_s, }; - return {ids[type], args}; + return {ids[type], args, {}}; } -NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs( +NVVM::IIDArgsWithTypes DotAccumulate2WayOp::getIIDAndArgsWithTypes( Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { auto curOp = cast(op); @@ -2002,7 +2084,7 @@ NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs( llvm::Intrinsic::nvvm_idp2a_s_u, llvm::Intrinsic::nvvm_idp2a_s_s, }; - return {ids[type], args}; + return {ids[type], args, {}}; } static llvm::Value *getParamCastedAddr(llvm::Value *addr, @@ -2013,39 +2095,40 @@ static llvm::Value *getParamCastedAddr(llvm::Value *addr, llvm::NVPTXAS::AddressSpace::ADDRESS_SPACE_PARAM)); } -NVVM::IDArgPair -PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op, - LLVM::ModuleTranslation &mt, - llvm::IRBuilderBase &builder) { +NVVM::IIDArgsWithTypes +PrefetchOp::getIIDAndArgsWithTypes(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { using MemSpace = NVVM::NVVMMemorySpace; using CacheLevel = NVVM::PrefetchCacheLevel; - std::optional cacheLevel = op.getCacheLevel(); + auto thisOp = cast(op); + + std::optional cacheLevel = thisOp.getCacheLevel(); std::optional evictPriority = - op.getEvictPriority(); + thisOp.getEvictPriority(); unsigned addressSpace = - llvm::cast(op.getAddr().getType()) + llvm::cast(thisOp.getAddr().getType()) .getAddressSpace(); llvm::SmallVector args; - llvm::Value *addr = mt.lookupValue(op.getAddr()); - args.push_back(op.getInParamSpace() ? getParamCastedAddr(addr, builder) - : addr); + llvm::Value *addr = mt.lookupValue(thisOp.getAddr()); + args.push_back(thisOp.getInParamSpace() ? getParamCastedAddr(addr, builder) + : addr); - if (op.getTensormap()) - return {llvm::Intrinsic::nvvm_prefetch_tensormap, args}; + if (thisOp.getTensormap()) + return {llvm::Intrinsic::nvvm_prefetch_tensormap, args, {addr->getType()}}; assert(cacheLevel && "expected cache level for non-tensormap prefetch"); - if (op.getUniform() && *cacheLevel == CacheLevel::L1) - return {llvm::Intrinsic::nvvm_prefetchu_L1, args}; + if (thisOp.getUniform() && *cacheLevel == CacheLevel::L1) + return {llvm::Intrinsic::nvvm_prefetchu_L1, args, {}}; if (evictPriority && *cacheLevel == CacheLevel::L2) { switch (*evictPriority) { case NVVM::CacheEvictionPriority::EvictLast: - return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args}; + return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args, {}}; case NVVM::CacheEvictionPriority::EvictNormal: - return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args}; + return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args, {}}; default: llvm_unreachable("Invalid cache eviction priority"); } @@ -2054,25 +2137,628 @@ PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op, switch (addressSpace) { case MemSpace::kGenericMemorySpace: return *cacheLevel == CacheLevel::L1 - ? NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L1, args}) - : NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args}); + ? NVVM::IIDArgsWithTypes(llvm::Intrinsic::nvvm_prefetch_L1, args, + {}) + : NVVM::IIDArgsWithTypes(llvm::Intrinsic::nvvm_prefetch_L2, args, + {}); case MemSpace::kGlobalMemorySpace: return *cacheLevel == CacheLevel::L1 - ? NVVM::IDArgPair( - {llvm::Intrinsic::nvvm_prefetch_global_L1, args}) - : NVVM::IDArgPair( - {llvm::Intrinsic::nvvm_prefetch_global_L2, args}); + ? NVVM::IIDArgsWithTypes( + llvm::Intrinsic::nvvm_prefetch_global_L1, args, {}) + : NVVM::IIDArgsWithTypes( + llvm::Intrinsic::nvvm_prefetch_global_L2, args, {}); case MemSpace::kLocalMemorySpace: return *cacheLevel == CacheLevel::L1 - ? NVVM::IDArgPair( - {llvm::Intrinsic::nvvm_prefetch_local_L1, args}) - : NVVM::IDArgPair( - {llvm::Intrinsic::nvvm_prefetch_local_L2, args}); + ? NVVM::IIDArgsWithTypes(llvm::Intrinsic::nvvm_prefetch_local_L1, + args, {}) + : NVVM::IIDArgsWithTypes(llvm::Intrinsic::nvvm_prefetch_local_L2, + args, {}); default: llvm_unreachable("Invalid pointer address space"); } } +#define REDUX_F32_ID_IMPL(op, abs, hasNaN) \ + hasNaN ? llvm::Intrinsic::nvvm_redux_sync_f##op##abs##_NaN \ + : llvm::Intrinsic::nvvm_redux_sync_f##op##abs + +#define GET_REDUX_F32_ID(op, hasAbs, hasNaN) \ + hasAbs ? REDUX_F32_ID_IMPL(op, _abs, hasNaN) : REDUX_F32_ID_IMPL(op, , hasNaN) + +NVVM::IIDArgsWithTypes +ReduxOp::getIIDAndArgsWithTypes(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + auto thisOp = cast(op); + + llvm::SmallVector args; + args.push_back(mt.lookupValue(thisOp.getVal())); + args.push_back(mt.lookupValue(thisOp.getMaskAndClamp())); + + bool hasAbs = thisOp.getAbs(); + bool hasNaN = thisOp.getNan(); + NVVM::ReduxKind kind = thisOp.getKind(); + + llvm::Intrinsic::ID id; + + switch (kind) { + case NVVM::ReduxKind::ADD: + id = llvm::Intrinsic::nvvm_redux_sync_add; + break; + case NVVM::ReduxKind::UMAX: + id = llvm::Intrinsic::nvvm_redux_sync_umax; + break; + case NVVM::ReduxKind::UMIN: + id = llvm::Intrinsic::nvvm_redux_sync_umin; + break; + case NVVM::ReduxKind::AND: + id = llvm::Intrinsic::nvvm_redux_sync_and; + break; + case NVVM::ReduxKind::OR: + id = llvm::Intrinsic::nvvm_redux_sync_or; + break; + case NVVM::ReduxKind::XOR: + id = llvm::Intrinsic::nvvm_redux_sync_xor; + break; + case NVVM::ReduxKind::MAX: + id = llvm::Intrinsic::nvvm_redux_sync_max; + break; + case NVVM::ReduxKind::MIN: + id = llvm::Intrinsic::nvvm_redux_sync_min; + break; + case NVVM::ReduxKind::FMIN: + id = GET_REDUX_F32_ID(min, hasAbs, hasNaN); + break; + case NVVM::ReduxKind::FMAX: + id = GET_REDUX_F32_ID(max, hasAbs, hasNaN); + break; + } + + return {id, std::move(args), {}}; +} + +NVVM::IIDArgsWithTypes +ShflOp::getIIDAndArgsWithTypes(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + auto thisOp = cast(op); + + llvm::SmallVector args; + args.push_back(mt.lookupValue(thisOp.getThreadMask())); + args.push_back(mt.lookupValue(thisOp.getVal())); + args.push_back(mt.lookupValue(thisOp.getOffset())); + args.push_back(mt.lookupValue(thisOp.getMaskAndClamp())); + + mlir::Type resultType = thisOp.getResult().getType(); + NVVM::ShflKind kind = thisOp.getKind(); + bool withPredicate = static_cast(thisOp.getReturnValueAndIsValid()); + + llvm::Intrinsic::ID id; + + if (withPredicate) { + resultType = cast(resultType).getBody()[0]; + switch (kind) { + case NVVM::ShflKind::bfly: + id = resultType.isFloat() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p + : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p; + break; + case NVVM::ShflKind::up: + id = resultType.isFloat() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p + : llvm::Intrinsic::nvvm_shfl_sync_up_i32p; + break; + case NVVM::ShflKind::down: + id = resultType.isFloat() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p + : llvm::Intrinsic::nvvm_shfl_sync_down_i32p; + break; + case NVVM::ShflKind::idx: + id = resultType.isFloat() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p + : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p; + break; + } + } else { + switch (kind) { + case NVVM::ShflKind::bfly: + id = resultType.isFloat() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32 + : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32; + break; + case NVVM::ShflKind::up: + id = resultType.isFloat() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32 + : llvm::Intrinsic::nvvm_shfl_sync_up_i32; + break; + case NVVM::ShflKind::down: + id = resultType.isFloat() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32 + : llvm::Intrinsic::nvvm_shfl_sync_down_i32; + break; + case NVVM::ShflKind::idx: + id = resultType.isFloat() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32 + : llvm::Intrinsic::nvvm_shfl_sync_idx_i32; + break; + } + } + + return {id, std::move(args), {}}; +} + +NVVM::IIDArgsWithTypes +MatchSyncOp::getIIDAndArgsWithTypes(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + auto thisOp = cast(op); + + llvm::SmallVector args; + args.push_back(mt.lookupValue(thisOp.getThreadMask())); + args.push_back(mt.lookupValue(thisOp.getVal())); + + llvm::Intrinsic::ID id; + + mlir::Type valType = thisOp.getVal().getType(); + NVVM::MatchSyncKind kind = thisOp.getKind(); + + switch (kind) { + case NVVM::MatchSyncKind::any: + id = valType.isInteger(32) ? llvm::Intrinsic::nvvm_match_any_sync_i32 + : llvm::Intrinsic::nvvm_match_any_sync_i64; + break; + case NVVM::MatchSyncKind::all: + // match.all instruction has two variants -- one returns a single value, + // another returns a pair {value, predicate}. We currently only implement + // the latter as that's the variant exposed by CUDA API. + id = valType.isInteger(32) ? llvm::Intrinsic::nvvm_match_all_sync_i32p + : llvm::Intrinsic::nvvm_match_all_sync_i64p; + break; + } + + return {id, std::move(args), {}}; +} + +NVVM::IIDArgsWithTypes +VoteSyncOp::getIIDAndArgsWithTypes(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + auto thisOp = cast(op); + + llvm::SmallVector args; + args.push_back(mt.lookupValue(thisOp.getMask())); + args.push_back(mt.lookupValue(thisOp.getPred())); + + llvm::Intrinsic::ID id; + + NVVM::VoteSyncKind kind = thisOp.getKind(); + + switch (kind) { + case NVVM::VoteSyncKind::any: + id = llvm::Intrinsic::nvvm_vote_any_sync; + break; + case NVVM::VoteSyncKind::all: + id = llvm::Intrinsic::nvvm_vote_all_sync; + break; + case NVVM::VoteSyncKind::ballot: + id = llvm::Intrinsic::nvvm_vote_ballot_sync; + break; + case NVVM::VoteSyncKind::uni: + id = llvm::Intrinsic::nvvm_vote_uni_sync; + break; + } + + return {id, std::move(args), {}}; +} + +NVVM::IIDArgsWithTypes +LdMatrixOp::getIIDAndArgsWithTypes(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + auto thisOp = cast(op); + + llvm::SmallVector args; + args.reserve(op.getOperands().size()); + for (mlir::Value v : op.getOperands()) + args.push_back(mt.lookupValue(v)); + + llvm::SmallVector types = {args[0]->getType()}; + + llvm::Intrinsic::ID id = llvm::Intrinsic::not_intrinsic; + + NVVM::MMALayout layout = thisOp.getLayout(); + int32_t num = thisOp.getNum(); + NVVM::LdStMatrixShapeAttr shape = thisOp.getShape(); + NVVM::LdStMatrixEltType eltType = thisOp.getEltType(); + + if (shape.getM() == 8 && shape.getN() == 8) { + switch (num) { + case 1: + id = (layout == NVVM::MMALayout::row) + ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16 + : llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16; + break; + case 2: + id = (layout == NVVM::MMALayout::row) + ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16 + : llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16; + break; + case 4: + id = (layout == NVVM::MMALayout::row) + ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16 + : llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16; + break; + } + } else if (shape.getM() == 8 && shape.getN() == 16) { + if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) { + switch (num) { + case 1: + id = llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32; + break; + case 2: + id = llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32; + break; + case 4: + id = llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32; + break; + } + } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) { + switch (num) { + case 1: + id = llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64; + break; + case 2: + id = llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64; + break; + case 4: + id = llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64; + break; + } + } + } else if (shape.getM() == 16 && shape.getN() == 16) { + if (eltType == NVVM::LdStMatrixEltType::B8) { + switch (num) { + case 1: + id = llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8; + break; + case 2: + id = llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8; + break; + } + } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) { + switch (num) { + case 1: + id = llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32; + break; + case 2: + id = llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32; + break; + } + } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) { + switch (num) { + case 1: + id = llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64; + break; + case 2: + id = llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64; + break; + } + } + } else { + llvm_unreachable("unknown ldmatrix kind"); + } + + return {id, std::move(args), types}; +} + +NVVM::IIDArgsWithTypes +StMatrixOp::getIIDAndArgsWithTypes(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + auto thisOp = cast(op); + + llvm::SmallVector args; + args.reserve(op.getOperands().size()); + for (mlir::Value v : op.getOperands()) + args.push_back(mt.lookupValue(v)); + + llvm::SmallVector types = {args[0]->getType()}; + + llvm::Intrinsic::ID id = llvm::Intrinsic::not_intrinsic; + + NVVM::MMALayout layout = thisOp.getLayout(); + int32_t num = thisOp.getSources().size(); + NVVM::LdStMatrixShapeAttr shape = thisOp.getShape(); + + if (shape.getM() == 8 && shape.getN() == 8) { + switch (num) { + case 1: + id = (layout == NVVM::MMALayout::row) + ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16 + : llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16; + break; + case 2: + id = (layout == NVVM::MMALayout::row) + ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16 + : llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16; + break; + case 4: + id = (layout == NVVM::MMALayout::row) + ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16 + : llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16; + break; + } + } else if (shape.getM() == 16 && shape.getN() == 8) { + switch (num) { + case 1: + id = llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8; + break; + case 2: + id = llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8; + break; + case 4: + id = llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8; + break; + } + } else { + llvm_unreachable("unknown stmatrix kind"); + } + + return {id, std::move(args), types}; +} + +NVVM::IIDArgsWithTypes +BulkStoreOp::getIIDAndArgsWithTypes(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + auto thisOp = cast(op); + + llvm::SmallVector args; + args.push_back(mt.lookupValue(thisOp.getAddr())); + args.push_back(mt.lookupValue(thisOp.getSize())); + args.push_back(builder.getInt64(thisOp.getInitVal())); + + llvm::Intrinsic::ID id; + + auto addrType = llvm::cast(thisOp.getAddr().getType()); + bool isSharedMemory = + addrType.getAddressSpace() == NVVM::NVVMMemorySpace::kSharedMemorySpace; + id = isSharedMemory ? llvm::Intrinsic::nvvm_st_bulk_shared_cta + : llvm::Intrinsic::nvvm_st_bulk; + + return {id, std::move(args), {}}; +} + +NVVM::IIDArgsWithTypes FenceProxyAcquireOp::getIIDAndArgsWithTypes( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast(op); + + llvm::SmallVector args; + args.push_back(mt.lookupValue(thisOp.getAddr())); + args.push_back(mt.lookupValue(thisOp.getSize())); + + llvm::Intrinsic::ID id = llvm::Intrinsic::not_intrinsic; + + NVVM::ProxyKind fromProxy = thisOp.getFromProxy(); + NVVM::ProxyKind toProxy = thisOp.getToProxy(); + NVVM::MemScopeKind scope = thisOp.getScope(); + + if (fromProxy == NVVM::ProxyKind::GENERIC && + toProxy == NVVM::ProxyKind::TENSORMAP) { + switch (scope) { + case NVVM::MemScopeKind::CTA: + id = llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_cta; + break; + case NVVM::MemScopeKind::CLUSTER: + id = llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_cluster; + break; + case NVVM::MemScopeKind::GPU: + id = llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_gpu; + break; + case NVVM::MemScopeKind::SYS: + id = llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_sys; + break; + } + } else { + llvm_unreachable("unsupported proxy kinds"); + } + + return {id, std::move(args), {}}; +} + +NVVM::IIDArgsWithTypes FenceProxyReleaseOp::getIIDAndArgsWithTypes( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast(op); + + llvm::Intrinsic::ID id = llvm::Intrinsic::not_intrinsic; + + NVVM::ProxyKind fromProxy = thisOp.getFromProxy(); + NVVM::ProxyKind toProxy = thisOp.getToProxy(); + NVVM::MemScopeKind scope = thisOp.getScope(); + + if (fromProxy == NVVM::ProxyKind::GENERIC && + toProxy == NVVM::ProxyKind::TENSORMAP) { + switch (scope) { + case NVVM::MemScopeKind::CTA: + id = llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_cta; + break; + case NVVM::MemScopeKind::CLUSTER: + id = llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_cluster; + break; + case NVVM::MemScopeKind::GPU: + id = llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_gpu; + break; + case NVVM::MemScopeKind::SYS: + id = llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_sys; + break; + } + } else { + llvm_unreachable("unsupported proxy kinds"); + } + + return {id, {}, {}}; +} + +#define TCGEN05LD(SHAPE, NUM) llvm::Intrinsic::nvvm_tcgen05_ld_##SHAPE##_##NUM + +NVVM::IIDArgsWithTypes +Tcgen05LdOp::getIIDAndArgsWithTypes(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + auto thisOp = cast(op); + + llvm::SmallVector args; + args.push_back(mt.lookupValue(thisOp.getTmemAddr())); + if (mt.lookupValue(thisOp.getOffset())) + args.push_back(mt.lookupValue(thisOp.getOffset())); + + llvm::LLVMContext &ctx = mt.getLLVMContext(); + auto Pack = llvm::ConstantInt::get(ctx, llvm::APInt(1, thisOp.getPack())); + args.push_back(Pack); + + mlir::Type resultType = thisOp.getResult().getType(); + uint32_t num = isa(resultType) + ? llvm::cast(resultType).getNumElements() + : 1; + NVVM::Tcgen05LdStShape shape = thisOp.getShape(); + + llvm::Intrinsic::ID id; + + llvm::Intrinsic::ID Shape16x64b[] = { + TCGEN05LD(16x64b, x1), TCGEN05LD(16x64b, x2), TCGEN05LD(16x64b, x4), + TCGEN05LD(16x64b, x8), TCGEN05LD(16x64b, x16), TCGEN05LD(16x64b, x32), + TCGEN05LD(16x64b, x64), TCGEN05LD(16x64b, x128), + }; + + llvm::Intrinsic::ID Shape16x128b[] = { + TCGEN05LD(16x128b, x1), TCGEN05LD(16x128b, x2), TCGEN05LD(16x128b, x4), + TCGEN05LD(16x128b, x8), TCGEN05LD(16x128b, x16), TCGEN05LD(16x128b, x32), + TCGEN05LD(16x128b, x64), + }; + + llvm::Intrinsic::ID Shape16x256b[] = { + TCGEN05LD(16x256b, x1), TCGEN05LD(16x256b, x2), TCGEN05LD(16x256b, x4), + TCGEN05LD(16x256b, x8), TCGEN05LD(16x256b, x16), TCGEN05LD(16x256b, x32), + }; + + llvm::Intrinsic::ID Shape16x32bx2[] = { + TCGEN05LD(16x32bx2, x1), TCGEN05LD(16x32bx2, x2), + TCGEN05LD(16x32bx2, x4), TCGEN05LD(16x32bx2, x8), + TCGEN05LD(16x32bx2, x16), TCGEN05LD(16x32bx2, x32), + TCGEN05LD(16x32bx2, x64), TCGEN05LD(16x32bx2, x128), + }; + + llvm::Intrinsic::ID Shape32x32b[] = { + TCGEN05LD(32x32b, x1), TCGEN05LD(32x32b, x2), TCGEN05LD(32x32b, x4), + TCGEN05LD(32x32b, x8), TCGEN05LD(32x32b, x16), TCGEN05LD(32x32b, x32), + TCGEN05LD(32x32b, x64), TCGEN05LD(32x32b, x128), + }; + + // `num` contains the length of vector and log2 of `num` returns the index + // into the shape array + unsigned Idx = std::log2(num); + + switch (shape) { + case NVVM::Tcgen05LdStShape::SHAPE_16X64B: + id = Shape16x64b[Idx]; + break; + case NVVM::Tcgen05LdStShape::SHAPE_16X128B: + id = Shape16x128b[Idx - 1]; + break; + case NVVM::Tcgen05LdStShape::SHAPE_16X256B: + id = Shape16x256b[Idx - 2]; + break; + case NVVM::Tcgen05LdStShape::SHAPE_32X32B: + id = Shape32x32b[Idx]; + break; + case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2: + id = Shape16x32bx2[Idx]; + break; + } + + if (id == llvm::Intrinsic::not_intrinsic) + llvm::report_fatal_error("unknow intrinsic signature for tcgen05.ld"); + + return {id, std::move(args), {}}; +} + +#define TCGEN05ST(SHAPE, NUM) llvm::Intrinsic::nvvm_tcgen05_st_##SHAPE##_##NUM + +NVVM::IIDArgsWithTypes +Tcgen05StOp::getIIDAndArgsWithTypes(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + auto thisOp = cast(op); + + llvm::SmallVector args; + args.push_back(mt.lookupValue(thisOp.getTmemAddr())); + if (mt.lookupValue(thisOp.getOffset())) + args.push_back(mt.lookupValue(thisOp.getOffset())); + args.push_back(mt.lookupValue(thisOp.getVal())); + + llvm::LLVMContext &ctx = mt.getLLVMContext(); + auto Unpack = llvm::ConstantInt::get(ctx, llvm::APInt(1, thisOp.getUnpack())); + args.push_back(Unpack); + + mlir::Type resultType = thisOp.getVal().getType(); + uint32_t num = isa(resultType) + ? llvm::cast(resultType).getNumElements() + : 1; + NVVM::Tcgen05LdStShape shape = thisOp.getShape(); + + llvm::Intrinsic::ID id; + + llvm::Intrinsic::ID Shape16x64b[] = { + TCGEN05ST(16x64b, x1), TCGEN05ST(16x64b, x2), TCGEN05ST(16x64b, x4), + TCGEN05ST(16x64b, x8), TCGEN05ST(16x64b, x16), TCGEN05ST(16x64b, x32), + TCGEN05ST(16x64b, x64), TCGEN05ST(16x64b, x128), + }; + + llvm::Intrinsic::ID Shape16x128b[] = { + TCGEN05ST(16x128b, x1), TCGEN05ST(16x128b, x2), TCGEN05ST(16x128b, x4), + TCGEN05ST(16x128b, x8), TCGEN05ST(16x128b, x16), TCGEN05ST(16x128b, x32), + TCGEN05ST(16x128b, x64), + }; + + llvm::Intrinsic::ID Shape16x256b[] = { + TCGEN05ST(16x256b, x1), TCGEN05ST(16x256b, x2), TCGEN05ST(16x256b, x4), + TCGEN05ST(16x256b, x8), TCGEN05ST(16x256b, x16), TCGEN05ST(16x256b, x32), + }; + + llvm::Intrinsic::ID Shape16x32bx2[] = { + TCGEN05ST(16x32bx2, x1), TCGEN05ST(16x32bx2, x2), + TCGEN05ST(16x32bx2, x4), TCGEN05ST(16x32bx2, x8), + TCGEN05ST(16x32bx2, x16), TCGEN05ST(16x32bx2, x32), + TCGEN05ST(16x32bx2, x64), TCGEN05ST(16x32bx2, x128), + }; + + llvm::Intrinsic::ID Shape32x32b[] = { + TCGEN05ST(32x32b, x1), TCGEN05ST(32x32b, x2), TCGEN05ST(32x32b, x4), + TCGEN05ST(32x32b, x8), TCGEN05ST(32x32b, x16), TCGEN05ST(32x32b, x32), + TCGEN05ST(32x32b, x64), TCGEN05ST(32x32b, x128), + }; + + // `num` contains the length of vector and log2 of `num` returns the index + // into the shape array + unsigned Idx = std::log2(num); + + switch (shape) { + case NVVM::Tcgen05LdStShape::SHAPE_16X64B: + id = Shape16x64b[Idx]; + break; + case NVVM::Tcgen05LdStShape::SHAPE_16X128B: + id = Shape16x128b[Idx - 1]; + break; + case NVVM::Tcgen05LdStShape::SHAPE_16X256B: + id = Shape16x256b[Idx - 2]; + break; + case NVVM::Tcgen05LdStShape::SHAPE_32X32B: + id = Shape32x32b[Idx]; + break; + case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2: + id = Shape16x32bx2[Idx]; + break; + } + + if (id == llvm::Intrinsic::not_intrinsic) + llvm::report_fatal_error("unknow intrinsic signature for tcgen05.st"); + + return {id, std::move(args), {}}; +} + bool NVVM::InlinePtxOp::getAsmValues( RewriterBase &rewriter, llvm::SmallVectorImpl> diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp index 7f69af14df338..cadb1e7c67246 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -26,380 +26,6 @@ using namespace mlir; using namespace mlir::LLVM; using mlir::LLVM::detail::createIntrinsicCall; -#define REDUX_F32_ID_IMPL(op, abs, hasNaN) \ - hasNaN ? llvm::Intrinsic::nvvm_redux_sync_f##op##abs##_NaN \ - : llvm::Intrinsic::nvvm_redux_sync_f##op##abs - -#define GET_REDUX_F32_ID(op, hasAbs, hasNaN) \ - hasAbs ? REDUX_F32_ID_IMPL(op, _abs, hasNaN) : REDUX_F32_ID_IMPL(op, , hasNaN) - -static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType, - NVVM::ReduxKind kind, - bool hasAbs, bool hasNaN) { - if (!(resultType->isIntegerTy(32) || resultType->isFloatTy())) - llvm_unreachable("unsupported data type for redux"); - - switch (kind) { - case NVVM::ReduxKind::ADD: - return llvm::Intrinsic::nvvm_redux_sync_add; - case NVVM::ReduxKind::UMAX: - return llvm::Intrinsic::nvvm_redux_sync_umax; - case NVVM::ReduxKind::UMIN: - return llvm::Intrinsic::nvvm_redux_sync_umin; - case NVVM::ReduxKind::AND: - return llvm::Intrinsic::nvvm_redux_sync_and; - case NVVM::ReduxKind::OR: - return llvm::Intrinsic::nvvm_redux_sync_or; - case NVVM::ReduxKind::XOR: - return llvm::Intrinsic::nvvm_redux_sync_xor; - case NVVM::ReduxKind::MAX: - return llvm::Intrinsic::nvvm_redux_sync_max; - case NVVM::ReduxKind::MIN: - return llvm::Intrinsic::nvvm_redux_sync_min; - case NVVM::ReduxKind::FMIN: - return GET_REDUX_F32_ID(min, hasAbs, hasNaN); - case NVVM::ReduxKind::FMAX: - return GET_REDUX_F32_ID(max, hasAbs, hasNaN); - } - llvm_unreachable("unknown redux kind"); -} - -static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType, - NVVM::ShflKind kind, - bool withPredicate) { - - if (withPredicate) { - resultType = cast(resultType)->getElementType(0); - switch (kind) { - case NVVM::ShflKind::bfly: - return resultType->isFloatTy() - ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p - : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p; - case NVVM::ShflKind::up: - return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p - : llvm::Intrinsic::nvvm_shfl_sync_up_i32p; - case NVVM::ShflKind::down: - return resultType->isFloatTy() - ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p - : llvm::Intrinsic::nvvm_shfl_sync_down_i32p; - case NVVM::ShflKind::idx: - return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p - : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p; - } - } else { - switch (kind) { - case NVVM::ShflKind::bfly: - return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32 - : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32; - case NVVM::ShflKind::up: - return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32 - : llvm::Intrinsic::nvvm_shfl_sync_up_i32; - case NVVM::ShflKind::down: - return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32 - : llvm::Intrinsic::nvvm_shfl_sync_down_i32; - case NVVM::ShflKind::idx: - return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32 - : llvm::Intrinsic::nvvm_shfl_sync_idx_i32; - } - } - llvm_unreachable("unknown shuffle kind"); -} - -static llvm::Intrinsic::ID getMatchSyncIntrinsicId(Type valType, - NVVM::MatchSyncKind kind) { - switch (kind) { - case NVVM::MatchSyncKind::any: - return valType.isInteger(32) ? llvm::Intrinsic::nvvm_match_any_sync_i32 - : llvm::Intrinsic::nvvm_match_any_sync_i64; - case NVVM::MatchSyncKind::all: - // match.all instruction has two variants -- one returns a single value, - // another returns a pair {value, predicate}. We currently only implement - // the latter as that's the variant exposed by CUDA API. - return valType.isInteger(32) ? llvm::Intrinsic::nvvm_match_all_sync_i32p - : llvm::Intrinsic::nvvm_match_all_sync_i64p; - } - llvm_unreachable("unsupported match sync kind"); -} - -static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) { - switch (kind) { - case NVVM::VoteSyncKind::any: - return llvm::Intrinsic::nvvm_vote_any_sync; - case NVVM::VoteSyncKind::all: - return llvm::Intrinsic::nvvm_vote_all_sync; - case NVVM::VoteSyncKind::ballot: - return llvm::Intrinsic::nvvm_vote_ballot_sync; - case NVVM::VoteSyncKind::uni: - return llvm::Intrinsic::nvvm_vote_uni_sync; - } - llvm_unreachable("unsupported vote kind"); -} - -static llvm::Intrinsic::ID -getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, - NVVM::LdStMatrixShapeAttr shape, - NVVM::LdStMatrixEltType eltType) { - if (shape.getM() == 8 && shape.getN() == 8) { - switch (num) { - case 1: - return (layout == NVVM::MMALayout::row) - ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16 - : llvm::Intrinsic:: - nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16; - case 2: - return (layout == NVVM::MMALayout::row) - ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16 - : llvm::Intrinsic:: - nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16; - case 4: - return (layout == NVVM::MMALayout::row) - ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16 - : llvm::Intrinsic:: - nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16; - } - } else if (shape.getM() == 8 && shape.getN() == 16) { - if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) { - switch (num) { - case 1: - return llvm::Intrinsic:: - nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32; - case 2: - return llvm::Intrinsic:: - nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32; - case 4: - return llvm::Intrinsic:: - nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32; - } - } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) { - switch (num) { - case 1: - return llvm::Intrinsic:: - nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64; - case 2: - return llvm::Intrinsic:: - nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64; - case 4: - return llvm::Intrinsic:: - nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64; - } - } - } else if (shape.getM() == 16 && shape.getN() == 16) { - if (eltType == NVVM::LdStMatrixEltType::B8) { - switch (num) { - case 1: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8; - case 2: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8; - } - } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) { - switch (num) { - case 1: - return llvm::Intrinsic:: - nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32; - case 2: - return llvm::Intrinsic:: - nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32; - } - } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) { - switch (num) { - case 1: - return llvm::Intrinsic:: - nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64; - case 2: - return llvm::Intrinsic:: - nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64; - } - } - } - llvm_unreachable("unknown ldmatrix kind"); -} - -/// Return the intrinsic ID associated with stmatrix for the given paramters. -static llvm::Intrinsic::ID -getStMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, - NVVM::LdStMatrixShapeAttr shape, - NVVM::LdStMatrixEltType eltType) { - if (shape.getM() == 8 && shape.getN() == 8) { - switch (num) { - case 1: - return (layout == NVVM::MMALayout::row) - ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16 - : llvm::Intrinsic:: - nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16; - case 2: - return (layout == NVVM::MMALayout::row) - ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16 - : llvm::Intrinsic:: - nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16; - case 4: - return (layout == NVVM::MMALayout::row) - ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16 - : llvm::Intrinsic:: - nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16; - } - } else if (shape.getM() == 16 && shape.getN() == 8) { - switch (num) { - case 1: - return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8; - case 2: - return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8; - case 4: - return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8; - } - } - llvm_unreachable("unknown stmatrix kind"); -} - -/// Return the intrinsic ID associated with st.bulk for the given address type. -static llvm::Intrinsic::ID -getStBulkIntrinsicId(LLVM::LLVMPointerType addrType) { - bool isSharedMemory = - addrType.getAddressSpace() == NVVM::NVVMMemorySpace::kSharedMemorySpace; - return isSharedMemory ? llvm::Intrinsic::nvvm_st_bulk_shared_cta - : llvm::Intrinsic::nvvm_st_bulk; -} - -static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy, - NVVM::ProxyKind toProxy, - NVVM::MemScopeKind scope, - bool isRelease) { - if (fromProxy == NVVM::ProxyKind::GENERIC && - toProxy == NVVM::ProxyKind::TENSORMAP) { - switch (scope) { - case NVVM::MemScopeKind::CTA: { - if (isRelease) - return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_cta; - return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_cta; - } - case NVVM::MemScopeKind::CLUSTER: { - if (isRelease) - return llvm::Intrinsic:: - nvvm_fence_proxy_tensormap_generic_release_cluster; - return llvm::Intrinsic:: - nvvm_fence_proxy_tensormap_generic_acquire_cluster; - } - case NVVM::MemScopeKind::GPU: { - if (isRelease) - return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_gpu; - return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_gpu; - } - case NVVM::MemScopeKind::SYS: { - if (isRelease) - return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_sys; - return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_sys; - } - } - llvm_unreachable("Unknown scope for uni-directional fence.proxy operation"); - } - llvm_unreachable("Unsupported proxy kinds"); -} - -#define TCGEN05LD(SHAPE, NUM) llvm::Intrinsic::nvvm_tcgen05_ld_##SHAPE##_##NUM - -static llvm::Intrinsic::ID -getTcgen05LdIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num) { - llvm::Intrinsic::ID Shape16x64b[] = { - TCGEN05LD(16x64b, x1), TCGEN05LD(16x64b, x2), TCGEN05LD(16x64b, x4), - TCGEN05LD(16x64b, x8), TCGEN05LD(16x64b, x16), TCGEN05LD(16x64b, x32), - TCGEN05LD(16x64b, x64), TCGEN05LD(16x64b, x128), - }; - - llvm::Intrinsic::ID Shape16x128b[] = { - TCGEN05LD(16x128b, x1), TCGEN05LD(16x128b, x2), TCGEN05LD(16x128b, x4), - TCGEN05LD(16x128b, x8), TCGEN05LD(16x128b, x16), TCGEN05LD(16x128b, x32), - TCGEN05LD(16x128b, x64), - }; - - llvm::Intrinsic::ID Shape16x256b[] = { - TCGEN05LD(16x256b, x1), TCGEN05LD(16x256b, x2), TCGEN05LD(16x256b, x4), - TCGEN05LD(16x256b, x8), TCGEN05LD(16x256b, x16), TCGEN05LD(16x256b, x32), - }; - - llvm::Intrinsic::ID Shape16x32bx2[] = { - TCGEN05LD(16x32bx2, x1), TCGEN05LD(16x32bx2, x2), - TCGEN05LD(16x32bx2, x4), TCGEN05LD(16x32bx2, x8), - TCGEN05LD(16x32bx2, x16), TCGEN05LD(16x32bx2, x32), - TCGEN05LD(16x32bx2, x64), TCGEN05LD(16x32bx2, x128), - }; - - llvm::Intrinsic::ID Shape32x32b[] = { - TCGEN05LD(32x32b, x1), TCGEN05LD(32x32b, x2), TCGEN05LD(32x32b, x4), - TCGEN05LD(32x32b, x8), TCGEN05LD(32x32b, x16), TCGEN05LD(32x32b, x32), - TCGEN05LD(32x32b, x64), TCGEN05LD(32x32b, x128), - }; - - // `num` contains the length of vector and log2 of `num` returns the index - // into the shape array - unsigned Idx = std::log2(num); - - switch (shape) { - case NVVM::Tcgen05LdStShape::SHAPE_16X64B: - return Shape16x64b[Idx]; - case NVVM::Tcgen05LdStShape::SHAPE_16X128B: - return Shape16x128b[Idx - 1]; - case NVVM::Tcgen05LdStShape::SHAPE_16X256B: - return Shape16x256b[Idx - 2]; - case NVVM::Tcgen05LdStShape::SHAPE_32X32B: - return Shape32x32b[Idx]; - case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2: - return Shape16x32bx2[Idx]; - } - llvm_unreachable("unhandled tcgen05.ld lowering"); -} - -#define TCGEN05ST(SHAPE, NUM) llvm::Intrinsic::nvvm_tcgen05_st_##SHAPE##_##NUM - -static llvm::Intrinsic::ID -getTcgen05StIntrinsicID(mlir::NVVM::Tcgen05LdStShape shape, uint32_t num) { - llvm::Intrinsic::ID Shape16x64b[] = { - TCGEN05ST(16x64b, x1), TCGEN05ST(16x64b, x2), TCGEN05ST(16x64b, x4), - TCGEN05ST(16x64b, x8), TCGEN05ST(16x64b, x16), TCGEN05ST(16x64b, x32), - TCGEN05ST(16x64b, x64), TCGEN05ST(16x64b, x128), - }; - - llvm::Intrinsic::ID Shape16x128b[] = { - TCGEN05ST(16x128b, x1), TCGEN05ST(16x128b, x2), TCGEN05ST(16x128b, x4), - TCGEN05ST(16x128b, x8), TCGEN05ST(16x128b, x16), TCGEN05ST(16x128b, x32), - TCGEN05ST(16x128b, x64), - }; - - llvm::Intrinsic::ID Shape16x256b[] = { - TCGEN05ST(16x256b, x1), TCGEN05ST(16x256b, x2), TCGEN05ST(16x256b, x4), - TCGEN05ST(16x256b, x8), TCGEN05ST(16x256b, x16), TCGEN05ST(16x256b, x32), - }; - - llvm::Intrinsic::ID Shape16x32bx2[] = { - TCGEN05ST(16x32bx2, x1), TCGEN05ST(16x32bx2, x2), - TCGEN05ST(16x32bx2, x4), TCGEN05ST(16x32bx2, x8), - TCGEN05ST(16x32bx2, x16), TCGEN05ST(16x32bx2, x32), - TCGEN05ST(16x32bx2, x64), TCGEN05ST(16x32bx2, x128), - }; - - llvm::Intrinsic::ID Shape32x32b[] = { - TCGEN05ST(32x32b, x1), TCGEN05ST(32x32b, x2), TCGEN05ST(32x32b, x4), - TCGEN05ST(32x32b, x8), TCGEN05ST(32x32b, x16), TCGEN05ST(32x32b, x32), - TCGEN05ST(32x32b, x64), TCGEN05ST(32x32b, x128), - }; - - // `num` contains the length of vector and log2 of `num` returns the index - // into the shape array - unsigned Idx = std::log2(num); - - switch (shape) { - case NVVM::Tcgen05LdStShape::SHAPE_16X64B: - return Shape16x64b[Idx]; - case NVVM::Tcgen05LdStShape::SHAPE_16X128B: - return Shape16x128b[Idx - 1]; - case NVVM::Tcgen05LdStShape::SHAPE_16X256B: - return Shape16x256b[Idx - 2]; - case NVVM::Tcgen05LdStShape::SHAPE_32X32B: - return Shape32x32b[Idx]; - case NVVM::Tcgen05LdStShape::SHAPE_16X32BX2: - return Shape16x32bx2[Idx]; - } - llvm_unreachable("unhandled tcgen05.st lowering"); -} - namespace { /// Implementation of the dialect interface that converts operations belonging /// to the NVVM dialect to LLVM IR. diff --git a/mlir/test/Target/LLVMIR/nvvm/tma_store_reduce.mlir b/mlir/test/Target/LLVMIR/nvvm/tma_store_reduce.mlir index 6e0b48489e8b0..71580ebed672b 100644 --- a/mlir/test/Target/LLVMIR/nvvm/tma_store_reduce.mlir +++ b/mlir/test/Target/LLVMIR/nvvm/tma_store_reduce.mlir @@ -19,14 +19,14 @@ llvm.func @tma_store_reduce_1d(%src : !llvm.ptr<3>, %tma_desc : !llvm.ptr, %d0 : nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind} : !llvm.ptr, !llvm.ptr<3> nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind} : !llvm.ptr, !llvm.ptr<3> - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.add.tile.1d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.min.tile.1d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.max.tile.1d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.inc.tile.1d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.dec.tile.1d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.and.tile.1d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.or.tile.1d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.xor.tile.1d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i64 undef, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.add.tile.1d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.min.tile.1d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.max.tile.1d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.inc.tile.1d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.dec.tile.1d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.and.tile.1d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.or.tile.1d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.xor.tile.1d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i64 poison, i1 false) nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0] {redKind = #nvvm.tma_redux_kind, mode = #nvvm.tma_store_mode} : !llvm.ptr, !llvm.ptr<3> nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0] {redKind = #nvvm.tma_redux_kind, mode = #nvvm.tma_store_mode} : !llvm.ptr, !llvm.ptr<3> nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0] {redKind = #nvvm.tma_redux_kind, mode = #nvvm.tma_store_mode} : !llvm.ptr, !llvm.ptr<3> @@ -59,14 +59,14 @@ llvm.func @tma_store_reduce_2d(%src : !llvm.ptr<3>, %tma_desc : !llvm.ptr, %d0 : nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind} : !llvm.ptr, !llvm.ptr<3> nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind} : !llvm.ptr, !llvm.ptr<3> - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.add.tile.2d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.min.tile.2d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.max.tile.2d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.inc.tile.2d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.dec.tile.2d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.and.tile.2d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.or.tile.2d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.xor.tile.2d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i64 undef, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.add.tile.2d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.min.tile.2d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.max.tile.2d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.inc.tile.2d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.dec.tile.2d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.and.tile.2d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.or.tile.2d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.xor.tile.2d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i64 poison, i1 false) nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1] {redKind = #nvvm.tma_redux_kind} : !llvm.ptr, !llvm.ptr<3> nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1] {redKind = #nvvm.tma_redux_kind} : !llvm.ptr, !llvm.ptr<3> nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1] {redKind = #nvvm.tma_redux_kind} : !llvm.ptr, !llvm.ptr<3> @@ -99,14 +99,14 @@ llvm.func @tma_store_reduce_3d_tile(%src : !llvm.ptr<3>, %tma_desc : !llvm.ptr, nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1, %d2] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind} : !llvm.ptr, !llvm.ptr<3> nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1, %d2] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind} : !llvm.ptr, !llvm.ptr<3> - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.add.tile.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.min.tile.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.max.tile.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.inc.tile.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.dec.tile.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.and.tile.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.or.tile.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.xor.tile.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 undef, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.add.tile.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.min.tile.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.max.tile.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.inc.tile.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.dec.tile.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.and.tile.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.or.tile.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.xor.tile.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 poison, i1 false) nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1, %d2] {redKind = #nvvm.tma_redux_kind} : !llvm.ptr, !llvm.ptr<3> nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1, %d2] {redKind = #nvvm.tma_redux_kind} : !llvm.ptr, !llvm.ptr<3> nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1, %d2] {redKind = #nvvm.tma_redux_kind} : !llvm.ptr, !llvm.ptr<3> @@ -137,14 +137,14 @@ llvm.func @tma_store_reduce_3d_im2col(%src : !llvm.ptr<3>, %tma_desc : !llvm.ptr nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1, %d2] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind, mode = #nvvm.tma_store_mode} : !llvm.ptr, !llvm.ptr<3> nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1, %d2] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind, mode = #nvvm.tma_store_mode} : !llvm.ptr, !llvm.ptr<3> - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.add.im2col.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.min.im2col.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.max.im2col.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.inc.im2col.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.dec.im2col.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.and.im2col.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.or.im2col.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.xor.im2col.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 undef, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.add.im2col.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.min.im2col.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.max.im2col.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.inc.im2col.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.dec.im2col.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.and.im2col.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.or.im2col.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.xor.im2col.3d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i64 poison, i1 false) nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1, %d2] {redKind = #nvvm.tma_redux_kind, mode = #nvvm.tma_store_mode} : !llvm.ptr, !llvm.ptr<3> nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1, %d2] {redKind = #nvvm.tma_redux_kind, mode = #nvvm.tma_store_mode} : !llvm.ptr, !llvm.ptr<3> nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1, %d2] {redKind = #nvvm.tma_redux_kind, mode = #nvvm.tma_store_mode} : !llvm.ptr, !llvm.ptr<3> @@ -177,14 +177,14 @@ llvm.func @tma_store_reduce_4d_tile(%src : !llvm.ptr<3>, %tma_desc : !llvm.ptr, nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1, %d2, %d3] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind} : !llvm.ptr, !llvm.ptr<3> nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1, %d2, %d3] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind} : !llvm.ptr, !llvm.ptr<3> - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.add.tile.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.min.tile.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.max.tile.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.inc.tile.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.dec.tile.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.and.tile.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.or.tile.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.xor.tile.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 undef, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.add.tile.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.min.tile.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.max.tile.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.inc.tile.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.dec.tile.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.and.tile.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.or.tile.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.xor.tile.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 poison, i1 false) nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1, %d2, %d3] {redKind = #nvvm.tma_redux_kind} : !llvm.ptr, !llvm.ptr<3> nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1, %d2, %d3] {redKind = #nvvm.tma_redux_kind} : !llvm.ptr, !llvm.ptr<3> nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1, %d2, %d3] {redKind = #nvvm.tma_redux_kind} : !llvm.ptr, !llvm.ptr<3> @@ -215,14 +215,14 @@ llvm.func @tma_store_reduce_4d_im2col(%src : !llvm.ptr<3>, %tma_desc : !llvm.ptr nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1, %d2, %d3] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind, mode = #nvvm.tma_store_mode} : !llvm.ptr, !llvm.ptr<3> nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1, %d2, %d3] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind, mode = #nvvm.tma_store_mode} : !llvm.ptr, !llvm.ptr<3> - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.add.im2col.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.min.im2col.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.max.im2col.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.inc.im2col.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.dec.im2col.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.and.im2col.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.or.im2col.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.xor.im2col.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 undef, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.add.im2col.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.min.im2col.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.max.im2col.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.inc.im2col.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.dec.im2col.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.and.im2col.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.or.im2col.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.xor.im2col.4d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i64 poison, i1 false) nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1, %d2, %d3] {redKind = #nvvm.tma_redux_kind, mode = #nvvm.tma_store_mode} : !llvm.ptr, !llvm.ptr<3> nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1, %d2, %d3] {redKind = #nvvm.tma_redux_kind, mode = #nvvm.tma_store_mode} : !llvm.ptr, !llvm.ptr<3> nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1, %d2, %d3] {redKind = #nvvm.tma_redux_kind, mode = #nvvm.tma_store_mode} : !llvm.ptr, !llvm.ptr<3> @@ -255,14 +255,14 @@ llvm.func @tma_store_reduce_5d_tile(%src : !llvm.ptr<3>, %tma_desc : !llvm.ptr, nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1, %d2, %d3, %d4] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind} : !llvm.ptr, !llvm.ptr<3> nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1, %d2, %d3, %d4] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind} : !llvm.ptr, !llvm.ptr<3> - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.add.tile.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.min.tile.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.max.tile.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.inc.tile.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.dec.tile.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.and.tile.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.or.tile.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.xor.tile.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 undef, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.add.tile.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.min.tile.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.max.tile.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.inc.tile.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.dec.tile.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.and.tile.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.or.tile.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.xor.tile.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 poison, i1 false) nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1, %d2, %d3, %d4] {redKind = #nvvm.tma_redux_kind} : !llvm.ptr, !llvm.ptr<3> nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1, %d2, %d3, %d4] {redKind = #nvvm.tma_redux_kind} : !llvm.ptr, !llvm.ptr<3> nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1, %d2, %d3, %d4] {redKind = #nvvm.tma_redux_kind} : !llvm.ptr, !llvm.ptr<3> @@ -293,14 +293,14 @@ llvm.func @tma_store_reduce_5d_im2col(%src : !llvm.ptr<3>, %tma_desc : !llvm.ptr nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1, %d2, %d3, %d4] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind, mode = #nvvm.tma_store_mode} : !llvm.ptr, !llvm.ptr<3> nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1, %d2, %d3, %d4] l2_cache_hint = %ch {redKind = #nvvm.tma_redux_kind, mode = #nvvm.tma_store_mode} : !llvm.ptr, !llvm.ptr<3> - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.add.im2col.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.min.im2col.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.max.im2col.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.inc.im2col.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.dec.im2col.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.and.im2col.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.or.im2col.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 undef, i1 false) - // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.xor.im2col.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 undef, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.add.im2col.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.min.im2col.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.max.im2col.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.inc.im2col.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.dec.im2col.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.and.im2col.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.or.im2col.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 poison, i1 false) + // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.reduce.xor.im2col.5d(ptr addrspace(3) %[[SRC]], ptr %[[DST]], i32 %[[D0]], i32 %[[D1]], i32 %[[D2]], i32 %[[D3]], i32 %[[D4]], i64 poison, i1 false) nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1, %d2, %d3, %d4] {redKind = #nvvm.tma_redux_kind, mode = #nvvm.tma_store_mode} : !llvm.ptr, !llvm.ptr<3> nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1, %d2, %d3, %d4] {redKind = #nvvm.tma_redux_kind, mode = #nvvm.tma_store_mode} : !llvm.ptr, !llvm.ptr<3> nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1, %d2, %d3, %d4] {redKind = #nvvm.tma_redux_kind, mode = #nvvm.tma_store_mode} : !llvm.ptr, !llvm.ptr<3>