- 
                Notifications
    
You must be signed in to change notification settings  - Fork 15.1k
 
[MLIR][GPU] subgroup_mma fp64 extension #165873
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| 
          
 @llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Giacomo Castiglioni (castigli) ChangesThis PR extends the  Full diff: https://github.com/llvm/llvm-project/pull/165873.diff 7 Files Affected: 
 diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
index 4c8abea680b66..48982ac6efe7c 100644
--- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
+++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
@@ -27,7 +27,7 @@ class MMAMatrixType;
 #define GEN_PASS_DECL_CONVERTGPUOPSTONVVMOPS
 #include "mlir/Conversion/Passes.h.inc"
 
-LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type);
+Type convertMMAToLLVMType(gpu::MMAMatrixType type);
 
 /// Configure target to convert from the GPU dialect to NVVM.
 void configureGpuToNVVMConversionLegality(ConversionTarget &target);
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
index 860f893367203..2c29bb8a01a41 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
@@ -114,7 +114,7 @@ def GPU_MMAMatrix : DialectType<
   GPU_Dialect, IsMMAMatrixTypePred, "MMAMatrix type">;
 
 // Memref type acceptable to gpu.subgroup_mma_{load|store}_matrix ops.
-def GPU_MMAMemRef : MemRefOf<[I8, I32, F16, F32, VectorOfRankAndType<[1], [I8, I32, F16, F32]>]>;
+def GPU_MMAMemRef : MemRefOf<[I8, I32, F16, F32, F64, VectorOfRankAndType<[1], [I8, I32, F16, F32, F64]>]>;
 
 class MMAMatrixOf<list<Type> allowedTypes> :
   ContainerType<AnyTypeOf<allowedTypes>, IsMMAMatrixTypePred,
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index a6c6038e1e224..5c7df25c58cde 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1872,7 +1872,7 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
     ```
   }];
 
-  let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, I32, F16, F32]>>:$src,
+  let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, I32, F16, F32, F64]>>:$src,
                   Arg<GPU_MMAMemRef, "",[MemWriteAt<0, FullEffect>]>:$dstMemref,
                   Variadic<Index>:$indices,
                   IndexAttr:$leadDimension,
@@ -1919,9 +1919,9 @@ def GPU_SubgroupMmaComputeOp
     ```
   }];
 
-  let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, F16, F32]>>:$opA,
-                  Arg<MMAMatrixOf<[SI8, UI8, F16, F32]>>:$opB,
-                  Arg<MMAMatrixOf<[I32, F16, F32]>>:$opC,
+  let arguments = (ins Arg<MMAMatrixOf<[SI8, UI8, F16, F32, F64]>>:$opA,
+                  Arg<MMAMatrixOf<[SI8, UI8, F16, F32, F64]>>:$opB,
+                  Arg<MMAMatrixOf<[I32, F16, F32, F64]>>:$opC,
                   OptionalAttr<UnitAttr>:$a_transpose,
                   OptionalAttr<UnitAttr>:$b_transpose);
 
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 99c059cb03299..fb1a37a03fe4d 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Types.h"
 
 using namespace mlir;
 
@@ -57,7 +58,8 @@ static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) {
   if (type.getElementType().isF32())
     return type.getOperand() == "COp" ? NVVM::MMATypes::f32
                                       : NVVM::MMATypes::tf32;
-
+  if (type.getElementType().isF64())
+    return NVVM::MMATypes::f64;
   if (type.getElementType().isSignedInteger(8))
     return NVVM::MMATypes::s8;
   if (type.getElementType().isUnsignedInteger(8))
@@ -212,8 +214,13 @@ struct WmmaMmaOpToNVVMLowering
     // then passed on to the intrinsic call. Emit llvm ops to extract individual
     // values form lowered memrefs.
     SmallVector<Value> unpackedOps;
-
     auto unpackOp = [&](Value operand) {
+      // f64 a and b fragments are not structs but scalars.
+      if (!isa<LLVM::LLVMStructType>(operand.getType())) {
+        unpackedOps.push_back(operand);
+        return;
+      }
+      // every other type is lowered to an LLVM struct, extract the values.
       auto structType = cast<LLVM::LLVMStructType>(operand.getType());
       for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) {
         Value toUse = LLVM::ExtractValueOp::create(rewriter, loc, operand, i);
@@ -276,10 +283,16 @@ struct WmmaConstantOpToNVVMLowering
       return failure();
     Location loc = subgroupMmaConstantOp.getLoc();
     Value cst = adaptor.getOperands()[0];
-    LLVM::LLVMStructType type = convertMMAToLLVMType(
+    Type type = convertMMAToLLVMType(
         cast<gpu::MMAMatrixType>(subgroupMmaConstantOp.getType()));
+    // If the element is not a struct, it means it's a scalar f64.
+    LLVM::LLVMStructType structType = dyn_cast<LLVM::LLVMStructType>(type);
+    if (!structType) {
+      rewriter.replaceOp(subgroupMmaConstantOp, cst);
+      return success();
+    }
     // If the element type is a vector create a vector from the operand.
-    if (auto vecType = dyn_cast<VectorType>(type.getBody()[0])) {
+    if (auto vecType = dyn_cast<VectorType>(structType.getBody()[0])) {
       Value vecCst = LLVM::PoisonOp::create(rewriter, loc, vecType);
       for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) {
         Value idx = LLVM::ConstantOp::create(rewriter, loc,
@@ -289,8 +302,8 @@ struct WmmaConstantOpToNVVMLowering
       }
       cst = vecCst;
     }
-    Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, type);
-    for (size_t i : llvm::seq(size_t(0), type.getBody().size())) {
+    Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, structType);
+    for (size_t i : llvm::seq(size_t(0), structType.getBody().size())) {
       matrixStruct =
           LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, cst, i);
     }
@@ -354,10 +367,24 @@ struct WmmaElementwiseOpToNVVMLowering
       return failure();
     Location loc = subgroupMmaElementwiseOp.getLoc();
     size_t numOperands = adaptor.getOperands().size();
-    LLVM::LLVMStructType destType = convertMMAToLLVMType(
+    Type destType = convertMMAToLLVMType(
         cast<gpu::MMAMatrixType>(subgroupMmaElementwiseOp.getType()));
-    Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, destType);
-    for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) {
+
+    // If the element is not a struct, it means it's a scalar f64.
+    LLVM::LLVMStructType structDestTy = dyn_cast<LLVM::LLVMStructType>(destType);
+    if (!structDestTy) {
+      SmallVector<Value> operands;
+      for (auto operand : adaptor.getOperands()) {
+        operands.push_back(operand);
+      }
+      Value element =
+          createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.getOpType(),
+                         operands);
+      rewriter.replaceOp(subgroupMmaElementwiseOp, element);
+      return success();
+    }
+    Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, structDestTy);
+    for (size_t i = 0, e = structDestTy.getBody().size(); i < e; ++i) {
       SmallVector<Value> extractedOperands;
       for (size_t opIdx = 0; opIdx < numOperands; opIdx++) {
         extractedOperands.push_back(LLVM::ExtractValueOp::create(
@@ -377,13 +404,18 @@ struct WmmaElementwiseOpToNVVMLowering
 } // namespace
 
 /// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
-LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
+Type mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
   NVVM::MMAFrag frag = convertOperand(type.getOperand());
   NVVM::MMATypes eltType = getElementType(type);
   auto nRow = type.getShape()[0];
   auto nCol = type.getShape()[1];
   std::pair<Type, unsigned> typeInfo =
       NVVM::inferMMAType(eltType, frag, nRow, nCol, type.getContext());
+  // Special handling for f64 a and b fragments
+  Type f64Ty = Float64Type::get(type.getContext());
+  if (typeInfo.first == f64Ty && typeInfo.second == 1) {
+    return f64Ty;
+  }
   return LLVM::LLVMStructType::getLiteral(
       type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
 }
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 6c6d8d2bad55d..61a630aa88960 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -208,7 +208,7 @@ Type MMAMatrixType::getElementType() const { return getImpl()->elementType; }
 StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); }
 
 bool MMAMatrixType::isValidElementType(Type elementType) {
-  return elementType.isF16() || elementType.isF32() ||
+  return elementType.isF16() || elementType.isF32() || elementType.isF64() ||
          elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) ||
          elementType.isInteger(32);
 }
@@ -225,7 +225,7 @@ MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
 
   if (!MMAMatrixType::isValidElementType(elementType))
     return emitError()
-           << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32";
+           << "MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64";
 
   return success();
 }
diff --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
index b479467efc208..83b5fb5e6ea54 100644
--- a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
@@ -79,6 +79,28 @@ gpu.module @test_module {
 
 // -----
 
+gpu.module @test_module {
+
+  // CHECK-LABEL: func @gpu_wmma_f64_load_op() ->
+  // CHECK-SAME: f64
+  // CHECK32-LABEL: func @gpu_wmma_f64_load_op() ->
+  func.func @gpu_wmma_f64_load_op() -> (!gpu.mma_matrix<8x4xf64, "AOp">) {
+    %wg = memref.alloca() {alignment = 32} : memref<32x32xf64, 3>
+    %i = arith.constant 16 : index
+    %j = arith.constant 16 : index
+    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf64, 3> -> !gpu.mma_matrix<8x4xf64, "AOp">
+    return %0 : !gpu.mma_matrix<8x4xf64, "AOp">
+    // CHECK: %[[MUL:.*]] = llvm.mul %{{.*}}, %{{.*}} : i64
+    // CHECK: %[[ADD:.*]] = llvm.add %[[MUL]], %{{.*}} : i64
+    // CHECK: %[[GEP:.*]] = llvm.getelementptr %{{.*}}[%[[ADD]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f64
+    // CHECK: %[[C32_I32:.*]] = llvm.mlir.constant(32 : index) : i32
+    // CHECK: %[[LOAD:.*]] = nvvm.wmma.load %[[GEP]], %[[C32_I32]] {eltype = #nvvm.mma_type<f64>, frag = #nvvm.mma_frag<a>, k = 4 : i32, layout = #nvvm.mma_layout<row>, m = 8 : i32, n = 8 : i32} : (!llvm.ptr<3>) -> f64
+    // CHECK: llvm.return %[[LOAD]] : f64
+  }
+}
+
+// -----
+
 gpu.module @test_module {
 
   // CHECK-LABEL: func @gpu_wmma_store_op
diff --git a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f64.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f64.mlir
new file mode 100644
index 0000000000000..a016a60022699
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f64.mlir
@@ -0,0 +1,72 @@
+// RUN: mlir-opt %s \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_80 cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-runner \
+// RUN:   --shared-libs=%mlir_cuda_runtime \
+// RUN:   --shared-libs=%mlir_runner_utils \
+// RUN:   --entry-point-result=void \
+// RUN: | FileCheck %s
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+
+func.func @main() {
+  %a = memref.alloc() : memref<8x4xf64>
+  %b = memref.alloc() : memref<4x8xf64>
+  %c = memref.alloc() : memref<8x8xf64>
+  %d = memref.alloc() : memref<8x8xf64>
+
+  %f1 = arith.constant 1.0e+00 : f64
+  %fcst = arith.constant 3.14e+00 : f64
+  %c0 = arith.constant 0 : index
+  %c8 = arith.constant 8 : index
+  %c4 = arith.constant 4 : index
+  %c1 = arith.constant 1 : index
+  %c32 = arith.constant 32 : index
+
+  // Initialize the Input matrixes with ones.
+  scf.for %arg0 = %c0 to %c8 step %c1 {
+    scf.for %arg1 = %c0 to %c4 step %c1 {
+      memref.store %f1, %a[%arg0, %arg1] : memref<8x4xf64>
+      memref.store %f1, %b[%arg1, %arg0] : memref<4x8xf64>
+    }
+  }
+  // Initialize the accumulator matrix with a constant.
+  scf.for %arg0 = %c0 to %c8 step %c1 {
+    scf.for %arg1 = %c0 to %c8 step %c1 {
+      memref.store %fcst, %c[%arg0, %arg1] : memref<8x8xf64>
+    }
+  }
+
+  %2 = memref.cast %a : memref<8x4xf64> to memref<*xf64>
+  %20 = memref.cast %b : memref<4x8xf64> to memref<*xf64>
+  %33 = memref.cast %c : memref<8x8xf64> to memref<*xf64>
+  %34 = memref.cast %d : memref<8x8xf64> to memref<*xf64>
+
+  gpu.host_register %2 : memref<*xf64>
+  gpu.host_register %20 : memref<*xf64>
+  gpu.host_register %33 : memref<*xf64>
+
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+             threads(%tx, %ty, %tz) in (%block_x = %c32, %block_y = %c1, %block_z = %c1) {
+    %A = gpu.subgroup_mma_load_matrix %a[%c0, %c0] {leadDimension = 4 : index} : memref<8x4xf64> -> !gpu.mma_matrix<8x4xf64, "AOp">
+    %B = gpu.subgroup_mma_load_matrix %b[%c0, %c0] {leadDimension = 8 : index} : memref<4x8xf64> -> !gpu.mma_matrix<4x8xf64, "BOp">
+    %C = gpu.subgroup_mma_load_matrix %c[%c0, %c0] {leadDimension = 8 : index} : memref<8x8xf64> -> !gpu.mma_matrix<8x8xf64, "COp">
+
+    %R = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<8x4xf64, "AOp">, !gpu.mma_matrix<4x8xf64, "BOp"> -> !gpu.mma_matrix<8x8xf64, "COp">
+
+    gpu.subgroup_mma_store_matrix %R, %d[%c0, %c0] {leadDimension = 8 : index}: !gpu.mma_matrix<8x8xf64, "COp">, memref<8x8xf64>
+    gpu.terminator
+  }
+  // Print the memref after computation.
+  call @printMemrefF64(%34) : (memref<*xf64>) -> ()
+  // CHECK: [7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14],
+  // CHECK-NEXT: [7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14],
+  // CHECK-NEXT: [7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14],
+  // CHECK-NEXT: [7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14],
+  // CHECK-NEXT: [7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14],
+  // CHECK-NEXT: [7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14],
+  // CHECK-NEXT: [7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14],
+  // CHECK-NEXT: [7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14,   7.14]
+  return
+}
+
+func.func private @printMemrefF64(memref<*xf64>)
 | 
    
| 
          
 ✅ With the latest revision this PR passed the C/C++ code formatter.  | 
    
This PR extends the
gpu.subgroup_mma_*ops to support fp64 type.The extension requires special handling during the lowering to
nvvmdue to the return type for load ops for fragment a and b (they return a scalar instead of a struct).