2020#include " mlir/IR/DialectInterface.h"
2121#include " mlir/IR/OpImplementation.h"
2222#include " mlir/Support/LogicalResult.h"
23+ #include " llvm/ADT/STLExtras.h"
2324#include " llvm/ADT/Twine.h"
2425
2526namespace mlir {
@@ -39,6 +40,9 @@ class DialectBytecodeReader {
3940 // / Emit an error to the reader.
4041 virtual InFlightDiagnostic emitError (const Twine &msg = {}) = 0;
4142
43+ // / Return the bytecode version being read.
44+ virtual uint64_t getBytecodeVersion () const = 0;
45+
4246 // / Read out a list of elements, invoking the provided callback for each
4347 // / element. The callback function may be in any of the following forms:
4448 // / * LogicalResult(T &)
@@ -148,6 +152,76 @@ class DialectBytecodeReader {
148152 [this ](int64_t &value) { return readSignedVarInt (value); });
149153 }
150154
155+ // / Parse a variable length encoded integer whose low bit is used to encode an
156+ // / unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`.
157+ LogicalResult readVarIntWithFlag (uint64_t &result, bool &flag) {
158+ if (failed (readVarInt (result)))
159+ return failure ();
160+ flag = result & 1 ;
161+ result >>= 1 ;
162+ return success ();
163+ }
164+
165+ // / Read a "small" sparse array of integer <= 32 bits elements, where
166+ // / index/value pairs can be compressed when the array is small.
167+ // / Note that only some position of the array will be read and the ones
168+ // / not stored in the bytecode are gonne be left untouched.
169+ // / If the provided array is too small for the stored indices, an error
170+ // / will be returned.
171+ template <typename T>
172+ LogicalResult readSparseArray (MutableArrayRef<T> array) {
173+ static_assert (sizeof (T) < sizeof (uint64_t ), " expect integer < 64 bits" );
174+ static_assert (std::is_integral<T>::value, " expects integer" );
175+ uint64_t nonZeroesCount;
176+ bool useSparseEncoding;
177+ if (failed (readVarIntWithFlag (nonZeroesCount, useSparseEncoding)))
178+ return failure ();
179+ if (nonZeroesCount == 0 )
180+ return success ();
181+ if (!useSparseEncoding) {
182+ // This is a simple dense array.
183+ if (nonZeroesCount > array.size ()) {
184+ emitError (" trying to read an array of " )
185+ << nonZeroesCount << " but only " << array.size ()
186+ << " storage available." ;
187+ return failure ();
188+ }
189+ for (int64_t index : llvm::seq<int64_t >(0 , nonZeroesCount)) {
190+ uint64_t value;
191+ if (failed (readVarInt (value)))
192+ return failure ();
193+ array[index] = value;
194+ }
195+ return success ();
196+ }
197+ // Read sparse encoding
198+ // This is the number of bits used for packing the index with the value.
199+ uint64_t indexBitSize;
200+ if (failed (readVarInt (indexBitSize)))
201+ return failure ();
202+ constexpr uint64_t maxIndexBitSize = 8 ;
203+ if (indexBitSize > maxIndexBitSize) {
204+ emitError (" reading sparse array with indexing above 8 bits: " )
205+ << indexBitSize;
206+ return failure ();
207+ }
208+ for (uint32_t count : llvm::seq<uint32_t >(0 , nonZeroesCount)) {
209+ (void )count;
210+ uint64_t indexValuePair;
211+ if (failed (readVarInt (indexValuePair)))
212+ return failure ();
213+ uint64_t index = indexValuePair & ~(uint64_t (-1 ) << (indexBitSize));
214+ uint64_t value = indexValuePair >> indexBitSize;
215+ if (index >= array.size ()) {
216+ emitError (" reading a sparse array found index " )
217+ << index << " but only " << array.size () << " storage available." ;
218+ return failure ();
219+ }
220+ array[index] = value;
221+ }
222+ return success ();
223+ }
224+
151225 // / Read an APInt that is known to have been encoded with the given width.
152226 virtual FailureOr<APInt> readAPIntWithKnownWidth (unsigned bitWidth) = 0;
153227
@@ -230,6 +304,55 @@ class DialectBytecodeWriter {
230304 writeList (value, [this ](int64_t value) { writeSignedVarInt (value); });
231305 }
232306
307+ // / Write a VarInt and a flag packed together.
308+ void writeVarIntWithFlag (uint64_t value, bool flag) {
309+ writeVarInt ((value << 1 ) | (flag ? 1 : 0 ));
310+ }
311+
312+ // / Write out a "small" sparse array of integer <= 32 bits elements, where
313+ // / index/value pairs can be compressed when the array is small. This method
314+ // / will scan the array multiple times and should not be used for large
315+ // / arrays. The optional provided "zero" can be used to adjust for the
316+ // / expected repeated value. We assume here that the array size fits in a 32
317+ // / bits integer.
318+ template <typename T>
319+ void writeSparseArray (ArrayRef<T> array) {
320+ static_assert (sizeof (T) < sizeof (uint64_t ), " expect integer < 64 bits" );
321+ static_assert (std::is_integral<T>::value, " expects integer" );
322+ uint32_t size = array.size ();
323+ uint32_t nonZeroesCount = 0 , lastIndex = 0 ;
324+ for (uint32_t index : llvm::seq<uint32_t >(0 , size)) {
325+ if (!array[index])
326+ continue ;
327+ nonZeroesCount++;
328+ lastIndex = index;
329+ }
330+ // If the last position is too large, or the array isn't at least 50%
331+ // sparse, emit it with a dense encoding.
332+ if (lastIndex > 256 || nonZeroesCount > size / 2 ) {
333+ // Emit the array size and a flag which indicates whether it is sparse.
334+ writeVarIntWithFlag (size, false );
335+ for (const T &elt : array)
336+ writeVarInt (elt);
337+ return ;
338+ }
339+ // Emit sparse: first the number of elements we'll write and a flag
340+ // indicating it is a sparse encoding.
341+ writeVarIntWithFlag (nonZeroesCount, true );
342+ if (nonZeroesCount == 0 )
343+ return ;
344+ // This is the number of bits used for packing the index with the value.
345+ int indexBitSize = llvm::Log2_32_Ceil (lastIndex + 1 );
346+ writeVarInt (indexBitSize);
347+ for (uint32_t index : llvm::seq<uint32_t >(0 , lastIndex + 1 )) {
348+ T value = array[index];
349+ if (!value)
350+ continue ;
351+ uint64_t indexValuePair = (value << indexBitSize) | (index);
352+ writeVarInt (indexValuePair);
353+ }
354+ }
355+
233356 // / Write an APInt to the bytecode stream whose bitwidth will be known
234357 // / externally at read time. This method is useful for encoding APInt values
235358 // / when the width is known via external means, such as via a type. This
0 commit comments