diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index d78145d690fc8..c427b0942aafd 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -2131,6 +2131,12 @@ class WMMA_NAME_LDST { /// Generate the signature part of the mma intrinsic name. class MMA_SIGNATURE { list id_frags = !cond( + // FP8/F8F6F4 ops are identified by A,B inputs & accomulator & result type. + !or(!eq(A.ptx_elt_type, "e4m3"), + !eq(A.ptx_elt_type, "e5m2"), + !eq(A.ptx_elt_type, "e3m2"), + !eq(A.ptx_elt_type, "e2m3"), + !eq(A.ptx_elt_type, "e2m1")): [D, A, B, C], // FP16 ops are identified by accumulator & result type. !eq(A.ptx_elt_type, "f16") : [D, C], // other ops are identified by input types. @@ -2257,6 +2263,31 @@ class NVVM_MMA_OPS { list> all_mma_sync_ops = !listconcat( tf32_mma_ops, bf16_mma_ops, f64_mma_ops, fp_mma_ops, int_mma_ops, subint_mma_ops, bit_mma_ops); + + list> bf16_mma_sp_ops = MMA_OPS< + [GEOM<16,8,16>, GEOM<16,8,32>], + ["bf16"], [], ["f32"], []>.ret; + list> tf32_mma_sp_ops = MMA_OPS< + [GEOM<16,8,8>, GEOM<16,8,16>], + ["tf32"], [], ["f32"], []>.ret; + list> fp_mma_sp_ops = MMA_OPS< + [GEOM<16,8,16>, GEOM<16,8,32>], + ["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret; + list> fp8_mma_sp_ops = MMA_OPS< + [GEOM<16,8,64>], + ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"], + ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"], + ["f16", "f32"], ["f16", "f32"]>.ret; + list> subint_mma_sp_ops = MMA_OPS< + [GEOM<16,8,64>, GEOM<16,8,128>], + ["s4", "u4"], ["s4", "u4"], ["s32"], []>.ret; + list> int_mma_sp_ops = MMA_OPS< + [GEOM<16,8,32>, GEOM<16,8,64>], + ["s8", "u8"], ["s8", "u8"], ["s32"], []>.ret; + list> all_mma_sp_sync_ops = !listconcat( + bf16_mma_sp_ops, tf32_mma_sp_ops, fp_mma_sp_ops, fp8_mma_sp_ops, + subint_mma_sp_ops, int_mma_sp_ops); + } def NVVM_MMA_OPS : NVVM_MMA_OPS; @@ -2362,6 +2393,16 @@ def MMAIntOverflow : I32EnumAttr<"MMAIntOverflow", "MMA overflow options", def MMAIntOverflowAttr : EnumAttr { let assemblyFormat = "`<` $value `>`"; } +/// MMA kind types (for mixed-precision FP8 operations) +def MMAKindF8F6F4 : I32EnumAttrCase<"f8f6f4", 0>; +def MMAKind : I32EnumAttr<"MMAKind", "MMA operation kind", + [MMAKindF8F6F4]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} +def MMAKindAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} /// Attribute to hold the MMA shape def NVVM_MMAShapeAttr : NVVM_Attr<"MMAShape", "shape"> { @@ -2506,12 +2547,18 @@ def MMATypeU4 : I32EnumAttrCase<"u4", 7>; def MMATypeS4 : I32EnumAttrCase<"s4", 8>; def MMATypeBF16 : I32EnumAttrCase<"bf16", 9>; def MMATypeF64 : I32EnumAttrCase<"f64", 10>; +def MMATypeE4M3 : I32EnumAttrCase<"e4m3", 11>; +def MMATypeE5M2 : I32EnumAttrCase<"e5m2", 12>; +def MMATypeE3M2 : I32EnumAttrCase<"e3m2", 13>; +def MMATypeE2M3 : I32EnumAttrCase<"e2m3", 14>; +def MMATypeE2M1 : I32EnumAttrCase<"e2m1", 15>; def MMATypes : I32EnumAttr<"MMATypes", "NVVM MMA types", [MMATypeF16, MMATypeF32, MMATypeTF32, MMATypeBF16, MMATypeS8, MMATypeU8, MMATypeS32, MMATypeS4, MMATypeU4, - MMATypeB1, MMATypeF64]> { + MMATypeB1, MMATypeF64, + MMATypeE4M3, MMATypeE5M2, MMATypeE3M2, MMATypeE2M3, MMATypeE2M1]> { let genSpecializedAttr = 0; let cppNamespace = "::mlir::NVVM"; } @@ -2948,6 +2995,216 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> { let hasVerifier = 1; } +/// Generate enum value of the mma.sync intrinsic. +class MMA_SP_SYNC_NAME { + string signature = MMA_SIGNATURE.ret; + string id = "llvm::Intrinsic::nvvm_mma" + # "_" # !subst("::", "_", Metadata) + # "_" # A.geom + # "_row_col" + # !if(!ne(Kind, ""), !strconcat("_", !subst("::", "_", Kind)), "") + # !if(Satfinite, "_satfinite", "") + # signature; +} + +// Returns true if this combination of layout/kind/satf for MMA.SP ops is supported; +// false otherwise. +// E.g. +// if NVVM_MMA_SP_SUPPORTED<...>.ret then +// def : FOO<>; // The record will only be defined for supported ops. +// +class NVVM_MMA_SP_SUPPORTED frags, string metadata, + string kind, int satf> { + // MMA.SP ops check both layouts. + string a_type = frags[0].ptx_elt_type; + string b_type = frags[1].ptx_elt_type; + string c_type = frags[2].ptx_elt_type; + string d_type = frags[3].ptx_elt_type; + string geom = frags[0].geom; + + bit is_int = !or(!eq(a_type, "s8"), + !eq(a_type, "u8"), + !eq(a_type, "s4"), + !eq(a_type, "u4")); + + bit ret = !cond( + + // Limit satf to valid types + !and(!eq(satf, 1), + !eq(is_int, 0)): false, + + // f16/bf16/tf32 requires A and B to be the same type. + !and(!or(!eq(a_type, "f16"), + !eq(a_type, "bf16"), + !eq(a_type, "tf32")), + !ne(a_type, b_type)): false, + + // m16n8k16, m16n8k32 and m16n8k64 requires C and D to be the same type. + !and(!or(!eq(geom, "m16n8k16"), + !eq(geom, "m16n8k32"), + !eq(geom, "m16n8k64")), + !ne(c_type, d_type)): false, + + !and(!eq(kind, ""), + !or(!eq(a_type, "e3m2"), + !eq(a_type, "e2m3"), + !eq(a_type, "e2m1"), + !eq(b_type, "e3m2"), + !eq(b_type, "e2m3"), + !eq(b_type, "e2m1"))): false, + + !and(!eq(kind, ""), + !eq(geom, "m16n8k64"), + !or(!eq(c_type, "f16"), + !eq(d_type, "f16"))): false, + + !and(!ne(kind, ""), + !or(!eq(metadata, "sp"), + !ne(geom, "m16n8k64"), + !eq(is_int, 1))): false, + + // All other are OK. + true: true + ); +} + +/// Helper to create the mapping between the configuration and the mma.sp.sync +/// intrinsic enum value. +class MMA_SP_SYNC_INTR { + list>>> cond0 = + !foreach(op, NVVM_MMA_OPS.all_mma_sp_sync_ops, + !foreach(metadata, ["sp", "sp::ordered_metadata"], + !foreach(kind, ["", "kind::f8f6f4"], + !foreach (satf, [0, 1], + !if(NVVM_MMA_SP_SUPPORTED.ret, + "if (m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k + # " && \"" # op[0].ptx_elt_type # "\" == eltypeA" + # " && \"" # op[1].ptx_elt_type # "\" == eltypeB" + # " && \"" # op[2].ptx_elt_type # "\" == eltypeC" + # " && \"" # op[3].ptx_elt_type # "\" == eltypeD" + # " && (satf.has_value() ? " # satf # " == static_cast(*satf) : true)" + # " && " # !if(!eq(metadata, "sp"), "!orderedMetadata", "orderedMetadata") # ")\n" + # " return " # + MMA_SP_SYNC_NAME.id # ";", + "") // if supported + ) // satf + ) // kind + ) // metadata + ); // all_mma_sp_sync_ops + list>> f1 = !foldl([[[""]]], cond0, acc, el, + !listconcat(acc, el)); + list> f2 = !foldl([[""]], f1, acc, el, !listconcat(acc, el)); + list f3 = !foldl([""], f2, acc, el, !listconcat(acc, el)); + string id = !foldl("", f3, acc, el, acc # "\n" # el); +} + +def NVVM_MmaSpOp : NVVM_Op<"mma.sp.sync", [AttrSizedOperandSegments]> { + + let summary = "cooperative sparse matrix-multiply and accumulate"; + + let description = [{ + The `nvvm.mma.sp.sync` operation collectively performs the sparse operation + `D = matmul(A_sparse, B) + C` using all threads in a warp. + + This operation is similar to `nvvm.mma.sync` but with structured sparsity + in the A operand. The sparsity follows the 2:4 structured sparse pattern + where 2 out of every 4 elements are non-zero. + + All the threads in the warp must execute the same `mma.sp.sync` operation. + + The `sparseMetadata` operand provides the sparsity indices that indicate + which elements in the A operand are non-zero. The `sparsitySelector` + controls how the indices are distributed among threads in the warp and + should typically be 0 or 1. + + The optional `orderedMetadata` attribute specifies the metadata ordering: + - Absence (default): Uses standard sparse metadata ordering + - Presence: Uses ordered metadata (PTX ISA 8.5+, sm_90+) + + The optional `kind` attribute specifies mixed-precision modes for FP8 operations: + - `f8f6f4`: Enables e3m2, e2m3, e2m1 FP8 types and f16 accumulator (PTX ISA 8.7+, sm_90+) + - Only valid with ordered metadata and m16n8k64 shape + + The shapes, layouts, and data types follow the same constraints as the + regular `nvvm.mma.sync` operation, but the A operand contains only the + non-zero elements in compressed format. + + Example: + ```mlir + %d = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}} + : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + + // With ordered metadata: + %d = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}} + : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + ``` + }]; + + let results = (outs LLVM_AnyStruct:$res); + let arguments = (ins NVVM_MMAShapeAttr:$shape, + OptionalAttr:$intOverflowBehavior, + OptionalAttr:$multiplicandAPtxType, + OptionalAttr:$multiplicandBPtxType, + UnitAttr:$orderedMetadata, + OptionalAttr:$kind, + Variadic:$operandA, + Variadic:$operandB, + Variadic:$operandC, + I32:$sparseMetadata, + I32:$sparsitySelector); + + let extraClassDeclaration = !strconcat([{ + static llvm::Intrinsic::ID getIntrinsicID( + int64_t m, int64_t n, uint64_t k, + std::optional satf, + bool orderedMetadata, + std::optional kind, + mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum, + mlir::NVVM::MMATypes eltypeCEnum, mlir::NVVM::MMATypes eltypeDEnum) { + llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum); + llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum); + llvm::StringRef eltypeC = stringifyEnum(eltypeCEnum); + llvm::StringRef eltypeD = stringifyEnum(eltypeDEnum); + }], + MMA_SP_SYNC_INTR<>.id, [{ + return 0; + } + + static std::optional inferOperandMMAType(Type operandElType, + bool isAccumulator); + + MMATypes accumPtxType(); + MMATypes resultPtxType(); + + static mlir::NVVM::IDArgPair + getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase& builder); + }]); + + let builders = [ + OpBuilder<(ins "Type":$resultType, "ValueRange":$operandA, + "ValueRange":$operandB, "ValueRange":$operandC, + "Value":$sparseMetadata, "Value":$sparsitySelector, + "ArrayRef":$shape, + "std::optional":$intOverflow, + "std::optional>":$multiplicandPtxTypes)> + ]; + + string llvmBuilder = [{ + auto [id, args] = NVVM::MmaSpOp::getIntrinsicIDAndArgs( + *op, moduleTranslation, builder); + $res = createIntrinsicCall(builder, id, args); + }]; + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // NVVM TMA Ops //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 428bc72c88a30..846b299312d67 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -1026,6 +1026,482 @@ LogicalResult MmaOp::verify() { return success(); } +MMATypes MmaSpOp::accumPtxType() { + std::optional val = MmaOp::inferOperandMMAType( + getODSOperands(2).getTypes().front(), /*isAccumulator=*/true); + assert(val.has_value() && "accumulator PTX type should always be inferrable"); + return val.value(); +} + +MMATypes MmaSpOp::resultPtxType() { + std::optional val = + MmaOp::inferOperandMMAType(getResult().getType(), /*isAccumulator=*/true); + assert(val.has_value() && "result PTX type should always be inferrable"); + return val.value(); +} + +mlir::NVVM::IDArgPair +MmaSpOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { + auto thisOp = cast(op); + + // Get operands + llvm::SmallVector args; + for (mlir::Value v : thisOp.getOperands()) + args.push_back(mt.lookupValue(v)); + + // Get intrinsic ID using the existing getIntrinsicID method + auto intId = MmaSpOp::getIntrinsicID( + thisOp.getShape().getM(), thisOp.getShape().getN(), + thisOp.getShape().getK(), thisOp.getIntOverflowBehavior(), + thisOp.getOrderedMetadata(), thisOp.getKind(), + *thisOp.getMultiplicandAPtxType(), *thisOp.getMultiplicandBPtxType(), + thisOp.accumPtxType(), thisOp.resultPtxType()); + + return {intId, args}; +} + +void MmaSpOp::print(OpAsmPrinter &p) { + SmallVector regTypes; + struct OperandFragment { + StringRef operandName; + StringRef ptxTypeAttr; + SmallVector regs; + explicit OperandFragment(StringRef name, StringRef ptxTypeName) + : operandName(name), ptxTypeAttr(ptxTypeName) {} + }; + + std::array frags{ + OperandFragment("A", getMultiplicandAPtxTypeAttrName()), + OperandFragment("B", getMultiplicandBPtxTypeAttrName()), + OperandFragment("C", ""), OperandFragment("sparseMetadata", ""), + OperandFragment("selector", "")}; + SmallVector ignoreAttrNames{ + mlir::NVVM::MmaSpOp::getOperandSegmentSizeAttr()}; + + // Handle variadic operands A, B, C + for (unsigned fragIdx = 0; fragIdx < 3; fragIdx++) { + auto &frag = frags[fragIdx]; + auto varOperandSpec = getODSOperandIndexAndLength(fragIdx); + for (auto operandIdx = varOperandSpec.first; + operandIdx < varOperandSpec.first + varOperandSpec.second; + operandIdx++) { + frag.regs.push_back(this->getOperand(operandIdx)); + if (operandIdx == varOperandSpec.first) { + regTypes.push_back(this->getOperand(operandIdx).getType()); + } + } + std::optional inferredType = MmaOp::inferOperandMMAType( + regTypes.back(), /*isAccumulator=*/fragIdx >= 2); + if (inferredType) + ignoreAttrNames.push_back(frag.ptxTypeAttr); + } + + // Handle sparse metadata and selector (single operands) + frags[3].regs.push_back(getSparseMetadata()); + frags[4].regs.push_back(getSparsitySelector()); + + auto printMmaSpOperand = [&](const OperandFragment &frag) -> void { + p << " " << frag.operandName; + p << "["; + p.printOperands(frag.regs); + p << "]"; + }; + + for (const auto &frag : frags) + printMmaSpOperand(frag); + + p.printOptionalAttrDict((*this)->getAttrs(), ignoreAttrNames); + p << " : "; + p << "("; + for (int i = 0; i < 3; ++i) { + p << regTypes[i]; + if (i < 2) + p << ", "; + } + p << ") -> " << getResult().getType(); +} + +void MmaSpOp::build( + OpBuilder &builder, OperationState &result, Type resultType, + ValueRange operandA, ValueRange operandB, ValueRange operandC, + Value sparseMetadata, Value sparsitySelector, ArrayRef shape, + std::optional intOverflow, + std::optional> multiplicandPtxTypes) { + + assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)"); + MLIRContext *ctx = builder.getContext(); + result.addAttribute( + "shape", builder.getAttr(shape[0], shape[1], shape[2])); + + result.addOperands(operandA); + result.addOperands(operandB); + result.addOperands(operandC); + result.addOperands(sparseMetadata); + result.addOperands(sparsitySelector); + + if (multiplicandPtxTypes) { + result.addAttribute("multiplicandAPtxType", + MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0])); + result.addAttribute("multiplicandBPtxType", + MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1])); + } else { + if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false)) + result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res)); + if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false)) + result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res)); + } + + if (intOverflow.has_value()) + result.addAttribute("intOverflowBehavior", + MMAIntOverflowAttr::get(ctx, *intOverflow)); + + result.addTypes(resultType); + result.addAttribute( + MmaSpOp::getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr({static_cast(operandA.size()), + static_cast(operandB.size()), + static_cast(operandC.size()), 1, + 1})); // sparseMetadata and sparsitySelector +} + +ParseResult MmaSpOp::parse(OpAsmParser &parser, OperationState &result) { + struct OperandFragment { + std::optional elemtype; + SmallVector regs; + SmallVector regTypes; + }; + + Builder &builder = parser.getBuilder(); + std::array frags; // A, B, C, sparseMetadata, selector + + NamedAttrList namedAttributes; + + // A helper to parse the operand segments. + auto parseMmaSpOperand = [&](StringRef operandName, + OperandFragment &frag) -> LogicalResult { + if (parser.parseKeyword(operandName).failed()) + return failure(); + if (parser + .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare) + .failed()) + return failure(); + return success(); + }; + + // Parse the operand segments. + if (parseMmaSpOperand("A", frags[0]).failed()) + return failure(); + if (parseMmaSpOperand("B", frags[1]).failed()) + return failure(); + if (parseMmaSpOperand("C", frags[2]).failed()) + return failure(); + if (parseMmaSpOperand("sparseMetadata", frags[3]).failed()) + return failure(); + if (parseMmaSpOperand("selector", frags[4]).failed()) + return failure(); + + if (parser.parseOptionalAttrDict(namedAttributes).failed()) + return failure(); + + // Parse the type specification and resolve operands. + SmallVector operandTypes; + if (failed(parser.parseColon())) + return failure(); + if (failed(parser.parseLParen())) + return failure(); + if (failed(parser.parseTypeList(operandTypes))) + return failure(); + if (failed(parser.parseRParen())) + return failure(); + if (operandTypes.size() != 3) + return parser.emitError( + parser.getNameLoc(), + "expected one type for each operand segment but got " + + Twine(operandTypes.size()) + " types"); + for (const auto &iter : llvm::enumerate(operandTypes)) { + auto &frag = frags[iter.index()]; + frag.regTypes.resize(frag.regs.size(), iter.value()); + if (failed(parser.resolveOperands(frag.regs, frag.regTypes, + parser.getNameLoc(), result.operands))) + return failure(); + frag.elemtype = + MmaOp::inferOperandMMAType(frag.regTypes[0], + /*isAccumulator*/ iter.index() >= 2); + } + + Type resultType; + if (parser.parseArrow() || parser.parseType(resultType)) + return failure(); + frags[5].elemtype = + MmaOp::inferOperandMMAType(resultType, /*isAccumulator*/ true); + + // Resolve sparse metadata and selector (assume i32 type) + Type i32Type = builder.getIntegerType(32); + if (parser + .resolveOperands(frags[3].regs, i32Type, parser.getCurrentLocation(), + result.operands) + .failed()) + return failure(); + if (parser + .resolveOperands(frags[4].regs, i32Type, parser.getCurrentLocation(), + result.operands) + .failed()) + return failure(); + + std::array names{"multiplicandAPtxType", + "multiplicandBPtxType"}; + for (unsigned idx = 0; idx < names.size(); idx++) { + const auto &frag = frags[idx]; + std::optional attr = namedAttributes.getNamed(names[idx]); + if (!frag.elemtype.has_value() && !attr.has_value()) { + return parser.emitError( + parser.getNameLoc(), + "attribute " + names[idx] + + " is not provided explicitly and cannot be inferred"); + } + if (!attr.has_value()) + result.addAttribute( + names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype)); + } + + result.addTypes(resultType); + if (!namedAttributes.empty()) + result.addAttributes(namedAttributes); + result.addAttribute(MmaSpOp::getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr({ + static_cast(frags[0].regs.size()), + static_cast(frags[1].regs.size()), + static_cast(frags[2].regs.size()), + 1, // sparseMetadata + 1 // sparsitySelector + })); + return success(); +} + +LogicalResult MmaSpOp::verify() { + MLIRContext *context = getContext(); + auto f16Ty = Float16Type::get(context); + auto i32Ty = IntegerType::get(context, 32); + auto f16x2Ty = VectorType::get(2, f16Ty); + auto f32Ty = Float32Type::get(context); + auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral( + context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); + + auto s32x4StructTy = + LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty}); + auto f32x8StructTy = + LLVM::LLVMStructType::getLiteral(context, SmallVector(8, f32Ty)); + auto f16x2x2StructTy = + LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty}); + auto f32x4StructTy = + LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty}); + auto s32x2StructTy = + LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty}); + + std::array mmaShape{getShapeAttr().getM(), getShapeAttr().getN(), + getShapeAttr().getK()}; + + // These variables define the set of allowed data types for matrices A, B, C, + // and result. + using AllowedShapes = SmallVector, 2>; + using AllowedTypes = SmallVector, 2>; + AllowedShapes allowedShapes; + AllowedTypes expectedA; + AllowedTypes expectedB; + AllowedTypes expectedC; + SmallVector expectedResult; + + // When M = 16, we just need to calculate the number of 8xk tiles, where + // k is a factor that depends on the data type. + if (mmaShape[0] == 16) { + int64_t kFactor; + Type multiplicandFragType; + switch (*getMultiplicandAPtxType()) { + case MMATypes::tf32: + kFactor = 4; + multiplicandFragType = i32Ty; + expectedResult.push_back(LLVM::LLVMStructType::getLiteral( + context, {f32Ty, f32Ty, f32Ty, f32Ty})); + // Sparse MMA supports m16n8k8 and m16n8k16 for tf32 + allowedShapes.push_back({16, 8, 8}); + allowedShapes.push_back({16, 8, 16}); + break; + case MMATypes::bf16: + kFactor = 8; + multiplicandFragType = i32Ty; + expectedResult.push_back(LLVM::LLVMStructType::getLiteral( + context, {f32Ty, f32Ty, f32Ty, f32Ty})); + // Sparse MMA supports m16n8k16 and m16n8k32 for bf16 + allowedShapes.push_back({16, 8, 16}); + allowedShapes.push_back({16, 8, 32}); + break; + case MMATypes::f16: + kFactor = 8; + multiplicandFragType = f16x2Ty; + expectedResult.push_back(f16x2x2StructTy); + expectedResult.push_back(f32x4StructTy); + // Sparse MMA supports m16n8k16 and m16n8k32 for f16 + allowedShapes.push_back({16, 8, 16}); + allowedShapes.push_back({16, 8, 32}); + break; + case MMATypes::s4: + case MMATypes::u4: + kFactor = 32; + // Sparse MMA supports m16n8k64 and m16n8k128 for s4/u4 + allowedShapes.push_back({16, 8, 64}); + allowedShapes.push_back({16, 8, 128}); + break; + case MMATypes::s8: + case MMATypes::u8: + kFactor = 16; + // Sparse MMA supports m16n8k32 and m16n8k64 for s8/u8 + allowedShapes.push_back({16, 8, 32}); + allowedShapes.push_back({16, 8, 64}); + break; + case MMATypes::e4m3: + case MMATypes::e5m2: + case MMATypes::e3m2: + case MMATypes::e2m3: + case MMATypes::e2m1: + kFactor = 32; + multiplicandFragType = i32Ty; + expectedResult.push_back(f16x2x2StructTy); + expectedResult.push_back(f32x4StructTy); + // Sparse MMA supports m16n8k64 for FP8 types + allowedShapes.push_back({16, 8, 64}); + break; + default: + return emitError("invalid shape or multiplicand type: " + + stringifyEnum(getMultiplicandAPtxType().value())); + } + + if (isIntegerPtxType(getMultiplicandAPtxType().value())) { + expectedResult.push_back(s32x4StructTy); + expectedC.emplace_back(4, i32Ty); + multiplicandFragType = i32Ty; + } else if (*getMultiplicandAPtxType() >= MMATypes::e4m3 && + *getMultiplicandAPtxType() <= MMATypes::e2m1) { + // FP8 types + expectedC.emplace_back(2, f16x2Ty); + expectedC.emplace_back(4, f32Ty); + } else { + expectedC.emplace_back(2, f16x2Ty); + expectedC.emplace_back(4, f32Ty); + } + + // For sparse MMA, A operand is compressed (2:4 sparsity means half the + // elements) + int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor) / 2; + int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor); + expectedA.emplace_back(unitA, multiplicandFragType); + expectedB.emplace_back(unitB, multiplicandFragType); + + if (resultPtxType() != accumPtxType()) + return emitOpError("ctype does not match dtype"); + } + + // In the M=8 case, there is only 1 possible case per data type. + if (mmaShape[0] == 8) { + if (*getMultiplicandAPtxType() == MMATypes::f16) { + expectedA.emplace_back(2, f16x2Ty); + expectedB.emplace_back(2, f16x2Ty); + expectedResult.push_back(f16x2x4StructTy); + expectedResult.push_back(f32x8StructTy); + expectedC.emplace_back(4, f16x2Ty); + expectedC.emplace_back(8, f32Ty); + allowedShapes.push_back({8, 8, 4}); + } + if (*getMultiplicandAPtxType() == MMATypes::f64) { + Type f64Ty = Float64Type::get(context); + expectedA.emplace_back(1, f64Ty); + expectedB.emplace_back(1, f64Ty); + expectedC.emplace_back(2, f64Ty); + expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral( + context, SmallVector(2, f64Ty))); + allowedShapes.push_back({8, 8, 4}); + } + if (isIntegerPtxType(getMultiplicandAPtxType().value())) { + expectedA.push_back({i32Ty}); + expectedB.push_back({i32Ty}); + expectedC.push_back({i32Ty, i32Ty}); + expectedResult.push_back(s32x2StructTy); + if (isInt4PtxType(getMultiplicandAPtxType().value())) + allowedShapes.push_back({8, 8, 32}); + if (isInt8PtxType(getMultiplicandAPtxType().value())) + allowedShapes.push_back({8, 8, 16}); + } + } + + std::string errorMessage; + llvm::raw_string_ostream errorStream(errorMessage); + + // Check that we matched an existing shape/dtype combination. + if (expectedA.empty() || expectedB.empty() || expectedC.empty() || + !llvm::is_contained(allowedShapes, mmaShape)) { + errorStream << "unimplemented variant for MMA shape <"; + llvm::interleaveComma(mmaShape, errorStream); + errorStream << ">"; + return emitOpError(errorMessage); + } + + // Verify the operand types for segments of A, B, and C operands. + std::array operandNames{"A", "B", "C"}; + for (const auto &iter : llvm::enumerate( + SmallVector{expectedA, expectedB, expectedC})) { + auto spec = this->getODSOperandIndexAndLength(iter.index()); + SmallVector operandTySeg(operand_type_begin() + spec.first, + operand_type_begin() + spec.first + + spec.second); + bool match = llvm::is_contained(iter.value(), operandTySeg); + + if (!match) { + errorStream << "Could not match types for the " + << operandNames[iter.index()] + << " operands; expected one of "; + for (const auto &x : iter.value()) { + errorStream << x.size() << "x" << x[0] << " "; + } + errorStream << "but got "; + llvm::interleaveComma(operandTySeg, errorStream); + return emitOpError(errorMessage); + } + } + + // Check the result type + if (!llvm::any_of(expectedResult, [&](Type expectedResultType) { + return expectedResultType == getResult().getType(); + })) { + errorStream + << "Could not match allowed types for the result; expected one of "; + llvm::interleaveComma(expectedResult, errorStream); + errorStream << " but got " << getResult().getType(); + return emitOpError(errorMessage); + } + + // Ensure int4/int8 MMA variants specify the accum overflow behavior + // attribute. + if (isInt4PtxType(*getMultiplicandAPtxType()) || + isInt8PtxType(*getMultiplicandAPtxType())) { + if (!getIntOverflowBehavior()) + return emitOpError("op requires " + + getIntOverflowBehaviorAttrName().strref() + + " attribute"); + } + + // Validate sparse metadata type (should be i32) + if (!getSparseMetadata().getType().isInteger(32)) { + return emitOpError() << "sparse metadata must be i32 type"; + } + + // Validate sparsity selector type (should be i32) + if (!getSparsitySelector().getType().isInteger(32)) { + return emitOpError() << "sparsity selector must be i32 type"; + } + + return success(); +} + LogicalResult ShflOp::verify() { auto returnStructType = llvm::dyn_cast(getType()); diff --git a/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-kind.mlir b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-kind.mlir new file mode 100644 index 0000000000000..ff3e91b89016d --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-kind.mlir @@ -0,0 +1,221 @@ +// RUN: mlir-opt %s -split-input-file | FileCheck %s + +// This file contains tests for sparse MMA (mma.sp.sync) operations with KIND variants. +// The kind::f8f6f4 variant was introduced in PTX ISA 8.7 for sm_90+ architectures. +// +// Based on PTX ISA documentation: +// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-for-sparse-mma +// +// KIND::F8F6F4 enables: +// - Additional FP8 types: e3m2, e2m3, e2m1 +// - F16 accumulator for m16n8k64 FP8 operations +// - Mixed-precision FP8 computations +// +// Requirements: +// - ONLY works with ordered metadata (sp::ordered_metadata) +// - ONLY for shape m16n8k64 +// - ONLY for FP8 types (not integers or other floats) + +// ============================================================================= +// FP8 e4m3 Sparse MMA with KIND (m16n8k64) +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e4m3_f16 +func.func @nvvm_mma_sp_kind_m16n8k64_e4m3_f16( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, orderedMetadata, shape = #nvvm.shape} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {kind = #nvvm.mma_kind, + orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + shape = #nvvm.shape} + : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e4m3_f32 +func.func @nvvm_mma_sp_kind_m16n8k64_e4m3_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, orderedMetadata, shape = #nvvm.shape} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {kind = #nvvm.mma_kind, + orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + shape = #nvvm.shape} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// FP8 e5m2 Sparse MMA with KIND (m16n8k64) +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e5m2_f16 +func.func @nvvm_mma_sp_kind_m16n8k64_e5m2_f16( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, orderedMetadata, shape = #nvvm.shape} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {kind = #nvvm.mma_kind, + orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + shape = #nvvm.shape} + : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e5m2_f32 +func.func @nvvm_mma_sp_kind_m16n8k64_e5m2_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, orderedMetadata, shape = #nvvm.shape} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {kind = #nvvm.mma_kind, + orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + shape = #nvvm.shape} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// FP8 e3m2 Sparse MMA with KIND (m16n8k64) +// NOTE: e3m2 is ONLY available with kind::f8f6f4 +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e3m2_f16 +func.func @nvvm_mma_sp_kind_m16n8k64_e3m2_f16( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, orderedMetadata, shape = #nvvm.shape} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {kind = #nvvm.mma_kind, + orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + shape = #nvvm.shape} + : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e3m2_f32 +func.func @nvvm_mma_sp_kind_m16n8k64_e3m2_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, orderedMetadata, shape = #nvvm.shape} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {kind = #nvvm.mma_kind, + orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + shape = #nvvm.shape} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// FP8 e2m3 Sparse MMA with KIND (m16n8k64) +// NOTE: e2m3 is ONLY available with kind::f8f6f4 +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e2m3_f16 +func.func @nvvm_mma_sp_kind_m16n8k64_e2m3_f16( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, orderedMetadata, shape = #nvvm.shape} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {kind = #nvvm.mma_kind, + orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + shape = #nvvm.shape} + : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e2m3_f32 +func.func @nvvm_mma_sp_kind_m16n8k64_e2m3_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, orderedMetadata, shape = #nvvm.shape} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {kind = #nvvm.mma_kind, + orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + shape = #nvvm.shape} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// FP8 e2m1 Sparse MMA with KIND (m16n8k64) +// NOTE: e2m1 is ONLY available with kind::f8f6f4 +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e2m1_f16 +func.func @nvvm_mma_sp_kind_m16n8k64_e2m1_f16( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, orderedMetadata, shape = #nvvm.shape} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {kind = #nvvm.mma_kind, + orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + shape = #nvvm.shape} + : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_kind_m16n8k64_e2m1_f32 +func.func @nvvm_mma_sp_kind_m16n8k64_e2m1_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {kind = #nvvm.mma_kind, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, orderedMetadata, shape = #nvvm.shape} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {kind = #nvvm.mma_kind, + orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + shape = #nvvm.shape} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + diff --git a/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-ordered.mlir b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-ordered.mlir new file mode 100644 index 0000000000000..a4e2812e54c12 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp-ordered.mlir @@ -0,0 +1,411 @@ +// RUN: mlir-opt %s -split-input-file | FileCheck %s + +// This file contains tests for sparse MMA (mma.sp.sync) operations with ORDERED metadata. +// The ordered metadata variant was introduced in PTX ISA 8.5 for sm_90+ architectures. +// +// Based on PTX ISA documentation: +// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-for-sparse-mma +// +// Ordered metadata provides an alternative metadata ordering for 2:4 structured sparsity +// that can offer better performance on newer architectures. + +// ============================================================================= +// F16 Sparse MMA Operations with Ordered Metadata (m16n8k16) +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k16_f16_f16 +func.func @nvvm_mma_sp_ordered_m16n8k16_f16_f16( + %a0 : vector<2xf16>, %a1 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {orderedMetadata, shape = #nvvm.shape} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + shape = #nvvm.shape} + : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k16_f16_f32 +func.func @nvvm_mma_sp_ordered_m16n8k16_f16_f32( + %a0 : vector<2xf16>, %a1 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {orderedMetadata, shape = #nvvm.shape} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + shape = #nvvm.shape} + : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// F16 Sparse MMA Operations with Ordered Metadata (m16n8k32) +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_f16_f16 +func.func @nvvm_mma_sp_ordered_m16n8k32_f16_f16( + %a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, %b2 : vector<2xf16>, %b3 : vector<2xf16>, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {orderedMetadata, shape = #nvvm.shape} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + shape = #nvvm.shape} + : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_f16_f32 +func.func @nvvm_mma_sp_ordered_m16n8k32_f16_f32( + %a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, %b2 : vector<2xf16>, %b3 : vector<2xf16>, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {orderedMetadata, shape = #nvvm.shape} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + shape = #nvvm.shape} + : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// BF16 Sparse MMA Operations with Ordered Metadata +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k16_bf16_f32 +func.func @nvvm_mma_sp_ordered_m16n8k16_bf16_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, orderedMetadata, shape = #nvvm.shape} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + shape = #nvvm.shape} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_bf16_f32 +func.func @nvvm_mma_sp_ordered_m16n8k32_bf16_f32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, orderedMetadata, shape = #nvvm.shape} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + shape = #nvvm.shape} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// TF32 Sparse MMA Operations with Ordered Metadata +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k8_tf32_f32 +func.func @nvvm_mma_sp_ordered_m16n8k8_tf32_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, orderedMetadata, shape = #nvvm.shape} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + shape = #nvvm.shape} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k16_tf32_f32 +func.func @nvvm_mma_sp_ordered_m16n8k16_tf32_f32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, orderedMetadata, shape = #nvvm.shape} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + shape = #nvvm.shape} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// Integer (s8) Sparse MMA Operations with Ordered Metadata +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_s8_s32 +func.func @nvvm_mma_sp_ordered_m16n8k32_s8_s32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, orderedMetadata, shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + intOverflowBehavior = #nvvm.mma_int_overflow, + shape = #nvvm.shape} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_s8_s32_satfinite +func.func @nvvm_mma_sp_ordered_m16n8k32_s8_s32_satfinite( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, orderedMetadata, shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + intOverflowBehavior = #nvvm.mma_int_overflow, + shape = #nvvm.shape} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_s8_s32 +func.func @nvvm_mma_sp_ordered_m16n8k64_s8_s32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, orderedMetadata, shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + intOverflowBehavior = #nvvm.mma_int_overflow, + shape = #nvvm.shape} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// ============================================================================= +// Integer (u8) Sparse MMA Operations with Ordered Metadata +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k32_u8_s32 +func.func @nvvm_mma_sp_ordered_m16n8k32_u8_s32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, orderedMetadata, shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + intOverflowBehavior = #nvvm.mma_int_overflow, + shape = #nvvm.shape} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_u8_s32 +func.func @nvvm_mma_sp_ordered_m16n8k64_u8_s32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, orderedMetadata, shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + intOverflowBehavior = #nvvm.mma_int_overflow, + shape = #nvvm.shape} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// ============================================================================= +// Sub-byte Integer (s4) Sparse MMA Operations with Ordered Metadata +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_s4_s32 +func.func @nvvm_mma_sp_ordered_m16n8k64_s4_s32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, orderedMetadata, shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + intOverflowBehavior = #nvvm.mma_int_overflow, + shape = #nvvm.shape} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k128_s4_s32 +func.func @nvvm_mma_sp_ordered_m16n8k128_s4_s32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, orderedMetadata, shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + intOverflowBehavior = #nvvm.mma_int_overflow, + shape = #nvvm.shape} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// ============================================================================= +// Sub-byte Integer (u4) Sparse MMA Operations with Ordered Metadata +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_u4_s32 +func.func @nvvm_mma_sp_ordered_m16n8k64_u4_s32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, orderedMetadata, shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + intOverflowBehavior = #nvvm.mma_int_overflow, + shape = #nvvm.shape} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k128_u4_s32 +func.func @nvvm_mma_sp_ordered_m16n8k128_u4_s32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, orderedMetadata, shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + intOverflowBehavior = #nvvm.mma_int_overflow, + shape = #nvvm.shape} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// ============================================================================= +// FP8 (e4m3) Sparse MMA Operations with Ordered Metadata +// NOTE: FP8 ordered metadata requires PTX ISA 8.7+ and sm_90+ +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_e4m3_f16 +func.func @nvvm_mma_sp_ordered_m16n8k64_e4m3_f16( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, orderedMetadata, shape = #nvvm.shape} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + shape = #nvvm.shape} + : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_e4m3_f32 +func.func @nvvm_mma_sp_ordered_m16n8k64_e4m3_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, orderedMetadata, shape = #nvvm.shape} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + shape = #nvvm.shape} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// FP8 (e5m2) Sparse MMA Operations with Ordered Metadata +// NOTE: FP8 ordered metadata requires PTX ISA 8.7+ and sm_90+ +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_e5m2_f16 +func.func @nvvm_mma_sp_ordered_m16n8k64_e5m2_f16( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, orderedMetadata, shape = #nvvm.shape} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + shape = #nvvm.shape} + : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_ordered_m16n8k64_e5m2_f32 +func.func @nvvm_mma_sp_ordered_m16n8k64_e5m2_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, orderedMetadata, shape = #nvvm.shape} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {orderedMetadata, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + shape = #nvvm.shape} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + diff --git a/mlir/test/Dialect/LLVMIR/nvvm-mma-sp.mlir b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp.mlir new file mode 100644 index 0000000000000..e7122aac61baf --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/nvvm-mma-sp.mlir @@ -0,0 +1,390 @@ +// RUN: mlir-opt %s -split-input-file | FileCheck %s + +// This file contains tests for all sparse MMA (mma.sp.sync) operations in the NVVM dialect +// Based on PTX ISA documentation: +// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-for-sparse-mma +// +// Sparse MMA operations follow 2:4 structured sparsity where 2 out of every 4 elements +// in the A operand are non-zero. The A operand is provided in compressed form, +// and sparseMetadata provides the sparsity indices. +// +// NOTE: These tests use the default (standard) metadata ordering. +// For ordered metadata tests (PTX ISA 8.5+, sm_90+), see nvvm-mma-sp-ordered.mlir. + +// ============================================================================= +// F16 Sparse MMA Operations (m16n8k16) +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k16_f16_f16 +func.func @nvvm_mma_sp_m16n8k16_f16_f16( + %a0 : vector<2xf16>, %a1 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {shape = #nvvm.shape} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {shape = #nvvm.shape} + : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k16_f16_f32 +func.func @nvvm_mma_sp_m16n8k16_f16_f32( + %a0 : vector<2xf16>, %a1 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {shape = #nvvm.shape} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {shape = #nvvm.shape} + : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// F16 Sparse MMA Operations (m16n8k32) +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_f16_f16 +func.func @nvvm_mma_sp_m16n8k32_f16_f16( + %a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, %b2 : vector<2xf16>, %b3 : vector<2xf16>, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {shape = #nvvm.shape} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {shape = #nvvm.shape} + : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_f16_f32 +func.func @nvvm_mma_sp_m16n8k32_f16_f32( + %a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>, + %b0 : vector<2xf16>, %b1 : vector<2xf16>, %b2 : vector<2xf16>, %b3 : vector<2xf16>, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {shape = #nvvm.shape} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {shape = #nvvm.shape} + : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// BF16 Sparse MMA Operations +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k16_bf16_f32 +func.func @nvvm_mma_sp_m16n8k16_bf16_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = #nvvm.shape} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + shape = #nvvm.shape} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_bf16_f32 +func.func @nvvm_mma_sp_m16n8k32_bf16_f32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = #nvvm.shape} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + shape = #nvvm.shape} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// TF32 Sparse MMA Operations +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k8_tf32_f32 +func.func @nvvm_mma_sp_m16n8k8_tf32_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = #nvvm.shape} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + shape = #nvvm.shape} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k16_tf32_f32 +func.func @nvvm_mma_sp_m16n8k16_tf32_f32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = #nvvm.shape} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + shape = #nvvm.shape} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// Integer (s8) Sparse MMA Operations +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_s8_s32 +func.func @nvvm_mma_sp_m16n8k32_s8_s32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + intOverflowBehavior = #nvvm.mma_int_overflow, + shape = #nvvm.shape} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_s8_s32_satfinite +func.func @nvvm_mma_sp_m16n8k32_s8_s32_satfinite( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + intOverflowBehavior = #nvvm.mma_int_overflow, + shape = #nvvm.shape} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_s8_s32 +func.func @nvvm_mma_sp_m16n8k64_s8_s32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + intOverflowBehavior = #nvvm.mma_int_overflow, + shape = #nvvm.shape} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// ============================================================================= +// Integer (u8) Sparse MMA Operations +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k32_u8_s32 +func.func @nvvm_mma_sp_m16n8k32_u8_s32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + intOverflowBehavior = #nvvm.mma_int_overflow, + shape = #nvvm.shape} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_u8_s32 +func.func @nvvm_mma_sp_m16n8k64_u8_s32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + intOverflowBehavior = #nvvm.mma_int_overflow, + shape = #nvvm.shape} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// ============================================================================= +// Sub-byte Integer (s4) Sparse MMA Operations +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_s4_s32 +func.func @nvvm_mma_sp_m16n8k64_s4_s32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + intOverflowBehavior = #nvvm.mma_int_overflow, + shape = #nvvm.shape} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k128_s4_s32 +func.func @nvvm_mma_sp_m16n8k128_s4_s32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + intOverflowBehavior = #nvvm.mma_int_overflow, + shape = #nvvm.shape} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// ============================================================================= +// Sub-byte Integer (u4) Sparse MMA Operations +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_u4_s32 +func.func @nvvm_mma_sp_m16n8k64_u4_s32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + intOverflowBehavior = #nvvm.mma_int_overflow, + shape = #nvvm.shape} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k128_u4_s32 +func.func @nvvm_mma_sp_m16n8k128_u4_s32( + %a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, + %b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32, + %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}, {{.*}}, {{.*}}] B[{{.*}}, {{.*}}, {{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = #nvvm.shape} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1, %b2, %b3] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + intOverflowBehavior = #nvvm.mma_int_overflow, + shape = #nvvm.shape} + : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> + return +} + +// ============================================================================= +// FP8 (e4m3) Sparse MMA Operations +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_e4m3_f16 +func.func @nvvm_mma_sp_m16n8k64_e4m3_f16( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = #nvvm.shape} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + shape = #nvvm.shape} + : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_e4m3_f32 +func.func @nvvm_mma_sp_m16n8k64_e4m3_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = #nvvm.shape} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + shape = #nvvm.shape} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} + +// ============================================================================= +// FP8 (e5m2) Sparse MMA Operations +// ============================================================================= + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_e5m2_f16 +func.func @nvvm_mma_sp_m16n8k64_e5m2_f16( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : vector<2xf16>, %c1 : vector<2xf16>, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = #nvvm.shape} : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + shape = #nvvm.shape} + : (i32, i32, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + return +} + +// CHECK-LABEL: @nvvm_mma_sp_m16n8k64_e5m2_f32 +func.func @nvvm_mma_sp_m16n8k64_e5m2_f32( + %a0 : i32, %a1 : i32, + %b0 : i32, %b1 : i32, + %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32, + %meta : i32, %sel : i32) { + // CHECK: nvvm.mma.sp.sync A[{{.*}}, {{.*}}] B[{{.*}}, {{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] {multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = #nvvm.shape} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + %0 = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3] + sparseMetadata[%meta] selector[%sel] + {multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + shape = #nvvm.shape} + : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> + return +} +