Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s

// CHECK-LABEL: func @nop_shape_cast
Expand Down Expand Up @@ -124,9 +123,35 @@ func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
return %s : vector<2x1x3xf32>
}

// CHECK-LABEL: func.func @shape_cast_0d1d(
// CHECK-SAME: %[[VAL_0:.*]]: vector<f32>) -> vector<1xf32> {
// CHECK: %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
// CHECK: %[[VAL_2:.*]] = vector.extractelement %[[VAL_0]][] : vector<f32>
// CHECK: %[[VAL_3:.*]] = vector.insert %[[VAL_2]], %[[VAL_1]] [0] : f32 into vector<1xf32>
// CHECK: return %[[VAL_3]] : vector<1xf32>
// CHECK: }

func.func @shape_cast_0d1d(%arg0 : vector<f32>) -> vector<1xf32> {
%s = vector.shape_cast %arg0 : vector<f32> to vector<1xf32>
return %s : vector<1xf32>
}

// CHECK-LABEL: func.func @shape_cast_1d0d(
// CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>) -> vector<f32> {
// CHECK: %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : vector<f32>
// CHECK: %[[VAL_2:.*]] = vector.extract %[[VAL_0]][0] : vector<1xf32>
// CHECK: %[[VAL_3:.*]] = vector.insertelement %[[VAL_2]], %[[VAL_1]][] : vector<f32>
// CHECK: return %[[VAL_3]] : vector<f32>
// CHECK: }

func.func @shape_cast_1d0d(%arg0 : vector<1xf32>) -> vector<f32> {
%s = vector.shape_cast %arg0 : vector<1xf32> to vector<f32>
return %s : vector<f32>
}

transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
%f = transform.structured.match ops{["func.func"]} in %module_op
%f = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op

%f2 = transform.vector.lower_shape_cast %f
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ func.func @transfer_read_rank_reducing(
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
// CHECK: vector.transfer_read %[[SUBVIEW]]

transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
transform.vector.apply_rank_reducing_subview_patterns %module_op
: (!pdl.operation) -> !pdl.operation
}

// -----

func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, %vec : vector<3x2xi8>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
Expand All @@ -28,6 +36,97 @@ func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6,
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]

transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
transform.vector.apply_rank_reducing_subview_patterns %module_op
: (!pdl.operation) -> !pdl.operation
}

// -----

func.func @transfer_read_and_vector_rank_reducing(
%arg : memref<1x1x3x2x1xf32>) -> vector<3x2x1xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.0 : f32
%v = vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0], %cst :
memref<1x1x3x2x1xf32>, vector<3x2x1xf32>
return %v : vector<3x2x1xf32>
}

// CHECK-LABEL: func @transfer_read_and_vector_rank_reducing
// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2x1xf32>
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0] [1, 1, 3, 2, 1] [1, 1, 1, 1, 1]
// CHECK-SAME: memref<1x1x3x2x1xf32> to memref<3x2xf32>
// CHECK: vector.transfer_read %[[SUBVIEW]]{{.*}} {in_bounds = [true, true]} : memref<3x2xf32>, vector<3x2xf32>

transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
transform.vector.apply_rank_reducing_subview_patterns %module_op
: (!pdl.operation) -> !pdl.operation
}

// -----

func.func @transfer_write_and_vector_rank_reducing(
%arg : memref<1x1x3x2x1xf32>,
%vec : vector<3x2x1xf32>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0, %c0] :
vector<3x2x1xf32>, memref<1x1x3x2x1xf32>
return
}

// CHECK-LABEL: func @transfer_write_and_vector_rank_reducing
// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2x1xf32>
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0] [1, 1, 3, 2, 1] [1, 1, 1, 1, 1]
// CHECK-SAME: memref<1x1x3x2x1xf32> to memref<3x2xf32>
// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}} {in_bounds = [true, true]} : vector<3x2xf32>, memref<3x2xf32>

transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
transform.vector.apply_rank_reducing_subview_patterns %module_op
: (!transform.any_op) -> !transform.any_op
}

// -----

func.func @transfer_read_and_vector_rank_reducing_to_0d(
%arg : memref<1x1x1x1x1xf32>) -> vector<1x1x1xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.0 : f32
%v = vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0], %cst :
memref<1x1x1x1x1xf32>, vector<1x1x1xf32>
return %v : vector<1x1x1xf32>
}

// CHECK-LABEL: func @transfer_read_and_vector_rank_reducing_to_0d
// CHECK-SAME: %[[MEMREF:.+]]: memref<1x1x1x1x1xf32>
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[MEMREF]][0, 0, 0, 0, 0] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : memref<1x1x1x1x1xf32> to memref<f32>
// CHECK: %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]]{{.*}} : memref<f32>, vector<f32>
// CHECK: vector.shape_cast %[[READ]] : vector<f32> to vector<1x1x1xf32>

transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
transform.vector.apply_rank_reducing_subview_patterns %module_op
: (!pdl.operation) -> !pdl.operation
}

// -----

func.func @transfer_write_and_vector_rank_reducing_to_0d(
%arg : memref<1x1x1x1x1xf32>,
%vec : vector<1x1x1xf32>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0, %c0] :
vector<1x1x1xf32>, memref<1x1x1x1x1xf32>
return
}

// CHECK-LABEL: func @transfer_write_and_vector_rank_reducing_to_0d
// CHECK-SAME: %[[MEMREF:.+]]: memref<1x1x1x1x1xf32>, %[[VECTOR:.+]]: vector<1x1x1xf32>
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[MEMREF]][0, 0, 0, 0, 0] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : memref<1x1x1x1x1xf32> to memref<f32>
// CHECK: %[[SHCAST:.+]] = vector.shape_cast %[[VECTOR]] : vector<1x1x1xf32> to vector<f32>
// CHECK: vector.transfer_write %[[SHCAST]], %[[SUBVIEW]]{{.*}} : vector<f32>, memref<f32>

transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
Expand Down