@@ -50,7 +50,7 @@ using namespace NVVM;
50
50
// ===----------------------------------------------------------------------===//
51
51
52
52
// This verifier is shared among the following Ops:
53
- // CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load )
53
+ // CpAsyncBulkTensorSharedCTAToGlobalOp (TMA Store )
54
54
// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
55
55
static LogicalResult cpAsyncBulkTensorCommonVerifier (size_t tensorDims,
56
56
bool isIm2Col,
@@ -74,13 +74,6 @@ static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
74
74
return success ();
75
75
}
76
76
77
- LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify () {
78
- size_t numIm2ColOffsets = getIm2colOffsets ().size ();
79
- bool isIm2Col = numIm2ColOffsets > 0 ;
80
- return cpAsyncBulkTensorCommonVerifier (getCoordinates ().size (), isIm2Col,
81
- numIm2ColOffsets, getLoc ());
82
- }
83
-
84
77
LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify () {
85
78
TMAStoreMode mode = getMode ();
86
79
// We lower through inline-ptx when getPredicate() is true.
@@ -158,6 +151,38 @@ LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
158
151
getMode (), getLoc ());
159
152
}
160
153
154
+ LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify () {
155
+ TMALoadMode mode = getMode ();
156
+ bool isCTAOnly = getIsCTAOnly ();
157
+ if (getPredicate ()) { // Inline-asm based lowering
158
+ if (isCTAOnly)
159
+ return emitError (" Predicate is supported only for shared::cluster mode." );
160
+ if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
161
+ return emitError (
162
+ " Predicate is supported only for Tile and Im2col modes." );
163
+ } else { // Intrinsics-based lowering
164
+ NVVMMemorySpace expectedAS =
165
+ isCTAOnly ? NVVMMemorySpace::Shared : NVVMMemorySpace::SharedCluster;
166
+ unsigned AS = llvm::cast<LLVM::LLVMPointerType>(getDstMem ().getType ())
167
+ .getAddressSpace ();
168
+ if (AS != expectedAS)
169
+ return emitError ()
170
+ << (isCTAOnly
171
+ ? " Shared::cta destination requires address-space 3."
172
+ : " Shared::cluster destination requires address-space 7." );
173
+ // Checks specific to shared::cta mode
174
+ if (isCTAOnly) {
175
+ if (getMulticastMask ())
176
+ return emitError (" Multicast is not supported with shared::cta mode." );
177
+ if (getGroup ())
178
+ return emitError (" CTAGroup is not supported with shared::cta mode." );
179
+ }
180
+ }
181
+
182
+ return verifyTMALoadParams (getCoordinates ().size (), getIm2colOffsets ().size (),
183
+ getMode (), getLoc ());
184
+ }
185
+
161
186
LogicalResult CpAsyncBulkTensorReduceOp::verify () {
162
187
TMAStoreMode mode = getMode ();
163
188
size_t dims = getCoordinates ().size ();
@@ -1553,6 +1578,131 @@ mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
1553
1578
return {id, std::move (args)};
1554
1579
}
1555
1580
1581
+ bool CpAsyncBulkTensorGlobalToSharedClusterOp::getAsmValues (
1582
+ RewriterBase &rewriter,
1583
+ llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
1584
+ &asmValues) {
1585
+ // Add all the operands but not the attrs to the asmValues list.
1586
+ // The attrs here are used to generate the right variants for
1587
+ // intrinsics-lowering. So, we ignore them while generating inline-PTX.
1588
+ for (auto val : getOperands ())
1589
+ asmValues.push_back ({val, mlir::NVVM::PTXRegisterMod::Read});
1590
+
1591
+ return false ;
1592
+ }
1593
+
1594
+ mlir::NVVM::IDArgPair
1595
+ CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs (
1596
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1597
+ auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
1598
+ const bool isCTAOnly = thisOp.getIsCTAOnly ();
1599
+ llvm::SmallVector<llvm::Value *> args;
1600
+
1601
+ // Fill the Intrinsic Args
1602
+ args.push_back (mt.lookupValue (thisOp.getDstMem ()));
1603
+ args.push_back (mt.lookupValue (thisOp.getMbar ()));
1604
+ args.push_back (mt.lookupValue (thisOp.getTmaDescriptor ()));
1605
+
1606
+ // Coordinates and im2col-offsets
1607
+ for (mlir::Value v : thisOp.getCoordinates ())
1608
+ args.push_back (mt.lookupValue (v));
1609
+ for (mlir::Value v : thisOp.getIm2colOffsets ())
1610
+ args.push_back (mt.lookupValue (v));
1611
+
1612
+ // MulticastMask, if available
1613
+ mlir::Value mcMask = thisOp.getMulticastMask ();
1614
+ const bool hasMC = static_cast <bool >(mcMask);
1615
+ llvm::Value *i16Zero =
1616
+ llvm::ConstantInt::get (llvm::Type::getInt16Ty (mt.getLLVMContext ()), 0 );
1617
+
1618
+ // CacheHint, if available
1619
+ mlir::Value cacheHint = thisOp.getL2CacheHint ();
1620
+ const bool hasCacheHint = static_cast <bool >(cacheHint);
1621
+ llvm::Value *i64Zero =
1622
+ llvm::ConstantInt::get (llvm::Type::getInt64Ty (mt.getLLVMContext ()), 0 );
1623
+
1624
+ // Flag argument CTAGroup
1625
+ // CTA_1/2 is mapped to values 1 and 2 for the intrinsics.
1626
+ // Hence, the +1 to getGroup().
1627
+ const int32_t val =
1628
+ thisOp.getGroup () ? (static_cast <int32_t >(*thisOp.getGroup ()) + 1 ) : 0 ;
1629
+ llvm::Value *cg =
1630
+ llvm::ConstantInt::get (llvm::Type::getInt32Ty (mt.getLLVMContext ()), val);
1631
+
1632
+ if (!isCTAOnly) {
1633
+ // For shared::cluster, all the arguments that we build are applicable.
1634
+ args.push_back (hasMC ? mt.lookupValue (mcMask) : i16Zero);
1635
+ args.push_back (hasCacheHint ? mt.lookupValue (cacheHint) : i64Zero);
1636
+ args.push_back (builder.getInt1 (hasMC));
1637
+ args.push_back (builder.getInt1 (hasCacheHint));
1638
+ args.push_back (cg);
1639
+ } else {
1640
+ // For shared::cta, only cache-hint is applicable.
1641
+ args.push_back (hasCacheHint ? mt.lookupValue (cacheHint) : i64Zero);
1642
+ args.push_back (builder.getInt1 (hasCacheHint));
1643
+ }
1644
+
1645
+ constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
1646
+ constexpr size_t numDims = 5 ; // 1D to 5D
1647
+ constexpr size_t numModes = 5 ; // Tile, Im2col, w, w_128, gather4
1648
+ using rowTy = std::array<llvm::Intrinsic::ID, numDims + 1 >;
1649
+ using TableTy = std::array<rowTy, numModes>;
1650
+ static constexpr TableTy IDTable{
1651
+ {{notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d,
1652
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d,
1653
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d,
1654
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d,
1655
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d},
1656
+ {notIntrinsic, notIntrinsic, notIntrinsic,
1657
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d,
1658
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d,
1659
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d},
1660
+ {notIntrinsic, notIntrinsic, notIntrinsic,
1661
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d,
1662
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d,
1663
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d},
1664
+ {notIntrinsic, notIntrinsic, notIntrinsic,
1665
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d,
1666
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d,
1667
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d},
1668
+ {notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic,
1669
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}}};
1670
+
1671
+ static constexpr TableTy IDTableCTA{
1672
+ {{notIntrinsic,
1673
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_1d,
1674
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_2d,
1675
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_3d,
1676
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_4d,
1677
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_5d},
1678
+ {notIntrinsic, notIntrinsic, notIntrinsic,
1679
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_3d,
1680
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_4d,
1681
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_5d},
1682
+ {notIntrinsic, notIntrinsic, notIntrinsic,
1683
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_3d,
1684
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_4d,
1685
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_5d},
1686
+ {notIntrinsic, notIntrinsic, notIntrinsic,
1687
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_3d,
1688
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_4d,
1689
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_5d},
1690
+ {notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic,
1691
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_gather4_2d}}};
1692
+
1693
+ static_assert (
1694
+ (getMaxEnumValForTMALoadMode () == std::size (IDTable) - 1 ) ||
1695
+ (getMaxEnumValForTMALoadMode () == std::size (IDTableCTA) - 1 ),
1696
+ " TMALoadModes must match number of rows in IDTable" );
1697
+ size_t mode = static_cast <size_t >(thisOp.getMode ());
1698
+ size_t dim = thisOp.getCoordinates ().size ();
1699
+ auto id = isCTAOnly ? IDTableCTA[mode][dim] : IDTable[mode][dim];
1700
+ assert (id != notIntrinsic &&
1701
+ " Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp." );
1702
+
1703
+ return {id, std::move (args)};
1704
+ }
1705
+
1556
1706
mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs (
1557
1707
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1558
1708
auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
0 commit comments