-
Notifications
You must be signed in to change notification settings - Fork 10.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] Add first-class support for scalability in VectorType dims #74251
base: main
Are you sure you want to change the base?
Conversation
c8e828a
to
037c665
Compare
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir-core Author: Benjamin Maxwell (MacDue) ChangesCurrently, the shape of a VectorType is stored in two separate lists. The 'shape' which comes from ShapedType, which does not have a way to represent scalability, and the 'scalableDims', an additional list of bools attached to VectorType. This can be somewhat cumbersome to work with, and easy to ignore the scalability of a dim, producing incorrect results. For example, to correctly trim leading unit dims of a VectorType, currently, you need to do something like: while (!newShape.empty() && newShape.front() == 1 &&
!newScalableDims.front()) {
newShape = newShape.drop_front(1);
newScalableDims = newScalableDims.drop_front(1);
} Which would be wrong if you (more naturally) wrote it as: auto newShape = vectorType.getShape().drop_while([](int64_t dim) {
return dim == 1;
}); As this would trim scalable one dims ( This patch does not change the storage of the VectorType, but instead adds new scalability-safe accessors and iterators. Two new methods are added to VectorType: /// Returns the value of the specified dimension (including scalability)
VectorDim VectorType::getDim(unsigned idx);
/// Returns the dimensions of this vector type (including scalability)
VectorDims VectorType::getDims(); These are backed by two new classes:
There are also new builders to construct VectorTypes from both the With these changes the previous example becomes: auto newDims = vectorType.getDims().dropWhile([](VectorDim dim) {
return dim == VectorDim::getFixed(1);
}); Which (to me) is easier to read, and safer as it is not possible to forget check the scalability of the dim. Just comparing with Full diff: https://github.com/llvm/llvm-project/pull/74251.diff 3 Files Affected:
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 92ce053ad5c82..b468fd42f374e 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -12,6 +12,7 @@
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/Support/ADTExtras.h"
+#include "llvm/ADT/STLExtras.h"
namespace llvm {
class BitVector;
@@ -181,6 +182,197 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
};
+//===----------------------------------------------------------------------===//
+// VectorDim
+//===----------------------------------------------------------------------===//
+
+/// This class represents a dimension of a vector type. Unlike other ShapedTypes
+/// vector dimensions can have scalable quantities, which means the dimension
+/// has a known minimum size, which is scaled by a constant that is only
+/// known at runtime.
+class VectorDim {
+public:
+ explicit constexpr VectorDim(int64_t quantity, bool scalable)
+ : quantity(quantity), scalable(scalable){};
+
+ /// Constructs a new fixed dimension.
+ constexpr static VectorDim getFixed(int64_t quantity) {
+ return VectorDim(quantity, false);
+ }
+
+ /// Constructs a new scalable dimension.
+ constexpr static VectorDim getScalable(int64_t quantity) {
+ return VectorDim(quantity, true);
+ }
+
+ /// Returns true if this dimension is scalable;
+ constexpr bool isScalable() const { return scalable; }
+
+ /// Returns true if this dimension is fixed.
+ constexpr bool isFixed() const { return !isScalable(); }
+
+ /// Returns the minimum number of elements this dimension can contain.
+ constexpr int64_t getMinSize() const { return quantity; }
+
+ /// If this dimension is fixed returns the number of elements, otherwise
+ /// aborts.
+ constexpr int64_t getFixedSize() const {
+ assert(isFixed());
+ return quantity;
+ }
+
+ constexpr bool operator==(VectorDim const &dim) const {
+ return quantity == dim.quantity && scalable == dim.scalable;
+ }
+
+ constexpr bool operator!=(VectorDim const &dim) const {
+ return !(*this == dim);
+ }
+
+ /// Helper class for indexing into a list of sizes (and possibly empty) list
+ /// of scalable dimensions, extracting VectorDim elements.
+ struct Indexer {
+ explicit Indexer(ArrayRef<int64_t> sizes, ArrayRef<bool> scalableDims)
+ : sizes(sizes), scalableDims(scalableDims) {
+ assert(
+ scalableDims.empty() ||
+ sizes.size() == scalableDims.size() &&
+ "expected `scalableDims` to be empty or match `sizes` in length");
+ }
+
+ VectorDim operator[](size_t idx) const {
+ int64_t size = sizes[idx];
+ bool scalable = scalableDims.empty() ? false : scalableDims[idx];
+ return VectorDim(size, scalable);
+ }
+
+ ArrayRef<int64_t> sizes;
+ ArrayRef<bool> scalableDims;
+ };
+
+private:
+ int64_t quantity;
+ bool scalable;
+};
+
+//===----------------------------------------------------------------------===//
+// VectorDims
+//===----------------------------------------------------------------------===//
+
+/// Represents a non-owning list of vector dimensions. The underlying dimension
+/// sizes and scalability flags are stored a two seperate lists to match the
+/// storage of a VectorType.
+class VectorDims : public VectorDim::Indexer {
+public:
+ using VectorDim::Indexer::Indexer;
+
+ class Iterator : public llvm::iterator_facade_base<
+ Iterator, std::random_access_iterator_tag, VectorDim,
+ std::ptrdiff_t, VectorDim, VectorDim> {
+ public:
+ Iterator(VectorDim::Indexer indexer, size_t index)
+ : indexer(indexer), index(index){};
+
+ // Iterator boilerplate.
+ ptrdiff_t operator-(const Iterator &rhs) const { return index - rhs.index; }
+ bool operator==(const Iterator &rhs) const { return index == rhs.index; }
+ bool operator<(const Iterator &rhs) const { return index < rhs.index; }
+ Iterator &operator+=(ptrdiff_t offset) {
+ index += offset;
+ return *this;
+ }
+ Iterator &operator-=(ptrdiff_t offset) {
+ index -= offset;
+ return *this;
+ }
+ VectorDim operator*() const { return indexer[index]; }
+
+ VectorDim::Indexer getIndexer() const { return indexer; }
+ ptrdiff_t getIndex() const { return index; }
+
+ private:
+ VectorDim::Indexer indexer;
+ ptrdiff_t index;
+ };
+
+ /// Construct from iterator pair.
+ VectorDims(Iterator begin, Iterator end)
+ : VectorDims(VectorDims(begin.getIndexer())
+ .slice(begin.getIndex(), end - begin)) {}
+
+ VectorDims(VectorDim::Indexer indexer) : VectorDim::Indexer(indexer){};
+
+ Iterator begin() const { return Iterator(*this, 0); }
+ Iterator end() const { return Iterator(*this, size()); }
+
+ /// Check if the dims are empty.
+ bool empty() const { return sizes.empty(); }
+
+ /// Get the number of dims.
+ size_t size() const { return sizes.size(); }
+
+ /// Return the first dim.
+ VectorDim front() const { return (*this)[0]; }
+
+ /// Return the last dim.
+ VectorDim back() const { return (*this)[size() - 1]; }
+
+ /// Chop of thie first \p n dims, and keep the remaining \p m
+ /// dims.
+ VectorDims slice(size_t n, size_t m) const {
+ ArrayRef<int64_t> newSizes = sizes.slice(n, m);
+ ArrayRef<bool> newScalableDims =
+ scalableDims.empty() ? ArrayRef<bool>{} : scalableDims.slice(n, m);
+ return VectorDims(newSizes, newScalableDims);
+ }
+
+ /// Drop the first \p n dims.
+ VectorDims dropFront(size_t n = 1) const { return slice(n, size() - n); }
+
+ /// Drop the last \p n dims.
+ VectorDims dropBack(size_t n = 1) const { return slice(0, size() - n); }
+
+ /// Return copy of *this with the first n dims matching the predicate removed.
+ template <class PredicateT>
+ VectorDims dropWhile(PredicateT predicate) const {
+ return VectorDims(llvm::find_if_not(*this, predicate), end());
+ }
+
+ /// Return the underlying sizes.
+ ArrayRef<int64_t> getSizes() const { return sizes; }
+
+ /// Return the underlying scalable dims.
+ ArrayRef<bool> getScalableDims() const { return scalableDims; }
+
+ /// Check for dim equality.
+ bool equals(VectorDims rhs) const {
+ if (size() != rhs.size())
+ return false;
+ return std::equal(begin(), end(), rhs.begin());
+ }
+
+ /// Check for dim equality.
+ bool equals(ArrayRef<VectorDim> rhs) const {
+ if (size() != rhs.size())
+ return false;
+ return std::equal(begin(), end(), rhs.begin());
+ }
+};
+
+inline bool operator==(VectorDims lhs, VectorDims rhs) {
+ return lhs.equals(rhs);
+}
+
+inline bool operator!=(VectorDims lhs, VectorDims rhs) { return !(lhs == rhs); }
+
+inline bool operator==(VectorDims lhs, ArrayRef<VectorDim> rhs) {
+ return lhs.equals(rhs);
+}
+
+inline bool operator!=(VectorDims lhs, ArrayRef<VectorDim> rhs) {
+ return !(lhs == rhs);
+}
+
} // namespace mlir
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 1d7772810ae6e..d6cd14079fab8 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1089,6 +1089,18 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
scalableDims = isScalableVec;
}
return $_get(elementType.getContext(), shape, elementType, scalableDims);
+ }]>,
+ TypeBuilderWithInferredContext<(ins "Type":$elementType, "ArrayRef<VectorDim>": $shape), [{
+ SmallVector<int64_t> sizes;
+ SmallVector<bool> scalableDims;
+ for (VectorDim dim : shape) {
+ sizes.push_back(dim.getMinSize());
+ scalableDims.push_back(dim.isScalable());
+ }
+ return get(sizes, elementType, scalableDims);
+ }]>,
+ TypeBuilderWithInferredContext<(ins "Type":$elementType, "VectorDims": $shape), [{
+ return get(shape.getSizes(), elementType, shape.getScalableDims());
}]>
];
let extraClassDeclaration = [{
@@ -1096,6 +1108,17 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
/// Arguments that are passed into the builder must outlive the builder.
class Builder;
+ /// Returns the value of the specified dimension (including scalability).
+ VectorDim getDim(unsigned idx) const {
+ assert(idx < getRank() && "invalid dim index for vector type");
+ return getDims()[idx];
+ }
+
+ /// Returns the dimensions of this vector type (including scalability).
+ VectorDims getDims() const {
+ return VectorDims(getShape(), getScalableDims());
+ }
+
/// Returns true if the given type can be used as an element of a vector
/// type. In particular, vectors can consist of integer, index, or float
/// primitives.
diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index 61264bc523648..07625da6ee889 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -226,4 +226,105 @@ TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
}
}
+TEST(ShapedTypeTest, VectorDims) {
+ MLIRContext context;
+ Type f32 = FloatType::getF32(&context);
+
+ SmallVector<VectorDim> dims{VectorDim::getFixed(2), VectorDim::getScalable(4),
+ VectorDim::getFixed(8), VectorDim::getScalable(9),
+ VectorDim::getFixed(1)};
+ VectorType vectorType = VectorType::get(f32, dims);
+
+ // Directly check values
+ {
+ auto dim0 = vectorType.getDim(0);
+ ASSERT_EQ(dim0.getMinSize(), 2);
+ ASSERT_TRUE(dim0.isFixed());
+
+ auto dim1 = vectorType.getDim(1);
+ ASSERT_EQ(dim1.getMinSize(), 4);
+ ASSERT_TRUE(dim1.isScalable());
+
+ auto dim2 = vectorType.getDim(2);
+ ASSERT_EQ(dim2.getMinSize(), 8);
+ ASSERT_TRUE(dim2.isFixed());
+
+ auto dim3 = vectorType.getDim(3);
+ ASSERT_EQ(dim3.getMinSize(), 9);
+ ASSERT_TRUE(dim3.isScalable());
+
+ auto dim4 = vectorType.getDim(4);
+ ASSERT_EQ(dim4.getMinSize(), 1);
+ ASSERT_TRUE(dim4.isFixed());
+ }
+
+ // Test indexing via getDim(idx)
+ {
+ for (unsigned i = 0; i < dims.size(); i++)
+ ASSERT_EQ(vectorType.getDim(i), dims[i]);
+ }
+
+ // Test using VectorDims::Iterator in for-each loop
+ {
+ unsigned i = 0;
+ for (VectorDim dim : vectorType.getDims())
+ ASSERT_EQ(dim, dims[i++]);
+ ASSERT_EQ(i, vectorType.getRank());
+ }
+
+ // Test using VectorDims::Iterator in LLVM iterator helper
+ {
+ for (auto [dim, expectedDim] :
+ llvm::zip_equal(vectorType.getDims(), dims)) {
+ ASSERT_EQ(dim, expectedDim);
+ }
+ }
+
+ // Test dropFront()
+ {
+ auto vectorDims = vectorType.getDims();
+ auto newDims = vectorDims.dropFront();
+
+ ASSERT_EQ(newDims.size(), vectorDims.size() - 1);
+ for (unsigned i = 0; i < newDims.size(); i++)
+ ASSERT_EQ(newDims[i], vectorDims[i + 1]);
+ }
+
+ // Test dropBack()
+ {
+ auto vectorDims = vectorType.getDims();
+ auto newDims = vectorDims.dropBack();
+
+ ASSERT_EQ(newDims.size(), vectorDims.size() - 1);
+ for (unsigned i = 0; i < newDims.size(); i++)
+ ASSERT_EQ(newDims[i], vectorDims[i]);
+ }
+
+ // Test front()
+ { ASSERT_EQ(vectorType.getDims().front(), VectorDim::getFixed(2)); }
+
+ // Test back()
+ { ASSERT_EQ(vectorType.getDims().back(), VectorDim::getFixed(1)); }
+
+ // Test dropWhile.
+ {
+ SmallVector<VectorDim> dims{
+ VectorDim::getFixed(1), VectorDim::getFixed(1), VectorDim::getFixed(1),
+ VectorDim::getScalable(1), VectorDim::getScalable(4)};
+
+ VectorType vectorTypeWithLeadingUnitDims = VectorType::get(f32, dims);
+ ASSERT_EQ(vectorTypeWithLeadingUnitDims.getDims().size(),
+ unsigned(vectorTypeWithLeadingUnitDims.getRank()));
+
+ // Drop leading unit dims.
+ auto withoutLeadingUnitDims =
+ vectorTypeWithLeadingUnitDims.getDims().dropWhile(
+ [](VectorDim dim) { return dim == VectorDim::getFixed(1); });
+
+ SmallVector<VectorDim> expectedDims{VectorDim::getScalable(1),
+ VectorDim::getScalable(4)};
+ ASSERT_EQ(withoutLeadingUnitDims, expectedDims);
+ }
+}
+
} // namespace
|
@llvm/pr-subscribers-mlir Author: Benjamin Maxwell (MacDue) ChangesCurrently, the shape of a VectorType is stored in two separate lists. The 'shape' which comes from ShapedType, which does not have a way to represent scalability, and the 'scalableDims', an additional list of bools attached to VectorType. This can be somewhat cumbersome to work with, and easy to ignore the scalability of a dim, producing incorrect results. For example, to correctly trim leading unit dims of a VectorType, currently, you need to do something like: while (!newShape.empty() && newShape.front() == 1 &&
!newScalableDims.front()) {
newShape = newShape.drop_front(1);
newScalableDims = newScalableDims.drop_front(1);
} Which would be wrong if you (more naturally) wrote it as: auto newShape = vectorType.getShape().drop_while([](int64_t dim) {
return dim == 1;
}); As this would trim scalable one dims ( This patch does not change the storage of the VectorType, but instead adds new scalability-safe accessors and iterators. Two new methods are added to VectorType: /// Returns the value of the specified dimension (including scalability)
VectorDim VectorType::getDim(unsigned idx);
/// Returns the dimensions of this vector type (including scalability)
VectorDims VectorType::getDims(); These are backed by two new classes:
There are also new builders to construct VectorTypes from both the With these changes the previous example becomes: auto newDims = vectorType.getDims().dropWhile([](VectorDim dim) {
return dim == VectorDim::getFixed(1);
}); Which (to me) is easier to read, and safer as it is not possible to forget check the scalability of the dim. Just comparing with Full diff: https://github.com/llvm/llvm-project/pull/74251.diff 3 Files Affected:
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 92ce053ad5c82..b468fd42f374e 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -12,6 +12,7 @@
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/Support/ADTExtras.h"
+#include "llvm/ADT/STLExtras.h"
namespace llvm {
class BitVector;
@@ -181,6 +182,197 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
};
+//===----------------------------------------------------------------------===//
+// VectorDim
+//===----------------------------------------------------------------------===//
+
+/// This class represents a dimension of a vector type. Unlike other ShapedTypes
+/// vector dimensions can have scalable quantities, which means the dimension
+/// has a known minimum size, which is scaled by a constant that is only
+/// known at runtime.
+class VectorDim {
+public:
+ explicit constexpr VectorDim(int64_t quantity, bool scalable)
+ : quantity(quantity), scalable(scalable){};
+
+ /// Constructs a new fixed dimension.
+ constexpr static VectorDim getFixed(int64_t quantity) {
+ return VectorDim(quantity, false);
+ }
+
+ /// Constructs a new scalable dimension.
+ constexpr static VectorDim getScalable(int64_t quantity) {
+ return VectorDim(quantity, true);
+ }
+
+ /// Returns true if this dimension is scalable;
+ constexpr bool isScalable() const { return scalable; }
+
+ /// Returns true if this dimension is fixed.
+ constexpr bool isFixed() const { return !isScalable(); }
+
+ /// Returns the minimum number of elements this dimension can contain.
+ constexpr int64_t getMinSize() const { return quantity; }
+
+ /// If this dimension is fixed returns the number of elements, otherwise
+ /// aborts.
+ constexpr int64_t getFixedSize() const {
+ assert(isFixed());
+ return quantity;
+ }
+
+ constexpr bool operator==(VectorDim const &dim) const {
+ return quantity == dim.quantity && scalable == dim.scalable;
+ }
+
+ constexpr bool operator!=(VectorDim const &dim) const {
+ return !(*this == dim);
+ }
+
+ /// Helper class for indexing into a list of sizes (and possibly empty) list
+ /// of scalable dimensions, extracting VectorDim elements.
+ struct Indexer {
+ explicit Indexer(ArrayRef<int64_t> sizes, ArrayRef<bool> scalableDims)
+ : sizes(sizes), scalableDims(scalableDims) {
+ assert(
+ scalableDims.empty() ||
+ sizes.size() == scalableDims.size() &&
+ "expected `scalableDims` to be empty or match `sizes` in length");
+ }
+
+ VectorDim operator[](size_t idx) const {
+ int64_t size = sizes[idx];
+ bool scalable = scalableDims.empty() ? false : scalableDims[idx];
+ return VectorDim(size, scalable);
+ }
+
+ ArrayRef<int64_t> sizes;
+ ArrayRef<bool> scalableDims;
+ };
+
+private:
+ int64_t quantity;
+ bool scalable;
+};
+
+//===----------------------------------------------------------------------===//
+// VectorDims
+//===----------------------------------------------------------------------===//
+
+/// Represents a non-owning list of vector dimensions. The underlying dimension
+/// sizes and scalability flags are stored a two seperate lists to match the
+/// storage of a VectorType.
+class VectorDims : public VectorDim::Indexer {
+public:
+ using VectorDim::Indexer::Indexer;
+
+ class Iterator : public llvm::iterator_facade_base<
+ Iterator, std::random_access_iterator_tag, VectorDim,
+ std::ptrdiff_t, VectorDim, VectorDim> {
+ public:
+ Iterator(VectorDim::Indexer indexer, size_t index)
+ : indexer(indexer), index(index){};
+
+ // Iterator boilerplate.
+ ptrdiff_t operator-(const Iterator &rhs) const { return index - rhs.index; }
+ bool operator==(const Iterator &rhs) const { return index == rhs.index; }
+ bool operator<(const Iterator &rhs) const { return index < rhs.index; }
+ Iterator &operator+=(ptrdiff_t offset) {
+ index += offset;
+ return *this;
+ }
+ Iterator &operator-=(ptrdiff_t offset) {
+ index -= offset;
+ return *this;
+ }
+ VectorDim operator*() const { return indexer[index]; }
+
+ VectorDim::Indexer getIndexer() const { return indexer; }
+ ptrdiff_t getIndex() const { return index; }
+
+ private:
+ VectorDim::Indexer indexer;
+ ptrdiff_t index;
+ };
+
+ /// Construct from iterator pair.
+ VectorDims(Iterator begin, Iterator end)
+ : VectorDims(VectorDims(begin.getIndexer())
+ .slice(begin.getIndex(), end - begin)) {}
+
+ VectorDims(VectorDim::Indexer indexer) : VectorDim::Indexer(indexer){};
+
+ Iterator begin() const { return Iterator(*this, 0); }
+ Iterator end() const { return Iterator(*this, size()); }
+
+ /// Check if the dims are empty.
+ bool empty() const { return sizes.empty(); }
+
+ /// Get the number of dims.
+ size_t size() const { return sizes.size(); }
+
+ /// Return the first dim.
+ VectorDim front() const { return (*this)[0]; }
+
+ /// Return the last dim.
+ VectorDim back() const { return (*this)[size() - 1]; }
+
+ /// Chop of thie first \p n dims, and keep the remaining \p m
+ /// dims.
+ VectorDims slice(size_t n, size_t m) const {
+ ArrayRef<int64_t> newSizes = sizes.slice(n, m);
+ ArrayRef<bool> newScalableDims =
+ scalableDims.empty() ? ArrayRef<bool>{} : scalableDims.slice(n, m);
+ return VectorDims(newSizes, newScalableDims);
+ }
+
+ /// Drop the first \p n dims.
+ VectorDims dropFront(size_t n = 1) const { return slice(n, size() - n); }
+
+ /// Drop the last \p n dims.
+ VectorDims dropBack(size_t n = 1) const { return slice(0, size() - n); }
+
+ /// Return copy of *this with the first n dims matching the predicate removed.
+ template <class PredicateT>
+ VectorDims dropWhile(PredicateT predicate) const {
+ return VectorDims(llvm::find_if_not(*this, predicate), end());
+ }
+
+ /// Return the underlying sizes.
+ ArrayRef<int64_t> getSizes() const { return sizes; }
+
+ /// Return the underlying scalable dims.
+ ArrayRef<bool> getScalableDims() const { return scalableDims; }
+
+ /// Check for dim equality.
+ bool equals(VectorDims rhs) const {
+ if (size() != rhs.size())
+ return false;
+ return std::equal(begin(), end(), rhs.begin());
+ }
+
+ /// Check for dim equality.
+ bool equals(ArrayRef<VectorDim> rhs) const {
+ if (size() != rhs.size())
+ return false;
+ return std::equal(begin(), end(), rhs.begin());
+ }
+};
+
+inline bool operator==(VectorDims lhs, VectorDims rhs) {
+ return lhs.equals(rhs);
+}
+
+inline bool operator!=(VectorDims lhs, VectorDims rhs) { return !(lhs == rhs); }
+
+inline bool operator==(VectorDims lhs, ArrayRef<VectorDim> rhs) {
+ return lhs.equals(rhs);
+}
+
+inline bool operator!=(VectorDims lhs, ArrayRef<VectorDim> rhs) {
+ return !(lhs == rhs);
+}
+
} // namespace mlir
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 1d7772810ae6e..d6cd14079fab8 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1089,6 +1089,18 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
scalableDims = isScalableVec;
}
return $_get(elementType.getContext(), shape, elementType, scalableDims);
+ }]>,
+ TypeBuilderWithInferredContext<(ins "Type":$elementType, "ArrayRef<VectorDim>": $shape), [{
+ SmallVector<int64_t> sizes;
+ SmallVector<bool> scalableDims;
+ for (VectorDim dim : shape) {
+ sizes.push_back(dim.getMinSize());
+ scalableDims.push_back(dim.isScalable());
+ }
+ return get(sizes, elementType, scalableDims);
+ }]>,
+ TypeBuilderWithInferredContext<(ins "Type":$elementType, "VectorDims": $shape), [{
+ return get(shape.getSizes(), elementType, shape.getScalableDims());
}]>
];
let extraClassDeclaration = [{
@@ -1096,6 +1108,17 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
/// Arguments that are passed into the builder must outlive the builder.
class Builder;
+ /// Returns the value of the specified dimension (including scalability).
+ VectorDim getDim(unsigned idx) const {
+ assert(idx < getRank() && "invalid dim index for vector type");
+ return getDims()[idx];
+ }
+
+ /// Returns the dimensions of this vector type (including scalability).
+ VectorDims getDims() const {
+ return VectorDims(getShape(), getScalableDims());
+ }
+
/// Returns true if the given type can be used as an element of a vector
/// type. In particular, vectors can consist of integer, index, or float
/// primitives.
diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index 61264bc523648..07625da6ee889 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -226,4 +226,105 @@ TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
}
}
+TEST(ShapedTypeTest, VectorDims) {
+ MLIRContext context;
+ Type f32 = FloatType::getF32(&context);
+
+ SmallVector<VectorDim> dims{VectorDim::getFixed(2), VectorDim::getScalable(4),
+ VectorDim::getFixed(8), VectorDim::getScalable(9),
+ VectorDim::getFixed(1)};
+ VectorType vectorType = VectorType::get(f32, dims);
+
+ // Directly check values
+ {
+ auto dim0 = vectorType.getDim(0);
+ ASSERT_EQ(dim0.getMinSize(), 2);
+ ASSERT_TRUE(dim0.isFixed());
+
+ auto dim1 = vectorType.getDim(1);
+ ASSERT_EQ(dim1.getMinSize(), 4);
+ ASSERT_TRUE(dim1.isScalable());
+
+ auto dim2 = vectorType.getDim(2);
+ ASSERT_EQ(dim2.getMinSize(), 8);
+ ASSERT_TRUE(dim2.isFixed());
+
+ auto dim3 = vectorType.getDim(3);
+ ASSERT_EQ(dim3.getMinSize(), 9);
+ ASSERT_TRUE(dim3.isScalable());
+
+ auto dim4 = vectorType.getDim(4);
+ ASSERT_EQ(dim4.getMinSize(), 1);
+ ASSERT_TRUE(dim4.isFixed());
+ }
+
+ // Test indexing via getDim(idx)
+ {
+ for (unsigned i = 0; i < dims.size(); i++)
+ ASSERT_EQ(vectorType.getDim(i), dims[i]);
+ }
+
+ // Test using VectorDims::Iterator in for-each loop
+ {
+ unsigned i = 0;
+ for (VectorDim dim : vectorType.getDims())
+ ASSERT_EQ(dim, dims[i++]);
+ ASSERT_EQ(i, vectorType.getRank());
+ }
+
+ // Test using VectorDims::Iterator in LLVM iterator helper
+ {
+ for (auto [dim, expectedDim] :
+ llvm::zip_equal(vectorType.getDims(), dims)) {
+ ASSERT_EQ(dim, expectedDim);
+ }
+ }
+
+ // Test dropFront()
+ {
+ auto vectorDims = vectorType.getDims();
+ auto newDims = vectorDims.dropFront();
+
+ ASSERT_EQ(newDims.size(), vectorDims.size() - 1);
+ for (unsigned i = 0; i < newDims.size(); i++)
+ ASSERT_EQ(newDims[i], vectorDims[i + 1]);
+ }
+
+ // Test dropBack()
+ {
+ auto vectorDims = vectorType.getDims();
+ auto newDims = vectorDims.dropBack();
+
+ ASSERT_EQ(newDims.size(), vectorDims.size() - 1);
+ for (unsigned i = 0; i < newDims.size(); i++)
+ ASSERT_EQ(newDims[i], vectorDims[i]);
+ }
+
+ // Test front()
+ { ASSERT_EQ(vectorType.getDims().front(), VectorDim::getFixed(2)); }
+
+ // Test back()
+ { ASSERT_EQ(vectorType.getDims().back(), VectorDim::getFixed(1)); }
+
+ // Test dropWhile.
+ {
+ SmallVector<VectorDim> dims{
+ VectorDim::getFixed(1), VectorDim::getFixed(1), VectorDim::getFixed(1),
+ VectorDim::getScalable(1), VectorDim::getScalable(4)};
+
+ VectorType vectorTypeWithLeadingUnitDims = VectorType::get(f32, dims);
+ ASSERT_EQ(vectorTypeWithLeadingUnitDims.getDims().size(),
+ unsigned(vectorTypeWithLeadingUnitDims.getRank()));
+
+ // Drop leading unit dims.
+ auto withoutLeadingUnitDims =
+ vectorTypeWithLeadingUnitDims.getDims().dropWhile(
+ [](VectorDim dim) { return dim == VectorDim::getFixed(1); });
+
+ SmallVector<VectorDim> expectedDims{VectorDim::getScalable(1),
+ VectorDim::getScalable(4)};
+ ASSERT_EQ(withoutLeadingUnitDims, expectedDims);
+ }
+}
+
} // namespace
|
Currently, the shape of a VectorType is stored in two separate lists. The 'shape' which comes from ShapedType, which does not have a way to represent scalability, and the 'scalableDims', an additional list of bools attached to VectorType. This can be somewhat cumbersome to work with, and easy to ignore the scalability of a dim, producing incorrect results. For example, to correctly trim leading unit dims of a VectorType, currently, you need to do something like: ```c++ while (!newShape.empty() && newShape.front() == 1 && !newScalableDims.front()) { newShape = newShape.drop_front(1); newScalableDims = newScalableDims.drop_front(1); } ``` Which would be wrong if you (more naturally) wrote it as: ```c++ auto newShape = vectorType.getShape().drop_while([](int64_t dim) { return dim == 1; }); ``` As this would trim scalable one dims (`[1]`), which are not unit dims like their fixed counterpart. This patch does not change the storage of the VectorType, but instead adds new scalability-safe accessors and iterators. Two new methods are added to VectorType: ``` /// Returns the value of the specified dimension (including scalability) VectorDim VectorType::getDim(unsigned idx) /// Returns the dimensions of this vector type (including scalability) VectorDims VectorType::getDims() ``` These are backed by two new classes: `VectorDim` and `VectorDims`. `VectorDim` represents a single dimension of a VectorType. It can be a fixed or scalable quantity. It cannot be implicitly converted to/from an integer, so you must specify the kind of quantity you expect in comparisons. `VectorDims` represents a non-owning list of vector dimensions, backed by separate size and scalability lists (matching the storage of VectorType). This class has an iterator, and a few common helper methods (similar to that of ArrayRef). There are also new builders to construct VectorTypes from both the `VectorDims` class and an `ArrayRef<VectorDim>`. With these changes the previous example becomes: ```c++ auto newDims = vectorType.getDims().dropWhile([](VectorDim dim) { return dim == VectorDim::getFixed(1); }); ``` Which (to me) is easier to read, and safer as it is not possible to forget check the scalability of the dim. Just comparing with `1`, would fail to build.
037c665
to
e95d21f
Compare
This is not a complete change, this just updates a few examples found by grepping for getScalableDims().
e95d21f
to
f69fad1
Compare
f69fad1 contains a few (randomly) updated examples. But I think it shows that with APIs like this, you can write safer and more idiomatic code where scalability is a bit less of a burden :) |
As a little summary this is what I'd like to solve:
I don't think this is a complete solution, but I think it's a fairly non-invasive step towards something nicer :) |
Currently, the shape of a VectorType is stored in two separate lists. The 'shape' which comes from ShapedType, which does not have a way to represent scalability, and the 'scalableDims', an additional list of bools attached to VectorType. This can be somewhat cumbersome to work with, and easy to ignore the scalability of a dim, producing incorrect results.
For example, to correctly trim leading unit dims of a VectorType, currently, you need to do something like:
Which would be wrong if you (more naturally) wrote it as:
As this would trim scalable one dims (
[1]
), which are not unit dims like their fixed counterparts.This patch does not change the storage of the VectorType, but instead adds new scalability-safe accessors and iterators.
Two new methods are added to VectorType:
These are backed by two new classes:
VectorDim
andVectorDims
.VectorDim
represents a single dimension of a VectorType. It can be a fixed or scalable quantity. It cannot be implicitly converted to/from an integer, so you must specify the kind of quantity you expect in comparisons.VectorDims
represents a non-owning list of vector dimensions, backed by separate size and scalability lists (matching the storage of VectorType). This class has an iterator, and a few common helper methods (similar to that of ArrayRef).There are also new builders to construct VectorTypes from both the
VectorDims
class and anArrayRef<VectorDim>
.With these changes the previous example becomes:
Which (to me) is easier to read, and safer as it is not possible to forget check the scalability of the dim. Just comparing with
1
, would fail to build.