diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp index d16c7c3d6fdbe..c788d4ccb4a08 100644 --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -184,9 +184,9 @@ static ParseResult parseContractionOp(OpAsmParser &parser, auto lhsType = types[0].cast(); auto rhsType = types[1].cast(); auto maskElementType = parser.getBuilder().getI1Type(); - SmallVector maskTypes; - maskTypes.push_back(VectorType::get(lhsType.getShape(), maskElementType)); - maskTypes.push_back(VectorType::get(rhsType.getShape(), maskElementType)); + std::array maskTypes = { + VectorType::get(lhsType.getShape(), maskElementType), + VectorType::get(rhsType.getShape(), maskElementType)}; if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands)) return failure(); return success(); @@ -462,12 +462,10 @@ std::vector> ContractionOp::getBatchDimMap() { } SmallVector ContractionOp::getIndexingMaps() { - SmallVector res; - auto mapAttrs = indexing_maps().getValue(); - res.reserve(mapAttrs.size()); - for (auto mapAttr : mapAttrs) - res.push_back(mapAttr.cast().getValue()); - return res; + return llvm::to_vector<4>( + llvm::map_range(indexing_maps().getValue(), [](Attribute mapAttr) { + return mapAttr.cast().getValue(); + })); } Optional> ContractionOp::getShapeForUnroll() { @@ -1854,8 +1852,7 @@ LogicalResult TransferWriteOp::fold(ArrayRef, } Optional> TransferWriteOp::getShapeForUnroll() { - auto s = getVectorType().getShape(); - return SmallVector{s.begin(), s.end()}; + return llvm::to_vector<4>(getVectorType().getShape()); } //===----------------------------------------------------------------------===// @@ -2014,11 +2011,8 @@ static SmallVector extractShape(MemRefType memRefType) { auto vectorType = memRefType.getElementType().dyn_cast(); SmallVector res(memRefType.getShape().begin(), memRefType.getShape().end()); - if (vectorType) { - res.reserve(memRefType.getRank() + vectorType.getRank()); - for (auto s : vectorType.getShape()) - res.push_back(s); - } + if (vectorType) + res.append(vectorType.getShape().begin(), vectorType.getShape().end()); return res; } diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp index ab93ef406024e..197b1c62274b2 100644 --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1707,7 +1707,7 @@ void ContractionOpToOuterProductOpLowering::rewrite( auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; AffineExpr m, n, k; bindDims(rewriter.getContext(), m, n, k); - SmallVector perm{1, 0}; + static constexpr std::array perm = {1, 0}; auto iteratorTypes = op.iterator_types().getValue(); SmallVector maps = op.getIndexingMaps(); if (isParallelIterator(iteratorTypes[0]) && @@ -1911,10 +1911,10 @@ Value ContractionOpLowering::lowerParallel(vector::ContractionOp op, assert(lookup.hasValue() && "parallel index not listed in reduction"); int64_t resIndex = lookup.getValue(); // Construct new iterator types and affine map array attribute. - SmallVector lowIndexingMaps; - lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter)); - lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter)); - lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter)); + std::array lowIndexingMaps = { + adjustMap(iMap[0], iterIndex, rewriter), + adjustMap(iMap[1], iterIndex, rewriter), + adjustMap(iMap[2], iterIndex, rewriter)}; auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); auto lowIter = rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex)); @@ -1962,10 +1962,10 @@ Value ContractionOpLowering::lowerReduction(vector::ContractionOp op, op.acc()); } // Construct new iterator types and affine map array attribute. - SmallVector lowIndexingMaps; - lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter)); - lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter)); - lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter)); + std::array lowIndexingMaps = { + adjustMap(iMap[0], iterIndex, rewriter), + adjustMap(iMap[1], iterIndex, rewriter), + adjustMap(iMap[2], iterIndex, rewriter)}; auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); auto lowIter = rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex));