Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Add optional layout attribute to VectorType #71916

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

harsh-nod
Copy link
Contributor

This patch adds an attribute interface for representing the layout on vector types. This layout could be used to represent the mapping from the vector indices to the indices of the vector fragments held by different threads of a GPU.

The interface has a verify function that can be used to validate that the layout accurately represents the vector shape.

This patch adds an attribute interface for representing the
layout on vector types. This layout could be used to represent
the mapping from the vector indices to the indices of the
vector fragments held by different threads of a GPU.

The interface has a verify function that can be used to validate
that the layout accurately represents the vector shape.
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:ods labels Nov 10, 2023
@llvmbot
Copy link
Collaborator

llvmbot commented Nov 10, 2023

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-ods

Author: None (harsh-nod)

Changes

This patch adds an attribute interface for representing the layout on vector types. This layout could be used to represent the mapping from the vector indices to the indices of the vector fragments held by different threads of a GPU.

The interface has a verify function that can be used to validate that the layout accurately represents the vector shape.


Full diff: https://github.com/llvm/llvm-project/pull/71916.diff

8 Files Affected:

  • (modified) mlir/include/mlir/IR/BuiltinAttributeInterfaces.td (+24)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+11-3)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+17-4)
  • (modified) mlir/lib/AsmParser/TypeParser.cpp (+26-2)
  • (modified) mlir/lib/IR/AsmPrinter.cpp (+5)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (+7-1)
  • (modified) mlir/unittests/Interfaces/CMakeLists.txt (+1)
  • (added) mlir/unittests/Interfaces/VectorLayoutInterfaceTest.cpp (+158)
diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
index c741db9b47f34e5..9241cac8c3b98a0 100644
--- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
@@ -495,4 +495,28 @@ def MemRefLayoutAttrInterface : AttrInterface<"MemRefLayoutAttrInterface"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// VectorLayoutAttrInterface
+//===----------------------------------------------------------------------===//
+
+def VectorLayoutAttrInterface : AttrInterface<"VectorLayoutAttrInterface"> {
+  let cppNamespace = "::mlir";
+
+  let description = [{
+    This interface is used for attributes that can represent the Vector type's
+    layout semantics, such as being able to map the vector indices to those
+    of the vector fragments held by individiual threads.
+  }];
+
+  let methods = [
+    InterfaceMethod<
+      "Check if the current layout is applicable to the provided shape",
+      "::mlir::LogicalResult", "verifyLayout",
+      (ins "::llvm::ArrayRef<int64_t>":$shape,
+           "::mlir::Type":$elementType,
+           "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
+    >
+  ];
+}
+
 #endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 92ce053ad5c829b..a387390b38e7de4 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -307,12 +307,14 @@ class VectorType::Builder {
   /// Build from another VectorType.
   explicit Builder(VectorType other)
       : elementType(other.getElementType()), shape(other.getShape()),
-        scalableDims(other.getScalableDims()) {}
+        scalableDims(other.getScalableDims()), layout(other.getLayout()) {}
 
   /// Build from scratch.
   Builder(ArrayRef<int64_t> shape, Type elementType,
-          ArrayRef<bool> scalableDims = {})
-      : elementType(elementType), shape(shape), scalableDims(scalableDims) {}
+          ArrayRef<bool> scalableDims = {},
+          VectorLayoutAttrInterface layout = {})
+      : elementType(elementType), shape(shape), scalableDims(scalableDims),
+        layout(layout) {}
 
   Builder &setShape(ArrayRef<int64_t> newShape,
                     ArrayRef<bool> newIsScalableDim = {}) {
@@ -342,6 +344,11 @@ class VectorType::Builder {
     return *this;
   }
 
+  Builder &setLayout(VectorLayoutAttrInterface newLayout) {
+    layout = newLayout;
+    return *this;
+  }
+
   operator VectorType() {
     return VectorType::get(shape, elementType, scalableDims);
   }
@@ -350,6 +357,7 @@ class VectorType::Builder {
   Type elementType;
   CopyOnWriteArrayRef<int64_t> shape;
   CopyOnWriteArrayRef<bool> scalableDims;
+  VectorLayoutAttrInterface layout;
 };
 
 /// Given an `originalShape` and a `reducedShape` assumed to be a subset of
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 5ec986ac26de06b..3a2193d7a1768f7 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1029,11 +1029,13 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
     Syntax:
 
     ```
-    vector-type ::= `vector` `<` vector-dim-list vector-element-type `>`
+    vector-type ::= `vector` `<` vector-dim-list vector-element-type
+                    (`,` layout-specification)? `>`
     vector-element-type ::= float-type | integer-type | index-type
     vector-dim-list := (static-dim-list `x`)?
     static-dim-list ::= static-dim (`x` static-dim)*
     static-dim ::= (decimal-literal | `[` decimal-literal `]`)
+    layout-specification ::= attribute-value
     ```
 
     The vector type represents a SIMD style vector used by target-specific
@@ -1050,6 +1052,14 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
     declarations, `vector<0x42xi32>` is invalid because it is interpreted as a
     2D vector with shape `(0, 42)` and zero shapes are not allowed.
 
+    ##### Layout
+
+    A vector may optionally have a layout that indicates how indices of
+    the vector are transformed to indices of the vector fragments that
+    are held by individual threads in a SIMT execution model. Such layouts
+    are common in a wide variety of GPU matrix multiplication instructions.
+    The layout can be any attribute that implements `VectorLayoutAttrInterface`.
+
     Examples:
 
     ```mlir
@@ -1068,17 +1078,20 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
     // A 3D mixed fixed/scalable vector in which only the inner dimension is
     // scalable.
     vector<2x[4]x8xf32>
+
     ```
   }];
   let parameters = (ins
     ArrayRefParameter<"int64_t">:$shape,
     "Type":$elementType,
-    ArrayRefParameter<"bool">:$scalableDims
+    ArrayRefParameter<"bool">:$scalableDims,
+    "VectorLayoutAttrInterface":$layout
   );
   let builders = [
     TypeBuilderWithInferredContext<(ins
       "ArrayRef<int64_t>":$shape, "Type":$elementType,
-      CArg<"ArrayRef<bool>", "{}">:$scalableDims
+      CArg<"ArrayRef<bool>", "{}">:$scalableDims,
+      CArg<"VectorLayoutAttrInterface", "{}">:$layout
     ), [{
       // While `scalableDims` is optional, its default value should be
       // `false` for every dim in `shape`.
@@ -1087,7 +1100,7 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
         isScalableVec.resize(shape.size(), false);
         scalableDims = isScalableVec;
       }
-      return $_get(elementType.getContext(), shape, elementType, scalableDims);
+      return $_get(elementType.getContext(), shape, elementType, scalableDims, layout);
     }]>
   ];
   let extraClassDeclaration = [{
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index adaefb78172c2ea..be8c84ee74e8688 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -459,14 +459,38 @@ VectorType Parser::parseVectorType() {
   // Parse the element type.
   auto typeLoc = getToken().getLoc();
   auto elementType = parseType();
-  if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
+  if (!elementType)
     return nullptr;
 
   if (!VectorType::isValidElementType(elementType))
     return emitError(typeLoc, "vector elements must be int/index/float type"),
            nullptr;
 
-  return VectorType::get(dimensions, elementType, scalableDims);
+  VectorLayoutAttrInterface layout;
+  auto parseElt = [&]() -> ParseResult {
+    Attribute attr = parseAttribute();
+    if (!attr)
+      return failure();
+    if (isa<VectorLayoutAttrInterface>(attr)) {
+      layout = cast<VectorLayoutAttrInterface>(attr);
+    }
+    return success();
+  };
+
+  // Parse a list of mappings and address space if present.
+  if (!consumeIf(Token::greater)) {
+    // Parse comma separated list of affine maps, followed by memory space.
+    if (parseToken(Token::comma, "expected ',' or '>' in vector type") ||
+        parseCommaSeparatedListUntil(Token::greater, parseElt,
+                                     /*allowEmptyList=*/false)) {
+      return nullptr;
+    }
+  }
+
+  if (!layout)
+    return VectorType::get(dimensions, elementType, scalableDims);
+
+  return VectorType::get(dimensions, elementType, scalableDims, layout);
 }
 
 /// Parse a dimension list in a vector type. This populates the dimension list.
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index dae7fdd40b5456c..458140b0de81c50 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2572,6 +2572,11 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
           os << 'x';
         }
         printType(vectorTy.getElementType());
+        VectorLayoutAttrInterface layout = vectorTy.getLayout();
+        if (layout) {
+          os << ", ";
+          printAttribute(vectorTy.getLayout(), AttrTypeElision::May);
+        }
         os << '>';
       })
       .Case<RankedTensorType>([&](RankedTensorType tensorTy) {
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index a9284d5714637bc..b0ebec169e3c841 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -227,7 +227,8 @@ LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
 
 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
                                  ArrayRef<int64_t> shape, Type elementType,
-                                 ArrayRef<bool> scalableDims) {
+                                 ArrayRef<bool> scalableDims,
+                                 VectorLayoutAttrInterface layout) {
   if (!isValidElementType(elementType))
     return emitError()
            << "vector elements must be int/index/float type but got "
@@ -242,6 +243,11 @@ LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
     return emitError() << "number of dims must match, got "
                        << scalableDims.size() << " and " << shape.size();
 
+  if (layout) {
+    if (failed(layout.verifyLayout(shape, elementType, emitError)))
+      return emitError() << "layout does not match underlying vector shape";
+  }
+
   return success();
 }
 
diff --git a/mlir/unittests/Interfaces/CMakeLists.txt b/mlir/unittests/Interfaces/CMakeLists.txt
index d192b2922d6b9dc..45a518f8549a3c6 100644
--- a/mlir/unittests/Interfaces/CMakeLists.txt
+++ b/mlir/unittests/Interfaces/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_unittest(MLIRInterfacesTests
   DataLayoutInterfacesTest.cpp
   InferIntRangeInterfaceTest.cpp
   InferTypeOpInterfaceTest.cpp
+  VectorLayoutInterfaceTest.cpp
 )
 
 target_link_libraries(MLIRInterfacesTests
diff --git a/mlir/unittests/Interfaces/VectorLayoutInterfaceTest.cpp b/mlir/unittests/Interfaces/VectorLayoutInterfaceTest.cpp
new file mode 100644
index 000000000000000..0ea9a710138e3a4
--- /dev/null
+++ b/mlir/unittests/Interfaces/VectorLayoutInterfaceTest.cpp
@@ -0,0 +1,158 @@
+//===-- VectorLayoutInterfaceTest.cpp - Unit Tests for Vector Layouts -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Parser/Parser.h"
+
+#include <gtest/gtest.h>
+
+using namespace mlir;
+using namespace mlir::detail;
+
+class NamedStridedLayoutAttrStorage : public AttributeStorage {
+public:
+  using KeyTy =
+      std::tuple<ArrayRef<std::string>, ArrayRef<int64_t>, ArrayRef<int64_t>>;
+
+  NamedStridedLayoutAttrStorage(ArrayRef<std::string> names,
+                                ArrayRef<int64_t> strides,
+                                ArrayRef<int64_t> vectorShape)
+      : names(names), strides(strides), vectorShape(vectorShape) {}
+
+  bool operator==(const KeyTy &key) const {
+    return (std::get<0>(key) == names) && (std::get<1>(key) == strides) &&
+           (std::get<2>(key) == vectorShape);
+  }
+
+  static NamedStridedLayoutAttrStorage *
+  construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+    ArrayRef<std::string> names = allocator.copyInto(std::get<0>(key));
+    ArrayRef<int64_t> strides = allocator.copyInto(std::get<1>(key));
+    ArrayRef<int64_t> vectorShape = allocator.copyInto(std::get<2>(key));
+    return new (allocator.allocate<NamedStridedLayoutAttrStorage>())
+        NamedStridedLayoutAttrStorage(names, strides, vectorShape);
+  }
+
+  ArrayRef<std::string> names;
+  ArrayRef<int64_t> strides;
+  ArrayRef<int64_t> vectorShape;
+};
+
+struct NamedStridedLayoutAttr
+    : public Attribute::AttrBase<NamedStridedLayoutAttr, Attribute,
+                                 NamedStridedLayoutAttrStorage,
+                                 VectorLayoutAttrInterface::Trait> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NamedStridedLayoutAttr)
+  using Base::Base;
+  static NamedStridedLayoutAttr get(MLIRContext *ctx,
+                                    ArrayRef<std::string> names,
+                                    ArrayRef<int64_t> strides,
+                                    ArrayRef<int64_t> vectorShape) {
+    return Base::get(ctx, names, strides, vectorShape);
+  }
+
+  LogicalResult verifyLayout(ArrayRef<int64_t> shape, Type elementType,
+                             function_ref<InFlightDiagnostic()> emitError) {
+    if (shape == getVectorShape())
+      return success();
+    return failure();
+  }
+
+  ArrayRef<std::string> getNames() { return getImpl()->names; }
+  ArrayRef<int64_t> getStrides() { return getImpl()->strides; }
+  ArrayRef<int64_t> getVectorShape() { return getImpl()->vectorShape; }
+};
+
+struct VLTestDialect : Dialect {
+  explicit VLTestDialect(MLIRContext *ctx)
+      : Dialect(getDialectNamespace(), ctx, TypeID::get<VLTestDialect>()) {
+    ctx->loadDialect<VLTestDialect>();
+    addAttributes<NamedStridedLayoutAttr>();
+  }
+  static StringRef getDialectNamespace() { return "vltest"; }
+
+  Attribute parseAttribute(DialectAsmParser &parser, Type type) const override {
+    SmallVector<int64_t> strides, vectorShape;
+    SmallVector<std::string> names;
+    if (!succeeded(parser.parseKeyword("named_strided_layout")))
+      return {};
+    if (!succeeded(parser.parseLess()))
+      return {};
+    do {
+      if (!succeeded(parser.parseLSquare()))
+        return {};
+      std::string name;
+      int64_t stride;
+      int64_t shape = 1;
+      do {
+        if (succeeded(parser.parseString(&name)) &&
+            succeeded(parser.parseColon()) &&
+            succeeded(parser.parseInteger(stride))) {
+          names.push_back(name);
+          strides.push_back(stride);
+          shape *= stride;
+        }
+      } while (succeeded(parser.parseOptionalComma()));
+      if (!succeeded(parser.parseRSquare()))
+        return {};
+      vectorShape.push_back(shape);
+    } while (succeeded(parser.parseOptionalComma()));
+    if (!succeeded(parser.parseGreater()))
+      return {};
+    return NamedStridedLayoutAttr::get(parser.getContext(), names, strides,
+                                       vectorShape);
+  }
+};
+
+TEST(VectorLayoutAttrInterface, NamedStridedLayout) {
+  const char *ir = R"MLIR(
+    #layout = #vltest.named_strided_layout<["BatchX" : 2, "LaneX" : 4, "VectorX" : 2],
+                                           ["BatchY" : 1, "LaneY" : 8, "VectorY" : 2]>
+    %lhs = "arith.constant"() {value = dense<0.0> : vector<16x16xf16, #layout>}
+        : () -> (vector<16x16xf16, #layout>)
+  )MLIR";
+
+  DialectRegistry registry;
+  registry.insert<VLTestDialect, func::FuncDialect, arith::ArithDialect>();
+  MLIRContext ctx(registry);
+  OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
+
+  arith::ConstantOp op =
+      llvm::cast<arith::ConstantOp>(module->getBody()->getOperations().front());
+  Type type = op.getResult().getType();
+  if (auto vectorType = llvm::cast<VectorType>(type)) {
+    VectorLayoutAttrInterface layout = vectorType.getLayout();
+    auto namedStridedLayout = llvm::cast<NamedStridedLayoutAttr>(layout);
+    ArrayRef<std::string> names = namedStridedLayout.getNames();
+    ArrayRef<int64_t> strides = namedStridedLayout.getStrides();
+    ArrayRef<int64_t> vectorShape = namedStridedLayout.getVectorShape();
+    EXPECT_EQ(vectorShape.size(), 2u);
+    EXPECT_EQ(vectorShape[0], 16u);
+    EXPECT_EQ(vectorShape[1], 16u);
+    EXPECT_EQ(strides.size(), 6u);
+    EXPECT_EQ(strides[0], 2u);
+    EXPECT_EQ(strides[1], 4u);
+    EXPECT_EQ(strides[2], 2u);
+    EXPECT_EQ(strides[3], 1u);
+    EXPECT_EQ(strides[4], 8u);
+    EXPECT_EQ(strides[5], 2u);
+    EXPECT_EQ(names.size(), 6u);
+    EXPECT_EQ(names[0], "BatchX");
+    EXPECT_EQ(names[1], "LaneX");
+    EXPECT_EQ(names[2], "VectorX");
+    EXPECT_EQ(names[3], "BatchY");
+    EXPECT_EQ(names[4], "LaneY");
+    EXPECT_EQ(names[5], "VectorY");
+  }
+}

Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code looks ok to me, the abstraction level to keep it implementation defined, too. Should be fine for the hands-off approach we're having now.

I wonder if we want to have some implementation upstream, though. Sure, this would be for another PR, but how many upstream passes are we going to disable because they don't understand the downstream layout encoding?

I'll leave for others to finally approve the PR, to give them time to review.

Copy link
Contributor

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that mean that every vector transformation is now responsible for maintaining and propagating this attribute? That seems like a fairly fundamental change, do we need to make sure the community is onboard with that? What's going to be the expected behavior of transformations regarding this new attribute in patterns handling vectors?

Comment on lines 480 to 482
// Parse a list of mappings and address space if present.
if (!consumeIf(Token::greater)) {
// Parse comma separated list of affine maps, followed by memory space.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

those comments look out of date.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, will update.

@@ -242,6 +243,11 @@ LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
return emitError() << "number of dims must match, got "
<< scalableDims.size() << " and " << shape.size();

if (layout) {
if (failed(layout.verifyLayout(shape, elementType, emitError)))
return emitError() << "layout does not match underlying vector shape";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the error message should be more generic as the mismatch may be in the shape or somewhere else.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, will change to something more generic like "Layout verification failed!"

};

TEST(VectorLayoutAttrInterface, NamedStridedLayout) {
const char *ir = R"MLIR(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also add a lit test to test round trip printing/parsing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure will add a roundtrip test. Might have to keep it as a unit test since the attribute and dialect are only defined in this cpp file.

Comment on lines 1057 to 1061
A vector may optionally have a layout that indicates how indices of
the vector are transformed to indices of the vector fragments that
are held by individual threads in a SIMT execution model. Such layouts
are common in a wide variety of GPU matrix multiplication instructions.
The layout can be any attribute that implements `VectorLayoutAttrInterface`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the description should probably be more generic than that to allow for more usage of the layout. (not only GPU and mapping to SIMT registers)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will rewrite this to make it more general. I think the primary motivation for this was GPU/SIMT but people can use this in different ways.

@dcaballe
Copy link
Contributor

Hey Hash, thanks for the proposal. This is an interesting change! Some initial questions that come to mind:

  • How would we lower something like this to LLVM?
  • Should this be a property of the type or of the operations supporting these swizzles? In the doc you mention that this is mostly for SIMT matmuls, which make me wonder if we should only make this available for those ops.
  • Assuming this should be a property of the type, should we consider a "LayoutVectorType", where the attribute is not optional but mandatory?
  • It would be very helpful to add a bunch for IR examples. I would help me get a clearer idea of what we are trying to achieve here.

This looks like a big change so I would give it enough time to settle. Perhaps we should discuss some of these questions around an RFC?

@harsh-nod
Copy link
Contributor Author

Does that mean that every vector transformation is now responsible for maintaining and propagating this attribute? That seems like a fairly fundamental change, do we need to make sure the community is onboard with that? What's going to be the expected behavior of transformations regarding this new attribute in patterns handling vectors?

I don't think every vector transformation should be responsible for maintaining and propagating this attribute. The existing vector transformations and patterns can safely ignore this attribute which hopefully reduces the scope of this change.

@harsh-nod
Copy link
Contributor Author

harsh-nod commented Nov 10, 2023

Hey Hash, thanks for the proposal. This is an interesting change! Some initial questions that come to mind:

  • How would we lower something like this to LLVM?
  • Should this be a property of the type or of the operations supporting these swizzles? In the doc you mention that this is mostly for SIMT matmuls, which make me wonder if we should only make this available for those ops.
  • Assuming this should be a property of the type, should we consider a "LayoutVectorType", where the attribute is not optional but mandatory?
  • It would be very helpful to add a bunch for IR examples. I would help me get a clearer idea of what we are trying to achieve here.

This looks like a big change so I would give it enough time to settle. Perhaps we should discuss some of these questions around an RFC?

So this PR is un-opinionated in terms of how the vector layout can be represented. The idea behind that is that many people will have a preference on what layout representation best suits them and we would like upstream to be able to support everyone's ideas. So regarding lowering to LLVM, this PR does not prescribe a way to utilize the layout information to lower to LLVM. That will be user dependent. That being said, I am a proponent of a specific form of the layout representation that can be used to lower to LLVM. For more details on how that works as well as IR examples, please see the following slides and videos (the work in these slides and videos is not present in this PR, it simply shows how one could construct a vector layout and use it)-
https://github.com/nod-ai/techtalks/blob/main/high_dimensional_layout_flash_attn_harsh_23.pdf
https://www.youtube.com/watch?v=fqeyTCmqO4g
https://www.youtube.com/watch?v=5i7xrBUCD38

Regarding whether this should be a property of the type or the operations that support these swizzles -
I believe this should be a property of the type but it stems from constraints/requirements on the operations. While my work is currently mostly driven by GPU/SIMT requirements, I believe this layout is more general than that and could be used for other execution models as well.

Regarding a new vector with layout type, that is an interesting idea. However, my concerns would be the scope of the changes required to have all the vector, arithmetic etc. ops support this new type. Also, this new type would be a superset of the existing vector type so would seem like a lot of duplication.

Hope the links above provide you with enough IR examples. If not, happy to get on a GVC to give more info.

Happy to continue the conversation anywhere. I am hoping that this is not a big change with the goal that everyone should hopefully be able to continue using vector ops and types as usual with extra work required from those who want to use this layout.

@MaheshRavishankar
Copy link
Contributor

The way I see it, this is part of the type (like memory space is part of memref type). For most transformation this should be a pass through. Some transformations (in-tree or out-of three) can implement transformations that use these attributes. As Harsh mentioned I dont expect anything to directly work on this Attribute type, but this is the parent class for all derived attribute types for specific use cases..
There is going to be the issue of making sure transformations dont drop the attribute from the type, but that is probably something worth doing anyway.

- Fix out of date comments
- More generic error message on layout verification failure
- Add roundtrip test
- More generic description of layout
@joker-eph
Copy link
Collaborator

Does that mean that every vector transformation is now responsible for maintaining and propagating this attribute? That seems like a fairly fundamental change, do we need to make sure the community is onboard with that? What's going to be the expected behavior of transformations regarding this new attribute in patterns handling vectors?

Same concern here: we need much more consideration on all this, this deserves an RFC.

Copy link
Collaborator

@lattner lattner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This patch extends the Vector type but does not update any of the existing clients to handle it (defensively or otherwise). It'd be great to come to consensus of whether extending this at all is a good idea:
https://discourse.llvm.org/t/rfc-remove-arith-math-ops-on-tensors/74357/44

if it is a good idea, then all the existing clients should be updated.

@nicolasvasilache
Copy link
Contributor

This needs its own RFC, I too would prefer to see something less intrusive with room for evolution.

One possible path forward is to start from something like a FragmentedVectorType that is not automatically a builtin VectorType and interfaces that define the semantics of such fragments. The existing VectorType is a particular limit case with a single fragment.

Over time, ops, patterns and transformations can be ported to VectorTypeOrFragmentedVectorType.
At some point in the future, if we get to full convergence great, if not that's fine too: it is unclear that everything makes sense to have as a fragmented vector.

One big question is whether the number of fragments and the data paths to convert between different such "layouts" are statically known but the system should be designed such that we can ask this question later (e.g. if ? quantities in vectors ever become a thing).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:ods mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants