Skip to content

Commit

Permalink
[mlir][nvvm] Generalize wmma ops to handle more types and shapes
Browse files Browse the repository at this point in the history
wmma intrinsics have a large number of combinations, ideally we want to be able
to target all the different variants. To avoid a combinatorial explosion in the
number of mlir op we use attributes to represent the different variation of
load/store/mma ops. We also can generate with tablegen helpers to know which
combinations are available. Using this we can avoid having too hardcode a path
for specific shapes and can support more types.
This patch also adds boiler plates for tf32 op support.

Differential Revision: https://reviews.llvm.org/D112689
  • Loading branch information
ThomasRaoux committed Nov 1, 2021
1 parent 68bb4e1 commit 77eafb8
Show file tree
Hide file tree
Showing 13 changed files with 646 additions and 549 deletions.
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
Expand Up @@ -33,6 +33,8 @@ add_mlir_dialect(NVVMOps nvvm)
add_mlir_doc(NVVMOps NVVMDialect Dialects/ -gen-dialect-doc)
set(LLVM_TARGET_DEFINITIONS NVVMOps.td)
mlir_tablegen(NVVMConversions.inc -gen-llvmir-conversions)
mlir_tablegen(NVVMOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(NVVMOpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRNVVMConversionsIncGen)

add_mlir_dialect(ROCDLOps rocdl)
Expand Down
10 changes: 10 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
Expand Up @@ -18,6 +18,16 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/IR/IntrinsicsNVPTX.h"

#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.h.inc"

/// Return the element type and number of elements associated with a wmma matrix
/// of given chracteristics. This matches the logic in IntrinsicsNVVM.td
/// WMMA_REGS structure.
std::pair<mlir::Type, unsigned> inferMMAType(mlir::NVVM::MMATypes type,
mlir::NVVM::MMAFrag frag,
mlir::MLIRContext *context);

///// Ops /////
#define GET_OP_CLASSES
Expand Down
511 changes: 344 additions & 167 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Large diffs are not rendered by default.

8 changes: 0 additions & 8 deletions mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
Expand Up @@ -359,14 +359,6 @@ llvm::Value *createIntrinsicCall(llvm::IRBuilderBase &builder,
llvm::Intrinsic::ID intrinsic,
ArrayRef<llvm::Value *> args = {},
ArrayRef<llvm::Type *> tys = {});

/// Creates a call to an LLVM IR intrinsic function with the given arguments
/// for NVVM WMMA ops. Handles cases where the intrinsic name is overloaded
/// using the types of arguments supplied. Selects the correct intrinsic
/// by inspecting the argument types.
llvm::Value *createNvvmIntrinsicCall(llvm::IRBuilderBase &builder,
llvm::Intrinsic::ID intrinsic,
ArrayRef<llvm::Value *> args = {});
} // namespace detail

} // namespace LLVM
Expand Down
178 changes: 81 additions & 97 deletions mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
Expand Up @@ -36,26 +36,36 @@ static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
return success();
}

/// Error string to emit when unimplemented WMMA variant is encountered.
static constexpr StringRef kInvalidCaseStr =
"Unimplemented WMMA variant, Only M16N16K16 version implemented.";
/// Error string to emit when an unimplemented WMMA variant is encountered.
static constexpr StringRef kInvalidCaseStr = "Unsupported WMMA variant.";

static NVVM::MMAFrag convertOperand(StringRef operandName) {
if (operandName.equals("AOp"))
return NVVM::MMAFrag::a;
if (operandName.equals("BOp"))
return NVVM::MMAFrag::b;
if (operandName.equals("COp"))
return NVVM::MMAFrag::c;
llvm_unreachable("Unknown operand name");
}

static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
if (type.getElementType().isF16())
return NVVM::MMATypes::f16;
if (type.getElementType().isF32())
return type.getOperand().equals("COp") ? NVVM::MMATypes::f32
: NVVM::MMATypes::tf32;
llvm_unreachable("Unsupported type");
}

/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
static LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type) {
StringRef operandStr = type.getOperand();
assert(type.getElementType().isa<FloatType>());
Type baseType = type.getElementType().isF16()
? VectorType::get(2, type.getElementType())
: type.getElementType();
auto getLLVMType = [&](int64_t numElements) {
return LLVM::LLVMStructType::getLiteral(
type.getContext(), SmallVector<Type, 8>(numElements, baseType));
};
if (operandStr.equals("AOp") || operandStr.equals("BOp"))
return getLLVMType(8);
if (type.getElementType().isF16())
return getLLVMType(4);
return getLLVMType(8);
NVVM::MMAFrag frag = convertOperand(type.getOperand());
NVVM::MMATypes eltType = getElementType(type);
std::pair<Type, unsigned> typeInfo =
inferMMAType(eltType, frag, type.getContext());
return LLVM::LLVMStructType::getLiteral(
type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
}

/// This class implements the conversion of GPU MMA loadOp to wmma.load op
Expand Down Expand Up @@ -118,41 +128,41 @@ struct WmmaLoadOpToNVVMLowering
gpu::MMAMatrixType retType =
subgroupMmaLoadMatrixOp.res().getType().cast<gpu::MMAMatrixType>();
ArrayRef<int64_t> retTypeShape = retType.getShape();
int64_t m = 0;
int64_t n = 0;
int64_t k = 0;
NVVM::MMATypes eltype = getElementType(retType);
// NVVM intrinsics require to give mxnxk dimensions, infer the missing
// dimension based on the valid intrinsics available.
if (retType.getOperand().equals("AOp")) {
m = retTypeShape[0];
k = retTypeShape[1];
n = NVVM::WMMALoadOp::inferNDimension(m, k, eltype);
} else if (retType.getOperand().equals("BOp")) {
k = retTypeShape[0];
n = retTypeShape[1];
m = NVVM::WMMALoadOp::inferMDimension(k, n, eltype);
} else if (retType.getOperand().equals("COp")) {
m = retTypeShape[0];
n = retTypeShape[1];
k = NVVM::WMMALoadOp::inferKDimension(m, n, eltype);
}
NVVM::MMALayout layout = NVVM::MMALayout::row;
NVVM::MMAFrag frag = convertOperand(retType.getOperand());
// Check that there is an exisiting instruction for the combination we need.
if (NVVM::WMMALoadOp::getIntrinsicID(m, n, k, layout, eltype, frag) == 0)
return rewriter.notifyMatchFailure(op, kInvalidCaseStr);

Type resType = convertMMAToLLVMType(retType);
StringRef operandStr = retType.getOperand();

// Create nvvm.mma_load op according to the operand types.
Value leadingDim32 = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), leadDimension);
SmallVector<Value, 2> loadOpOperands({loadAddressCasted, leadingDim32});
if (operandStr.equals("AOp")) {
if (retTypeShape[0] == 16 && retTypeShape[1] == 16) {
rewriter.replaceOpWithNewOp<NVVM::WMMALoadAM16N16K16Op>(op, resType,
loadOpOperands);
} else {
return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
}
} else if (operandStr.equals("BOp")) {
if (retTypeShape[0] == 16 && retTypeShape[1] == 16) {
rewriter.replaceOpWithNewOp<NVVM::WMMALoadBM16N16K16Op>(op, resType,
loadOpOperands);
} else {
return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
}
} else {
if (retTypeShape[0] == 16 && retTypeShape[1] == 16) {
if (retType.getElementType().isF16()) {
rewriter.replaceOpWithNewOp<NVVM::WMMALoadCF16M16N16K16Op>(
op, resType, loadOpOperands);
} else if (retType.getElementType().isF32()) {
rewriter.replaceOpWithNewOp<NVVM::WMMALoadCF32M16N16K16Op>(
op, resType, loadOpOperands);
}
} else {
return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
}
}

rewriter.replaceOpWithNewOp<NVVM::WMMALoadOp>(
op, resType, loadAddressCasted, leadingDim32, m, n, k, layout, eltype,
frag);

return success();
}
};
Expand Down Expand Up @@ -212,13 +222,18 @@ struct WmmaStoreOpToNVVMLowering
storeAddress);

SmallVector<Value, 4> storeOpOperands;
storeOpOperands.push_back(storeAddressCasted);

// Get the shape of the MMAMatrix type being stored. The shape will
// choose which intrinsic this op will be lowered to.
gpu::MMAMatrixType srcType =
subgroupMmaStoreMatrixOp.src().getType().cast<gpu::MMAMatrixType>();
ArrayRef<int64_t> srcTypeShape = srcType.getShape();
NVVM::MMALayout layout = NVVM::MMALayout::row;
NVVM::MMATypes eltype = getElementType(srcType);
int64_t m = srcTypeShape[0];
int64_t n = srcTypeShape[1];
int64_t k = NVVM::WMMAStoreOp::inferKDimension(m, n, eltype);
if (NVVM::WMMAStoreOp::getIntrinsicID(m, n, k, layout, eltype) == 0)
return rewriter.notifyMatchFailure(op, kInvalidCaseStr);

auto matrixType = adaptor.src().getType().cast<LLVM::LLVMStructType>();
for (unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) {
Expand All @@ -229,29 +244,11 @@ struct WmmaStoreOpToNVVMLowering
}
Value leadingDim32 = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), leadDimension);
storeOpOperands.push_back(leadingDim32);
// Unpack the results from the source.
if (srcType.getElementType().isF16()) {
// Create nvvm.mma_store op.
if (srcTypeShape[0] == 16 && srcTypeShape[1] == 16) {
rewriter.create<NVVM::WMMAStoreF16M16N16K16Op>(loc, storeOpOperands);
} else {
return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
}
rewriter.eraseOp(op);
return success();
}
if (srcType.getElementType().isF32()) {
// Create nvvm.mma_store op.
if (srcTypeShape[0] == 16 && srcTypeShape[1] == 16)
rewriter.create<NVVM::WMMAStoreF32M16N16K16Op>(loc, storeOpOperands);
else {
return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
}
rewriter.eraseOp(op);
return success();
}
return failure();
rewriter.create<NVVM::WMMAStoreOp>(loc, storeAddressCasted, m, n, k, layout,
eltype, storeOpOperands, leadingDim32);

rewriter.eraseOp(op);
return success();
}
};

Expand Down Expand Up @@ -292,40 +289,27 @@ struct WmmaMmaOpToNVVMLowering
gpu::MMAMatrixType aType =
subgroupMmaComputeOp.opA().getType().cast<gpu::MMAMatrixType>();
ArrayRef<int64_t> aTypeShape = aType.getShape();
gpu::MMAMatrixType bType =
subgroupMmaComputeOp.opB().getType().cast<gpu::MMAMatrixType>();
ArrayRef<int64_t> bTypeShape = bType.getShape();
gpu::MMAMatrixType cType =
subgroupMmaComputeOp.opC().getType().cast<gpu::MMAMatrixType>();
ArrayRef<int64_t> cTypeShape = cType.getShape();
int64_t m = cTypeShape[0];
int64_t n = cTypeShape[1];
int64_t k = aTypeShape[1];
NVVM::MMALayout layout = NVVM::MMALayout::row;
NVVM::MMATypes sourceType = getElementType(aType);
NVVM::MMATypes destType = getElementType(cType);
if (NVVM::WMMAMmaOp::getIntrinsicID(m, n, k, layout, layout, sourceType,
destType) == 0)
return rewriter.notifyMatchFailure(op, kInvalidCaseStr);

unpackOp(adaptor.opA());
unpackOp(adaptor.opB());
unpackOp(adaptor.opC());

if (cType.getElementType().isF16()) {
if (aTypeShape[0] == 16 && aTypeShape[1] == 16 && bTypeShape[0] == 16 &&
bTypeShape[1] == 16 && cTypeShape[0] == 16 && cTypeShape[1] == 16) {
// Create nvvm.wmma.mma op.
rewriter.replaceOpWithNewOp<NVVM::WMMAMmaF16F16M16N16K16Op>(
op, adaptor.opC().getType(), unpackedOps);

return success();
}
return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
}
if (cType.getElementType().isF32()) {
if (aTypeShape[0] == 16 && aTypeShape[1] == 16 && bTypeShape[0] == 16 &&
bTypeShape[1] == 16 && cTypeShape[0] == 16 && cTypeShape[1] == 16) {
// Create nvvm.wmma.mma op.
rewriter.replaceOpWithNewOp<NVVM::WMMAMmaF32F32M16N16K16Op>(
op, adaptor.opC().getType(), unpackedOps);

return success();
}
return rewriter.notifyMatchFailure(op, kInvalidCaseStr);
}
return failure();
rewriter.replaceOpWithNewOp<NVVM::WMMAMmaOp>(
op, adaptor.opC().getType(), m, n, k, layout, layout, sourceType,
destType, unpackedOps);
return success();
}
};

Expand Down

0 comments on commit 77eafb8

Please sign in to comment.