-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Extend vector.{insert|extract}_strided_slice #79052
[mlir][vector] Extend vector.{insert|extract}_strided_slice #79052
Conversation
Extends `vector.insert_strided_slice` and `vector.insert_strided_slice` to allow scalable input and output vectors. For scalable sizes, the corresponding slice size has to match the corresponding dimension in the output/input vector (insert/extract, respectively). This is supported: ```mlir vector.extract_strided_slice %1 { offsets = [0, 3, 0], sizes = [1, 1, 4], strides = [1, 1, 1] } : vector<1x4x[4]xi32> to vector<1x1x[4]xi32> ``` This is not supported: ```mlir vector.extract_strided_slice %1 { offsets = [0, 3, 0], sizes = [1, 1, 2], strides = [1, 1, 1] } : vector<1x4x[4]xi32> to vector<1x1x[2]xi32> ```
@llvm/pr-subscribers-mlir-vector Author: Andrzej Warzyński (banach-space) ChangesExtends This is supported: vector.extract_strided_slice %1 {
offsets = [0, 3, 0],
sizes = [1, 1, 4],
strides = [1, 1, 1] } : vector<1x4x[4]xi32> to vector<1x1x[4]xi32> This is not supported: vector.extract_strided_slice %1 {
offsets = [0, 3, 0],
sizes = [1, 1, 2],
strides = [1, 1, 1] } : vector<1x4x[4]xi32> to vector<1x1x[2]xi32> Full diff: https://github.com/llvm/llvm-project/pull/79052.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 791924f92e8ad40..b168b7d7afe9db0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3194,6 +3194,7 @@ void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
// Inference works as follows:
// 1. Add 'sizes' from prefix of dims in 'offsets'.
// 2. Add sizes from 'vectorType' for remaining dims.
+// Scalable flags are inherited from 'vectorType'.
static Type inferStridedSliceOpResultType(VectorType vectorType,
ArrayAttr offsets, ArrayAttr sizes,
ArrayAttr strides) {
@@ -3206,7 +3207,8 @@ static Type inferStridedSliceOpResultType(VectorType vectorType,
for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
shape.push_back(vectorType.getShape()[idx]);
- return VectorType::get(shape, vectorType.getElementType());
+ return VectorType::get(shape, vectorType.getElementType(),
+ vectorType.getScalableDims());
}
void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
@@ -3265,6 +3267,20 @@ LogicalResult ExtractStridedSliceOp::verify() {
if (getResult().getType() != resultType)
return emitOpError("expected result type to be ") << resultType;
+ unsigned idx = 0;
+ for (unsigned ub = sizes.size(); idx < ub; ++idx) {
+ if (type.getScalableDims()[idx]) {
+ auto inputDim = type.getShape()[idx];
+ auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
+ if (inputDim != inputSize)
+ return emitOpError("expected size at idx=")
+ << idx
+ << (" to match the corresponding base size from the input "
+ "vector (")
+ << inputSize << (" vs ") << inputDim << (")");
+ }
+ }
+
return success();
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 09108ab31799984..394b4dea3dab253 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1142,6 +1142,29 @@ func.func @extract_strided_slice3(%arg0: vector<4x8xf32>) -> vector<2x2xf32> {
// -----
+func.func @extract_strided_slice_scalable(%arg0 : vector<1x4x[4]xi32>) -> vector<1x1x[4]xi32> {
+ %0 = vector.extract_strided_slice %arg0 {offsets = [0, 3, 0], sizes = [1, 1, 4], strides = [1, 1, 1]} : vector<1x4x[4]xi32> to vector<1x1x[4]xi32>
+ return %0 : vector<1x1x[4]xi32>
+}
+
+// CHECK-LABEL: func.func @extract_strided_slice_scalable(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<1x4x[4]xi32>) -> vector<1x1x[4]xi32> {
+// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<1x4x[4]xi32> to !llvm.array<1 x array<4 x vector<[4]xi32>>>
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : i32
+// CHECK: %[[VAL_3:.*]] = arith.constant dense<0> : vector<1x1x[4]xi32>
+// CHECK: %[[VAL_4:.*]] = builtin.unrealized_conversion_cast %[[VAL_3]] : vector<1x1x[4]xi32> to !llvm.array<1 x array<1 x vector<[4]xi32>>>
+// CHECK: %[[VAL_5:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
+// CHECK: %[[VAL_6:.*]] = arith.constant 0 : i32
+// CHECK: %[[VAL_7:.*]] = arith.constant dense<0> : vector<1x[4]xi32>
+// CHECK: %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[VAL_7]] : vector<1x[4]xi32> to !llvm.array<1 x vector<[4]xi32>>
+// CHECK: %[[VAL_9:.*]] = llvm.extractvalue %[[VAL_1]][0, 3] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
+// CHECK: %[[VAL_10:.*]] = llvm.insertvalue %[[VAL_9]], %[[VAL_8]][0] : !llvm.array<1 x vector<[4]xi32>>
+// CHECK: %[[VAL_11:.*]] = llvm.insertvalue %[[VAL_10]], %[[VAL_4]][0] : !llvm.array<1 x array<1 x vector<[4]xi32>>>
+// CHECK: %[[VAL_12:.*]] = builtin.unrealized_conversion_cast %[[VAL_11]] : !llvm.array<1 x array<1 x vector<[4]xi32>>> to vector<1x1x[4]xi32>
+// CHECK: return %[[VAL_12]] : vector<1x1x[4]xi32>
+
+// -----
+
func.func @insert_strided_slice1(%b: vector<4x4xf32>, %c: vector<4x4x4xf32>) -> vector<4x4x4xf32> {
%0 = vector.insert_strided_slice %b, %c {offsets = [2, 0, 0], strides = [1, 1]} : vector<4x4xf32> into vector<4x4x4xf32>
return %0 : vector<4x4x4xf32>
@@ -1207,6 +1230,26 @@ func.func @insert_strided_slice3(%arg0: vector<2x4xf32>, %arg1: vector<16x4x8xf3
// -----
+func.func @insert_strided_slice_scalable(%arg0 : vector<1x1x[4]xi32>, %arg1: vector<1x4x[4]xi32>) -> vector<1x4x[4]xi32> {
+ %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 3, 0], strides = [1, 1, 1]} : vector<1x1x[4]xi32> into vector<1x4x[4]xi32>
+ return %0 : vector<1x4x[4]xi32>
+}
+// CHECK-LABEL: func.func @insert_strided_slice_scalable(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<1x1x[4]xi32>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<1x4x[4]xi32>) -> vector<1x4x[4]xi32> {
+// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<1x1x[4]xi32> to !llvm.array<1 x array<1 x vector<[4]xi32>>>
+// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[VAL_1]] : vector<1x4x[4]xi32> to !llvm.array<1 x array<4 x vector<[4]xi32>>>
+// CHECK: %[[VAL_4:.*]] = llvm.extractvalue %[[VAL_2]][0] : !llvm.array<1 x array<1 x vector<[4]xi32>>>
+// CHECK: %[[VAL_5:.*]] = llvm.extractvalue %[[VAL_3]][0] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
+// CHECK: %[[VAL_6:.*]] = llvm.extractvalue %[[VAL_2]][0, 0] : !llvm.array<1 x array<1 x vector<[4]xi32>>>
+// CHECK: %[[VAL_7:.*]] = llvm.extractvalue %[[VAL_3]][0, 3] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
+// CHECK: %[[VAL_8:.*]] = llvm.insertvalue %[[VAL_6]], %[[VAL_5]][3] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK: %[[VAL_9:.*]] = llvm.insertvalue %[[VAL_8]], %[[VAL_3]][0] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
+// CHECK: %[[VAL_10:.*]] = builtin.unrealized_conversion_cast %[[VAL_9]] : !llvm.array<1 x array<4 x vector<[4]xi32>>> to vector<1x4x[4]xi32>
+// CHECK: return %[[VAL_10]] : vector<1x4x[4]xi32>
+
+// -----
+
func.func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>, %c: vector<1x1x1xf32>, %d: vector<f32>) -> (vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>, vector<f32>) {
// CHECK-LABEL: @vector_fma
// CHECK-SAME: %[[A:.*]]: vector<8xf32>
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 5fa8ac245ce973b..2072262864c4cde 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -687,6 +687,14 @@ func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
// -----
+func.func @extract_strided_slice_scalable(%arg0 : vector<1x4x[4]xi32>) -> vector<1x1x[2]xi32> {
+ // expected-error@+1 {{op expected size at idx=2 to match the corresponding base size from the input vector (2 vs 4)}}
+ %1 = vector.extract_strided_slice %arg0 {offsets = [0, 3, 0], sizes = [1, 1, 2], strides = [1, 1, 1]} : vector<1x4x[4]xi32> to vector<1x1x[2]xi32>
+ return %1 : vector<1x1x[2]xi32>
+ }
+
+// -----
+
func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{op expected strides to be confined to [1, 2)}}
%1 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [1], strides = [100]} : vector<4x8x16xf32> to vector<1x8x16xf32>
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 03532c5c1ceb18d..c95d0dfba69ada2 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -326,6 +326,13 @@ func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32
return %1: vector<2x2x16xf32>
}
+// CHECK-LABEL: @extract_strided_slice_scalable
+func.func @extract_strided_slice_scalable(%arg0: vector<4x[8]x16xf32>) -> vector<2x[8]x16xf32> {
+ // CHECK: vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]x16xf32>
+ %1 = vector.extract_strided_slice %arg0 {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]x16xf32> to vector<2x[8]x16xf32>
+ return %1: vector<2x[8]x16xf32>
+}
+
#contraction_to_scalar_accesses = [
affine_map<(i) -> (i)>,
affine_map<(i) -> (i)>,
|
@llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) ChangesExtends This is supported: vector.extract_strided_slice %1 {
offsets = [0, 3, 0],
sizes = [1, 1, 4],
strides = [1, 1, 1] } : vector<1x4x[4]xi32> to vector<1x1x[4]xi32> This is not supported: vector.extract_strided_slice %1 {
offsets = [0, 3, 0],
sizes = [1, 1, 2],
strides = [1, 1, 1] } : vector<1x4x[4]xi32> to vector<1x1x[2]xi32> Full diff: https://github.com/llvm/llvm-project/pull/79052.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 791924f92e8ad40..b168b7d7afe9db0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3194,6 +3194,7 @@ void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
// Inference works as follows:
// 1. Add 'sizes' from prefix of dims in 'offsets'.
// 2. Add sizes from 'vectorType' for remaining dims.
+// Scalable flags are inherited from 'vectorType'.
static Type inferStridedSliceOpResultType(VectorType vectorType,
ArrayAttr offsets, ArrayAttr sizes,
ArrayAttr strides) {
@@ -3206,7 +3207,8 @@ static Type inferStridedSliceOpResultType(VectorType vectorType,
for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
shape.push_back(vectorType.getShape()[idx]);
- return VectorType::get(shape, vectorType.getElementType());
+ return VectorType::get(shape, vectorType.getElementType(),
+ vectorType.getScalableDims());
}
void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
@@ -3265,6 +3267,20 @@ LogicalResult ExtractStridedSliceOp::verify() {
if (getResult().getType() != resultType)
return emitOpError("expected result type to be ") << resultType;
+ unsigned idx = 0;
+ for (unsigned ub = sizes.size(); idx < ub; ++idx) {
+ if (type.getScalableDims()[idx]) {
+ auto inputDim = type.getShape()[idx];
+ auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
+ if (inputDim != inputSize)
+ return emitOpError("expected size at idx=")
+ << idx
+ << (" to match the corresponding base size from the input "
+ "vector (")
+ << inputSize << (" vs ") << inputDim << (")");
+ }
+ }
+
return success();
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 09108ab31799984..394b4dea3dab253 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1142,6 +1142,29 @@ func.func @extract_strided_slice3(%arg0: vector<4x8xf32>) -> vector<2x2xf32> {
// -----
+func.func @extract_strided_slice_scalable(%arg0 : vector<1x4x[4]xi32>) -> vector<1x1x[4]xi32> {
+ %0 = vector.extract_strided_slice %arg0 {offsets = [0, 3, 0], sizes = [1, 1, 4], strides = [1, 1, 1]} : vector<1x4x[4]xi32> to vector<1x1x[4]xi32>
+ return %0 : vector<1x1x[4]xi32>
+}
+
+// CHECK-LABEL: func.func @extract_strided_slice_scalable(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<1x4x[4]xi32>) -> vector<1x1x[4]xi32> {
+// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<1x4x[4]xi32> to !llvm.array<1 x array<4 x vector<[4]xi32>>>
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : i32
+// CHECK: %[[VAL_3:.*]] = arith.constant dense<0> : vector<1x1x[4]xi32>
+// CHECK: %[[VAL_4:.*]] = builtin.unrealized_conversion_cast %[[VAL_3]] : vector<1x1x[4]xi32> to !llvm.array<1 x array<1 x vector<[4]xi32>>>
+// CHECK: %[[VAL_5:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
+// CHECK: %[[VAL_6:.*]] = arith.constant 0 : i32
+// CHECK: %[[VAL_7:.*]] = arith.constant dense<0> : vector<1x[4]xi32>
+// CHECK: %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[VAL_7]] : vector<1x[4]xi32> to !llvm.array<1 x vector<[4]xi32>>
+// CHECK: %[[VAL_9:.*]] = llvm.extractvalue %[[VAL_1]][0, 3] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
+// CHECK: %[[VAL_10:.*]] = llvm.insertvalue %[[VAL_9]], %[[VAL_8]][0] : !llvm.array<1 x vector<[4]xi32>>
+// CHECK: %[[VAL_11:.*]] = llvm.insertvalue %[[VAL_10]], %[[VAL_4]][0] : !llvm.array<1 x array<1 x vector<[4]xi32>>>
+// CHECK: %[[VAL_12:.*]] = builtin.unrealized_conversion_cast %[[VAL_11]] : !llvm.array<1 x array<1 x vector<[4]xi32>>> to vector<1x1x[4]xi32>
+// CHECK: return %[[VAL_12]] : vector<1x1x[4]xi32>
+
+// -----
+
func.func @insert_strided_slice1(%b: vector<4x4xf32>, %c: vector<4x4x4xf32>) -> vector<4x4x4xf32> {
%0 = vector.insert_strided_slice %b, %c {offsets = [2, 0, 0], strides = [1, 1]} : vector<4x4xf32> into vector<4x4x4xf32>
return %0 : vector<4x4x4xf32>
@@ -1207,6 +1230,26 @@ func.func @insert_strided_slice3(%arg0: vector<2x4xf32>, %arg1: vector<16x4x8xf3
// -----
+func.func @insert_strided_slice_scalable(%arg0 : vector<1x1x[4]xi32>, %arg1: vector<1x4x[4]xi32>) -> vector<1x4x[4]xi32> {
+ %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 3, 0], strides = [1, 1, 1]} : vector<1x1x[4]xi32> into vector<1x4x[4]xi32>
+ return %0 : vector<1x4x[4]xi32>
+}
+// CHECK-LABEL: func.func @insert_strided_slice_scalable(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<1x1x[4]xi32>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<1x4x[4]xi32>) -> vector<1x4x[4]xi32> {
+// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<1x1x[4]xi32> to !llvm.array<1 x array<1 x vector<[4]xi32>>>
+// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[VAL_1]] : vector<1x4x[4]xi32> to !llvm.array<1 x array<4 x vector<[4]xi32>>>
+// CHECK: %[[VAL_4:.*]] = llvm.extractvalue %[[VAL_2]][0] : !llvm.array<1 x array<1 x vector<[4]xi32>>>
+// CHECK: %[[VAL_5:.*]] = llvm.extractvalue %[[VAL_3]][0] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
+// CHECK: %[[VAL_6:.*]] = llvm.extractvalue %[[VAL_2]][0, 0] : !llvm.array<1 x array<1 x vector<[4]xi32>>>
+// CHECK: %[[VAL_7:.*]] = llvm.extractvalue %[[VAL_3]][0, 3] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
+// CHECK: %[[VAL_8:.*]] = llvm.insertvalue %[[VAL_6]], %[[VAL_5]][3] : !llvm.array<4 x vector<[4]xi32>>
+// CHECK: %[[VAL_9:.*]] = llvm.insertvalue %[[VAL_8]], %[[VAL_3]][0] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
+// CHECK: %[[VAL_10:.*]] = builtin.unrealized_conversion_cast %[[VAL_9]] : !llvm.array<1 x array<4 x vector<[4]xi32>>> to vector<1x4x[4]xi32>
+// CHECK: return %[[VAL_10]] : vector<1x4x[4]xi32>
+
+// -----
+
func.func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>, %c: vector<1x1x1xf32>, %d: vector<f32>) -> (vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>, vector<f32>) {
// CHECK-LABEL: @vector_fma
// CHECK-SAME: %[[A:.*]]: vector<8xf32>
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 5fa8ac245ce973b..2072262864c4cde 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -687,6 +687,14 @@ func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
// -----
+func.func @extract_strided_slice_scalable(%arg0 : vector<1x4x[4]xi32>) -> vector<1x1x[2]xi32> {
+ // expected-error@+1 {{op expected size at idx=2 to match the corresponding base size from the input vector (2 vs 4)}}
+ %1 = vector.extract_strided_slice %arg0 {offsets = [0, 3, 0], sizes = [1, 1, 2], strides = [1, 1, 1]} : vector<1x4x[4]xi32> to vector<1x1x[2]xi32>
+ return %1 : vector<1x1x[2]xi32>
+ }
+
+// -----
+
func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{op expected strides to be confined to [1, 2)}}
%1 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [1], strides = [100]} : vector<4x8x16xf32> to vector<1x8x16xf32>
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 03532c5c1ceb18d..c95d0dfba69ada2 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -326,6 +326,13 @@ func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32
return %1: vector<2x2x16xf32>
}
+// CHECK-LABEL: @extract_strided_slice_scalable
+func.func @extract_strided_slice_scalable(%arg0: vector<4x[8]x16xf32>) -> vector<2x[8]x16xf32> {
+ // CHECK: vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]x16xf32>
+ %1 = vector.extract_strided_slice %arg0 {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]x16xf32> to vector<2x[8]x16xf32>
+ return %1: vector<2x[8]x16xf32>
+}
+
#contraction_to_scalar_accesses = [
affine_map<(i) -> (i)>,
affine_map<(i) -> (i)>,
|
Update the verifier for insert_strided_slice
Add a test in ops.mlir
Simplify implementation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM cheers
Extends
vector.insert_strided_slice
andvector.insert_strided_slice
to allow scalable input and output vectors. For scalable sizes, the corresponding slice size has to match the corresponding dimension in the output/input vector (insert/extract, respectively).This is supported:
This is not supported: