Skip to content

Conversation

Hsiangkai
Copy link
Contributor

Current convertVectorToMMAOps starts from vector.contract and finds its dependencies as the targets to convert. In GPU dialect, we have gpu.subgroup_mma_elementwise operation. We should be able to lower element-wise operations to GPU MMA operations without vector.contract. This patch adds this case to the pattern.

…only

Current convertVectorToMMAOps starts from vector.contract and finds its
dependencies as the targets to convert. In GPU dialect, we have
gpu.subgroup_mma_elementwise operation. We should be able to lower
element-wise operations to GPU MMA operations without vector.contract.
This patch adds this case to the pattern.
@llvmbot
Copy link
Member

llvmbot commented Sep 16, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Hsiangkai Wang (Hsiangkai)

Changes

Current convertVectorToMMAOps starts from vector.contract and finds its dependencies as the targets to convert. In GPU dialect, we have gpu.subgroup_mma_elementwise operation. We should be able to lower element-wise operations to GPU MMA operations without vector.contract. This patch adds this case to the pattern.


Full diff: https://github.com/llvm/llvm-project/pull/159091.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+6-3)
  • (modified) mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir (+19)
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 1d1904f717335..e373e9fe63500 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -355,11 +355,14 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
   forwardSliceOptions.filter = hasVectorSrc;
 
   SetVector<Operation *> opToConvert;
-  op->walk([&](vector::ContractionOp contract) {
-    if (opToConvert.contains(contract.getOperation()))
+  op->walk([&](Operation *nestedOp) {
+    if (!isa<vector::ContractionOp>(nestedOp) &&
+        !elementwiseSupportsMMAMatrixType(nestedOp))
+      return;
+    if (opToConvert.contains(nestedOp))
       return;
     SetVector<Operation *> dependentOps =
-        getSliceContract(contract, backwardSliceOptions, forwardSliceOptions);
+        getSliceContract(nestedOp, backwardSliceOptions, forwardSliceOptions);
     // If any instruction cannot use MMA matrix type drop the whole
     // chain. MMA matrix are stored in an opaque type so they cannot be used
     // by all operations.
diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
index b8ac63f89af33..ef72901750479 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
@@ -536,3 +536,22 @@ func.func @test_unsupported(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg
                         %0, %1, %arg2 : vector<4x4xi64>, vector<4x4xi64> into vector<4x4xi64>
   return %2 : vector<4x4xi64>
 }
+
+// -----
+
+#map0 = affine_map<(d0, d1) -> (d1, d0)>
+
+// CHECK-LABEL: func @addf
+//       CHECK:   %[[A:.+]] = gpu.subgroup_mma_load_matrix {{.+}} {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
+//       CHECK:   %[[B:.+]] = gpu.subgroup_mma_load_matrix {{.+}} {leadDimension = 16 : index, transpose} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
+//       CHECK:   %[[C:.+]] = gpu.subgroup_mma_elementwise  addf %[[A]], %[[B]] : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+//       CHECK:   gpu.subgroup_mma_store_matrix %[[C]]
+func.func @addf(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f16
+  %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+  %B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+  %C = arith.addf %A, %B : vector<16x16xf16>
+  vector.transfer_write %C, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
+  return
+}

@Hsiangkai Hsiangkai merged commit 46ad540 into llvm:main Sep 17, 2025
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants