diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index c694f4f58faa1..a2d59010a2901 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1563,10 +1563,6 @@ class DropInnerMostUnitDimsTransferRead if (readOp.getTransferRank() == 0) return failure(); - // TODO: support mask. - if (readOp.getMask()) - return failure(); - auto srcType = dyn_cast(readOp.getBase().getType()); if (!srcType) return failure(); @@ -1614,12 +1610,22 @@ class DropInnerMostUnitDimsTransferRead readOp.getBase(), offsets, sizes, strides); auto permMap = getTransferMinorIdentityMap( cast(rankedReducedView.getType()), resultTargetVecType); + + // If there is a mask, shape_cast it to drop the same inner unit dims. + Value mask = readOp.getMask(); + if (mask) { + auto maskType = cast(mask.getType()); + auto reducedMaskType = VectorType::get( + maskType.getShape().drop_back(dimsToDrop), maskType.getElementType(), + maskType.getScalableDims().drop_back(dimsToDrop)); + mask = rewriter.createOrFold(loc, reducedMaskType, + mask); + } + Value result = vector::TransferReadOp::create( rewriter, loc, resultTargetVecType, rankedReducedView, readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap), - readOp.getPadding(), - // TODO: support mask. - /*mask=*/Value(), inBoundsAttr); + readOp.getPadding(), mask, inBoundsAttr); rewriter.replaceOpWithNewOp(readOp, targetType, result); return success(); @@ -1654,10 +1660,6 @@ class DropInnerMostUnitDimsTransferWrite if (writeOp.getTransferRank() == 0) return failure(); - // TODO: support mask. - if (writeOp.getMask()) - return failure(); - auto srcType = dyn_cast(writeOp.getBase().getType()); if (!srcType) return failure(); @@ -1709,11 +1711,22 @@ class DropInnerMostUnitDimsTransferWrite auto shapeCast = rewriter.createOrFold( loc, resultTargetVecType, writeOp.getVector()); + + // If there is a mask, shape_cast it to drop the same inner unit dims. + Value mask = writeOp.getMask(); + if (mask) { + auto maskType = cast(mask.getType()); + auto reducedMaskType = VectorType::get( + maskType.getShape().drop_back(dimsToDrop), maskType.getElementType(), + maskType.getScalableDims().drop_back(dimsToDrop)); + mask = rewriter.createOrFold(loc, reducedMaskType, + mask); + } + rewriter.replaceOpWithNewOp( writeOp, shapeCast, rankedReducedView, writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap), - // TODO: support mask. - /*mask=*/Value(), inBoundsAttr); + mask, inBoundsAttr); return success(); } }; diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir index 18c28799a62e5..1bedce7ea6a67 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir @@ -266,6 +266,23 @@ func.func @contiguous_inner_most_dim_with_subview_2d_scalable_inner_dim(%src: me // ----- +func.func @contiguous_inner_most_with_mask(%src: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %mask: vector<1x8x1xi1>) -> vector<1x8x1xf32>{ + %c0 = arith.constant 0 : index + %pad = arith.constant 0.0 : f32 + %v = vector.transfer_read %src[%c0, %c0, %c0, %c0], %pad, %mask {in_bounds = [true, true, true]} : memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, vector<1x8x1xf32> + return %v : vector<1x8x1xf32> +} +// CHECK: func @contiguous_inner_most_with_mask(%[[SRC:.+]]: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %[[MASK:.+]]: vector<1x8x1xi1>) +// CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]] +// CHECK-SAME: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>> +// CHECK: %[[REDUCED_MASK:.+]] = vector.shape_cast %[[MASK]] : vector<1x8x1xi1> to vector<1x8xi1> +// CHECK: %[[VEC:.+]] = vector.transfer_read %[[SRC_0]]{{.*}}, %[[REDUCED_MASK]] +// CHECK-SAME: memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>, vector<1x8xf32> +// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[VEC]] +// CHECK: return %[[RESULT]] + +// ----- + // NOTE: This is an out-of-bounds access. func.func @negative_non_unit_inner_vec_dim(%src: memref<4x1xf32>) -> vector<4x8xf32> { @@ -580,6 +597,20 @@ func.func @contiguous_inner_most_dim_with_subview_2d_scalable(%dest: memref<1000 // ----- +func.func @contiguous_inner_most_with_mask(%dest: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %vec: vector<1x8x1xf32>, %mask: vector<1x8x1xi1>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %dest[%c0, %c0, %c0, %c0], %mask {in_bounds = [true, true, true]} : vector<1x8x1xf32>, memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> + return +} +// CHECK: func @contiguous_inner_most_with_mask(%[[DEST:.+]]: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, %[[VEC:.+]]: vector<1x8x1xf32>, %[[MASK:.+]]: vector<1x8x1xi1>) +// CHECK: %[[DEST_0:.+]] = memref.subview %[[DEST]] +// CHECK-SAME: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>> +// CHECK: %[[REDUCED_VEC:.+]] = vector.shape_cast %[[VEC]] : vector<1x8x1xf32> to vector<1x8xf32> +// CHECK: %[[REDUCED_MASK:.+]] = vector.shape_cast %[[MASK]] : vector<1x8x1xi1> to vector<1x8xi1> +// CHECK: vector.transfer_write %[[REDUCED_VEC]], %[[DEST_0]]{{.*}}, %[[REDUCED_MASK]] +// CHECK-SAME: vector<1x8xf32>, memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>> +// ----- + // NOTE: This is an out-of-bounds access. func.func @negative_non_unit_inner_vec_dim(%dest: memref<4x1xf32>, %vec: vector<4x8xf32>) {