Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][vector] Extend vector.{insert|extract}_strided_slice #79052

Merged
merged 5 commits into from
Jan 25, 2024

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Jan 22, 2024

Extends vector.insert_strided_slice and vector.insert_strided_slice to allow scalable input and output vectors. For scalable sizes, the corresponding slice size has to match the corresponding dimension in the output/input vector (insert/extract, respectively).

This is supported:

vector.extract_strided_slice %1 {
  offsets = [0, 3, 0],
  sizes = [1, 1, 4],
  strides = [1, 1, 1] } : vector<1x4x[4]xi32> to vector<1x1x[4]xi32>

This is not supported:

vector.extract_strided_slice %1 {
  offsets = [0, 3, 0],
  sizes = [1, 1, 2],
  strides = [1, 1, 1] } : vector<1x4x[4]xi32> to vector<1x1x[2]xi32>

Extends `vector.insert_strided_slice` and `vector.insert_strided_slice`
to allow scalable input and output vectors. For scalable sizes, the
corresponding slice size has to match the corresponding dimension in the
output/input vector (insert/extract, respectively).

This is supported:
```mlir
vector.extract_strided_slice %1 {
  offsets = [0, 3, 0],
  sizes = [1, 1, 4],
  strides = [1, 1, 1] } : vector<1x4x[4]xi32> to vector<1x1x[4]xi32>
```

This is not supported:
```mlir
vector.extract_strided_slice %1 {
  offsets = [0, 3, 0],
  sizes = [1, 1, 2],
  strides = [1, 1, 1] } : vector<1x4x[4]xi32> to vector<1x1x[2]xi32>
```
@llvmbot
Copy link
Collaborator

llvmbot commented Jan 22, 2024

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes

Extends vector.insert_strided_slice and vector.insert_strided_slice to allow scalable input and output vectors. For scalable sizes, the corresponding slice size has to match the corresponding dimension in the output/input vector (insert/extract, respectively).

This is supported:

vector.extract_strided_slice %1 {
  offsets = [0, 3, 0],
  sizes = [1, 1, 4],
  strides = [1, 1, 1] } : vector&lt;1x4x[4]xi32&gt; to vector&lt;1x1x[4]xi32&gt;

This is not supported:

vector.extract_strided_slice %1 {
  offsets = [0, 3, 0],
  sizes = [1, 1, 2],
  strides = [1, 1, 1] } : vector&lt;1x4x[4]xi32&gt; to vector&lt;1x1x[2]xi32&gt;

Full diff: https://github.com/llvm/llvm-project/pull/79052.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+17-1)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+43)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+8)
  • (modified) mlir/test/Dialect/Vector/ops.mlir (+7)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 791924f92e8ad40..b168b7d7afe9db0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3194,6 +3194,7 @@ void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
 // Inference works as follows:
 //   1. Add 'sizes' from prefix of dims in 'offsets'.
 //   2. Add sizes from 'vectorType' for remaining dims.
+// Scalable flags are inherited from 'vectorType'.
 static Type inferStridedSliceOpResultType(VectorType vectorType,
                                           ArrayAttr offsets, ArrayAttr sizes,
                                           ArrayAttr strides) {
@@ -3206,7 +3207,8 @@ static Type inferStridedSliceOpResultType(VectorType vectorType,
   for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
     shape.push_back(vectorType.getShape()[idx]);
 
-  return VectorType::get(shape, vectorType.getElementType());
+  return VectorType::get(shape, vectorType.getElementType(),
+                         vectorType.getScalableDims());
 }
 
 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
@@ -3265,6 +3267,20 @@ LogicalResult ExtractStridedSliceOp::verify() {
   if (getResult().getType() != resultType)
     return emitOpError("expected result type to be ") << resultType;
 
+  unsigned idx = 0;
+  for (unsigned ub = sizes.size(); idx < ub; ++idx) {
+    if (type.getScalableDims()[idx]) {
+      auto inputDim = type.getShape()[idx];
+      auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
+      if (inputDim != inputSize)
+        return emitOpError("expected size at idx=")
+               << idx
+               << (" to match the corresponding base size from the input "
+                   "vector (")
+               << inputSize << (" vs ") << inputDim << (")");
+    }
+  }
+
   return success();
 }
 
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 09108ab31799984..394b4dea3dab253 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1142,6 +1142,29 @@ func.func @extract_strided_slice3(%arg0: vector<4x8xf32>) -> vector<2x2xf32> {
 
 // -----
 
+func.func @extract_strided_slice_scalable(%arg0 : vector<1x4x[4]xi32>) -> vector<1x1x[4]xi32> {
+  %0 = vector.extract_strided_slice %arg0 {offsets = [0, 3, 0], sizes = [1, 1, 4], strides = [1, 1, 1]} : vector<1x4x[4]xi32> to vector<1x1x[4]xi32>
+  return %0 : vector<1x1x[4]xi32>
+}
+
+// CHECK-LABEL:   func.func @extract_strided_slice_scalable(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<1x4x[4]xi32>) -> vector<1x1x[4]xi32> {
+//      CHECK:      %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<1x4x[4]xi32> to !llvm.array<1 x array<4 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_2:.*]] = arith.constant 0 : i32
+//      CHECK:      %[[VAL_3:.*]] = arith.constant dense<0> : vector<1x1x[4]xi32>
+//      CHECK:      %[[VAL_4:.*]] = builtin.unrealized_conversion_cast %[[VAL_3]] : vector<1x1x[4]xi32> to !llvm.array<1 x array<1 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_5:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_6:.*]] = arith.constant 0 : i32
+//      CHECK:      %[[VAL_7:.*]] = arith.constant dense<0> : vector<1x[4]xi32>
+//      CHECK:      %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[VAL_7]] : vector<1x[4]xi32> to !llvm.array<1 x vector<[4]xi32>>
+//      CHECK:      %[[VAL_9:.*]] = llvm.extractvalue %[[VAL_1]][0, 3] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_10:.*]] = llvm.insertvalue %[[VAL_9]], %[[VAL_8]][0] : !llvm.array<1 x vector<[4]xi32>>
+//      CHECK:      %[[VAL_11:.*]] = llvm.insertvalue %[[VAL_10]], %[[VAL_4]][0] : !llvm.array<1 x array<1 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_12:.*]] = builtin.unrealized_conversion_cast %[[VAL_11]] : !llvm.array<1 x array<1 x vector<[4]xi32>>> to vector<1x1x[4]xi32>
+//      CHECK:      return %[[VAL_12]] : vector<1x1x[4]xi32>
+
+// -----
+
 func.func @insert_strided_slice1(%b: vector<4x4xf32>, %c: vector<4x4x4xf32>) -> vector<4x4x4xf32> {
   %0 = vector.insert_strided_slice %b, %c {offsets = [2, 0, 0], strides = [1, 1]} : vector<4x4xf32> into vector<4x4x4xf32>
   return %0 : vector<4x4x4xf32>
@@ -1207,6 +1230,26 @@ func.func @insert_strided_slice3(%arg0: vector<2x4xf32>, %arg1: vector<16x4x8xf3
 
 // -----
 
+func.func @insert_strided_slice_scalable(%arg0 : vector<1x1x[4]xi32>, %arg1: vector<1x4x[4]xi32>) -> vector<1x4x[4]xi32> {
+  %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 3, 0], strides = [1, 1, 1]} : vector<1x1x[4]xi32> into vector<1x4x[4]xi32>
+  return %0 : vector<1x4x[4]xi32>
+}
+// CHECK-LABEL:   func.func @insert_strided_slice_scalable(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<1x1x[4]xi32>,
+// CHECK-SAME:      %[[VAL_1:.*]]: vector<1x4x[4]xi32>) -> vector<1x4x[4]xi32> {
+//      CHECK:      %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<1x1x[4]xi32> to !llvm.array<1 x array<1 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[VAL_1]] : vector<1x4x[4]xi32> to !llvm.array<1 x array<4 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_4:.*]] = llvm.extractvalue %[[VAL_2]][0] : !llvm.array<1 x array<1 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_5:.*]] = llvm.extractvalue %[[VAL_3]][0] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_6:.*]] = llvm.extractvalue %[[VAL_2]][0, 0] : !llvm.array<1 x array<1 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_7:.*]] = llvm.extractvalue %[[VAL_3]][0, 3] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_8:.*]] = llvm.insertvalue %[[VAL_6]], %[[VAL_5]][3] : !llvm.array<4 x vector<[4]xi32>>
+//      CHECK:      %[[VAL_9:.*]] = llvm.insertvalue %[[VAL_8]], %[[VAL_3]][0] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_10:.*]] = builtin.unrealized_conversion_cast %[[VAL_9]] : !llvm.array<1 x array<4 x vector<[4]xi32>>> to vector<1x4x[4]xi32>
+//      CHECK:      return %[[VAL_10]] : vector<1x4x[4]xi32>
+
+// -----
+
 func.func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>, %c: vector<1x1x1xf32>, %d: vector<f32>) -> (vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>, vector<f32>) {
   // CHECK-LABEL: @vector_fma
   //  CHECK-SAME: %[[A:.*]]: vector<8xf32>
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 5fa8ac245ce973b..2072262864c4cde 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -687,6 +687,14 @@ func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
 
 // -----
 
+func.func @extract_strided_slice_scalable(%arg0 : vector<1x4x[4]xi32>) -> vector<1x1x[2]xi32> {
+    // expected-error@+1 {{op expected size at idx=2 to match the corresponding base size from the input vector (2 vs 4)}}
+    %1 = vector.extract_strided_slice %arg0 {offsets = [0, 3, 0], sizes = [1, 1, 2], strides = [1, 1, 1]} : vector<1x4x[4]xi32> to vector<1x1x[2]xi32>
+    return %1 : vector<1x1x[2]xi32>
+  }
+
+// -----
+
 func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
   // expected-error@+1 {{op expected strides to be confined to [1, 2)}}
   %1 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [1], strides = [100]} : vector<4x8x16xf32> to vector<1x8x16xf32>
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 03532c5c1ceb18d..c95d0dfba69ada2 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -326,6 +326,13 @@ func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32
   return %1: vector<2x2x16xf32>
 }
 
+// CHECK-LABEL: @extract_strided_slice_scalable
+func.func @extract_strided_slice_scalable(%arg0: vector<4x[8]x16xf32>) -> vector<2x[8]x16xf32> {
+  // CHECK: vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]x16xf32>
+  %1 = vector.extract_strided_slice %arg0 {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]x16xf32> to vector<2x[8]x16xf32>
+  return %1: vector<2x[8]x16xf32>
+}
+
 #contraction_to_scalar_accesses = [
   affine_map<(i) -> (i)>,
   affine_map<(i) -> (i)>,

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 22, 2024

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

Extends vector.insert_strided_slice and vector.insert_strided_slice to allow scalable input and output vectors. For scalable sizes, the corresponding slice size has to match the corresponding dimension in the output/input vector (insert/extract, respectively).

This is supported:

vector.extract_strided_slice %1 {
  offsets = [0, 3, 0],
  sizes = [1, 1, 4],
  strides = [1, 1, 1] } : vector&lt;1x4x[4]xi32&gt; to vector&lt;1x1x[4]xi32&gt;

This is not supported:

vector.extract_strided_slice %1 {
  offsets = [0, 3, 0],
  sizes = [1, 1, 2],
  strides = [1, 1, 1] } : vector&lt;1x4x[4]xi32&gt; to vector&lt;1x1x[2]xi32&gt;

Full diff: https://github.com/llvm/llvm-project/pull/79052.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+17-1)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+43)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+8)
  • (modified) mlir/test/Dialect/Vector/ops.mlir (+7)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 791924f92e8ad40..b168b7d7afe9db0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3194,6 +3194,7 @@ void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
 // Inference works as follows:
 //   1. Add 'sizes' from prefix of dims in 'offsets'.
 //   2. Add sizes from 'vectorType' for remaining dims.
+// Scalable flags are inherited from 'vectorType'.
 static Type inferStridedSliceOpResultType(VectorType vectorType,
                                           ArrayAttr offsets, ArrayAttr sizes,
                                           ArrayAttr strides) {
@@ -3206,7 +3207,8 @@ static Type inferStridedSliceOpResultType(VectorType vectorType,
   for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
     shape.push_back(vectorType.getShape()[idx]);
 
-  return VectorType::get(shape, vectorType.getElementType());
+  return VectorType::get(shape, vectorType.getElementType(),
+                         vectorType.getScalableDims());
 }
 
 void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
@@ -3265,6 +3267,20 @@ LogicalResult ExtractStridedSliceOp::verify() {
   if (getResult().getType() != resultType)
     return emitOpError("expected result type to be ") << resultType;
 
+  unsigned idx = 0;
+  for (unsigned ub = sizes.size(); idx < ub; ++idx) {
+    if (type.getScalableDims()[idx]) {
+      auto inputDim = type.getShape()[idx];
+      auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
+      if (inputDim != inputSize)
+        return emitOpError("expected size at idx=")
+               << idx
+               << (" to match the corresponding base size from the input "
+                   "vector (")
+               << inputSize << (" vs ") << inputDim << (")");
+    }
+  }
+
   return success();
 }
 
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 09108ab31799984..394b4dea3dab253 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1142,6 +1142,29 @@ func.func @extract_strided_slice3(%arg0: vector<4x8xf32>) -> vector<2x2xf32> {
 
 // -----
 
+func.func @extract_strided_slice_scalable(%arg0 : vector<1x4x[4]xi32>) -> vector<1x1x[4]xi32> {
+  %0 = vector.extract_strided_slice %arg0 {offsets = [0, 3, 0], sizes = [1, 1, 4], strides = [1, 1, 1]} : vector<1x4x[4]xi32> to vector<1x1x[4]xi32>
+  return %0 : vector<1x1x[4]xi32>
+}
+
+// CHECK-LABEL:   func.func @extract_strided_slice_scalable(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<1x4x[4]xi32>) -> vector<1x1x[4]xi32> {
+//      CHECK:      %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<1x4x[4]xi32> to !llvm.array<1 x array<4 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_2:.*]] = arith.constant 0 : i32
+//      CHECK:      %[[VAL_3:.*]] = arith.constant dense<0> : vector<1x1x[4]xi32>
+//      CHECK:      %[[VAL_4:.*]] = builtin.unrealized_conversion_cast %[[VAL_3]] : vector<1x1x[4]xi32> to !llvm.array<1 x array<1 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_5:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_6:.*]] = arith.constant 0 : i32
+//      CHECK:      %[[VAL_7:.*]] = arith.constant dense<0> : vector<1x[4]xi32>
+//      CHECK:      %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[VAL_7]] : vector<1x[4]xi32> to !llvm.array<1 x vector<[4]xi32>>
+//      CHECK:      %[[VAL_9:.*]] = llvm.extractvalue %[[VAL_1]][0, 3] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_10:.*]] = llvm.insertvalue %[[VAL_9]], %[[VAL_8]][0] : !llvm.array<1 x vector<[4]xi32>>
+//      CHECK:      %[[VAL_11:.*]] = llvm.insertvalue %[[VAL_10]], %[[VAL_4]][0] : !llvm.array<1 x array<1 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_12:.*]] = builtin.unrealized_conversion_cast %[[VAL_11]] : !llvm.array<1 x array<1 x vector<[4]xi32>>> to vector<1x1x[4]xi32>
+//      CHECK:      return %[[VAL_12]] : vector<1x1x[4]xi32>
+
+// -----
+
 func.func @insert_strided_slice1(%b: vector<4x4xf32>, %c: vector<4x4x4xf32>) -> vector<4x4x4xf32> {
   %0 = vector.insert_strided_slice %b, %c {offsets = [2, 0, 0], strides = [1, 1]} : vector<4x4xf32> into vector<4x4x4xf32>
   return %0 : vector<4x4x4xf32>
@@ -1207,6 +1230,26 @@ func.func @insert_strided_slice3(%arg0: vector<2x4xf32>, %arg1: vector<16x4x8xf3
 
 // -----
 
+func.func @insert_strided_slice_scalable(%arg0 : vector<1x1x[4]xi32>, %arg1: vector<1x4x[4]xi32>) -> vector<1x4x[4]xi32> {
+  %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 3, 0], strides = [1, 1, 1]} : vector<1x1x[4]xi32> into vector<1x4x[4]xi32>
+  return %0 : vector<1x4x[4]xi32>
+}
+// CHECK-LABEL:   func.func @insert_strided_slice_scalable(
+// CHECK-SAME:      %[[VAL_0:.*]]: vector<1x1x[4]xi32>,
+// CHECK-SAME:      %[[VAL_1:.*]]: vector<1x4x[4]xi32>) -> vector<1x4x[4]xi32> {
+//      CHECK:      %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<1x1x[4]xi32> to !llvm.array<1 x array<1 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[VAL_1]] : vector<1x4x[4]xi32> to !llvm.array<1 x array<4 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_4:.*]] = llvm.extractvalue %[[VAL_2]][0] : !llvm.array<1 x array<1 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_5:.*]] = llvm.extractvalue %[[VAL_3]][0] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_6:.*]] = llvm.extractvalue %[[VAL_2]][0, 0] : !llvm.array<1 x array<1 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_7:.*]] = llvm.extractvalue %[[VAL_3]][0, 3] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_8:.*]] = llvm.insertvalue %[[VAL_6]], %[[VAL_5]][3] : !llvm.array<4 x vector<[4]xi32>>
+//      CHECK:      %[[VAL_9:.*]] = llvm.insertvalue %[[VAL_8]], %[[VAL_3]][0] : !llvm.array<1 x array<4 x vector<[4]xi32>>>
+//      CHECK:      %[[VAL_10:.*]] = builtin.unrealized_conversion_cast %[[VAL_9]] : !llvm.array<1 x array<4 x vector<[4]xi32>>> to vector<1x4x[4]xi32>
+//      CHECK:      return %[[VAL_10]] : vector<1x4x[4]xi32>
+
+// -----
+
 func.func @vector_fma(%a: vector<8xf32>, %b: vector<2x4xf32>, %c: vector<1x1x1xf32>, %d: vector<f32>) -> (vector<8xf32>, vector<2x4xf32>, vector<1x1x1xf32>, vector<f32>) {
   // CHECK-LABEL: @vector_fma
   //  CHECK-SAME: %[[A:.*]]: vector<8xf32>
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 5fa8ac245ce973b..2072262864c4cde 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -687,6 +687,14 @@ func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
 
 // -----
 
+func.func @extract_strided_slice_scalable(%arg0 : vector<1x4x[4]xi32>) -> vector<1x1x[2]xi32> {
+    // expected-error@+1 {{op expected size at idx=2 to match the corresponding base size from the input vector (2 vs 4)}}
+    %1 = vector.extract_strided_slice %arg0 {offsets = [0, 3, 0], sizes = [1, 1, 2], strides = [1, 1, 1]} : vector<1x4x[4]xi32> to vector<1x1x[2]xi32>
+    return %1 : vector<1x1x[2]xi32>
+  }
+
+// -----
+
 func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) {
   // expected-error@+1 {{op expected strides to be confined to [1, 2)}}
   %1 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [1], strides = [100]} : vector<4x8x16xf32> to vector<1x8x16xf32>
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 03532c5c1ceb18d..c95d0dfba69ada2 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -326,6 +326,13 @@ func.func @extract_strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32
   return %1: vector<2x2x16xf32>
 }
 
+// CHECK-LABEL: @extract_strided_slice_scalable
+func.func @extract_strided_slice_scalable(%arg0: vector<4x[8]x16xf32>) -> vector<2x[8]x16xf32> {
+  // CHECK: vector.extract_strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]x16xf32>
+  %1 = vector.extract_strided_slice %arg0 {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]x16xf32> to vector<2x[8]x16xf32>
+  return %1: vector<2x[8]x16xf32>
+}
+
 #contraction_to_scalar_accesses = [
   affine_map<(i) -> (i)>,
   affine_map<(i) -> (i)>,

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir Outdated Show resolved Hide resolved
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir Outdated Show resolved Hide resolved
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir Outdated Show resolved Hide resolved
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir Outdated Show resolved Hide resolved
mlir/test/Dialect/Vector/ops.mlir Outdated Show resolved Hide resolved
mlir/lib/Dialect/Vector/IR/VectorOps.cpp Show resolved Hide resolved
mlir/lib/Dialect/Vector/IR/VectorOps.cpp Outdated Show resolved Hide resolved
Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM cheers

@banach-space banach-space merged commit 9ddbcee into llvm:main Jan 25, 2024
4 checks passed
@banach-space banach-space deleted the andrzej/extract_strided_scalable branch March 16, 2024 19:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants