Skip to content

Commit 269f556

Browse files
committed
[MLIR][NVVM][v3] Update TMA Load Op
This patch includes im2col and gather mode support to the TMA Load Op. The lowering is also updated to intrinsics except when Predicate is given. This completes the Blackwell additions on this Op. lit tests are added for all combinations. Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
1 parent 76a11c7 commit 269f556

File tree

12 files changed

+1025
-119
lines changed

12 files changed

+1025
-119
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2827,26 +2827,21 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
28272827
NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global",
28282828
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
28292829
AttrSizedOperandSegments, NVVMRequiresSM<90>]>,
2830-
Arguments<(ins LLVM_PointerShared:$dstMem,
2831-
LLVM_AnyPointer:$tmaDescriptor,
2830+
Arguments<(ins AnyTypeOf<[LLVM_PointerShared, LLVM_PointerSharedCluster]>:$dstMem,
2831+
LLVM_PointerGeneric:$tmaDescriptor,
28322832
Variadic<I32>:$coordinates,
28332833
LLVM_PointerShared:$mbar,
28342834
Variadic<I16>:$im2colOffsets,
28352835
Optional<I16>:$multicastMask,
28362836
Optional<I64>:$l2CacheHint,
2837+
DefaultValuedAttr<TMALoadModeAttr, "TMALoadMode::TILE">:$mode,
2838+
DefaultValuedAttr<BoolAttr, "false">:$isCTAOnly,
2839+
OptionalAttr<CTAGroupKindAttr>:$group,
28372840
PtxPredicate:$predicate)> {
28382841
let description = [{
28392842
Initiates an asynchronous copy operation on the tensor data from global
2840-
memory to shared memory.
2841-
2842-
The Op operates has two load modes:
2843-
1) Tiled Mode: It's the default mode. The source multi-dimensional tensor
2844-
layout is preserved at the destination.
2845-
2846-
2) Im2col Mode: This mode is used when `im2colOffsets` operands are present.
2847-
the elements in the Bounding Box of the source tensor are rearranged into
2848-
columns at the destination. In this mode, the tensor has to be at least
2849-
3-dimensional.
2843+
memory to shared::cluster (or) shared::cta memory. This Op supports all
2844+
the load modes specified in `TMALoadMode`.
28502845

28512846
The `multicastMask` operand is optional. When it is present, the Op copies
28522847
data from global memory to shared memory of multiple CTAs in the cluster.
@@ -2857,6 +2852,10 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
28572852
The `l2CacheHint` operand is optional, and it is used to specify cache
28582853
eviction policy that may be used during the memory access.
28592854

2855+
When the `isCTAOnly` attribute is set to true, the destination is
2856+
shared::cta only. Hence, `multicastMask` and `CTAGroup` are not applicable
2857+
when `isCTAOnly` is true.
2858+
28602859
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor)
28612860
}];
28622861

@@ -2904,6 +2903,23 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
29042903
}
29052904
}];
29062905
let hasVerifier = 1;
2906+
2907+
let extraClassDeclaration = [{
2908+
bool hasIntrinsic() { return !getPredicate(); }
2909+
2910+
bool getAsmValues(RewriterBase &rewriter,
2911+
llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &asmValues);
2912+
2913+
static mlir::NVVM::IDArgPair
2914+
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
2915+
llvm::IRBuilderBase& builder);
2916+
}];
2917+
2918+
string llvmBuilder = [{
2919+
auto [id, args] = NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
2920+
*op, moduleTranslation, builder);
2921+
createIntrinsicCall(builder, id, args);
2922+
}];
29072923
}
29082924

29092925
def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp :

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -993,6 +993,14 @@ struct NVGPUTmaAsyncLoadOpLowering
993993
auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
994994
Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
995995
adaptor.getDst(), {});
996+
// Intrinsics takes a shared-cluster pointer so we need an
997+
// address space cast from 3 to 7.
998+
// TODO: Introduce AS(7) in NVGPU.
999+
auto ptrSharedClusterType = LLVM::LLVMPointerType::get(
1000+
op->getContext(),
1001+
static_cast<unsigned>(NVVM::NVVMMemorySpace::SharedCluster));
1002+
dest = LLVM::AddrSpaceCastOp::create(b, ptrSharedClusterType, dest);
1003+
9961004
Value barrier =
9971005
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
9981006
adaptor.getMbarId(), rewriter);
@@ -1001,9 +1009,14 @@ struct NVGPUTmaAsyncLoadOpLowering
10011009
for (auto [index, value] : llvm::enumerate(coords)) {
10021010
coords[index] = truncToI32(b, value);
10031011
}
1012+
1013+
// TODO: Enhance the NVGPU Op for other modes too
10041014
rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
10051015
op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
10061016
ValueRange{}, adaptor.getMulticastMask(), Value{},
1017+
NVVM::TMALoadMode::TILE, // default is TILE mode
1018+
false, // default is cluster-scope
1019+
nullptr, // default is no cta-group
10071020
adaptor.getPredicate());
10081021
return success();
10091022
}

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 159 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,14 @@ using namespace NVVM;
4545
#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
4646
#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
4747

48+
static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
49+
4850
//===----------------------------------------------------------------------===//
4951
// Verifier methods
5052
//===----------------------------------------------------------------------===//
5153

5254
// This verifier is shared among the following Ops:
53-
// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
55+
// CpAsyncBulkTensorSharedCTAToGlobalOp (TMA Store)
5456
// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
5557
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
5658
bool isIm2Col,
@@ -74,13 +76,6 @@ static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
7476
return success();
7577
}
7678

77-
LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
78-
size_t numIm2ColOffsets = getIm2colOffsets().size();
79-
bool isIm2Col = numIm2ColOffsets > 0;
80-
return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
81-
numIm2ColOffsets, getLoc());
82-
}
83-
8479
LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
8580
TMAStoreMode mode = getMode();
8681
// We lower through inline-ptx when getPredicate() is true.
@@ -158,6 +153,38 @@ LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
158153
getMode(), getLoc());
159154
}
160155

156+
LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
157+
TMALoadMode mode = getMode();
158+
bool isCTAOnly = getIsCTAOnly();
159+
if (getPredicate()) { // Inline-asm based lowering
160+
if (isCTAOnly)
161+
return emitError("Predicate is supported only for shared::cluster mode.");
162+
if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
163+
return emitError(
164+
"Predicate is supported only for Tile and Im2col modes.");
165+
} else { // Intrinsics-based lowering
166+
NVVMMemorySpace expectedAS =
167+
isCTAOnly ? NVVMMemorySpace::Shared : NVVMMemorySpace::SharedCluster;
168+
unsigned AS = llvm::cast<LLVM::LLVMPointerType>(getDstMem().getType())
169+
.getAddressSpace();
170+
if (AS != expectedAS)
171+
return emitError()
172+
<< (isCTAOnly
173+
? "Shared::cta destination requires address-space 3."
174+
: "Shared::cluster destination requires address-space 7.");
175+
// Checks specific to shared::cta mode
176+
if (isCTAOnly) {
177+
if (getMulticastMask())
178+
return emitError("Multicast is not supported with shared::cta mode.");
179+
if (getGroup())
180+
return emitError("CTAGroup is not supported with shared::cta mode.");
181+
}
182+
}
183+
184+
return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
185+
getMode(), getLoc());
186+
}
187+
161188
LogicalResult CpAsyncBulkTensorReduceOp::verify() {
162189
TMAStoreMode mode = getMode();
163190
size_t dims = getCoordinates().size();
@@ -1553,6 +1580,130 @@ mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
15531580
return {id, std::move(args)};
15541581
}
15551582

1583+
bool CpAsyncBulkTensorGlobalToSharedClusterOp::getAsmValues(
1584+
RewriterBase &rewriter,
1585+
llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
1586+
&asmValues) {
1587+
// Add all the operands but not the attrs to the asmValues list.
1588+
// The attrs here are used to generate the right variants for
1589+
// intrinsics-lowering. So, we ignore them while generating inline-PTX.
1590+
for (auto val : getOperands())
1591+
asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
1592+
1593+
return false;
1594+
}
1595+
1596+
mlir::NVVM::IDArgPair
1597+
CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
1598+
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1599+
auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
1600+
const bool isCTAOnly = thisOp.getIsCTAOnly();
1601+
llvm::SmallVector<llvm::Value *> args;
1602+
1603+
// Fill the Intrinsic Args
1604+
args.push_back(mt.lookupValue(thisOp.getDstMem()));
1605+
args.push_back(mt.lookupValue(thisOp.getMbar()));
1606+
args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
1607+
1608+
// Coordinates and im2col-offsets
1609+
for (mlir::Value v : thisOp.getCoordinates())
1610+
args.push_back(mt.lookupValue(v));
1611+
for (mlir::Value v : thisOp.getIm2colOffsets())
1612+
args.push_back(mt.lookupValue(v));
1613+
1614+
// MulticastMask, if available
1615+
mlir::Value mcMask = thisOp.getMulticastMask();
1616+
const bool hasMC = static_cast<bool>(mcMask);
1617+
llvm::Value *i16Zero =
1618+
llvm::ConstantInt::get(llvm::Type::getInt16Ty(mt.getLLVMContext()), 0);
1619+
1620+
// CacheHint, if available
1621+
mlir::Value cacheHint = thisOp.getL2CacheHint();
1622+
const bool hasCacheHint = static_cast<bool>(cacheHint);
1623+
llvm::Value *i64Zero =
1624+
llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
1625+
1626+
// Flag argument CTAGroup
1627+
// CTA_1/2 is mapped to values 1 and 2 for the intrinsics.
1628+
// Hence, the +1 to getGroup().
1629+
const int32_t val =
1630+
thisOp.getGroup() ? (static_cast<int32_t>(*thisOp.getGroup()) + 1) : 0;
1631+
llvm::Value *cg =
1632+
llvm::ConstantInt::get(llvm::Type::getInt32Ty(mt.getLLVMContext()), val);
1633+
1634+
if (!isCTAOnly) {
1635+
// For shared::cluster, all the arguments that we build are applicable.
1636+
args.push_back(hasMC ? mt.lookupValue(mcMask) : i16Zero);
1637+
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
1638+
args.push_back(builder.getInt1(hasMC));
1639+
args.push_back(builder.getInt1(hasCacheHint));
1640+
args.push_back(cg);
1641+
} else {
1642+
// For shared::cta, only cache-hint is applicable.
1643+
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
1644+
args.push_back(builder.getInt1(hasCacheHint));
1645+
}
1646+
1647+
constexpr size_t numDims = 5; // 1D to 5D
1648+
constexpr size_t numModes = 5; // Tile, Im2col, w, w_128, gather4
1649+
using rowTy = std::array<llvm::Intrinsic::ID, numDims + 1>;
1650+
using TableTy = std::array<rowTy, numModes>;
1651+
static constexpr TableTy IDTable{
1652+
{{notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d,
1653+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d,
1654+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d,
1655+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d,
1656+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d},
1657+
{notIntrinsic, notIntrinsic, notIntrinsic,
1658+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d,
1659+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d,
1660+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d},
1661+
{notIntrinsic, notIntrinsic, notIntrinsic,
1662+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d,
1663+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d,
1664+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d},
1665+
{notIntrinsic, notIntrinsic, notIntrinsic,
1666+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d,
1667+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d,
1668+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d},
1669+
{notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic,
1670+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}}};
1671+
1672+
static constexpr TableTy IDTableCTA{
1673+
{{notIntrinsic,
1674+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_1d,
1675+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_2d,
1676+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_3d,
1677+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_4d,
1678+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_5d},
1679+
{notIntrinsic, notIntrinsic, notIntrinsic,
1680+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_3d,
1681+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_4d,
1682+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_5d},
1683+
{notIntrinsic, notIntrinsic, notIntrinsic,
1684+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_3d,
1685+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_4d,
1686+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_5d},
1687+
{notIntrinsic, notIntrinsic, notIntrinsic,
1688+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_3d,
1689+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_4d,
1690+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_5d},
1691+
{notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic,
1692+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_gather4_2d}}};
1693+
1694+
static_assert(
1695+
(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1) &&
1696+
(getMaxEnumValForTMALoadMode() == std::size(IDTableCTA) - 1),
1697+
"TMALoadModes must match number of rows in IDTable and IDTableCTA");
1698+
size_t mode = static_cast<size_t>(thisOp.getMode());
1699+
size_t dim = thisOp.getCoordinates().size();
1700+
auto id = isCTAOnly ? IDTableCTA[mode][dim] : IDTable[mode][dim];
1701+
assert(id != notIntrinsic &&
1702+
"Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp.");
1703+
1704+
return {id, std::move(args)};
1705+
}
1706+
15561707
mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
15571708
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
15581709
auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);

mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,8 @@ module @mymodule {
854854
// CHECK: %[[desc:.+]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
855855
// CHECK: %[[c8192:.+]] = llvm.mlir.constant(8192 : index) : i64
856856
// CHECK: %[[shmemOfset:.+]] = llvm.getelementptr %[[desc]][%[[c8192]]] : (!llvm.ptr<3>, i64)
857-
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %[[shmemOfset]], %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}]
857+
// CHECK: %[[dest:.+]] = llvm.addrspacecast %[[shmemOfset]] : !llvm.ptr<3> to !llvm.ptr<7>
858+
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %[[dest]], %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}]
858859
nvgpu.tma.async.load %rhsTensorMap[%c0, %c0], %mbarrier[%c0] to %rhsShmem : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[64, 1], offset: 8192>, 3>
859860
return
860861
}

0 commit comments

Comments
 (0)