diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 15c467b21c81e..4db6057519da5 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -524,6 +524,23 @@ VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask, if (!mask) { LDBG() << "No mask required"; + if (assumeDynamicDimsMatchVecSizes) { + LDBG() << "Assuming dynamic dimensions match vector sizes!"; + // Set inbounds to all-true. + llvm::TypeSwitch(opToMask) + .Case( + [&](auto xferOp) { + SmallVector inBoundsMap(xferOp.getInBounds().size(), + true); + rewriter.modifyOpInPlace(xferOp, [&]() { + xferOp.setInBoundsAttr( + rewriter.getBoolArrayAttr(inBoundsMap)); + }); + }) + .Default([](Operation *op) { + // No-op if the operation is not an xfer read or write. + }); + } return opToMask; } diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir index 62bf1f55c9af2..7fe707902072a 100644 --- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir +++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir @@ -918,12 +918,16 @@ func.func @mmt4d_scalable_with_assume(%A: memref<16x16x8x1xf32>, %B: memref<16x1 // CHECK-SAME: %[[B:.*]]: memref<16x16x?x1xf32>, // CHECK-SAME: %[[C_IN:.*]]: memref<16x16x8x?xf32>) { // CHECK-NOT: mask -// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32> -// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32> -// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32> +// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]] +// CHECK-SAME: in_bounds = [true, true, true, true, true, true]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x[4]x1xf32> +// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]] +// CHECK-SAME: in_bounds = [true, true, true, true, true, true]{{.*}} : memref<16x16x?x1xf32>, vector<16x16x16x8x[4]x1xf32> +// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C_IN]] +// CHECK-SAME: in_bounds = [true, true, true, true]{{.*}} : memref<16x16x8x?xf32>, vector<16x16x8x[4]xf32> // CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x[4]x1xf32> // CHECK: %[[RED:.*]] = vector.multi_reduction , %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x[4]x1xf32> to vector<16x16x8x[4]xf32> -// CHECK: vector.transfer_write %[[RED]], %[[C_IN]]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32> +// CHECK: vector.transfer_write %[[RED]], %[[C_IN]] +// CHECK-SAME: in_bounds = [true, true, true, true]{{.*}} : vector<16x16x8x[4]xf32>, memref<16x16x8x?xf32> module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { @@ -1011,12 +1015,16 @@ func.func @batch_mmt4d_scalable_with_assume(%A: memref<2x16x16x8x1xf32>, %B: mem // CHECK-SAME: %[[B:.*]]: memref<2x16x16x?x1xf32>, // CHECK-SAME: %[[C_IN:.*]]: memref<2x16x16x8x?xf32>) { // CHECK-NOT: mask -// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x[4]x1xf32> -// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<2x16x16x?x1xf32>, vector<2x16x16x16x8x[4]x1xf32> -// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C_IN]]{{.*}} : memref<2x16x16x8x?xf32>, vector<2x16x16x8x[4]xf32> +// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]] +// CHECK-SAME: in_bounds = [true, true, true, true, true, true, true]{{.*}} : memref<2x16x16x8x1xf32>, vector<2x16x16x16x8x[4]x1xf32> +// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]] +// CHECK-SAME: in_bounds = [true, true, true, true, true, true, true]{{.*}} : memref<2x16x16x?x1xf32>, vector<2x16x16x16x8x[4]x1xf32> +// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C_IN]] +// CHECK-SAME: in_bounds = [true, true, true, true, true]{{.*}} : memref<2x16x16x8x?xf32>, vector<2x16x16x8x[4]xf32> // CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<2x16x16x16x8x[4]x1xf32> // CHECK: %[[RED:.*]] = vector.multi_reduction , %[[MUL]], %[[VEC_C]] [3, 6] : vector<2x16x16x16x8x[4]x1xf32> to vector<2x16x16x8x[4]xf32> -// CHECK: vector.transfer_write %[[RED]], %[[C_IN]]{{.*}} : vector<2x16x16x8x[4]xf32>, memref<2x16x16x8x?xf32> +// CHECK: vector.transfer_write %[[RED]], %[[C_IN]] +// CHECK-SAME: in_bounds = [true, true, true, true, true]{{.*}} : vector<2x16x16x8x[4]xf32>, memref<2x16x16x8x?xf32> module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {