Skip to content

Commit a518299

Browse files
committed
[mlir] Support for mutable types
Introduce support for mutable storage in the StorageUniquer infrastructure. This makes MLIR have key-value storage instead of just uniqued key storage. A storage instance now contains a unique immutable key and a mutable value, both stored in the arena allocator that belongs to the context. This is a preconditio for supporting recursive types that require delayed initialization, in particular LLVM structure types. The functionality is exercised in the test pass with trivial self-recursive type. So far, recursive types can only be printed in parsed in a closed type system. Removing this restriction is left for future work. Differential Revision: https://reviews.llvm.org/D84171
1 parent 1956cf1 commit a518299

File tree

14 files changed

+425
-18
lines changed

14 files changed

+425
-18
lines changed

mlir/docs/Tutorials/DefiningAttributesAndTypes.md

Lines changed: 128 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ namespace MyTypes {
4747
enum Kinds {
4848
// These kinds will be used in the examples below.
4949
Simple = Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_0_TYPE,
50-
Complex
50+
Complex,
51+
Recursive
5152
};
5253
}
5354
```
@@ -58,13 +59,17 @@ As described above, `Type` objects in MLIR are value-typed and rely on having an
5859
implicitly internal storage object that holds the actual data for the type. When
5960
defining a new `Type` it isn't always necessary to define a new storage class.
6061
So before defining the derived `Type`, it's important to know which of the two
61-
classes of `Type` we are defining. Some types are `primitives` meaning they do
62+
classes of `Type` we are defining. Some types are _primitives_ meaning they do
6263
not have any parameters and are singletons uniqued by kind, like the
6364
[`index` type](LangRef.md#index-type). Parametric types on the other hand, have
6465
additional information that differentiates different instances of the same
6566
`Type` kind. For example the [`integer` type](LangRef.md#integer-type) has a
6667
bitwidth, making `i8` and `i16` be different instances of
67-
[`integer` type](LangRef.md#integer-type).
68+
[`integer` type](LangRef.md#integer-type). Types can also have a mutable
69+
component, which can be used, for example, to construct self-referring recursive
70+
types. The mutable component _cannot_ be used to differentiate types within the
71+
same kind, so usually such types are also parametric where the parameters serve
72+
to identify them.
6873

6974
#### Simple non-parametric types
7075

@@ -240,6 +245,126 @@ public:
240245
};
241246
```
242247
248+
#### Types with a mutable component
249+
250+
Types with a mutable component require defining a type storage class regardless
251+
of being parametric. The storage contains both the parameters and the mutable
252+
component and is accessed in a thread-safe way by the type support
253+
infrastructure.
254+
255+
##### Defining a type storage
256+
257+
In addition to the requirements for the type storage class for parametric types,
258+
the storage class for types with a mutable component must additionally obey the
259+
following.
260+
261+
* The mutable component must not participate in the storage key.
262+
* Provide a mutation method that is used to modify an existing instance of the
263+
storage. This method modifies the mutable component based on arguments,
264+
using `allocator` for any new dynamically-allocated storage, and indicates
265+
whether the modification was successful.
266+
- `LogicalResult mutate(StorageAllocator &allocator, Args ...&& args)`
267+
268+
Let's define a simple storage for recursive types, where a type is identified by
269+
its name and can contain another type including itself.
270+
271+
```c++
272+
/// Here we define a storage class for a RecursiveType that is identified by its
273+
/// name and contains another type.
274+
struct RecursiveTypeStorage : public TypeStorage {
275+
/// The type is uniquely identified by its name. Note that the contained type
276+
/// is _not_ a part of the key.
277+
using KeyTy = StringRef;
278+
279+
/// Construct the storage from the type name. Explicitly initialize the
280+
/// containedType to nullptr, which is used as marker for the mutable
281+
/// component being not yet initialized.
282+
RecursiveTypeStorage(StringRef name) : name(name), containedType(nullptr) {}
283+
284+
/// Define the comparison function.
285+
bool operator==(const KeyTy &key) const { return key == name; }
286+
287+
/// Define a construction method for creating a new instance of the storage.
288+
static RecursiveTypeStorage *construct(StorageAllocator &allocator,
289+
const KeyTy &key) {
290+
// Note that the key string is copied into the allocator to ensure it
291+
// remains live as long as the storage itself.
292+
return new (allocator.allocate<RecursiveTypeStorage>())
293+
RecursiveTypeStorage(allocator.copyInto(key));
294+
}
295+
296+
/// Define a mutation method for changing the type after it is created. In
297+
/// many cases, we only want to set the mutable component once and reject
298+
/// any further modification, which can be achieved by returning failure from
299+
/// this function.
300+
LogicalResult mutate(StorageAllocator &, Type body) {
301+
// If the contained type has been initialized already, and the call tries
302+
// to change it, reject the change.
303+
if (containedType && containedType != body)
304+
return failure();
305+
306+
// Change the body successfully.
307+
containedType = body;
308+
return success();
309+
}
310+
311+
StringRef name;
312+
Type containedType;
313+
};
314+
```
315+
316+
##### Type class definition
317+
318+
Having defined the storage class, we can define the type class itself. This is
319+
similar to parametric types. `Type::TypeBase` provides a `mutate` method that
320+
forwards its arguments to the `mutate` method of the storage and ensures the
321+
modification happens under lock.
322+
323+
```c++
324+
class RecursiveType : public Type::TypeBase<RecursiveType, Type,
325+
RecursiveTypeStorage> {
326+
public:
327+
/// Inherit parent constructors.
328+
using Base::Base;
329+
330+
/// This static method is used to support type inquiry through isa, cast,
331+
/// and dyn_cast.
332+
static bool kindof(unsigned kind) { return kind == MyTypes::Recursive; }
333+
334+
/// Creates an instance of the Recursive type. This only takes the type name
335+
/// and returns the type with uninitialized body.
336+
static RecursiveType get(MLIRContext *ctx, StringRef name) {
337+
// Call into the base to get a uniqued instance of this type. The parameter
338+
// (name) is passed after the kind.
339+
return Base::get(ctx, MyTypes::Recursive, name);
340+
}
341+
342+
/// Now we can change the mutable component of the type. This is an instance
343+
/// method callable on an already existing RecursiveType.
344+
void setBody(Type body) {
345+
// Call into the base to mutate the type.
346+
LogicalResult result = Base::mutate(body);
347+
// Most types expect mutation to always succeed, but types can implement
348+
// custom logic for handling mutation failures.
349+
assert(succeeded(result) &&
350+
"attempting to change the body of an already-initialized type");
351+
// Avoid unused-variable warning when building without assertions.
352+
(void) result;
353+
}
354+
355+
/// Returns the contained type, which may be null if it has not been
356+
/// initialized yet.
357+
Type getBody() {
358+
return getImpl()->containedType;
359+
}
360+
361+
/// Returns the name.
362+
StringRef getName() {
363+
return getImpl()->name;
364+
}
365+
};
366+
```
367+
243368
### Registering types with a Dialect
244369
245370
Once the dialect types have been defined, they must then be registered with a

mlir/include/mlir/IR/AttributeSupport.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,13 @@ class AttributeUniquer {
139139
kind, std::forward<Args>(args)...);
140140
}
141141

142+
template <typename ImplType, typename... Args>
143+
static LogicalResult mutate(MLIRContext *ctx, ImplType *impl,
144+
Args &&...args) {
145+
assert(impl && "cannot mutate null attribute");
146+
return ctx->getAttributeUniquer().mutate(impl, std::forward<Args>(args)...);
147+
}
148+
142149
private:
143150
/// Initialize the given attribute storage instance.
144151
static void initializeAttributeStorage(AttributeStorage *storage,

mlir/include/mlir/IR/Attributes.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,10 @@ struct SparseElementsAttributeStorage;
4848

4949
/// Attributes are known-constant values of operations and functions.
5050
///
51-
/// Instances of the Attribute class are references to immutable, uniqued,
52-
/// and immortal values owned by MLIRContext. As such, an Attribute is a thin
53-
/// wrapper around an underlying storage pointer. Attributes are usually passed
54-
/// by value.
51+
/// Instances of the Attribute class are references to immortal key-value pairs
52+
/// with immutable, uniqued key owned by MLIRContext. As such, an Attribute is a
53+
/// thin wrapper around an underlying storage pointer. Attributes are usually
54+
/// passed by value.
5555
class Attribute {
5656
public:
5757
/// Integer identifier for all the concrete attribute kinds.

mlir/include/mlir/IR/StorageUniquerSupport.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,14 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
105105
return UniquerT::template get<ConcreteT>(loc.getContext(), kind, args...);
106106
}
107107

108+
/// Mutate the current storage instance. This will not change the unique key.
109+
/// The arguments are forwarded to 'ConcreteT::mutate'.
110+
template <typename... Args>
111+
LogicalResult mutate(Args &&...args) {
112+
return UniquerT::mutate(this->getContext(), getImpl(),
113+
std::forward<Args>(args)...);
114+
}
115+
108116
/// Default implementation that just returns success.
109117
template <typename... Args>
110118
static LogicalResult verifyConstructionInvariants(Args... args) {

mlir/include/mlir/IR/TypeSupport.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,15 @@ struct TypeUniquer {
132132
},
133133
kind, std::forward<Args>(args)...);
134134
}
135+
136+
/// Change the mutable component of the given type instance in the provided
137+
/// context.
138+
template <typename ImplType, typename... Args>
139+
static LogicalResult mutate(MLIRContext *ctx, ImplType *impl,
140+
Args &&...args) {
141+
assert(impl && "cannot mutate null type");
142+
return ctx->getTypeUniquer().mutate(impl, std::forward<Args>(args)...);
143+
}
135144
};
136145
} // namespace detail
137146

mlir/include/mlir/IR/Types.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,17 @@ struct FunctionTypeStorage;
2727
struct OpaqueTypeStorage;
2828
} // namespace detail
2929

30-
/// Instances of the Type class are immutable and uniqued. They wrap a pointer
31-
/// to the storage object owned by MLIRContext. Therefore, instances of Type
32-
/// are passed around by value.
30+
/// Instances of the Type class are uniqued, have an immutable identifier and an
31+
/// optional mutable component. They wrap a pointer to the storage object owned
32+
/// by MLIRContext. Therefore, instances of Type are passed around by value.
3333
///
3434
/// Some types are "primitives" meaning they do not have any parameters, for
3535
/// example the Index type. Parametric types have additional information that
3636
/// differentiates the types of the same kind between them, for example the
3737
/// Integer type has bitwidth, making i8 and i16 belong to the same kind by be
38-
/// different instances of the IntegerType.
38+
/// different instances of the IntegerType. Type parameters are part of the
39+
/// unique immutable key. The mutable component of the type can be modified
40+
/// after the type is created, but cannot affect the identity of the type.
3941
///
4042
/// Types are constructed and uniqued via the 'detail::TypeUniquer' class.
4143
///
@@ -62,6 +64,7 @@ struct OpaqueTypeStorage;
6264
/// - The type kind (for LLVM-style RTTI).
6365
/// - The dialect that defined the type.
6466
/// - Any parameters of the type.
67+
/// - An optional mutable component.
6568
/// For non-parametric types, a convenience DefaultTypeStorage is provided.
6669
/// Parametric storage types must derive TypeStorage and respect the following:
6770
/// - Define a type alias, KeyTy, to a type that uniquely identifies the
@@ -75,11 +78,14 @@ struct OpaqueTypeStorage;
7578
/// - Provide a method, 'bool operator==(const KeyTy &) const', to
7679
/// compare the storage instance against an instance of the key type.
7780
///
78-
/// - Provide a construction method:
81+
/// - Provide a static construction method:
7982
/// 'DerivedStorage *construct(TypeStorageAllocator &, const KeyTy &key)'
8083
/// that builds a unique instance of the derived storage. The arguments to
8184
/// this function are an allocator to store any uniqued data within the
8285
/// context and the key type for this storage.
86+
///
87+
/// - If they have a mutable component, this component must not be a part of
88+
// the key.
8389
class Type {
8490
public:
8591
/// Integer identifier for all the concrete type kinds.

mlir/include/mlir/Support/StorageUniquer.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define MLIR_SUPPORT_STORAGEUNIQUER_H
1111

1212
#include "mlir/Support/LLVM.h"
13+
#include "mlir/Support/LogicalResult.h"
1314
#include "llvm/ADT/DenseSet.h"
1415
#include "llvm/Support/Allocator.h"
1516

@@ -60,6 +61,20 @@ using has_impltype_hash_t = decltype(ImplTy::hashKey(std::declval<T>()));
6061
/// that is called when erasing a storage instance. This should cleanup any
6162
/// fields of the storage as necessary and not attempt to free the memory
6263
/// of the storage itself.
64+
///
65+
/// Storage classes may have an optional mutable component, which must not take
66+
/// part in the unique immutable key. In this case, storage classes may be
67+
/// mutated with `mutate` and must additionally respect the following:
68+
/// - Provide a mutation method:
69+
/// 'LogicalResult mutate(StorageAllocator &, <...>)'
70+
/// that is called when mutating a storage instance. The first argument is
71+
/// an allocator to store any mutable data, and the remaining arguments are
72+
/// forwarded from the call site. The storage can be mutated at any time
73+
/// after creation. Care must be taken to avoid excessive mutation since
74+
/// the allocated storage can keep containing previous states. The return
75+
/// value of the function is used to indicate whether the mutation was
76+
/// successful, e.g., to limit the number of mutations or enable deferred
77+
/// one-time assignment of the mutable component.
6378
class StorageUniquer {
6479
public:
6580
StorageUniquer();
@@ -166,6 +181,17 @@ class StorageUniquer {
166181
return static_cast<Storage *>(getImpl(kind, ctorFn));
167182
}
168183

184+
/// Changes the mutable component of 'storage' by forwarding the trailing
185+
/// arguments to the 'mutate' function of the derived class.
186+
template <typename Storage, typename... Args>
187+
LogicalResult mutate(Storage *storage, Args &&...args) {
188+
auto mutationFn = [&](StorageAllocator &allocator) -> LogicalResult {
189+
return static_cast<Storage &>(*storage).mutate(
190+
allocator, std::forward<Args>(args)...);
191+
};
192+
return mutateImpl(mutationFn);
193+
}
194+
169195
/// Erases a uniqued instance of 'Storage'. This function is used for derived
170196
/// types that have complex storage or uniquing constraints.
171197
template <typename Storage, typename Arg, typename... Args>
@@ -206,6 +232,10 @@ class StorageUniquer {
206232
function_ref<bool(const BaseStorage *)> isEqual,
207233
function_ref<void(BaseStorage *)> cleanupFn);
208234

235+
/// Implementation for mutating an instance of a derived storage.
236+
LogicalResult
237+
mutateImpl(function_ref<LogicalResult(StorageAllocator &)> mutationFn);
238+
209239
/// The internal implementation class.
210240
std::unique_ptr<detail::StorageUniquerImpl> impl;
211241

mlir/lib/Support/StorageUniquer.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,16 @@ struct StorageUniquerImpl {
124124
storageTypes.erase(existing);
125125
}
126126

127+
/// Mutates an instance of a derived storage in a thread-safe way.
128+
LogicalResult
129+
mutate(function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
130+
if (!threadingIsEnabled)
131+
return mutationFn(allocator);
132+
133+
llvm::sys::SmartScopedWriter<true> lock(mutex);
134+
return mutationFn(allocator);
135+
}
136+
127137
//===--------------------------------------------------------------------===//
128138
// Instance Storage
129139
//===--------------------------------------------------------------------===//
@@ -214,3 +224,9 @@ void StorageUniquer::eraseImpl(unsigned kind, unsigned hashValue,
214224
function_ref<void(BaseStorage *)> cleanupFn) {
215225
impl->erase(kind, hashValue, isEqual, cleanupFn);
216226
}
227+
228+
/// Implementation for mutating an instance of a derived storage.
229+
LogicalResult StorageUniquer::mutateImpl(
230+
function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
231+
return impl->mutate(mutationFn);
232+
}

mlir/test/IR/recursive-type.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// RUN: mlir-opt %s -test-recursive-types | FileCheck %s
2+
3+
// CHECK-LABEL: @roundtrip
4+
func @roundtrip() {
5+
// CHECK: !test.test_rec<a, test_rec<b, test_type>>
6+
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec<a, test_rec<b, test_type>>
7+
// CHECK: !test.test_rec<c, test_rec<c>>
8+
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec<c, test_rec<c>>
9+
return
10+
}
11+
12+
// CHECK-LABEL: @create
13+
func @create() {
14+
// CHECK: !test.test_rec<some_long_and_unique_name, test_rec<some_long_and_unique_name>>
15+
return
16+
}

0 commit comments

Comments
 (0)