Skip to content

Conversation

@schwarzschild-radius
Copy link
Contributor

This commit adds support for tgen05.mma family of instructions in the NVVM MLIR dialect and lowers to LLVM Intrinsics. Please refer PTX ISA for information

@llvmbot
Copy link
Member

llvmbot commented Oct 21, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Pradeep Kumar (schwarzschild-radius)

Changes

This commit adds support for tgen05.mma family of instructions in the NVVM MLIR dialect and lowers to LLVM Intrinsics. Please refer PTX ISA for information


Patch is 472.48 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/164356.diff

15 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+639)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+546)
  • (added) mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-shared.mlir (+229)
  • (added) mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-tensor.mlir (+229)
  • (added) mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-invalid.mlir (+119)
  • (added) mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-shared.mlir (+466)
  • (added) mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-shared.mlir (+229)
  • (added) mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-tensor.mlir (+229)
  • (added) mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-shared.mlir (+442)
  • (added) mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-tensor.mlir (+634)
  • (added) mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-tensor.mlir (+634)
  • (added) mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-shared.mlir (+133)
  • (added) mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-sp-shared.mlir (+133)
  • (added) mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-sp-tensor.mlir (+133)
  • (added) mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-ws-tensor.mlir (+133)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index d959464836043..a580a7f42bccc 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -4537,6 +4537,645 @@ def NVVM_ClusterLaunchControlQueryCancelOp
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// NVVM tcgen05.mma attributes
+//===----------------------------------------------------------------------===//
+
+def Tcgen05MMAKindF16          : I32EnumAttrCase<"F16",    0, "f16">;
+def Tcgen05MMAKindTF32         : I32EnumAttrCase<"TF32",   1, "tf32">;
+def Tcgen05MMAKindF8F6F4       : I32EnumAttrCase<"F8F6F4", 2, "f8f6f4">;
+def Tcgen05MMAKindINT8         : I32EnumAttrCase<"I8",     3, "i8">;
+
+def Tcgen05MMAKind : I32EnumAttr<
+  "Tcgen05MMAKind",
+  "tcgen05 MMA Supported Types",
+  [Tcgen05MMAKindF8F6F4, Tcgen05MMAKindINT8, Tcgen05MMAKindF16,
+   Tcgen05MMAKindTF32]> {
+    let cppNamespace = "::mlir::NVVM";
+    let genSpecializedAttr = 0;
+}
+
+def Tcgen05MMAKindAttr : EnumAttr<NVVM_Dialect, Tcgen05MMAKind, "tcgen05_mma_kind"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def Tcgen05MMACollectorOpDiscard  : I32EnumAttrCase<"DISCARD", 0, "discard">;
+def Tcgen05MMACollectorOpLastUse  : I32EnumAttrCase<"LASTUSE", 1, "lastuse">;
+def Tcgen05MMACollectorOpFill     : I32EnumAttrCase<"FILL",    2, "fill">;
+def Tcgen05MMACollectorOpUse      : I32EnumAttrCase<"USE",     3, "use">;
+
+def Tcgen05MMACollectorOp : I32EnumAttr<
+  "Tcgen05MMACollectorOp",
+  "tcgen05.mma Collector Buffer Operation",
+  [Tcgen05MMACollectorOpDiscard,
+   Tcgen05MMACollectorOpLastUse,
+   Tcgen05MMACollectorOpFill,
+   Tcgen05MMACollectorOpUse]> {
+    let cppNamespace = "::mlir::NVVM";
+    let genSpecializedAttr = 0;
+}
+
+def Tcgen05MMACollectorOpAttr : EnumAttr<NVVM_Dialect, Tcgen05MMACollectorOp, "tcgen05_mma_collectorop"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+//===----------------------------------------------------------------------===//
+// NVVM tcgen05.mma Ops.
+//===----------------------------------------------------------------------===//
+
+def NVVM_Tcgen05MMAOp : NVVM_Op<"tcgen05.mma", [AttrSizedOperandSegments]> {
+
+  let summary = "Performs MMA operation on 5th-gen tensor cores";
+
+  let arguments = (ins
+      // Attributes
+      Tcgen05MMAKindAttr:$kind,
+      CTAGroupKindAttr:$ctaGroup,
+      DefaultValuedAttr<Tcgen05MMACollectorOpAttr,
+                        "Tcgen05MMACollectorOp::DISCARD">:$collectorOp,
+      UnitAttr:$ashift,
+      // Arguments
+      LLVM_PointerTensor:$d,
+      AnyTypeOf<[LLVM_PointerTensor, I64]>:$a,
+      I64:$b,
+      I32:$idesc,
+      I1:$enableInputD,
+      // Optional arguments
+      Optional<I64>:$scaleInputD,
+      Optional<FixedVectorOfLengthAndType<[4, 8], [I32]>>:$disableOutputLane
+    );
+
+  let description = [{
+    The `tcgen05.mma` is an asynchronous op which performs matrix multiplication, 
+    and accumulation using 5th generation tensor cores
+
+    ```
+    D = A * B + (D * 2^ -scaleInputD)    // if `scaleInputD` is provided
+    D = A * B                            // if `enableInputD` is false
+    D = A * B + D                        // otherwise
+    ```
+
+    where:
+    - A is an `M x K` matrix in tensor memory or described using shared memory descriptor
+    - B is a `K x N` matrix described using shared memory descriptor
+    - D is an `M x N` accumulator matrix in tensory memory
+
+    `shared memory descriptor` is a 64 bit value which describes the properties
+    of multiplicand matrix in shared memory including its location in the shared
+    memory of the current CTA. For more details, please refer the
+    [PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-shared-memory-descriptor)
+
+    - idesc is a 32-bit value representing the [Instruction Descriptor](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instruction-descriptor)
+
+    Optional Operands:
+    - `scaleInputD` is an Immediate value operand used for scaling D matrix by 2 ^ (-scaleInputD). The valid range is [0, 15]
+
+    - `disableOutputLane` is a vector mask for selective output
+      * vector<4 x i32> when ctaGroup is CTA_1
+      * vector<8 x i32> when ctaGroup is CTA_2
+
+    Required Attributes:
+    - `kind` specifies the computation data type and precision
+      * f16    : 16-bit floating point (half precision)
+      * tf32   : Tensor Float 32 (truncated 32-bit float)
+      * f8f6f4 : Mixed precision FP8/FP6/FP4
+      * i8     : 8-bit integer operations
+
+    - `ctaGroup` specifies CTA group configuration
+      * cta_1: MMA will be performed on the current thread's CTA
+      * cta_2: MMA will be performed on the current thread and it's peer CTA
+
+    Default Attributes:
+    - collectorOp specifies the collector buffer operations for matrix A
+      * discard : Release buffer after use (default)
+      * lastuse : Mark buffer for last use
+      * fill    : Fill buffer
+      * use     : Use buffer without modification
+
+    - `ashift` shifts the rows of the A matrix down by one row
+
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma)
+  }];
+
+  let assemblyFormat = [{
+    $d `,` $a `,` $b `,` $idesc `,` $enableInputD (`scale` `=` $scaleInputD^)?
+    (`mask` `=` $disableOutputLane^)? attr-dict `:` `(` type(operands) `)`
+  }];
+
+  let extraClassDeclaration = [{
+    static mlir::NVVM::IDArgPair getIntrinsicIDAndArgs(
+        Operation &op, LLVM::ModuleTranslation &mt,
+        llvm::IRBuilderBase &builder);
+  }];
+
+  let llvmBuilder = [{
+    auto [ID, args] = NVVM::Tcgen05MMAOp::getIntrinsicIDAndArgs(
+        *op, moduleTranslation, builder);
+    createIntrinsicCall(builder, ID, args);
+  }];
+
+  let hasVerifier = true;
+}
+
+def NVVM_Tcgen05MMASpOp : NVVM_Op<"tcgen05.mma.sp", [AttrSizedOperandSegments]> {
+
+  let summary = "Performs MMA operation with sparse A matrix on 5th-gen tensor cores";
+
+  let arguments = (ins
+    // Attributes
+    Tcgen05MMAKindAttr:$kind,
+    CTAGroupKindAttr:$ctaGroup,
+    DefaultValuedAttr<Tcgen05MMACollectorOpAttr,
+                      "Tcgen05MMACollectorOp::DISCARD">:$collectorOp,
+    UnitAttr:$ashift,
+    // Arguments
+    LLVM_PointerTensor:$d,
+    AnyTypeOf<[LLVM_PointerTensor, I64]>:$a,
+    I64:$b,
+    I32:$idesc,
+    I1:$enableInputD,
+    LLVM_PointerTensor:$sparseMetadata,
+    Optional<I64>:$scaleInputD,
+    Optional<FixedVectorOfLengthAndType<[4, 8], [I32]>>:$disableOutputLane
+  );
+
+  let description = [{
+    The `tcgen05.mma.sp` performs matrix multiplication and accumulation with
+    sparse `A` matrix using 5th generation tensor cores.
+
+    It executes a non-blocking `M x N x K` MMA operation:
+    ```
+    D = A * B + (D * 2^ -scaleInputD)    // if `scaleInputD` is provided
+    D = A * B                            // if `enableInputD` is false
+    D = A * B + D                        // otherwise
+    ```
+
+    where:
+    - A is an `M x (K / 2)` matrix in tensor memory or described using shared memory descriptor
+    - B is a `K x N` matrix described using shared memory descriptor
+    - D is an `M x N` accumulator matrix in tensory memory
+    - sparseMetadata specifies the mapping of the `K / 2` non-zero elements to
+      the K elements before performing the MMA operation
+
+    `shared memory descriptor` is a 64 bit value which describes the properties
+    of multiplicand matrix in shared memory including its location in the shared
+    memory of the current CTA. For more details, please refer the
+    [PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-shared-memory-descriptor)
+
+    - `idesc` is a 32 bit value representing the [Instruction Descriptor](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instruction-descriptor)
+
+    Optional Operands:
+    - `scaleInputD` is an Immediate value operand used for scaling D matrix by 2 ^ (-scaleInputD). The valid range is [0, 15]
+
+    - `disableOutputLane` is a vector mask for selective output
+      * vector<4 x i32> when ctaGroup is CTA_1
+      * vector<8 x i32> when ctaGroup is CTA_2
+
+    Required Attributes:
+    - `kind` specifies the computation data type and precision
+      * f16    : 16-bit floating point (half precision)
+      * tf32   : Tensor Float 32 (truncated 32-bit float)
+      * f8f6f4 : Mixed precision FP8/FP6/FP4
+      * i8     : 8-bit integer operations
+
+    - `ctaGroup` specifies CTA group configuration
+      * cta_1: MMA will be performed on the current thread's CTA
+      * cta_2: MMA will be performed on the current thread and it's peer CTA
+
+    Default Attributes:
+    - collectorOp specifies the collector buffer operations for matrix A
+      * discard : Release buffer after use (default)
+      * lastuse : Mark buffer for last use
+      * fill    : Fill buffer
+      * use     : Use buffer without modification
+
+    - `ashift` shifts the rows of the A matrix down by one row
+
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma-sp)
+  }];
+
+  let assemblyFormat = [{
+    $d `,` $a `,` $b `,` $idesc `,` $enableInputD `,` $sparseMetadata (`scale` `=` $scaleInputD^)? (`mask` `=` $disableOutputLane^)? attr-dict `:` `(` type(operands) `)`
+  }];
+
+  let extraClassDeclaration = [{
+    static mlir::NVVM::IDArgPair getIntrinsicIDAndArgs(
+        Operation &op, LLVM::ModuleTranslation &mt,
+        llvm::IRBuilderBase &builder);
+  }];
+
+  let llvmBuilder = [{
+    auto [ID, args] = NVVM::Tcgen05MMASpOp::getIntrinsicIDAndArgs(
+        *op, moduleTranslation, builder);
+    createIntrinsicCall(builder, ID, args);
+  }];
+
+  let hasVerifier = true;
+}
+
+// tcgen05.mma.block_scale attribute
+def Tcgen05MMAKindMXF8F6F4     : I32EnumAttrCase<"MXF8F6F4", 0, "mxf8f6f4">;
+def Tcgen05MMAKindMXF4     : I32EnumAttrCase<"MXF4", 1, "mxf4">;
+def Tcgen05MMAKindMXF4NVF4     : I32EnumAttrCase<"MXF4NVF4", 2, "mxf4nvf4">;
+
+def Tcgen05MMABlockScaleKind : I32EnumAttr<
+  "Tcgen05MMABlockScaleKind",
+  "tcgen05.mma.block_scale supported types",
+  [Tcgen05MMAKindMXF8F6F4, Tcgen05MMAKindMXF4, Tcgen05MMAKindMXF4NVF4]> {
+    let cppNamespace = "::mlir::NVVM";
+    let genSpecializedAttr = 0;
+}
+
+def Tcgen05MMABlockScaleKindAttr : EnumAttr<NVVM_Dialect, Tcgen05MMABlockScaleKind,
+                                            "tcgen05_mma_block_scale_kind"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def Tcgen05MMABlockScaleDefault : I32EnumAttrCase<"DEFAULT", 0, "default">;
+def Tcgen05MMABlockScaleBlock16      : I32EnumAttrCase<"BLOCK16", 1, "block16">;
+def Tcgen05MMABlockScaleBlock32      : I32EnumAttrCase<"BLOCK32", 2, "block32">;
+
+def Tcgen05MMABlockScale
+    : I32EnumAttr<"Tcgen05MMABlockScale",
+                  "tcgen05.mma block scale attribute",
+                  [Tcgen05MMABlockScaleDefault, Tcgen05MMABlockScaleBlock16,
+                   Tcgen05MMABlockScaleBlock32]> {
+  let cppNamespace = "::mlir::NVVM";
+  let genSpecializedAttr = 0;
+}
+
+def Tcgen05MMABlockScaleAttr : EnumAttr<NVVM_Dialect, Tcgen05MMABlockScale,
+                                          "tcgen05_mma_block_scale"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+//===----------------------------------------------------------------------===//
+// NVVM tcgen05.mma.block_scale Op
+//===----------------------------------------------------------------------===//
+
+def NVVM_Tcgen05MMABlockScaleOp : NVVM_Op<"tcgen05.mma.block_scale"> {
+
+  let summary = "Performs block scaled MMA operation on 5th-gen tensor cores";
+
+  let arguments = (ins
+      // Attributes
+      Tcgen05MMABlockScaleKindAttr:$kind,
+      CTAGroupKindAttr:$ctaGroup,
+      DefaultValuedAttr<Tcgen05MMABlockScaleAttr,
+                      "Tcgen05MMABlockScale::DEFAULT">:$blockScale,
+      DefaultValuedAttr<Tcgen05MMACollectorOpAttr,
+                        "Tcgen05MMACollectorOp::DISCARD">:$collectorOp,
+      // Arguments
+      LLVM_PointerTensor:$d,
+      AnyTypeOf<[LLVM_PointerTensor, I64]>:$a,
+      I64:$b,
+      I32:$idesc, I1:$enableInputD,
+      LLVM_PointerTensor:$scaleA,
+      LLVM_PointerTensor:$scaleB
+    );
+
+  let description = [{
+    `nvvm.tcgen05.mma.block_scale` performs matrix multiplication, and
+    accumulate (MMA) using 5th generation tensor cores. It executes a non-blocking
+    `M x N x K` matrix operation. The matrices `A` and `B` are scaled before
+    performing the matrix multiply and accumulate operation.
+
+    It executes a non-blocking `M x N x K` MMA operation:
+
+    ```
+    D = (A * scale_a)  * (B * scale_b)`      // if `enableInputD` is false
+    D = (A * scale_a)  * (B * scale_b) + D`
+    ```
+
+    where:
+    - A is an M x (K / 2) matrix in tensor memory or described using shared memory descriptor
+    - B is a K x N matrix described using shared memory descriptor
+    - D is an M x N accumulator matrix in tensor memory
+    - `scale_a` and `scale_b` are matrices in tensor memory used to scale `A` and `B` respectively
+
+    `shared memory descriptor` is a 64 bit value which describes the properties
+    of multiplicand matrix in shared memory including its location in the shared
+    memory of the current CTA. For more details, please refer the
+    [PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-shared-memory-descriptor)
+
+    - `idesc` is a 32 bit value representing the [Instruction Descriptor](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instruction-descriptor)
+
+    Required Attributes:
+    - `kind` specifies the computation data type and precision
+      * mxf8f6f4 - MX-floating point formats
+      * mxf4     - MX-floating point formats (FP4)
+      * mxf4nvf4 - MXF4 + custom NVIDIA 4-bit format (with common scaling factor)
+
+    - `ctaGroup` specifies CTA group configuration
+      * cta_1: MMA will be performed on the current thread's CTA
+      * cta_2: MMA will be performed on the current thread and it's peer CTA
+
+    Default Attributes:
+    - collectorOp specifies the collector buffer operations for matrix A
+      * discard : Release buffer after use (default)
+      * lastuse : Mark buffer for last use
+      * fill    : Fill buffer
+      * use     : Use buffer without modification
+
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma)
+  }];
+
+  let assemblyFormat = [{
+    $d `,` $a `,` $b `,` $idesc `,` $enableInputD `,` $scaleA `,` $scaleB
+    attr-dict `:` `(` type(operands) `)`
+  }];
+
+  let extraClassDeclaration = [{
+    static mlir::NVVM::IDArgPair
+    getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+                          llvm::IRBuilderBase &builder);
+  }];
+
+  let llvmBuilder = [{
+    auto [ID, args] = NVVM::Tcgen05MMABlockScaleOp::getIntrinsicIDAndArgs(
+        *op, moduleTranslation, builder);
+    createIntrinsicCall(builder, ID, args);
+  }];
+  let hasVerifier = true;
+}
+
+def NVVM_Tcgen05MMASpBlockScaleOp : NVVM_Op<"tcgen05.mma.sp.block_scale"> {
+
+  let summary = "Performs block scaled MMA operation with sparse A matrix on 5th-gen tensor cores";
+
+  let arguments = (ins
+    // Attributes
+    Tcgen05MMABlockScaleKindAttr:$kind,
+    CTAGroupKindAttr:$ctaGroup,
+    DefaultValuedAttr<Tcgen05MMABlockScaleAttr,
+                      "Tcgen05MMABlockScale::DEFAULT">:$blockScale,
+    DefaultValuedAttr<Tcgen05MMACollectorOpAttr,
+                      "Tcgen05MMACollectorOp::DISCARD">:$collectorOp,
+    // Arguments
+    LLVM_PointerTensor:$d,
+    AnyTypeOf<[LLVM_PointerTensor, I64]>:$a,
+    I64:$b,
+    I32:$idesc,
+    I1:$enableInputD,
+    LLVM_PointerTensor:$sparseMetadata,
+    LLVM_PointerTensor:$scaleA,
+    LLVM_PointerTensor:$scaleB
+  );
+
+  let description = [{
+    `nvvm.tcgen05.mma.sp.block_scale` is an asynchronous op which performs
+    matrix multiplication, and accumulate with sparse A using 5th generation tensor cores
+
+    ```
+    D = (A * scale_a)  * (B * scale_b)      // if `enableInputD` is specified
+    D = (A * scale_a)  * (B * scale_b) + D  // otherwise
+    ```
+
+    where:
+    - A is an M x (K / 2) matrix in tensor memory or described using shared memory descriptor
+    - B is a K x N matrix described using shared memory descriptor
+    - D is an M x N accumulator matrix in tensor memory
+    - `scale_a` and `scale_b` are matrices in tensor memory used to scale `A` and `B` respectively
+
+    `shared memory descriptor` is a 64 bit value which describes the properties
+    of multiplicand matrix in shared memory including its location in the shared
+    memory of the current CTA. For more details, please refer the
+    [PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-shared-memory-descriptor)
+
+    Operands:
+    - `idesc` is a 32 bit value representing the [Instruction Descriptor](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instruction-descriptor)
+
+    - `sparseMetadata` specifies the mapping of the `K / 2` non-zero elements to
+      the K elements before performing the MMA operation
+
+    Required Attributes:
+    - `kind` specifies the computation data type and precision
+      * mxf8f6f4 - MX-floating point formats
+      * mxf4     - MX-floating point formats (FP4)
+      * mxf4nvf4 - MXF4 + custom NVIDIA 4-bit format (with common scaling factor)
+
+    - `ctaGroup` specifies CTA group configuration
+      * cta_1: MMA will be performed on the current thread's CTA
+      * cta_2: MMA will be performed on the current thread and it's peer CTA
+
+    Default Attributes:
+    - collectorOp specifies the collector buffer operations for matrix A
+      * discard : Release buffer after use (default)
+      * lastuse : Mark buffer for last use
+      * fill    : Fill buffer
+      * use     : Use buffer without modification
+
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma-sp)
+  }];
+
+  let assemblyFormat = [{
+    $d `,` $a `,` $b `,` $idesc `,` $enableInputD `,` $sparseMetadata `,`  $scaleA `,`  $scaleB
+    attr-dict `:` `(` type(operands) `)`
+  }];
+
+  let extraClassDeclaration = [{
+    static mlir::NVVM::IDArgPair
+    getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+                          llvm::IRBuilderBase &builder);
+  }];
+
+  let llvmBuilder = [{
+    auto [ID, args] = NVVM::Tcgen05MMASpBlockScaleOp::getIntrinsicIDAndArgs(
+        *op, moduleTranslation, builder);
+    createIntrinsicCall(builder, ID, args);
+  }];
+
+  let hasVerifier = true;
+}
+
+def Tcgen05MMACollectorBBuffer0  : I32EnumAttrCase<"B0", 0, "b0">;
+def Tcgen05MMACollectorBBuffer1  : I32EnumAttrCase<"B1", 1, "b1">;
+def Tcgen05MMACollectorBBuffer2  : I32EnumAttrCase<"B2", 2, "b2">;
+def Tcgen05MMACollectorBBuffer3  : I32EnumAttrCase<"B3", 3, "b3">;
+
+def Tcgen05MMACollectorBBuffer : I32EnumAttr<
+  "Tcgen05MMACollectorBBuffer",
+  "tcgen05 MMA Collector Buffer B Attribute",
+  [Tcgen05MMACollectorBBuffer0,
+  Tcgen05MMACollectorBBuffer1,
+  Tcgen05MMACollectorBBuffer2,
+  Tcgen05MMACollectorBBuffer3]> {
+    let cppNamespace = "::mlir::NVVM";
+    let genSpecializedAttr = 0;
+}
+
+def Tcgen05MMACollectorBBufferAttr : EnumAttr<NVVM_Dialect, Tcgen05MMACollectorBBuffer, "tcgen05_mma_collectorb"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+//===----------------------------------------------------------------------===//
+// NVVM tcgen05.mma.ws Op
+//===----------------------------------------------------------------------===//
+
+def NVVM_Tcgen05MMAWsOp : NVVM_Op<"tcgen05.mma.ws"> {
+    let summary = "Performs weight stationary convolution MMA operation on 5th-gen tensor cores";
+
+  let arguments = (ins
+    // Attributes
+    Tcgen05MMAKindAttr:$kind,
+    DefaultValuedAttr<Tcgen05MMACollectorBBufferAttr,
+                      "Tcgen05MMACollectorBBuffer::B0">:$collectorBBuffer,
+    DefaultValuedAttr<Tcgen05MMACollectorOpAttr,
+                      "Tcgen05MMACollectorOp::DISCARD">:$collectorOp,
+    // Arguments
+    LLVM_PointerTensor:$d,
+    AnyTypeOf<[LLVM_PointerTensor, I64]>:$a,
+    I64:$b,
+    I32:$idesc,
+    I1:$enableInputD,
+    Optional<I64>:$zeroColMask
+  );...
[truncated]

@github-actions
Copy link

github-actions bot commented Oct 21, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@schwarzschild-radius schwarzschild-radius force-pushed the tcgen05_mma_mlir_support branch 3 times, most recently from 6efa8ca to edd9ac3 Compare October 21, 2025 05:24
Comment on lines +4544 to +4593
def Tcgen05MMAKindF16 : I32EnumAttrCase<"F16", 0, "f16">;
def Tcgen05MMAKindTF32 : I32EnumAttrCase<"TF32", 1, "tf32">;
def Tcgen05MMAKindF8F6F4 : I32EnumAttrCase<"F8F6F4", 2, "f8f6f4">;
def Tcgen05MMAKindINT8 : I32EnumAttrCase<"I8", 3, "i8">;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we use TypeAttr in that case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, sometimes it is a mix of types like "f8f6f4", which does not have any equivalent individual type representation. So, I suppose we need to use a separate Attr for this case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added detailed descriptions for both the Kind attributes

// NVVM tcgen05.mma Ops.
//===----------------------------------------------------------------------===//

def NVVM_Tcgen05MMAOp : NVVM_Op<"tcgen05.mma", [AttrSizedOperandSegments]> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add NVVMRequiresSMa<100> verifier check here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added requires clause for sm100 and sm_110

Comment on lines 4598 to 4601
LLVM_PointerTensor:$d,
AnyTypeOf<[LLVM_PointerTensor, I64]>:$a,
I64:$b,
I32:$idesc,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nit : it's harder to get code completion with single char variable name. Maybe we write something explicit :)

Copy link
Contributor

@durga4github durga4github Oct 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, then, in that case, can we have matA and matB ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe matrixA, matrixB

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed $d -> $matrixD, $a -> $matrixA, $b -> $matrixB

I64:$b,
I32:$idesc,
I1:$enableInputD,
// Optional arguments
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's remove that comment, we have already Optional on the attributes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed both comments for both Attributes and Optional

Comment on lines 4609 to 4610
The `tcgen05.mma` is an asynchronous op which performs matrix multiplication,
and accumulation using 5th generation tensor cores
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The `tcgen05.mma` is an asynchronous op which performs matrix multiplication,
and accumulation using 5th generation tensor cores
The `tcgen05.mma` operation is an asynchronous tensor core instruction
that performs matrix multiplication, accumulation in a single fused
operation. It targets 5th-generation tensor cores, providing developers
with fine-grained control over execution and scheduling.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the description for all of the tcgen05.mma Ops with the above suggestion

- B is a `K x N` matrix described using shared memory descriptor
- D is an `M x N` accumulator matrix in tensory memory

`shared memory descriptor` is a 64 bit value which describes the properties
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe mention about the op tcgen05.mma_smem_desc that generates a descriptor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the doc line and pointed the shared memory descriptor creation to tcgen05.mma_smem_desc Op which also contains a detailed documentation

DefaultValuedAttr<Tcgen05MMACollectorOpAttr,
"Tcgen05MMACollectorOp::DISCARD">:$collectorOp,
UnitAttr:$ashift,
// Arguments
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the comments, it's clear what is Attribute what is argument

DefaultValuedAttr<Tcgen05MMACollectorOpAttr,
"Tcgen05MMACollectorOp::DISCARD">:$collectorOp,
UnitAttr:$ashift,
// Arguments
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the comments, it's clear what is Attribute what is argument

@schwarzschild-radius schwarzschild-radius force-pushed the tcgen05_mma_mlir_support branch 5 times, most recently from a0ddc62 to a605a51 Compare October 23, 2025 13:32
@rajatbajpai
Copy link
Contributor

LGTM, thanks!

args.push_back(DisableOutputLane);
args.push_back(builder.getInt32(static_cast<unsigned>(thisOp.getKind())));
} else {
if (hasScaleInputD) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you can move out this from both the branches.


```
+--------+--------------------------------------------+
| Matrix | A / B |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Matrix Kind | Supported types for A/B matrices

//===----------------------------------------------------------------------===//
// NVVM tcgen05.mma Ops.
//===----------------------------------------------------------------------===//

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can remove newline here

def NVVM_Tcgen05MMAOp : NVVM_Op<"tcgen05.mma",
[AttrSizedOperandSegments,
NVVMRequiresSMa<[100, 110]>]> {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can remove newline here

CTAGroupKindAttr:$ctaGroup,
DefaultValuedAttr<Tcgen05MMACollectorOpAttr,
"Tcgen05MMACollectorOp::DISCARD">:$collectorOp,
UnitAttr:$ashift,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optional:
Should we do aShift?

* f16 : 16-bit floating point (half precision)
* tf32 : Tensor Float 32 (truncated 32-bit float)
* f8f6f4 : Mixed precision FP8/FP6/FP4
* i8 : 8-bit integer operations
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we documented this in the EnumAttr itself, we do not need to repeat here.
Only line 4699 should suffice.

createIntrinsicCall(builder, ID, args);
}];

let hasVerifier = true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move it to line 4725 after the asm-format?
(to be consistent with most of the other Ops)

let hasVerifier = true;
}

def NVVM_Tcgen05MMASpOp : NVVM_Op<"tcgen05.mma.sp",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does NVVM_Tcgen05SparseMMAOp read better?

Let us keep "tcgen05.mma.sp" as is.


```
+------------+-------------------------------------------+
| | A / B |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add the same header and probably move this attr definition right after the previous one above?

createIntrinsicCall(builder, ID, args);
}];

let hasVerifier = true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let us move it after the asm-format


def NVVM_Tcgen05MMABlockScaleOp : NVVM_Op<"tcgen05.mma.block_scale",
[NVVMRequiresSMa<[100, 110]>]> {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can remove the newline here

"Tcgen05MMACollectorBBuffer",
"tcgen05 MMA Collector Buffer B Attribute",
[Tcgen05MMACollectorBBuffer0,
Tcgen05MMACollectorBBuffer1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we fold this to the previous line?


if (hasAShift && !isATensor)
res = emitError(loc,
"Only A operand in tensor memory support ashift attribute");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we rephrase slightly?
"A-shift can be applied only when matrix A is in tensor memory"


if (hasDisableOutputLane) {
if (hasScaleInputD)
args.push_back(ScaleInputD);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this scaleInpD seems to be common, should we move this out?

assert(ID != notIntrinsic && "Invalid intrinsic for Tcgen05MMAOp.");

args.push_back(
builder.getInt32(static_cast<unsigned>(thisOp.getCollectorOp())));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why is this here and not in line 2386, as a sequence of args.push_back() ?

unsigned ctaGroup =
static_cast<unsigned>(getNVVMCtaGroupKind(thisOp.getCtaGroup()));

bool isATensor = isa<llvm::PointerType>(A->getType());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we move this to line 2896, so that it is easier to see the use of 'A' ?
(like how it is done in 2905-2907)

static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;

static constexpr llvm::nvvm::CTAGroupKind
getNVVMCtaGroupKind(NVVM::CTAGroupKind ctaGroup) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice for implementing this utility!

[optional]

In all the use of this utility below, I observe that we are casting the value to int.

Do you think it makes sense to do the cast here and return unsigned directly from here?
(Then we can name it getNVVMCtaGroupKindAsInt or something like that)

Copy link
Contributor

@durga4github durga4github left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a few tiny nits. It should be good to go after a refresh.


let assemblyFormat = [{
$matrixD `,` $matrixA `,` $matrixB `,` $idesc `,` $enableInputD `,` $sparseMetadata (`scale` `=` $scaleInputD^)? (`mask` `=` $disableOutputLane^)? attr-dict `:` `(` type(operands) `)`
}];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for moving the asm-format right next to the args. It is much easier to read and relate quickly

Copy link
Contributor

@durga4github durga4github left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latest revision LGTM.
Thank you for addressing the comments!

// NVVM tcgen05.mma attributes
//===----------------------------------------------------------------------===//

def Tcgen05MMAKindF16 : I32EnumAttrCase<"F16", 0, "f16">;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this f16_bf16 so it's clear on the IR?


```
+-------------+--------------------------------------------+
| Matrix Kind | supported types for A / B |
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice ! thanks


static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;

static constexpr llvm::nvvm::CTAGroupKind
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you comment on top of the function

args.push_back(mt.lookupValue(thisOp.getEnableInputD()));

// [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift];
static constexpr llvm::Intrinsic::ID tcgen05MMAIDs[2][2][2][2][2] = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We implement this table in slightly more readable way. we can follow the same style if you like
https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp#L1797-L1801

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@grypp Are you saying that we should split the table into separate ones for scaled, disable output lane etc.?

disableOutputLaneType.getNumElements() != 4) ||
(ctaGroup == NVVM::CTAGroupKind::CTA_2 &&
disableOutputLaneType.getNumElements() != 8))
res = emitError(loc) << "Disable Output Lane of length "
Copy link
Member

@grypp grypp Oct 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We follow the following style when returning errors. Maybe we follow the same style here.

verifyTcgen05MMAOp(...) {
 if ...   
   return emitError(loc) << "Disable O

  return success();

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the verifiers in the latest revision

args.push_back(mt.lookupValue(thisOp.getSparseMetadata()));

// [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift];
static constexpr llvm::Intrinsic::ID tcgen05MMASparseIDs[2][2][2][2][2] = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment goes for the creating table in the same style

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@grypp Can you please confirm what the expected change is? I looked at the code and I noticed that we are using std::array but here we have a 5 dimensional matrix, so having a std::array type would be verbose (std::array<std::array<std::array<std::array<std::array<unsigned, 2>, 2>, 2>, 2>, 2>). What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Guray meant something like below.

using EnableAShiftArray = std::array<llvm::Intrinsic::ID, 2>;
using CtaGroupArray = std::array<EnableAShiftArray, 2>;
using IsATensorArray = std::array<CtaGroupArray, 2>;
using HasScaleInputDArray = std::array<IsATensorArray, 2>;
using HasDisableOutputLaneArray = std::array<HasScaleInputDArray, 2>;

// [hasDisableOutputLane][hasScaleInputD][isATensor][CtaGroup][EnableAShift]
static constexpr HasDisableOutputLaneArray tcgen05MMAIDs = {

res = emitError(loc,
llvm::formatv("{} kind does not support block16 attribute",
stringifyEnum(kind)));
return res;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment about the style of the verifier.

Copy link
Member

@grypp grypp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR is in a good shape. I've 2 major comments that are mostly stylistic

  1. Verifier
  2. Creating intrinsic table

This commit adds support for tgen05.mma family of instructions in the NVVM MLIR dialect and lowers to LLVM Intrinsics. Please refer [PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions) for information
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants