-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[mlir][spirv] Handle scalar shuffles in vector to spirv conversion #98809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
These may not get canonicalized before conversion to spirv and need to be handled during vector to spirv conversion. Because spirv does not support 1-element vectors, we can't emit `spirv.VectorShuffle` and need to lower this to `spirv.CompositeExtract`.
@llvm/pr-subscribers-mlir-spirv Author: Jakub Kuderski (kuhar) ChangesThese may not get canonicalized before conversion to spirv and need to be handled during vector to spirv conversion. Because spirv does not support 1-element vectors, we can't emit Full diff: https://github.com/llvm/llvm-project/pull/98809.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c9363295ec32f..a4390447532a5 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -521,7 +521,7 @@ struct VectorShuffleOpConvert final
LogicalResult
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto oldResultType = shuffleOp.getResultVectorType();
+ VectorType oldResultType = shuffleOp.getResultVectorType();
Type newResultType = getTypeConverter()->convertType(oldResultType);
if (!newResultType)
return rewriter.notifyMatchFailure(shuffleOp,
@@ -532,20 +532,22 @@ struct VectorShuffleOpConvert final
return cast<IntegerAttr>(attr).getValue().getZExtValue();
});
- auto oldV1Type = shuffleOp.getV1VectorType();
- auto oldV2Type = shuffleOp.getV2VectorType();
+ VectorType oldV1Type = shuffleOp.getV1VectorType();
+ VectorType oldV2Type = shuffleOp.getV2VectorType();
- // When both operands are SPIR-V vectors, emit a SPIR-V shuffle.
- if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1) {
+ // When both operands and the result are SPIR-V vectors, emit a SPIR-V
+ // shuffle.
+ if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 &&
+ oldResultType.getNumElements() > 1) {
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
rewriter.getI32ArrayAttr(mask));
return success();
}
- // When at least one of the operands becomes a scalar after type conversion
- // for SPIR-V, extract all the required elements and construct the result
- // vector.
+ // When at least one of the operands or the result becomes a scalar after
+ // type conversion for SPIR-V, extract all the required elements and
+ // construct the result vector.
auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
Value scalarOrVec, int32_t idx) -> Value {
if (auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
@@ -569,9 +571,14 @@ struct VectorShuffleOpConvert final
newOperand = getElementAtIdx(vec, elementIdx);
}
+ // Handle the scalar result corner case.
+ if (newOperands.size() == 1) {
+ rewriter.replaceOp(shuffleOp, newOperands.front());
+ return success();
+ }
+
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
shuffleOp, newResultType, newOperands);
-
return success();
}
};
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 0d67851dfe41d..667aad7645c51 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -483,6 +483,30 @@ func.func @shuffle(%v0 : vector<1xi32>, %v1: vector<1xi32>) -> vector<2xi32> {
// -----
+// CHECK-LABEL: func @shuffle
+// CHECK-SAME: %[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>
+// CHECK: %[[EXTR:.+]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<4xi32>
+// CHECK: %[[RES:.+]] = builtin.unrealized_conversion_cast %[[EXTR]] : i32 to vector<1xi32>
+// CHECK: return %[[RES]] : vector<1xi32>
+func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<1xi32> {
+ %shuffle = vector.shuffle %v0, %v1 [0] : vector<4xi32>, vector<4xi32>
+ return %shuffle : vector<1xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @shuffle
+// CHECK-SAME: %[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>
+// CHECK: %[[EXTR:.+]] = spirv.CompositeExtract %[[ARG1]][1 : i32] : vector<4xi32>
+// CHECK: %[[RES:.+]] = builtin.unrealized_conversion_cast %[[EXTR]] : i32 to vector<1xi32>
+// CHECK: return %[[RES]] : vector<1xi32>
+func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<1xi32> {
+ %shuffle = vector.shuffle %v0, %v1 [5] : vector<4xi32>, vector<4xi32>
+ return %shuffle : vector<1xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @interleave
// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: vector<2xf32>)
// CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[ARG0]], %[[ARG1]] : vector<2xf32>, vector<2xf32> -> vector<4xf32>
|
@llvm/pr-subscribers-mlir Author: Jakub Kuderski (kuhar) ChangesThese may not get canonicalized before conversion to spirv and need to be handled during vector to spirv conversion. Because spirv does not support 1-element vectors, we can't emit Full diff: https://github.com/llvm/llvm-project/pull/98809.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c9363295ec32f..a4390447532a5 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -521,7 +521,7 @@ struct VectorShuffleOpConvert final
LogicalResult
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto oldResultType = shuffleOp.getResultVectorType();
+ VectorType oldResultType = shuffleOp.getResultVectorType();
Type newResultType = getTypeConverter()->convertType(oldResultType);
if (!newResultType)
return rewriter.notifyMatchFailure(shuffleOp,
@@ -532,20 +532,22 @@ struct VectorShuffleOpConvert final
return cast<IntegerAttr>(attr).getValue().getZExtValue();
});
- auto oldV1Type = shuffleOp.getV1VectorType();
- auto oldV2Type = shuffleOp.getV2VectorType();
+ VectorType oldV1Type = shuffleOp.getV1VectorType();
+ VectorType oldV2Type = shuffleOp.getV2VectorType();
- // When both operands are SPIR-V vectors, emit a SPIR-V shuffle.
- if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1) {
+ // When both operands and the result are SPIR-V vectors, emit a SPIR-V
+ // shuffle.
+ if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1 &&
+ oldResultType.getNumElements() > 1) {
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(),
rewriter.getI32ArrayAttr(mask));
return success();
}
- // When at least one of the operands becomes a scalar after type conversion
- // for SPIR-V, extract all the required elements and construct the result
- // vector.
+ // When at least one of the operands or the result becomes a scalar after
+ // type conversion for SPIR-V, extract all the required elements and
+ // construct the result vector.
auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()](
Value scalarOrVec, int32_t idx) -> Value {
if (auto vecTy = dyn_cast<VectorType>(scalarOrVec.getType()))
@@ -569,9 +571,14 @@ struct VectorShuffleOpConvert final
newOperand = getElementAtIdx(vec, elementIdx);
}
+ // Handle the scalar result corner case.
+ if (newOperands.size() == 1) {
+ rewriter.replaceOp(shuffleOp, newOperands.front());
+ return success();
+ }
+
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
shuffleOp, newResultType, newOperands);
-
return success();
}
};
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 0d67851dfe41d..667aad7645c51 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -483,6 +483,30 @@ func.func @shuffle(%v0 : vector<1xi32>, %v1: vector<1xi32>) -> vector<2xi32> {
// -----
+// CHECK-LABEL: func @shuffle
+// CHECK-SAME: %[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>
+// CHECK: %[[EXTR:.+]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<4xi32>
+// CHECK: %[[RES:.+]] = builtin.unrealized_conversion_cast %[[EXTR]] : i32 to vector<1xi32>
+// CHECK: return %[[RES]] : vector<1xi32>
+func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<1xi32> {
+ %shuffle = vector.shuffle %v0, %v1 [0] : vector<4xi32>, vector<4xi32>
+ return %shuffle : vector<1xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @shuffle
+// CHECK-SAME: %[[ARG0:.+]]: vector<4xi32>, %[[ARG1:.+]]: vector<4xi32>
+// CHECK: %[[EXTR:.+]] = spirv.CompositeExtract %[[ARG1]][1 : i32] : vector<4xi32>
+// CHECK: %[[RES:.+]] = builtin.unrealized_conversion_cast %[[EXTR]] : i32 to vector<1xi32>
+// CHECK: return %[[RES]] : vector<1xi32>
+func.func @shuffle(%v0 : vector<4xi32>, %v1: vector<4xi32>) -> vector<1xi32> {
+ %shuffle = vector.shuffle %v0, %v1 [5] : vector<4xi32>, vector<4xi32>
+ return %shuffle : vector<1xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @interleave
// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>, %[[ARG1:.+]]: vector<2xf32>)
// CHECK: %[[SHUFFLE:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32, 1 : i32, 3 : i32] %[[ARG0]], %[[ARG1]] : vector<2xf32>, vector<2xf32> -> vector<4xf32>
|
…lvm#98809) These may not get canonicalized before conversion to spirv and need to be handled during vector to spirv conversion. Because spirv does not support 1-element vectors, we can't emit `spirv.VectorShuffle` and need to lower this to `spirv.CompositeExtract`.
These may not get canonicalized before conversion to spirv and need to be handled during vector to spirv conversion. Because spirv does not support 1-element vectors, we can't emit
spirv.VectorShuffle
and need to lower this tospirv.CompositeExtract
.