diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp index aad42039300e3..26a702ef0f512 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -7,9 +7,7 @@ //===----------------------------------------------------------------------===// #include -#include -#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" @@ -17,7 +15,6 @@ #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/TypeUtilities.h" -#include "llvm/ADT/Repeated.h" #include "llvm/ADT/STLExtras.h" #define DEBUG_TYPE "vector-drop-unit-dim" @@ -25,9 +22,9 @@ using namespace mlir; using namespace mlir::vector; -// Trims leading unit dimensions from `oldType` and returns the result type. -static VectorType trimLeadingUnitDims(VectorType oldType, - bool zeroDimsAllowed) { +// Trims leading one dimensions from `oldType` and returns the result type. +// Returns `vector<1xT>` if `oldType` only has one element. +static VectorType trimLeadingOneDims(VectorType oldType) { ArrayRef oldShape = oldType.getShape(); ArrayRef newShape = oldShape; @@ -40,117 +37,22 @@ static VectorType trimLeadingUnitDims(VectorType oldType, newScalableDims = newScalableDims.drop_front(1); } - // Some vector ops forbid 0-D vectors. - if (!zeroDimsAllowed && newShape.empty()) { + // Make sure we have at least 1 dimension per vector type requirements. + if (newShape.empty()) { newShape = oldShape.take_back(); newScalableDims = oldType.getScalableDims().take_back(); } return VectorType::get(newShape, oldType.getElementType(), newScalableDims); } -static bool isNonScalableUnitDim(VectorType type, int64_t dim) { - assert(dim >= 0 && dim < type.getRank() && - "expected a valid vector dimension"); - return type.getShape()[dim] == 1 && !type.getScalableDims()[dim]; +/// Return a smallVector of size `rank` containing all zeros. +static SmallVector splatZero(int64_t rank) { + return SmallVector(rank, 0); } - -/// Returns true if the first `k` dimensions of `type` are non-scalable unit -/// dimensions. -static bool areLeadingDimsUnit(VectorType type, int64_t k) { - assert(k >= 0 && k <= type.getRank() && - "expected a valid leading dimension count"); - return llvm::all_of(llvm::seq(0, k), [&](int64_t dim) { - return isNonScalableUnitDim(type, dim); - }); -} - -static bool areLeadingDimsUnitAfterPermutation(VectorType type, - ArrayRef permutation, - int64_t k) { - assert(k >= 0 && k <= static_cast(permutation.size()) && - "expected a valid leading dimension count"); - return llvm::all_of(permutation.take_front(k), [&](int64_t dim) { - return isNonScalableUnitDim(type, dim); - }); -} - -/// Shape-casts `operand` to the vector type obtained by dropping dimension -/// `dim`, which must be non-scalable and unit-sized. -static Value dropUnitDim(OpBuilder &b, Location loc, Value operand, - int64_t dimToDrop, bool zeroDimsAllowed) { - auto oldType = cast(operand.getType()); - assert(isNonScalableUnitDim(oldType, dimToDrop) && - "expected a non-scalable unit dim to drop"); - int64_t rank = oldType.getRank(); - assert((zeroDimsAllowed || rank > 1) && - "target op does not allow 0-D vectors"); - - SmallVector newShape; - SmallVector newScalableDims; - newShape.reserve(rank - 1); - newScalableDims.reserve(rank - 1); - for (auto [i, size, scalable] : - llvm::enumerate(oldType.getShape(), oldType.getScalableDims())) { - if (static_cast(i) == dimToDrop) - continue; - newShape.push_back(size); - newScalableDims.push_back(scalable); - } - - return b.createOrFold( - loc, VectorType::get(newShape, oldType.getElementType(), newScalableDims), - operand); -} - -/// Shape-casts `operand` to the vector type obtained by dropping the first -/// `k` non-scalable unit dimensions. -static Value dropLeadingUnitDims(OpBuilder &b, Location loc, Value operand, - int64_t k, bool zeroDimsAllowed) { - auto oldType = cast(operand.getType()); - assert(areLeadingDimsUnit(oldType, k) && - "expected non-scalable leading unit dims to drop"); - assert((zeroDimsAllowed || k < oldType.getRank()) && - "target op does not allow 0-D vectors"); - VectorType newType = VectorType::get(oldType.getShape().drop_front(k), - oldType.getElementType(), - oldType.getScalableDims().drop_front(k)); - return b.createOrFold(loc, newType, operand); -} - -/// Returns the vector type obtained by applying `permutation` to `type`. -static VectorType permuteVectorType(VectorType type, - ArrayRef permutation) { - assert(static_cast(permutation.size()) == type.getRank() && - "expected a permutation matching the operand rank"); - SmallVector permutedShape = - applyPermutation(type.getShape(), permutation); - SmallVector permutedScalableDims = - applyPermutation(type.getScalableDims(), permutation); - return VectorType::get(permutedShape, type.getElementType(), - permutedScalableDims); -} - -/// Like `dropLeadingUnitDims` except that if all dimensions would be dropped, -/// the single element inside that vector is extracted and returned. -static Value dropLeadingUnitDims0DIsScalar(OpBuilder &b, Location loc, - Value operand, int64_t k) { - auto oldType = cast(operand.getType()); - assert(areLeadingDimsUnit(oldType, k) && - "expected non-scalable leading unit dims to drop"); - - if (k == oldType.getRank()) { - SmallVector zeros(k, static_cast(0)); - return vector::ExtractOp::create(b, loc, operand, zeros); - } - - return dropLeadingUnitDims(b, loc, operand, k, - /*zeroDimsAllowed=*/true); -} - namespace { // Casts away leading one dimensions in vector.extract_strided_slice's vector -// input by inserting vector.shape_cast. +// input by inserting vector.broadcast. struct CastAwayExtractStridedSliceLeadingOneDim : public OpRewritePattern { using Base::Base; @@ -161,8 +63,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim // the same rank. Here we drop leading one dimensions from the input vector // type to make sure we don't cause mismatch. VectorType oldSrcType = extractOp.getSourceVectorType(); - VectorType newSrcType = - trimLeadingUnitDims(oldSrcType, /*zeroDimsAllowed=*/false); + VectorType newSrcType = trimLeadingOneDims(oldSrcType); if (newSrcType.getRank() == oldSrcType.getRank()) return failure(); @@ -177,8 +78,8 @@ struct CastAwayExtractStridedSliceLeadingOneDim Location loc = extractOp.getLoc(); - Value newSrcVector = rewriter.createOrFold( - loc, newSrcType, extractOp.getSource()); + Value newSrcVector = vector::ExtractOp::create( + rewriter, loc, extractOp.getSource(), splatZero(dropCount)); // The offsets/sizes/strides attribute can have a less number of elements // than the input vector's rank: it is meant for the leading dimensions. @@ -193,7 +94,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim rewriter, loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides); - rewriter.replaceOpWithNewOp(extractOp, oldDstType, + rewriter.replaceOpWithNewOp(extractOp, oldDstType, newExtractOp); return success(); @@ -201,7 +102,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim }; // Casts away leading one dimensions in vector.insert_strided_slice's vector -// inputs by inserting vector.shape_cast. +// inputs by inserting vector.broadcast. struct CastAwayInsertStridedSliceLeadingOneDim : public OpRewritePattern { using Base::Base; @@ -209,11 +110,9 @@ struct CastAwayInsertStridedSliceLeadingOneDim LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp, PatternRewriter &rewriter) const override { VectorType oldSrcType = insertOp.getSourceVectorType(); - VectorType newSrcType = - trimLeadingUnitDims(oldSrcType, /*zeroDimsAllowed=*/false); + VectorType newSrcType = trimLeadingOneDims(oldSrcType); VectorType oldDstType = insertOp.getDestVectorType(); - VectorType newDstType = - trimLeadingUnitDims(oldDstType, /*zeroDimsAllowed=*/false); + VectorType newDstType = trimLeadingOneDims(oldDstType); int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank(); int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank(); @@ -223,10 +122,10 @@ struct CastAwayInsertStridedSliceLeadingOneDim // Trim leading one dimensions from both operands. Location loc = insertOp.getLoc(); - Value newSrcVector = rewriter.createOrFold( - loc, newSrcType, insertOp.getValueToStore()); - Value newDstVector = rewriter.createOrFold( - loc, newDstType, insertOp.getDest()); + Value newSrcVector = vector::ExtractOp::create( + rewriter, loc, insertOp.getValueToStore(), splatZero(srcDropCount)); + Value newDstVector = vector::ExtractOp::create( + rewriter, loc, insertOp.getDest(), splatZero(dstDropCount)); auto newOffsets = rewriter.getArrayAttr( insertOp.getOffsets().getValue().take_back(newDstType.getRank())); @@ -237,7 +136,7 @@ struct CastAwayInsertStridedSliceLeadingOneDim rewriter, loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides); - rewriter.replaceOpWithNewOp(insertOp, oldDstType, + rewriter.replaceOpWithNewOp(insertOp, oldDstType, newInsertOp); return success(); @@ -245,7 +144,7 @@ struct CastAwayInsertStridedSliceLeadingOneDim }; // Casts away leading one dimensions in vector.insert's vector inputs by -// inserting vector.shape_cast. +// inserting vector.broadcast. struct CastAwayInsertLeadingOneDim : public OpRewritePattern { using Base::Base; @@ -255,14 +154,13 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern { Type newSrcType = oldSrcType; int64_t oldSrcRank = 0, newSrcRank = 0; if (auto type = dyn_cast(oldSrcType)) { - newSrcType = trimLeadingUnitDims(type, /*zeroDimsAllowed=*/false); + newSrcType = trimLeadingOneDims(type); oldSrcRank = type.getRank(); newSrcRank = cast(newSrcType).getRank(); } VectorType oldDstType = insertOp.getDestVectorType(); - VectorType newDstType = - trimLeadingUnitDims(oldDstType, /*zeroDimsAllowed=*/oldSrcRank == 0); + VectorType newDstType = trimLeadingOneDims(oldDstType); int64_t srcDropCount = oldSrcRank - newSrcRank; int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank(); @@ -273,11 +171,12 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern { Location loc = insertOp.getLoc(); Value newSrcVector = insertOp.getValueToStore(); - if (oldSrcRank != 0) - newSrcVector = rewriter.createOrFold( - loc, cast(newSrcType), insertOp.getValueToStore()); - Value newDstVector = rewriter.createOrFold( - loc, newDstType, insertOp.getDest()); + if (oldSrcRank != 0) { + newSrcVector = vector::ExtractOp::create( + rewriter, loc, insertOp.getValueToStore(), splatZero(srcDropCount)); + } + Value newDstVector = vector::ExtractOp::create( + rewriter, loc, insertOp.getDest(), splatZero(dstDropCount)); // New position rank needs to be computed in two steps: (1) if destination // type has leading unit dims, we also trim the position array accordingly, @@ -294,7 +193,7 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern { auto newInsertOp = vector::InsertOp::create(rewriter, loc, newSrcVector, newDstVector, newPosition); - rewriter.replaceOpWithNewOp(insertOp, oldDstType, + rewriter.replaceOpWithNewOp(insertOp, oldDstType, newInsertOp); return success(); @@ -302,10 +201,20 @@ struct CastAwayInsertLeadingOneDim : public OpRewritePattern { }; static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask, - VectorType newType, AffineMap newMap) { + VectorType newType, AffineMap newMap, + VectorType oldMaskType) { // Infer the type of the new mask from the new map. VectorType newMaskType = inferTransferOpMaskType(newType, newMap); - return b.createOrFold(loc, newMaskType, mask); + + // If the new mask is broadcastable to the old result type, we can safely + // use a `vector.extract` to get the new mask. Otherwise the best we can + // do is shape cast. + if (vector::isBroadcastableTo(newMaskType, oldMaskType) == + BroadcastableToResult::Success) { + int64_t dropDim = oldMaskType.getRank() - newMaskType.getRank(); + return vector::ExtractOp::create(b, loc, mask, splatZero(dropDim)); + } + return vector::ShapeCastOp::create(b, loc, newMaskType, mask); } // Turns vector.transfer_read on vector with leading 1 dimensions into @@ -320,7 +229,7 @@ struct CastAwayTransferReadLeadingOneDim // TODO(#78787): Not supported masked op yet. if (cast(read.getOperation()).isMasked()) return failure(); - // Nothing to trim when the transfer itself has rank zero. + // TODO: support 0-d corner case. if (read.getTransferRank() == 0) return failure(); @@ -329,7 +238,7 @@ struct CastAwayTransferReadLeadingOneDim return failure(); VectorType oldType = read.getVectorType(); - VectorType newType = trimLeadingUnitDims(oldType, /*zeroDimsAllowed=*/true); + VectorType newType = trimLeadingOneDims(oldType); if (newType == oldType) return failure(); @@ -347,14 +256,16 @@ struct CastAwayTransferReadLeadingOneDim read.getInBoundsAttr().getValue().take_back(newType.getRank())); Value mask = Value(); - if (read.getMask()) + if (read.getMask()) { + VectorType maskType = read.getMaskType(); mask = dropUnitDimsFromMask(rewriter, read.getLoc(), read.getMask(), - newType, newMap); + newType, newMap, maskType); + } auto newRead = vector::TransferReadOp::create( rewriter, read.getLoc(), newType, read.getBase(), read.getIndices(), AffineMapAttr::get(newMap), read.getPadding(), mask, inBoundsAttr); - rewriter.replaceOpWithNewOp(read, oldType, newRead); + rewriter.replaceOpWithNewOp(read, oldType, newRead); return success(); } @@ -372,7 +283,7 @@ struct CastAwayTransferWriteLeadingOneDim // TODO(#78787): Not supported masked op yet. if (cast(write.getOperation()).isMasked()) return failure(); - // Nothing to trim when the transfer itself has rank zero. + // TODO: support 0-d corner case. if (write.getTransferRank() == 0) return failure(); @@ -381,9 +292,11 @@ struct CastAwayTransferWriteLeadingOneDim return failure(); VectorType oldType = write.getVectorType(); - VectorType newType = trimLeadingUnitDims(oldType, /*zeroDimsAllowed=*/true); + VectorType newType = trimLeadingOneDims(oldType); if (newType == oldType) return failure(); + int64_t dropDim = oldType.getRank() - newType.getRank(); + AffineMap oldMap = write.getPermutationMap(); ArrayRef newResults = oldMap.getResults().take_back(newType.getRank()); @@ -396,12 +309,13 @@ struct CastAwayTransferWriteLeadingOneDim inBoundsAttr = rewriter.getArrayAttr( write.getInBoundsAttr().getValue().take_back(newType.getRank())); - auto newVector = rewriter.createOrFold( - write.getLoc(), newType, write.getVector()); + auto newVector = vector::ExtractOp::create( + rewriter, write.getLoc(), write.getVector(), splatZero(dropDim)); if (write.getMask()) { - Value newMask = dropUnitDimsFromMask(rewriter, write.getLoc(), - write.getMask(), newType, newMap); + VectorType maskType = write.getMaskType(); + Value newMask = dropUnitDimsFromMask( + rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType); rewriter.replaceOpWithNewOp( write, newVector, write.getBase(), write.getIndices(), AffineMapAttr::get(newMap), newMask, inBoundsAttr); @@ -417,15 +331,6 @@ struct CastAwayTransferWriteLeadingOneDim } // namespace -namespace { -struct VectorContractOperandCastPlan { - AffineMap map; - SmallVector permutation; - bool dropLeadingUnitDim = false; - bool permuteOperand = false; -}; -} // namespace - FailureOr mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, MaskingOpInterface maskingOp, @@ -435,7 +340,7 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, return failure(); if (oldAccType.getRank() < 1) return failure(); - if (!isNonScalableUnitDim(oldAccType, 0)) + if (oldAccType.getShape()[0] != 1) return failure(); // currently we support only dropping one dim but the pattern can be applied // greedily to drop more. @@ -462,70 +367,74 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, SmallVector operands = {contractOp.getLhs(), contractOp.getRhs(), contractOp.getAcc()}; - SmallVector operandCastPlans; SmallVector newOperands; auto loc = contractOp.getLoc(); - if (maskingOp) { - auto oldMaskType = cast(maskingOp.getMask().getType()); - if (oldMaskType.getRank() <= 1 || dimToDrop >= oldMaskType.getRank() || - !isNonScalableUnitDim(oldMaskType, dimToDrop)) - return failure(); - } - for (const auto &it : llvm::enumerate(oldIndexingMaps)) { // Check if the dim to be dropped exists as a leading dim in the operand - // if it does then we use vector.shape_cast to drop it. - VectorContractOperandCastPlan plan; + // if it does then we use vector.extract to drop it. + bool validExtract = false; SmallVector results; - plan.map = it.value(); - int64_t originalZeroDim = plan.map.getDimPosition(0); - if (originalZeroDim != dimToDrop) { + auto map = it.value(); + int64_t orginalZeroDim = it.value().getDimPosition(0); + if (orginalZeroDim != dimToDrop) { // There are two reasons to be in this path, 1. We need to - // permute the operand type to make the dim to be dropped + // transpose the operand to make the dim to be dropped // leading. 2. The dim to be dropped does not exist and in - // that case we dont want to add a unit permutation but we must + // that case we dont want to add a unit transpose but we must // check all the indices to make sure this is the case. - SmallVector permutedResults; + bool transposeNeeded = false; + SmallVector perm; + SmallVector transposeResults; - for (int64_t i = 0, e = plan.map.getNumResults(); i < e; ++i) { - int64_t currDim = plan.map.getDimPosition(i); + for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { + int64_t currDim = map.getDimPosition(i); if (currDim == dimToDrop) { - plan.permuteOperand = true; - plan.permutation.insert(plan.permutation.begin(), i); + transposeNeeded = true; + perm.insert(perm.begin(), i); auto targetExpr = rewriter.getAffineDimExpr(currDim); - permutedResults.insert(permutedResults.begin(), targetExpr); + transposeResults.insert(transposeResults.begin(), targetExpr); } else { - plan.permutation.push_back(i); + perm.push_back(i); auto targetExpr = rewriter.getAffineDimExpr(currDim); - permutedResults.push_back(targetExpr); + transposeResults.push_back(targetExpr); } } - // Update the map now so that the later shape_cast drops the correct dim. - if (plan.permuteOperand) { - plan.map = AffineMap::get(plan.map.getNumDims(), 0, permutedResults, - contractOp.getContext()); - if (plan.map.getDimPosition(0) == dimToDrop) { - auto operandType = cast(operands[it.index()].getType()); - if (!areLeadingDimsUnitAfterPermutation(operandType, plan.permutation, - dropDim)) - return failure(); + // Checks if only the outer, unit dimensions (of size 1) are permuted. + // Such transposes do not materially effect the underlying vector and can + // be omitted. EG: perm [1, 0, 2] applied to vector<1x1x8xi32> + bool transposeNonOuterUnitDims = false; + auto operandShape = cast(operands[it.index()].getType()); + for (auto [index, dim] : + llvm::enumerate(ArrayRef(perm).drop_back(1))) { + if (dim != static_cast(index) && + operandShape.getDimSize(index) != 1) { + transposeNonOuterUnitDims = true; + break; + } + } + + // Do the transpose now if needed so that we can drop the + // correct dim using extract later. + if (transposeNeeded) { + map = AffineMap::get(map.getNumDims(), 0, transposeResults, + contractOp.getContext()); + if (transposeNonOuterUnitDims) { + operands[it.index()] = rewriter.createOrFold( + loc, operands[it.index()], perm); } } } // We have taken care to have the dim to be dropped be // the leading dim. If its still not leading that means it - // does not exist in this operand and hence we do not need a shape_cast. - if (plan.map.getDimPosition(0) == dimToDrop) - plan.dropLeadingUnitDim = true; - if (plan.dropLeadingUnitDim && originalZeroDim == dimToDrop && - !areLeadingDimsUnit(cast(operands[it.index()].getType()), - dropDim)) - return failure(); + // does not exist in this operand and hence we do not need + // an extract. + if (map.getDimPosition(0) == dimToDrop) + validExtract = true; - for (int64_t i = 0, e = plan.map.getNumResults(); i < e; ++i) { - int64_t currDim = plan.map.getDimPosition(i); + for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { + int64_t currDim = map.getDimPosition(i); if (currDim == dimToDrop) // This is the dim we are dropping. continue; @@ -533,23 +442,15 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, currDim < dimToDrop ? currDim : currDim - 1); results.push_back(targetExpr); } - newIndexingMaps.push_back(AffineMap::get(plan.map.getNumDims() - 1, 0, - results, contractOp.getContext())); - operandCastPlans.push_back(std::move(plan)); - } - - for (auto [plan, operand] : llvm::zip_equal(operandCastPlans, operands)) { - Value newOperand = operand; - if (plan.permuteOperand) - newOperand = rewriter.createOrFold( - loc, - permuteVectorType(cast(newOperand.getType()), - plan.permutation), - newOperand); - if (plan.dropLeadingUnitDim) - newOperand = - dropLeadingUnitDims0DIsScalar(rewriter, loc, newOperand, dropDim); - newOperands.push_back(newOperand); + newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results, + contractOp.getContext())); + // Extract if its a valid extraction, otherwise use the operand + // without extraction. + newOperands.push_back(validExtract + ? vector::ExtractOp::create(rewriter, loc, + operands[it.index()], + splatZero(dropDim)) + : operands[it.index()]); } // Depending on whether this vector.contract is masked, the replacing Op @@ -560,19 +461,13 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind()); if (maskingOp) { - Value newMask = dropUnitDim(rewriter, loc, maskingOp.getMask(), dimToDrop, - /*zeroDimsAllowed=*/false); + auto newMask = vector::ExtractOp::create(rewriter, loc, maskingOp.getMask(), + splatZero(dropDim)); newOp = mlir::vector::maskOperation(rewriter, newOp, newMask); } - if (!isa(newOp->getResults()[0].getType())) - return vector::BroadcastOp::create(rewriter, loc, - contractOp->getResultTypes()[0], - newOp->getResults()[0]) - .getResult(); - - return vector::ShapeCastOp::create(rewriter, loc, + return vector::BroadcastOp::create(rewriter, loc, contractOp->getResultTypes()[0], newOp->getResults()[0]) .getResult(); @@ -581,9 +476,9 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, namespace { /// Turns vector.contract on vector with leading 1 dimensions into -/// vector.shape_cast followed by vector.contract on vector without leading -/// 1 dimensions. Non-leading unit dimensions are dropped via direct -/// shape_casts. +/// vector.extract followed by vector.contract on vector without leading +/// 1 dimensions. Also performs transpose of lhs and rhs operands if required +/// prior to extract. struct CastAwayContractionLeadingOneDim : public MaskableOpRewritePattern { using MaskableOpRewritePattern::MaskableOpRewritePattern; @@ -598,15 +493,14 @@ struct CastAwayContractionLeadingOneDim /// Looks at elementwise operations on vectors with at least one leading /// dimension equal 1, e.g. vector<1x[4]x1xf32> (but not vector<2x[4]x1xf32>), -/// and casts away the leading one dimensions (_plural_) with shape_cast. +/// and cast aways the leading one dimensions (_plural_) and then broadcasts +/// the results. /// /// Example before: /// %1 = arith.mulf %arg0, %arg1 : vector<1x4x1xf32> /// Example after: -/// %2 = vector.shape_cast %arg0 : vector<1x4x1xf32> to vector<4x1xf32> -/// %3 = vector.shape_cast %arg1 : vector<1x4x1xf32> to vector<4x1xf32> -/// %4 = arith.mulf %2, %3 : vector<4x1xf32> -/// %5 = vector.shape_cast %4 : vector<4x1xf32> to vector<1x4x1xf32> +/// %2 = arith.mulf %0, %1 : vector<4x1xf32> +/// %3 = vector.broadcast %2 : vector<4x1xf32> to vector<1x4x1xf32> /// /// Does support scalable vectors. class CastAwayElementwiseLeadingOneDim : public RewritePattern { @@ -622,34 +516,55 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern { auto vecType = dyn_cast(op->getResultTypes()[0]); if (!vecType) return failure(); - VectorType newVecType = - trimLeadingUnitDims(vecType, /*zeroDimsAllowed=*/true); + VectorType newVecType = trimLeadingOneDims(vecType); if (newVecType == vecType) return failure(); + int64_t dropDim = vecType.getRank() - newVecType.getRank(); SmallVector newOperands; for (Value operand : op->getOperands()) { - if (auto opVecType = dyn_cast(operand.getType())) - newOperands.push_back(rewriter.createOrFold( - op->getLoc(), - trimLeadingUnitDims(opVecType, /*zeroDimsAllowed=*/true), operand)); - else + if (auto opVecType = dyn_cast(operand.getType())) { + newOperands.push_back(vector::ExtractOp::create( + rewriter, op->getLoc(), operand, splatZero(dropDim))); + } else { newOperands.push_back(operand); + } } Operation *newOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(), newOperands, newVecType, op->getAttrs()); - rewriter.replaceOpWithNewOp(op, vecType, + rewriter.replaceOpWithNewOp(op, vecType, newOp->getResult(0)); return success(); } }; } // namespace +// Drops `dropDim` leading dimensions from `operand` using vector.extract when +// those dims are all non-scalable units (the cheap, structural rewrite); falls +// back to vector.shape_cast otherwise. +static Value dropLeadingOneDimsFromOperand(OpBuilder &b, Location loc, + Value operand, int64_t nDropped) { + auto oldType = cast(operand.getType()); + ArrayRef leadingShape = oldType.getShape().take_front(nDropped); + ArrayRef leadingScalable = + oldType.getScalableDims().take_front(nDropped); + bool extractable = + llvm::all_of(leadingShape, [](int64_t d) { return d == 1; }) && + llvm::none_of(leadingScalable, [](bool s) { return s; }); + if (extractable) + return vector::ExtractOp::create(b, loc, operand, splatZero(nDropped)); + VectorType newType = VectorType::get( + oldType.getShape().drop_front(nDropped), oldType.getElementType(), + oldType.getScalableDims().drop_front(nDropped)); + return vector::ShapeCastOp::create(b, loc, newType, operand); +} + namespace { -// Drops leading unit dimensions from load-like memory operations by -// shape_casting each vector operand and shape_casting the result back to the -// original type. +// Drops leading 1 dimensions from load-like memory operaitons. REmoves leading +// unit dimensions from the result types and then broadcasts back in those 1s, +// while also extracting (or shape_cast-ing) any leading unit dimensions on +// the input operands. template struct CastAwayLoadLikeLeadingOneDim : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -657,10 +572,7 @@ struct CastAwayLoadLikeLeadingOneDim : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { VectorType oldResultType = op.getVectorType(); - constexpr bool zeroDimsAllowed = - llvm::is_one_of::value; - VectorType newResultType = - trimLeadingUnitDims(oldResultType, zeroDimsAllowed); + VectorType newResultType = trimLeadingOneDims(oldResultType); if (newResultType == oldResultType) return failure(); int64_t nDropped = oldResultType.getRank() - newResultType.getRank(); @@ -670,8 +582,8 @@ struct CastAwayLoadLikeLeadingOneDim : public OpRewritePattern { newOperands.reserve(op->getNumOperands()); for (Value operand : op->getOperands()) { if (isa(operand.getType())) { - newOperands.push_back(dropLeadingUnitDims(rewriter, loc, operand, - nDropped, zeroDimsAllowed)); + newOperands.push_back( + dropLeadingOneDimsFromOperand(rewriter, loc, operand, nDropped)); } else { newOperands.push_back(operand); } @@ -680,14 +592,15 @@ struct CastAwayLoadLikeLeadingOneDim : public OpRewritePattern { Operation *newOp = rewriter.create(loc, op->getName().getIdentifier(), newOperands, TypeRange{newResultType}, op->getAttrs()); - rewriter.replaceOpWithNewOp(op, oldResultType, + rewriter.replaceOpWithNewOp(op, oldResultType, newOp->getResult(0)); return success(); } }; -// Drops leading unit dimensions from store-like memory operations by -// shape_casting each vector operand and leaving any scalar operands alone. +// Drops leading 1 dimensions from store-like memory ops. Extracts or +// `shape_cast`s away those leading unit dimensions and leaves any scalar +// operands alone. template struct CastAwayStoreLikeLeadingOneDim : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -695,9 +608,7 @@ struct CastAwayStoreLikeLeadingOneDim : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { VectorType oldVecType = op.getVectorType(); - constexpr bool zeroDimsAllowed = - llvm::is_one_of::value; - VectorType newVecType = trimLeadingUnitDims(oldVecType, zeroDimsAllowed); + VectorType newVecType = trimLeadingOneDims(oldVecType); if (newVecType == oldVecType) return failure(); int64_t nDropped = oldVecType.getRank() - newVecType.getRank(); @@ -707,8 +618,8 @@ struct CastAwayStoreLikeLeadingOneDim : public OpRewritePattern { newOperands.reserve(op->getNumOperands()); for (Value operand : op->getOperands()) { if (isa(operand.getType())) { - newOperands.push_back(dropLeadingUnitDims(rewriter, loc, operand, - nDropped, zeroDimsAllowed)); + newOperands.push_back( + dropLeadingOneDimsFromOperand(rewriter, loc, operand, nDropped)); } else { newOperands.push_back(operand); } @@ -722,8 +633,8 @@ struct CastAwayStoreLikeLeadingOneDim : public OpRewritePattern { } }; -// Drops leading 1 dimensions from vector.constant_mask and shape_casts back to -// the original shape. +// Drops leading 1 dimensions from vector.constant_mask and inserts a +// vector.broadcast back to the original shape. struct CastAwayConstantMaskLeadingOneDim : public OpRewritePattern { using Base::Base; @@ -731,8 +642,7 @@ struct CastAwayConstantMaskLeadingOneDim LogicalResult matchAndRewrite(vector::ConstantMaskOp mask, PatternRewriter &rewriter) const override { VectorType oldType = mask.getType(); - VectorType newType = trimLeadingUnitDims(oldType, - /*zeroDimsAllowed=*/true); + VectorType newType = trimLeadingOneDims(oldType); if (newType == oldType) return failure(); @@ -740,22 +650,16 @@ struct CastAwayConstantMaskLeadingOneDim int64_t dropDim = oldType.getRank() - newType.getRank(); ArrayRef dimSizes = mask.getMaskDimSizes(); - // If any of the folded unit dims has a size of `0`, the entire leading - // mask region is zero. Otherwise the folded unit dims have no effect on - // the mask. - SmallVector newDimSizes; - if (newType.getRank() == 0) { - newDimSizes.push_back(llvm::product_of(dimSizes)); - } else { - int64_t flatLeadingSize = - llvm::product_of(dimSizes.take_front(dropDim + 1)); - newDimSizes.push_back(flatLeadingSize); - newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end()); - } + // If any of the dropped unit dims has a size of `0`, the entire mask is a + // zero mask, else the unit dim has no effect on the mask. + int64_t flatLeadingSize = + llvm::product_of(dimSizes.take_front(dropDim + 1)); + SmallVector newDimSizes = {flatLeadingSize}; + newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end()); auto newMask = vector::ConstantMaskOp::create(rewriter, mask.getLoc(), newType, newDimSizes); - rewriter.replaceOpWithNewOp(mask, oldType, newMask); + rewriter.replaceOpWithNewOp(mask, oldType, newMask); return success(); } }; diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 2575c9e4a85b9..752610efc6992 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1931,12 +1931,12 @@ struct ChainedReduction final : OpRewritePattern { } }; -// Helper function dropping unit non-scalable dimension from a VectorType. -// Scalable unit dimensions are not dropped. Folding such dimensions would -// require "shifting" the scalable flag onto some other fixed-width dim (e.g. -// vector<[1]x4xf32> -> vector<[4]xf32>). -static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy, - bool zeroDimsAllowed) { +// Helper function dropping unit non-scalable dimension from a VectorType +// keeping at least 1 dimension to avoid generating 0-D vectors. Scalable unit +// dimensions are not dropped. Folding such dimensions would require "shifting" +// the scalable flag onto some other fixed-width dim (e.g. vector<[1]x4xf32> -> +// vector<[4]xf32>). This could be implemented in the future. +static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) { auto inVecShape = inVecTy.getShape(); SmallVector newShape; SmallVector newScalableDims; @@ -1948,8 +1948,8 @@ static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy, newShape.push_back(dim); newScalableDims.push_back(isScalable); } - // Some vector ops forbid 0-D vectors. - if (!zeroDimsAllowed && newShape.empty()) { + // All dims have been dropped, return vector<1xeType>. + if (newShape.empty()) { newShape.push_back(1); newScalableDims.push_back(false); } @@ -2000,12 +2000,14 @@ struct DropUnitDimFromElementwiseOps final auto sourceVectorType = dyn_cast(op->getOperand(0).getType()); if (!sourceVectorType) return failure(); + if (sourceVectorType.getRank() < 2) + return failure(); + SmallVector newOperands; auto loc = op->getLoc(); for (auto operand : op->getOperands()) { auto opVectorType = cast(operand.getType()); - auto newVType = dropNonScalableUnitDimFromType(opVectorType, - /*zeroDimsAllowed=*/true); + auto newVType = dropNonScalableUnitDimFromType(opVectorType); if (newVType == opVectorType) return rewriter.notifyMatchFailure(op, "No unit dimension to remove."); @@ -2014,8 +2016,7 @@ struct DropUnitDimFromElementwiseOps final } VectorType newResultVectorType = - dropNonScalableUnitDimFromType(resultVectorType, - /*zeroDimsAllowed=*/true); + dropNonScalableUnitDimFromType(resultVectorType); // Create an updated elementwise Op without unit dim. Operation *elementwiseOp = rewriter.create(loc, op->getName().getIdentifier(), newOperands, @@ -2056,8 +2057,7 @@ struct DropUnitDimsFromTransposeOp final PatternRewriter &rewriter) const override { VectorType sourceType = op.getSourceVectorType(); VectorType sourceTypeWithoutUnitDims = - dropNonScalableUnitDimFromType(sourceType, - /*zeroDimsAllowed=*/true); + dropNonScalableUnitDimFromType(sourceType); if (sourceType == sourceTypeWithoutUnitDims) return failure(); @@ -2082,9 +2082,9 @@ struct DropUnitDimsFromTransposeOp final } // Fixup for `newPerm`. The `sourceTypeWithoutUnitDims` could be vector<1xT> - // type when the dimensions are unit dimensions and 0-D vectors are not - // allowed. In this case, the newPerm should be [0]. - if (newPerm.empty() && sourceTypeWithoutUnitDims.getRank() > 0) { + // type when the dimensions are unit dimensions. In this case, the newPerm + // should be [0]. + if (newPerm.empty()) { newPerm.push_back(0); } @@ -2139,9 +2139,7 @@ struct DropUnitDimsFromScfForOp final : OpRewritePattern { if (!vectorType) continue; - VectorType newVectorType = - dropNonScalableUnitDimFromType(vectorType, - /*zeroDimsAllowed=*/true); + VectorType newVectorType = dropNonScalableUnitDimFromType(vectorType); if (vectorType == newVectorType) continue; diff --git a/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir b/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir index 4e800ab169bf6..34a155fbf2fc1 100644 --- a/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir +++ b/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir @@ -150,26 +150,10 @@ func.func @fold_all_unit_dims(%vec: vector<1x1xf32>) -> vector<1xf32> { // CHECK-LABEL: func.func @fold_all_unit_dims( // CHECK-SAME: %[[VAL_0:.*]]: vector<1x1xf32>) -> vector<1xf32> -// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector -// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector -// CHECK: %[[VAL_3:.*]] = arith.mulf %[[VAL_1]], %[[VAL_2]] : vector -// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_3]] : vector to vector<1xf32> -// CHECK: return %[[VAL_4]] : vector<1xf32> - -// ----- - -func.func @fold_rank1_unit_dim(%vec: vector<1xf32>) -> vector<1xf32> { - %res = arith.addf %vec, %vec : vector<1xf32> - return %res : vector<1xf32> -} - -// CHECK-LABEL: func.func @fold_rank1_unit_dim( -// CHECK-SAME: %[[VAL_0:.*]]: vector<1xf32>) -> vector<1xf32> -// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1xf32> to vector -// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<1xf32> to vector -// CHECK: %[[VAL_3:.*]] = arith.addf %[[VAL_1]], %[[VAL_2]] : vector -// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_3]] : vector to vector<1xf32> -// CHECK: return %[[VAL_4]] : vector<1xf32> +// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector<1xf32> +// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector<1xf32> +// CHECK: %[[VAL_3:.*]] = arith.mulf %[[VAL_1]], %[[VAL_2]] : vector<1xf32> +// CHECK: return %[[VAL_3]] : vector<1xf32> ///---------------------------------------------------------------------------------------- /// [Pattern: DropUnitDimsFromTransposeOp] @@ -265,11 +249,11 @@ func.func @scf_for_with_all_unit_dims(%vec: vector<1x1xf32>) -> vector<1x1xf32> // CHECK-LABEL: func.func @scf_for_with_all_unit_dims // CHECK-SAME: %[[VEC:[A-Za-z0-9]+]]: vector<1x1xf32> -// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x1xf32> to vector +// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x1xf32> to vector<1xf32> // CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%[[ITER:.+]] = %[[CAST]]) -// CHECK: %[[SQRT:.+]] = math.sqrt %[[ITER]] : vector +// CHECK: %[[SQRT:.+]] = math.sqrt %[[ITER]] : vector<1xf32> // CHECK: scf.yield %[[SQRT]] -// CHECK: %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]] : vector to vector<1x1xf32> +// CHECK: %[[CASTBACK:.+]] = vector.shape_cast %[[LOOP]] : vector<1xf32> to vector<1x1xf32> // CHECK: return %[[CASTBACK]] // ----- diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir index cd1ecec455896..bf01c8a8589d9 100644 --- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir @@ -5,13 +5,13 @@ // CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK-LABEL: cast_away_contraction_leading_one_dims -// CHECK-NEXT: %[[R0:.+]] = vector.shape_cast %{{.*}} : vector<1x16x8xf32> to vector<16x8xf32> -// CHECK-NEXT: %[[R1:.+]] = vector.shape_cast %{{.*}} : vector<1x8x16xf32> to vector<8x16xf32> -// CHECK-NEXT: %[[R2:.+]] = vector.shape_cast %{{.*}} : vector<1x16x16xf32> to vector<16x16xf32> +// CHECK-NEXT: %[[R0:.+]] = vector.extract %{{.*}}[0] : vector<16x8xf32> from vector<1x16x8xf32> +// CHECK-NEXT: %[[R1:.+]] = vector.extract %{{.*}}[0] : vector<8x16xf32> from vector<1x8x16xf32> +// CHECK-NEXT: %[[R2:.+]] = vector.extract %{{.*}}[0] : vector<16x16xf32> from vector<1x16x16xf32> // CHECK-NEXT: %[[R3:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} // CHECK-SAME: %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32> -// CHECK-NEXT: %[[R4:.+]] = vector.shape_cast %[[R3]] : vector<16x16xf32> to vector<1x16x16xf32> +// CHECK-NEXT: %[[R4:.+]] = vector.broadcast %[[R3]] : vector<16x16xf32> to vector<1x16x16xf32> // CHECK-NEXT: return %[[R4]] : vector<1x16x16xf32> #contraction_accesses0 = [ @@ -36,14 +36,14 @@ func.func @cast_away_contraction_leading_one_dims(%arg0: vector<1x16x8xf32>, %ar // CHECK-LABEL: func.func @cast_away_contraction_leading_one_dim_under_const_mask // CHECK: %[[MASK:.*]] = vector.constant_mask [15, 15, 8] : vector<16x16x8xi1> -// CHECK: %[[R0:.*]] = vector.shape_cast %{{.*}} : vector<1x16x8xf32> to vector<16x8xf32> -// CHECK: %[[R1:.*]] = vector.shape_cast %{{.*}} : vector<1x8x16xf32> to vector<8x16xf32> -// CHECK: %[[R2:.*]] = vector.shape_cast %{{.*}} : vector<1x16x16xf32> to vector<16x16xf32> +// CHECK: %[[R0:.*]] = vector.extract %{{.*}}[0] : vector<16x8xf32> from vector<1x16x8xf32> +// CHECK: %[[R1:.*]] = vector.extract %{{.*}}[0] : vector<8x16xf32> from vector<1x8x16xf32> +// CHECK: %[[R2:.*]] = vector.extract %{{.*}}[0] : vector<16x16xf32> from vector<1x16x16xf32> // CHECK: %[[CONTRACT:.*]] = vector.mask %[[MASK]] { // CHECK-SAME: vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} // CHECK-SAME: %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32> // CHECK-SAME: } : vector<16x16x8xi1> -> vector<16x16xf32> -// CHECK: %[[RES:.*]] = vector.shape_cast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32> +// CHECK: %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32> // CHECK: return %[[RES]] : vector<1x16x16xf32> #contraction_accesses0 = [ @@ -70,15 +70,15 @@ func.func @cast_away_contraction_leading_one_dim_under_const_mask(%arg0: vector< // CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK-LABEL: func.func @cast_away_contraction_leading_one_dim_under_mask -// CHECK: %[[R0:.*]] = vector.shape_cast %{{.*}} : vector<1x16x8xf32> to vector<16x8xf32> -// CHECK: %[[R1:.*]] = vector.shape_cast %{{.*}} : vector<1x8x16xf32> to vector<8x16xf32> -// CHECK: %[[R2:.*]] = vector.shape_cast %{{.*}} : vector<1x16x16xf32> to vector<16x16xf32> -// CHECK: %[[M:.*]] = vector.shape_cast %{{.*}} : vector<1x16x16x8xi1> to vector<16x16x8xi1> +// CHECK: %[[R0:.*]] = vector.extract %{{.*}} : vector<16x8xf32> from vector<1x16x8xf32> +// CHECK: %[[R1:.*]] = vector.extract %{{.*}} : vector<8x16xf32> from vector<1x8x16xf32> +// CHECK: %[[R2:.*]] = vector.extract %{{.*}} : vector<16x16xf32> from vector<1x16x16xf32> +// CHECK: %[[M:.*]] = vector.extract %{{.*}} : vector<16x16x8xi1> from vector<1x16x16x8xi1> // CHECK: %[[CONTRACT:.*]] = vector.mask %[[M]] { // CHECK-SAME: vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} // CHECK-SAME: %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32> // CHECK-SAME: } : vector<16x16x8xi1> -> vector<16x16xf32> -// CHECK-NEXT: %[[RES:.*]] = vector.shape_cast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32> +// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32> // CHECK-NEXT: return %[[RES]] : vector<1x16x16xf32> #contraction_accesses0 = [ @@ -109,14 +109,15 @@ func.func @cast_away_contraction_leading_one_dim_under_mask( // CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1) -> (d0)> // CHECK-LABEL: cast_away_contraction_leading_one_dims_transposeneeded -// CHECK-NEXT: %[[R0:.+]] = vector.shape_cast %{{.*}} : vector<1x8x16xf32> to vector<8x16xf32> -// CHECK-NEXT: %[[R1:.+]] = vector.shape_cast %{{.*}} : vector<1x1x8xf32> to vector<8xf32> -// CHECK-NEXT: %[[R2:.+]] = vector.shape_cast %{{.*}} : vector<1x1x16xf32> to vector<16xf32> +// CHECK-NEXT: %[[R0:.+]] = vector.extract %{{.*}}[0] : vector<8x16xf32> from vector<1x8x16xf32> +// CHECK-NEXT: %[[R1:.+]] = vector.extract %{{.*}}[0, 0] : vector<8xf32> from vector<1x1x8xf32> +// CHECK-NEXT: %[[R2:.+]] = vector.extract %{{.*}}[0, 0] : vector<16xf32> from vector<1x1x16xf32> // CHECK-NEXT: %[[R3:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], // CHECK-SAME: iterator_types = ["parallel", "reduction"], kind = #vector.kind} // CHECK-SAME: %[[R1]], %[[R0]], %[[R2]] : vector<8xf32>, vector<8x16xf32> into vector<16xf32> -// CHECK-NEXT: %[[R4:.+]] = vector.shape_cast %[[R3]] : vector<16xf32> to vector<1x1x16xf32> -// CHECK-NEXT: return %[[R4]] : vector<1x1x16xf32> +// CHECK-NEXT: %[[R4:.+]] = vector.broadcast %[[R3]] : vector<16xf32> to vector<1x16xf32> +// CHECK-NEXT: %[[R5:.+]] = vector.broadcast %[[R4]] : vector<1x16xf32> to vector<1x1x16xf32> +// CHECK-NEXT: return %[[R5]] : vector<1x1x16xf32> #contraction_accesses1 = [ affine_map<(l, i, j, k) -> (i, l, k)>, @@ -140,13 +141,15 @@ func.func @cast_away_contraction_leading_one_dims_transposeneeded(%arg0: vector< // CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK-LABEL: cast_away_contraction_leading_one_dims_transposeneeded2 -// CHECK-NEXT: %[[R1:.+]] = vector.shape_cast %{{.*}} : vector<8x1x16xf32> to vector<8x16xf32> -// CHECK-NEXT: %[[R3:.+]] = vector.shape_cast %{{.*}} : vector<2x8x1xf32> to vector<2x8xf32> -// CHECK-NEXT: %[[R4:.+]] = vector.shape_cast %{{.*}} : vector<1x2x16xf32> to vector<2x16xf32> +// CHECK-NEXT: %[[R0:.+]] = vector.transpose %{{.*}}[1, 0, 2] : vector<8x1x16xf32> to vector<1x8x16xf32> +// CHECK-NEXT: %[[R1:.+]] = vector.extract %[[R0]][0] : vector<8x16xf32> from vector<1x8x16xf32> +// CHECK-NEXT: %[[R2:.+]] = vector.transpose %{{.*}}[2, 0, 1] : vector<2x8x1xf32> to vector<1x2x8xf32> +// CHECK-NEXT: %[[R3:.+]] = vector.extract %[[R2]][0] : vector<2x8xf32> from vector<1x2x8xf32> +// CHECK-NEXT: %[[R4:.+]] = vector.extract %{{.*}}[0] : vector<2x16xf32> from vector<1x2x16xf32> // CHECK-NEXT: %[[R5:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} // CHECK-SAME: %[[R1]], %[[R3]], %[[R4]] : vector<8x16xf32>, vector<2x8xf32> into vector<2x16xf32> -// CHECK-NEXT: %[[R6:.+]] = vector.shape_cast %[[R5]] : vector<2x16xf32> to vector<1x2x16xf32> +// CHECK-NEXT: %[[R6:.+]] = vector.broadcast %[[R5]] : vector<2x16xf32> to vector<1x2x16xf32> // CHECK-NEXT: return %[[R6]] : vector<1x2x16xf32> #contraction_accesses2 = [ @@ -172,14 +175,19 @@ func.func @cast_away_contraction_leading_one_dims_transposeneeded2(%arg0: vector // CHECK-LABEL: cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4 -// CHECK-NEXT: %[[R3:.+]] = vector.shape_cast %{{.*}} : vector<1x8x1x16xf32> to vector<8x16xf32> -// CHECK-NEXT: %[[R5:.+]] = vector.shape_cast %{{.*}} : vector<1x2x8x1xf32> to vector<2x8xf32> -// CHECK-NEXT: %[[R6:.+]] = vector.shape_cast %{{.*}} : vector<1x1x2x16xf32> to vector<2x16xf32> +// CHECK-NEXT: %[[R0:.+]] = vector.extract %{{.*}}[0] : vector<8x1x16xf32> from vector<1x8x1x16xf32> +// CHECK-NEXT: %[[R1:.+]] = vector.extract %{{.*}}[0] : vector<2x8x1xf32> from vector<1x2x8x1xf32> +// CHECK-NEXT: %[[R2:.+]] = vector.transpose %[[R0]], [1, 0, 2] : vector<8x1x16xf32> to vector<1x8x16xf32> +// CHECK-NEXT: %[[R3:.+]] = vector.extract %[[R2]][0] : vector<8x16xf32> from vector<1x8x16xf32> +// CHECK-NEXT: %[[R4:.+]] = vector.transpose %[[R1]], [2, 0, 1] : vector<2x8x1xf32> to vector<1x2x8xf32> +// CHECK-NEXT: %[[R5:.+]] = vector.extract %[[R4]][0] : vector<2x8xf32> from vector<1x2x8xf32> +// CHECK-NEXT: %[[R6:.+]] = vector.extract %{{.*}}[0, 0] : vector<2x16xf32> from vector<1x1x2x16xf32> // CHECK-NEXT: %[[R7:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} // CHECK-SAME: %[[R3]], %[[R5]], %[[R6]] : vector<8x16xf32>, vector<2x8xf32> into vector<2x16xf32> -// CHECK-NEXT: %[[R8:.+]] = vector.shape_cast %[[R7]] : vector<2x16xf32> to vector<1x1x2x16xf32> -// CHECK-NEXT: return %[[R8]] : vector<1x1x2x16xf32> +// CHECK-NEXT: %[[R8:.+]] = vector.broadcast %[[R7]] : vector<2x16xf32> to vector<1x2x16xf32> +// CHECK-NEXT: %[[R9:.+]] = vector.broadcast %[[R8]] : vector<1x2x16xf32> to vector<1x1x2x16xf32> +// CHECK-NEXT: return %[[R9]] : vector<1x1x2x16xf32> #contraction_accesses2 = [ affine_map<(m, l, i, j, k) -> (m, k, l, j)>, @@ -203,14 +211,17 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4(%arg0: // CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> // CHECK-LABEL: cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctranspose -// CHECK-NEXT: %[[R2:.+]] = vector.shape_cast %{{.*}} : vector<1x8x1x16xf32> to vector<8x16xf32> -// CHECK-NEXT: %[[R3:.+]] = vector.shape_cast %{{.*}} : vector<1x2x8x1xf32> to vector<2x8xf32> -// CHECK-NEXT: %[[R4:.+]] = vector.shape_cast %{{.*}} : vector<1x1x2x16xf32> to vector<2x16xf32> +// CHECK-NEXT: %[[R0:.+]] = vector.transpose %{{.*}}, [2, 0, 1, 3] : vector<1x8x1x16xf32> to vector<1x1x8x16xf32> +// CHECK-NEXT: %[[R1:.+]] = vector.transpose %{{.*}}, [3, 0, 1, 2] : vector<1x2x8x1xf32> to vector<1x1x2x8xf32> +// CHECK-NEXT: %[[R2:.+]] = vector.extract %[[R0]][0, 0] : vector<8x16xf32> from vector<1x1x8x16xf32> +// CHECK-NEXT: %[[R3:.+]] = vector.extract %[[R1]][0, 0] : vector<2x8xf32> from vector<1x1x2x8xf32> +// CHECK-NEXT: %[[R4:.+]] = vector.extract %{{.*}}[0, 0] : vector<2x16xf32> from vector<1x1x2x16xf32> // CHECK-NEXT: %[[R5:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], // CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} // CHECK-SAME: %[[R2]], %[[R3]], %[[R4]] : vector<8x16xf32>, vector<2x8xf32> into vector<2x16xf32> -// CHECK-NEXT: %[[R6:.+]] = vector.shape_cast %[[R5]] : vector<2x16xf32> to vector<1x1x2x16xf32> -// CHECK-NEXT: return %[[R6]] : vector<1x1x2x16xf32> +// CHECK-NEXT: %[[R6:.+]] = vector.broadcast %[[R5]] : vector<2x16xf32> to vector<1x2x16xf32> +// CHECK-NEXT: %[[R7:.+]] = vector.broadcast %[[R6]] : vector<1x2x16xf32> to vector<1x1x2x16xf32> +// CHECK-NEXT: return %[[R7]] : vector<1x1x2x16xf32> #contraction_accesses3 = [ affine_map<(m, l, i, j, k) -> (m, k, l, j)>, @@ -245,7 +256,7 @@ func.func @cast_away_contraction_does_not_transpose_leading_unit_dims(%lhs: vect // CHECK-DAG: #[[$map_dp1:.*]] = affine_map<(d0) -> ()> // CHECK-LABEL: cast_away_contraction_leading_one_dims_to_dot_product -// CHECK-NEXT: %[[R0:.+]] = vector.shape_cast %{{.*}} : vector<1x64xf32> to vector<64xf32> +// CHECK-NEXT: %[[R0:.+]] = vector.extract %{{.*}}[0] : vector<64xf32> from vector<1x64xf32> // CHECK-NEXT: %[[R1:.+]] = vector.extract %{{.*}}[0] : f32 from vector<1xf32> // CHECK-NEXT: %[[R2:.+]] = vector.contract {indexing_maps = [#[[$map_dp0]], #[[$map_dp0]], #[[$map_dp1]]], // CHECK-SAME: iterator_types = ["reduction"], kind = #vector.kind} @@ -259,96 +270,44 @@ func.func @cast_away_contraction_leading_one_dims_to_dot_product(%arg0: vector<6 } // ----- - -// CHECK-DAG: #[[$DOT_MAP:.*]] = affine_map<(d0) -> (d0)> -// CHECK-DAG: #[[$SCALAR_MAP:.*]] = affine_map<(d0) -> ()> - -// CHECK-LABEL: cast_away_masked_contraction_with_rank1_acc -// CHECK-NEXT: %[[RHS:.+]] = vector.shape_cast %{{.*}} : vector<1x64xf32> to vector<64xf32> -// CHECK-NEXT: %[[ACC:.+]] = vector.extract %{{.*}}[0] : f32 from vector<1xf32> -// CHECK-NEXT: %[[MASK:.+]] = vector.shape_cast %{{.*}} : vector<64x1xi1> to vector<64xi1> -// CHECK-NEXT: %[[DOT:.+]] = vector.mask %[[MASK]] { -// CHECK-SAME: vector.contract {indexing_maps = [#[[$DOT_MAP]], #[[$DOT_MAP]], #[[$SCALAR_MAP]]], iterator_types = ["reduction"], kind = #vector.kind} -// CHECK-SAME: %{{.*}}, %[[RHS]], %[[ACC]] : vector<64xf32>, vector<64xf32> into f32 -// CHECK-SAME: } : vector<64xi1> -> f32 -// CHECK-NEXT: %[[RES:.+]] = vector.broadcast %[[DOT]] : f32 to vector<1xf32> -// CHECK-NEXT: return %[[RES]] : vector<1xf32> - -func.func @cast_away_masked_contraction_with_rank1_acc(%arg0: vector<64xf32>, %arg1: vector<1x64xf32>, %arg2: vector<1xf32>, %mask: vector<64x1xi1>) -> vector<1xf32> { - %0 = vector.mask %mask { - vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d1)>], iterator_types = ["reduction", "parallel"], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<64xf32>, vector<1x64xf32> into vector<1xf32> - } : vector<64x1xi1> -> vector<1xf32> - return %0 : vector<1xf32> -} - -// ----- - -// CHECK-LABEL: negative_cast_away_contraction_with_scalable_rank1_acc -// CHECK-NOT: vector.shape_cast -// CHECK-NOT: vector.extract -// CHECK-NOT: vector.broadcast -// CHECK-NEXT: vector.contract -// CHECK-NEXT: return - -func.func @negative_cast_away_contraction_with_scalable_rank1_acc(%arg0: vector<64xf32>, %arg1: vector<[1]x64xf32>, %arg2: vector<[1]xf32>) -> vector<[1]xf32> { - %0 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d1)>], iterator_types = ["reduction", "parallel"], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<64xf32>, vector<[1]x64xf32> into vector<[1]xf32> - return %0 : vector<[1]xf32> -} - -// ----- - -// CHECK-LABEL: negative_cast_away_contraction_with_scalable_operand_dim -// CHECK-NOT: vector.shape_cast -// CHECK-NOT: vector.extract -// CHECK-NOT: vector.broadcast -// CHECK-NEXT: vector.contract -// CHECK-NEXT: return - -func.func @negative_cast_away_contraction_with_scalable_operand_dim(%arg0: vector<64xf32>, %arg1: vector<[1]x64xf32>, %arg2: vector<1xf32>) -> vector<1xf32> { - %0 = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d1)>], iterator_types = ["reduction", "parallel"], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<64xf32>, vector<[1]x64xf32> into vector<1xf32> - return %0 : vector<1xf32> -} - -// ----- - // CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims func.func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8xf16>) -> vector<1x1x8xf16> { - // CHECK: %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x8x8xf16> to vector<8x8xf16> + // CHECK: %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<8x8xf16> from vector<1x8x8xf16> // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [4], sizes = [1], strides = [1]} : vector<8x8xf16> to vector<1x8xf16> %0 = vector.extract_strided_slice %arg0 {offsets = [0, 4], sizes = [1, 1], strides = [1, 1]} : vector<1x8x8xf16> to vector<1x1x8xf16> - // CHECK: %[[RET:.+]] = vector.shape_cast %[[EXTRACT]] : vector<1x8xf16> to vector<1x1x8xf16> + // CHECK: %[[RET:.+]] = vector.broadcast %[[EXTRACT]] : vector<1x8xf16> to vector<1x1x8xf16> // CHECK: return %[[RET]] return %0: vector<1x1x8xf16> } // CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims_scalable func.func @cast_away_extract_strided_slice_leading_one_dims_scalable(%arg0: vector<1x8x[8]xf16>) -> vector<1x1x[8]xf16> { - // CHECK: %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x8x[8]xf16> to vector<8x[8]xf16> + // CHECK: %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<8x[8]xf16> from vector<1x8x[8]xf16> // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [4], sizes = [1], strides = [1]} : vector<8x[8]xf16> to vector<1x[8]xf16> %0 = vector.extract_strided_slice %arg0 {offsets = [0, 4], sizes = [1, 1], strides = [1, 1]} : vector<1x8x[8]xf16> to vector<1x1x[8]xf16> - // CHECK: %[[RET:.+]] = vector.shape_cast %[[EXTRACT]] : vector<1x[8]xf16> to vector<1x1x[8]xf16> + // CHECK: %[[RET:.+]] = vector.broadcast %[[EXTRACT]] : vector<1x[8]xf16> to vector<1x1x[8]xf16> // CHECK: return %[[RET]] return %0: vector<1x1x[8]xf16> } // CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims func.func @cast_away_insert_strided_slice_leading_one_dims(%arg0: vector<1x8xf16>, %arg1: vector<1x8x8xf16>) -> vector<1x8x8xf16> { - // CHECK: %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x8xf16> to vector<8xf16> - // CHECK: %[[DST:.+]] = vector.shape_cast %{{.*}} : vector<1x8x8xf16> to vector<8x8xf16> + // CHECK: %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<8xf16> from vector<1x8xf16> + // CHECK: %[[DST:.+]] = vector.extract %{{.*}}[0] : vector<8x8xf16> from vector<1x8x8xf16> // CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[SRC]], %[[DST]] {offsets = [0, 0], strides = [1]} : vector<8xf16> into vector<8x8xf16> %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x8xf16> into vector<1x8x8xf16> - // CHECK: %[[RET:.+]] = vector.shape_cast %[[INSERT]] : vector<8x8xf16> to vector<1x8x8xf16> + // CHECK: %[[RET:.+]] = vector.broadcast %[[INSERT]] : vector<8x8xf16> to vector<1x8x8xf16> // CHECK: return %[[RET]] return %0: vector<1x8x8xf16> } // CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_scalable func.func @cast_away_insert_strided_slice_leading_one_dims_scalable(%arg0: vector<1x[8]xf16>, %arg1: vector<1x8x[8]xf16>) -> vector<1x8x[8]xf16> { - // CHECK: %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x[8]xf16> to vector<[8]xf16> - // CHECK: %[[DST:.+]] = vector.shape_cast %{{.*}} : vector<1x8x[8]xf16> to vector<8x[8]xf16> + // CHECK: %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<[8]xf16> from vector<1x[8]xf16> + // CHECK: %[[DST:.+]] = vector.extract %{{.*}}[0] : vector<8x[8]xf16> from vector<1x8x[8]xf16> // CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[SRC]], %[[DST]] {offsets = [0, 0], strides = [1]} : vector<[8]xf16> into vector<8x[8]xf16> %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x[8]xf16> into vector<1x8x[8]xf16> - // CHECK: %[[RET:.+]] = vector.shape_cast %[[INSERT]] : vector<8x[8]xf16> to vector<1x8x[8]xf16> + // CHECK: %[[RET:.+]] = vector.broadcast %[[INSERT]] : vector<8x[8]xf16> to vector<1x8x[8]xf16> // CHECK: return %[[RET]] return %0: vector<1x8x[8]xf16> } @@ -356,7 +315,8 @@ func.func @cast_away_insert_strided_slice_leading_one_dims_scalable(%arg0: vecto // CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_one_element // CHECK-SAME: %[[ARG0:.+]]: vector<1x1xf16>, %{{.+}}: vector<1x1x1xf16> func.func @cast_away_insert_strided_slice_leading_one_dims_one_element(%arg0: vector<1x1xf16>, %arg1: vector<1x1x1xf16>) -> vector<1x1x1xf16> { - // CHECK: %[[B:.+]] = vector.shape_cast %{{.*}} : vector<1x1xf16> to vector<1x1x1xf16> + // CHECK: %[[EXT:.+]] = vector.extract %{{.*}}[0] : vector<1xf16> from vector<1x1xf16> + // CHECK: %[[B:.+]] = vector.broadcast %[[EXT]] : vector<1xf16> to vector<1x1x1xf16> %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x1xf16> into vector<1x1x1xf16> // CHECK: return %[[B]] return %0: vector<1x1x1xf16> @@ -365,7 +325,8 @@ func.func @cast_away_insert_strided_slice_leading_one_dims_one_element(%arg0: ve // CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_one_element_scalable // CHECK-SAME: %[[ARG0:.+]]: vector<1x[1]xf16>, %{{.+}}: vector<1x1x[1]xf16> func.func @cast_away_insert_strided_slice_leading_one_dims_one_element_scalable(%arg0: vector<1x[1]xf16>, %arg1: vector<1x1x[1]xf16>) -> vector<1x1x[1]xf16> { - // CHECK: %[[B:.+]] = vector.shape_cast %{{.*}} : vector<1x[1]xf16> to vector<1x1x[1]xf16> + // CHECK: %[[EXT:.+]] = vector.extract %{{.*}}[0] : vector<[1]xf16> from vector<1x[1]xf16> + // CHECK: %[[B:.+]] = vector.broadcast %[[EXT]] : vector<[1]xf16> to vector<1x1x[1]xf16> %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x[1]xf16> into vector<1x1x[1]xf16> // CHECK: return %[[B]] return %0: vector<1x1x[1]xf16> @@ -378,7 +339,7 @@ func.func @cast_away_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>) // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16 %f0 = arith.constant 0. : f16 // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {in_bounds = [true]} : memref<1x4x8x16xf16>, vector<4xf16> - // CHECK: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<4xf16> to vector<1x4xf16> + // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x4xf16> %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x4x8x16xf16>, vector<1x4xf16> // CHECK: return %[[CAST]] return %0: vector<1x4xf16> @@ -390,9 +351,9 @@ func.func @cast_away_masked_transfer_read_leading_one_dims(%arg0: memref<1x4x8x1 %c0 = arith.constant 0 : index // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16 %f0 = arith.constant 0. : f16 - // CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4xi1> to vector<4xi1> + // CHECK: %[[MASK_CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1> // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]], %[[MASK_CAST]] {in_bounds = [true]} : memref<1x4x8x16xf16>, vector<4xf16> - // CHECK: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<4xf16> to vector<1x4xf16> + // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x4xf16> %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0, %arg1 {in_bounds = [true, true]} : memref<1x4x8x16xf16>, vector<1x4xf16> // CHECK: return %[[CAST]] return %0: vector<1x4xf16> @@ -402,7 +363,7 @@ func.func @cast_away_masked_transfer_read_leading_one_dims(%arg0: memref<1x4x8x1 func.func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>) -> vector<1x1xf16> { %c0 = arith.constant 0 : index %f0 = arith.constant 0. : f16 - // CHECK: vector.shape_cast %{{.+}} : vector to vector<1x1xf16> + // CHECK: vector.broadcast %{{.+}} : vector<1xf16> to vector<1x1xf16> %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x1x1x1xf16>, vector<1x1xf16> return %0: vector<1x1xf16> } @@ -419,7 +380,7 @@ func.func @cast_away_nontrivial_map_masked_transfer_read(%arg0: memref<1x4x8xf16 // CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4x1xi1> to vector<4xi1> // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[F0]], %[[MASK_CAST]] {in_bounds = [true] // CHECK-SAME: permutation_map = #[[$MAP]]} : memref<1x4x8xf16>, vector<4xf16> - // CHECK: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<4xf16> to vector<1x1x4xf16> + // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x1x4xf16> %0 = vector.transfer_read %arg0[%c0, %c0, %c0], %f0, %arg1 {in_bounds = [true, true, true], permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>} : memref<1x4x8xf16>, vector<1x1x4xf16> // CHECK: return %[[CAST]] @@ -430,7 +391,7 @@ func.func @cast_away_nontrivial_map_masked_transfer_read(%arg0: memref<1x4x8xf16 // CHECK-LABEL: func @not_insert_cast_fo4_transfer_read_under_mask // CHECK: %[[MASK:.+]] = vector.constant_mask -// CHECK: %[[CASTED_MASK:.+]] = vector.shape_cast %[[MASK]] +// CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]] // CHECK: %[[RET:.+]] = vector.mask %[[CASTED_MASK]] { // CHECK-SAME: vector.transfer_read {{.*}} : memref<1x1x4xf16>, vector<1x4xf16> } // CHECK: return %[[RET]] : vector<1x4xf16> @@ -450,7 +411,7 @@ func.func @not_insert_cast_fo4_transfer_read_under_mask(%arg0: memref<1x1x4xf16> func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) { // CHECK: %[[C0:.+]] = arith.constant 0 : index %c0 = arith.constant 0 : index - // CHECK: %[[CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf16> to vector<4xf16> + // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xf16> from vector<1x4xf16> // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]} : vector<4xf16>, memref<1x4x8x16xf16> vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x4x8x16xf16> @@ -461,8 +422,8 @@ func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16> func.func @cast_away_masked_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>, %arg2: vector<1x4xi1>) { // CHECK: %[[C0:.+]] = arith.constant 0 : index %c0 = arith.constant 0 : index - // CHECK: %[[CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf16> to vector<4xf16> - // CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4xi1> to vector<4xi1> + // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xf16> from vector<1x4xf16> + // CHECK: %[[MASK_CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1> // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[MASK_CAST]] {in_bounds = [true]} : vector<4xf16>, memref<1x4x8x16xf16> vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0], %arg2 {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x4x8x16xf16> @@ -472,7 +433,7 @@ func.func @cast_away_masked_transfer_write_leading_one_dims(%arg0: memref<1x4x8x // CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims_one_element func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>, %arg1: vector<1x1xf16>) { %c0 = arith.constant 0 : index - // CHECK: vector.shape_cast %{{.+}} : vector<1x1xf16> to vector + // CHECK: vector.extract %{{.+}}[0] : vector<1xf16> from vector<1x1xf16> vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x1xf16>, memref<1x1x1x1xf16> return } @@ -481,7 +442,7 @@ func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1 // CHECK-LABEL: func @not_insert_cast_for_transfer_write_under_mask // CHECK: %[[MASK:.+]] = vector.constant_mask -// CHECK: %[[CASTED_MASK:.+]] = vector.shape_cast %[[MASK]] +// CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]] // CHECK: vector.mask %[[CASTED_MASK]] { // CHECK-SAME: vector.transfer_write {{.*}} : vector<1x4xf16>, memref<1x1x4xf16> } // CHECK: return @@ -501,7 +462,7 @@ func.func @not_insert_cast_for_transfer_write_under_mask(%arg0: memref<1x1x4xf16 func.func @cast_away_nontrivial_map_masked_transfer_write(%arg0: memref<1x4x8xf16>, %arg1: vector<1x1x4xf16>, %arg2: vector<1x4x1xi1>) { // CHECK: %[[C0:.+]] = arith.constant 0 : index %c0 = arith.constant 0 : index - // CHECK: %[[CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x1x4xf16> to vector<4xf16> + // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0, 0] : vector<4xf16> from vector<1x1x4xf16> // CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4x1xi1> to vector<4xi1> // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[MASK_CAST]] {in_bounds = [true] // CHECK-SAME: permutation_map = #[[$MAP]]} : vector<4xf16>, memref<1x4x8xf16> @@ -518,25 +479,25 @@ func.func @cast_away_elementwise_leading_one_dims( %arg0: vector<1x1x8xf32>, %arg1: f32, %arg2: vector<1x4xf32>, %arg3: vector<1x4xf32>, %arg4: i1) -> (vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32>) { - // CHECK: vector.shape_cast %{{.*}} : vector<1x1x8xf32> to vector<8xf32> - // CHECK: vector.shape_cast %{{.*}} : vector<1x1x8xf32> to vector<8xf32> + // CHECK: vector.extract %{{.*}}[0, 0] : vector<8xf32> from vector<1x1x8xf32> + // CHECK: vector.extract %{{.*}}[0, 0] : vector<8xf32> from vector<1x1x8xf32> // CHECK: arith.addf %{{.*}}, %{{.*}} : vector<8xf32> - // CHECK: vector.shape_cast %{{.*}} : vector<8xf32> to vector<1x1x8xf32> + // CHECK: vector.broadcast %{{.*}} : vector<8xf32> to vector<1x1x8xf32> %0 = arith.addf %arg0, %arg0 : vector<1x1x8xf32> - // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32> - // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32> + // CHECK: vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32> + // CHECK: vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32> // CHECK: arith.cmpf ogt, %{{.*}}, %{{.*}} : vector<4xf32> - // CHECK: vector.shape_cast %{{.*}} : vector<4xi1> to vector<1x4xi1> + // CHECK: vector.broadcast %{{.*}} : vector<4xi1> to vector<1x4xi1> %1 = arith.cmpf ogt, %arg2, %arg3 : vector<1x4xf32> - // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32> - // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32> + // CHECK: vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32> + // CHECK: vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32> // CHECK: select %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi1>, vector<4xf32> - // CHECK: vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x4xf32> + // CHECK: vector.broadcast %{{.*}} : vector<4xf32> to vector<1x4xf32> %2 = arith.select %1, %arg3, %arg2 : vector<1x4xi1>, vector<1x4xf32> - // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32> - // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32> + // CHECK: vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32> + // CHECK: vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32> // CHECK: select %arg4, %12, %{{.*}} : vector<4xf32> - // CHECK: vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x4xf32> + // CHECK: vector.broadcast %{{.*}} : vector<4xf32> to vector<1x4xf32> %3 = arith.select %arg4, %arg3, %arg2 : vector<1x4xf32> return %0, %1, %2, %3: vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32> } @@ -545,10 +506,10 @@ func.func @cast_away_elementwise_leading_one_dims( // CHECK-LABEL: func @cast_away_insert_leading_one_dims_scalar // CHECK-SAME: (%[[S:.+]]: f32, %[[V:.+]]: vector<1x1x4xf32>) -// CHECK: %[[DST_CAST:.+]] = vector.shape_cast %[[V]] : vector<1x1x4xf32> to vector<4xf32> -// CHECK: %[[INSERT:.+]] = vector.insert %[[S]], %[[DST_CAST]] [0] : f32 into vector<4xf32> -// CHECK: %[[RESULT_CAST:.+]] = vector.shape_cast %[[INSERT]] : vector<4xf32> to vector<1x1x4xf32> -// CHECK: return %[[RESULT_CAST]] +// CHECK: %[[EXTRACT:.+]] = vector.extract %[[V]][0, 0] : vector<4xf32> from vector<1x1x4xf32> +// CHECK: %[[INSERT:.+]] = vector.insert %[[S]], %[[EXTRACT]] [0] : f32 into vector<4xf32> +// CHECK: %[[BCAST:.+]] = vector.broadcast %[[INSERT]] : vector<4xf32> to vector<1x1x4xf32> +// CHECK: return %[[BCAST]] func.func @cast_away_insert_leading_one_dims_scalar(%s: f32, %v: vector<1x1x4xf32>) -> vector<1x1x4xf32> { %0 = vector.insert %s, %v [0, 0, 0] : f32 into vector<1x1x4xf32> return %0: vector<1x1x4xf32> @@ -556,27 +517,14 @@ func.func @cast_away_insert_leading_one_dims_scalar(%s: f32, %v: vector<1x1x4xf3 // ----- -// CHECK-LABEL: func @cast_away_insert_leading_one_dims_scalar_0d_dest -// CHECK-SAME: (%[[S:.+]]: f32, %[[V:.+]]: vector<1x1xf32>) -// CHECK: %[[DST_CAST:.+]] = vector.shape_cast %[[V]] : vector<1x1xf32> to vector -// CHECK: %[[INSERT:.+]] = vector.insert %[[S]], %[[DST_CAST]] [] : f32 into vector -// CHECK: %[[RESULT_CAST:.+]] = vector.shape_cast %[[INSERT]] : vector to vector<1x1xf32> -// CHECK: return %[[RESULT_CAST]] -func.func @cast_away_insert_leading_one_dims_scalar_0d_dest(%s: f32, %v: vector<1x1xf32>) -> vector<1x1xf32> { - %0 = vector.insert %s, %v [0, 0] : f32 into vector<1x1xf32> - return %0: vector<1x1xf32> -} - -// ----- - // CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_scalar_scalable( // CHECK-SAME: %[[S:.*]]: f32, // CHECK-SAME: %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> { func.func @cast_away_insert_leading_one_dims_scalar_scalable(%s: f32, %v: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> { -// CHECK: %[[DST_CAST:.*]] = vector.shape_cast %[[V]] : vector<1x1x[4]xf32> to vector<[4]xf32> -// CHECK: %[[INSERT:.*]] = vector.insert %[[S]], %[[DST_CAST]] [0] : f32 into vector<[4]xf32> -// CHECK: %[[RESULT_CAST:.*]] = vector.shape_cast %[[INSERT]] : vector<[4]xf32> to vector<1x1x[4]xf32> -// CHECK: return %[[RESULT_CAST]] : vector<1x1x[4]xf32> +// CHECK: %[[EXTRACT:.*]] = vector.extract %[[V]][0, 0] : vector<[4]xf32> from vector<1x1x[4]xf32> +// CHECK: %[[INSERT:.*]] = vector.insert %[[S]], %[[EXTRACT]] [0] : f32 into vector<[4]xf32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<[4]xf32> to vector<1x1x[4]xf32> +// CHECK: return %[[BCAST]] : vector<1x1x[4]xf32> %0 = vector.insert %s, %v [0, 0, 0] : f32 into vector<1x1x[4]xf32> return %0: vector<1x1x[4]xf32> } @@ -587,10 +535,10 @@ func.func @cast_away_insert_leading_one_dims_scalar_scalable(%s: f32, %v: vector // CHECK-SAME: %[[S:.*]]: f32, // CHECK-SAME: %[[V:.*]]: vector<1x[1]x4xf32>) -> vector<1x[1]x4xf32> { func.func @cast_away_insert_leading_one_dims_scalar_skip_scalable_dim(%s: f32, %v: vector<1x[1]x4xf32>) -> vector<1x[1]x4xf32> { -// CHECK: %[[DST_CAST:.*]] = vector.shape_cast %[[V]] : vector<1x[1]x4xf32> to vector<[1]x4xf32> -// CHECK: %[[INSERT:.*]] = vector.insert %[[S]], %[[DST_CAST]] [0, 0] : f32 into vector<[1]x4xf32> -// CHECK: %[[RESULT_CAST:.*]] = vector.shape_cast %[[INSERT]] : vector<[1]x4xf32> to vector<1x[1]x4xf32> -// CHECK: return %[[RESULT_CAST]] : vector<1x[1]x4xf32> +// CHECK: %[[EXTRACT:.*]] = vector.extract %[[V]][0] : vector<[1]x4xf32> from vector<1x[1]x4xf32> +// CHECK: %[[INSERT:.*]] = vector.insert %[[S]], %[[EXTRACT]] [0, 0] : f32 into vector<[1]x4xf32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<[1]x4xf32> to vector<1x[1]x4xf32> +// CHECK: return %[[BCAST]] : vector<1x[1]x4xf32> %0 = vector.insert %s, %v [0, 0, 0] : f32 into vector<1x[1]x4xf32> return %0: vector<1x[1]x4xf32> } @@ -599,8 +547,8 @@ func.func @cast_away_insert_leading_one_dims_scalar_skip_scalable_dim(%s: f32, % // CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank1 // CHECK-SAME: (%[[S:.+]]: vector<4xf32>, %[[V:.+]]: vector<1x1x4xf32>) -// CHECK: %[[RESULT_CAST:.+]] = vector.shape_cast %[[S]] : vector<4xf32> to vector<1x1x4xf32> -// CHECK: return %[[RESULT_CAST]] +// CHECK: %[[BCAST:.+]] = vector.broadcast %[[S]] : vector<4xf32> to vector<1x1x4xf32> +// CHECK: return %[[BCAST]] func.func @cast_away_insert_leading_one_dims_rank1(%s: vector<4xf32>, %v: vector<1x1x4xf32>) -> vector<1x1x4xf32> { %0 = vector.insert %s, %v [0, 0] : vector<4xf32> into vector<1x1x4xf32> return %0: vector<1x1x4xf32> @@ -611,8 +559,8 @@ func.func @cast_away_insert_leading_one_dims_rank1(%s: vector<4xf32>, %v: vector // CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_rank1_scalable( // CHECK-SAME: %[[S:.*]]: vector<[4]xf32>, // CHECK-SAME: %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> { -// CHECK: %[[RESULT_CAST:.*]] = vector.shape_cast %[[S]] : vector<[4]xf32> to vector<1x1x[4]xf32> -// CHECK: return %[[RESULT_CAST]] : vector<1x1x[4]xf32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[S]] : vector<[4]xf32> to vector<1x1x[4]xf32> +// CHECK: return %[[BCAST]] : vector<1x1x[4]xf32> func.func @cast_away_insert_leading_one_dims_rank1_scalable(%s: vector<[4]xf32>, %v: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> { %0 = vector.insert %s, %v [0, 0] : vector<[4]xf32> into vector<1x1x[4]xf32> return %0: vector<1x1x[4]xf32> @@ -622,8 +570,9 @@ func.func @cast_away_insert_leading_one_dims_rank1_scalable(%s: vector<[4]xf32>, // CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank2 // CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<1x1x4xf32>) -// CHECK: %[[SRC_CAST:.+]] = vector.shape_cast %[[S]] : vector<1x4xf32> to vector<1x1x4xf32> -// CHECK: return %[[SRC_CAST]] +// CHECK: %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<4xf32> from vector<1x4xf32> +// CHECK: %[[BCAST:.+]] = vector.broadcast %[[EXTRACT]] : vector<4xf32> to vector<1x1x4xf32> +// CHECK: return %[[BCAST]] func.func @cast_away_insert_leading_one_dims_rank2(%s: vector<1x4xf32>, %v: vector<1x1x4xf32>) -> vector<1x1x4xf32> { %0 = vector.insert %s, %v [0] : vector<1x4xf32> into vector<1x1x4xf32> return %0: vector<1x1x4xf32> @@ -634,8 +583,9 @@ func.func @cast_away_insert_leading_one_dims_rank2(%s: vector<1x4xf32>, %v: vect // CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_rank2_scalable( // CHECK-SAME: %[[S:.*]]: vector<1x[4]xf32>, // CHECK-SAME: %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> { -// CHECK: %[[SRC_CAST:.*]] = vector.shape_cast %[[S]] : vector<1x[4]xf32> to vector<1x1x[4]xf32> -// CHECK: return %[[SRC_CAST]] : vector<1x1x[4]xf32> +// CHECK: %[[EXTRACT:.*]] = vector.extract %[[S]][0] : vector<[4]xf32> from vector<1x[4]xf32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTRACT]] : vector<[4]xf32> to vector<1x1x[4]xf32> +// CHECK: return %[[BCAST]] : vector<1x1x[4]xf32> func.func @cast_away_insert_leading_one_dims_rank2_scalable(%s: vector<1x[4]xf32>, %v: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> { %0 = vector.insert %s, %v [0] : vector<1x[4]xf32> into vector<1x1x[4]xf32> return %0: vector<1x1x[4]xf32> @@ -645,11 +595,11 @@ func.func @cast_away_insert_leading_one_dims_rank2_scalable(%s: vector<1x[4]xf32 // CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank2_one_dest // CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<1x2x1x4xf32>) -// CHECK: %[[SRC_CAST:.+]] = vector.shape_cast %[[S]] : vector<1x4xf32> to vector<4xf32> -// CHECK: %[[DST_CAST:.+]] = vector.shape_cast %[[V]] : vector<1x2x1x4xf32> to vector<2x1x4xf32> -// CHECK: %[[INSERT:.+]] = vector.insert %[[SRC_CAST]], %[[DST_CAST]] [1, 0] : vector<4xf32> into vector<2x1x4xf32> -// CHECK: %[[RESULT_CAST:.+]] = vector.shape_cast %[[INSERT]] : vector<2x1x4xf32> to vector<1x2x1x4xf32> -// CHECK: return %[[RESULT_CAST]] +// CHECK: %[[EXTRACTS:.+]] = vector.extract %[[S]][0] : vector<4xf32> from vector<1x4xf32> +// CHECK: %[[EXTRACTV:.+]] = vector.extract %[[V]][0] : vector<2x1x4xf32> from vector<1x2x1x4xf32> +// CHECK: %[[INSERT:.+]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [1, 0] : vector<4xf32> into vector<2x1x4xf32> +// CHECK: %[[BCAST:.+]] = vector.broadcast %[[INSERT]] : vector<2x1x4xf32> to vector<1x2x1x4xf32> +// CHECK: return %[[BCAST]] func.func @cast_away_insert_leading_one_dims_rank2_one_dest(%s: vector<1x4xf32>, %v: vector<1x2x1x4xf32>) -> vector<1x2x1x4xf32> { %0 = vector.insert %s, %v [0, 1] : vector<1x4xf32> into vector<1x2x1x4xf32> return %0: vector<1x2x1x4xf32> @@ -660,11 +610,11 @@ func.func @cast_away_insert_leading_one_dims_rank2_one_dest(%s: vector<1x4xf32>, // CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_rank2_one_dest_scalable( // CHECK-SAME: %[[S:.*]]: vector<1x[4]xf32>, // CHECK-SAME: %[[V:.*]]: vector<1x2x1x[4]xf32>) -> vector<1x2x1x[4]xf32> { -// CHECK: %[[SRC_CAST:.*]] = vector.shape_cast %[[S]] : vector<1x[4]xf32> to vector<[4]xf32> -// CHECK: %[[DST_CAST:.*]] = vector.shape_cast %[[V]] : vector<1x2x1x[4]xf32> to vector<2x1x[4]xf32> -// CHECK: %[[INSERT:.*]] = vector.insert %[[SRC_CAST]], %[[DST_CAST]] [1, 0] : vector<[4]xf32> into vector<2x1x[4]xf32> -// CHECK: %[[RESULT_CAST:.*]] = vector.shape_cast %[[INSERT]] : vector<2x1x[4]xf32> to vector<1x2x1x[4]xf32> -// CHECK: return %[[RESULT_CAST]] : vector<1x2x1x[4]xf32> +// CHECK: %[[EXTRACTS:.*]] = vector.extract %[[S]][0] : vector<[4]xf32> from vector<1x[4]xf32> +// CHECK: %[[EXTRACTV:.*]] = vector.extract %[[V]][0] : vector<2x1x[4]xf32> from vector<1x2x1x[4]xf32> +// CHECK: %[[INSERT:.*]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [1, 0] : vector<[4]xf32> into vector<2x1x[4]xf32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<2x1x[4]xf32> to vector<1x2x1x[4]xf32> +// CHECK: return %[[BCAST]] : vector<1x2x1x[4]xf32> func.func @cast_away_insert_leading_one_dims_rank2_one_dest_scalable(%s: vector<1x[4]xf32>, %v: vector<1x2x1x[4]xf32>) -> vector<1x2x1x[4]xf32> { %0 = vector.insert %s, %v [0, 1] : vector<1x[4]xf32> into vector<1x2x1x[4]xf32> return %0: vector<1x2x1x[4]xf32> @@ -674,8 +624,8 @@ func.func @cast_away_insert_leading_one_dims_rank2_one_dest_scalable(%s: vector< // CHECK-LABEL: func @cast_away_insert_leading_one_dims_non_one_dest // CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<8x1x4xf32>) -// CHECK: %[[SRC_CAST:.+]] = vector.shape_cast %[[S]] : vector<1x4xf32> to vector<4xf32> -// CHECK: %[[INSERT:.+]] = vector.insert %[[SRC_CAST]], %[[V]] [5, 0] : vector<4xf32> into vector<8x1x4xf32> +// CHECK: %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<4xf32> from vector<1x4xf32> +// CHECK: %[[INSERT:.+]] = vector.insert %[[EXTRACT]], %[[V]] [5, 0] : vector<4xf32> into vector<8x1x4xf32> // CHECK: return %[[INSERT]] func.func @cast_away_insert_leading_one_dims_non_one_dest(%s: vector<1x4xf32>, %v: vector<8x1x4xf32>) -> vector<8x1x4xf32> { %0 = vector.insert %s, %v [5] : vector<1x4xf32> into vector<8x1x4xf32> @@ -687,8 +637,8 @@ func.func @cast_away_insert_leading_one_dims_non_one_dest(%s: vector<1x4xf32>, % // CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_non_one_dest_scalable( // CHECK-SAME: %[[S:.*]]: vector<1x[4]xf32>, // CHECK-SAME: %[[V:.*]]: vector<8x1x[4]xf32>) -> vector<8x1x[4]xf32> { -// CHECK: %[[SRC_CAST:.*]] = vector.shape_cast %[[S]] : vector<1x[4]xf32> to vector<[4]xf32> -// CHECK: %[[INSERT:.*]] = vector.insert %[[SRC_CAST]], %[[V]] [5, 0] : vector<[4]xf32> into vector<8x1x[4]xf32> +// CHECK: %[[EXTRACT:.*]] = vector.extract %[[S]][0] : vector<[4]xf32> from vector<1x[4]xf32> +// CHECK: %[[INSERT:.*]] = vector.insert %[[EXTRACT]], %[[V]] [5, 0] : vector<[4]xf32> into vector<8x1x[4]xf32> // CHECK: return %[[INSERT]] : vector<8x1x[4]xf32> func.func @cast_away_insert_leading_one_dims_non_one_dest_scalable(%s: vector<1x[4]xf32>, %v: vector<8x1x[4]xf32>) -> vector<8x1x[4]xf32> { %0 = vector.insert %s, %v [5] : vector<1x[4]xf32> into vector<8x1x[4]xf32> @@ -699,11 +649,11 @@ func.func @cast_away_insert_leading_one_dims_non_one_dest_scalable(%s: vector<1x // CHECK-LABEL: func @cast_away_insert_leading_one_dims_one_two_dest // CHECK-SAME: (%[[S:.+]]: vector<1x8xi1>, %[[V:.+]]: vector<1x1x8x1x8xi1>) -// CHECK: %[[SRC_CAST:.+]] = vector.shape_cast %[[S]] : vector<1x8xi1> to vector<8xi1> -// CHECK: %[[DST_CAST:.+]] = vector.shape_cast %[[V]] : vector<1x1x8x1x8xi1> to vector<8x1x8xi1> -// CHECK: %[[INSERT:.+]] = vector.insert %[[SRC_CAST]], %[[DST_CAST]] [7, 0] : vector<8xi1> into vector<8x1x8xi1> -// CHECK: %[[RESULT_CAST:.+]] = vector.shape_cast %[[INSERT]] : vector<8x1x8xi1> to vector<1x1x8x1x8xi1> -// CHECK: return %[[RESULT_CAST]] +// CHECK: %[[EXTRACTS:.+]] = vector.extract %[[S]][0] : vector<8xi1> from vector<1x8xi1> +// CHECK: %[[EXTRACTV:.+]] = vector.extract %[[V]][0, 0] : vector<8x1x8xi1> from vector<1x1x8x1x8xi1> +// CHECK: %[[INSERT:.+]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [7, 0] : vector<8xi1> into vector<8x1x8xi1> +// CHECK: %[[BCAST:.+]] = vector.broadcast %[[INSERT]] : vector<8x1x8xi1> to vector<1x1x8x1x8xi1> +// CHECK: return %[[BCAST]] func.func @cast_away_insert_leading_one_dims_one_two_dest(%s: vector<1x8xi1>, %v: vector<1x1x8x1x8xi1>) -> vector<1x1x8x1x8xi1> { %0 = vector.insert %s, %v [0, 0, 7] : vector<1x8xi1> into vector<1x1x8x1x8xi1> return %0: vector<1x1x8x1x8xi1> @@ -714,11 +664,11 @@ func.func @cast_away_insert_leading_one_dims_one_two_dest(%s: vector<1x8xi1>, %v // CHECK-LABEL: func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable( // CHECK-SAME: %[[S:.*]]: vector<1x[8]xi1>, // CHECK-SAME: %[[V:.*]]: vector<1x1x8x1x[8]xi1>) -> vector<1x1x8x1x[8]xi1> { -// CHECK: %[[SRC_CAST:.*]] = vector.shape_cast %[[S]] : vector<1x[8]xi1> to vector<[8]xi1> -// CHECK: %[[DST_CAST:.*]] = vector.shape_cast %[[V]] : vector<1x1x8x1x[8]xi1> to vector<8x1x[8]xi1> -// CHECK: %[[INSERT:.*]] = vector.insert %[[SRC_CAST]], %[[DST_CAST]] [7, 0] : vector<[8]xi1> into vector<8x1x[8]xi1> -// CHECK: %[[RESULT_CAST:.*]] = vector.shape_cast %[[INSERT]] : vector<8x1x[8]xi1> to vector<1x1x8x1x[8]xi1> -// CHECK: return %[[RESULT_CAST]] : vector<1x1x8x1x[8]xi1> +// CHECK: %[[EXTRACTS:.*]] = vector.extract %[[S]][0] : vector<[8]xi1> from vector<1x[8]xi1> +// CHECK: %[[EXTRACTV:.*]] = vector.extract %[[V]][0, 0] : vector<8x1x[8]xi1> from vector<1x1x8x1x[8]xi1> +// CHECK: %[[INSERT:.*]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [7, 0] : vector<[8]xi1> into vector<8x1x[8]xi1> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<8x1x[8]xi1> to vector<1x1x8x1x[8]xi1> +// CHECK: return %[[BCAST]] : vector<1x1x8x1x[8]xi1> func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(%s: vector<1x[8]xi1>, %v: vector<1x1x8x1x[8]xi1>) -> vector<1x1x8x1x[8]xi1> { %0 = vector.insert %s, %v [0, 0, 7] : vector<1x[8]xi1> into vector<1x1x8x1x[8]xi1> return %0: vector<1x1x8x1x[8]xi1> @@ -728,8 +678,8 @@ func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(%s: vector<1x // CHECK-LABEL: func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> { // CHECK: %[[MASK:.*]] = vector.constant_mask [6, 1, 1] : vector<8x2x1xi1> -// CHECK: %[[MASK_CAST:.*]] = vector.shape_cast %[[MASK]] : vector<8x2x1xi1> to vector<1x1x8x2x1xi1> -// CHECK: return %[[MASK_CAST]] : vector<1x1x8x2x1xi1> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[MASK]] : vector<8x2x1xi1> to vector<1x1x8x2x1xi1> +// CHECK: return %[[BCAST]] : vector<1x1x8x2x1xi1> func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> { %0 = vector.constant_mask [1, 1, 6, 1, 1] : vector<1x1x8x2x1xi1> return %0: vector<1x1x8x2x1xi1> @@ -737,16 +687,6 @@ func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> { // ----- -// CHECK-LABEL: func.func @cast_away_constant_mask_all_unit_dims() -> vector<1x1xi1> { -// CHECK: %[[MASK:.*]] = arith.constant dense : vector<1x1xi1> -// CHECK: return %[[MASK]] : vector<1x1xi1> -func.func @cast_away_constant_mask_all_unit_dims() -> vector<1x1xi1> { - %0 = vector.constant_mask [1, 1] : vector<1x1xi1> - return %0: vector<1x1xi1> -} - -// ----- - // CHECK-LABEL: func.func @drop_unit_dims_scalar_cond_select( // CHECK: arith.select {{.*}} : vector<16xi1> func.func @drop_unit_dims_scalar_cond_select(%cond: i1, %arg0: vector<1x16xi1>, %arg1: vector<1x16xi1>) -> vector<1x16xi1> { @@ -758,7 +698,7 @@ func.func @drop_unit_dims_scalar_cond_select(%cond: i1, %arg0: vector<1x16xi1>, // CHECK-LABEL: func.func @cast_away_load_leading_one_dims // CHECK: %[[L:.+]] = vector.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<8x16xf32>, vector<4xf32> -// CHECK: %[[B:.+]] = vector.shape_cast %[[L]] : vector<4xf32> to vector<1x4xf32> +// CHECK: %[[B:.+]] = vector.broadcast %[[L]] : vector<4xf32> to vector<1x4xf32> // CHECK: return %[[B]] : vector<1x4xf32> func.func @cast_away_load_leading_one_dims(%base: memref<8x16xf32>, %i: index, %j: index) -> vector<1x4xf32> { %0 = vector.load %base[%i, %j] : memref<8x16xf32>, vector<1x4xf32> @@ -767,33 +707,11 @@ func.func @cast_away_load_leading_one_dims(%base: memref<8x16xf32>, %i: index, % // ----- -// CHECK-LABEL: func.func @cast_away_load_all_unit_dims -// CHECK: %[[L:.+]] = vector.load %{{.*}}[%{{.*}}] : memref<1xf32>, vector -// CHECK: %[[B:.+]] = vector.shape_cast %[[L]] : vector to vector<1xf32> -// CHECK: return %[[B]] : vector<1xf32> -func.func @cast_away_load_all_unit_dims(%base: memref<1xf32>, %i: index) -> vector<1xf32> { - %0 = vector.load %base[%i] : memref<1xf32>, vector<1xf32> - return %0 : vector<1xf32> -} - -// ----- - -// CHECK-LABEL: func.func @cast_away_load_leading_one_dims_scalable -// CHECK: %[[L:.+]] = vector.load %{{.*}}[%{{.*}}, %{{.*}}] : memref, vector<[4]xf32> -// CHECK: %[[B:.+]] = vector.shape_cast %[[L]] : vector<[4]xf32> to vector<1x[4]xf32> -// CHECK: return %[[B]] : vector<1x[4]xf32> -func.func @cast_away_load_leading_one_dims_scalable(%base: memref, %i: index, %j: index) -> vector<1x[4]xf32> { - %0 = vector.load %base[%i, %j] : memref, vector<1x[4]xf32> - return %0 : vector<1x[4]xf32> -} - -// ----- - // CHECK-LABEL: func.func @cast_away_maskedload_leading_one_dims -// CHECK: %[[M:.+]] = vector.shape_cast %{{.*}} : vector<1x4xi1> to vector<4xi1> -// CHECK: %[[P:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32> +// CHECK: %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1> +// CHECK: %[[P:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32> // CHECK: %[[L:.+]] = vector.maskedload %{{.*}}[%{{.*}}], %[[M]], %[[P]] : memref<16xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32> -// CHECK: %[[B:.+]] = vector.shape_cast %[[L]] : vector<4xf32> to vector<1x4xf32> +// CHECK: %[[B:.+]] = vector.broadcast %[[L]] : vector<4xf32> to vector<1x4xf32> // CHECK: return %[[B]] : vector<1x4xf32> func.func @cast_away_maskedload_leading_one_dims(%base: memref<16xf32>, %i: index, %mask: vector<1x4xi1>, %pass: vector<1x4xf32>) -> vector<1x4xf32> { %0 = vector.maskedload %base[%i], %mask, %pass : memref<16xf32>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> @@ -803,10 +721,10 @@ func.func @cast_away_maskedload_leading_one_dims(%base: memref<16xf32>, %i: inde // ----- // CHECK-LABEL: func.func @cast_away_expandload_leading_one_dims -// CHECK: %[[M:.+]] = vector.shape_cast %{{.*}} : vector<1x4xi1> to vector<4xi1> -// CHECK: %[[P:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32> +// CHECK: %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1> +// CHECK: %[[P:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32> // CHECK: %[[L:.+]] = vector.expandload %{{.*}}[%{{.*}}], %[[M]], %[[P]] : memref<16xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32> -// CHECK: %[[B:.+]] = vector.shape_cast %[[L]] : vector<4xf32> to vector<1x4xf32> +// CHECK: %[[B:.+]] = vector.broadcast %[[L]] : vector<4xf32> to vector<1x4xf32> // CHECK: return %[[B]] : vector<1x4xf32> func.func @cast_away_expandload_leading_one_dims(%base: memref<16xf32>, %i: index, %mask: vector<1x4xi1>, %pass: vector<1x4xf32>) -> vector<1x4xf32> { %0 = vector.expandload %base[%i], %mask, %pass : memref<16xf32>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> @@ -816,11 +734,11 @@ func.func @cast_away_expandload_leading_one_dims(%base: memref<16xf32>, %i: inde // ----- // CHECK-LABEL: func.func @cast_away_gather_leading_one_dims -// CHECK: %[[I:.+]] = vector.shape_cast %{{.*}} : vector<1x4xi32> to vector<4xi32> -// CHECK: %[[M:.+]] = vector.shape_cast %{{.*}} : vector<1x4xi1> to vector<4xi1> -// CHECK: %[[P:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32> +// CHECK: %[[I:.+]] = vector.extract %{{.*}}[0] : vector<4xi32> from vector<1x4xi32> +// CHECK: %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1> +// CHECK: %[[P:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32> // CHECK: %[[G:.+]] = vector.gather %{{.*}}[%{{.*}}] [%[[I]]], %[[M]], %[[P]] : memref<16xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32> -// CHECK: %[[B:.+]] = vector.shape_cast %[[G]] : vector<4xf32> to vector<1x4xf32> +// CHECK: %[[B:.+]] = vector.broadcast %[[G]] : vector<4xf32> to vector<1x4xf32> // CHECK: return %[[B]] : vector<1x4xf32> func.func @cast_away_gather_leading_one_dims(%base: memref<16xf32>, %i: index, %idx: vector<1x4xi32>, %mask: vector<1x4xi1>, %pass: vector<1x4xf32>) -> vector<1x4xf32> { %0 = vector.gather %base[%i] [%idx], %mask, %pass : memref<16xf32>, vector<1x4xi32>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> @@ -830,7 +748,7 @@ func.func @cast_away_gather_leading_one_dims(%base: memref<16xf32>, %i: index, % // ----- // CHECK-LABEL: func.func @cast_away_store_leading_one_dims -// CHECK: %[[V:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32> +// CHECK: %[[V:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32> // CHECK: vector.store %[[V]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<8x16xf32>, vector<4xf32> func.func @cast_away_store_leading_one_dims(%val: vector<1x4xf32>, %base: memref<8x16xf32>, %i: index, %j: index) { vector.store %val, %base[%i, %j] : memref<8x16xf32>, vector<1x4xf32> @@ -839,29 +757,9 @@ func.func @cast_away_store_leading_one_dims(%val: vector<1x4xf32>, %base: memref // ----- -// CHECK-LABEL: func.func @cast_away_store_all_unit_dims -// CHECK: %[[V:.+]] = vector.shape_cast %{{.*}} : vector<1xf32> to vector -// CHECK: vector.store %[[V]], %{{.*}}[%{{.*}}] : memref<1xf32>, vector -func.func @cast_away_store_all_unit_dims(%val: vector<1xf32>, %base: memref<1xf32>, %i: index) { - vector.store %val, %base[%i] : memref<1xf32>, vector<1xf32> - return -} - -// ----- - -// CHECK-LABEL: func.func @cast_away_store_leading_one_dims_scalable -// CHECK: %[[V:.+]] = vector.shape_cast %{{.*}} : vector<1x[4]xf32> to vector<[4]xf32> -// CHECK: vector.store %[[V]], %{{.*}}[%{{.*}}, %{{.*}}] : memref, vector<[4]xf32> -func.func @cast_away_store_leading_one_dims_scalable(%val: vector<1x[4]xf32>, %base: memref, %i: index, %j: index) { - vector.store %val, %base[%i, %j] : memref, vector<1x[4]xf32> - return -} - -// ----- - // CHECK-LABEL: func.func @cast_away_maskedstore_leading_one_dims -// CHECK: %[[M:.+]] = vector.shape_cast %{{.*}} : vector<1x4xi1> to vector<4xi1> -// CHECK: %[[V:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32> +// CHECK: %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1> +// CHECK: %[[V:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32> // CHECK: vector.maskedstore %{{.*}}[%{{.*}}], %[[M]], %[[V]] : memref<16xf32>, vector<4xi1>, vector<4xf32> func.func @cast_away_maskedstore_leading_one_dims(%base: memref<16xf32>, %i: index, %mask: vector<1x4xi1>, %val: vector<1x4xf32>) { vector.maskedstore %base[%i], %mask, %val : memref<16xf32>, vector<1x4xi1>, vector<1x4xf32> @@ -871,8 +769,8 @@ func.func @cast_away_maskedstore_leading_one_dims(%base: memref<16xf32>, %i: ind // ----- // CHECK-LABEL: func.func @cast_away_compressstore_leading_one_dims -// CHECK: %[[M:.+]] = vector.shape_cast %{{.*}} : vector<1x4xi1> to vector<4xi1> -// CHECK: %[[V:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32> +// CHECK: %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1> +// CHECK: %[[V:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32> // CHECK: vector.compressstore %{{.*}}[%{{.*}}], %[[M]], %[[V]] : memref<16xf32>, vector<4xi1>, vector<4xf32> func.func @cast_away_compressstore_leading_one_dims(%base: memref<16xf32>, %i: index, %mask: vector<1x4xi1>, %val: vector<1x4xf32>) { vector.compressstore %base[%i], %mask, %val : memref<16xf32>, vector<1x4xi1>, vector<1x4xf32> @@ -882,41 +780,11 @@ func.func @cast_away_compressstore_leading_one_dims(%base: memref<16xf32>, %i: i // ----- // CHECK-LABEL: func.func @cast_away_scatter_leading_one_dims -// CHECK: %[[I:.+]] = vector.shape_cast %{{.*}} : vector<1x4xi32> to vector<4xi32> -// CHECK: %[[M:.+]] = vector.shape_cast %{{.*}} : vector<1x4xi1> to vector<4xi1> -// CHECK: %[[V:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32> +// CHECK: %[[I:.+]] = vector.extract %{{.*}}[0] : vector<4xi32> from vector<1x4xi32> +// CHECK: %[[M:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1> +// CHECK: %[[V:.+]] = vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32> // CHECK: vector.scatter %{{.*}}[%{{.*}}] [%[[I]]], %[[M]], %[[V]] : memref<16xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> func.func @cast_away_scatter_leading_one_dims(%base: memref<16xf32>, %i: index, %idx: vector<1x4xi32>, %mask: vector<1x4xi1>, %val: vector<1x4xf32>) { vector.scatter %base[%i] [%idx], %mask, %val : memref<16xf32>, vector<1x4xi32>, vector<1x4xi1>, vector<1x4xf32> return } - -// ----- - -// CHECK-LABEL: func.func @negative_cast_memory_ops_to_0d -// CHECK-NOT: vector.shape_cast -// CHECK: vector.maskedload {{.*}} : memref<16xf32>, vector<1xi1>, vector<1xf32> into vector<1xf32> -// CHECK-NOT: vector.shape_cast -// CHECK: vector.expandload {{.*}} : memref<16xf32>, vector<1xi1>, vector<1xf32> into vector<1xf32> -// CHECK-NOT: vector.shape_cast -// CHECK: vector.gather {{.*}} : memref<16xf32>, vector<1xi32>, vector<1xi1>, vector<1xf32> into vector<1xf32> -// CHECK-NOT: vector.shape_cast -// CHECK: vector.maskedstore {{.*}} : memref<16xf32>, vector<1xi1>, vector<1xf32> -// CHECK-NOT: vector.shape_cast -// CHECK: vector.compressstore {{.*}} : memref<16xf32>, vector<1xi1>, vector<1xf32> -// CHECK-NOT: vector.shape_cast -// CHECK: vector.scatter {{.*}} : memref<16xf32>, vector<1xi32>, vector<1xi1>, vector<1xf32> -// CHECK-NOT: vector.shape_cast -// CHECK: return -func.func @negative_cast_memory_ops_to_0d( - %base: memref<16xf32>, %i: index, %idx: vector<1xi32>, - %mask: vector<1xi1>, %pass: vector<1xf32>, %val: vector<1xf32>) - -> (vector<1xf32>, vector<1xf32>, vector<1xf32>) { - %0 = vector.maskedload %base[%i], %mask, %pass : memref<16xf32>, vector<1xi1>, vector<1xf32> into vector<1xf32> - %1 = vector.expandload %base[%i], %mask, %pass : memref<16xf32>, vector<1xi1>, vector<1xf32> into vector<1xf32> - %2 = vector.gather %base[%i] [%idx], %mask, %pass : memref<16xf32>, vector<1xi32>, vector<1xi1>, vector<1xf32> into vector<1xf32> - vector.maskedstore %base[%i], %mask, %val : memref<16xf32>, vector<1xi1>, vector<1xf32> - vector.compressstore %base[%i], %mask, %val : memref<16xf32>, vector<1xi1>, vector<1xf32> - vector.scatter %base[%i] [%idx], %mask, %val : memref<16xf32>, vector<1xi32>, vector<1xi1>, vector<1xf32> - return %0, %1, %2 : vector<1xf32>, vector<1xf32>, vector<1xf32> -} diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir index d0d3a6c0bb976..de12a87253a67 100644 --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -36,7 +36,7 @@ func.func @no_change(%arg0: vector<2x[4]x1xf32>, %arg1: vector<2x[4]x1xf32>) -> // CHECK-LABEL: func.func @cast_away_leading_one_dim( // CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<4x1xf32> -// CHECK: vector.shape_cast %[[MUL]] : vector<4x1xf32> to vector<1x4x1xf32> +// CHECK: vector.broadcast %[[MUL]] : vector<4x1xf32> to vector<1x4x1xf32> func.func @cast_away_leading_one_dim(%arg0: vector<1x4x1xf32>, %arg1: vector<1x4x1xf32>) -> vector<1x4x1xf32> { %1 = arith.mulf %arg0, %arg1 : vector<1x4x1xf32> return %1: vector<1x4x1xf32> @@ -44,7 +44,7 @@ func.func @cast_away_leading_one_dim(%arg0: vector<1x4x1xf32>, %arg1: vector<1x4 // CHECK-LABEL: func.func @cast_away_leading_one_dim_scalable( // CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} : vector<[4]x1xf32> -// CHECK: vector.shape_cast %[[MUL]] : vector<[4]x1xf32> to vector<1x[4]x1xf32> +// CHECK: vector.broadcast %[[MUL]] : vector<[4]x1xf32> to vector<1x[4]x1xf32> func.func @cast_away_leading_one_dim_scalable(%arg0: vector<1x[4]x1xf32>, %arg1: vector<1x[4]x1xf32>) -> vector<1x[4]x1xf32> { %1 = arith.mulf %arg0, %arg1 : vector<1x[4]x1xf32> return %1: vector<1x[4]x1xf32> @@ -277,15 +277,13 @@ func.func @contraction4x4_ikj_xfer_read_tensor(%arg0 : tensor<4x2xf32>, func.func @bubble_down_bitcast_in_extract(%src: vector<4xf32>) -> (f16, f16) { %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16> // CHECK: %[[EXTRACT1:.+]] = vector.extract %[[SRC]][1] : f32 from vector<4xf32> - // CHECK: %[[INSERT1:.+]] = vector.insert %[[EXTRACT1]], %{{.+}} [] : f32 into vector - // CHECK: %[[SHAPE_CAST1:.+]] = vector.shape_cast %[[INSERT1]] : vector to vector<1xf32> - // CHECK: %[[CAST1:.+]] = vector.bitcast %[[SHAPE_CAST1]] : vector<1xf32> to vector<2xf16> + // CHECK: %[[INSERT1:.+]] = vector.insert %[[EXTRACT1]], %{{.+}} [0] : f32 into vector<1xf32> + // CHECK: %[[CAST1:.+]] = vector.bitcast %[[INSERT1]] : vector<1xf32> to vector<2xf16> // CHECK: %[[EXTRACT2:.+]] = vector.extract %[[CAST1]][1] : f16 from vector<2xf16> %1 = vector.extract %0[3] : f16 from vector<8xf16> // CHECK: %[[EXTRACT3:.+]] = vector.extract %[[SRC]][2] : f32 from vector<4xf32> - // CHECK: %[[INSERT3:.+]] = vector.insert %[[EXTRACT3]], %{{.+}} [] : f32 into vector - // CHECK: %[[SHAPE_CAST2:.+]] = vector.shape_cast %[[INSERT3]] : vector to vector<1xf32> - // CHECK: %[[CAST2:.+]] = vector.bitcast %[[SHAPE_CAST2]] : vector<1xf32> to vector<2xf16> + // CHECK: %[[INSERT3:.+]] = vector.insert %[[EXTRACT3]], %{{.+}} [0] : f32 into vector<1xf32> + // CHECK: %[[CAST2:.+]] = vector.bitcast %[[INSERT3]] : vector<1xf32> to vector<2xf16> // CHECK: %[[EXTRACT4:.+]] = vector.extract %[[CAST2]][0] : f16 from vector<2xf16> %2 = vector.extract %0[4] : f16 from vector<8xf16> // CHECK: return %[[EXTRACT2]], %[[EXTRACT4]]