diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index a96d65d3fcacd..1faa435fca6f9 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -2499,6 +2499,30 @@ class NVVM_MMA_OPS { bf16_mma_sp_ops, tf32_mma_sp_ops, fp_mma_sp_ops, fp8_mma_sp_ops, subint_mma_sp_ops, int_mma_sp_ops); + // Block scale MMA operations (dense) + list> mxf4_mma_ops = MMA_OPS< + [GEOM<16,8,64>], + ["e2m1"], ["e2m1"], ["f32"], []>.ret; + list> mxf8f6f4_mma_ops = MMA_OPS< + [GEOM<16,8,32>], + ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"], + ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"], + ["f32"], []>.ret; + list> all_mma_block_scale_ops = !listconcat( + mxf4_mma_ops, mxf8f6f4_mma_ops); + + // Block scale sparse MMA operations + list> mxf4xx_mma_sp_ops = MMA_OPS< + [GEOM<16,8,128>], + ["e2m1"], ["e2m1"], ["f32"], []>.ret; + list> mxf8f6f4_mma_sp_ops = MMA_OPS< + [GEOM<16,8,64>], + ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"], + ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"], + ["f32"], []>.ret; + list> all_mma_sp_block_scale_ops = !listconcat( + mxf4xx_mma_sp_ops, mxf8f6f4_mma_sp_ops); + } def NVVM_MMA_OPS : NVVM_MMA_OPS; @@ -3332,7 +3356,7 @@ def NVVM_MmaSpOp : NVVM_Op<"mma.sp.sync", [AttrSizedOperandSegments]> { The optional `orderedMetadata` attribute specifies the metadata ordering: - Absence (default): Uses standard sparse metadata ordering - Presence: Uses ordered metadata (PTX ISA 8.5+, sm_90+) - + The optional `kind` attribute specifies mixed-precision modes for FP8 operations: - `f8f6f4`: Enables e3m2, e2m3, e2m1 FP8 types and f16 accumulator (PTX ISA 8.7+, sm_90+) - Only valid with ordered metadata and m16n8k64 shape @@ -3347,7 +3371,7 @@ def NVVM_MmaSpOp : NVVM_Op<"mma.sp.sync", [AttrSizedOperandSegments]> { sparseMetadata[%meta] selector[%sel] {shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> - + // With ordered metadata: %d = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] sparseMetadata[%meta] selector[%sel] @@ -3416,6 +3440,429 @@ def NVVM_MmaSpOp : NVVM_Op<"mma.sp.sync", [AttrSizedOperandSegments]> { let hasVerifier = 1; } +def ScaleVecSize1X : I32EnumAttrCase<"X1", 0, "x1">; +def ScaleVecSize2X : I32EnumAttrCase<"X2", 1, "x2">; +def ScaleVecSize4X : I32EnumAttrCase<"X4", 2, "x4">; + +def ScaleVecSize : I32EnumAttr< + "ScaleVecSize", + "MMA Scale Vector Sizes", + [ScaleVecSize1X, ScaleVecSize2X, ScaleVecSize4X]> { + let cppNamespace = "::mlir::NVVM"; + let genSpecializedAttr = 0; +} + +def ScaleVecSizeAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def UE8M0 : I32EnumAttrCase<"UE8M0", 0, "ue8m0">; +def UE4M3 : I32EnumAttrCase<"UE4M3", 1, "ue4m3">; + +def BlockScaleFormat : I32EnumAttr< + "BlockScaleFormat", + "MMA Block Scale Format", + [UE8M0, UE4M3] +> { + let cppNamespace = "::mlir::NVVM"; + let genSpecializedAttr = 0; +} + +def BlockScaleFormatAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def MMABlockScaleKindMXF8F6F4 : I32EnumAttrCase<"MXF8F6F4", 0, "mxf8f6f4">; +def MMABlockScaleKindMXF4 : I32EnumAttrCase<"MXF4", 1, "mxf4">; +def MMABlockScaleKindMXF4NVF4 : I32EnumAttrCase<"MXF4NVF4", 2, "mxf4nvf4">; + +def MMABlockScaleKind : I32EnumAttr< + "MMABlockScaleKind", + "Block Scale Kind", + [MMABlockScaleKindMXF8F6F4, MMABlockScaleKindMXF4, MMABlockScaleKindMXF4NVF4]> { + let cppNamespace = "::mlir::NVVM"; + let genSpecializedAttr = 0; +} + +def MMABlockScaleKindAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +/// Generate enum value of the mma.block_scale intrinsic. +class MMA_BLOCK_SCALE_NAME { + string signature = MMA_SIGNATURE.ret; + string id = "llvm::Intrinsic::nvvm_mma_block_scale" + # "_" # A.geom + # "_row_col" + # "_" # Kind + # !subst(".", "_", ScaleVecSize) + # signature + # "_" # SType; +} + +/// Generate enum value of the mma.sp.block_scale intrinsic. +class MMA_SP_BLOCK_SCALE_NAME { + string signature = MMA_SIGNATURE.ret; + string id = "llvm::Intrinsic::nvvm_mma_sp_ordered_metadata_block_scale" + # "_" # A.geom + # "_row_col" + # "_" # Kind + # !subst(".", "_", ScaleVecSize) + # signature + # "_" # SType; +} + +// Returns true if this combination is supported for MMA.BLOCK_SCALE ops. +// This references the NVVM_MMA_BLOCK_SCALE_SUPPORTED class from IntrinsicsNVVM.td +class NVVM_MMA_BLOCK_SCALE_SUPPORTED frags, string kind, + string stype, string scale_vec_size> { + string geom = frags[0].geom; + bit ret = !cond( + !and(!eq(geom, "m16n8k64"), + !eq(kind, "mxf4"), + !or(!eq(scale_vec_size, ""), + !eq(scale_vec_size, ".scale_2x")), + !eq(stype, "ue8m0")) : true, + !and(!eq(geom, "m16n8k64"), + !eq(kind, "mxf4nvf4"), + !eq(scale_vec_size, ".scale_2x"), + !eq(stype, "ue8m0")) : true, + !and(!eq(geom, "m16n8k64"), + !eq(kind, "mxf4nvf4"), + !eq(scale_vec_size, ".scale_4x"), + !eq(stype, "ue4m3")) : true, + !and(!eq(geom, "m16n8k32"), + !eq(kind, "mxf8f6f4"), + !or(!eq(scale_vec_size, ""), + !eq(scale_vec_size, ".scale_1x")), + !eq(stype, "ue8m0")) : true, + true: false + ); +} + +// Returns true if this combination is supported for MMA.SP.BLOCK_SCALE ops. +// This references the NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED class from IntrinsicsNVVM.td +class NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED frags, string kind, + string stype, string scale_vec_size> { + string geom = frags[0].geom; + bit ret = !cond( + !and(!eq(geom, "m16n8k128"), + !eq(kind, "mxf4"), + !eq(stype, "ue8m0"), + !or(!eq(scale_vec_size, ""), + !eq(scale_vec_size, ".scale_2x"))): true, + !and(!eq(geom, "m16n8k128"), + !eq(kind, "mxf4nvf4"), + !eq(stype, "ue8m0"), + !eq(scale_vec_size, ".scale_2x")): true, + !and(!eq(geom, "m16n8k128"), + !eq(kind, "mxf4nvf4"), + !eq(stype, "ue4m3"), + !eq(scale_vec_size, ".scale_4x")): true, + !and(!eq(geom, "m16n8k64"), + !eq(kind, "mxf8f6f4"), + !eq(stype, "ue8m0"), + !or(!eq(scale_vec_size, ""), + !eq(scale_vec_size, ".scale_1x"))): true, + true: false + ); +} + +/// Helper to create the mapping between the configuration and the mma.block_scale +/// intrinsic enum value. +class MMA_BLOCK_SCALE_INTR { + list>>> cond0 = + !foreach(op, NVVM_MMA_OPS.all_mma_block_scale_ops, + !foreach(kind, ["mxf4", "mxf4nvf4", "mxf8f6f4"], + !foreach(scale_vec_size, ["", ".scale_1x", ".scale_2x", ".scale_4x"], + !foreach(stype, ["ue8m0", "ue4m3"], + !if(NVVM_MMA_BLOCK_SCALE_SUPPORTED.ret, + "if (m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k + # " && \"" # op[0].ptx_elt_type # "\" == eltypeA" + # " && \"" # op[1].ptx_elt_type # "\" == eltypeB" + # " && \"" # op[2].ptx_elt_type # "\" == eltypeC" + # " && \"" # kind # "\" == stringifyEnum(kind)" + # " && \"" # stype # "\" == stringifyEnum(blockScaleFormat)" + # " && \"" # scale_vec_size # "\" == getScaleVecSizeStr(scaleVecSize))\n" + # " return " # + MMA_BLOCK_SCALE_NAME.id # ";", + "") // if supported + ) // stype + ) // scale_vec_size + ) // kind + ); // all_mma_block_scale_ops + list>> f1 = !foldl([[[""]]], cond0, acc, el, + !listconcat(acc, el)); + list> f2 = !foldl([[""]], f1, acc, el, !listconcat(acc, el)); + list f3 = !foldl([""], f2, acc, el, !listconcat(acc, el)); + string id = !foldl("", f3, acc, el, acc # "\n" # el); +} + +/// Helper to create the mapping between the configuration and the mma.sp.block_scale +/// intrinsic enum value. +class MMA_SP_BLOCK_SCALE_INTR { + list>>> cond0 = + !foreach(op, NVVM_MMA_OPS.all_mma_sp_block_scale_ops, + !foreach(kind, ["mxf4", "mxf4nvf4", "mxf8f6f4"], + !foreach(scale_vec_size, ["", ".scale_1x", ".scale_2x", ".scale_4x"], + !foreach(stype, ["ue8m0", "ue4m3"], + !if(NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED.ret, + "if (m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k + # " && \"" # op[0].ptx_elt_type # "\" == eltypeA" + # " && \"" # op[1].ptx_elt_type # "\" == eltypeB" + # " && \"" # op[2].ptx_elt_type # "\" == eltypeC" + # " && \"" # kind # "\" == stringifyEnum(kind)" + # " && \"" # stype # "\" == stringifyEnum(blockScaleFormat)" + # " && \"" # scale_vec_size # "\" == getScaleVecSizeStr(scaleVecSize))\n" + # " return " # + MMA_SP_BLOCK_SCALE_NAME.id # ";", + "") // if supported + ) // stype + ) // scale_vec_size + ) // kind + ); // all_mma_sp_block_scale_ops + list>> f1 = !foldl([[[""]]], cond0, acc, el, + !listconcat(acc, el)); + list> f2 = !foldl([[""]], f1, acc, el, !listconcat(acc, el)); + list f3 = !foldl([""], f2, acc, el, !listconcat(acc, el)); + string id = !foldl("", f3, acc, el, acc # "\n" # el); +} + +// Common base class for MMA block scale operations (dense and sparse) +class NVVM_MmaBlockScaleBase traits = []> : + NVVM_Op { + + let results = (outs LLVM_AnyStruct:$res); + + // Common attributes shared by both dense and sparse variants + dag commonArguments = (ins + NVVM_MMAShapeAttr:$shape, + OptionalAttr:$multiplicandAPtxType, + OptionalAttr:$multiplicandBPtxType, + ScaleVecSizeAttr:$scaleVecSize, + BlockScaleFormatAttr:$blockScaleFormat, + MMABlockScaleKindAttr:$kind); + + // Common variadic operands for A, B, C matrices + dag commonVariadicOperands = (ins + Variadic:$operandA, + Variadic:$operandB, + Variadic:$operandC); + + // Common scale operands for both A and B + dag commonScaleOperands = (ins + I32:$scaleAData, + I16:$byteIdA, + I16:$threadIdA, + I32:$scaleBData, + I16:$byteIdB, + I16:$threadIdB); + + let extraClassDeclaration = !strconcat([{ + static llvm::Intrinsic::ID getIntrinsicID( + int64_t m, int64_t n, uint64_t k, + mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum, + mlir::NVVM::MMATypes eltypeCEnum, + mlir::NVVM::ScaleVecSize scaleVecSize, + mlir::NVVM::BlockScaleFormat blockScaleFormat, + mlir::NVVM::MMABlockScaleKind kind) { + llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum); + llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum); + llvm::StringRef eltypeC = stringifyEnum(eltypeCEnum); + + auto getScaleVecSizeStr = [](ScaleVecSize svs) -> std::string { + switch (svs) { + case ScaleVecSize::X1: return ".scale_1x"; + case ScaleVecSize::X2: return ".scale_2x"; + case ScaleVecSize::X4: return ".scale_4x"; + } + return ""; + }; + }], + MMA_BLOCK_SCALE_INTR<>.id, [{ + return 0; + } + + // Common declarations - implementations in NVVMDialect.cpp + MMATypes accumPtxType(); + MMATypes resultPtxType(); + + static mlir::NVVM::IDArgPair + getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase& builder); + }]); + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + +def NVVM_MmaBlockScaleOp : NVVM_MmaBlockScaleBase<"mma.block_scale"> { + + let summary = "cooperative matrix-multiply and accumulate with block scaling"; + + let description = [{ + The `nvvm.mma.block_scale` operation collectively performs the operation + `D = matmul(A * SF_A, B * SF_B) + C` using all threads in a warp. + + A, B, C and D are dense matrices and SF_A and SF_B are scaling factors. + Dimensions of SF_A and SF_B are based on scale vector sizes (x1, x2, x4), + and the data type must be either ue8m0 or ue4m3. + + All the threads in the warp must execute the same `mma.block_scale` operation. + + This operation follows the same design pattern as `nvvm.mma.sync`, with additional + scaling operands for both A and B matrices. + + Example: + ```mlir + %d = nvvm.mma.block_scale A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xf16>, vector<2xf16>, vector<2xf32>) -> !llvm.struct<(f32, f32)> + ``` + }]; + + // Combine common attributes and operands + let arguments = !con(commonArguments, commonVariadicOperands, commonScaleOperands); + + let builders = [ + OpBuilder<(ins "Type":$resultType, "ValueRange":$operandA, + "ValueRange":$operandB, "ValueRange":$operandC, + "Value":$scaleAData, "Value":$byteIdA, "Value":$threadIdA, + "Value":$scaleBData, "Value":$byteIdB, "Value":$threadIdB, + "ArrayRef":$shape, + "std::optional>":$multiplicandPtxTypes, + "ScaleVecSize":$scaleVecSize, + "BlockScaleFormat":$blockScaleFormat, + "MMABlockScaleKind":$kind)> + ]; + + string llvmBuilder = [{ + auto [id, args] = NVVM::MmaBlockScaleOp::getIntrinsicIDAndArgs( + *op, moduleTranslation, builder); + $res = createIntrinsicCall(builder, id, args); + }]; +} + +def NVVM_MmaSpBlockScaleOp : NVVM_MmaBlockScaleBase<"mma.sp.block_scale"> { + + let summary = "cooperative sparse matrix-multiply and accumulate with block scaling"; + + let description = [{ + The `nvvm.mma.sp.block_scale` operation collectively performs the operation + `D = matmul(A_sparse * SF_A, B * SF_B) + C` using all threads in a warp. + + A is a sparse matrix, and B, C and D are dense matrices. + SF_A and SF_B are scaling factors. + Dimensions of SF_A and SF_B are based on scale vector sizes (x1, x2, x4), + and the data type must be either ue8m0 or ue4m3. + + This operation is similar to `nvvm.mma.block_scale` but with structured sparsity + in the A operand. The sparsity follows the 2:4 structured sparse pattern + where 2 out of every 4 elements are non-zero. + + All the threads in the warp must execute the same `mma.sp.block_scale` operation. + + The `sparseMetadata` operand provides the sparsity indices that indicate + which elements in the A operand are non-zero. The `sparsitySelector` + controls how the indices are distributed among threads in the warp and + should typically be 0 or 1. + + This operation follows the same design pattern as `nvvm.mma.sp.sync`, with additional + scaling operands for both A and B matrices. Note that sparse block scale operations + always use ordered metadata (sm_90+). + + Example: + ```mlir + %d = nvvm.mma.sp.block_scale A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<2xf16>, vector<2xf16>, vector<2xf32>) -> !llvm.struct<(f32, f32)> + ``` + }]; + + // Sparse-specific attributes and operands + dag sparseSpecificArguments = (ins + UnitAttr:$orderedMetadata); + + dag sparseSpecificOperands = (ins + I32:$sparseMetadata, + I32:$sparsitySelector); + + // Combine common and sparse-specific attributes and operands + let arguments = !con(commonArguments, sparseSpecificArguments, + commonVariadicOperands, sparseSpecificOperands, + commonScaleOperands); + + // Override extraClassDeclaration to use sparse intrinsics + let extraClassDeclaration = !strconcat([{ + static llvm::Intrinsic::ID getIntrinsicID( + int64_t m, int64_t n, uint64_t k, + mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum, + mlir::NVVM::MMATypes eltypeCEnum, + mlir::NVVM::ScaleVecSize scaleVecSize, + mlir::NVVM::BlockScaleFormat blockScaleFormat, + mlir::NVVM::MMABlockScaleKind kind) { + llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum); + llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum); + llvm::StringRef eltypeC = stringifyEnum(eltypeCEnum); + + auto getScaleVecSizeStr = [](ScaleVecSize svs) -> std::string { + switch (svs) { + case ScaleVecSize::X1: return ".scale_1x"; + case ScaleVecSize::X2: return ".scale_2x"; + case ScaleVecSize::X4: return ".scale_4x"; + } + return ""; + }; + }], + MMA_SP_BLOCK_SCALE_INTR<>.id, [{ + return 0; + } + + MMATypes accumPtxType(); + MMATypes resultPtxType(); + + static mlir::NVVM::IDArgPair + getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase& builder); + }]); + + let builders = [ + OpBuilder<(ins "Type":$resultType, "ValueRange":$operandA, + "ValueRange":$operandB, "ValueRange":$operandC, + "Value":$sparseMetadata, "Value":$sparsitySelector, + "Value":$scaleAData, "Value":$byteIdA, "Value":$threadIdA, + "Value":$scaleBData, "Value":$byteIdB, "Value":$threadIdB, + "ArrayRef":$shape, + "std::optional>":$multiplicandPtxTypes, + "ScaleVecSize":$scaleVecSize, + "BlockScaleFormat":$blockScaleFormat, + "MMABlockScaleKind":$kind)> + ]; + + string llvmBuilder = [{ + auto [id, args] = NVVM::MmaSpBlockScaleOp::getIntrinsicIDAndArgs( + *op, moduleTranslation, builder); + $res = createIntrinsicCall(builder, id, args); + }]; +} + //===----------------------------------------------------------------------===// // NVVM TMA Ops //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index ada4223ac12de..3387c69025e20 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -693,8 +693,8 @@ void MmaOp::print(OpAsmPrinter &p) { regTypes.push_back(this->getOperand(operandIdx).getType()); } } - std::optional inferredType = - inferOperandMMAType(regTypes.back(), /*isAccumulator=*/fragIdx >= 2); + std::optional inferredType = MmaOp::inferOperandMMAType( + regTypes.back(), /*isAccumulator=*/fragIdx >= 2); if (inferredType) ignoreAttrNames.push_back(frag.ptxTypeAttr); } @@ -1559,6 +1559,650 @@ LogicalResult MmaSpOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// MMA Block Scale Operations - Shared Helpers +//===----------------------------------------------------------------------===// + +namespace { +// Shared structure for MMA operand fragments (A, B, C) +struct OperandFragment { + StringRef operandName; + StringRef ptxTypeAttr; + SmallVector regs; + explicit OperandFragment(StringRef name, StringRef ptxTypeName) + : operandName(name), ptxTypeAttr(ptxTypeName) {} +}; + +// Helper to print operand list in the format: name[operands] +void printOperandList(OpAsmPrinter &p, StringRef name, + ArrayRef operands) { + p << " " << name << "["; + p.printOperands(operands); + p << "]"; +} + +// Helper to parse operand list in the format: name[operands] +LogicalResult +parseMmaOperand(OpAsmParser &parser, StringRef operandName, + SmallVectorImpl ®s) { + if (parser.parseKeyword(operandName).failed()) + return failure(); + if (parser.parseOperandList(regs, OpAsmParser::Delimiter::OptionalSquare) + .failed()) + return failure(); + return success(); +} + +// Helper to process operand fragments and determine which attributes can be +// inferred +template +void processOperandFragments(Op &op, std::array &frags, + SmallVectorImpl ®Types, + SmallVectorImpl &ignoreAttrNames) { + for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) { + auto &frag = frags[fragIdx]; + auto varOperandSpec = op.getODSOperandIndexAndLength(fragIdx); + for (auto operandIdx = varOperandSpec.first; + operandIdx < varOperandSpec.first + varOperandSpec.second; + operandIdx++) { + frag.regs.push_back(op.getOperand(operandIdx)); + if (fragIdx == 0 && operandIdx == varOperandSpec.first) { + regTypes.push_back(op.getOperand(operandIdx).getType()); + } + } + if (fragIdx < 2) { + regTypes.push_back(frag.regs[0].getType()); + } + std::optional inferredType = + MmaOp::inferOperandMMAType(regTypes.back(), + /*isAccumulator=*/fragIdx >= 2); + if (inferredType) + ignoreAttrNames.push_back(frag.ptxTypeAttr); + } +} + +// Helper to parse type signature: (A_type, B_type, C_type) +LogicalResult parseMmaTypeSignature(OpAsmParser &parser, + SmallVectorImpl &operandTypes) { + if (parser.parseColon().failed() || parser.parseLParen().failed()) + return failure(); + + for (int i = 0; i < 3; i++) { + if (i > 0 && parser.parseComma().failed()) + return failure(); + Type ty; + if (parser.parseType(ty).failed()) + return failure(); + operandTypes.push_back(ty); + } + + return parser.parseRParen(); +} + +// Helper to infer and set multiplicand PTX type attributes +void inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs, + const SmallVectorImpl &operandTypes) { + if (!attrs.get("multiplicandAPtxType")) { + if (auto inferredType = + MmaOp::inferOperandMMAType(operandTypes[0], false)) { + attrs.set("multiplicandAPtxType", MMATypesAttr::get(ctx, *inferredType)); + } + } + if (!attrs.get("multiplicandBPtxType")) { + if (auto inferredType = + MmaOp::inferOperandMMAType(operandTypes[1], false)) { + attrs.set("multiplicandBPtxType", MMATypesAttr::get(ctx, *inferredType)); + } + } +} + +// Helper to add common block scale attributes +void addBlockScaleAttributes(OpBuilder &builder, OperationState &result, + ArrayRef shape, ScaleVecSize scaleVecSize, + BlockScaleFormat blockScaleFormat, + MMABlockScaleKind kind) { + MLIRContext *ctx = builder.getContext(); + result.addAttribute( + "shape", builder.getAttr(shape[0], shape[1], shape[2])); + result.addAttribute("scaleVecSize", ScaleVecSizeAttr::get(ctx, scaleVecSize)); + result.addAttribute("blockScaleFormat", + BlockScaleFormatAttr::get(ctx, blockScaleFormat)); + result.addAttribute("kind", MMABlockScaleKindAttr::get(ctx, kind)); +} + +// Helper to infer and add multiplicand PTX types to builder +void addInferredMultiplicandTypes( + MLIRContext *ctx, OperationState &result, ValueRange operandA, + ValueRange operandB, + std::optional> multiplicandPtxTypes) { + if (multiplicandPtxTypes) { + result.addAttribute("multiplicandAPtxType", + MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0])); + result.addAttribute("multiplicandBPtxType", + MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1])); + } else { + if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false)) + result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res)); + if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false)) + result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res)); + } +} + +// Template helper for common accumPtxType/resultPtxType implementation +template +MMATypes inferPtxTypeFromResult(OpTy op) { + return *MmaOp::inferOperandMMAType( + cast(op.getRes().getType()).getBody()[0], + /*isAccumulator=*/true); +} +} // namespace + +//===----------------------------------------------------------------------===// +// MmaBlockScaleOp +//===----------------------------------------------------------------------===// + +void MmaBlockScaleOp::print(OpAsmPrinter &p) { + SmallVector regTypes; + std::array frags{ + OperandFragment("A", getMultiplicandAPtxTypeAttrName()), + OperandFragment("B", getMultiplicandBPtxTypeAttrName()), + OperandFragment("C", "")}; + SmallVector ignoreAttrNames{ + mlir::NVVM::MmaBlockScaleOp::getOperandSegmentSizeAttr()}; + + processOperandFragments(*this, frags, regTypes, ignoreAttrNames); + + // Print A, B, C operands + for (const auto &frag : frags) + printOperandList(p, frag.operandName, frag.regs); + + // Print scale operands + printOperandList(p, "scaleA", + {getScaleAData(), getByteIdA(), getThreadIdA()}); + printOperandList(p, "scaleB", + {getScaleBData(), getByteIdB(), getThreadIdB()}); + + p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames); + + // Print type signature + p << " : ("; + llvm::interleaveComma(SmallVector{frags[0].regs[0].getType(), + frags[1].regs[0].getType(), + frags[2].regs[0].getType()}, + p); + p << ")"; + p.printArrowTypeList(TypeRange{this->getRes().getType()}); +} + +ParseResult MmaBlockScaleOp::parse(OpAsmParser &parser, + OperationState &result) { + struct LocalOperandFragment { + std::optional elemtype; + SmallVector regs; + }; + + Builder &builder = parser.getBuilder(); + std::array frags; + NamedAttrList namedAttributes; + + // Parse A[...] B[...] C[...] + if (parseMmaOperand(parser, "A", frags[0].regs).failed() || + parseMmaOperand(parser, "B", frags[1].regs).failed() || + parseMmaOperand(parser, "C", frags[2].regs).failed()) + return failure(); + + // Parse scale operands: scaleA[...] scaleB[...] + SmallVector scaleAOperands, scaleBOperands; + if (parseMmaOperand(parser, "scaleA", scaleAOperands).failed() || + parseMmaOperand(parser, "scaleB", scaleBOperands).failed()) + return failure(); + + if (parser.parseOptionalAttrDict(namedAttributes).failed()) + return failure(); + + // Parse type signature + SmallVector operandTypes; + if (parseMmaTypeSignature(parser, operandTypes).failed()) + return failure(); + + // Parse result type + SmallVector resultTypes; + if (parser.parseArrowTypeList(resultTypes).failed()) + return failure(); + + // Infer element types and resolve operands + for (const auto &[idx, frag] : llvm::enumerate(frags)) { + frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx], + /*isAccumulator=*/idx >= 2); + if (parser + .resolveOperands(frag.regs, operandTypes[idx], parser.getNameLoc(), + result.operands) + .failed()) + return failure(); + } + + // Resolve scale operands + SmallVector scaleTypes = {builder.getI32Type(), builder.getI16Type(), + builder.getI16Type()}; + if (parser + .resolveOperands(scaleAOperands, scaleTypes, parser.getNameLoc(), + result.operands) + .failed() || + parser + .resolveOperands(scaleBOperands, scaleTypes, parser.getNameLoc(), + result.operands) + .failed()) + return failure(); + + // Add attributes + result.addAttributes(namedAttributes); + inferAndSetMultiplicandTypes(parser.getContext(), result.attributes, + operandTypes); + + result.addTypes(resultTypes); + result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr({ + static_cast(frags[0].regs.size()), + static_cast(frags[1].regs.size()), + static_cast(frags[2].regs.size()), + 1, // scaleAData + 1, // byteIdA + 1, // threadIdA + 1, // scaleBData + 1, // byteIdB + 1 // threadIdB + })); + return success(); +} + +void MmaBlockScaleOp::build( + OpBuilder &builder, OperationState &result, Type resultType, + ValueRange operandA, ValueRange operandB, ValueRange operandC, + Value scaleAData, Value byteIdA, Value threadIdA, Value scaleBData, + Value byteIdB, Value threadIdB, ArrayRef shape, + std::optional> multiplicandPtxTypes, + ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat, + MMABlockScaleKind kind) { + assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)"); + + addBlockScaleAttributes(builder, result, shape, scaleVecSize, + blockScaleFormat, kind); + + result.addOperands(operandA); + result.addOperands(operandB); + result.addOperands(operandC); + result.addOperands( + {scaleAData, byteIdA, threadIdA, scaleBData, byteIdB, threadIdB}); + + addInferredMultiplicandTypes(builder.getContext(), result, operandA, operandB, + multiplicandPtxTypes); + + result.addTypes(resultType); + result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr({ + static_cast(operandA.size()), + static_cast(operandB.size()), + static_cast(operandC.size()), + 1, // scaleAData + 1, // byteIdA + 1, // threadIdA + 1, // scaleBData + 1, // byteIdB + 1 // threadIdB + })); +} + +MMATypes MmaBlockScaleOp::accumPtxType() { + return inferPtxTypeFromResult(*this); +} + +MMATypes MmaBlockScaleOp::resultPtxType() { + return inferPtxTypeFromResult(*this); +} + +NVVM::IDArgPair MmaBlockScaleOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto curOp = cast(op); + + SmallVector args; + // Add A, B, C operands + for (Value operand : curOp.getOperandA()) + args.push_back(mt.lookupValue(operand)); + for (Value operand : curOp.getOperandB()) + args.push_back(mt.lookupValue(operand)); + for (Value operand : curOp.getOperandC()) + args.push_back(mt.lookupValue(operand)); + + // Add scale operands + args.push_back(mt.lookupValue(curOp.getScaleAData())); + args.push_back(mt.lookupValue(curOp.getByteIdA())); + args.push_back(mt.lookupValue(curOp.getThreadIdA())); + args.push_back(mt.lookupValue(curOp.getScaleBData())); + args.push_back(mt.lookupValue(curOp.getByteIdB())); + args.push_back(mt.lookupValue(curOp.getThreadIdB())); + + unsigned intId = MmaBlockScaleOp::getIntrinsicID( + curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(), + *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(), + curOp.accumPtxType(), curOp.getScaleVecSize(), + curOp.getBlockScaleFormat(), curOp.getKind()); + + return {intId, args}; +} + +LogicalResult MmaBlockScaleOp::verify() { + LogicalResult result = success(); + int m = getShape().getM(); + int n = getShape().getN(); + int k = getShape().getK(); + + if (m == 16 && n == 8 && k == 64) { + if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 || + getMultiplicandBPtxType() != NVVM::MMATypes::e2m1) + result = emitOpError( + "unsupported MMATypes attribute for mma.m16n8k64.(mxf4nvf4|mxf4)"); + if (getKind() == NVVM::MMABlockScaleKind::MXF4) { + if (getScaleVecSize() != NVVM::ScaleVecSize::X2) + result = emitOpError( + "unsupported ScaleVecSize attribute for mma.m16n8k64.mxf4"); + if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0) + result = emitOpError( + "unsupported BlockScaleFormat attribute for mma.m16n8k64.mxf4"); + } else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) { + if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 && + getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) || + (getScaleVecSize() == NVVM::ScaleVecSize::X4 && + getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3))) + result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat " + "attributes for mma.m16n8k64.mxf4nvf4"); + } else + result = emitOpError("unsupported Kind attribute for mma.m16n8k64"); + } else if (m == 16 && n == 8 && k == 32) { + if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 && + getScaleVecSize() == NVVM::ScaleVecSize::X1 && + getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0)) + result = + emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat " + "attributes for mma.m16n8k32"); + } else + result = emitOpError("unsupported Geom for mma with block scaling"); + return result; +} + +//===----------------------------------------------------------------------===// +// MmaSpBlockScaleOp +//===----------------------------------------------------------------------===// + +void MmaSpBlockScaleOp::print(OpAsmPrinter &p) { + SmallVector regTypes; + std::array frags{ + OperandFragment("A", getMultiplicandAPtxTypeAttrName()), + OperandFragment("B", getMultiplicandBPtxTypeAttrName()), + OperandFragment("C", "")}; + SmallVector ignoreAttrNames{ + mlir::NVVM::MmaSpBlockScaleOp::getOperandSegmentSizeAttr()}; + + processOperandFragments(*this, frags, regTypes, ignoreAttrNames); + + // Print A, B, C operands + for (const auto &frag : frags) + printOperandList(p, frag.operandName, frag.regs); + + // Print sparse-specific operands + printOperandList(p, "sparseMetadata", {getSparseMetadata()}); + printOperandList(p, "selector", {getSparsitySelector()}); + + // Print scale operands + printOperandList(p, "scaleA", + {getScaleAData(), getByteIdA(), getThreadIdA()}); + printOperandList(p, "scaleB", + {getScaleBData(), getByteIdB(), getThreadIdB()}); + + p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames); + + // Print type signature + p << " : ("; + llvm::interleaveComma(SmallVector{frags[0].regs[0].getType(), + frags[1].regs[0].getType(), + frags[2].regs[0].getType()}, + p); + p << ")"; + p.printArrowTypeList(TypeRange{this->getRes().getType()}); +} + +ParseResult MmaSpBlockScaleOp::parse(OpAsmParser &parser, + OperationState &result) { + struct LocalOperandFragment { + std::optional elemtype; + SmallVector regs; + }; + + Builder &builder = parser.getBuilder(); + std::array frags; + NamedAttrList namedAttributes; + + // Parse A[...] B[...] C[...] + if (parseMmaOperand(parser, "A", frags[0].regs).failed() || + parseMmaOperand(parser, "B", frags[1].regs).failed() || + parseMmaOperand(parser, "C", frags[2].regs).failed()) + return failure(); + + // Parse sparse-specific operands + SmallVector metadataOperands, + selectorOperands; + if (parseMmaOperand(parser, "sparseMetadata", metadataOperands).failed() || + parseMmaOperand(parser, "selector", selectorOperands).failed()) + return failure(); + + // Parse scale operands + SmallVector scaleAOperands, scaleBOperands; + if (parseMmaOperand(parser, "scaleA", scaleAOperands).failed() || + parseMmaOperand(parser, "scaleB", scaleBOperands).failed()) + return failure(); + + if (parser.parseOptionalAttrDict(namedAttributes).failed()) + return failure(); + + // Parse type signature + SmallVector operandTypes; + if (parseMmaTypeSignature(parser, operandTypes).failed()) + return failure(); + + // Parse result type + SmallVector resultTypes; + if (parser.parseArrowTypeList(resultTypes).failed()) + return failure(); + + // Infer element types and resolve operands + for (const auto &[idx, frag] : llvm::enumerate(frags)) { + frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx], + /*isAccumulator=*/idx >= 2); + if (parser + .resolveOperands(frag.regs, operandTypes[idx], parser.getNameLoc(), + result.operands) + .failed()) + return failure(); + } + + // Resolve sparse metadata and selector + Type i32Type = builder.getI32Type(); + if (parser + .resolveOperands(metadataOperands, i32Type, parser.getNameLoc(), + result.operands) + .failed() || + parser + .resolveOperands(selectorOperands, i32Type, parser.getNameLoc(), + result.operands) + .failed()) + return failure(); + + // Resolve scale operands + SmallVector scaleTypes = {i32Type, builder.getI16Type(), + builder.getI16Type()}; + if (parser + .resolveOperands(scaleAOperands, scaleTypes, parser.getNameLoc(), + result.operands) + .failed() || + parser + .resolveOperands(scaleBOperands, scaleTypes, parser.getNameLoc(), + result.operands) + .failed()) + return failure(); + + // Add attributes + result.addAttributes(namedAttributes); + inferAndSetMultiplicandTypes(parser.getContext(), result.attributes, + operandTypes); + + // orderedMetadata is mandatory + if (!result.attributes.get("orderedMetadata")) + result.addAttribute("orderedMetadata", builder.getUnitAttr()); + + result.addTypes(resultTypes); + result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr({ + static_cast(frags[0].regs.size()), + static_cast(frags[1].regs.size()), + static_cast(frags[2].regs.size()), + 1, // sparseMetadata + 1, // sparsitySelector + 1, // scaleAData + 1, // byteIdA + 1, // threadIdA + 1, // scaleBData + 1, // byteIdB + 1 // threadIdB + })); + return success(); +} + +void MmaSpBlockScaleOp::build( + OpBuilder &builder, OperationState &result, Type resultType, + ValueRange operandA, ValueRange operandB, ValueRange operandC, + Value sparseMetadata, Value sparsitySelector, Value scaleAData, + Value byteIdA, Value threadIdA, Value scaleBData, Value byteIdB, + Value threadIdB, ArrayRef shape, + std::optional> multiplicandPtxTypes, + ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat, + MMABlockScaleKind kind) { + assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)"); + + addBlockScaleAttributes(builder, result, shape, scaleVecSize, + blockScaleFormat, kind); + result.addAttribute("orderedMetadata", builder.getUnitAttr()); + + result.addOperands(operandA); + result.addOperands(operandB); + result.addOperands(operandC); + result.addOperands({sparseMetadata, sparsitySelector, scaleAData, byteIdA, + threadIdA, scaleBData, byteIdB, threadIdB}); + + addInferredMultiplicandTypes(builder.getContext(), result, operandA, operandB, + multiplicandPtxTypes); + + result.addTypes(resultType); + result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr({ + static_cast(operandA.size()), + static_cast(operandB.size()), + static_cast(operandC.size()), + 1, // sparseMetadata + 1, // sparsitySelector + 1, // scaleAData + 1, // byteIdA + 1, // threadIdA + 1, // scaleBData + 1, // byteIdB + 1 // threadIdB + })); +} + +MMATypes MmaSpBlockScaleOp::accumPtxType() { + return inferPtxTypeFromResult(*this); +} + +MMATypes MmaSpBlockScaleOp::resultPtxType() { + return inferPtxTypeFromResult(*this); +} + +NVVM::IDArgPair MmaSpBlockScaleOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto curOp = cast(op); + + SmallVector args; + // Add A, B, C operands + for (Value operand : curOp.getOperandA()) + args.push_back(mt.lookupValue(operand)); + for (Value operand : curOp.getOperandB()) + args.push_back(mt.lookupValue(operand)); + for (Value operand : curOp.getOperandC()) + args.push_back(mt.lookupValue(operand)); + + // Add sparse metadata and selector + args.push_back(mt.lookupValue(curOp.getSparseMetadata())); + args.push_back(mt.lookupValue(curOp.getSparsitySelector())); + + // Add scale operands + args.push_back(mt.lookupValue(curOp.getScaleAData())); + args.push_back(mt.lookupValue(curOp.getByteIdA())); + args.push_back(mt.lookupValue(curOp.getThreadIdA())); + args.push_back(mt.lookupValue(curOp.getScaleBData())); + args.push_back(mt.lookupValue(curOp.getByteIdB())); + args.push_back(mt.lookupValue(curOp.getThreadIdB())); + + unsigned intId = MmaSpBlockScaleOp::getIntrinsicID( + curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(), + *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(), + curOp.accumPtxType(), curOp.getScaleVecSize(), + curOp.getBlockScaleFormat(), curOp.getKind()); + + return {intId, args}; +} + +LogicalResult MmaSpBlockScaleOp::verify() { + // Check that orderedMetadata is present + if (!getOrderedMetadata()) { + return emitOpError("'orderedMetadata' attribute is mandatory"); + } + + LogicalResult result = success(); + int m = getShape().getM(); + int n = getShape().getN(); + int k = getShape().getK(); + + if (m == 16 && n == 8 && k == 128) { + if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 || + getMultiplicandBPtxType() != NVVM::MMATypes::e2m1) + result = emitOpError( + "unsupported MMATypes attribute for mma.m16n8k128.(mxf4nvf4|mxf4)"); + if (getKind() == NVVM::MMABlockScaleKind::MXF4) { + if (getScaleVecSize() != NVVM::ScaleVecSize::X2) + result = emitOpError( + "unsupported ScaleVecSize attribute for mma.m16n8k128.mxf4"); + if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0) + result = emitOpError( + "unsupported BlockScaleFormat attribute for mma.m16n8k128.mxf4"); + } else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) { + if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 && + getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) || + (getScaleVecSize() == NVVM::ScaleVecSize::X4 && + getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3))) + result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat " + "attributes for mma.m16n8k128.mxf4nvf4"); + } else + result = emitOpError("unsupported Kind attribute for mma.m16n8k128"); + } else if (m == 16 && n == 8 && k == 64) { + if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 && + getScaleVecSize() == NVVM::ScaleVecSize::X1 && + getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0)) + result = + emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat " + "attributes for mma.m16n8k64"); + } else + result = emitOpError("unsupported Geom for sparse mma with block scaling"); + return result; +} + LogicalResult ShflOp::verify() { auto returnStructType = llvm::dyn_cast(getType()); diff --git a/mlir/test/Dialect/LLVMIR/nvvm-mma-blockscale.mlir b/mlir/test/Dialect/LLVMIR/nvvm-mma-blockscale.mlir new file mode 100644 index 0000000000000..fbd0203d19904 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/nvvm-mma-blockscale.mlir @@ -0,0 +1,525 @@ +// RUN: mlir-opt %s -split-input-file | FileCheck %s + +// This file contains tests for all dense MMA block scale operations in the NVVM dialect +// Based on PTX ISA documentation: +// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-with-block-scaling +// +// MMA block scale operations perform matrix multiply-accumulate with block scaling: +// D = matmul(A * SF_A, B * SF_B) + C +// where SF_A and SF_B are scaling factors with dimensions based on scale vector size. + +// ============================================================================= +// MXF8F6F4 Block Scale MMA Operations (m16n8k32) - All Type Combinations +// ============================================================================= + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m1_e2m1 +func.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e2m1(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m1_e2m3 +func.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e2m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m1_e3m2 +func.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e3m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m1_e4m3 +func.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e4m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m1_e5m2 +func.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e5m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m3_e2m1 +func.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e2m1(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m3_e2m3 +func.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e2m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m3_e3m2 +func.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e3m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m3_e4m3 +func.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e4m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m3_e5m2 +func.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e5m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e3m2_e2m1 +func.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e2m1(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e3m2_e2m3 +func.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e2m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e3m2_e3m2 +func.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e3m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e3m2_e4m3 +func.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e4m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e3m2_e5m2 +func.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e5m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e4m3_e2m1 +func.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e2m1(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e4m3_e2m3 +func.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e2m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e4m3_e3m2 +func.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e3m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e4m3_e4m3 +func.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e4m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e4m3_e5m2 +func.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e5m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e5m2_e2m1 +func.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e2m1(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e5m2_e2m3 +func.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e2m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e5m2_e3m2 +func.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e3m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e5m2_e4m3 +func.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e4m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e5m2_e5m2 +func.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e5m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// ============================================================================= +// MXF4 Block Scale MMA Operations (m16n8k64) +// ============================================================================= + +// CHECK-LABEL: @nvvm_mxf4_blockscale_mma +func.func @nvvm_mxf4_blockscale_mma(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// ============================================================================= +// MXF4NVF4 Block Scale MMA Operations (m16n8k64) +// ============================================================================= + +// CHECK-LABEL: @nvvm_mxf4nvf4_blockscale_mma_ue8m0 +func.func @nvvm_mxf4nvf4_blockscale_mma_ue8m0(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf4nvf4_blockscale_mma_ue4m3 +func.func @nvvm_mxf4nvf4_blockscale_mma_ue4m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} diff --git a/mlir/test/Dialect/LLVMIR/nvvm-mma-sparse-blockscale.mlir b/mlir/test/Dialect/LLVMIR/nvvm-mma-sparse-blockscale.mlir new file mode 100644 index 0000000000000..2e72012bcf722 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/nvvm-mma-sparse-blockscale.mlir @@ -0,0 +1,637 @@ +// RUN: mlir-opt %s -split-input-file | FileCheck %s + +// This file contains tests for all sparse MMA block scale operations in the NVVM dialect +// Based on PTX ISA documentation: +// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-with-block-scaling +// +// Sparse MMA block scale operations perform matrix multiply-accumulate with block scaling +// on sparse matrices: D = matmul(A * SF_A, B * SF_B) + C +// where A follows 2:4 structured sparsity and SF_A, SF_B are scaling factors. + +// ============================================================================= +// MXF8F6F4 Sparse Block Scale MMA Operations (m16n8k64) - All Type Combinations +// ============================================================================= + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e2m1 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e2m1(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e2m3 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e2m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e3m2 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e3m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e4m3 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e4m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e5m2 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e5m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e2m1 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e2m1(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e2m3 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e2m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e3m2 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e3m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e4m3 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e4m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e5m2 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e5m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e2m1 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e2m1(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e2m3 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e2m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e3m2 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e3m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e4m3 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e4m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e5m2 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e5m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e2m1 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e2m1(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e2m3 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e2m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e3m2 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e3m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e4m3 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e4m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e5m2 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e5m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e2m1 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e2m1(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e2m3 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e2m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e3m2 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e3m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e4m3 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e4m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e5m2 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e5m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// ============================================================================= +// MXF4 Sparse Block Scale MMA Operations (m16n8k128) +// ============================================================================= + +// CHECK-LABEL: @nvvm_mxf4_sp_blockscale_mma +func.func @nvvm_mxf4_sp_blockscale_mma(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// ============================================================================= +// MXF4NVF4 Sparse Block Scale MMA Operations (m16n8k128) +// ============================================================================= + +// CHECK-LABEL: @nvvm_mxf4nvf4_sp_blockscale_mma_ue8m0 +func.func @nvvm_mxf4nvf4_sp_blockscale_mma_ue8m0(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf4nvf4_sp_blockscale_mma_ue4m3 +func.func @nvvm_mxf4nvf4_sp_blockscale_mma_ue4m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +}