Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2827,26 +2827,21 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global",
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
AttrSizedOperandSegments, NVVMRequiresSM<90>]>,
Arguments<(ins LLVM_PointerShared:$dstMem,
LLVM_AnyPointer:$tmaDescriptor,
Arguments<(ins AnyTypeOf<[LLVM_PointerShared, LLVM_PointerSharedCluster]>:$dstMem,
LLVM_PointerGeneric:$tmaDescriptor,
Variadic<I32>:$coordinates,
LLVM_PointerShared:$mbar,
Variadic<I16>:$im2colOffsets,
Optional<I16>:$multicastMask,
Optional<I64>:$l2CacheHint,
DefaultValuedAttr<TMALoadModeAttr, "TMALoadMode::TILE">:$mode,
DefaultValuedAttr<BoolAttr, "false">:$isCTAOnly,
OptionalAttr<CTAGroupKindAttr>:$group,
PtxPredicate:$predicate)> {
let description = [{
Initiates an asynchronous copy operation on the tensor data from global
memory to shared memory.

The Op operates has two load modes:
1) Tiled Mode: It's the default mode. The source multi-dimensional tensor
layout is preserved at the destination.

2) Im2col Mode: This mode is used when `im2colOffsets` operands are present.
the elements in the Bounding Box of the source tensor are rearranged into
columns at the destination. In this mode, the tensor has to be at least
3-dimensional.
memory to shared::cluster (or) shared::cta memory. This Op supports all
the load modes specified in `TMALoadMode`.

The `multicastMask` operand is optional. When it is present, the Op copies
data from global memory to shared memory of multiple CTAs in the cluster.
Expand All @@ -2857,6 +2852,10 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
The `l2CacheHint` operand is optional, and it is used to specify cache
eviction policy that may be used during the memory access.

When the `isCTAOnly` attribute is set to true, the destination is
shared::cta only. Hence, `multicastMask` and `CTAGroup` are not applicable
when `isCTAOnly` is true.

[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor)
}];

Expand Down Expand Up @@ -2904,6 +2903,23 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
}
}];
let hasVerifier = 1;

let extraClassDeclaration = [{
bool hasIntrinsic() { return !getPredicate(); }

bool getAsmValues(RewriterBase &rewriter,
llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &asmValues);

static mlir::NVVM::IDArgPair
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase& builder);
}];

string llvmBuilder = [{
auto [id, args] = NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
*op, moduleTranslation, builder);
createIntrinsicCall(builder, id, args);
}];
}

def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp :
Expand Down
13 changes: 13 additions & 0 deletions mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,14 @@ struct NVGPUTmaAsyncLoadOpLowering
auto srcMemrefType = cast<MemRefType>(op.getDst().getType());
Value dest = getStridedElementPtr(rewriter, op->getLoc(), srcMemrefType,
adaptor.getDst(), {});
// Intrinsics takes a shared-cluster pointer so we need an
// address space cast from 3 to 7.
// TODO: Introduce AS(7) in NVGPU.
auto ptrSharedClusterType = LLVM::LLVMPointerType::get(
op->getContext(),
static_cast<unsigned>(NVVM::NVVMMemorySpace::SharedCluster));
dest = LLVM::AddrSpaceCastOp::create(b, ptrSharedClusterType, dest);

Value barrier =
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
Expand All @@ -1001,9 +1009,14 @@ struct NVGPUTmaAsyncLoadOpLowering
for (auto [index, value] : llvm::enumerate(coords)) {
coords[index] = truncToI32(b, value);
}

// TODO: Enhance the NVGPU Op for other modes too
rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
ValueRange{}, adaptor.getMulticastMask(), Value{},
NVVM::TMALoadMode::TILE, // default is TILE mode
false, // default is cluster-scope
nullptr, // default is no cta-group
adaptor.getPredicate());
return success();
}
Expand Down
167 changes: 159 additions & 8 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,14 @@ using namespace NVVM;
#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"

static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;

//===----------------------------------------------------------------------===//
// Verifier methods
//===----------------------------------------------------------------------===//

// This verifier is shared among the following Ops:
// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
// CpAsyncBulkTensorSharedCTAToGlobalOp (TMA Store)
// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
bool isIm2Col,
Expand All @@ -74,13 +76,6 @@ static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
return success();
}

LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
size_t numIm2ColOffsets = getIm2colOffsets().size();
bool isIm2Col = numIm2ColOffsets > 0;
return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
numIm2ColOffsets, getLoc());
}

LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
TMAStoreMode mode = getMode();
// We lower through inline-ptx when getPredicate() is true.
Expand Down Expand Up @@ -158,6 +153,38 @@ LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
getMode(), getLoc());
}

LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
TMALoadMode mode = getMode();
bool isCTAOnly = getIsCTAOnly();
if (getPredicate()) { // Inline-asm based lowering
if (isCTAOnly)
return emitError("Predicate is supported only for shared::cluster mode.");
if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
return emitError(
"Predicate is supported only for Tile and Im2col modes.");
} else { // Intrinsics-based lowering
NVVMMemorySpace expectedAS =
isCTAOnly ? NVVMMemorySpace::Shared : NVVMMemorySpace::SharedCluster;
unsigned AS = llvm::cast<LLVM::LLVMPointerType>(getDstMem().getType())
.getAddressSpace();
if (AS != expectedAS)
return emitError()
<< (isCTAOnly
? "Shared::cta destination requires address-space 3."
: "Shared::cluster destination requires address-space 7.");
// Checks specific to shared::cta mode
if (isCTAOnly) {
if (getMulticastMask())
return emitError("Multicast is not supported with shared::cta mode.");
if (getGroup())
return emitError("CTAGroup is not supported with shared::cta mode.");
}
}

return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
getMode(), getLoc());
}

LogicalResult CpAsyncBulkTensorReduceOp::verify() {
TMAStoreMode mode = getMode();
size_t dims = getCoordinates().size();
Expand Down Expand Up @@ -1553,6 +1580,130 @@ mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
return {id, std::move(args)};
}

bool CpAsyncBulkTensorGlobalToSharedClusterOp::getAsmValues(
RewriterBase &rewriter,
llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
&asmValues) {
// Add all the operands but not the attrs to the asmValues list.
// The attrs here are used to generate the right variants for
// intrinsics-lowering. So, we ignore them while generating inline-PTX.
for (auto val : getOperands())
asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});

return false;
}

mlir::NVVM::IDArgPair
CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
const bool isCTAOnly = thisOp.getIsCTAOnly();
llvm::SmallVector<llvm::Value *> args;

// Fill the Intrinsic Args
args.push_back(mt.lookupValue(thisOp.getDstMem()));
args.push_back(mt.lookupValue(thisOp.getMbar()));
args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));

// Coordinates and im2col-offsets
for (mlir::Value v : thisOp.getCoordinates())
args.push_back(mt.lookupValue(v));
for (mlir::Value v : thisOp.getIm2colOffsets())
args.push_back(mt.lookupValue(v));

// MulticastMask, if available
mlir::Value mcMask = thisOp.getMulticastMask();
const bool hasMC = static_cast<bool>(mcMask);
llvm::Value *i16Zero =
llvm::ConstantInt::get(llvm::Type::getInt16Ty(mt.getLLVMContext()), 0);

// CacheHint, if available
mlir::Value cacheHint = thisOp.getL2CacheHint();
const bool hasCacheHint = static_cast<bool>(cacheHint);
llvm::Value *i64Zero =
llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);

// Flag argument CTAGroup
// CTA_1/2 is mapped to values 1 and 2 for the intrinsics.
// Hence, the +1 to getGroup().
const int32_t val =
thisOp.getGroup() ? (static_cast<int32_t>(*thisOp.getGroup()) + 1) : 0;
llvm::Value *cg =
llvm::ConstantInt::get(llvm::Type::getInt32Ty(mt.getLLVMContext()), val);

if (!isCTAOnly) {
// For shared::cluster, all the arguments that we build are applicable.
args.push_back(hasMC ? mt.lookupValue(mcMask) : i16Zero);
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
args.push_back(builder.getInt1(hasMC));
args.push_back(builder.getInt1(hasCacheHint));
args.push_back(cg);
} else {
// For shared::cta, only cache-hint is applicable.
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Zero);
args.push_back(builder.getInt1(hasCacheHint));
}

constexpr size_t numDims = 5; // 1D to 5D
constexpr size_t numModes = 5; // Tile, Im2col, w, w_128, gather4
using rowTy = std::array<llvm::Intrinsic::ID, numDims + 1>;
using TableTy = std::array<rowTy, numModes>;
static constexpr TableTy IDTable{
{{notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d},
{notIntrinsic, notIntrinsic, notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d},
{notIntrinsic, notIntrinsic, notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d},
{notIntrinsic, notIntrinsic, notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d},
{notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}}};

static constexpr TableTy IDTableCTA{
{{notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_1d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_2d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_5d},
{notIntrinsic, notIntrinsic, notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_5d},
{notIntrinsic, notIntrinsic, notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_5d},
{notIntrinsic, notIntrinsic, notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_5d},
{notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_gather4_2d}}};

static_assert(
(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1) &&
(getMaxEnumValForTMALoadMode() == std::size(IDTableCTA) - 1),
"TMALoadModes must match number of rows in IDTable and IDTableCTA");
size_t mode = static_cast<size_t>(thisOp.getMode());
size_t dim = thisOp.getCoordinates().size();
auto id = isCTAOnly ? IDTableCTA[mode][dim] : IDTable[mode][dim];
assert(id != notIntrinsic &&
"Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp.");

return {id, std::move(args)};
}

mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
Expand Down
3 changes: 2 additions & 1 deletion mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -854,7 +854,8 @@ module @mymodule {
// CHECK: %[[desc:.+]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[c8192:.+]] = llvm.mlir.constant(8192 : index) : i64
// CHECK: %[[shmemOfset:.+]] = llvm.getelementptr %[[desc]][%[[c8192]]] : (!llvm.ptr<3>, i64)
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %[[shmemOfset]], %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}]
// CHECK: %[[dest:.+]] = llvm.addrspacecast %[[shmemOfset]] : !llvm.ptr<3> to !llvm.ptr<7>
// CHECK: nvvm.cp.async.bulk.tensor.shared.cluster.global %[[dest]], %{{.*}}, %{{.*}}, box[%{{.*}}, %{{.*}}]
nvgpu.tma.async.load %rhsTensorMap[%c0, %c0], %mbarrier[%c0] to %rhsShmem : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[64, 1], offset: 8192>, 3>
return
}
Expand Down
Loading