diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp index fa95f96b88177..26a702ef0f512 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -537,6 +537,101 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern { return success(); } }; +} // namespace + +// Drops `dropDim` leading dimensions from `operand` using vector.extract when +// those dims are all non-scalable units (the cheap, structural rewrite); falls +// back to vector.shape_cast otherwise. +static Value dropLeadingOneDimsFromOperand(OpBuilder &b, Location loc, + Value operand, int64_t nDropped) { + auto oldType = cast(operand.getType()); + ArrayRef leadingShape = oldType.getShape().take_front(nDropped); + ArrayRef leadingScalable = + oldType.getScalableDims().take_front(nDropped); + bool extractable = + llvm::all_of(leadingShape, [](int64_t d) { return d == 1; }) && + llvm::none_of(leadingScalable, [](bool s) { return s; }); + if (extractable) + return vector::ExtractOp::create(b, loc, operand, splatZero(nDropped)); + VectorType newType = VectorType::get( + oldType.getShape().drop_front(nDropped), oldType.getElementType(), + oldType.getScalableDims().drop_front(nDropped)); + return vector::ShapeCastOp::create(b, loc, newType, operand); +} + +namespace { + +// Drops leading 1 dimensions from load-like memory operaitons. REmoves leading +// unit dimensions from the result types and then broadcasts back in those 1s, +// while also extracting (or shape_cast-ing) any leading unit dimensions on +// the input operands. +template +struct CastAwayLoadLikeLeadingOneDim : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + VectorType oldResultType = op.getVectorType(); + VectorType newResultType = trimLeadingOneDims(oldResultType); + if (newResultType == oldResultType) + return failure(); + int64_t nDropped = oldResultType.getRank() - newResultType.getRank(); + + Location loc = op.getLoc(); + SmallVector newOperands; + newOperands.reserve(op->getNumOperands()); + for (Value operand : op->getOperands()) { + if (isa(operand.getType())) { + newOperands.push_back( + dropLeadingOneDimsFromOperand(rewriter, loc, operand, nDropped)); + } else { + newOperands.push_back(operand); + } + } + + Operation *newOp = + rewriter.create(loc, op->getName().getIdentifier(), newOperands, + TypeRange{newResultType}, op->getAttrs()); + rewriter.replaceOpWithNewOp(op, oldResultType, + newOp->getResult(0)); + return success(); + } +}; + +// Drops leading 1 dimensions from store-like memory ops. Extracts or +// `shape_cast`s away those leading unit dimensions and leaves any scalar +// operands alone. +template +struct CastAwayStoreLikeLeadingOneDim : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + VectorType oldVecType = op.getVectorType(); + VectorType newVecType = trimLeadingOneDims(oldVecType); + if (newVecType == oldVecType) + return failure(); + int64_t nDropped = oldVecType.getRank() - newVecType.getRank(); + + Location loc = op.getLoc(); + SmallVector newOperands; + newOperands.reserve(op->getNumOperands()); + for (Value operand : op->getOperands()) { + if (isa(operand.getType())) { + newOperands.push_back( + dropLeadingOneDimsFromOperand(rewriter, loc, operand, nDropped)); + } else { + newOperands.push_back(operand); + } + } + + Operation *newOp = + rewriter.create(loc, op->getName().getIdentifier(), newOperands, + op->getResultTypes(), op->getAttrs()); + rewriter.replaceOp(op, newOp->getResults()); + return success(); + } +}; // Drops leading 1 dimensions from vector.constant_mask and inserts a // vector.broadcast back to the original shape. @@ -578,5 +673,14 @@ void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns( CastAwayInsertStridedSliceLeadingOneDim, CastAwayInsertLeadingOneDim, CastAwayConstantMaskLeadingOneDim, CastAwayTransferReadLeadingOneDim, CastAwayTransferWriteLeadingOneDim, CastAwayElementwiseLeadingOneDim, - CastAwayContractionLeadingOneDim>(patterns.getContext(), benefit); + CastAwayContractionLeadingOneDim, + CastAwayLoadLikeLeadingOneDim, + CastAwayLoadLikeLeadingOneDim, + CastAwayLoadLikeLeadingOneDim, + CastAwayLoadLikeLeadingOneDim, + CastAwayStoreLikeLeadingOneDim, + CastAwayStoreLikeLeadingOneDim, + CastAwayStoreLikeLeadingOneDim, + CastAwayStoreLikeLeadingOneDim>( + patterns.getContext(), benefit); } diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir index aee77ce3da553..bf01c8a8589d9 100644 --- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir @@ -693,3 +693,98 @@ func.func @drop_unit_dims_scalar_cond_select(%cond: i1, %arg0: vector<1x16xi1>, %sel = arith.select %cond, %arg0, %arg1 : vector<1x16xi1> return %sel : vector<1x16xi1> } + +// ----- + +// CHECK-LABEL: func.func @cast_away_load_leading_one_dims +// CHECK: %[[L:.+]] = vector.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<8x16xf32>, vector<4xf32> +// CHECK: %[[B:.+]] = vector.broadcast %[[L]] : vector<4xf32> to vector<1x4xf32> +// CHECK: return %[[B]] : vector<1x4xf32> +func.func @cast_away_load_leading_one_dims(%base: memref<8x16xf32>, %i: index, %j: index) -> vector<1x4xf32> { + %0 = vector.load %base[%i, %j] : memref<8x16xf32>, vector<1x4xf32> + return %0 : vector<1x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @cast_away_maskedload_leading_one_dims +// CHECK: %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1> +// CHECK: %[[P:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32> +// CHECK: %[[L:.+]] = vector.maskedload %{{.*}}[%{{.*}}], %[[M]], %[[P]] : memref<16xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32> +// CHECK: %[[B:.+]] = vector.broadcast %[[L]] : vector<4xf32> to vector<1x4xf32> +// CHECK: return %[[B]] : vector<1x4xf32> +func.func @cast_away_maskedload_leading_one_dims(%base: memref<16xf32>, %i: index, %mask: vector<1x4xi1>, %pass: vector<1x4xf32>) -> vector<1x4xf32> { + %0 = vector.maskedload %base[%i], %mask, %pass : memref<16xf32>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> + return %0 : vector<1x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @cast_away_expandload_leading_one_dims +// CHECK: %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1> +// CHECK: %[[P:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32> +// CHECK: %[[L:.+]] = vector.expandload %{{.*}}[%{{.*}}], %[[M]], %[[P]] : memref<16xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32> +// CHECK: %[[B:.+]] = vector.broadcast %[[L]] : vector<4xf32> to vector<1x4xf32> +// CHECK: return %[[B]] : vector<1x4xf32> +func.func @cast_away_expandload_leading_one_dims(%base: memref<16xf32>, %i: index, %mask: vector<1x4xi1>, %pass: vector<1x4xf32>) -> vector<1x4xf32> { + %0 = vector.expandload %base[%i], %mask, %pass : memref<16xf32>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> + return %0 : vector<1x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @cast_away_gather_leading_one_dims +// CHECK: %[[I:.+]] = vector.extract %{{.*}}[0] : vector<4xi32> from vector<1x4xi32> +// CHECK: %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1> +// CHECK: %[[P:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32> +// CHECK: %[[G:.+]] = vector.gather %{{.*}}[%{{.*}}] [%[[I]]], %[[M]], %[[P]] : memref<16xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32> +// CHECK: %[[B:.+]] = vector.broadcast %[[G]] : vector<4xf32> to vector<1x4xf32> +// CHECK: return %[[B]] : vector<1x4xf32> +func.func @cast_away_gather_leading_one_dims(%base: memref<16xf32>, %i: index, %idx: vector<1x4xi32>, %mask: vector<1x4xi1>, %pass: vector<1x4xf32>) -> vector<1x4xf32> { + %0 = vector.gather %base[%i] [%idx], %mask, %pass : memref<16xf32>, vector<1x4xi32>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> + return %0 : vector<1x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @cast_away_store_leading_one_dims +// CHECK: %[[V:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32> +// CHECK: vector.store %[[V]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<8x16xf32>, vector<4xf32> +func.func @cast_away_store_leading_one_dims(%val: vector<1x4xf32>, %base: memref<8x16xf32>, %i: index, %j: index) { + vector.store %val, %base[%i, %j] : memref<8x16xf32>, vector<1x4xf32> + return +} + +// ----- + +// CHECK-LABEL: func.func @cast_away_maskedstore_leading_one_dims +// CHECK: %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1> +// CHECK: %[[V:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32> +// CHECK: vector.maskedstore %{{.*}}[%{{.*}}], %[[M]], %[[V]] : memref<16xf32>, vector<4xi1>, vector<4xf32> +func.func @cast_away_maskedstore_leading_one_dims(%base: memref<16xf32>, %i: index, %mask: vector<1x4xi1>, %val: vector<1x4xf32>) { + vector.maskedstore %base[%i], %mask, %val : memref<16xf32>, vector<1x4xi1>, vector<1x4xf32> + return +} + +// ----- + +// CHECK-LABEL: func.func @cast_away_compressstore_leading_one_dims +// CHECK: %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1> +// CHECK: %[[V:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32> +// CHECK: vector.compressstore %{{.*}}[%{{.*}}], %[[M]], %[[V]] : memref<16xf32>, vector<4xi1>, vector<4xf32> +func.func @cast_away_compressstore_leading_one_dims(%base: memref<16xf32>, %i: index, %mask: vector<1x4xi1>, %val: vector<1x4xf32>) { + vector.compressstore %base[%i], %mask, %val : memref<16xf32>, vector<1x4xi1>, vector<1x4xf32> + return +} + +// ----- + +// CHECK-LABEL: func.func @cast_away_scatter_leading_one_dims +// CHECK: %[[I:.+]] = vector.extract %{{.*}}[0] : vector<4xi32> from vector<1x4xi32> +// CHECK: %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1> +// CHECK: %[[V:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32> +// CHECK: vector.scatter %{{.*}}[%{{.*}}] [%[[I]]], %[[M]], %[[V]] : memref<16xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> +func.func @cast_away_scatter_leading_one_dims(%base: memref<16xf32>, %i: index, %idx: vector<1x4xi32>, %mask: vector<1x4xi1>, %val: vector<1x4xf32>) { + vector.scatter %base[%i] [%idx], %mask, %val : memref<16xf32>, vector<1x4xi32>, vector<1x4xi1>, vector<1x4xf32> + return +}