Skip to content

Commit

Permalink
[ADT] Allow specifying the size of resulting SmallVector in `map_to…
Browse files Browse the repository at this point in the history
…_vector`

This patch adds an overload for the `map_to_vector` helper template, exposing a parameter to control the size of the resulting `SmallVector`. A few call sites in mlir are updated to illustrate and test the change.

Differential Revision: https://reviews.llvm.org/D150601
  • Loading branch information
Laszlo Kindrat committed May 25, 2023
1 parent aae8524 commit f3ece29
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 35 deletions.
5 changes: 5 additions & 0 deletions llvm/include/llvm/ADT/SmallVectorExtras.h
Expand Up @@ -20,6 +20,11 @@
namespace llvm {

/// Map a range to a SmallVector with element types deduced from the mapping.
template <unsigned Size, class ContainerTy, class FuncTy>
auto map_to_vector(ContainerTy &&C, FuncTy &&F) {
return to_vector<Size>(
map_range(std::forward<ContainerTy>(C), std::forward<FuncTy>(F)));
}
template <class ContainerTy, class FuncTy>
auto map_to_vector(ContainerTy &&C, FuncTy &&F) {
return to_vector(
Expand Down
17 changes: 9 additions & 8 deletions mlir/include/mlir/IR/AffineMap.h
Expand Up @@ -19,6 +19,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include <optional>

namespace llvm {
Expand Down Expand Up @@ -226,24 +227,24 @@ class AffineMap {
AffineMap shiftDims(unsigned shift, unsigned offset = 0) const {
assert(offset <= getNumDims());
return AffineMap::get(getNumDims() + shift, getNumSymbols(),
llvm::to_vector<4>(llvm::map_range(
llvm::map_to_vector<4>(
getResults(),
[&](AffineExpr e) {
return e.shiftDims(getNumDims(), shift, offset);
})),
}),
getContext());
}

/// Replace symbols[offset ... numSymbols)
/// by symbols[offset + shift ... shift + numSymbols).
AffineMap shiftSymbols(unsigned shift, unsigned offset = 0) const {
return AffineMap::get(getNumDims(), getNumSymbols() + shift,
llvm::to_vector<4>(llvm::map_range(
getResults(),
[&](AffineExpr e) {
return e.shiftSymbols(getNumSymbols(), shift,
offset);
})),
llvm::map_to_vector<4>(getResults(),
[&](AffineExpr e) {
return e.shiftSymbols(
getNumSymbols(), shift,
offset);
}),
getContext());
}

Expand Down
40 changes: 20 additions & 20 deletions mlir/lib/IR/Builders.cpp
Expand Up @@ -15,6 +15,7 @@
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/SymbolTable.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/raw_ostream.h"

using namespace mlir;
Expand Down Expand Up @@ -261,57 +262,56 @@ ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
}

ArrayAttr Builder::getBoolArrayAttr(ArrayRef<bool> values) {
auto attrs = llvm::to_vector<8>(llvm::map_range(
values, [this](bool v) -> Attribute { return getBoolAttr(v); }));
auto attrs = llvm::map_to_vector<8>(
values, [this](bool v) -> Attribute { return getBoolAttr(v); });
return getArrayAttr(attrs);
}

ArrayAttr Builder::getI32ArrayAttr(ArrayRef<int32_t> values) {
auto attrs = llvm::to_vector<8>(llvm::map_range(
values, [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); }));
auto attrs = llvm::map_to_vector<8>(
values, [this](int32_t v) -> Attribute { return getI32IntegerAttr(v); });
return getArrayAttr(attrs);
}
ArrayAttr Builder::getI64ArrayAttr(ArrayRef<int64_t> values) {
auto attrs = llvm::to_vector<8>(llvm::map_range(
values, [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); }));
auto attrs = llvm::map_to_vector<8>(
values, [this](int64_t v) -> Attribute { return getI64IntegerAttr(v); });
return getArrayAttr(attrs);
}

ArrayAttr Builder::getIndexArrayAttr(ArrayRef<int64_t> values) {
auto attrs = llvm::to_vector<8>(
llvm::map_range(values, [this](int64_t v) -> Attribute {
return getIntegerAttr(IndexType::get(getContext()), v);
}));
auto attrs = llvm::map_to_vector<8>(values, [this](int64_t v) -> Attribute {
return getIntegerAttr(IndexType::get(getContext()), v);
});
return getArrayAttr(attrs);
}

ArrayAttr Builder::getF32ArrayAttr(ArrayRef<float> values) {
auto attrs = llvm::to_vector<8>(llvm::map_range(
values, [this](float v) -> Attribute { return getF32FloatAttr(v); }));
auto attrs = llvm::map_to_vector<8>(
values, [this](float v) -> Attribute { return getF32FloatAttr(v); });
return getArrayAttr(attrs);
}

ArrayAttr Builder::getF64ArrayAttr(ArrayRef<double> values) {
auto attrs = llvm::to_vector<8>(llvm::map_range(
values, [this](double v) -> Attribute { return getF64FloatAttr(v); }));
auto attrs = llvm::map_to_vector<8>(
values, [this](double v) -> Attribute { return getF64FloatAttr(v); });
return getArrayAttr(attrs);
}

ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) {
auto attrs = llvm::to_vector<8>(llvm::map_range(
values, [this](StringRef v) -> Attribute { return getStringAttr(v); }));
auto attrs = llvm::map_to_vector<8>(
values, [this](StringRef v) -> Attribute { return getStringAttr(v); });
return getArrayAttr(attrs);
}

ArrayAttr Builder::getTypeArrayAttr(TypeRange values) {
auto attrs = llvm::to_vector<8>(llvm::map_range(
values, [](Type v) -> Attribute { return TypeAttr::get(v); }));
auto attrs = llvm::map_to_vector<8>(
values, [](Type v) -> Attribute { return TypeAttr::get(v); });
return getArrayAttr(attrs);
}

ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
auto attrs = llvm::to_vector<8>(llvm::map_range(
values, [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); }));
auto attrs = llvm::map_to_vector<8>(
values, [](AffineMap v) -> Attribute { return AffineMapAttr::get(v); });
return getArrayAttr(attrs);
}

Expand Down
13 changes: 6 additions & 7 deletions mlir/lib/IR/TypeUtilities.cpp
Expand Up @@ -11,13 +11,12 @@
//===----------------------------------------------------------------------===//

#include "mlir/IR/TypeUtilities.h"

#include <numeric>

#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include <numeric>

using namespace mlir;

Expand Down Expand Up @@ -119,8 +118,8 @@ LogicalResult mlir::verifyCompatibleDims(ArrayRef<int64_t> dims) {
/// have compatible dimensions. Dimensions are compatible if all non-dynamic
/// dims are equal. The element type does not matter.
LogicalResult mlir::verifyCompatibleShapes(TypeRange types) {
auto shapedTypes = llvm::to_vector<8>(llvm::map_range(
types, [](auto type) { return llvm::dyn_cast<ShapedType>(type); }));
auto shapedTypes = llvm::map_to_vector<8>(
types, [](auto type) { return llvm::dyn_cast<ShapedType>(type); });
// Return failure if some, but not all are not shaped. Return early if none
// are shaped also.
if (llvm::none_of(shapedTypes, [](auto t) { return t; }))
Expand Down Expand Up @@ -155,10 +154,10 @@ LogicalResult mlir::verifyCompatibleShapes(TypeRange types) {

for (unsigned i = 0; i < firstRank; ++i) {
// Retrieve all ranked dimensions
auto dims = llvm::to_vector<8>(llvm::map_range(
auto dims = llvm::map_to_vector<8>(
llvm::make_filter_range(
shapes, [&](auto shape) { return shape.getRank() >= i; }),
[&](auto shape) { return shape.getDimSize(i); }));
[&](auto shape) { return shape.getDimSize(i); });
if (verifyCompatibleDims(dims).failed())
return failure();
}
Expand Down

0 comments on commit f3ece29

Please sign in to comment.