diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 3d3e49134363f..f78b3c37832c2 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5850,6 +5850,9 @@ static LogicalResult foldReadInitWrite(TransferWriteOp write, // Bail on potential out-of-bounds accesses. if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim()) return failure(); + // Masked transfers have padding/select semantics and are not identity folds. + if (read.getMask() || write.getMask()) + return failure(); // Tensor types must be the same. if (read.getBase().getType() != rankedTensorType) return failure(); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index de32ab45aa7ea..6aa92ab79a0dd 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1816,6 +1816,61 @@ func.func @transfer_folding_1(%t0: tensor<2x3x4xf32>, %t1: tensor<2x3x4xf32>) // ----- +// CHECK-LABEL: func @negative_transfer_folding_masked_read +// CHECK: vector.transfer_read {{.*}}, {{.*}}, %[[MASK:.*]] +// CHECK: %[[R:.*]] = vector.transfer_write +// CHECK: return %[[R]] +func.func @negative_transfer_folding_masked_read( + %t0: tensor<2x3x4xf32>, %t1: tensor<2x3x4xf32>, + %mask: vector<2x3x4xi1>) -> tensor<2x3x4xf32> { + %c0 = arith.constant 0 : index + %pad = arith.constant 0.0 : f32 + %v = vector.transfer_read %t0[%c0, %c0, %c0], %pad, %mask {in_bounds = [true, true, true]} : + tensor<2x3x4xf32>, vector<2x3x4xf32> + %r = vector.transfer_write %v, %t1[%c0, %c0, %c0] {in_bounds = [true, true, true]} : + vector<2x3x4xf32>, tensor<2x3x4xf32> + return %r : tensor<2x3x4xf32> +} + +// ----- + +// CHECK-LABEL: func @negative_transfer_folding_masked_write +// CHECK: vector.transfer_read +// CHECK: %[[R:.*]] = vector.transfer_write {{.*}}, {{.*}}, %[[MASK:.*]] +// CHECK: return %[[R]] +func.func @negative_transfer_folding_masked_write( + %t0: tensor<2x3x4xf32>, %t1: tensor<2x3x4xf32>, + %mask: vector<2x3x4xi1>) -> tensor<2x3x4xf32> { + %c0 = arith.constant 0 : index + %pad = arith.constant 0.0 : f32 + %v = vector.transfer_read %t0[%c0, %c0, %c0], %pad {in_bounds = [true, true, true]} : + tensor<2x3x4xf32>, vector<2x3x4xf32> + %r = vector.transfer_write %v, %t1[%c0, %c0, %c0], %mask {in_bounds = [true, true, true]} : + vector<2x3x4xf32>, tensor<2x3x4xf32> + return %r : tensor<2x3x4xf32> +} + +// ----- + +// CHECK-LABEL: func @negative_transfer_folding_masked_read_and_write +// CHECK-SAME: %[[MASK:[0-9a-zA-Z_]+]]: vector<2x3x4xi1> +// CHECK: %[[V:.*]] = vector.transfer_read {{.*}}, {{.*}}, %[[MASK]] +// CHECK: %[[R:.*]] = vector.transfer_write %[[V]], {{.*}}, %[[MASK]] +// CHECK: return %[[R]] +func.func @negative_transfer_folding_masked_read_and_write( + %t0: tensor<2x3x4xf32>, %t1: tensor<2x3x4xf32>, + %mask: vector<2x3x4xi1>) -> tensor<2x3x4xf32> { + %c0 = arith.constant 0 : index + %pad = arith.constant 0.0 : f32 + %v = vector.transfer_read %t0[%c0, %c0, %c0], %pad, %mask {in_bounds = [true, true, true]} : + tensor<2x3x4xf32>, vector<2x3x4xf32> + %r = vector.transfer_write %v, %t1[%c0, %c0, %c0], %mask {in_bounds = [true, true, true]} : + vector<2x3x4xf32>, tensor<2x3x4xf32> + return %r : tensor<2x3x4xf32> +} + +// ----- + // CHECK-LABEL: func @store_after_load_tensor // CHECK-SAME: (%[[ARG:.*]]: tensor<4x4xf32>) // CHECK-NOT: vector.transfer_read