Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 204 additions & 53 deletions mlir/lib/Bytecode/Reader/BytecodeReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <cstddef>
#include <cstdint>
#include <deque>
#include <list>
#include <memory>
#include <numeric>
Expand Down Expand Up @@ -830,6 +831,23 @@ namespace {
/// This class provides support for reading attribute and type entries from the
/// bytecode. Attribute and Type entries are read lazily on demand, so we use
/// this reader to manage when to actually parse them from the bytecode.
///
/// The parsing of attributes & types are generally recursive, this can lead to
/// stack overflows for deeply nested structures, so we track a few extra pieces
/// of information to avoid this:
///
/// - `depth`: The current depth while parsing nested attributes. We defer on
/// parsing deeply nested attributes to avoid potential stack overflows. The
/// deferred parsing is achieved by reporting a failure when parsing a nested
/// attribute/type and registering the index of the encountered attribute/type
/// in the deferred parsing worklist. Hence, a failure with deffered entry
/// does not constitute a failure, it also requires that folks return on
/// first failure rather than attempting additional parses.
/// - `deferredWorklist`: A list of attribute/type indices that we could not
/// parse due to hitting the depth limit. The worklist is used to capture the
/// indices of attributes/types that need to be parsed/reparsed when we hit
/// the depth limit. This enables moving the tracking of what needs to be
/// parsed to the heap.
class AttrTypeReader {
/// This class represents a single attribute or type entry.
template <typename T>
Expand Down Expand Up @@ -863,12 +881,34 @@ class AttrTypeReader {
ArrayRef<uint8_t> sectionData,
ArrayRef<uint8_t> offsetSectionData);

LogicalResult readAttribute(uint64_t index, Attribute &result,
uint64_t depth = 0) {
return readEntry(attributes, index, result, "attribute", depth);
}

LogicalResult readType(uint64_t index, Type &result, uint64_t depth = 0) {
return readEntry(types, index, result, "type", depth);
}

/// Resolve the attribute or type at the given index. Returns nullptr on
/// failure.
Attribute resolveAttribute(size_t index) {
return resolveEntry(attributes, index, "Attribute");
Attribute resolveAttribute(size_t index, uint64_t depth = 0) {
return resolveEntry(attributes, index, "Attribute", depth);
}
Type resolveType(size_t index, uint64_t depth = 0) {
return resolveEntry(types, index, "Type", depth);
}

Attribute getAttributeOrSentinel(size_t index) {
if (index >= attributes.size())
return nullptr;
return attributes[index].entry;
}
Type getTypeOrSentinel(size_t index) {
if (index >= types.size())
return nullptr;
return types[index].entry;
}
Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); }

/// Parse a reference to an attribute or type using the given reader.
LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) {
Expand Down Expand Up @@ -909,23 +949,33 @@ class AttrTypeReader {
llvm::getTypeName<T>(), ", but got: ", baseResult);
}

/// Add an index to the deferred worklist for re-parsing.
void addDeferredParsing(uint64_t index) { deferredWorklist.push_back(index); }

private:
/// Resolve the given entry at `index`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you document "depth" here and elsewhere?

template <typename T>
T resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
StringRef entryType);
T resolveEntry(SmallVectorImpl<Entry<T>> &entries, uint64_t index,
StringRef entryType, uint64_t depth = 0);

/// Parse an entry using the given reader that was encoded using the textual
/// assembly format.
/// Read the entry at the given index, returning failure if the entry is not
/// yet resolved.
template <typename T>
LogicalResult parseAsmEntry(T &result, EncodingReader &reader,
StringRef entryType);
LogicalResult readEntry(SmallVectorImpl<Entry<T>> &entries, uint64_t index,
T &result, StringRef entryType, uint64_t depth);

/// Parse an entry using the given reader that was encoded using a custom
/// bytecode format.
template <typename T>
LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader,
StringRef entryType);
StringRef entryType, uint64_t index,
uint64_t depth);

/// Parse an entry using the given reader that was encoded using the textual
/// assembly format.
template <typename T>
LogicalResult parseAsmEntry(T &result, EncodingReader &reader,
StringRef entryType);

/// The string section reader used to resolve string references when parsing
/// custom encoded attribute/type entries.
Expand All @@ -951,6 +1001,10 @@ class AttrTypeReader {

/// Reference to the parser configuration.
const ParserConfig &parserConfig;

/// Worklist for deferred attribute/type parsing. This is used to handle
/// deeply nested structures like CallSiteLoc iteratively.
std::vector<uint64_t> deferredWorklist;
};

class DialectReader : public DialectBytecodeReader {
Expand All @@ -959,10 +1013,11 @@ class DialectReader : public DialectBytecodeReader {
const StringSectionReader &stringReader,
const ResourceSectionReader &resourceReader,
const llvm::StringMap<BytecodeDialect *> &dialectsMap,
EncodingReader &reader, uint64_t &bytecodeVersion)
EncodingReader &reader, uint64_t &bytecodeVersion,
uint64_t depth = 0)
: attrTypeReader(attrTypeReader), stringReader(stringReader),
resourceReader(resourceReader), dialectsMap(dialectsMap),
reader(reader), bytecodeVersion(bytecodeVersion) {}
reader(reader), bytecodeVersion(bytecodeVersion), depth(depth) {}

InFlightDiagnostic emitError(const Twine &msg) const override {
return reader.emitError(msg);
Expand Down Expand Up @@ -998,14 +1053,40 @@ class DialectReader : public DialectBytecodeReader {
// IR
//===--------------------------------------------------------------------===//

/// The maximum depth to eagerly parse nested attributes/types before
/// deferring.
static constexpr uint64_t maxAttrTypeDepth = 5;

LogicalResult readAttribute(Attribute &result) override {
return attrTypeReader.parseAttribute(reader, result);
uint64_t index;
if (failed(reader.parseVarInt(index)))
return failure();
if (depth > maxAttrTypeDepth) {
if (Attribute attr = attrTypeReader.getAttributeOrSentinel(index)) {
result = attr;
return success();
}
attrTypeReader.addDeferredParsing(index);
return failure();
}
return attrTypeReader.readAttribute(index, result, depth + 1);
}
LogicalResult readOptionalAttribute(Attribute &result) override {
return attrTypeReader.parseOptionalAttribute(reader, result);
}
LogicalResult readType(Type &result) override {
return attrTypeReader.parseType(reader, result);
uint64_t index;
if (failed(reader.parseVarInt(index)))
return failure();
if (depth > maxAttrTypeDepth) {
if (Type type = attrTypeReader.getTypeOrSentinel(index)) {
result = type;
return success();
}
attrTypeReader.addDeferredParsing(index);
return failure();
}
return attrTypeReader.readType(index, result, depth + 1);
}

FailureOr<AsmDialectResourceHandle> readResourceHandle() override {
Expand Down Expand Up @@ -1095,6 +1176,7 @@ class DialectReader : public DialectBytecodeReader {
const llvm::StringMap<BytecodeDialect *> &dialectsMap;
EncodingReader &reader;
uint64_t &bytecodeVersion;
uint64_t depth;
};

/// Wraps the properties section and handles reading properties out of it.
Expand Down Expand Up @@ -1239,68 +1321,110 @@ LogicalResult AttrTypeReader::initialize(

template <typename T>
T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
StringRef entryType) {
StringRef entryType, uint64_t depth) {
if (index >= entries.size()) {
emitError(fileLoc) << "invalid " << entryType << " index: " << index;
return {};
}

// If the entry has already been resolved, there is nothing left to do.
Entry<T> &entry = entries[index];
if (entry.entry)
return entry.entry;
// Fast path: Try direct parsing without worklist overhead. This handles the
// common case where there are no deferred dependencies.
assert(deferredWorklist.empty());
T result;
if (succeeded(readEntry(entries, index, result, entryType, depth))) {
assert(deferredWorklist.empty());
return result;
}
if (deferredWorklist.empty()) {
// Failed with no deferred entries is error.
return T();
}

// Parse the entry.
EncodingReader reader(entry.data, fileLoc);
// Slow path: Use worklist to handle deferred dependencies. Use a deque to
// iteratively resolve entries with dependencies.
// - Pop from front to process
// - Push new dependencies to front (depth-first)
// - Move failed entries to back (retry after dependencies)
std::deque<size_t> worklist;
llvm::DenseSet<size_t> inWorklist;

// Parse based on how the entry was encoded.
if (entry.hasCustomEncoding) {
if (failed(parseCustomEntry(entry, reader, entryType)))
return T();
} else if (failed(parseAsmEntry(entry.entry, reader, entryType))) {
return T();
// Add the original index and any dependencies from the fast path attempt.
worklist.push_back(index);
inWorklist.insert(index);
for (uint64_t idx : llvm::reverse(deferredWorklist)) {
if (inWorklist.insert(idx).second)
worklist.push_front(idx);
}

if (!reader.empty()) {
reader.emitError("unexpected trailing bytes after " + entryType + " entry");
return T();
while (!worklist.empty()) {
size_t currentIndex = worklist.front();
worklist.pop_front();

// Clear the deferred worklist before parsing to capture any new entries.
deferredWorklist.clear();

T result;
if (succeeded(readEntry(entries, currentIndex, result, entryType, depth))) {
inWorklist.erase(currentIndex);
continue;
}

if (deferredWorklist.empty()) {
// Parsing failed with no deferred entries which implies an error.
return T();
}

// Move this entry to the back to retry after dependencies.
worklist.push_back(currentIndex);

// Add dependencies to the front (in reverse so they maintain order).
for (uint64_t idx : llvm::reverse(deferredWorklist)) {
if (inWorklist.insert(idx).second)
worklist.push_front(idx);
}
deferredWorklist.clear();
}
return entry.entry;
return entries[index].entry;
}

template <typename T>
LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
StringRef entryType) {
StringRef asmStr;
if (failed(reader.parseNullTerminatedString(asmStr)))
return failure();
LogicalResult AttrTypeReader::readEntry(SmallVectorImpl<Entry<T>> &entries,
uint64_t index, T &result,
StringRef entryType, uint64_t depth) {
if (index >= entries.size())
return emitError(fileLoc) << "invalid " << entryType << " index: " << index;

// Invoke the MLIR assembly parser to parse the entry text.
size_t numRead = 0;
MLIRContext *context = fileLoc->getContext();
if constexpr (std::is_same_v<T, Type>)
result =
::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true);
else
result = ::parseAttribute(asmStr, context, Type(), &numRead,
/*isKnownNullTerminated=*/true);
if (!result)
// If the entry has already been resolved, return it.
Entry<T> &entry = entries[index];
if (entry.entry) {
result = entry.entry;
return success();
}

// If the entry hasn't been resolved, try to parse it.
EncodingReader reader(entry.data, fileLoc);
LogicalResult parseResult =
entry.hasCustomEncoding
? parseCustomEntry(entry, reader, entryType, index, depth)
: parseAsmEntry(entry.entry, reader, entryType);
if (failed(parseResult))
return failure();

// Ensure there weren't dangling characters after the entry.
if (numRead != asmStr.size()) {
return reader.emitError("trailing characters found after ", entryType,
" assembly format: ", asmStr.drop_front(numRead));
}
if (!reader.empty())
return reader.emitError("unexpected trailing bytes after " + entryType +
" entry");

result = entry.entry;
return success();
}

template <typename T>
LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
EncodingReader &reader,
StringRef entryType) {
StringRef entryType,
uint64_t index, uint64_t depth) {
DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap,
reader, bytecodeVersion);
reader, bytecodeVersion, depth);
if (failed(entry.dialect->load(dialectReader, fileLoc.getContext())))
return failure();

Expand Down Expand Up @@ -1350,6 +1474,33 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
return success(!!entry.entry);
}

template <typename T>
LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
StringRef entryType) {
StringRef asmStr;
if (failed(reader.parseNullTerminatedString(asmStr)))
return failure();

// Invoke the MLIR assembly parser to parse the entry text.
size_t numRead = 0;
MLIRContext *context = fileLoc->getContext();
if constexpr (std::is_same_v<T, Type>)
result =
::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true);
else
result = ::parseAttribute(asmStr, context, Type(), &numRead,
/*isKnownNullTerminated=*/true);
if (!result)
return failure();

// Ensure there weren't dangling characters after the entry.
if (numRead != asmStr.size()) {
return reader.emitError("trailing characters found after ", entryType,
" assembly format: ", asmStr.drop_front(numRead));
}
return success();
}

//===----------------------------------------------------------------------===//
// Bytecode Reader
//===----------------------------------------------------------------------===//
Expand Down
Loading