292 changes: 192 additions & 100 deletions mlir/test/Dialect/Linalg/reshape_fusion.mlir

Large diffs are not rendered by default.

5 changes: 2 additions & 3 deletions mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,12 @@ func.func @empty_tensor_dim_of_linalg_result(%arg_0 : tensor<?xf32>,

// -----

func.func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>) -> (index, index, index)
func.func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>, %sz0: index) -> (index, index, index)
{
%c1 = arith.constant 1 : index
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%0 = tensor.expand_shape %arg0 [[0, 1], [2], [3, 4, 5]]
: tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
%0 = tensor.expand_shape %arg0 [[0, 1], [2], [3, 4, 5]] output_shape [2, 3, 5, 4, %sz0, 7] : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
%1 = tensor.dim %0, %c1 : tensor<2x3x5x4x?x7xf32>
%2 = tensor.dim %0, %c3 : tensor<2x3x5x4x?x7xf32>
%3 = tensor.dim %0, %c4 : tensor<2x3x5x4x?x7xf32>
Expand Down
28 changes: 14 additions & 14 deletions mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ func.func @matmul_split(%A : tensor<16x256xf32>, %B: tensor<256x32xf32>, %C: ten
// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: @matmul_split
// CHECK-DAG: %[[ID:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<16x256xf32> into tensor<16x4x64xf32>
// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<256x32xf32> into tensor<4x64x32xf32>
// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] output_shape [16, 4, 64] : tensor<16x256xf32> into tensor<16x4x64xf32>
// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] output_shape [4, 64, 32] : tensor<256x32xf32> into tensor<4x64x32xf32>
// CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<16x32x4xf32>
// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<16x32x4xf32>) -> tensor<16x32x4xf32>
// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
Expand Down Expand Up @@ -65,7 +65,7 @@ func.func @generic_split_1d(%arg0: tensor<32xf32>, %arg1: tensor<f32>, %out: ten
// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0) -> ()>
//CHECK-LABEL: @generic_split_1d
// CHECK-DAG: %[[ID:.*]] = arith.constant 1.000000e+00 : f32
// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] : tensor<32xf32> into tensor<4x8xf32>
// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] output_shape [4, 8] : tensor<32xf32> into tensor<4x8xf32>
// CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<4xf32>
// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32>
// CHECK: %[[G:.*]] = linalg.generic
Expand Down Expand Up @@ -119,8 +119,8 @@ func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>
// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: func @generic_split_3d
// CHECK-DAG: %[[ID:.*]] = arith.constant 0xFF800000 : f32
// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<4x8x2xf32>
// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x4x8xf32>
// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] output_shape [4, 8, 2] : tensor<32x2xf32> into tensor<4x8x2xf32>
// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] output_shape [5, 4, 8] : tensor<5x32xf32> into tensor<5x4x8xf32>
// CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32>
// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32>
// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]}
Expand Down Expand Up @@ -177,8 +177,8 @@ func.func @generic_split_3d_ninf(%input: tensor<32x2xf32>, %input_2: tensor<5x32
// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: func @generic_split_3d_ninf
// CHECK-DAG: %[[ID:.*]] = arith.constant -3.40282347E+38 : f32
// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<4x8x2xf32>
// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x4x8xf32>
// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] output_shape [4, 8, 2] : tensor<32x2xf32> into tensor<4x8x2xf32>
// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] output_shape [5, 4, 8] : tensor<5x32xf32> into tensor<5x4x8xf32>
// CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32>
// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32>
// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]}
Expand Down Expand Up @@ -218,8 +218,8 @@ func.func @matmul_split(%A : tensor<16x256xf32>, %B: tensor<256x32xf32>, %C: ten
// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: @matmul_split
// CHECK-DAG: %[[ID:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<16x256xf32> into tensor<16x64x4xf32>
// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<256x32xf32> into tensor<64x4x32xf32>
// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] output_shape [16, 64, 4] : tensor<16x256xf32> into tensor<16x64x4xf32>
// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] output_shape [64, 4, 32] : tensor<256x32xf32> into tensor<64x4x32xf32>
// CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<16x32x4xf32>
// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<16x32x4xf32>) -> tensor<16x32x4xf32>
// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]]
Expand Down Expand Up @@ -270,7 +270,7 @@ func.func @generic_split_1d(%arg0: tensor<32xf32>, %arg1: tensor<f32>, %out: ten
// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0) -> ()>
//CHECK-LABEL: @generic_split_1d
// CHECK-DAG: %[[ID:.*]] = arith.constant 1.000000e+00 : f32
// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] : tensor<32xf32> into tensor<8x4xf32>
// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] output_shape [8, 4] : tensor<32xf32> into tensor<8x4xf32>
// CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<4xf32>
// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32>
// CHECK: %[[G:.*]] = linalg.generic
Expand Down Expand Up @@ -324,8 +324,8 @@ func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>
// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: func @generic_split_3d
// CHECK-DAG: %[[ID:.*]] = arith.constant 0x7F800000 : f32
// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<8x4x2xf32>
// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x8x4xf32>
// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] output_shape [8, 4, 2] : tensor<32x2xf32> into tensor<8x4x2xf32>
// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] output_shape [5, 8, 4] : tensor<5x32xf32> into tensor<5x8x4xf32>
// CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32>
// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32>
// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]}
Expand Down Expand Up @@ -382,8 +382,8 @@ func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>
// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-LABEL: func @generic_split_3d
// CHECK-DAG: %[[ID:.*]] = arith.constant 3.40282347E+38 : f32
// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<8x4x2xf32>
// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x8x4xf32>
// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] output_shape [8, 4, 2] : tensor<32x2xf32> into tensor<8x4x2xf32>
// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] output_shape [5, 8, 4] : tensor<5x32xf32> into tensor<5x8x4xf32>
// CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32>
// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32>
// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]}
Expand Down
4 changes: 3 additions & 1 deletion mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1710,10 +1710,12 @@ module attributes {transform.with_named_sequence} {
#map = affine_map<(d0) -> (d0)>
// CHECK-LABEL: @not_vectorizable
func.func @not_vectorizable(%arg0: tensor<1x?xf32>, %arg1: index, %arg2: index, %arg3: index) -> tensor<1x128xf32> {
%c0 = arith.constant 0 : index
%0 = tensor.empty() : tensor<1x128xf32>
%1 = scf.for %arg5 = %arg2 to %arg1 step %arg3 iter_args(%arg6 = %0) -> (tensor<1x128xf32>) {
%extracted_slice = tensor.extract_slice %arg6[0, 0] [1, %arg1] [1, 1] : tensor<1x128xf32> to tensor<?xf32>
%expanded = tensor.expand_shape %extracted_slice [[0, 1]] : tensor<?xf32> into tensor<1x?xf32>
%sz0 = tensor.dim %extracted_slice, %c0 : tensor<?xf32>
%expanded = tensor.expand_shape %extracted_slice [[0, 1]] output_shape [1, %sz0] : tensor<?xf32> into tensor<1x?xf32>
%extracted_slice_0 = tensor.extract_slice %arg0[0, %arg3] [1, %arg2] [1, 1] : tensor<1x?xf32> to tensor<?xf32>
%extracted_slice_1 = tensor.extract_slice %expanded[0, %arg3] [1, %arg2] [1, 1] : tensor<1x?xf32> to tensor<?xf32>
%2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%extracted_slice_0 : tensor<?xf32>) outs(%extracted_slice_1 : tensor<?xf32>) {
Expand Down
35 changes: 17 additions & 18 deletions mlir/test/Dialect/MemRef/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func.func @collapse_shape_identity_fold(%arg0 : memref<5xi8>) -> memref<5xi8> {
// CHECK-LABEL: expand_shape_identity_fold
// CHECK-NEXT: return
func.func @expand_shape_identity_fold(%arg0 : memref<5x4xi8>) -> memref<5x4xi8> {
%0 = memref.expand_shape %arg0 [[0], [1]] : memref<5x4xi8> into memref<5x4xi8>
%0 = memref.expand_shape %arg0 [[0], [1]] output_shape [5, 4] : memref<5x4xi8> into memref<5x4xi8>
return %0 : memref<5x4xi8>
}

Expand All @@ -23,7 +23,7 @@ func.func @expand_shape_identity_fold(%arg0 : memref<5x4xi8>) -> memref<5x4xi8>
// CHECK-NEXT: return
func.func @collapse_expand_rank0_cancel(%arg0 : memref<1x1xi8>) -> memref<1x1xi8> {
%0 = memref.collapse_shape %arg0 [] : memref<1x1xi8> into memref<i8>
%1 = memref.expand_shape %0 [] : memref<i8> into memref<1x1xi8>
%1 = memref.expand_shape %0 [] output_shape [1, 1] : memref<i8> into memref<1x1xi8>
return %1 : memref<1x1xi8>
}

Expand Down Expand Up @@ -455,9 +455,9 @@ func.func @compose_collapse_of_collapse(%arg0 : memref<?x?x?x?x?xf32>)
// -----

func.func @do_not_compose_collapse_of_expand_non_identity_layout(
%arg0: memref<?x?xf32, strided<[?, 1], offset: 0>>)
%arg0: memref<?x?xf32, strided<[?, 1], offset: 0>>, %sz0: index, %sz1: index)
-> memref<?xf32, strided<[?], offset: 0>> {
%1 = memref.expand_shape %arg0 [[0, 1], [2]] :
%1 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [%sz0, 4, %sz1] :
memref<?x?xf32, strided<[?, 1], offset: 0>> into
memref<?x4x?xf32, strided<[?, ?, 1], offset: 0>>
%2 = memref.collapse_shape %1 [[0, 1, 2]] :
Expand All @@ -471,35 +471,34 @@ func.func @do_not_compose_collapse_of_expand_non_identity_layout(

// -----

func.func @compose_expand_of_expand(%arg0 : memref<?x?xf32>)
func.func @compose_expand_of_expand(%arg0 : memref<?x?xf32>, %sz0: index, %sz1: index, %sz2: index, %sz3: index)
-> memref<?x6x4x5x?xf32> {
%0 = memref.expand_shape %arg0 [[0, 1], [2]]
%0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [%sz0, 4, %sz1]
: memref<?x?xf32> into memref<?x4x?xf32>
%1 = memref.expand_shape %0 [[0, 1], [2], [3, 4]]
: memref<?x4x?xf32> into memref<?x6x4x5x?xf32>
%1 = memref.expand_shape %0 [[0, 1], [2], [3, 4]] output_shape [%sz2, 6, 4, 5, %sz3] : memref<?x4x?xf32> into memref<?x6x4x5x?xf32>
return %1 : memref<?x6x4x5x?xf32>
}
// CHECK-LABEL: func @compose_expand_of_expand
// CHECK: memref.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
// CHECK: memref.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] output_shape [%{{.*}}, 6, 4, 5, %{{.*}}]
// CHECK-NOT: memref.expand_shape

// -----

func.func @compose_expand_of_expand_of_zero_dim(%arg0 : memref<f32>)
-> memref<1x1x1xf32> {
%0 = memref.expand_shape %arg0 [] : memref<f32> into memref<1xf32>
%1 = memref.expand_shape %0 [[0, 1, 2]]
%0 = memref.expand_shape %arg0 [] output_shape [1] : memref<f32> into memref<1xf32>
%1 = memref.expand_shape %0 [[0, 1, 2]] output_shape [1, 1, 1]
: memref<1xf32> into memref<1x1x1xf32>
return %1 : memref<1x1x1xf32>
}
// CHECK-LABEL: func @compose_expand_of_expand_of_zero_dim
// CHECK: memref.expand_shape %{{.*}} []
// CHECK: memref.expand_shape %{{.*}} [] output_shape [1, 1, 1]
// CHECK-SAME: memref<f32> into memref<1x1x1xf32>

// -----

func.func @fold_collapse_of_expand(%arg0 : memref<12x4xf32>) -> memref<12x4xf32> {
%0 = memref.expand_shape %arg0 [[0, 1], [2]]
%0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [3, 4, 4]
: memref<12x4xf32> into memref<3x4x4xf32>
%1 = memref.collapse_shape %0 [[0, 1], [2]]
: memref<3x4x4xf32> into memref<12x4xf32>
Expand All @@ -510,9 +509,9 @@ func.func @fold_collapse_of_expand(%arg0 : memref<12x4xf32>) -> memref<12x4xf32>

// -----

func.func @fold_collapse_collapse_of_expand(%arg0 : memref<?x?xf32>)
func.func @fold_collapse_collapse_of_expand(%arg0 : memref<?x?xf32>, %sz0: index, %sz1: index)
-> memref<?x?xf32> {
%0 = memref.expand_shape %arg0 [[0, 1], [2]]
%0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [%sz0, 4, %sz1]
: memref<?x?xf32> into memref<?x4x?xf32>
%1 = memref.collapse_shape %0 [[0, 1], [2]]
: memref<?x4x?xf32> into memref<?x?xf32>
Expand All @@ -525,7 +524,7 @@ func.func @fold_collapse_collapse_of_expand(%arg0 : memref<?x?xf32>)

func.func @fold_memref_expand_cast(%arg0 : memref<?x?xf32>) -> memref<2x4x4xf32> {
%0 = memref.cast %arg0 : memref<?x?xf32> to memref<8x4xf32>
%1 = memref.expand_shape %0 [[0, 1], [2]]
%1 = memref.expand_shape %0 [[0, 1], [2]] output_shape [2, 4, 4]
: memref<8x4xf32> into memref<2x4x4xf32>
return %1 : memref<2x4x4xf32>
}
Expand Down Expand Up @@ -981,10 +980,10 @@ func.func @memref_realloc_dead(%src : memref<2xf32>, %v : f32) -> memref<2xf32>{
// CHECK-SAME: %[[m:.*]]: memref<?xf32, strided<[1]>, 3>
// CHECK: %[[casted:.*]] = memref.cast %[[m]] : memref<?xf32, strided<[1]>, 3> to memref<?xf32, 3
// CHECK: return %[[casted]]
func.func @collapse_expand_fold_to_cast(%m: memref<?xf32, strided<[1]>, 3>)
func.func @collapse_expand_fold_to_cast(%m: memref<?xf32, strided<[1]>, 3>, %sz0: index)
-> (memref<?xf32, 3>)
{
%0 = memref.expand_shape %m [[0, 1]]
%0 = memref.expand_shape %m [[0, 1]] output_shape [1, %sz0]
: memref<?xf32, strided<[1]>, 3> into memref<1x?xf32, 3>
%1 = memref.collapse_shape %0 [[0, 1]]
: memref<1x?xf32, 3> into memref<?xf32, 3>
Expand Down
16 changes: 9 additions & 7 deletions mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -421,10 +421,11 @@ func.func @simplify_expand_shape(
%base: memref<?x?xf32, strided<[?,?], offset:?>>,
%offset0: index, %offset1: index, %offset2: index,
%size0: index, %size1: index, %size2: index,
%stride0: index, %stride1: index, %stride2: index)
%stride0: index, %stride1: index, %stride2: index,
%sz0: index, %sz1: index)
-> memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>> {

%subview = memref.expand_shape %base[[0, 1, 2, 3],[4, 5, 6, 7]] :
%subview = memref.expand_shape %base [[0, 1, 2, 3],[4, 5, 6, 7]] output_shape [%sz0, 7, 8, 9, 10, 2, %sz1, 3] :
memref<?x?xf32, strided<[?,?], offset: ?>> into
memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>

Expand Down Expand Up @@ -491,7 +492,7 @@ func.func @extract_strided_metadata_of_expand_shape_all_static(
index, index, index, index, index,
index, index, index, index, index) {

%expand_shape = memref.expand_shape %arg[[0, 1, 2], [3, 4]] :
%expand_shape = memref.expand_shape %arg[[0, 1, 2], [3, 4]] output_shape [3, 5, 2, 2, 2] :
memref<30x4xi16> into memref<3x5x2x2x2xi16>

%base, %offset, %sizes:5, %strides:5 = memref.extract_strided_metadata %expand_shape :
Expand Down Expand Up @@ -595,12 +596,13 @@ func.func @extract_strided_metadata_of_expand_shape_all_dynamic(
%base: memref<?x?xf32, strided<[?,?], offset:?>>,
%offset0: index, %offset1: index, %offset2: index,
%size0: index, %size1: index, %size2: index,
%stride0: index, %stride1: index, %stride2: index)
%stride0: index, %stride1: index, %stride2: index,
%sz0: index, %sz1: index)
-> (memref<f32>, index,
index, index, index, index, index, index, index, index,
index, index, index, index, index, index, index, index) {

%subview = memref.expand_shape %base[[0, 1, 2, 3],[4, 5, 6, 7]] :
%subview = memref.expand_shape %base[[0, 1, 2, 3],[4, 5, 6, 7]] output_shape [%sz0, 7, 8, 9, 10, 2, %sz1, 3] :
memref<?x?xf32, strided<[?,?], offset: ?>> into
memref<?x7x8x9x10x2x?x3xf32, strided<[?, ?, ?, ?, ?, ?, ?, ?], offset: ?>>

Expand Down Expand Up @@ -643,7 +645,7 @@ func.func @extract_strided_metadata_of_expand_shape_all_static_0_rank(
index, index, index, index, index,
index, index, index, index, index) {

%expand_shape = memref.expand_shape %arg[] :
%expand_shape = memref.expand_shape %arg[] output_shape [1, 1, 1, 1, 1] :
memref<i16, strided<[], offset: ?>> into memref<1x1x1x1x1xi16, strided<[1,1,1,1,1], offset: ?>>

%base, %offset, %sizes:5, %strides:5 = memref.extract_strided_metadata %expand_shape :
Expand Down Expand Up @@ -1513,4 +1515,4 @@ func.func @zero_sized_memred(%arg0: f32) -> (memref<f16, 3>, index,index,index)
%sizes, %strides :
memref<f16,3>, index,
index, index
}
}
22 changes: 12 additions & 10 deletions mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ func.func @fold_static_stride_subview_with_affine_load_store(%arg0 : memref<12x3
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
// CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) -> f32 {
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index) -> f32 {
%0 = memref.expand_shape %arg0 [[0, 1], [2]] : memref<12x32xf32> into memref<2x6x32xf32>
%0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [2, 6, 32] : memref<12x32xf32> into memref<2x6x32xf32>
%1 = affine.load %0[%arg1, %arg2, %arg3] : memref<2x6x32xf32>
return %1 : f32
}
Expand Down Expand Up @@ -458,7 +458,7 @@ func.func @fold_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x6x
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_3d
// CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index) -> f32 {
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: index) -> f32 {
%0 = memref.expand_shape %arg0 [[0, 1, 2], [3]] : memref<12x32xf32> into memref<2x2x3x32xf32>
%0 = memref.expand_shape %arg0 [[0, 1, 2], [3]] output_shape [2, 2, 3, 32] : memref<12x32xf32> into memref<2x2x3x32xf32>
%1 = affine.load %0[%arg1, %arg2, %arg3, %arg4] : memref<2x2x3x32xf32>
return %1 : f32
}
Expand All @@ -469,15 +469,17 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar
// -----

// CHECK-LABEL: fold_dynamic_subview_with_memref_load_store_expand_shape
func.func @fold_dynamic_subview_with_memref_load_store_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index) -> f32 {
// CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[SZ0:.*]]: index)
func.func @fold_dynamic_subview_with_memref_load_store_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index, %sz0: index) -> f32 {
%c0 = arith.constant 0 : index
%expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
%expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 16, %sz0, 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
%0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
return %0 : f32
}
// CHECK: %[[EXPAND_SHAPE:.+]] = memref.expand_shape {{.+}} : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
// CHECK: %[[LOAD:.+]] = memref.load %[[EXPAND_SHAPE]]
// CHECK: return %[[LOAD]]
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[EXPAND_SHAPE:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [1, 16, %[[SZ0]], 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
// CHECK: %[[VAL_0:.*]] = memref.load %[[EXPAND_SHAPE]][%[[C0]], %[[ARG1]], %[[ARG2]], %[[C0]]] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
// CHECK: return %[[VAL_0]] : f32

// -----

Expand All @@ -486,7 +488,7 @@ func.func @fold_dynamic_subview_with_memref_load_store_expand_shape(%arg0 : memr
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
%0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32>
%0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 1024, 1024, 1] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32>
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 1024 {
affine.for %arg5 = 0 to 1020 {
Expand Down Expand Up @@ -515,7 +517,7 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0:
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression
// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
%0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32>
%0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 1024, 1024, 1] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32>
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 1024 {
affine.for %arg5 = 0 to 1020 {
Expand Down Expand Up @@ -544,7 +546,7 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_a
// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index
// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
%0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32>
%0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 1024, 1024, 1] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32>
%cst = arith.constant 0 : index
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 1024 {
Expand Down
38 changes: 14 additions & 24 deletions mlir/test/Dialect/MemRef/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -392,17 +392,17 @@ func.func @copy_different_eltype(%arg0: memref<2xf32>, %arg1: memref<2xf16>) {

// -----

func.func @expand_shape(%arg0: memref<?x?xf32>) {
func.func @expand_shape(%arg0: memref<?x?xf32>, %sz0: index, %sz1: index) {
// expected-error @+1 {{invalid number of reassociation groups: found 1, expected 2}}
%0 = memref.expand_shape %arg0 [[0, 1]] : memref<?x?xf32> into memref<?x5x?xf32>
%0 = memref.expand_shape %arg0 [[0, 1]] output_shape [%sz0, 5, %sz1] : memref<?x?xf32> into memref<?x5x?xf32>
return
}

// -----

func.func @expand_shape(%arg0: memref<f32>) {
// expected-error @+1 {{rank 0 memrefs can only be extended/collapsed with/from ones}}
%0 = memref.expand_shape %arg0 [] : memref<f32> into memref<1x2xf32>
%0 = memref.expand_shape %arg0 [] output_shape [1, 2] : memref<f32> into memref<1x2xf32>
return
}

Expand All @@ -415,17 +415,17 @@ func.func @collapse_shape_out_of_bounds(%arg0: memref<?x?xf32>) {

// -----

func.func @expand_shape_out_of_bounds(%arg0: memref<?xf32>) {
func.func @expand_shape_out_of_bounds(%arg0: memref<?xf32>, %sz0: index) {
// expected-error @+1 {{op reassociation index 2 is out of bounds}}
%0 = memref.expand_shape %arg0 [[0, 1, 2]] : memref<?xf32> into memref<4x?xf32>
%0 = memref.expand_shape %arg0 [[0, 1, 2]] output_shape [4, %sz0] : memref<?xf32> into memref<4x?xf32>
}

// -----

func.func @expand_shape_invalid_result_layout(
%arg0: memref<30x20xf32, strided<[4000, 2], offset: 100>>) {
// expected-error @+1 {{expected expanded type to be 'memref<2x15x20xf32, strided<[60000, 4000, 2], offset: 100>>' but found 'memref<2x15x20xf32, strided<[5000, 4000, 2], offset: 100>>'}}
%0 = memref.expand_shape %arg0 [[0, 1], [2]] :
%0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [2, 15, 20] :
memref<30x20xf32, strided<[4000, 2], offset: 100>>
into memref<2x15x20xf32, strided<[5000, 4000, 2], offset: 100>>
}
Expand Down Expand Up @@ -462,7 +462,7 @@ func.func @collapse_shape_invalid_reassociation_expansion(%arg0: memref<?xf32>)
// like this. Verify that a sensible error is emitted in this case.
func.func @expand_shape_invalid_reassociation(%arg0: memref<2x3x1xf32>) {
// expected-error @+1 {{'memref.expand_shape' op has source rank 3 and result rank 2. This is not an expansion (3 > 2)}}
%0 = memref.expand_shape %arg0 [[0], [1], [1]] :
%0 = memref.expand_shape %arg0 [[0], [1], [1]] output_shape [2, 3] :
memref<2x3x1xf32> into memref<2x3xf32>
}

Expand Down Expand Up @@ -495,20 +495,10 @@ func.func @collapse_shape_wrong_collapsed_type(%arg0: memref<?x?x?xf32>) {

// -----

func.func @expand_shape_illegal_dynamic_memref
(%arg0: memref<?x?x?xf32>) -> memref<?x?x?x4x?xf32> {
// expected-error @+1 {{at most one dimension in a reassociation group may be dynamic}}
%0 = memref.expand_shape %arg0 [[0], [1], [2, 3, 4]]
: memref<?x?x?xf32> into memref<?x?x?x4x?xf32>
return %0 : memref<?x?x?x4x?xf32>
}

// -----

func.func @expand_shape_illegal_static_memref
(%arg0: memref<2x3x20xf32>) -> memref<2x3x2x4x5xf32> {
// expected-error @+1 {{collapsed dim size (20) must equal reassociation group size (40)}}
%0 = memref.expand_shape %arg0 [[0], [1], [2, 3, 4]]
%0 = memref.expand_shape %arg0 [[0], [1], [2, 3, 4]] output_shape [2, 3, 2, 4, 5]
: memref<2x3x20xf32> into memref<2x3x2x4x5xf32>
return %0 : memref<2x3x2x4x5xf32>
}
Expand All @@ -525,30 +515,30 @@ func.func @collapse_shape_illegal_static_memref

// -----

func.func @expand_shape_illegal_mixed_memref(%arg0 : memref<?x?xf32>)
func.func @expand_shape_illegal_mixed_memref(%arg0 : memref<?x?xf32>, %sz0: index)
-> memref<?x4x5xf32> {
// expected-error @+1 {{collapsed dim (1) must be dynamic if and only if reassociation group is dynamic}}
%0 = memref.expand_shape %arg0 [[0, 1], [2]]
%0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [%sz0, 4, 5]
: memref<?x?xf32> into memref<?x4x5xf32>
return %0 : memref<?x4x5xf32>
}

// -----

func.func @expand_shape_illegal_mixed_memref_2(%arg0 : memref<?x?xf32>)
func.func @expand_shape_illegal_mixed_memref_2(%arg0 : memref<?x?xf32>, %sz0: index)
-> memref<?x4x5xf32> {
// expected-error @+1 {{collapsed dim (1) must be dynamic if and only if reassociation group is dynamic}}
%0 = memref.expand_shape %arg0 [[0], [1, 2]]
%0 = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [%sz0, 4, 5]
: memref<?x?xf32> into memref<?x4x5xf32>
return %0 : memref<?x4x5xf32>
}

// -----

func.func @expand_shape_invalid_static_dim_size(%arg0 : memref<?x21xf32>)
func.func @expand_shape_invalid_static_dim_size(%arg0 : memref<?x21xf32>, %sz0: index)
-> memref<?x4x5xf32> {
// expected-error @+1 {{collapsed dim size (21) must equal reassociation group size (20)}}
%0 = memref.expand_shape %arg0 [[0], [1, 2]]
%0 = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [%sz0, 4, 5]
: memref<?x21xf32> into memref<?x4x5xf32>
return %0 : memref<?x4x5xf32>
}
Expand Down
72 changes: 40 additions & 32 deletions mlir/test/Dialect/MemRef/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -106,49 +106,49 @@ func.func @expand_collapse_shape_static(
%0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
memref<3x4x5xf32> into memref<12x5xf32>

// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] output_shape [3, 4, 5]
// CHECK-SAME: memref<12x5xf32> into memref<3x4x5xf32>
%r0 = memref.expand_shape %0 [[0, 1], [2]] :
%r0 = memref.expand_shape %0 [[0, 1], [2]] output_shape [3, 4, 5] :
memref<12x5xf32> into memref<3x4x5xf32>

// CHECK: memref.collapse_shape {{.*}} {{\[}}[0], [1, 2]]
// CHECK-SAME: memref<3x4x5xf32> into memref<3x20xf32>
%1 = memref.collapse_shape %arg0 [[0], [1, 2]] :
memref<3x4x5xf32> into memref<3x20xf32>

// CHECK: memref.expand_shape {{.*}} {{\[}}[0], [1, 2]]
// CHECK: memref.expand_shape {{.*}} {{\[}}[0], [1, 2]] output_shape [3, 4, 5]
// CHECK-SAME: memref<3x20xf32> into memref<3x4x5xf32>
%r1 = memref.expand_shape %1 [[0], [1, 2]] :
%r1 = memref.expand_shape %1 [[0], [1, 2]] output_shape [3, 4, 5] :
memref<3x20xf32> into memref<3x4x5xf32>

// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1, 2]]
// CHECK-SAME: memref<3x4x5xf32> into memref<60xf32>
%2 = memref.collapse_shape %arg0 [[0, 1, 2]] :
memref<3x4x5xf32> into memref<60xf32>

// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1, 2]]
// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1, 2]] output_shape [3, 4, 5]
// CHECK-SAME: memref<60xf32> into memref<3x4x5xf32>
%r2 = memref.expand_shape %2 [[0, 1, 2]] :
%r2 = memref.expand_shape %2 [[0, 1, 2]] output_shape [3, 4, 5] :
memref<60xf32> into memref<3x4x5xf32>

// CHECK: memref.expand_shape {{.*}} []
// CHECK: memref.expand_shape {{.*}} [] output_shape [1, 1]
// CHECK-SAME: memref<f32> into memref<1x1xf32>
%r5 = memref.expand_shape %arg5 [] :
%r5 = memref.expand_shape %arg5 [] output_shape [1, 1] :
memref<f32> into memref<1x1xf32>

// Reshapes with a custom layout map.
// CHECK: memref.expand_shape {{.*}} {{\[}}[0], [1, 2]]
%l0 = memref.expand_shape %arg3 [[0], [1, 2]] :
// CHECK: memref.expand_shape {{.*}} {{\[}}[0], [1, 2]] output_shape [30, 4, 5]
%l0 = memref.expand_shape %arg3 [[0], [1, 2]] output_shape [30, 4, 5] :
memref<30x20xf32, strided<[4000, 2], offset: 100>>
into memref<30x4x5xf32, strided<[4000, 10, 2], offset: 100>>

// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
%l1 = memref.expand_shape %arg3 [[0, 1], [2]] :
// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] output_shape [2, 15, 20]
%l1 = memref.expand_shape %arg3 [[0, 1], [2]] output_shape [2, 15, 20] :
memref<30x20xf32, strided<[4000, 2], offset: 100>>
into memref<2x15x20xf32, strided<[60000, 4000, 2], offset: 100>>

// CHECK: memref.expand_shape {{.*}} {{\[}}[0], [1, 2]]
%r4 = memref.expand_shape %arg4 [[0], [1, 2]] :
// CHECK: memref.expand_shape {{.*}} {{\[}}[0], [1, 2]] output_shape [1, 1, 5]
%r4 = memref.expand_shape %arg4 [[0], [1, 2]] output_shape [1, 1, 5] :
memref<1x5xf32, strided<[5, 1], offset: ?>> into
memref<1x1x5xf32, strided<[5, 5, 1], offset: ?>>

Expand All @@ -164,9 +164,9 @@ func.func @expand_collapse_shape_static(
memref<2049xi64, strided<[?], offset: ?>>

// Reshapes that expand and collapse back a contiguous buffer with some 1's.
// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5]
// CHECK-SAME: memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
%3 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]] :
%3 = memref.expand_shape %arg0 [[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5]:
memref<3x4x5xf32> into memref<1x3x4x1x5xf32>

// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
Expand All @@ -176,15 +176,18 @@ func.func @expand_collapse_shape_static(

// Reshapes on tensors.
// CHECK: tensor.expand_shape {{.*}}: tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>
%t0 = tensor.expand_shape %arg1 [[0, 1], [2], [3, 4]] :
%t0 = tensor.expand_shape %arg1 [[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5] :
tensor<3x4x5xf32> into tensor<1x3x4x1x5xf32>

// CHECK: tensor.collapse_shape {{.*}}: tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>
%rt0 = tensor.collapse_shape %t0 [[0, 1], [2], [3, 4]] :
tensor<1x3x4x1x5xf32> into tensor<3x4x5xf32>

// CHECK: tensor.dim %arg2, {{.*}} : tensor<3x?x5xf32>
// CHECK: tensor.expand_shape {{.*}}: tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>
%t1 = tensor.expand_shape %arg2 [[0, 1], [2], [3, 4]] :
%c1 = arith.constant 1 : index
%sz1 = tensor.dim %arg2, %c1 : tensor<3x?x5xf32>
%t1 = tensor.expand_shape %arg2 [[0, 1], [2], [3, 4]] output_shape [1, 3, %sz1, 1, 5] :
tensor<3x?x5xf32> into tensor<1x3x?x1x5xf32>

// CHECK: tensor.collapse_shape {{.*}}: tensor<1x3x?x1x5xf32> into tensor<1x?x5xf32>
Expand All @@ -197,15 +200,18 @@ func.func @expand_collapse_shape_static(
func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
%arg1: memref<?x?x?xf32, strided<[?, ?, 1], offset: 0>>,
%arg2: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>,
%arg3: memref<?x42xf32, strided<[42, 1], offset: 0>>) {
%arg3: memref<?x42xf32, strided<[42, 1], offset: 0>>,
%arg4: index,
%arg5: index,
%arg6: index) {
// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
// CHECK-SAME: memref<?x?x?xf32> into memref<?x?xf32>
%0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
memref<?x?x?xf32> into memref<?x?xf32>

// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] output_shape [%arg4, 4, %arg5]
// CHECK-SAME: memref<?x?xf32> into memref<?x4x?xf32>
%r0 = memref.expand_shape %0 [[0, 1], [2]] :
%r0 = memref.expand_shape %0 [[0, 1], [2]] output_shape [%arg4, 4, %arg5] :
memref<?x?xf32> into memref<?x4x?xf32>

// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
Expand All @@ -214,9 +220,9 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
memref<?x?x?xf32, strided<[?, ?, 1], offset: 0>> into
memref<?x?xf32, strided<[?, 1], offset: 0>>

// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] output_shape [%arg4, 4, %arg5]
// CHECK-SAME: memref<?x?xf32, strided<[?, 1]>> into memref<?x4x?xf32, strided<[?, ?, 1]>>
%r1 = memref.expand_shape %1 [[0, 1], [2]] :
%r1 = memref.expand_shape %1 [[0, 1], [2]] output_shape [%arg4, 4, %arg5] :
memref<?x?xf32, strided<[?, 1], offset: 0>> into
memref<?x4x?xf32, strided<[?, ?, 1], offset: 0>>

Expand All @@ -226,9 +232,9 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> into
memref<?x?xf32, strided<[?, 1], offset: ?>>

// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]]
// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2]] output_shape [%arg4, 4, %arg5]
// CHECK-SAME: memref<?x?xf32, strided<[?, 1], offset: ?>> into memref<?x4x?xf32, strided<[?, ?, 1], offset: ?>>
%r2 = memref.expand_shape %2 [[0, 1], [2]] :
%r2 = memref.expand_shape %2 [[0, 1], [2]] output_shape [%arg4, 4, %arg5] :
memref<?x?xf32, strided<[?, 1], offset: ?>> into
memref<?x4x?xf32, strided<[?, ?, 1], offset: ?>>

Expand All @@ -238,22 +244,22 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
memref<?x42xf32, strided<[42, 1], offset: 0>> into
memref<?xf32, strided<[1]>>

// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1]]
// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1]] output_shape [%arg6, 42]
// CHECK-SAME: memref<?xf32, strided<[1]>> into memref<?x42xf32>
%r3 = memref.expand_shape %3 [[0, 1]] :
%r3 = memref.expand_shape %3 [[0, 1]] output_shape [%arg6, 42] :
memref<?xf32, strided<[1]>> into memref<?x42xf32>
return
}

func.func @expand_collapse_shape_zero_dim(%arg0 : memref<1x1xf32>, %arg1 : memref<f32>)
-> (memref<f32>, memref<1x1xf32>) {
%0 = memref.collapse_shape %arg0 [] : memref<1x1xf32> into memref<f32>
%1 = memref.expand_shape %0 [] : memref<f32> into memref<1x1xf32>
%1 = memref.expand_shape %0 [] output_shape [1, 1] : memref<f32> into memref<1x1xf32>
return %0, %1 : memref<f32>, memref<1x1xf32>
}
// CHECK-LABEL: func @expand_collapse_shape_zero_dim
// CHECK: memref.collapse_shape %{{.*}} [] : memref<1x1xf32> into memref<f32>
// CHECK: memref.expand_shape %{{.*}} [] : memref<f32> into memref<1x1xf32>
// CHECK: memref.expand_shape %{{.*}} [] output_shape [1, 1] : memref<f32> into memref<1x1xf32>

func.func @collapse_shape_to_dynamic
(%arg0: memref<?x?x?x4x?xf32>) -> memref<?x?x?xf32> {
Expand All @@ -270,16 +276,18 @@ func.func @collapse_shape_to_dynamic
// CHECK-LABEL: func @expand_collapse_shape_transposed_layout
func.func @expand_collapse_shape_transposed_layout(
%m0: memref<?x?xf32, strided<[1, 10], offset: 0>>,
%m1: memref<4x5x6xf32, strided<[1, ?, 1000], offset: 0>>) {
%m1: memref<4x5x6xf32, strided<[1, ?, 1000], offset: 0>>,
%sz0: index,
%sz1: index) {

%r0 = memref.expand_shape %m0 [[0], [1, 2]] :
%r0 = memref.expand_shape %m0 [[0], [1, 2]] output_shape [%sz0, %sz1, 5] :
memref<?x?xf32, strided<[1, 10], offset: 0>> into
memref<?x?x5xf32, strided<[1, 50, 10], offset: 0>>
%rr0 = memref.collapse_shape %r0 [[0], [1, 2]] :
memref<?x?x5xf32, strided<[1, 50, 10], offset: 0>> into
memref<?x?xf32, strided<[1, 10], offset: 0>>

%r1 = memref.expand_shape %m1 [[0, 1], [2], [3, 4]] :
%r1 = memref.expand_shape %m1 [[0, 1], [2], [3, 4]] output_shape [2, 2, 5, 2, 3] :
memref<4x5x6xf32, strided<[1, ?, 1000], offset: 0>> into
memref<2x2x5x2x3xf32, strided<[2, 1, ?, 3000, 1000], offset: 0>>
%rr1 = memref.collapse_shape %r1 [[0, 1], [2], [3, 4]] :
Expand Down
5 changes: 3 additions & 2 deletions mlir/test/Dialect/MemRef/runtime-verification.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

// CHECK-LABEL: func @expand_shape(
// CHECK-SAME: %[[m:.*]]: memref<?xf32>
// CHECK-SAME: %[[sz0:.*]]: index
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[c5:.*]] = arith.constant 5 : index
// CHECK-DAG: %[[dim:.*]] = memref.dim %[[m]], %[[c0]]
// CHECK: %[[mod:.*]] = arith.remsi %[[dim]], %[[c5]]
// CHECK: %[[cmpi:.*]] = arith.cmpi eq, %[[mod]], %[[c0]]
// CHECK: cf.assert %[[cmpi]], "ERROR: Runtime op verification failed
func.func @expand_shape(%m: memref<?xf32>) -> memref<?x5xf32> {
%0 = memref.expand_shape %m [[0, 1]] : memref<?xf32> into memref<?x5xf32>
func.func @expand_shape(%m: memref<?xf32>, %sz0: index) -> memref<?x5xf32> {
%0 = memref.expand_shape %m [[0, 1]] output_shape [%sz0, 5] : memref<?xf32> into memref<?x5xf32>
return %0 : memref<?x5xf32>
}
12 changes: 6 additions & 6 deletions mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
//
// CHECK-ROUND-LABEL: func.func @sparse_expand(
// CHECK-ROUND-SAME: %[[A:.*]]: tensor<100xf64, #sparse{{[0-9]*}}>) -> tensor<10x10xf64, #sparse{{[0-9]*}}>
// CHECK-ROUND: %[[E:.*]] = tensor.expand_shape %[[A]] {{\[\[}}0, 1]] : tensor<100xf64, #sparse{{[0-9]*}}> into tensor<10x10xf64, #sparse{{[0-9]*}}>
// CHECK-ROUND: %[[E:.*]] = tensor.expand_shape %[[A]] {{\[\[}}0, 1]] output_shape [10, 10] : tensor<100xf64, #sparse{{[0-9]*}}> into tensor<10x10xf64, #sparse{{[0-9]*}}>
// CHECK-ROUND: return %[[E]] : tensor<10x10xf64, #sparse{{[0-9]*}}>
//
// CHECK-LABEL: func.func @sparse_expand(
Expand All @@ -39,7 +39,7 @@
// CHECK: return %[[NT1]] : tensor<10x10xf64, #sparse{{[0-9]*}}>
//
func.func @sparse_expand(%arg0: tensor<100xf64, #SparseVector>) -> tensor<10x10xf64, #SparseMatrix> {
%0 = tensor.expand_shape %arg0 [[0, 1]] :
%0 = tensor.expand_shape %arg0 [[0, 1]] output_shape [10, 10] :
tensor<100xf64, #SparseVector> into tensor<10x10xf64, #SparseMatrix>
return %0 : tensor<10x10xf64, #SparseMatrix>
}
Expand Down Expand Up @@ -94,8 +94,8 @@ func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<10
// roundtrip:
//
// CHECK-ROUND-LABEL: func.func @dynamic_sparse_expand(
// CHECK-ROUND-SAME: %[[A:.*]]: tensor<?xf64, #sparse{{[0-9]*}}>) -> tensor<?x10xf64, #sparse{{[0-9]*}}>
// CHECK-ROUND: %[[E:.*]] = tensor.expand_shape %[[A]] {{\[\[}}0, 1]] : tensor<?xf64, #sparse{{[0-9]*}}> into tensor<?x10xf64, #sparse{{[0-9]*}}>
// CHECK-ROUND-SAME: %[[A:.*]]: tensor<?xf64, #sparse{{[0-9]*}}>, %[[SZ0:.*]]: index) -> tensor<?x10xf64, #sparse{{[0-9]*}}>
// CHECK-ROUND: %[[E:.*]] = tensor.expand_shape %[[A]] {{\[\[}}0, 1]] output_shape [%[[SZ0]], 10] : tensor<?xf64, #sparse{{[0-9]*}}> into tensor<?x10xf64, #sparse{{[0-9]*}}>
// CHECK-ROUND: return %[[E]] : tensor<?x10xf64, #sparse{{[0-9]*}}>
//
// CHECK-LABEL: func.func @dynamic_sparse_expand(
Expand Down Expand Up @@ -127,8 +127,8 @@ func.func @sparse_collapse(%arg0: tensor<10x10xf64, #SparseMatrix>) -> tensor<10
// CHECK-NOT: sparse_tensor.convert
// CHECK: return %[[NT1]] : tensor<?x10xf64, #sparse{{[0-9]*}}>
//
func.func @dynamic_sparse_expand(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?x10xf64, #SparseMatrix> {
%0 = tensor.expand_shape %arg0 [[0, 1]] :
func.func @dynamic_sparse_expand(%arg0: tensor<?xf64, #SparseVector>, %sz0: index) -> tensor<?x10xf64, #SparseMatrix> {
%0 = tensor.expand_shape %arg0 [[0, 1]] output_shape [%sz0, 10] :
tensor<?xf64, #SparseVector> into tensor<?x10xf64, #SparseMatrix>
return %0 : tensor<?x10xf64, #SparseMatrix>
}
Expand Down
24 changes: 14 additions & 10 deletions mlir/test/Dialect/Tensor/bufferize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -367,11 +367,14 @@ func.func @tensor.insert(%t1: tensor<5xf32>, %idx1: index, %f: f32) -> tensor<5x

// CHECK-LABEL: func @tensor.expand_shape(
// CHECK-SAME: %[[t1:.*]]: tensor<?x10xf32>
func.func @tensor.expand_shape(%t1: tensor<?x10xf32>) -> tensor<2x?x10xf32> {
func.func @tensor.expand_shape(%t1: tensor<?x10xf32>, %sz0: index) -> tensor<2x?x10xf32> {
// CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x10xf32>
// CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] [
// CHECK-SAME: [0, 1], [2]] : memref<?x10xf32> into memref<2x?x10xf32>
%0 = tensor.expand_shape %t1 [[0, 1], [2]]
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM:.*]] = memref.dim %[[m1]], %[[C0]] : memref<?x10xf32>
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_1:.*]] = arith.divui %[[DIM]], %[[C2]] : index
// CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1], [2]] output_shape [2, %[[VAL_1]], 10] : memref<?x10xf32> into memref<2x?x10xf32>
%0 = tensor.expand_shape %t1 [[0, 1], [2]] output_shape [2, %sz0, 10]
: tensor<?x10xf32> into tensor<2x?x10xf32>

// CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
Expand All @@ -384,14 +387,15 @@ func.func @tensor.expand_shape(%t1: tensor<?x10xf32>) -> tensor<2x?x10xf32> {
// CHECK-LABEL: func @tensor.expand_shape_of_slice(
// CHECK-SAME: %[[t1:.*]]: tensor<?x20xf32>
func.func @tensor.expand_shape_of_slice(
%t1: tensor<?x20xf32>, %o1: index, %s1: index) -> tensor<?x7x2x5xf32> {
%t1: tensor<?x20xf32>, %o1: index, %s1: index, %sz0: index) -> tensor<?x7x2x5xf32> {
// CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?x20xf32>
// CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}, 5] [%{{.*}}, 10] [1, 1] : memref<?x20xf32> to memref<?x10xf32, strided<[20, 1], offset: ?>>
%0 = tensor.extract_slice %t1[%o1, 5][%s1, 10][1, 1] :
tensor<?x20xf32> to tensor<?x10xf32>
// CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [
// CHECK-SAME: [0, 1], [2, 3]] : memref<?x10xf32, strided<[20, 1], offset: ?>> into memref<?x7x2x5xf32, strided<[140, 20, 5, 1], offset: ?>>
%1 = tensor.expand_shape %0 [[0, 1], [2, 3]] :
// CHECK: %[[C7:.*]] = arith.constant 7 : index
// CHECK: %[[VAL_1:.*]] = arith.divui %{{.*}}, %[[C7]] : index
// CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 7, 2, 5] : memref<?x10xf32, strided<[20, 1], offset: ?>> into memref<?x7x2x5xf32, strided<[140, 20, 5, 1], offset: ?>>
%1 = tensor.expand_shape %0 [[0, 1], [2, 3]] output_shape [%sz0, 7, 2, 5] :
tensor<?x10xf32> into tensor<?x7x2x5xf32>
// CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
// CHECK: return %[[r]]
Expand All @@ -407,8 +411,8 @@ func.func @tensor.expand_shape_of_scalar_slice(
// CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<?xf32>
// CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}] [1] [1] : memref<?xf32> to memref<f32, strided<[], offset: ?>>
%0 = tensor.extract_slice %t1[%o1][1][1] : tensor<?xf32> to tensor<f32>
// CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [] : memref<f32, strided{{.*}}> into memref<1xf32, strided<[1], offset: ?>>
%1 = tensor.expand_shape %0 [] : tensor<f32> into tensor<1xf32>
// CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [] output_shape [1] : memref<f32, strided{{.*}}> into memref<1xf32, strided<[1], offset: ?>>
%1 = tensor.expand_shape %0 [] output_shape [1] : tensor<f32> into tensor<1xf32>
// CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
// CHECK: return %[[r]]
return %1 : tensor<1xf32>
Expand Down
112 changes: 53 additions & 59 deletions mlir/test/Dialect/Tensor/canonicalize.mlir

Large diffs are not rendered by default.

5 changes: 2 additions & 3 deletions mlir/test/Dialect/Tensor/fold-empty-op.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@ module attributes {transform.with_named_sequence} {
// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
// CHECK: #[[$MAP2:.+]] = affine_map<()[s0] -> (s0 * 28)>

func.func @empty_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> {
func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4x?x7xf32> {
%0 = tensor.empty(%arg0) : tensor<6x5x?xf32>
%1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4, 5]]
: tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
%1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4, 5]] output_shape [2, 3, 5, 4, %sz0, 7] : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
return %1 : tensor<2x3x5x4x?x7xf32>
}
// CHECK-LABEL: func @empty_reshape_expansion
Expand Down
6 changes: 4 additions & 2 deletions mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ func.func @expand_shape_of_rank_reducing_extract(
{
%0 = tensor.extract_slice %t[0, 0, 0, 0][%idx, 1, 1, 5][1, 1, 1, 1]
: tensor<?x?x?x?xf32> to tensor<?x1x5xf32>
%1 = tensor.expand_shape %0 [[0], [1, 2], [3]]
%c0 = arith.constant 0 : index
%sz0 = tensor.dim %0, %c0 : tensor<?x1x5xf32>
%1 = tensor.expand_shape %0 [[0], [1, 2], [3]] output_shape [%sz0, 1, 1, 5]
: tensor<?x1x5xf32> into tensor<?x1x1x5xf32>
%2 = tensor.expand_shape %0 [[0, 1], [2], [3]]
%2 = tensor.expand_shape %0 [[0, 1], [2], [3]] output_shape [%sz0, 1, 1, 5]
: tensor<?x1x5xf32> into tensor<?x1x1x5xf32>
return %1, %2 : tensor<?x1x1x5xf32>, tensor<?x1x1x5xf32>
}
Expand Down
21 changes: 5 additions & 16 deletions mlir/test/Dialect/Tensor/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -273,21 +273,10 @@ func.func @insert_slice_wrong_dynamic_type(%t1: tensor<?x4x4xf32>, %t2: tensor<8

// -----

func.func @illegal_expanding_reshape_dynamic_tensor
(%arg0: tensor<?x?x?xf32>) -> tensor<?x?x?x4x?xf32> {
// expected-error @+1 {{invalid to have a single dimension (2) expanded into multiple dynamic dims (2,4)}}
%0 = tensor.expand_shape %arg0 [[0], [1], [2, 3, 4]]
: tensor<?x?x?xf32> into tensor<?x?x?x4x?xf32>
return %0 : tensor<?x?x?x4x?xf32>
}

// -----


func.func @illegal_expanding_reshape_static_tensor
(%arg0: tensor<2x3x20xf32>) -> tensor<2x3x2x4x5xf32> {
// expected-error @+1 {{expected dimension 2 of collapsed type to be static value of 40}}
%0 = tensor.expand_shape %arg0 [[0], [1], [2, 3, 4]]
%0 = tensor.expand_shape %arg0 [[0], [1], [2, 3, 4]] output_shape [2, 3, 2, 4, 5]
: tensor<2x3x20xf32> into tensor<2x3x2x4x5xf32>
return %0 : tensor<2x3x2x4x5xf32>
}
Expand All @@ -304,20 +293,20 @@ func.func @illegal_collapsing_reshape_static_tensor

// -----

func.func @illegal_expanding_reshape_mixed_tensor(%arg0 : tensor<?x?xf32>)
func.func @illegal_expanding_reshape_mixed_tensor(%arg0 : tensor<?x?xf32>, %sz0: index)
-> tensor<?x4x5xf32> {
// expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 5}}
%0 = tensor.expand_shape %arg0 [[0, 1], [2]]
%0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%sz0, 4, 5]
: tensor<?x?xf32> into tensor<?x4x5xf32>
return %0 : tensor<?x4x5xf32>
}

// -----

func.func @illegal_expanding_reshape_mixed_tensor_2(%arg0 : tensor<?x?xf32>)
func.func @illegal_expanding_reshape_mixed_tensor_2(%arg0 : tensor<?x?xf32>, %sz0: index)
-> tensor<?x4x5xf32> {
// expected-error @+1 {{expected dimension 1 of collapsed type to be static value of 20}}
%0 = tensor.expand_shape %arg0 [[0], [1, 2]]
%0 = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [%sz0, 4, 5]
: tensor<?x?xf32> into tensor<?x4x5xf32>
return %0 : tensor<?x4x5xf32>
}
Expand Down
18 changes: 16 additions & 2 deletions mlir/test/Dialect/Tensor/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,26 @@ func.func @insert_slice(
func.func @tensor_reshape_zero_dim(%arg0 : tensor<1x1xf32>, %arg1 : tensor<f32>)
-> (tensor<f32>, tensor<1x1xf32>) {
%0 = tensor.collapse_shape %arg0 [] : tensor<1x1xf32> into tensor<f32>
%1 = tensor.expand_shape %0 [] : tensor<f32> into tensor<1x1xf32>
%1 = tensor.expand_shape %0 [] output_shape [1, 1] : tensor<f32> into tensor<1x1xf32>
return %0, %1 : tensor<f32>, tensor<1x1xf32>
}
// CHECK-LABEL: func @tensor_reshape_zero_dim
// CHECK: tensor.collapse_shape %{{.*}} [] : tensor<1x1xf32> into tensor<f32>
// CHECK: tensor.expand_shape %{{.*}} [] : tensor<f32> into tensor<1x1xf32>
// CHECK: tensor.expand_shape %{{.*}} [] output_shape [1, 1] : tensor<f32> into tensor<1x1xf32>

// -----

func.func @tensor_expand_shape_dynamic_dim(%arg0 : tensor<?x?xf32>, %sz0 : index, %sz1 : index, %sz2 : index)
-> (tensor<5x?x?x?xf32>) {
%1 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [5, %sz0, %sz1, %sz2] : tensor<?x?xf32> into tensor<5x?x?x?xf32>
return %1 : tensor<5x?x?x?xf32>
}

// CHECK-LABEL: func.func @tensor_expand_shape_dynamic_dim(%arg0: tensor<?x?xf32>, %arg1: index, %arg2: index, %arg3: index) -> tensor<5x?x?x?xf32> {
// CHECK: %expanded = tensor.expand_shape %arg0 {{\[\[}}0, 1], [2, 3{{\]\]}} output_shape [5, %arg1, %arg2, %arg3] : tensor<?x?xf32> into tensor<5x?x?x?xf32>
// CHECK: return %expanded : tensor<5x?x?x?xf32>
// CHECK: }


// -----

Expand Down
14 changes: 7 additions & 7 deletions mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

// CHECK-LABEL: func.func @single_dim_packing(
// CHECK-SAME: %[[ARG0:.+]]: tensor<256xf32>)
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<256xf32> into tensor<8x32xf32>
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] output_shape [8, 32] : tensor<256xf32> into tensor<8x32xf32>
// CHECK: return %[[EXPANDED]] : tensor<8x32xf32>
func.func @single_dim_packing(%arg0: tensor<256xf32>) -> tensor<8x32xf32> {
%empty = tensor.empty() : tensor<8x32xf32>
Expand All @@ -27,7 +27,7 @@ func.func @single_dim_packing_with_padding(%arg0: tensor<255xf32>) -> tensor<8x3

// CHECK-LABEL: func.func @single_last_inner_dim_packing(
// CHECK-SAME: %[[ARG0:.+]]: tensor<5x256xf32>)
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<5x256xf32> into tensor<5x8x32xf32>
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] output_shape [5, 8, 32] : tensor<5x256xf32> into tensor<5x8x32xf32>
// CHECK: return %[[EXPANDED]] : tensor<5x8x32xf32>
func.func @single_last_inner_dim_packing(%arg0: tensor<5x256xf32>) -> tensor<5x8x32xf32> {
%empty = tensor.empty() : tensor<5x8x32xf32>
Expand All @@ -39,7 +39,7 @@ func.func @single_last_inner_dim_packing(%arg0: tensor<5x256xf32>) -> tensor<5x8

// CHECK-LABEL: func.func @pack_1d_with_outer_dims_perm(
// CHECK-SAME: %[[ARG0:.+]]: tensor<64xf32>)
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<64xf32> into tensor<2x32xf32>
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1]] output_shape [2, 32] : tensor<64xf32> into tensor<2x32xf32>
// CHECK: return %[[EXPANDED]] : tensor<2x32xf32>
func.func @pack_1d_with_outer_dims_perm(%arg0: tensor<64xf32>) -> tensor<2x32xf32> {
%empty = tensor.empty() : tensor<2x32xf32>
Expand All @@ -51,7 +51,7 @@ func.func @pack_1d_with_outer_dims_perm(%arg0: tensor<64xf32>) -> tensor<2x32xf3

// CHECK-LABEL: func.func @single_last_inner_dim_packing_with_identity_outer_dims_perm(
// CHECK-SAME: %[[ARG0:.+]]: tensor<5x256xf32>)
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<5x256xf32> into tensor<5x8x32xf32>
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] output_shape [5, 8, 32] : tensor<5x256xf32> into tensor<5x8x32xf32>
// CHECK: return %[[EXPANDED]] : tensor<5x8x32xf32>
func.func @single_last_inner_dim_packing_with_identity_outer_dims_perm(%arg0: tensor<5x256xf32>) -> tensor<5x8x32xf32> {
%empty = tensor.empty() : tensor<5x8x32xf32>
Expand Down Expand Up @@ -85,7 +85,7 @@ func.func @single_first_inner_dim_packing(%arg0: tensor<256x5xf32>) -> tensor<8x

// CHECK-LABEL: func.func @pack_1x32_to_1x32x1x1
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]]
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]] output_shape [1, 32, 1, 1]
// CHECK: return %[[EXPANDED]]
func.func @pack_1x32_to_1x32x1x1(%arg0 : tensor<1x32xf32>) -> tensor<1x32x1x1xf32> {
%empty = tensor.empty() : tensor<1x32x1x1xf32>
Expand All @@ -98,7 +98,7 @@ func.func @pack_1x32_to_1x32x1x1(%arg0 : tensor<1x32xf32>) -> tensor<1x32x1x1xf3

// CHECK-LABEL: func.func @pack_1x32_to_1x16x1x2
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]]
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2, 3]] output_shape [1, 16, 1, 2]
// CHECK: return %[[EXPANDED]]
func.func @pack_1x32_to_1x16x1x2(%arg0 : tensor<1x32xf32>) -> tensor<1x16x1x2xf32> {
%empty = tensor.empty() : tensor<1x16x1x2xf32>
Expand All @@ -111,7 +111,7 @@ func.func @pack_1x32_to_1x16x1x2(%arg0 : tensor<1x32xf32>) -> tensor<1x16x1x2xf3

// CHECK-LABEL: func.func @pack_32x1_to_16x1x2x1
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] output_shape [1, 16, 2, 1]
// CHECK: return %[[EXPANDED]]
func.func @pack_32x1_to_16x1x2x1(%arg0 : tensor<32x1xf32>) -> tensor<1x16x2x1xf32> {
%empty = tensor.empty() : tensor<1x16x2x1xf32>
Expand Down
1 change: 1 addition & 0 deletions utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -3831,6 +3831,7 @@ cc_library(
includes = ["include"],
deps = [
":DialectUtilsIncGen",
":ArithDialect",
":IR",
":Support",
"//llvm:Support",
Expand Down