1,235 changes: 0 additions & 1,235 deletions mlir/test/Dialect/Vector/vector-contract-transforms.mlir

This file was deleted.

101 changes: 101 additions & 0 deletions mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s

// CHECK-LABEL: func @genbool_1d
// CHECK: %[[T0:.*]] = arith.constant dense<[true, true, true, true, false, false, false, false]> : vector<8xi1>
// CHECK: return %[[T0]] : vector<8xi1>

func.func @genbool_1d() -> vector<8xi1> {
%0 = vector.constant_mask [4] : vector<8xi1>
return %0 : vector<8xi1>
}

// CHECK-LABEL: func @genbool_2d
// CHECK: %[[C1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
// CHECK: %[[C2:.*]] = arith.constant dense<false> : vector<4x4xi1>
// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<4x4xi1>
// CHECK: %[[T1:.*]] = vector.insert %[[C1]], %[[T0]] [1] : vector<4xi1> into vector<4x4xi1>
// CHECK: return %[[T1]] : vector<4x4xi1>

func.func @genbool_2d() -> vector<4x4xi1> {
%v = vector.constant_mask [2, 2] : vector<4x4xi1>
return %v: vector<4x4xi1>
}

// CHECK-LABEL: func @genbool_3d
// CHECK: %[[C1:.*]] = arith.constant dense<[true, true, true, false]> : vector<4xi1>
// CHECK: %[[C2:.*]] = arith.constant dense<false> : vector<3x4xi1>
// CHECK: %[[C3:.*]] = arith.constant dense<false> : vector<2x3x4xi1>
// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<3x4xi1>
// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C3]] [0] : vector<3x4xi1> into vector<2x3x4xi1>
// CHECK: return %[[T1]] : vector<2x3x4xi1>

func.func @genbool_3d() -> vector<2x3x4xi1> {
%v = vector.constant_mask [1, 1, 3] : vector<2x3x4xi1>
return %v: vector<2x3x4xi1>
}

// CHECK-LABEL: func @genbool_var_1d(
// CHECK-SAME: %[[A:.*]]: index)
// CHECK: %[[T0:.*]] = vector.create_mask %[[A]] : vector<3xi1>
// CHECK: return %[[T0]] : vector<3xi1>

func.func @genbool_var_1d(%arg0: index) -> vector<3xi1> {
%0 = vector.create_mask %arg0 : vector<3xi1>
return %0 : vector<3xi1>
}

// CHECK-LABEL: func @genbool_var_2d(
// CHECK-SAME: %[[A:.*0]]: index,
// CHECK-SAME: %[[B:.*1]]: index)
// CHECK: %[[C1:.*]] = arith.constant dense<false> : vector<3xi1>
// CHECK: %[[C2:.*]] = arith.constant dense<false> : vector<2x3xi1>
// CHECK: %[[c0:.*]] = arith.constant 0 : index
// CHECK: %[[c1:.*]] = arith.constant 1 : index
// CHECK: %[[T0:.*]] = vector.create_mask %[[B]] : vector<3xi1>
// CHECK: %[[T1:.*]] = arith.cmpi sgt, %[[A]], %[[c0]] : index
// CHECK: %[[T2:.*]] = arith.select %[[T1]], %[[T0]], %[[C1]] : vector<3xi1>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<3xi1> into vector<2x3xi1>
// CHECK: %[[T4:.*]] = arith.cmpi sgt, %[[A]], %[[c1]] : index
// CHECK: %[[T5:.*]] = arith.select %[[T4]], %[[T0]], %[[C1]] : vector<3xi1>
// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[T3]] [1] : vector<3xi1> into vector<2x3xi1>
// CHECK: return %[[T6]] : vector<2x3xi1>

func.func @genbool_var_2d(%arg0: index, %arg1: index) -> vector<2x3xi1> {
%0 = vector.create_mask %arg0, %arg1 : vector<2x3xi1>
return %0 : vector<2x3xi1>
}

// CHECK-LABEL: func @genbool_var_3d(
// CHECK-SAME: %[[A:.*0]]: index,
// CHECK-SAME: %[[B:.*1]]: index,
// CHECK-SAME: %[[C:.*2]]: index)
// CHECK-DAG: %[[C1:.*]] = arith.constant dense<false> : vector<7xi1>
// CHECK-DAG: %[[C2:.*]] = arith.constant dense<false> : vector<1x7xi1>
// CHECK-DAG: %[[C3:.*]] = arith.constant dense<false> : vector<2x1x7xi1>
// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK: %[[T0:.*]] = vector.create_mask %[[C]] : vector<7xi1>
// CHECK: %[[T1:.*]] = arith.cmpi sgt, %[[B]], %[[c0]] : index
// CHECK: %[[T2:.*]] = arith.select %[[T1]], %[[T0]], %[[C1]] : vector<7xi1>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C2]] [0] : vector<7xi1> into vector<1x7xi1>
// CHECK: %[[T4:.*]] = arith.cmpi sgt, %[[A]], %[[c0]] : index
// CHECK: %[[T5:.*]] = arith.select %[[T4]], %[[T3]], %[[C2]] : vector<1x7xi1>
// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[C3]] [0] : vector<1x7xi1> into vector<2x1x7xi1>
// CHECK: %[[T7:.*]] = arith.cmpi sgt, %[[A]], %[[c1]] : index
// CHECK: %[[T8:.*]] = arith.select %[[T7]], %[[T3]], %[[C2]] : vector<1x7xi1>
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T6]] [1] : vector<1x7xi1> into vector<2x1x7xi1>
// CHECK: return %[[T9]] : vector<2x1x7xi1>

func.func @genbool_var_3d(%arg0: index, %arg1: index, %arg2: index) -> vector<2x1x7xi1> {
%0 = vector.create_mask %arg0, %arg1, %arg2 : vector<2x1x7xi1>
return %0 : vector<2x1x7xi1>
}

transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
%f = transform.structured.match ops{["func.func"]} in %module_op
: (!pdl.operation) -> !pdl.operation

transform.vector.lower_mask %f
: (!pdl.operation) -> !pdl.operation
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns -split-input-file | FileCheck %s
// RUN: mlir-opt %s --test-transform-dialect-interpreter | FileCheck %s

func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
%0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
Expand Down Expand Up @@ -264,3 +264,10 @@ func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc
// CHECK-LABEL: func @vector_multi_reduction_parallel_middle
// CHECK-SAME: %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32>
// CHECK: vector.transpose %[[INPUT]], [1, 0, 2] : vector<3x4x5xf32> to vector<4x3x5xf32>

transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
transform.vector.lower_multi_reduction %module_op
lowering_strategy = "innerreduction"
: (!pdl.operation) -> !pdl.operation
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns="use-outer-reductions" | FileCheck %s
// RUN: mlir-opt %s --test-transform-dialect-interpreter | FileCheck %s

func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
%0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
Expand Down Expand Up @@ -187,3 +187,10 @@ func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x3xf32>, %acc: f32) -
}
// CHECK-LABEL: func @vector_multi_reduction_to_scalar
// CHECK: return %{{.+}}

transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
transform.vector.lower_multi_reduction %module_op
lowering_strategy = "innerparallel"
: (!pdl.operation) -> !pdl.operation
}
148 changes: 148 additions & 0 deletions mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s

// CHECK-LABEL: func @outerproduct_noacc
// CHECK-SAME: %[[A:.*0]]: vector<2xf32>,
// CHECK-SAME: %[[B:.*1]]: vector<3xf32>
// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32>
// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
// CHECK: %[[T2:.*]] = arith.mulf %[[T1]], %[[B]] : vector<3xf32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2xf32>
// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xf32>
// CHECK: %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xf32> into vector<2x3xf32>
// CHECK: return %[[T7]] : vector<2x3xf32>

func.func @outerproduct_noacc(%arg0: vector<2xf32>,
%arg1: vector<3xf32>) -> vector<2x3xf32> {
%0 = vector.outerproduct %arg0, %arg1 : vector<2xf32>, vector<3xf32>
return %0: vector<2x3xf32>
}

// CHECK-LABEL: func @outerproduct_acc
// CHECK-SAME: %[[A:.*0]]: vector<2xf32>,
// CHECK-SAME: %[[B:.*1]]: vector<3xf32>,
// CHECK-SAME: %[[C:.*2]]: vector<2x3xf32>
// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xf32>
// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xf32>
// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<2x3xf32>
// CHECK: %[[T3:.*]] = vector.fma %[[T1]], %[[B]], %[[T2]] : vector<3xf32>
// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0] : vector<3xf32> into vector<2x3xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2xf32>
// CHECK: %[[T6:.*]] = vector.splat %[[T5]] : vector<3xf32>
// CHECK: %[[T7:.*]] = vector.extract %[[C]][1] : vector<2x3xf32>
// CHECK: %[[T8:.*]] = vector.fma %[[T6]], %[[B]], %[[T7]] : vector<3xf32>
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : vector<3xf32> into vector<2x3xf32>
// CHECK: return %[[T9]] : vector<2x3xf32>

func.func @outerproduct_acc(%arg0: vector<2xf32>,
%arg1: vector<3xf32>,
%arg2: vector<2x3xf32>) -> vector<2x3xf32> {
%0 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xf32>, vector<3xf32>
return %0: vector<2x3xf32>
}

// CHECK-LABEL: func @outerproduct_noacc_int
// CHECK-SAME: %[[A:.*0]]: vector<2xi32>,
// CHECK-SAME: %[[B:.*1]]: vector<3xi32>
// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xi32>
// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
// CHECK: %[[T2:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2xi32>
// CHECK: %[[T5:.*]] = vector.splat %[[T4]] : vector<3xi32>
// CHECK: %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xi32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T3]] [1] : vector<3xi32> into vector<2x3xi32>
// CHECK: return %[[T7]] : vector<2x3xi32>
func.func @outerproduct_noacc_int(%arg0: vector<2xi32>,
%arg1: vector<3xi32>) -> vector<2x3xi32> {
%0 = vector.outerproduct %arg0, %arg1 : vector<2xi32>, vector<3xi32>
return %0: vector<2x3xi32>
}

// CHECK-LABEL: func @outerproduct_acc_int
// CHECK-SAME: %[[A:.*0]]: vector<2xi32>,
// CHECK-SAME: %[[B:.*1]]: vector<3xi32>,
// CHECK-SAME: %[[C:.*2]]: vector<2x3xi32>
// CHECK: %[[C0:.*]] = arith.constant dense<0> : vector<2x3xi32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2xi32>
// CHECK: %[[T1:.*]] = vector.splat %[[T0]] : vector<3xi32>
// CHECK: %[[T2:.*]] = vector.extract %[[C]][0] : vector<2x3xi32>
// CHECK: %[[T3:.*]] = arith.muli %[[T1]], %[[B]] : vector<3xi32>
// CHECK: %[[T4:.*]] = arith.addi %[[T3]], %[[T2]] : vector<3xi32>
// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[C0]] [0] : vector<3xi32> into vector<2x3xi32>
// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xi32>
// CHECK: %[[T7:.*]] = vector.splat %[[T6]] : vector<3xi32>
// CHECK: %[[T8:.*]] = vector.extract %[[C]][1] : vector<2x3xi32>
// CHECK: %[[T9:.*]] = arith.muli %[[T7]], %[[B]] : vector<3xi32>
// CHECK: %[[T10:.*]] = arith.addi %[[T9]], %[[T8]] : vector<3xi32>
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T5]] [1] : vector<3xi32> into vector<2x3xi32>
// CHECK: return %[[T11]] : vector<2x3xi32>
func.func @outerproduct_acc_int(%arg0: vector<2xi32>,
%arg1: vector<3xi32>,
%arg2: vector<2x3xi32>) -> vector<2x3xi32> {
%0 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xi32>, vector<3xi32>
return %0: vector<2x3xi32>
}

// CHECK-LABEL: func @axpy_fp(
// CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
// CHECK-SAME: %[[B:.*1]]: f32)
// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
// CHECK: %[[T1:.*]] = arith.mulf %[[A]], %[[T0]] : vector<16xf32>
// CHECK: return %[[T1]] : vector<16xf32>
func.func @axpy_fp(%arg0: vector<16xf32>, %arg1: f32) -> vector<16xf32> {
%0 = vector.outerproduct %arg0, %arg1: vector<16xf32>, f32
return %0: vector<16xf32>
}

// CHECK-LABEL: func @axpy_fp_add(
// CHECK-SAME: %[[A:.*0]]: vector<16xf32>,
// CHECK-SAME: %[[B:.*1]]: f32,
// CHECK-SAME: %[[C:.*2]]: vector<16xf32>)
// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xf32>
// CHECK: %[[T1:.*]] = vector.fma %[[A]], %[[T0]], %[[C]] : vector<16xf32>
// CHECK: return %[[T1]] : vector<16xf32>
func.func @axpy_fp_add(%arg0: vector<16xf32>, %arg1: f32, %arg2 : vector<16xf32>) -> vector<16xf32> {
%0 = vector.outerproduct %arg0, %arg1, %arg2: vector<16xf32>, f32
return %0: vector<16xf32>
}

// CHECK-LABEL: func @axpy_int(
// CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
// CHECK-SAME: %[[B:.*1]]: i32)
// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
// CHECK: return %[[T1]] : vector<16xi32>
func.func @axpy_int(%arg0: vector<16xi32>, %arg1: i32) -> vector<16xi32> {
%0 = vector.outerproduct %arg0, %arg1: vector<16xi32>, i32
return %0: vector<16xi32>
}

// CHECK-LABEL: func @axpy_int_add(
// CHECK-SAME: %[[A:.*0]]: vector<16xi32>,
// CHECK-SAME: %[[B:.*1]]: i32,
// CHECK-SAME: %[[C:.*2]]: vector<16xi32>)
// CHECK: %[[T0:.*]] = vector.splat %[[B]] : vector<16xi32>
// CHECK: %[[T1:.*]] = arith.muli %[[A]], %[[T0]] : vector<16xi32>
// CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[C]] : vector<16xi32>
// CHECK: return %[[T2]] : vector<16xi32>
func.func @axpy_int_add(%arg0: vector<16xi32>, %arg1: i32, %arg2: vector<16xi32>) -> vector<16xi32> {
%0 = vector.outerproduct %arg0, %arg1, %arg2: vector<16xi32>, i32
return %0: vector<16xi32>
}

transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
%f = transform.structured.match ops{["func.func"]} in %module_op
: (!pdl.operation) -> !pdl.operation

%f2 = transform.vector.lower_outerproduct %f
: (!pdl.operation) -> !pdl.operation

%f3 = transform.vector.lower_broadcast %f2
: (!pdl.operation) -> !pdl.operation
}
134 changes: 134 additions & 0 deletions mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@

// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s

// CHECK-LABEL: func @nop_shape_cast
// CHECK-SAME: %[[A:.*]]: vector<16xf32>
// CHECK: return %[[A]] : vector<16xf32>
func.func @nop_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
%0 = vector.shape_cast %arg0 : vector<16xf32> to vector<16xf32>
return %0 : vector<16xf32>
}

// CHECK-LABEL: func @cancel_shape_cast
// CHECK-SAME: %[[A:.*]]: vector<16xf32>
// CHECK: return %[[A]] : vector<16xf32>

func.func @cancel_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> {
%0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32>
%1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32>
return %1 : vector<16xf32>
}

// Shape up and downcasts for 2-D vectors, for supporting conversion to
// llvm.matrix operations
// CHECK-LABEL: func @shape_casts
func.func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) {
// CHECK-DAG: %[[cst22:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
// CHECK-DAG: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
// CHECK: %[[ex0:.*]] = vector.extract %{{.*}}[0] : vector<2x2xf32>
//
// CHECK: %[[in0:.*]] = vector.insert_strided_slice %[[ex0]], %[[cst]]
// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
//
// CHECK: %[[ex1:.*]] = vector.extract %{{.*}}[1] : vector<2x2xf32>
//
// CHECK: %[[in2:.*]] = vector.insert_strided_slice %[[ex1]], %[[in0]]
// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
//
%0 = vector.shape_cast %a : vector<2x2xf32> to vector<4xf32>
// CHECK: %[[add:.*]] = arith.addf %[[in2]], %[[in2]] : vector<4xf32>
%r0 = arith.addf %0, %0: vector<4xf32>
//
// CHECK: %[[ss0:.*]] = vector.extract_strided_slice %[[add]]
// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} :
// CHECK-SAME: vector<4xf32> to vector<2xf32>
//
// CHECK: %[[res0:.*]] = vector.insert %[[ss0]], %[[cst22]] [0] :
// CHECK-SAME: vector<2xf32> into vector<2x2xf32>
//
// CHECK: %[[s2:.*]] = vector.extract_strided_slice %[[add]]
// CHECK-SAME: {offsets = [2], sizes = [2], strides = [1]} :
// CHECK-SAME: vector<4xf32> to vector<2xf32>
//
// CHECK: %[[res1:.*]] = vector.insert %[[s2]], %[[res0]] [1] :
// CHECK-SAME: vector<2xf32> into vector<2x2xf32>
//
%1 = vector.shape_cast %r0 : vector<4xf32> to vector<2x2xf32>
// CHECK: return %[[add]], %[[res1]] : vector<4xf32>, vector<2x2xf32>
return %r0, %1 : vector<4xf32>, vector<2x2xf32>
}

// CHECK-LABEL: func @shape_cast_2d2d
// CHECK-SAME: %[[A:.*]]: vector<3x2xf32>
// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<3x2xf32>
// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0] : f32 into vector<2x3xf32>
// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<3x2xf32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<2x3xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : vector<3x2xf32>
// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 2] : f32 into vector<2x3xf32>
// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<3x2xf32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0] : f32 into vector<2x3xf32>
// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : vector<3x2xf32>
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<2x3xf32>
// CHECK: %[[T10:.*]] = vector.extract %[[A]][2, 1] : vector<3x2xf32>
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 2] : f32 into vector<2x3xf32>
// CHECK: return %[[T11]] : vector<2x3xf32>

func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> {
%s = vector.shape_cast %arg0: vector<3x2xf32> to vector<2x3xf32>
return %s : vector<2x3xf32>
}

// CHECK-LABEL: func @shape_cast_3d1d
// CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32>
// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<6xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0, 0] : vector<1x3x2xf32>
// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0] : f32 into vector<6xf32>
// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 0, 1] : vector<1x3x2xf32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : f32 into vector<6xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 1, 0] : vector<1x3x2xf32>
// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : f32 into vector<6xf32>
// CHECK: %[[T6:.*]] = vector.extract %[[A]][0, 1, 1] : vector<1x3x2xf32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [3] : f32 into vector<6xf32>
// CHECK: %[[T8:.*]] = vector.extract %[[A]][0, 2, 0] : vector<1x3x2xf32>
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [4] : f32 into vector<6xf32>
// CHECK: %[[T10:.*]] = vector.extract %[[A]][0, 2, 1] : vector<1x3x2xf32>
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [5] : f32 into vector<6xf32>
// CHECK: return %[[T11]] : vector<6xf32>

func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> {
%s = vector.shape_cast %arg0 : vector<1x3x2xf32> to vector<6xf32>
return %s : vector<6xf32>
}

// CHECK-LABEL: func @shape_cast_1d3d
// CHECK-SAME: %[[A:.*]]: vector<6xf32>
// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x1x3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<6xf32>
// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0, 0] : f32 into vector<2x1x3xf32>
// CHECK: %[[T2:.*]] = vector.extract %[[A]][1] : vector<6xf32>
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 0, 1] : f32 into vector<2x1x3xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[A]][2] : vector<6xf32>
// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 0, 2] : f32 into vector<2x1x3xf32>
// CHECK: %[[T6:.*]] = vector.extract %[[A]][3] : vector<6xf32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0, 0] : f32 into vector<2x1x3xf32>
// CHECK: %[[T8:.*]] = vector.extract %[[A]][4] : vector<6xf32>
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 0, 1] : f32 into vector<2x1x3xf32>
// CHECK: %[[T10:.*]] = vector.extract %[[A]][5] : vector<6xf32>
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 0, 2] : f32 into vector<2x1x3xf32>
// CHECK: return %[[T11]] : vector<2x1x3xf32>

func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
%s = vector.shape_cast %arg0 : vector<6xf32> to vector<2x1x3xf32>
return %s : vector<2x1x3xf32>
}

transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
%f = transform.structured.match ops{["func.func"]} in %module_op
: (!pdl.operation) -> !pdl.operation

%f2 = transform.vector.lower_shape_cast %f
: (!pdl.operation) -> !pdl.operation
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -test-vector-transfer-drop-unit-dims-patterns -split-input-file | FileCheck %s
// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s

func.func @transfer_read_rank_reducing(
%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>) -> vector<3x2xi8> {
Expand All @@ -15,8 +15,6 @@ func.func @transfer_read_rank_reducing(
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
// CHECK: vector.transfer_read %[[SUBVIEW]]

// -----

func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, %vec : vector<3x2xi8>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
Expand All @@ -28,4 +26,11 @@ func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6,
// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]
// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]


transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
transform.vector.apply_rank_reducing_subview_patterns %module_op
: (!pdl.operation) -> !pdl.operation
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s

// CHECK-DAG: #[[$map_p4:.*]] = affine_map<()[s0] -> (s0 + 4)>
// CHECK-DAG: #[[$map_p8:.*]] = affine_map<()[s0] -> (s0 + 8)>
// CHECK-DAG: #[[$bounds_map_4:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 4)>
// CHECK-DAG: #[[$bounds_map_8:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 8)>

// CHECK-LABEL: split_vector_transfer_read_2d(
// CHECK-SAME: %[[A:[a-zA-Z0-9_]*]]: memref
// CHECK-SAME: %[[i:[a-zA-Z0-9_]*]]: index
// CHECK-SAME: %[[j:[a-zA-Z0-9_]*]]: index
func.func @split_vector_transfer_read_2d(%A: memref<?x8xf32>, %i: index, %j: index) -> vector<4x8xf32> {
%c0 = arith.constant 0 : index
%f0 = arith.constant 0.0 : f32

// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index
// alloca for boundary full tile
// CHECK: %[[alloc:.*]] = memref.alloca() {alignment = 32 : i64} : memref<4x8xf32>
// %i + 4 <= dim(%A, 0)
// CHECK: %[[idx0:.*]] = affine.apply #[[$map_p4]]()[%[[i]]]
// CHECK: %[[d0:.*]] = memref.dim %[[A]], %[[c0]] : memref<?x8xf32>
// CHECK: %[[cmp0:.*]] = arith.cmpi sle, %[[idx0]], %[[d0]] : index
// %j + 8 <= dim(%A, 1)
// CHECK: %[[idx1:.*]] = affine.apply #[[$map_p8]]()[%[[j]]]
// CHECK: %[[cmp1:.*]] = arith.cmpi sle, %[[idx1]], %[[c8]] : index
// are both conds true
// CHECK: %[[cond:.*]] = arith.andi %[[cmp0]], %[[cmp1]] : i1
// CHECK: %[[ifres:.*]]:3 = scf.if %[[cond]] -> (memref<?x8xf32>, index, index) {
// inBounds, just yield %A
// CHECK: scf.yield %[[A]], %[[i]], %[[j]] : memref<?x8xf32>, index, index
// CHECK: } else {
// slow path, fill tmp alloc and yield a memref_casted version of it
// CHECK: linalg.fill ins(%cst : f32) outs(%[[alloc]] : memref<4x8xf32>)
// CHECK: %[[d0:.*]] = memref.dim %[[A]], %[[c0]] : memref<?x8xf32>
// CHECK: %[[sv0:.*]] = affine.min #[[$bounds_map_4]](%[[d0]], %[[i]], %[[c4]])
// CHECK: %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]])
// CHECK: %[[sv:.*]] = memref.subview %[[A]][%[[i]], %[[j]]] [%[[sv0]], %[[sv1]]] [1, 1]
// CHECK-SAME: memref<?x8xf32> to memref<?x?xf32, strided<[8, 1], offset: ?>>
// CHECK: %[[alloc_view:.*]] = memref.subview %[[alloc]][0, 0] [%[[sv0]], %[[sv1]]] [1, 1]
// CHECK: memref.copy %[[sv]], %[[alloc_view]] : memref<?x?xf32, strided<[8, 1], offset: ?>> to memref<?x?xf32, strided{{.*}}>
// CHECK: %[[yielded:.*]] = memref.cast %[[alloc]] :
// CHECK-SAME: memref<4x8xf32> to memref<?x8xf32>
// CHECK: scf.yield %[[yielded]], %[[c0]], %[[c0]] :
// CHECK-SAME: memref<?x8xf32>, index, index
// CHECK: }
// CHECK: %[[res:.*]] = vector.transfer_read %[[ifres]]#0[%[[ifres]]#1, %[[ifres]]#2], %cst
// CHECK-SAME: {in_bounds = [true, true]} : memref<?x8xf32>, vector<4x8xf32>
%1 = vector.transfer_read %A[%i, %j], %f0 : memref<?x8xf32>, vector<4x8xf32>

// CHECK: return %[[res]] : vector<4x8xf32>
return %1: vector<4x8xf32>
}

// CHECK-LABEL: split_vector_transfer_read_strided_2d(
// CHECK-SAME: %[[A:[a-zA-Z0-9_]*]]: memref
// CHECK-SAME: %[[i:[a-zA-Z0-9_]*]]: index
// CHECK-SAME: %[[j:[a-zA-Z0-9_]*]]: index
func.func @split_vector_transfer_read_strided_2d(
%A: memref<7x8xf32, strided<[?, 1], offset: ?>>,
%i: index, %j: index) -> vector<4x8xf32> {
%c0 = arith.constant 0 : index
%f0 = arith.constant 0.0 : f32


// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[c7:.*]] = arith.constant 7 : index
// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index
// alloca for boundary full tile
// CHECK: %[[alloc:.*]] = memref.alloca() {alignment = 32 : i64} : memref<4x8xf32>
// %i + 4 <= dim(%A, 0)
// CHECK: %[[idx0:.*]] = affine.apply #[[$map_p4]]()[%[[i]]]
// CHECK: %[[cmp0:.*]] = arith.cmpi sle, %[[idx0]], %[[c7]] : index
// %j + 8 <= dim(%A, 1)
// CHECK: %[[idx1:.*]] = affine.apply #[[$map_p8]]()[%[[j]]]
// CHECK: %[[cmp1:.*]] = arith.cmpi sle, %[[idx1]], %[[c8]] : index
// are both conds true
// CHECK: %[[cond:.*]] = arith.andi %[[cmp0]], %[[cmp1]] : i1
// CHECK: %[[ifres:.*]]:3 = scf.if %[[cond]] -> (memref<?x8xf32, strided<[?, 1], offset: ?>>, index, index) {
// inBounds but not cast-compatible: yield a memref_casted form of %A
// CHECK: %[[casted:.*]] = memref.cast %arg0 :
// CHECK-SAME: memref<7x8xf32, strided<[?, 1], offset: ?>> to memref<?x8xf32, strided<[?, 1], offset: ?>>
// CHECK: scf.yield %[[casted]], %[[i]], %[[j]] :
// CHECK-SAME: memref<?x8xf32, strided<[?, 1], offset: ?>>, index, index
// CHECK: } else {
// slow path, fill tmp alloc and yield a memref_casted version of it
// CHECK: linalg.fill ins(%cst : f32) outs(%[[alloc]] : memref<4x8xf32>)
// CHECK: %[[sv0:.*]] = affine.min #[[$bounds_map_4]](%[[c7]], %[[i]], %[[c4]])
// CHECK: %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]])
// CHECK: %[[sv:.*]] = memref.subview %[[A]][%[[i]], %[[j]]] [%[[sv0]], %[[sv1]]] [1, 1]
// CHECK-SAME: memref<7x8xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
// CHECK: %[[alloc_view:.*]] = memref.subview %[[alloc]][0, 0] [%[[sv0]], %[[sv1]]] [1, 1]
// CHECK: memref.copy %[[sv]], %[[alloc_view]] : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32, strided{{.*}}>
// CHECK: %[[yielded:.*]] = memref.cast %[[alloc]] :
// CHECK-SAME: memref<4x8xf32> to memref<?x8xf32, strided<[?, 1], offset: ?>>
// CHECK: scf.yield %[[yielded]], %[[c0]], %[[c0]] :
// CHECK-SAME: memref<?x8xf32, strided<[?, 1], offset: ?>>, index, index
// CHECK: }
// CHECK: %[[res:.*]] = vector.transfer_read {{.*}} {in_bounds = [true, true]} :
// CHECK-SAME: memref<?x8xf32, strided<[?, 1], offset: ?>>, vector<4x8xf32>
%1 = vector.transfer_read %A[%i, %j], %f0 :
memref<7x8xf32, strided<[?, 1], offset: ?>>, vector<4x8xf32>

return %1 : vector<4x8xf32>
}

transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
transform.vector.split_transfer_full_partial %module_op
split_transfer_strategy = "linalg-copy"
: (!pdl.operation) -> !pdl.operation
}

// -----

func.func @split_vector_transfer_write_2d(%V: vector<4x8xf32>, %A: memref<?x8xf32>, %i: index, %j: index) {
vector.transfer_write %V, %A[%i, %j] :
vector<4x8xf32>, memref<?x8xf32>
return
}

// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 + 4)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 + 8)>
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 4)>
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 8)>

// CHECK-LABEL: func @split_vector_transfer_write_2d(
// CHECK-SAME: %[[VEC:.*]]: vector<4x8xf32>,
// CHECK-SAME: %[[DEST:.*]]: memref<?x8xf32>,
// CHECK-SAME: %[[I:.*]]: index,
// CHECK-SAME: %[[J:.*]]: index) {
// CHECK-DAG: %[[CT:.*]] = arith.constant true
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
// CHECK: %[[TEMP:.*]] = memref.alloca() {alignment = 32 : i64} : memref<4x8xf32>
// CHECK: %[[IDX0:.*]] = affine.apply #[[$MAP0]]()[%[[I]]]
// CHECK: %[[DIM0:.*]] = memref.dim %[[DEST]], %[[C0]] : memref<?x8xf32>
// CHECK: %[[DIM0_IN:.*]] = arith.cmpi sle, %[[IDX0]], %[[DIM0]] : index
// CHECK: %[[DIM1:.*]] = affine.apply #[[$MAP1]]()[%[[J]]]
// CHECK: %[[DIM1_IN:.*]] = arith.cmpi sle, %[[DIM1]], %[[C8]] : index
// CHECK: %[[IN_BOUNDS:.*]] = arith.andi %[[DIM0_IN]], %[[DIM1_IN]] : i1
// CHECK: %[[IN_BOUND_DEST:.*]]:3 = scf.if %[[IN_BOUNDS]]
// CHECK-SAME: -> (memref<?x8xf32>, index, index) {
// CHECK: scf.yield %[[DEST]], %[[I]], %[[J]] : memref<?x8xf32>, index, index
// CHECK: } else {
// CHECK: %[[VAL_16:.*]] = memref.cast %[[TEMP]] : memref<4x8xf32> to memref<?x8xf32>
// CHECK: scf.yield %[[VAL_16]], %[[C0]], %[[C0]] : memref<?x8xf32>, index, index
// CHECK: }
// CHECK: vector.transfer_write %[[VEC]],
// CHECK-SAME: %[[IN_BOUND_DEST:.*]]#0[%[[IN_BOUND_DEST]]#1, %[[IN_BOUND_DEST]]#2]
// CHECK-SAME: {in_bounds = [true, true]} : vector<4x8xf32>, memref<?x8xf32>
// CHECK: %[[OUT_BOUNDS:.*]] = arith.xori %[[IN_BOUNDS]], %[[CT]] : i1
// CHECK: scf.if %[[OUT_BOUNDS]] {
// CHECK: %[[VAL_19:.*]] = memref.dim %[[DEST]], %[[C0]] : memref<?x8xf32>
// CHECK-DAG: %[[VAL_20:.*]] = affine.min #[[$MAP2]](%[[VAL_19]], %[[I]], %[[C4]])
// CHECK-DAG: %[[VAL_21:.*]] = affine.min #[[$MAP3]](%[[C8]], %[[J]], %[[C8]])
// CHECK: %[[VAL_22:.*]] = memref.subview %[[TEMP]]
// CHECK-SAME: [%[[I]], %[[J]]] [%[[VAL_20]], %[[VAL_21]]]
// CHECK-SAME: [1, 1] : memref<4x8xf32> to memref<?x?xf32, strided<[8, 1], offset: ?>>
// CHECK: %[[DEST_VIEW:.*]] = memref.subview %[[DEST]][0, 0] [%[[VAL_20]], %[[VAL_21]]] [1, 1]
// CHECK: memref.copy %[[VAL_22]], %[[DEST_VIEW]]
// CHECK-SAME: : memref<?x?xf32, strided<[8, 1], offset: ?>> to memref<?x?xf32, strided{{.*}}>
// CHECK: }
// CHECK: return
// CHECK: }

transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
transform.vector.split_transfer_full_partial %module_op
split_transfer_strategy = "linalg-copy"
: (!pdl.operation) -> !pdl.operation
}

// -----

func.func @split_vector_transfer_write_strided_2d(
%V: vector<4x8xf32>, %A: memref<7x8xf32, strided<[?, 1], offset: ?>>,
%i: index, %j: index) {
vector.transfer_write %V, %A[%i, %j] :
vector<4x8xf32>, memref<7x8xf32, strided<[?, 1], offset: ?>>
return
}

// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 + 4)>
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 8)>
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 4)>
// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0 - d1, 8)>
// CHECK-LABEL: func @split_vector_transfer_write_strided_2d(
// CHECK-SAME: %[[VEC:.*]]: vector<4x8xf32>,
// CHECK-SAME: %[[DEST:.*]]: memref<7x8xf32, strided<[?, 1], offset: ?>>,
// CHECK-SAME: %[[I:.*]]: index,
// CHECK-SAME: %[[J:.*]]: index) {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[CT:.*]] = arith.constant true
// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
// CHECK: %[[TEMP:.*]] = memref.alloca() {alignment = 32 : i64} : memref<4x8xf32>
// CHECK: %[[DIM0:.*]] = affine.apply #[[$MAP1]]()[%[[I]]]
// CHECK: %[[DIM0_IN:.*]] = arith.cmpi sle, %[[DIM0]], %[[C7]] : index
// CHECK: %[[DIM1:.*]] = affine.apply #[[$MAP2]]()[%[[J]]]
// CHECK: %[[DIM1_IN:.*]] = arith.cmpi sle, %[[DIM1]], %[[C8]] : index
// CHECK: %[[IN_BOUNDS:.*]] = arith.andi %[[DIM0_IN]], %[[DIM1_IN]] : i1
// CHECK: %[[IN_BOUND_DEST:.*]]:3 = scf.if %[[IN_BOUNDS]]
// CHECK-SAME: -> (memref<?x8xf32, strided<[?, 1], offset: ?>>, index, index) {
// CHECK: %[[VAL_16:.*]] = memref.cast %[[DEST]]
// CHECK-SAME: : memref<7x8xf32, strided<[?, 1], offset: ?>> to memref<?x8xf32, strided<[?, 1], offset: ?>>
// CHECK: scf.yield %[[VAL_16]], %[[I]], %[[J]]
// CHECK-SAME: : memref<?x8xf32, strided<[?, 1], offset: ?>>, index, index
// CHECK: } else {
// CHECK: %[[VAL_17:.*]] = memref.cast %[[TEMP]]
// CHECK-SAME: : memref<4x8xf32> to memref<?x8xf32, strided<[?, 1], offset: ?>>
// CHECK: scf.yield %[[VAL_17]], %[[C0]], %[[C0]]
// CHECK-SAME: : memref<?x8xf32, strided<[?, 1], offset: ?>>, index, index
// CHECK: }
// CHECK: vector.transfer_write %[[VEC]],
// CHECK-SAME: %[[IN_BOUND_DEST:.*]]#0
// CHECK-SAME: [%[[IN_BOUND_DEST]]#1, %[[IN_BOUND_DEST]]#2]
// CHECK-SAME: {in_bounds = [true, true]}
// CHECK-SAME: : vector<4x8xf32>, memref<?x8xf32, strided<[?, 1], offset: ?>>
// CHECK: %[[OUT_BOUNDS:.*]] = arith.xori %[[IN_BOUNDS]], %[[CT]] : i1
// CHECK: scf.if %[[OUT_BOUNDS]] {
// CHECK-DAG: %[[VAL_20:.*]] = affine.min #[[$MAP3]](%[[C7]], %[[I]], %[[C4]])
// CHECK-DAG: %[[VAL_21:.*]] = affine.min #[[$MAP4]](%[[C8]], %[[J]], %[[C8]])
// CHECK: %[[VAL_22:.*]] = memref.subview %[[TEMP]]
// CHECK-SAME: [%[[I]], %[[J]]] [%[[VAL_20]], %[[VAL_21]]]
// CHECK-SAME: [1, 1] : memref<4x8xf32> to memref<?x?xf32, strided<[8, 1], offset: ?>>
// CHECK: %[[DEST_VIEW:.*]] = memref.subview %[[DEST]][0, 0] [%[[VAL_20]], %[[VAL_21]]] [1, 1]
// CHECK: memref.copy %[[VAL_22]], %[[DEST_VIEW]]
// CHECK-SAME: : memref<?x?xf32, strided<[8, 1], offset: ?>> to memref<?x?xf32, strided<[?, 1], offset: ?>>
// CHECK: }
// CHECK: return
// CHECK: }

transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
transform.vector.split_transfer_full_partial %module_op
split_transfer_strategy = "linalg-copy"
: (!pdl.operation) -> !pdl.operation
}
214 changes: 29 additions & 185 deletions mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir

Large diffs are not rendered by default.

69 changes: 32 additions & 37 deletions mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -canonicalize -split-input-file | FileCheck %s
// RUN: mlir-opt %s --test-transform-dialect-interpreter -canonicalize --split-input-file | FileCheck %s

// CHECK-LABEL: func @vector_transfer_ops_0d_memref(
// CHECK-SAME: %[[MEM:.*]]: memref<f32>
Expand All @@ -21,8 +21,6 @@ func.func @vector_transfer_ops_0d_memref(%M: memref<f32>, %v: vector<1x1x1xf32>)
return
}

// -----

// CHECK-LABEL: func @vector_transfer_ops_0d_tensor(
// CHECK-SAME: %[[SOURCE:.*]]: tensor<f32>
func.func @vector_transfer_ops_0d_tensor(%M: tensor<f32>) -> vector<1xf32> {
Expand All @@ -37,8 +35,6 @@ func.func @vector_transfer_ops_0d_tensor(%M: tensor<f32>) -> vector<1xf32> {
return %0: vector<1xf32>
}

// -----

// transfer_read/write are lowered to vector.load/store
// CHECK-LABEL: func @transfer_to_load(
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
Expand All @@ -55,8 +51,6 @@ func.func @transfer_to_load(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32>
return %res : vector<4xf32>
}

// -----

// n-D results are also supported.
// CHECK-LABEL: func @transfer_2D(
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
Expand All @@ -73,8 +67,6 @@ func.func @transfer_2D(%mem : memref<8x8xf32>, %i : index) -> vector<2x4xf32> {
return %res : vector<2x4xf32>
}

// -----

// Vector element types are supported when the result has the same type.
// CHECK-LABEL: func @transfer_vector_element(
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xvector<2x4xf32>>,
Expand All @@ -91,8 +83,6 @@ func.func @transfer_vector_element(%mem : memref<8x8xvector<2x4xf32>>, %i : inde
return %res : vector<2x4xf32>
}

// -----

// TODO: Vector element types are not supported yet when the result has a
// different type.
// CHECK-LABEL: func @transfer_vector_element_different_types(
Expand All @@ -111,8 +101,6 @@ func.func @transfer_vector_element_different_types(%mem : memref<8x8xvector<2x4x
return %res : vector<1x2x4xf32>
}

// -----

// TODO: transfer_read/write cannot be lowered because there is a dimension
// that is not guaranteed to be in-bounds.
// CHECK-LABEL: func @transfer_2D_not_inbounds(
Expand All @@ -131,8 +119,6 @@ func.func @transfer_2D_not_inbounds(%mem : memref<8x8xf32>, %i : index) -> vecto
return %res : vector<2x4xf32>
}

// -----

// TODO: transfer_read/write cannot be lowered because they are not guaranteed
// to be in-bounds.
// CHECK-LABEL: func @transfer_not_inbounds(
Expand All @@ -151,8 +137,6 @@ func.func @transfer_not_inbounds(%mem : memref<8x8xf32>, %i : index) -> vector<4
return %res : vector<4xf32>
}

// -----

// CHECK-LABEL: func @transfer_nondefault_layout(
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32, #{{.*}}>,
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> {
Expand All @@ -169,8 +153,6 @@ func.func @transfer_nondefault_layout(%mem : memref<8x8xf32, #layout>, %i : inde
return %res : vector<4xf32>
}

// -----

// TODO: transfer_read/write cannot be lowered to vector.load/store yet when the
// permutation map is not the minor identity map (up to broadcasting).
// CHECK-LABEL: func @transfer_perm_map(
Expand All @@ -187,8 +169,6 @@ func.func @transfer_perm_map(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32
return %res : vector<4xf32>
}

// -----

// Lowering of transfer_read with broadcasting is supported (note that a `load`
// is generated instead of a `vector.load`).
// CHECK-LABEL: func @transfer_broadcasting(
Expand All @@ -199,15 +179,15 @@ func.func @transfer_perm_map(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32
// CHECK-NEXT: return %[[RES]] : vector<4xf32>
// CHECK-NEXT: }

#broadcast = affine_map<(d0, d1) -> (0)>
#broadcast_1d = affine_map<(d0, d1) -> (0)>
func.func @transfer_broadcasting(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> {
%cf0 = arith.constant 0.0 : f32
%res = vector.transfer_read %mem[%i, %i], %cf0 {in_bounds = [true], permutation_map = #broadcast} : memref<8x8xf32>, vector<4xf32>
%res = vector.transfer_read %mem[%i, %i], %cf0
{in_bounds = [true], permutation_map = #broadcast_1d}
: memref<8x8xf32>, vector<4xf32>
return %res : vector<4xf32>
}

// -----

// CHECK-LABEL: func @transfer_scalar(
// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<1xf32> {
Expand All @@ -221,8 +201,6 @@ func.func @transfer_scalar(%mem : memref<?x?xf32>, %i : index) -> vector<1xf32>
return %res : vector<1xf32>
}

// -----

// An example with two broadcasted dimensions.
// CHECK-LABEL: func @transfer_broadcasting_2D(
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
Expand All @@ -232,15 +210,15 @@ func.func @transfer_scalar(%mem : memref<?x?xf32>, %i : index) -> vector<1xf32>
// CHECK-NEXT: return %[[RES]] : vector<4x4xf32>
// CHECK-NEXT: }

#broadcast = affine_map<(d0, d1) -> (0, 0)>
#broadcast_2d = affine_map<(d0, d1) -> (0, 0)>
func.func @transfer_broadcasting_2D(%mem : memref<8x8xf32>, %i : index) -> vector<4x4xf32> {
%cf0 = arith.constant 0.0 : f32
%res = vector.transfer_read %mem[%i, %i], %cf0 {in_bounds = [true, true], permutation_map = #broadcast} : memref<8x8xf32>, vector<4x4xf32>
%res = vector.transfer_read %mem[%i, %i], %cf0
{in_bounds = [true, true], permutation_map = #broadcast_2d}
: memref<8x8xf32>, vector<4x4xf32>
return %res : vector<4x4xf32>
}

// -----

// More complex broadcasting case (here a `vector.load` is generated).
// CHECK-LABEL: func @transfer_broadcasting_complex(
// CHECK-SAME: %[[MEM:.*]]: memref<10x20x30x8x8xf32>,
Expand All @@ -250,13 +228,25 @@ func.func @transfer_broadcasting_2D(%mem : memref<8x8xf32>, %i : index) -> vecto
// CHECK-NEXT: return %[[RES]] : vector<3x2x4x5xf32>
// CHECK-NEXT: }

#broadcast = affine_map<(d0, d1, d2, d3, d4) -> (d1, 0, 0, d4)>
#broadcast_2d_in_4d = affine_map<(d0, d1, d2, d3, d4) -> (d1, 0, 0, d4)>
func.func @transfer_broadcasting_complex(%mem : memref<10x20x30x8x8xf32>, %i : index) -> vector<3x2x4x5xf32> {
%cf0 = arith.constant 0.0 : f32
%res = vector.transfer_read %mem[%i, %i, %i, %i, %i], %cf0 {in_bounds = [true, true, true, true], permutation_map = #broadcast} : memref<10x20x30x8x8xf32>, vector<3x2x4x5xf32>
%res = vector.transfer_read %mem[%i, %i, %i, %i, %i], %cf0
{in_bounds = [true, true, true, true], permutation_map = #broadcast_2d_in_4d}
: memref<10x20x30x8x8xf32>, vector<3x2x4x5xf32>
return %res : vector<3x2x4x5xf32>
}


transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
%m2 = transform.vector.lower_transfer %module_op
max_transfer_rank = 99
: (!pdl.operation) -> !pdl.operation
transform.vector.apply_transfer_permutation_patterns %m2
: (!pdl.operation) -> !pdl.operation
}

// -----

#map0 = affine_map<(d0, d1, d2, d3) -> (d1, d0, 0, 0)>
Expand Down Expand Up @@ -321,8 +311,6 @@ func.func @transfer_read_permutations(%arg0 : memref<?x?xf32>, %arg1 : memref<?x
vector<7x14x8x16xf32>, vector<8xf32>
}

// -----

// CHECK-LABEL: func @transfer_write_permutations
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?x?x?xf32>
// CHECK-SAME: %[[ARG1:.*]]: tensor<?x?x?x?xf32>
Expand All @@ -348,8 +336,6 @@ func.func @transfer_write_permutations(
return %0 : tensor<?x?x?x?xf32>
}

// -----

// CHECK-LABEL: func @transfer_write_broadcast_unit_dim
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?x?x?xf32>
// CHECK-SAME: %[[ARG1:.*]]: tensor<?x?x?x?xf32>
Expand All @@ -374,3 +360,12 @@ func.func @transfer_write_broadcast_unit_dim(

return %0 : tensor<?x?x?x?xf32>
}

transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
%m2 = transform.vector.lower_transfer %module_op
max_transfer_rank = 99
: (!pdl.operation) -> !pdl.operation
transform.vector.apply_transfer_permutation_patterns %m2
: (!pdl.operation) -> !pdl.operation
}
1,117 changes: 531 additions & 586 deletions mlir/test/Dialect/Vector/vector-transpose-lowering.mlir

Large diffs are not rendered by default.

274 changes: 0 additions & 274 deletions mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,75 +113,6 @@ struct TestVectorToVectorLowering
}
};

struct TestVectorContractionLowering
: public PassWrapper<TestVectorContractionLowering,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorContractionLowering)

StringRef getArgument() const final {
return "test-vector-contraction-lowering";
}
StringRef getDescription() const final {
return "Test lowering patterns that lower contract ops in the vector "
"dialect";
}
TestVectorContractionLowering() = default;
TestVectorContractionLowering(const TestVectorContractionLowering &pass)
: PassWrapper(pass) {}

Option<bool> lowerToFlatMatrix{
*this, "vector-lower-matrix-intrinsics",
llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
llvm::cl::init(false)};
Option<bool> lowerToOuterProduct{
*this, "vector-outerproduct",
llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
llvm::cl::init(false)};
Option<bool> lowerToParallelArith{
*this, "vector-parallel-arith",
llvm::cl::desc("Lower vector.contract to elementwise vector ops."),
llvm::cl::init(false)};

void runOnOperation() override {
RewritePatternSet patterns(&getContext());

// Test on one pattern in isolation.
if (lowerToOuterProduct) {
VectorContractLowering lowering = VectorContractLowering::OuterProduct;
VectorTransformsOptions options{lowering};
populateVectorContractLoweringPatterns(
patterns, options, /*benefit=*/1,
/*disableOuterProductlowering=*/true);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
return;
}

if (lowerToParallelArith) {
vector::populateVectorContractLoweringPatterns(
patterns,
vector::VectorTransformsOptions().setVectorTransformsOptions(
vector::VectorContractLowering::ParallelArith));
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
return;
}

// Test on all contract lowering patterns.
VectorContractLowering contractLowering = VectorContractLowering::Dot;
if (lowerToFlatMatrix)
contractLowering = VectorContractLowering::Matmul;
VectorMultiReductionLowering vectorMultiReductionLowering =
VectorMultiReductionLowering::InnerParallel;
VectorTransformsOptions options{contractLowering,
vectorMultiReductionLowering,
VectorTransposeLowering()};
populateVectorBroadcastLoweringPatterns(patterns);
populateVectorContractLoweringPatterns(patterns, options);
populateVectorMaskOpLoweringPatterns(patterns);
populateVectorShapeCastLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

struct TestVectorContractionPrepareForMMTLowering
: public PassWrapper<TestVectorContractionPrepareForMMTLowering,
OperationPass<func::FuncOp>> {
Expand Down Expand Up @@ -209,82 +140,6 @@ struct TestVectorContractionPrepareForMMTLowering
}
};

struct TestVectorTransposeLowering
: public PassWrapper<TestVectorTransposeLowering,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransposeLowering)

StringRef getArgument() const final {
return "test-vector-transpose-lowering";
}
StringRef getDescription() const final {
return "Test lowering patterns that lower contract ops in the vector "
"dialect";
}
TestVectorTransposeLowering() = default;
TestVectorTransposeLowering(const TestVectorTransposeLowering &pass)
: PassWrapper(pass) {}

Option<bool> lowerToEltwise{
*this, "eltwise",
llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"),
llvm::cl::init(false)};
Option<bool> lowerToFlatTranspose{
*this, "flat",
llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
llvm::cl::init(false)};
Option<bool> lowerToShuffleTranspose{
*this, "shuffle",
llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
llvm::cl::init(false)};
Option<bool> lowerToAvx2{
*this, "avx2",
llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"),
llvm::cl::init(false)};

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LLVM::LLVMDialect>();
}

void runOnOperation() override {
func::FuncOp funcOp = getOperation();
MLIRContext *context = funcOp.getContext();
RewritePatternSet patterns(context);

vector::VectorTransformsOptions vectorTransformOptions;
if (lowerToEltwise) {
vectorTransformOptions =
vectorTransformOptions.setVectorTransposeLowering(
VectorTransposeLowering::EltWise);
}
if (lowerToFlatTranspose) {
vectorTransformOptions =
vectorTransformOptions.setVectorTransposeLowering(
VectorTransposeLowering::Flat);
}
if (lowerToShuffleTranspose) {
vectorTransformOptions =
vectorTransformOptions.setVectorTransposeLowering(
VectorTransposeLowering::Shuffle);
}
vector::populateVectorTransposeLoweringPatterns(patterns,
vectorTransformOptions);

if (lowerToAvx2) {
auto avx2LoweringOptions =
x86vector::avx2::LoweringOptions().setTransposeOptions(
x86vector::avx2::TransposeLoweringOptions()
.lower4x8xf32()
.lower8x8xf32());
x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
patterns, avx2LoweringOptions, /*benefit=*/10);
}

if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
return signalPassFailure();
}
};

struct TestVectorUnrollingPatterns
: public PassWrapper<TestVectorUnrollingPatterns,
OperationPass<func::FuncOp>> {
Expand Down Expand Up @@ -435,47 +290,6 @@ struct TestVectorTransferUnrollingPatterns
llvm::cl::init(false)};
};

struct TestVectorTransferFullPartialSplitPatterns
: public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestVectorTransferFullPartialSplitPatterns)

StringRef getArgument() const final {
return "test-vector-transfer-full-partial-split";
}
StringRef getDescription() const final {
return "Test lowering patterns to split "
"transfer ops via scf.if + linalg ops";
}
TestVectorTransferFullPartialSplitPatterns() = default;
TestVectorTransferFullPartialSplitPatterns(
const TestVectorTransferFullPartialSplitPatterns &pass)
: PassWrapper(pass) {}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
scf::SCFDialect>();
}

Option<bool> useLinalgOps{
*this, "use-memref-copy",
llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
"memref.copy operations."),
llvm::cl::init(false)};
void runOnOperation() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
VectorTransformsOptions options;
if (useLinalgOps)
options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
else
options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
populateVectorTransferFullPartialPatterns(patterns, options);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

struct TestScalarVectorTransferLoweringPatterns
: public PassWrapper<TestScalarVectorTransferLoweringPatterns,
OperationPass<func::FuncOp>> {
Expand Down Expand Up @@ -514,63 +328,6 @@ struct TestVectorTransferOpt
void runOnOperation() override { transferOpflowOpt(getOperation()); }
};

struct TestVectorTransferLoweringPatterns
: public PassWrapper<TestVectorTransferLoweringPatterns,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestVectorTransferLoweringPatterns)

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<tensor::TensorDialect, memref::MemRefDialect>();
}
StringRef getArgument() const final {
return "test-vector-transfer-lowering-patterns";
}
StringRef getDescription() const final {
return "Test lowering patterns to lower transfer ops to other vector ops";
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorTransferLoweringPatterns(patterns);
populateVectorTransferPermutationMapLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

struct TestVectorMultiReductionLoweringPatterns
: public PassWrapper<TestVectorMultiReductionLoweringPatterns,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestVectorMultiReductionLoweringPatterns)

TestVectorMultiReductionLoweringPatterns() = default;
TestVectorMultiReductionLoweringPatterns(
const TestVectorMultiReductionLoweringPatterns &pass)
: PassWrapper(pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect>();
}
StringRef getArgument() const final {
return "test-vector-multi-reduction-lowering-patterns";
}
StringRef getDescription() const final {
return "Test lowering patterns to lower vector.multi_reduction to other "
"vector ops";
}
Option<bool> useOuterReductions{
*this, "use-outer-reductions",
llvm::cl::desc("Move reductions to outer most dimensions"),
llvm::cl::init(false)};
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorMultiReductionLoweringPatterns(
patterns, useOuterReductions
? vector::VectorMultiReductionLowering::InnerParallel
: vector::VectorMultiReductionLowering::InnerReduction);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

struct TestVectorTransferCollapseInnerMostContiguousDims
: public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
OperationPass<func::FuncOp>> {
Expand Down Expand Up @@ -621,25 +378,6 @@ struct TestVectorReduceToContractPatternsPatterns
}
};

struct TestVectorTransferDropUnitDimsPatterns
: public PassWrapper<TestVectorTransferDropUnitDimsPatterns,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestVectorTransferDropUnitDimsPatterns)

StringRef getArgument() const final {
return "test-vector-transfer-drop-unit-dims-patterns";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect>();
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorTransferDropUnitDimsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

struct TestFlattenVectorTransferPatterns
: public PassWrapper<TestFlattenVectorTransferPatterns,
OperationPass<func::FuncOp>> {
Expand Down Expand Up @@ -923,32 +661,20 @@ namespace test {
void registerTestVectorLowerings() {
PassRegistration<TestVectorToVectorLowering>();

PassRegistration<TestVectorContractionLowering>();

PassRegistration<TestVectorContractionPrepareForMMTLowering>();

PassRegistration<TestVectorTransposeLowering>();

PassRegistration<TestVectorUnrollingPatterns>();

PassRegistration<TestVectorTransferUnrollingPatterns>();

PassRegistration<TestVectorTransferFullPartialSplitPatterns>();

PassRegistration<TestScalarVectorTransferLoweringPatterns>();

PassRegistration<TestVectorTransferOpt>();

PassRegistration<TestVectorTransferLoweringPatterns>();

PassRegistration<TestVectorMultiReductionLoweringPatterns>();

PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();

PassRegistration<TestVectorReduceToContractPatternsPatterns>();

PassRegistration<TestVectorTransferDropUnitDimsPatterns>();

PassRegistration<TestFlattenVectorTransferPatterns>();

PassRegistration<TestVectorScanLowering>();
Expand Down
4 changes: 2 additions & 2 deletions utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -3426,8 +3426,7 @@ cc_library(
":ArithDialect",
":AsmParser",
":IR",
":PDLDialect",
":Parser",
":LLVMDialect",
":SideEffectInterfaces",
":TransformDialect",
":TransformUtils",
Expand Down Expand Up @@ -9363,6 +9362,7 @@ cc_library(
":TransformDialectInterfacesIncGen",
":TransformDialectUtils",
":TransformOpsIncGen",
":Transforms",
":TransformTypesIncGen",
"//llvm:Support",
],
Expand Down