Skip to content

Commit 20b93ab

Browse files
committed
Update ODS variadic segments "magic" attributes to use native Properties
The operand_segment_sizes and result_segment_sizes Attributes are now inlined in the operation as native propertie. We continue to support building an Attribute on the fly for `getAttr("operand_segment_sizes")` and setting the property from an attribute with `setAttr("operand_segment_sizes", attr)`. A new bytecode version is introduced to support backward compatibility and backdeployments. Differential Revision: https://reviews.llvm.org/D155919
1 parent 705fb08 commit 20b93ab

File tree

20 files changed

+542
-173
lines changed

20 files changed

+542
-173
lines changed

mlir/include/mlir/Bytecode/BytecodeImplementation.h

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/IR/DialectInterface.h"
2121
#include "mlir/IR/OpImplementation.h"
2222
#include "mlir/Support/LogicalResult.h"
23+
#include "llvm/ADT/STLExtras.h"
2324
#include "llvm/ADT/Twine.h"
2425

2526
namespace mlir {
@@ -39,6 +40,9 @@ class DialectBytecodeReader {
3940
/// Emit an error to the reader.
4041
virtual InFlightDiagnostic emitError(const Twine &msg = {}) = 0;
4142

43+
/// Return the bytecode version being read.
44+
virtual uint64_t getBytecodeVersion() const = 0;
45+
4246
/// Read out a list of elements, invoking the provided callback for each
4347
/// element. The callback function may be in any of the following forms:
4448
/// * LogicalResult(T &)
@@ -148,6 +152,76 @@ class DialectBytecodeReader {
148152
[this](int64_t &value) { return readSignedVarInt(value); });
149153
}
150154

155+
/// Parse a variable length encoded integer whose low bit is used to encode an
156+
/// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`.
157+
LogicalResult readVarIntWithFlag(uint64_t &result, bool &flag) {
158+
if (failed(readVarInt(result)))
159+
return failure();
160+
flag = result & 1;
161+
result >>= 1;
162+
return success();
163+
}
164+
165+
/// Read a "small" sparse array of integer <= 32 bits elements, where
166+
/// index/value pairs can be compressed when the array is small.
167+
/// Note that only some position of the array will be read and the ones
168+
/// not stored in the bytecode are gonne be left untouched.
169+
/// If the provided array is too small for the stored indices, an error
170+
/// will be returned.
171+
template <typename T>
172+
LogicalResult readSparseArray(MutableArrayRef<T> array) {
173+
static_assert(sizeof(T) < sizeof(uint64_t), "expect integer < 64 bits");
174+
static_assert(std::is_integral<T>::value, "expects integer");
175+
uint64_t nonZeroesCount;
176+
bool useSparseEncoding;
177+
if (failed(readVarIntWithFlag(nonZeroesCount, useSparseEncoding)))
178+
return failure();
179+
if (nonZeroesCount == 0)
180+
return success();
181+
if (!useSparseEncoding) {
182+
// This is a simple dense array.
183+
if (nonZeroesCount > array.size()) {
184+
emitError("trying to read an array of ")
185+
<< nonZeroesCount << " but only " << array.size()
186+
<< " storage available.";
187+
return failure();
188+
}
189+
for (int64_t index : llvm::seq<int64_t>(0, nonZeroesCount)) {
190+
uint64_t value;
191+
if (failed(readVarInt(value)))
192+
return failure();
193+
array[index] = value;
194+
}
195+
return success();
196+
}
197+
// Read sparse encoding
198+
// This is the number of bits used for packing the index with the value.
199+
uint64_t indexBitSize;
200+
if (failed(readVarInt(indexBitSize)))
201+
return failure();
202+
constexpr uint64_t maxIndexBitSize = 8;
203+
if (indexBitSize > maxIndexBitSize) {
204+
emitError("reading sparse array with indexing above 8 bits: ")
205+
<< indexBitSize;
206+
return failure();
207+
}
208+
for (uint32_t count : llvm::seq<uint32_t>(0, nonZeroesCount)) {
209+
(void)count;
210+
uint64_t indexValuePair;
211+
if (failed(readVarInt(indexValuePair)))
212+
return failure();
213+
uint64_t index = indexValuePair & ~(uint64_t(-1) << (indexBitSize));
214+
uint64_t value = indexValuePair >> indexBitSize;
215+
if (index >= array.size()) {
216+
emitError("reading a sparse array found index ")
217+
<< index << " but only " << array.size() << " storage available.";
218+
return failure();
219+
}
220+
array[index] = value;
221+
}
222+
return success();
223+
}
224+
151225
/// Read an APInt that is known to have been encoded with the given width.
152226
virtual FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) = 0;
153227

@@ -230,6 +304,55 @@ class DialectBytecodeWriter {
230304
writeList(value, [this](int64_t value) { writeSignedVarInt(value); });
231305
}
232306

307+
/// Write a VarInt and a flag packed together.
308+
void writeVarIntWithFlag(uint64_t value, bool flag) {
309+
writeVarInt((value << 1) | (flag ? 1 : 0));
310+
}
311+
312+
/// Write out a "small" sparse array of integer <= 32 bits elements, where
313+
/// index/value pairs can be compressed when the array is small. This method
314+
/// will scan the array multiple times and should not be used for large
315+
/// arrays. The optional provided "zero" can be used to adjust for the
316+
/// expected repeated value. We assume here that the array size fits in a 32
317+
/// bits integer.
318+
template <typename T>
319+
void writeSparseArray(ArrayRef<T> array) {
320+
static_assert(sizeof(T) < sizeof(uint64_t), "expect integer < 64 bits");
321+
static_assert(std::is_integral<T>::value, "expects integer");
322+
uint32_t size = array.size();
323+
uint32_t nonZeroesCount = 0, lastIndex = 0;
324+
for (uint32_t index : llvm::seq<uint32_t>(0, size)) {
325+
if (!array[index])
326+
continue;
327+
nonZeroesCount++;
328+
lastIndex = index;
329+
}
330+
// If the last position is too large, or the array isn't at least 50%
331+
// sparse, emit it with a dense encoding.
332+
if (lastIndex > 256 || nonZeroesCount > size / 2) {
333+
// Emit the array size and a flag which indicates whether it is sparse.
334+
writeVarIntWithFlag(size, false);
335+
for (const T &elt : array)
336+
writeVarInt(elt);
337+
return;
338+
}
339+
// Emit sparse: first the number of elements we'll write and a flag
340+
// indicating it is a sparse encoding.
341+
writeVarIntWithFlag(nonZeroesCount, true);
342+
if (nonZeroesCount == 0)
343+
return;
344+
// This is the number of bits used for packing the index with the value.
345+
int indexBitSize = llvm::Log2_32_Ceil(lastIndex + 1);
346+
writeVarInt(indexBitSize);
347+
for (uint32_t index : llvm::seq<uint32_t>(0, lastIndex + 1)) {
348+
T value = array[index];
349+
if (!value)
350+
continue;
351+
uint64_t indexValuePair = (value << indexBitSize) | (index);
352+
writeVarInt(indexValuePair);
353+
}
354+
}
355+
233356
/// Write an APInt to the bytecode stream whose bitwidth will be known
234357
/// externally at read time. This method is useful for encoding APInt values
235358
/// when the width is known via external means, such as via a type. This

mlir/include/mlir/Bytecode/Encoding.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,12 @@ enum BytecodeVersion {
4545
/// with the discardable attributes.
4646
kNativePropertiesEncoding = 5,
4747

48+
/// ODS emits operand/result segment_size as native properties instead of
49+
/// an attribute.
50+
kNativePropertiesODSSegmentSize = 6,
51+
4852
/// The current bytecode version.
49-
kVersion = 5,
53+
kVersion = 6,
5054

5155
/// An arbitrary value used to fill alignment padding.
5256
kAlignmentByte = 0xCB,

mlir/include/mlir/IR/ODSSupport.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,13 @@ Attribute convertToAttribute(MLIRContext *ctx, int64_t storage);
3737
LogicalResult convertFromAttribute(MutableArrayRef<int64_t> storage,
3838
Attribute attr, InFlightDiagnostic *diag);
3939

40+
/// Convert a DenseI32ArrayAttr to the provided storage. It is expected that the
41+
/// storage has the same size as the array. An error is returned if the
42+
/// attribute isn't a DenseI32ArrayAttr or it does not have the same size. If
43+
/// the optional diagnostic is provided an error message is also emitted.
44+
LogicalResult convertFromAttribute(MutableArrayRef<int32_t> storage,
45+
Attribute attr, InFlightDiagnostic *diag);
46+
4047
/// Convert the provided ArrayRef<int64_t> to a DenseI64ArrayAttr attribute.
4148
Attribute convertToAttribute(MLIRContext *ctx, ArrayRef<int64_t> storage);
4249

mlir/include/mlir/IR/OpBase.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,6 +1241,7 @@ class ArrayProperty<string storageTypeParam = "", int n, string desc = ""> :
12411241
let interfaceType = "::llvm::ArrayRef<" # storageTypeParam # ">";
12421242
let convertFromStorage = "$_storage";
12431243
let assignToStorage = "::llvm::copy($_value, $_storage)";
1244+
let hashProperty = "llvm::hash_combine_range(std::begin($_storage), std::end($_storage));";
12441245
}
12451246

12461247
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/OpDefinition.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#define MLIR_IR_OPDEFINITION_H
2121

2222
#include "mlir/IR/Dialect.h"
23+
#include "mlir/IR/ODSSupport.h"
2324
#include "mlir/IR/Operation.h"
2425
#include "llvm/Support/PointerLikeTypeTraits.h"
2526

mlir/include/mlir/IR/OperationSupport.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,8 @@ class RegisteredOperationName : public OperationName {
555555
StringRef name) final {
556556
if constexpr (hasProperties) {
557557
auto concreteOp = cast<ConcreteOp>(op);
558-
return ConcreteOp::getInherentAttr(concreteOp.getProperties(), name);
558+
return ConcreteOp::getInherentAttr(concreteOp.getContext(),
559+
concreteOp.getProperties(), name);
559560
}
560561
// If the op does not have support for properties, we dispatch back to the
561562
// dictionnary of discardable attributes for now.
@@ -575,7 +576,8 @@ class RegisteredOperationName : public OperationName {
575576
void populateInherentAttrs(Operation *op, NamedAttrList &attrs) final {
576577
if constexpr (hasProperties) {
577578
auto concreteOp = cast<ConcreteOp>(op);
578-
ConcreteOp::populateInherentAttrs(concreteOp.getProperties(), attrs);
579+
ConcreteOp::populateInherentAttrs(concreteOp.getContext(),
580+
concreteOp.getProperties(), attrs);
579581
}
580582
}
581583
LogicalResult

mlir/include/mlir/TableGen/Property.h

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,51 +35,76 @@ class Property {
3535
public:
3636
explicit Property(const llvm::Record *record);
3737
explicit Property(const llvm::DefInit *init);
38+
Property(StringRef storageType, StringRef interfaceType,
39+
StringRef convertFromStorageCall, StringRef assignToStorageCall,
40+
StringRef convertToAttributeCall, StringRef convertFromAttributeCall,
41+
StringRef readFromMlirBytecodeCall,
42+
StringRef writeToMlirBytecodeCall, StringRef hashPropertyCall,
43+
StringRef defaultValue);
3844

3945
// Returns the storage type.
40-
StringRef getStorageType() const;
46+
StringRef getStorageType() const { return storageType; }
4147

4248
// Returns the interface type for this property.
43-
StringRef getInterfaceType() const;
49+
StringRef getInterfaceType() const { return interfaceType; }
4450

4551
// Returns the template getter method call which reads this property's
4652
// storage and returns the value as of the desired return type.
47-
StringRef getConvertFromStorageCall() const;
53+
StringRef getConvertFromStorageCall() const { return convertFromStorageCall; }
4854

4955
// Returns the template setter method call which reads this property's
5056
// in the provided interface type and assign it to the storage.
51-
StringRef getAssignToStorageCall() const;
57+
StringRef getAssignToStorageCall() const { return assignToStorageCall; }
5258

5359
// Returns the conversion method call which reads this property's
5460
// in the storage type and builds an attribute.
55-
StringRef getConvertToAttributeCall() const;
61+
StringRef getConvertToAttributeCall() const { return convertToAttributeCall; }
5662

5763
// Returns the setter method call which reads this property's
5864
// in the provided interface type and assign it to the storage.
59-
StringRef getConvertFromAttributeCall() const;
65+
StringRef getConvertFromAttributeCall() const {
66+
return convertFromAttributeCall;
67+
}
6068

6169
// Returns the method call which reads this property from
6270
// bytecode and assign it to the storage.
63-
StringRef getReadFromMlirBytecodeCall() const;
71+
StringRef getReadFromMlirBytecodeCall() const {
72+
return readFromMlirBytecodeCall;
73+
}
6474

6575
// Returns the method call which write this property's
6676
// to the the bytecode.
67-
StringRef getWriteToMlirBytecodeCall() const;
77+
StringRef getWriteToMlirBytecodeCall() const {
78+
return writeToMlirBytecodeCall;
79+
}
6880

6981
// Returns the code to compute the hash for this property.
70-
StringRef getHashPropertyCall() const;
82+
StringRef getHashPropertyCall() const { return hashPropertyCall; }
7183

7284
// Returns whether this Property has a default value.
73-
bool hasDefaultValue() const;
85+
bool hasDefaultValue() const { return !defaultValue.empty(); }
86+
7487
// Returns the default value for this Property.
75-
StringRef getDefaultValue() const;
88+
StringRef getDefaultValue() const { return defaultValue; }
7689

7790
// Returns the TableGen definition this Property was constructed from.
78-
const llvm::Record &getDef() const;
91+
const llvm::Record &getDef() const { return *def; }
7992

8093
private:
8194
// The TableGen definition of this constraint.
8295
const llvm::Record *def;
96+
97+
// Elements describing a Property, in general fetched from the record.
98+
StringRef storageType;
99+
StringRef interfaceType;
100+
StringRef convertFromStorageCall;
101+
StringRef assignToStorageCall;
102+
StringRef convertToAttributeCall;
103+
StringRef convertFromAttributeCall;
104+
StringRef readFromMlirBytecodeCall;
105+
StringRef writeToMlirBytecodeCall;
106+
StringRef hashPropertyCall;
107+
StringRef defaultValue;
83108
};
84109

85110
// A struct wrapping an op property and its name together

0 commit comments

Comments
 (0)