Skip to content

[mlir][spirv] Handle scalar shuffles in vector to spirv conversion #98809

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 14, 2024

Conversation

kuhar
Copy link
Member

@kuhar kuhar commented Jul 14, 2024

These may not get canonicalized before conversion to spirv and need to be handled during vector to spirv conversion. Because spirv does not support 1-element vectors, we can't emit spirv.VectorShuffle and need to lower this to spirv.CompositeExtract.

These may not get canonicalized before conversion to spirv and need to
be handled during vector to spirv conversion. Because spirv does not
support 1-element vectors, we can't emit `spirv.VectorShuffle` and need
to lower this to `spirv.CompositeExtract`.
@llvmbot
Copy link
Member

llvmbot commented Jul 14, 2024

@llvm/pr-subscribers-mlir-spirv

Author: Jakub Kuderski (kuhar)

Changes

These may not get canonicalized before conversion to spirv and need to be handled during vector to spirv conversion. Because spirv does not support 1-element vectors, we can't emit spirv.VectorShuffle and need to lower this to spirv.CompositeExtract.


Full diff: https://github.com/llvm/llvm-project/pull/98809.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+16-9)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+24)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c9363295ec32f..a4390447532a5 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -521,7 +521,7 @@ struct VectorShuffleOpConvert final
   LogicalResult
   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto oldResultType = shuffleOp.getResultVectorType();
+    VectorType oldResultType = shuffleOp.getResultVectorType();
     Type newResultType = getTypeConverter()->convertType(oldResultType);
     if (!newResultType)
       return rewriter.notifyMatchFailure(shuffleOp,
@@ -532,20 +532,22 @@ struct VectorShuffleOpConvert final
           return cast<IntegerAttr>(attr).getValue().getZExtValue();
         });
 
-    auto oldV1Type = shuffleOp.getV1VectorType();
-    auto oldV2Type = shuffleOp.getV2VectorType();
+    VectorType oldV1Type = shuffleOp.getV1VectorType();
+    VectorType oldV2Type = shuffleOp.getV2VectorType();
 
-    // When both operands are SPIR-V vectors, emit a SPIR-V shuffle.
-    if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1) {
+    // When both operands and the result are SPIR-V vectors, emit a SPIR-V
+    // shuffle.
+    if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 &&
+        oldResultType.getNumElements() > 1) {
       rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
           shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
           rewriter.getI32ArrayAttr(mask));
       return success();
     }
 
-    // When at least one of the operands becomes a scalar after type conversion
-    // for SPIR-V, extract all the required elements and construct the result
-    // vector.
+    // When at least one of the operands or the result becomes a scalar after
+    // type conversion for SPIR-V, extract all the required elements and
+    // construct the result vector.
     auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
                                Value scalarOrVec, int32_t idx) -> Value {
       if (auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
@@ -569,9 +571,14 @@ struct VectorShuffleOpConvert final
       newOperand = getElementAtIdx(vec, elementIdx);
     }
 
+    // Handle the scalar result corner case.
+    if (newOperands.size() == 1) {
+      rewriter.replaceOp(shuffleOp, newOperands.front());
+      return success();
+    }
+
     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
         shuffleOp, newResultType, newOperands);
-
     return success();
   }
 };
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 0d67851dfe41d..667aad7645c51 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -483,6 +483,30 @@ func.func @shuffle(%v0 : vector<1xi32>, %v1: vector<1xi32>) -> vector<2xi32> {
 
 // -----
 
+// CHECK-LABEL:  func @shuffle
+//  CHECK-SAME:  %[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>
+//       CHECK:    %[[EXTR:.+]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<4xi32>
+//       CHECK:    %[[RES:.+]]  = builtin.unrealized_conversion_cast %[[EXTR]] : i32 to vector<1xi32>
+//       CHECK:    return %[[RES]] : vector<1xi32>
+func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<1xi32> {
+  %shuffle = vector.shuffle %v0, %v1 [0] : vector<4xi32>, vector<4xi32>
+  return %shuffle : vector<1xi32>
+}
+
+// -----
+
+// CHECK-LABEL:  func @shuffle
+//  CHECK-SAME:  %[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>
+//       CHECK:    %[[EXTR:.+]] = spirv.CompositeExtract %[[ARG1]][1 : i32] : vector<4xi32>
+//       CHECK:    %[[RES:.+]]  = builtin.unrealized_conversion_cast %[[EXTR]] : i32 to vector<1xi32>
+//       CHECK:    return %[[RES]] : vector<1xi32>
+func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<1xi32> {
+  %shuffle = vector.shuffle %v0, %v1 [5] : vector<4xi32>, vector<4xi32>
+  return %shuffle : vector<1xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @interleave
 //  CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: vector<2xf32>)
 //       CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[ARG0]], %[[ARG1]] : vector<2xf32>, vector<2xf32> -> vector<4xf32>

@llvmbot
Copy link
Member

llvmbot commented Jul 14, 2024

@llvm/pr-subscribers-mlir

Author: Jakub Kuderski (kuhar)

Changes

These may not get canonicalized before conversion to spirv and need to be handled during vector to spirv conversion. Because spirv does not support 1-element vectors, we can't emit spirv.VectorShuffle and need to lower this to spirv.CompositeExtract.


Full diff: https://github.com/llvm/llvm-project/pull/98809.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+16-9)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+24)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c9363295ec32f..a4390447532a5 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -521,7 +521,7 @@ struct VectorShuffleOpConvert final
   LogicalResult
   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto oldResultType = shuffleOp.getResultVectorType();
+    VectorType oldResultType = shuffleOp.getResultVectorType();
     Type newResultType = getTypeConverter()->convertType(oldResultType);
     if (!newResultType)
       return rewriter.notifyMatchFailure(shuffleOp,
@@ -532,20 +532,22 @@ struct VectorShuffleOpConvert final
           return cast<IntegerAttr>(attr).getValue().getZExtValue();
         });
 
-    auto oldV1Type = shuffleOp.getV1VectorType();
-    auto oldV2Type = shuffleOp.getV2VectorType();
+    VectorType oldV1Type = shuffleOp.getV1VectorType();
+    VectorType oldV2Type = shuffleOp.getV2VectorType();
 
-    // When both operands are SPIR-V vectors, emit a SPIR-V shuffle.
-    if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1) {
+    // When both operands and the result are SPIR-V vectors, emit a SPIR-V
+    // shuffle.
+    if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 &&
+        oldResultType.getNumElements() > 1) {
       rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
           shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
           rewriter.getI32ArrayAttr(mask));
       return success();
     }
 
-    // When at least one of the operands becomes a scalar after type conversion
-    // for SPIR-V, extract all the required elements and construct the result
-    // vector.
+    // When at least one of the operands or the result becomes a scalar after
+    // type conversion for SPIR-V, extract all the required elements and
+    // construct the result vector.
     auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
                                Value scalarOrVec, int32_t idx) -> Value {
       if (auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
@@ -569,9 +571,14 @@ struct VectorShuffleOpConvert final
       newOperand = getElementAtIdx(vec, elementIdx);
     }
 
+    // Handle the scalar result corner case.
+    if (newOperands.size() == 1) {
+      rewriter.replaceOp(shuffleOp, newOperands.front());
+      return success();
+    }
+
     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
         shuffleOp, newResultType, newOperands);
-
     return success();
   }
 };
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 0d67851dfe41d..667aad7645c51 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -483,6 +483,30 @@ func.func @shuffle(%v0 : vector<1xi32>, %v1: vector<1xi32>) -> vector<2xi32> {
 
 // -----
 
+// CHECK-LABEL:  func @shuffle
+//  CHECK-SAME:  %[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>
+//       CHECK:    %[[EXTR:.+]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<4xi32>
+//       CHECK:    %[[RES:.+]]  = builtin.unrealized_conversion_cast %[[EXTR]] : i32 to vector<1xi32>
+//       CHECK:    return %[[RES]] : vector<1xi32>
+func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<1xi32> {
+  %shuffle = vector.shuffle %v0, %v1 [0] : vector<4xi32>, vector<4xi32>
+  return %shuffle : vector<1xi32>
+}
+
+// -----
+
+// CHECK-LABEL:  func @shuffle
+//  CHECK-SAME:  %[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>
+//       CHECK:    %[[EXTR:.+]] = spirv.CompositeExtract %[[ARG1]][1 : i32] : vector<4xi32>
+//       CHECK:    %[[RES:.+]]  = builtin.unrealized_conversion_cast %[[EXTR]] : i32 to vector<1xi32>
+//       CHECK:    return %[[RES]] : vector<1xi32>
+func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<1xi32> {
+  %shuffle = vector.shuffle %v0, %v1 [5] : vector<4xi32>, vector<4xi32>
+  return %shuffle : vector<1xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @interleave
 //  CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: vector<2xf32>)
 //       CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[ARG0]], %[[ARG1]] : vector<2xf32>, vector<2xf32> -> vector<4xf32>

@kuhar kuhar merged commit a72eed7 into llvm:main Jul 14, 2024
10 checks passed
aaryanshukla pushed a commit to aaryanshukla/llvm-project that referenced this pull request Jul 14, 2024
…lvm#98809)

These may not get canonicalized before conversion to spirv and need to
be handled during vector to spirv conversion. Because spirv does not
support 1-element vectors, we can't emit `spirv.VectorShuffle` and need
to lower this to `spirv.CompositeExtract`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants