New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] Add optional layout attribute to VectorType #71916
base: main
Are you sure you want to change the base?
Conversation
This patch adds an attribute interface for representing the layout on vector types. This layout could be used to represent the mapping from the vector indices to the indices of the vector fragments held by different threads of a GPU. The interface has a verify function that can be used to validate that the layout accurately represents the vector shape.
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir-ods Author: None (harsh-nod) ChangesThis patch adds an attribute interface for representing the layout on vector types. This layout could be used to represent the mapping from the vector indices to the indices of the vector fragments held by different threads of a GPU. The interface has a verify function that can be used to validate that the layout accurately represents the vector shape. Full diff: https://github.com/llvm/llvm-project/pull/71916.diff 8 Files Affected:
diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
index c741db9b47f34e5..9241cac8c3b98a0 100644
--- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
@@ -495,4 +495,28 @@ def MemRefLayoutAttrInterface : AttrInterface<"MemRefLayoutAttrInterface"> {
];
}
+//===----------------------------------------------------------------------===//
+// VectorLayoutAttrInterface
+//===----------------------------------------------------------------------===//
+
+def VectorLayoutAttrInterface : AttrInterface<"VectorLayoutAttrInterface"> {
+ let cppNamespace = "::mlir";
+
+ let description = [{
+ This interface is used for attributes that can represent the Vector type's
+ layout semantics, such as being able to map the vector indices to those
+ of the vector fragments held by individiual threads.
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ "Check if the current layout is applicable to the provided shape",
+ "::mlir::LogicalResult", "verifyLayout",
+ (ins "::llvm::ArrayRef<int64_t>":$shape,
+ "::mlir::Type":$elementType,
+ "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError)
+ >
+ ];
+}
+
#endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 92ce053ad5c829b..a387390b38e7de4 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -307,12 +307,14 @@ class VectorType::Builder {
/// Build from another VectorType.
explicit Builder(VectorType other)
: elementType(other.getElementType()), shape(other.getShape()),
- scalableDims(other.getScalableDims()) {}
+ scalableDims(other.getScalableDims()), layout(other.getLayout()) {}
/// Build from scratch.
Builder(ArrayRef<int64_t> shape, Type elementType,
- ArrayRef<bool> scalableDims = {})
- : elementType(elementType), shape(shape), scalableDims(scalableDims) {}
+ ArrayRef<bool> scalableDims = {},
+ VectorLayoutAttrInterface layout = {})
+ : elementType(elementType), shape(shape), scalableDims(scalableDims),
+ layout(layout) {}
Builder &setShape(ArrayRef<int64_t> newShape,
ArrayRef<bool> newIsScalableDim = {}) {
@@ -342,6 +344,11 @@ class VectorType::Builder {
return *this;
}
+ Builder &setLayout(VectorLayoutAttrInterface newLayout) {
+ layout = newLayout;
+ return *this;
+ }
+
operator VectorType() {
return VectorType::get(shape, elementType, scalableDims);
}
@@ -350,6 +357,7 @@ class VectorType::Builder {
Type elementType;
CopyOnWriteArrayRef<int64_t> shape;
CopyOnWriteArrayRef<bool> scalableDims;
+ VectorLayoutAttrInterface layout;
};
/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 5ec986ac26de06b..3a2193d7a1768f7 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -1029,11 +1029,13 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
Syntax:
```
- vector-type ::= `vector` `<` vector-dim-list vector-element-type `>`
+ vector-type ::= `vector` `<` vector-dim-list vector-element-type
+ (`,` layout-specification)? `>`
vector-element-type ::= float-type | integer-type | index-type
vector-dim-list := (static-dim-list `x`)?
static-dim-list ::= static-dim (`x` static-dim)*
static-dim ::= (decimal-literal | `[` decimal-literal `]`)
+ layout-specification ::= attribute-value
```
The vector type represents a SIMD style vector used by target-specific
@@ -1050,6 +1052,14 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
declarations, `vector<0x42xi32>` is invalid because it is interpreted as a
2D vector with shape `(0, 42)` and zero shapes are not allowed.
+ ##### Layout
+
+ A vector may optionally have a layout that indicates how indices of
+ the vector are transformed to indices of the vector fragments that
+ are held by individual threads in a SIMT execution model. Such layouts
+ are common in a wide variety of GPU matrix multiplication instructions.
+ The layout can be any attribute that implements `VectorLayoutAttrInterface`.
+
Examples:
```mlir
@@ -1068,17 +1078,20 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
// A 3D mixed fixed/scalable vector in which only the inner dimension is
// scalable.
vector<2x[4]x8xf32>
+
```
}];
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
"Type":$elementType,
- ArrayRefParameter<"bool">:$scalableDims
+ ArrayRefParameter<"bool">:$scalableDims,
+ "VectorLayoutAttrInterface":$layout
);
let builders = [
TypeBuilderWithInferredContext<(ins
"ArrayRef<int64_t>":$shape, "Type":$elementType,
- CArg<"ArrayRef<bool>", "{}">:$scalableDims
+ CArg<"ArrayRef<bool>", "{}">:$scalableDims,
+ CArg<"VectorLayoutAttrInterface", "{}">:$layout
), [{
// While `scalableDims` is optional, its default value should be
// `false` for every dim in `shape`.
@@ -1087,7 +1100,7 @@ def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> {
isScalableVec.resize(shape.size(), false);
scalableDims = isScalableVec;
}
- return $_get(elementType.getContext(), shape, elementType, scalableDims);
+ return $_get(elementType.getContext(), shape, elementType, scalableDims, layout);
}]>
];
let extraClassDeclaration = [{
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index adaefb78172c2ea..be8c84ee74e8688 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -459,14 +459,38 @@ VectorType Parser::parseVectorType() {
// Parse the element type.
auto typeLoc = getToken().getLoc();
auto elementType = parseType();
- if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
+ if (!elementType)
return nullptr;
if (!VectorType::isValidElementType(elementType))
return emitError(typeLoc, "vector elements must be int/index/float type"),
nullptr;
- return VectorType::get(dimensions, elementType, scalableDims);
+ VectorLayoutAttrInterface layout;
+ auto parseElt = [&]() -> ParseResult {
+ Attribute attr = parseAttribute();
+ if (!attr)
+ return failure();
+ if (isa<VectorLayoutAttrInterface>(attr)) {
+ layout = cast<VectorLayoutAttrInterface>(attr);
+ }
+ return success();
+ };
+
+ // Parse a list of mappings and address space if present.
+ if (!consumeIf(Token::greater)) {
+ // Parse comma separated list of affine maps, followed by memory space.
+ if (parseToken(Token::comma, "expected ',' or '>' in vector type") ||
+ parseCommaSeparatedListUntil(Token::greater, parseElt,
+ /*allowEmptyList=*/false)) {
+ return nullptr;
+ }
+ }
+
+ if (!layout)
+ return VectorType::get(dimensions, elementType, scalableDims);
+
+ return VectorType::get(dimensions, elementType, scalableDims, layout);
}
/// Parse a dimension list in a vector type. This populates the dimension list.
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index dae7fdd40b5456c..458140b0de81c50 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2572,6 +2572,11 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
os << 'x';
}
printType(vectorTy.getElementType());
+ VectorLayoutAttrInterface layout = vectorTy.getLayout();
+ if (layout) {
+ os << ", ";
+ printAttribute(vectorTy.getLayout(), AttrTypeElision::May);
+ }
os << '>';
})
.Case<RankedTensorType>([&](RankedTensorType tensorTy) {
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index a9284d5714637bc..b0ebec169e3c841 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -227,7 +227,8 @@ LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
- ArrayRef<bool> scalableDims) {
+ ArrayRef<bool> scalableDims,
+ VectorLayoutAttrInterface layout) {
if (!isValidElementType(elementType))
return emitError()
<< "vector elements must be int/index/float type but got "
@@ -242,6 +243,11 @@ LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
return emitError() << "number of dims must match, got "
<< scalableDims.size() << " and " << shape.size();
+ if (layout) {
+ if (failed(layout.verifyLayout(shape, elementType, emitError)))
+ return emitError() << "layout does not match underlying vector shape";
+ }
+
return success();
}
diff --git a/mlir/unittests/Interfaces/CMakeLists.txt b/mlir/unittests/Interfaces/CMakeLists.txt
index d192b2922d6b9dc..45a518f8549a3c6 100644
--- a/mlir/unittests/Interfaces/CMakeLists.txt
+++ b/mlir/unittests/Interfaces/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_unittest(MLIRInterfacesTests
DataLayoutInterfacesTest.cpp
InferIntRangeInterfaceTest.cpp
InferTypeOpInterfaceTest.cpp
+ VectorLayoutInterfaceTest.cpp
)
target_link_libraries(MLIRInterfacesTests
diff --git a/mlir/unittests/Interfaces/VectorLayoutInterfaceTest.cpp b/mlir/unittests/Interfaces/VectorLayoutInterfaceTest.cpp
new file mode 100644
index 000000000000000..0ea9a710138e3a4
--- /dev/null
+++ b/mlir/unittests/Interfaces/VectorLayoutInterfaceTest.cpp
@@ -0,0 +1,158 @@
+//===-- VectorLayoutInterfaceTest.cpp - Unit Tests for Vector Layouts -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Parser/Parser.h"
+
+#include <gtest/gtest.h>
+
+using namespace mlir;
+using namespace mlir::detail;
+
+class NamedStridedLayoutAttrStorage : public AttributeStorage {
+public:
+ using KeyTy =
+ std::tuple<ArrayRef<std::string>, ArrayRef<int64_t>, ArrayRef<int64_t>>;
+
+ NamedStridedLayoutAttrStorage(ArrayRef<std::string> names,
+ ArrayRef<int64_t> strides,
+ ArrayRef<int64_t> vectorShape)
+ : names(names), strides(strides), vectorShape(vectorShape) {}
+
+ bool operator==(const KeyTy &key) const {
+ return (std::get<0>(key) == names) && (std::get<1>(key) == strides) &&
+ (std::get<2>(key) == vectorShape);
+ }
+
+ static NamedStridedLayoutAttrStorage *
+ construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+ ArrayRef<std::string> names = allocator.copyInto(std::get<0>(key));
+ ArrayRef<int64_t> strides = allocator.copyInto(std::get<1>(key));
+ ArrayRef<int64_t> vectorShape = allocator.copyInto(std::get<2>(key));
+ return new (allocator.allocate<NamedStridedLayoutAttrStorage>())
+ NamedStridedLayoutAttrStorage(names, strides, vectorShape);
+ }
+
+ ArrayRef<std::string> names;
+ ArrayRef<int64_t> strides;
+ ArrayRef<int64_t> vectorShape;
+};
+
+struct NamedStridedLayoutAttr
+ : public Attribute::AttrBase<NamedStridedLayoutAttr, Attribute,
+ NamedStridedLayoutAttrStorage,
+ VectorLayoutAttrInterface::Trait> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NamedStridedLayoutAttr)
+ using Base::Base;
+ static NamedStridedLayoutAttr get(MLIRContext *ctx,
+ ArrayRef<std::string> names,
+ ArrayRef<int64_t> strides,
+ ArrayRef<int64_t> vectorShape) {
+ return Base::get(ctx, names, strides, vectorShape);
+ }
+
+ LogicalResult verifyLayout(ArrayRef<int64_t> shape, Type elementType,
+ function_ref<InFlightDiagnostic()> emitError) {
+ if (shape == getVectorShape())
+ return success();
+ return failure();
+ }
+
+ ArrayRef<std::string> getNames() { return getImpl()->names; }
+ ArrayRef<int64_t> getStrides() { return getImpl()->strides; }
+ ArrayRef<int64_t> getVectorShape() { return getImpl()->vectorShape; }
+};
+
+struct VLTestDialect : Dialect {
+ explicit VLTestDialect(MLIRContext *ctx)
+ : Dialect(getDialectNamespace(), ctx, TypeID::get<VLTestDialect>()) {
+ ctx->loadDialect<VLTestDialect>();
+ addAttributes<NamedStridedLayoutAttr>();
+ }
+ static StringRef getDialectNamespace() { return "vltest"; }
+
+ Attribute parseAttribute(DialectAsmParser &parser, Type type) const override {
+ SmallVector<int64_t> strides, vectorShape;
+ SmallVector<std::string> names;
+ if (!succeeded(parser.parseKeyword("named_strided_layout")))
+ return {};
+ if (!succeeded(parser.parseLess()))
+ return {};
+ do {
+ if (!succeeded(parser.parseLSquare()))
+ return {};
+ std::string name;
+ int64_t stride;
+ int64_t shape = 1;
+ do {
+ if (succeeded(parser.parseString(&name)) &&
+ succeeded(parser.parseColon()) &&
+ succeeded(parser.parseInteger(stride))) {
+ names.push_back(name);
+ strides.push_back(stride);
+ shape *= stride;
+ }
+ } while (succeeded(parser.parseOptionalComma()));
+ if (!succeeded(parser.parseRSquare()))
+ return {};
+ vectorShape.push_back(shape);
+ } while (succeeded(parser.parseOptionalComma()));
+ if (!succeeded(parser.parseGreater()))
+ return {};
+ return NamedStridedLayoutAttr::get(parser.getContext(), names, strides,
+ vectorShape);
+ }
+};
+
+TEST(VectorLayoutAttrInterface, NamedStridedLayout) {
+ const char *ir = R"MLIR(
+ #layout = #vltest.named_strided_layout<["BatchX" : 2, "LaneX" : 4, "VectorX" : 2],
+ ["BatchY" : 1, "LaneY" : 8, "VectorY" : 2]>
+ %lhs = "arith.constant"() {value = dense<0.0> : vector<16x16xf16, #layout>}
+ : () -> (vector<16x16xf16, #layout>)
+ )MLIR";
+
+ DialectRegistry registry;
+ registry.insert<VLTestDialect, func::FuncDialect, arith::ArithDialect>();
+ MLIRContext ctx(registry);
+ OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(ir, &ctx);
+
+ arith::ConstantOp op =
+ llvm::cast<arith::ConstantOp>(module->getBody()->getOperations().front());
+ Type type = op.getResult().getType();
+ if (auto vectorType = llvm::cast<VectorType>(type)) {
+ VectorLayoutAttrInterface layout = vectorType.getLayout();
+ auto namedStridedLayout = llvm::cast<NamedStridedLayoutAttr>(layout);
+ ArrayRef<std::string> names = namedStridedLayout.getNames();
+ ArrayRef<int64_t> strides = namedStridedLayout.getStrides();
+ ArrayRef<int64_t> vectorShape = namedStridedLayout.getVectorShape();
+ EXPECT_EQ(vectorShape.size(), 2u);
+ EXPECT_EQ(vectorShape[0], 16u);
+ EXPECT_EQ(vectorShape[1], 16u);
+ EXPECT_EQ(strides.size(), 6u);
+ EXPECT_EQ(strides[0], 2u);
+ EXPECT_EQ(strides[1], 4u);
+ EXPECT_EQ(strides[2], 2u);
+ EXPECT_EQ(strides[3], 1u);
+ EXPECT_EQ(strides[4], 8u);
+ EXPECT_EQ(strides[5], 2u);
+ EXPECT_EQ(names.size(), 6u);
+ EXPECT_EQ(names[0], "BatchX");
+ EXPECT_EQ(names[1], "LaneX");
+ EXPECT_EQ(names[2], "VectorX");
+ EXPECT_EQ(names[3], "BatchY");
+ EXPECT_EQ(names[4], "LaneY");
+ EXPECT_EQ(names[5], "VectorY");
+ }
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code looks ok to me, the abstraction level to keep it implementation defined, too. Should be fine for the hands-off approach we're having now.
I wonder if we want to have some implementation upstream, though. Sure, this would be for another PR, but how many upstream passes are we going to disable because they don't understand the downstream layout encoding?
I'll leave for others to finally approve the PR, to give them time to review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does that mean that every vector transformation is now responsible for maintaining and propagating this attribute? That seems like a fairly fundamental change, do we need to make sure the community is onboard with that? What's going to be the expected behavior of transformations regarding this new attribute in patterns handling vectors?
mlir/lib/AsmParser/TypeParser.cpp
Outdated
// Parse a list of mappings and address space if present. | ||
if (!consumeIf(Token::greater)) { | ||
// Parse comma separated list of affine maps, followed by memory space. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
those comments look out of date.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, will update.
mlir/lib/IR/BuiltinTypes.cpp
Outdated
@@ -242,6 +243,11 @@ LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError, | |||
return emitError() << "number of dims must match, got " | |||
<< scalableDims.size() << " and " << shape.size(); | |||
|
|||
if (layout) { | |||
if (failed(layout.verifyLayout(shape, elementType, emitError))) | |||
return emitError() << "layout does not match underlying vector shape"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the error message should be more generic as the mismatch may be in the shape or somewhere else.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, will change to something more generic like "Layout verification failed!"
}; | ||
|
||
TEST(VectorLayoutAttrInterface, NamedStridedLayout) { | ||
const char *ir = R"MLIR( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you also add a lit test to test round trip printing/parsing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure will add a roundtrip test. Might have to keep it as a unit test since the attribute and dialect are only defined in this cpp file.
mlir/include/mlir/IR/BuiltinTypes.td
Outdated
A vector may optionally have a layout that indicates how indices of | ||
the vector are transformed to indices of the vector fragments that | ||
are held by individual threads in a SIMT execution model. Such layouts | ||
are common in a wide variety of GPU matrix multiplication instructions. | ||
The layout can be any attribute that implements `VectorLayoutAttrInterface`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the description should probably be more generic than that to allow for more usage of the layout. (not only GPU and mapping to SIMT registers)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, will rewrite this to make it more general. I think the primary motivation for this was GPU/SIMT but people can use this in different ways.
Hey Hash, thanks for the proposal. This is an interesting change! Some initial questions that come to mind:
This looks like a big change so I would give it enough time to settle. Perhaps we should discuss some of these questions around an RFC? |
I don't think every vector transformation should be responsible for maintaining and propagating this attribute. The existing vector transformations and patterns can safely ignore this attribute which hopefully reduces the scope of this change. |
So this PR is un-opinionated in terms of how the vector layout can be represented. The idea behind that is that many people will have a preference on what layout representation best suits them and we would like upstream to be able to support everyone's ideas. So regarding lowering to LLVM, this PR does not prescribe a way to utilize the layout information to lower to LLVM. That will be user dependent. That being said, I am a proponent of a specific form of the layout representation that can be used to lower to LLVM. For more details on how that works as well as IR examples, please see the following slides and videos (the work in these slides and videos is not present in this PR, it simply shows how one could construct a vector layout and use it)- Regarding whether this should be a property of the type or the operations that support these swizzles - Regarding a new vector with layout type, that is an interesting idea. However, my concerns would be the scope of the changes required to have all the vector, arithmetic etc. ops support this new type. Also, this new type would be a superset of the existing vector type so would seem like a lot of duplication. Hope the links above provide you with enough IR examples. If not, happy to get on a GVC to give more info. Happy to continue the conversation anywhere. I am hoping that this is not a big change with the goal that everyone should hopefully be able to continue using vector ops and types as usual with extra work required from those who want to use this layout. |
The way I see it, this is part of the type (like memory space is part of memref type). For most transformation this should be a pass through. Some transformations (in-tree or out-of three) can implement transformations that use these attributes. As Harsh mentioned I dont expect anything to directly work on this Attribute type, but this is the parent class for all derived attribute types for specific use cases.. |
- Fix out of date comments - More generic error message on layout verification failure - Add roundtrip test - More generic description of layout
cdefb68
to
5058923
Compare
Same concern here: we need much more consideration on all this, this deserves an RFC. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This patch extends the Vector type but does not update any of the existing clients to handle it (defensively or otherwise). It'd be great to come to consensus of whether extending this at all is a good idea:
https://discourse.llvm.org/t/rfc-remove-arith-math-ops-on-tensors/74357/44
if it is a good idea, then all the existing clients should be updated.
This needs its own RFC, I too would prefer to see something less intrusive with room for evolution. One possible path forward is to start from something like a FragmentedVectorType that is not automatically a builtin VectorType and interfaces that define the semantics of such fragments. The existing VectorType is a particular limit case with a single fragment. Over time, ops, patterns and transformations can be ported to One big question is whether the number of fragments and the data paths to convert between different such "layouts" are statically known but the system should be designed such that we can ask this question later (e.g. if |
This patch adds an attribute interface for representing the layout on vector types. This layout could be used to represent the mapping from the vector indices to the indices of the vector fragments held by different threads of a GPU.
The interface has a verify function that can be used to validate that the layout accurately represents the vector shape.