Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Add first-class support for scalability in VectorType dims #74251

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Dec 3, 2023

Currently, the shape of a VectorType is stored in two separate lists. The 'shape' which comes from ShapedType, which does not have a way to represent scalability, and the 'scalableDims', an additional list of bools attached to VectorType. This can be somewhat cumbersome to work with, and easy to ignore the scalability of a dim, producing incorrect results.

For example, to correctly trim leading unit dims of a VectorType, currently, you need to do something like:

while (!newShape.empty() && newShape.front() == 1 &&
        !newScalableDims.front()) {
  newShape = newShape.drop_front(1);
  newScalableDims = newScalableDims.drop_front(1);
}

Which would be wrong if you (more naturally) wrote it as:

auto newShape = vectorType.getShape().drop_while([](int64_t dim) {
  return dim == 1;
});

As this would trim scalable one dims ([1]), which are not unit dims like their fixed counterparts.


This patch does not change the storage of the VectorType, but instead adds new scalability-safe accessors and iterators.

Two new methods are added to VectorType:

/// Returns the value of the specified dimension (including scalability)
VectorDim VectorType::getDim(unsigned idx);

/// Returns the dimensions of this vector type (including scalability)
VectorDims VectorType::getDims();

These are backed by two new classes: VectorDim and VectorDims.

VectorDim represents a single dimension of a VectorType. It can be a fixed or scalable quantity. It cannot be implicitly converted to/from an integer, so you must specify the kind of quantity you expect in comparisons.

VectorDims represents a non-owning list of vector dimensions, backed by separate size and scalability lists (matching the storage of VectorType). This class has an iterator, and a few common helper methods (similar to that of ArrayRef).

There are also new builders to construct VectorTypes from both the VectorDims class and an ArrayRef<VectorDim>.

With these changes the previous example becomes:

auto newDims = vectorType.getDims().dropWhile([](VectorDim dim) {
  return dim == VectorDim::getFixed(1);
});

Which (to me) is easier to read, and safer as it is not possible to forget check the scalability of the dim. Just comparing with 1, would fail to build.

@MacDue MacDue force-pushed the vector_type_dims branch 5 times, most recently from c8e828a to 037c665 Compare December 4, 2023 10:13
@MacDue MacDue marked this pull request as ready for review December 4, 2023 11:54
@MacDue MacDue requested review from joker-eph, banach-space and nicolasvasilache and removed request for joker-eph and banach-space December 4, 2023 11:54
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:ods labels Dec 4, 2023
@llvmbot
Copy link
Collaborator

llvmbot commented Dec 4, 2023

@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir-sve
@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir-ods

@llvm/pr-subscribers-mlir-core

Author: Benjamin Maxwell (MacDue)

Changes

Currently, the shape of a VectorType is stored in two separate lists. The 'shape' which comes from ShapedType, which does not have a way to represent scalability, and the 'scalableDims', an additional list of bools attached to VectorType. This can be somewhat cumbersome to work with, and easy to ignore the scalability of a dim, producing incorrect results.

For example, to correctly trim leading unit dims of a VectorType, currently, you need to do something like:

while (!newShape.empty() &amp;&amp; newShape.front() == 1 &amp;&amp;
        !newScalableDims.front()) {
  newShape = newShape.drop_front(1);
  newScalableDims = newScalableDims.drop_front(1);
}

Which would be wrong if you (more naturally) wrote it as:

auto newShape = vectorType.getShape().drop_while([](int64_t dim) {
  return dim == 1;
});

As this would trim scalable one dims ([1]), which are not unit dims like their fixed counterparts.


This patch does not change the storage of the VectorType, but instead adds new scalability-safe accessors and iterators.

Two new methods are added to VectorType:

/// Returns the value of the specified dimension (including scalability)
VectorDim VectorType::getDim(unsigned idx);

/// Returns the dimensions of this vector type (including scalability)
VectorDims VectorType::getDims();

These are backed by two new classes: VectorDim and VectorDims.

VectorDim represents a single dimension of a VectorType. It can be a fixed or scalable quantity. It cannot be implicitly converted to/from an integer, so you must specify the kind of quantity you expect in comparisons.

VectorDims represents a non-owning list of vector dimensions, backed by separate size and scalability lists (matching the storage of VectorType). This class has an iterator, and a few common helper methods (similar to that of ArrayRef).

There are also new builders to construct VectorTypes from both the VectorDims class and an ArrayRef&lt;VectorDim&gt;.

With these changes the previous example becomes:

auto newDims = vectorType.getDims().dropWhile([](VectorDim dim) {
  return dim == VectorDim::getFixed(1);
});

Which (to me) is easier to read, and safer as it is not possible to forget check the scalability of the dim. Just comparing with 1, would fail to build.


Full diff: https://github.com/llvm/llvm-project/pull/74251.diff

3 Files Affected:

  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+192)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+23)
  • (modified) mlir/unittests/IR/ShapedTypeTest.cpp (+101)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 92ce053ad5c82..b468fd42f374e 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -12,6 +12,7 @@
 #include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/Support/ADTExtras.h"
+#include "llvm/ADT/STLExtras.h"
 
 namespace llvm {
 class BitVector;
@@ -181,6 +182,197 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
   operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
 };
 
+//===----------------------------------------------------------------------===//
+// VectorDim
+//===----------------------------------------------------------------------===//
+
+/// This class represents a dimension of a vector type. Unlike other ShapedTypes
+/// vector dimensions can have scalable quantities, which means the dimension
+/// has a known minimum size, which is scaled by a constant that is only
+/// known at runtime.
+class VectorDim {
+public:
+  explicit constexpr VectorDim(int64_t quantity, bool scalable)
+      : quantity(quantity), scalable(scalable){};
+
+  /// Constructs a new fixed dimension.
+  constexpr static VectorDim getFixed(int64_t quantity) {
+    return VectorDim(quantity, false);
+  }
+
+  /// Constructs a new scalable dimension.
+  constexpr static VectorDim getScalable(int64_t quantity) {
+    return VectorDim(quantity, true);
+  }
+
+  /// Returns true if this dimension is scalable;
+  constexpr bool isScalable() const { return scalable; }
+
+  /// Returns true if this dimension is fixed.
+  constexpr bool isFixed() const { return !isScalable(); }
+
+  /// Returns the minimum number of elements this dimension can contain.
+  constexpr int64_t getMinSize() const { return quantity; }
+
+  /// If this dimension is fixed returns the number of elements, otherwise
+  /// aborts.
+  constexpr int64_t getFixedSize() const {
+    assert(isFixed());
+    return quantity;
+  }
+
+  constexpr bool operator==(VectorDim const &dim) const {
+    return quantity == dim.quantity && scalable == dim.scalable;
+  }
+
+  constexpr bool operator!=(VectorDim const &dim) const {
+    return !(*this == dim);
+  }
+
+  /// Helper class for indexing into a list of sizes (and possibly empty) list
+  /// of scalable dimensions, extracting VectorDim elements.
+  struct Indexer {
+    explicit Indexer(ArrayRef<int64_t> sizes, ArrayRef<bool> scalableDims)
+        : sizes(sizes), scalableDims(scalableDims) {
+      assert(
+          scalableDims.empty() ||
+          sizes.size() == scalableDims.size() &&
+              "expected `scalableDims` to be empty or match `sizes` in length");
+    }
+
+    VectorDim operator[](size_t idx) const {
+      int64_t size = sizes[idx];
+      bool scalable = scalableDims.empty() ? false : scalableDims[idx];
+      return VectorDim(size, scalable);
+    }
+
+    ArrayRef<int64_t> sizes;
+    ArrayRef<bool> scalableDims;
+  };
+
+private:
+  int64_t quantity;
+  bool scalable;
+};
+
+//===----------------------------------------------------------------------===//
+// VectorDims
+//===----------------------------------------------------------------------===//
+
+/// Represents a non-owning list of vector dimensions. The underlying dimension
+/// sizes and scalability flags are stored a two seperate lists to match the
+/// storage of a VectorType.
+class VectorDims : public VectorDim::Indexer {
+public:
+  using VectorDim::Indexer::Indexer;
+
+  class Iterator : public llvm::iterator_facade_base<
+                       Iterator, std::random_access_iterator_tag, VectorDim,
+                       std::ptrdiff_t, VectorDim, VectorDim> {
+  public:
+    Iterator(VectorDim::Indexer indexer, size_t index)
+        : indexer(indexer), index(index){};
+
+    // Iterator boilerplate.
+    ptrdiff_t operator-(const Iterator &rhs) const { return index - rhs.index; }
+    bool operator==(const Iterator &rhs) const { return index == rhs.index; }
+    bool operator<(const Iterator &rhs) const { return index < rhs.index; }
+    Iterator &operator+=(ptrdiff_t offset) {
+      index += offset;
+      return *this;
+    }
+    Iterator &operator-=(ptrdiff_t offset) {
+      index -= offset;
+      return *this;
+    }
+    VectorDim operator*() const { return indexer[index]; }
+
+    VectorDim::Indexer getIndexer() const { return indexer; }
+    ptrdiff_t getIndex() const { return index; }
+
+  private:
+    VectorDim::Indexer indexer;
+    ptrdiff_t index;
+  };
+
+  /// Construct from iterator pair.
+  VectorDims(Iterator begin, Iterator end)
+      : VectorDims(VectorDims(begin.getIndexer())
+                       .slice(begin.getIndex(), end - begin)) {}
+
+  VectorDims(VectorDim::Indexer indexer) : VectorDim::Indexer(indexer){};
+
+  Iterator begin() const { return Iterator(*this, 0); }
+  Iterator end() const { return Iterator(*this, size()); }
+
+  /// Check if the dims are empty.
+  bool empty() const { return sizes.empty(); }
+
+  /// Get the number of dims.
+  size_t size() const { return sizes.size(); }
+
+  /// Return the first dim.
+  VectorDim front() const { return (*this)[0]; }
+
+  /// Return the last dim.
+  VectorDim back() const { return (*this)[size() - 1]; }
+
+  /// Chop of thie first \p n dims, and keep the remaining \p m
+  /// dims.
+  VectorDims slice(size_t n, size_t m) const {
+    ArrayRef<int64_t> newSizes = sizes.slice(n, m);
+    ArrayRef<bool> newScalableDims =
+        scalableDims.empty() ? ArrayRef<bool>{} : scalableDims.slice(n, m);
+    return VectorDims(newSizes, newScalableDims);
+  }
+
+  /// Drop the first \p n dims.
+  VectorDims dropFront(size_t n = 1) const { return slice(n, size() - n); }
+
+  /// Drop the last \p n dims.
+  VectorDims dropBack(size_t n = 1) const { return slice(0, size() - n); }
+
+  /// Return copy of *this with the first n dims matching the predicate removed.
+  template <class PredicateT>
+  VectorDims dropWhile(PredicateT predicate) const {
+    return VectorDims(llvm::find_if_not(*this, predicate), end());
+  }
+
+  /// Return the underlying sizes.
+  ArrayRef<int64_t> getSizes() const { return sizes; }
+
+  /// Return the underlying scalable dims.
+  ArrayRef<bool> getScalableDims() const { return scalableDims; }
+
+  /// Check for dim equality.
+  bool equals(VectorDims rhs) const {
+    if (size() != rhs.size())
+      return false;
+    return std::equal(begin(), end(), rhs.begin());
+  }
+
+  /// Check for dim equality.
+  bool equals(ArrayRef<VectorDim> rhs) const {
+    if (size() != rhs.size())
+      return false;
+    return std::equal(begin(), end(), rhs.begin());
+  }
+};
+
+inline bool operator==(VectorDims lhs, VectorDims rhs) {
+  return lhs.equals(rhs);
+}
+
+inline bool operator!=(VectorDims lhs, VectorDims rhs) { return !(lhs == rhs); }
+
+inline bool operator==(VectorDims lhs, ArrayRef<VectorDim> rhs) {
+  return lhs.equals(rhs);
+}
+
+inline bool operator!=(VectorDims lhs, ArrayRef<VectorDim> rhs) {
+  return !(lhs == rhs);
+}
+
 } // namespace mlir
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 1d7772810ae6e..d6cd14079fab8 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1089,6 +1089,18 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
         scalableDims = isScalableVec;
       }
       return $_get(elementType.getContext(), shape, elementType, scalableDims);
+    }]>,
+    TypeBuilderWithInferredContext<(ins "Type":$elementType, "ArrayRef<VectorDim>": $shape), [{
+      SmallVector<int64_t> sizes;
+      SmallVector<bool> scalableDims;
+      for (VectorDim dim : shape) {
+        sizes.push_back(dim.getMinSize());
+        scalableDims.push_back(dim.isScalable());
+      }
+      return get(sizes, elementType, scalableDims);
+    }]>,
+    TypeBuilderWithInferredContext<(ins "Type":$elementType, "VectorDims": $shape), [{
+      return get(shape.getSizes(), elementType, shape.getScalableDims());
     }]>
   ];
   let extraClassDeclaration = [{
@@ -1096,6 +1108,17 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
     /// Arguments that are passed into the builder must outlive the builder.
     class Builder;
 
+    /// Returns the value of the specified dimension (including scalability).
+    VectorDim getDim(unsigned idx) const {
+      assert(idx < getRank() && "invalid dim index for vector type");
+      return getDims()[idx];
+    }
+
+    /// Returns the dimensions of this vector type (including scalability).
+    VectorDims getDims() const {
+      return VectorDims(getShape(), getScalableDims());
+    }
+
     /// Returns true if the given type can be used as an element of a vector
     /// type. In particular, vectors can consist of integer, index, or float
     /// primitives.
diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index 61264bc523648..07625da6ee889 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -226,4 +226,105 @@ TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
   }
 }
 
+TEST(ShapedTypeTest, VectorDims) {
+  MLIRContext context;
+  Type f32 = FloatType::getF32(&context);
+
+  SmallVector<VectorDim> dims{VectorDim::getFixed(2), VectorDim::getScalable(4),
+                              VectorDim::getFixed(8), VectorDim::getScalable(9),
+                              VectorDim::getFixed(1)};
+  VectorType vectorType = VectorType::get(f32, dims);
+
+  // Directly check values
+  {
+    auto dim0 = vectorType.getDim(0);
+    ASSERT_EQ(dim0.getMinSize(), 2);
+    ASSERT_TRUE(dim0.isFixed());
+
+    auto dim1 = vectorType.getDim(1);
+    ASSERT_EQ(dim1.getMinSize(), 4);
+    ASSERT_TRUE(dim1.isScalable());
+
+    auto dim2 = vectorType.getDim(2);
+    ASSERT_EQ(dim2.getMinSize(), 8);
+    ASSERT_TRUE(dim2.isFixed());
+
+    auto dim3 = vectorType.getDim(3);
+    ASSERT_EQ(dim3.getMinSize(), 9);
+    ASSERT_TRUE(dim3.isScalable());
+
+    auto dim4 = vectorType.getDim(4);
+    ASSERT_EQ(dim4.getMinSize(), 1);
+    ASSERT_TRUE(dim4.isFixed());
+  }
+
+  // Test indexing via getDim(idx)
+  {
+    for (unsigned i = 0; i < dims.size(); i++)
+      ASSERT_EQ(vectorType.getDim(i), dims[i]);
+  }
+
+  // Test using VectorDims::Iterator in for-each loop
+  {
+    unsigned i = 0;
+    for (VectorDim dim : vectorType.getDims())
+      ASSERT_EQ(dim, dims[i++]);
+    ASSERT_EQ(i, vectorType.getRank());
+  }
+
+  // Test using VectorDims::Iterator in LLVM iterator helper
+  {
+    for (auto [dim, expectedDim] :
+         llvm::zip_equal(vectorType.getDims(), dims)) {
+      ASSERT_EQ(dim, expectedDim);
+    }
+  }
+
+  // Test dropFront()
+  {
+    auto vectorDims = vectorType.getDims();
+    auto newDims = vectorDims.dropFront();
+
+    ASSERT_EQ(newDims.size(), vectorDims.size() - 1);
+    for (unsigned i = 0; i < newDims.size(); i++)
+      ASSERT_EQ(newDims[i], vectorDims[i + 1]);
+  }
+
+  // Test dropBack()
+  {
+    auto vectorDims = vectorType.getDims();
+    auto newDims = vectorDims.dropBack();
+
+    ASSERT_EQ(newDims.size(), vectorDims.size() - 1);
+    for (unsigned i = 0; i < newDims.size(); i++)
+      ASSERT_EQ(newDims[i], vectorDims[i]);
+  }
+
+  // Test front()
+  { ASSERT_EQ(vectorType.getDims().front(), VectorDim::getFixed(2)); }
+
+  // Test back()
+  { ASSERT_EQ(vectorType.getDims().back(), VectorDim::getFixed(1)); }
+
+  // Test dropWhile.
+  {
+    SmallVector<VectorDim> dims{
+        VectorDim::getFixed(1), VectorDim::getFixed(1), VectorDim::getFixed(1),
+        VectorDim::getScalable(1), VectorDim::getScalable(4)};
+
+    VectorType vectorTypeWithLeadingUnitDims = VectorType::get(f32, dims);
+    ASSERT_EQ(vectorTypeWithLeadingUnitDims.getDims().size(),
+              unsigned(vectorTypeWithLeadingUnitDims.getRank()));
+
+    // Drop leading unit dims.
+    auto withoutLeadingUnitDims =
+        vectorTypeWithLeadingUnitDims.getDims().dropWhile(
+            [](VectorDim dim) { return dim == VectorDim::getFixed(1); });
+
+    SmallVector<VectorDim> expectedDims{VectorDim::getScalable(1),
+                                        VectorDim::getScalable(4)};
+    ASSERT_EQ(withoutLeadingUnitDims, expectedDims);
+  }
+}
+
 } // namespace

@llvmbot
Copy link
Collaborator

llvmbot commented Dec 4, 2023

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

Changes

Currently, the shape of a VectorType is stored in two separate lists. The 'shape' which comes from ShapedType, which does not have a way to represent scalability, and the 'scalableDims', an additional list of bools attached to VectorType. This can be somewhat cumbersome to work with, and easy to ignore the scalability of a dim, producing incorrect results.

For example, to correctly trim leading unit dims of a VectorType, currently, you need to do something like:

while (!newShape.empty() &amp;&amp; newShape.front() == 1 &amp;&amp;
        !newScalableDims.front()) {
  newShape = newShape.drop_front(1);
  newScalableDims = newScalableDims.drop_front(1);
}

Which would be wrong if you (more naturally) wrote it as:

auto newShape = vectorType.getShape().drop_while([](int64_t dim) {
  return dim == 1;
});

As this would trim scalable one dims ([1]), which are not unit dims like their fixed counterparts.


This patch does not change the storage of the VectorType, but instead adds new scalability-safe accessors and iterators.

Two new methods are added to VectorType:

/// Returns the value of the specified dimension (including scalability)
VectorDim VectorType::getDim(unsigned idx);

/// Returns the dimensions of this vector type (including scalability)
VectorDims VectorType::getDims();

These are backed by two new classes: VectorDim and VectorDims.

VectorDim represents a single dimension of a VectorType. It can be a fixed or scalable quantity. It cannot be implicitly converted to/from an integer, so you must specify the kind of quantity you expect in comparisons.

VectorDims represents a non-owning list of vector dimensions, backed by separate size and scalability lists (matching the storage of VectorType). This class has an iterator, and a few common helper methods (similar to that of ArrayRef).

There are also new builders to construct VectorTypes from both the VectorDims class and an ArrayRef&lt;VectorDim&gt;.

With these changes the previous example becomes:

auto newDims = vectorType.getDims().dropWhile([](VectorDim dim) {
  return dim == VectorDim::getFixed(1);
});

Which (to me) is easier to read, and safer as it is not possible to forget check the scalability of the dim. Just comparing with 1, would fail to build.


Full diff: https://github.com/llvm/llvm-project/pull/74251.diff

3 Files Affected:

  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+192)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+23)
  • (modified) mlir/unittests/IR/ShapedTypeTest.cpp (+101)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 92ce053ad5c82..b468fd42f374e 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -12,6 +12,7 @@
 #include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/Support/ADTExtras.h"
+#include "llvm/ADT/STLExtras.h"
 
 namespace llvm {
 class BitVector;
@@ -181,6 +182,197 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
   operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
 };
 
+//===----------------------------------------------------------------------===//
+// VectorDim
+//===----------------------------------------------------------------------===//
+
+/// This class represents a dimension of a vector type. Unlike other ShapedTypes
+/// vector dimensions can have scalable quantities, which means the dimension
+/// has a known minimum size, which is scaled by a constant that is only
+/// known at runtime.
+class VectorDim {
+public:
+  explicit constexpr VectorDim(int64_t quantity, bool scalable)
+      : quantity(quantity), scalable(scalable){};
+
+  /// Constructs a new fixed dimension.
+  constexpr static VectorDim getFixed(int64_t quantity) {
+    return VectorDim(quantity, false);
+  }
+
+  /// Constructs a new scalable dimension.
+  constexpr static VectorDim getScalable(int64_t quantity) {
+    return VectorDim(quantity, true);
+  }
+
+  /// Returns true if this dimension is scalable;
+  constexpr bool isScalable() const { return scalable; }
+
+  /// Returns true if this dimension is fixed.
+  constexpr bool isFixed() const { return !isScalable(); }
+
+  /// Returns the minimum number of elements this dimension can contain.
+  constexpr int64_t getMinSize() const { return quantity; }
+
+  /// If this dimension is fixed returns the number of elements, otherwise
+  /// aborts.
+  constexpr int64_t getFixedSize() const {
+    assert(isFixed());
+    return quantity;
+  }
+
+  constexpr bool operator==(VectorDim const &dim) const {
+    return quantity == dim.quantity && scalable == dim.scalable;
+  }
+
+  constexpr bool operator!=(VectorDim const &dim) const {
+    return !(*this == dim);
+  }
+
+  /// Helper class for indexing into a list of sizes (and possibly empty) list
+  /// of scalable dimensions, extracting VectorDim elements.
+  struct Indexer {
+    explicit Indexer(ArrayRef<int64_t> sizes, ArrayRef<bool> scalableDims)
+        : sizes(sizes), scalableDims(scalableDims) {
+      assert(
+          scalableDims.empty() ||
+          sizes.size() == scalableDims.size() &&
+              "expected `scalableDims` to be empty or match `sizes` in length");
+    }
+
+    VectorDim operator[](size_t idx) const {
+      int64_t size = sizes[idx];
+      bool scalable = scalableDims.empty() ? false : scalableDims[idx];
+      return VectorDim(size, scalable);
+    }
+
+    ArrayRef<int64_t> sizes;
+    ArrayRef<bool> scalableDims;
+  };
+
+private:
+  int64_t quantity;
+  bool scalable;
+};
+
+//===----------------------------------------------------------------------===//
+// VectorDims
+//===----------------------------------------------------------------------===//
+
+/// Represents a non-owning list of vector dimensions. The underlying dimension
+/// sizes and scalability flags are stored a two seperate lists to match the
+/// storage of a VectorType.
+class VectorDims : public VectorDim::Indexer {
+public:
+  using VectorDim::Indexer::Indexer;
+
+  class Iterator : public llvm::iterator_facade_base<
+                       Iterator, std::random_access_iterator_tag, VectorDim,
+                       std::ptrdiff_t, VectorDim, VectorDim> {
+  public:
+    Iterator(VectorDim::Indexer indexer, size_t index)
+        : indexer(indexer), index(index){};
+
+    // Iterator boilerplate.
+    ptrdiff_t operator-(const Iterator &rhs) const { return index - rhs.index; }
+    bool operator==(const Iterator &rhs) const { return index == rhs.index; }
+    bool operator<(const Iterator &rhs) const { return index < rhs.index; }
+    Iterator &operator+=(ptrdiff_t offset) {
+      index += offset;
+      return *this;
+    }
+    Iterator &operator-=(ptrdiff_t offset) {
+      index -= offset;
+      return *this;
+    }
+    VectorDim operator*() const { return indexer[index]; }
+
+    VectorDim::Indexer getIndexer() const { return indexer; }
+    ptrdiff_t getIndex() const { return index; }
+
+  private:
+    VectorDim::Indexer indexer;
+    ptrdiff_t index;
+  };
+
+  /// Construct from iterator pair.
+  VectorDims(Iterator begin, Iterator end)
+      : VectorDims(VectorDims(begin.getIndexer())
+                       .slice(begin.getIndex(), end - begin)) {}
+
+  VectorDims(VectorDim::Indexer indexer) : VectorDim::Indexer(indexer){};
+
+  Iterator begin() const { return Iterator(*this, 0); }
+  Iterator end() const { return Iterator(*this, size()); }
+
+  /// Check if the dims are empty.
+  bool empty() const { return sizes.empty(); }
+
+  /// Get the number of dims.
+  size_t size() const { return sizes.size(); }
+
+  /// Return the first dim.
+  VectorDim front() const { return (*this)[0]; }
+
+  /// Return the last dim.
+  VectorDim back() const { return (*this)[size() - 1]; }
+
+  /// Chop of thie first \p n dims, and keep the remaining \p m
+  /// dims.
+  VectorDims slice(size_t n, size_t m) const {
+    ArrayRef<int64_t> newSizes = sizes.slice(n, m);
+    ArrayRef<bool> newScalableDims =
+        scalableDims.empty() ? ArrayRef<bool>{} : scalableDims.slice(n, m);
+    return VectorDims(newSizes, newScalableDims);
+  }
+
+  /// Drop the first \p n dims.
+  VectorDims dropFront(size_t n = 1) const { return slice(n, size() - n); }
+
+  /// Drop the last \p n dims.
+  VectorDims dropBack(size_t n = 1) const { return slice(0, size() - n); }
+
+  /// Return copy of *this with the first n dims matching the predicate removed.
+  template <class PredicateT>
+  VectorDims dropWhile(PredicateT predicate) const {
+    return VectorDims(llvm::find_if_not(*this, predicate), end());
+  }
+
+  /// Return the underlying sizes.
+  ArrayRef<int64_t> getSizes() const { return sizes; }
+
+  /// Return the underlying scalable dims.
+  ArrayRef<bool> getScalableDims() const { return scalableDims; }
+
+  /// Check for dim equality.
+  bool equals(VectorDims rhs) const {
+    if (size() != rhs.size())
+      return false;
+    return std::equal(begin(), end(), rhs.begin());
+  }
+
+  /// Check for dim equality.
+  bool equals(ArrayRef<VectorDim> rhs) const {
+    if (size() != rhs.size())
+      return false;
+    return std::equal(begin(), end(), rhs.begin());
+  }
+};
+
+inline bool operator==(VectorDims lhs, VectorDims rhs) {
+  return lhs.equals(rhs);
+}
+
+inline bool operator!=(VectorDims lhs, VectorDims rhs) { return !(lhs == rhs); }
+
+inline bool operator==(VectorDims lhs, ArrayRef<VectorDim> rhs) {
+  return lhs.equals(rhs);
+}
+
+inline bool operator!=(VectorDims lhs, ArrayRef<VectorDim> rhs) {
+  return !(lhs == rhs);
+}
+
 } // namespace mlir
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 1d7772810ae6e..d6cd14079fab8 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1089,6 +1089,18 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
         scalableDims = isScalableVec;
       }
       return $_get(elementType.getContext(), shape, elementType, scalableDims);
+    }]>,
+    TypeBuilderWithInferredContext<(ins "Type":$elementType, "ArrayRef<VectorDim>": $shape), [{
+      SmallVector<int64_t> sizes;
+      SmallVector<bool> scalableDims;
+      for (VectorDim dim : shape) {
+        sizes.push_back(dim.getMinSize());
+        scalableDims.push_back(dim.isScalable());
+      }
+      return get(sizes, elementType, scalableDims);
+    }]>,
+    TypeBuilderWithInferredContext<(ins "Type":$elementType, "VectorDims": $shape), [{
+      return get(shape.getSizes(), elementType, shape.getScalableDims());
     }]>
   ];
   let extraClassDeclaration = [{
@@ -1096,6 +1108,17 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
     /// Arguments that are passed into the builder must outlive the builder.
     class Builder;
 
+    /// Returns the value of the specified dimension (including scalability).
+    VectorDim getDim(unsigned idx) const {
+      assert(idx < getRank() && "invalid dim index for vector type");
+      return getDims()[idx];
+    }
+
+    /// Returns the dimensions of this vector type (including scalability).
+    VectorDims getDims() const {
+      return VectorDims(getShape(), getScalableDims());
+    }
+
     /// Returns true if the given type can be used as an element of a vector
     /// type. In particular, vectors can consist of integer, index, or float
     /// primitives.
diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index 61264bc523648..07625da6ee889 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -226,4 +226,105 @@ TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
   }
 }
 
+TEST(ShapedTypeTest, VectorDims) {
+  MLIRContext context;
+  Type f32 = FloatType::getF32(&context);
+
+  SmallVector<VectorDim> dims{VectorDim::getFixed(2), VectorDim::getScalable(4),
+                              VectorDim::getFixed(8), VectorDim::getScalable(9),
+                              VectorDim::getFixed(1)};
+  VectorType vectorType = VectorType::get(f32, dims);
+
+  // Directly check values
+  {
+    auto dim0 = vectorType.getDim(0);
+    ASSERT_EQ(dim0.getMinSize(), 2);
+    ASSERT_TRUE(dim0.isFixed());
+
+    auto dim1 = vectorType.getDim(1);
+    ASSERT_EQ(dim1.getMinSize(), 4);
+    ASSERT_TRUE(dim1.isScalable());
+
+    auto dim2 = vectorType.getDim(2);
+    ASSERT_EQ(dim2.getMinSize(), 8);
+    ASSERT_TRUE(dim2.isFixed());
+
+    auto dim3 = vectorType.getDim(3);
+    ASSERT_EQ(dim3.getMinSize(), 9);
+    ASSERT_TRUE(dim3.isScalable());
+
+    auto dim4 = vectorType.getDim(4);
+    ASSERT_EQ(dim4.getMinSize(), 1);
+    ASSERT_TRUE(dim4.isFixed());
+  }
+
+  // Test indexing via getDim(idx)
+  {
+    for (unsigned i = 0; i < dims.size(); i++)
+      ASSERT_EQ(vectorType.getDim(i), dims[i]);
+  }
+
+  // Test using VectorDims::Iterator in for-each loop
+  {
+    unsigned i = 0;
+    for (VectorDim dim : vectorType.getDims())
+      ASSERT_EQ(dim, dims[i++]);
+    ASSERT_EQ(i, vectorType.getRank());
+  }
+
+  // Test using VectorDims::Iterator in LLVM iterator helper
+  {
+    for (auto [dim, expectedDim] :
+         llvm::zip_equal(vectorType.getDims(), dims)) {
+      ASSERT_EQ(dim, expectedDim);
+    }
+  }
+
+  // Test dropFront()
+  {
+    auto vectorDims = vectorType.getDims();
+    auto newDims = vectorDims.dropFront();
+
+    ASSERT_EQ(newDims.size(), vectorDims.size() - 1);
+    for (unsigned i = 0; i < newDims.size(); i++)
+      ASSERT_EQ(newDims[i], vectorDims[i + 1]);
+  }
+
+  // Test dropBack()
+  {
+    auto vectorDims = vectorType.getDims();
+    auto newDims = vectorDims.dropBack();
+
+    ASSERT_EQ(newDims.size(), vectorDims.size() - 1);
+    for (unsigned i = 0; i < newDims.size(); i++)
+      ASSERT_EQ(newDims[i], vectorDims[i]);
+  }
+
+  // Test front()
+  { ASSERT_EQ(vectorType.getDims().front(), VectorDim::getFixed(2)); }
+
+  // Test back()
+  { ASSERT_EQ(vectorType.getDims().back(), VectorDim::getFixed(1)); }
+
+  // Test dropWhile.
+  {
+    SmallVector<VectorDim> dims{
+        VectorDim::getFixed(1), VectorDim::getFixed(1), VectorDim::getFixed(1),
+        VectorDim::getScalable(1), VectorDim::getScalable(4)};
+
+    VectorType vectorTypeWithLeadingUnitDims = VectorType::get(f32, dims);
+    ASSERT_EQ(vectorTypeWithLeadingUnitDims.getDims().size(),
+              unsigned(vectorTypeWithLeadingUnitDims.getRank()));
+
+    // Drop leading unit dims.
+    auto withoutLeadingUnitDims =
+        vectorTypeWithLeadingUnitDims.getDims().dropWhile(
+            [](VectorDim dim) { return dim == VectorDim::getFixed(1); });
+
+    SmallVector<VectorDim> expectedDims{VectorDim::getScalable(1),
+                                        VectorDim::getScalable(4)};
+    ASSERT_EQ(withoutLeadingUnitDims, expectedDims);
+  }
+}
+
 } // namespace

Currently, the shape of a VectorType is stored in two separate lists.
The 'shape' which comes from ShapedType, which does not have a way to
represent scalability, and the 'scalableDims', an additional list of
bools attached to VectorType. This can be somewhat cumbersome to work
with, and easy to ignore the scalability of a dim, producing incorrect
results.

For example, to correctly trim leading unit dims of a VectorType,
currently, you need to do something like:

```c++
  while (!newShape.empty() && newShape.front() == 1 &&
         !newScalableDims.front()) {
    newShape = newShape.drop_front(1);
    newScalableDims = newScalableDims.drop_front(1);
  }
```

Which would be wrong if you (more naturally) wrote it as:

```c++
  auto newShape = vectorType.getShape().drop_while([](int64_t dim) {
    return dim == 1;
  });
```

As this would trim scalable one dims (`[1]`), which are not unit dims
like their fixed counterpart.

This patch does not change the storage of the VectorType, but instead
adds new scalability-safe accessors and iterators.

Two new methods are added to VectorType:

```
/// Returns the value of the specified dimension (including scalability)
VectorDim VectorType::getDim(unsigned idx)

/// Returns the dimensions of this vector type (including scalability)
VectorDims VectorType::getDims()
```

These are backed by two new classes: `VectorDim` and `VectorDims`.

`VectorDim` represents a single dimension of a VectorType. It can be a
fixed or scalable quantity. It cannot be implicitly converted to/from an
integer, so you must specify the kind of quantity you expect in
comparisons.

`VectorDims` represents a non-owning list of vector dimensions, backed
by separate size and scalability lists (matching the storage of
VectorType). This class has an iterator, and a few common helper methods
(similar to that of ArrayRef).

There are also new builders to construct VectorTypes from both
the `VectorDims` class and an `ArrayRef<VectorDim>`.

With these changes the previous example becomes:

```c++
  auto newDims = vectorType.getDims().dropWhile([](VectorDim dim) {
    return dim == VectorDim::getFixed(1);
  });
```

Which (to me) is easier to read, and safer as it is not possible to
forget check the scalability of the dim. Just comparing with `1`, would
fail to build.
This is not a complete change, this just updates a few examples found
by grepping for getScalableDims().
@MacDue
Copy link
Member Author

MacDue commented Dec 7, 2023

f69fad1 contains a few (randomly) updated examples. But I think it shows that with APIs like this, you can write safer and more idiomatic code where scalability is a bit less of a burden :)

@MacDue
Copy link
Member Author

MacDue commented Dec 7, 2023

As a little summary this is what I'd like to solve:

  • Safety
    • You have to specify a type of quantity in dim comparisons (fixed or scalable)
    • Can't forget to check the scalability of a dim (no implicit conversions, & can't forget to look at the scalableDims array)
    • Can't easily drop your scalableDims (they're safely stored in your VectorDims)
  • Convenience & readability
    • It should not be harder for people writing scalable-aware code to do things correctly
      • Standard handy iterators and utilities work with VectorDims
      • Simple helpers inspecting/updating your vector dims
    • It should be easy to read some code and see if it's doing the correct checks
      • Currently, when you do something like getDimSize(0) == 1 && !getScalableDims()[0], it obscures that's checking for VectorDim::getFixed(1)

I don't think this is a complete solution, but I think it's a fairly non-invasive step towards something nicer :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants