@@ -45,12 +45,14 @@ using namespace NVVM;
45
45
#include " mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
46
46
#include " mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
47
47
48
+ static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
49
+
48
50
// ===----------------------------------------------------------------------===//
49
51
// Verifier methods
50
52
// ===----------------------------------------------------------------------===//
51
53
52
54
// This verifier is shared among the following Ops:
53
- // CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load )
55
+ // CpAsyncBulkTensorSharedCTAToGlobalOp (TMA Store )
54
56
// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
55
57
static LogicalResult cpAsyncBulkTensorCommonVerifier (size_t tensorDims,
56
58
bool isIm2Col,
@@ -74,13 +76,6 @@ static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
74
76
return success ();
75
77
}
76
78
77
- LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify () {
78
- size_t numIm2ColOffsets = getIm2colOffsets ().size ();
79
- bool isIm2Col = numIm2ColOffsets > 0 ;
80
- return cpAsyncBulkTensorCommonVerifier (getCoordinates ().size (), isIm2Col,
81
- numIm2ColOffsets, getLoc ());
82
- }
83
-
84
79
LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify () {
85
80
TMAStoreMode mode = getMode ();
86
81
// We lower through inline-ptx when getPredicate() is true.
@@ -158,6 +153,38 @@ LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
158
153
getMode (), getLoc ());
159
154
}
160
155
156
+ LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify () {
157
+ TMALoadMode mode = getMode ();
158
+ bool isCTAOnly = getIsCTAOnly ();
159
+ if (getPredicate ()) { // Inline-asm based lowering
160
+ if (isCTAOnly)
161
+ return emitError (" Predicate is supported only for shared::cluster mode." );
162
+ if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
163
+ return emitError (
164
+ " Predicate is supported only for Tile and Im2col modes." );
165
+ } else { // Intrinsics-based lowering
166
+ NVVMMemorySpace expectedAS =
167
+ isCTAOnly ? NVVMMemorySpace::Shared : NVVMMemorySpace::SharedCluster;
168
+ unsigned AS = llvm::cast<LLVM::LLVMPointerType>(getDstMem ().getType ())
169
+ .getAddressSpace ();
170
+ if (AS != expectedAS)
171
+ return emitError ()
172
+ << (isCTAOnly
173
+ ? " Shared::cta destination requires address-space 3."
174
+ : " Shared::cluster destination requires address-space 7." );
175
+ // Checks specific to shared::cta mode
176
+ if (isCTAOnly) {
177
+ if (getMulticastMask ())
178
+ return emitError (" Multicast is not supported with shared::cta mode." );
179
+ if (getGroup ())
180
+ return emitError (" CTAGroup is not supported with shared::cta mode." );
181
+ }
182
+ }
183
+
184
+ return verifyTMALoadParams (getCoordinates ().size (), getIm2colOffsets ().size (),
185
+ getMode (), getLoc ());
186
+ }
187
+
161
188
LogicalResult CpAsyncBulkTensorReduceOp::verify () {
162
189
TMAStoreMode mode = getMode ();
163
190
size_t dims = getCoordinates ().size ();
@@ -1553,6 +1580,130 @@ mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
1553
1580
return {id, std::move (args)};
1554
1581
}
1555
1582
1583
+ bool CpAsyncBulkTensorGlobalToSharedClusterOp::getAsmValues (
1584
+ RewriterBase &rewriter,
1585
+ llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
1586
+ &asmValues) {
1587
+ // Add all the operands but not the attrs to the asmValues list.
1588
+ // The attrs here are used to generate the right variants for
1589
+ // intrinsics-lowering. So, we ignore them while generating inline-PTX.
1590
+ for (auto val : getOperands ())
1591
+ asmValues.push_back ({val, mlir::NVVM::PTXRegisterMod::Read});
1592
+
1593
+ return false ;
1594
+ }
1595
+
1596
+ mlir::NVVM::IDArgPair
1597
+ CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs (
1598
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1599
+ auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
1600
+ const bool isCTAOnly = thisOp.getIsCTAOnly ();
1601
+ llvm::SmallVector<llvm::Value *> args;
1602
+
1603
+ // Fill the Intrinsic Args
1604
+ args.push_back (mt.lookupValue (thisOp.getDstMem ()));
1605
+ args.push_back (mt.lookupValue (thisOp.getMbar ()));
1606
+ args.push_back (mt.lookupValue (thisOp.getTmaDescriptor ()));
1607
+
1608
+ // Coordinates and im2col-offsets
1609
+ for (mlir::Value v : thisOp.getCoordinates ())
1610
+ args.push_back (mt.lookupValue (v));
1611
+ for (mlir::Value v : thisOp.getIm2colOffsets ())
1612
+ args.push_back (mt.lookupValue (v));
1613
+
1614
+ // MulticastMask, if available
1615
+ mlir::Value mcMask = thisOp.getMulticastMask ();
1616
+ const bool hasMC = static_cast <bool >(mcMask);
1617
+ llvm::Value *i16Zero =
1618
+ llvm::ConstantInt::get (llvm::Type::getInt16Ty (mt.getLLVMContext ()), 0 );
1619
+
1620
+ // CacheHint, if available
1621
+ mlir::Value cacheHint = thisOp.getL2CacheHint ();
1622
+ const bool hasCacheHint = static_cast <bool >(cacheHint);
1623
+ llvm::Value *i64Zero =
1624
+ llvm::ConstantInt::get (llvm::Type::getInt64Ty (mt.getLLVMContext ()), 0 );
1625
+
1626
+ // Flag argument CTAGroup
1627
+ // CTA_1/2 is mapped to values 1 and 2 for the intrinsics.
1628
+ // Hence, the +1 to getGroup().
1629
+ const int32_t val =
1630
+ thisOp.getGroup () ? (static_cast <int32_t >(*thisOp.getGroup ()) + 1 ) : 0 ;
1631
+ llvm::Value *cg =
1632
+ llvm::ConstantInt::get (llvm::Type::getInt32Ty (mt.getLLVMContext ()), val);
1633
+
1634
+ if (!isCTAOnly) {
1635
+ // For shared::cluster, all the arguments that we build are applicable.
1636
+ args.push_back (hasMC ? mt.lookupValue (mcMask) : i16Zero);
1637
+ args.push_back (hasCacheHint ? mt.lookupValue (cacheHint) : i64Zero);
1638
+ args.push_back (builder.getInt1 (hasMC));
1639
+ args.push_back (builder.getInt1 (hasCacheHint));
1640
+ args.push_back (cg);
1641
+ } else {
1642
+ // For shared::cta, only cache-hint is applicable.
1643
+ args.push_back (hasCacheHint ? mt.lookupValue (cacheHint) : i64Zero);
1644
+ args.push_back (builder.getInt1 (hasCacheHint));
1645
+ }
1646
+
1647
+ constexpr size_t numDims = 5 ; // 1D to 5D
1648
+ constexpr size_t numModes = 5 ; // Tile, Im2col, w, w_128, gather4
1649
+ using rowTy = std::array<llvm::Intrinsic::ID, numDims + 1 >;
1650
+ using TableTy = std::array<rowTy, numModes>;
1651
+ static constexpr TableTy IDTable{
1652
+ {{notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d,
1653
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d,
1654
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d,
1655
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d,
1656
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d},
1657
+ {notIntrinsic, notIntrinsic, notIntrinsic,
1658
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d,
1659
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d,
1660
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d},
1661
+ {notIntrinsic, notIntrinsic, notIntrinsic,
1662
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d,
1663
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d,
1664
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d},
1665
+ {notIntrinsic, notIntrinsic, notIntrinsic,
1666
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d,
1667
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d,
1668
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d},
1669
+ {notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic,
1670
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}}};
1671
+
1672
+ static constexpr TableTy IDTableCTA{
1673
+ {{notIntrinsic,
1674
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_1d,
1675
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_2d,
1676
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_3d,
1677
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_4d,
1678
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_5d},
1679
+ {notIntrinsic, notIntrinsic, notIntrinsic,
1680
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_3d,
1681
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_4d,
1682
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_5d},
1683
+ {notIntrinsic, notIntrinsic, notIntrinsic,
1684
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_3d,
1685
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_4d,
1686
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_5d},
1687
+ {notIntrinsic, notIntrinsic, notIntrinsic,
1688
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_3d,
1689
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_4d,
1690
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_5d},
1691
+ {notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic,
1692
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_gather4_2d}}};
1693
+
1694
+ static_assert (
1695
+ (getMaxEnumValForTMALoadMode () == std::size (IDTable) - 1 ) &&
1696
+ (getMaxEnumValForTMALoadMode () == std::size (IDTableCTA) - 1 ),
1697
+ " TMALoadModes must match number of rows in IDTable and IDTableCTA" );
1698
+ size_t mode = static_cast <size_t >(thisOp.getMode ());
1699
+ size_t dim = thisOp.getCoordinates ().size ();
1700
+ auto id = isCTAOnly ? IDTableCTA[mode][dim] : IDTable[mode][dim];
1701
+ assert (id != notIntrinsic &&
1702
+ " Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp." );
1703
+
1704
+ return {id, std::move (args)};
1705
+ }
1706
+
1556
1707
mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs (
1557
1708
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1558
1709
auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
0 commit comments