-
Notifications
You must be signed in to change notification settings - Fork 15.6k
[mlir][bytecode] Add support for deferred attribute/type parsing #172901
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…vm#170993) Add ability to defer parsing and re-enqueueing oneself. This enables changing CallSiteLoc parsing to not recurse as deeply: previously this could fail (especially on large inputs in debug mode the recursion could overflow). Add a default depth cutoff, this could be a parameter later if needed.
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Jacques Pienaar (jpienaar) ChangesAdd ability to defer parsing and re-enqueueing oneself. This enables changing CallSiteLoc parsing to not recurse as deeply: previously this could fail (especially on large inputs in debug mode the recursion could overflow). Add a default depth cutoff, this could be a parameter later if needed. Roll-forward of #170993 with relatively direct change such that if processing while not resolving/when parsing property it eagerly resolves. Full diff: https://github.com/llvm/llvm-project/pull/172901.diff 2 Files Affected:
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 1659437e1eb24..8ba64096fbb0f 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -27,6 +27,7 @@
#include <cstddef>
#include <cstdint>
+#include <deque>
#include <list>
#include <memory>
#include <numeric>
@@ -830,6 +831,23 @@ namespace {
/// This class provides support for reading attribute and type entries from the
/// bytecode. Attribute and Type entries are read lazily on demand, so we use
/// this reader to manage when to actually parse them from the bytecode.
+///
+/// The parsing of attributes & types are generally recursive, this can lead to
+/// stack overflows for deeply nested structures, so we track a few extra pieces
+/// of information to avoid this:
+///
+/// - `depth`: The current depth while parsing nested attributes. We defer on
+/// parsing deeply nested attributes to avoid potential stack overflows. The
+/// deferred parsing is achieved by reporting a failure when parsing a nested
+/// attribute/type and registering the index of the encountered attribute/type
+/// in the deferred parsing worklist. Hence, a failure with deffered entry
+/// does not constitute a failure, it also requires that folks return on
+/// first failure rather than attempting additional parses.
+/// - `deferredWorklist`: A list of attribute/type indices that we could not
+/// parse due to hitting the depth limit. The worklist is used to capture the
+/// indices of attributes/types that need to be parsed/reparsed when we hit
+/// the depth limit. This enables moving the tracking of what needs to be
+/// parsed to the heap.
class AttrTypeReader {
/// This class represents a single attribute or type entry.
template <typename T>
@@ -863,12 +881,34 @@ class AttrTypeReader {
ArrayRef<uint8_t> sectionData,
ArrayRef<uint8_t> offsetSectionData);
+ LogicalResult readAttribute(uint64_t index, Attribute &result,
+ uint64_t depth = 0) {
+ return readEntry(attributes, index, result, "attribute", depth);
+ }
+
+ LogicalResult readType(uint64_t index, Type &result, uint64_t depth = 0) {
+ return readEntry(types, index, result, "type", depth);
+ }
+
/// Resolve the attribute or type at the given index. Returns nullptr on
/// failure.
- Attribute resolveAttribute(size_t index) {
- return resolveEntry(attributes, index, "Attribute");
+ Attribute resolveAttribute(size_t index, uint64_t depth = 0) {
+ return resolveEntry(attributes, index, "Attribute", depth);
+ }
+ Type resolveType(size_t index, uint64_t depth = 0) {
+ return resolveEntry(types, index, "Type", depth);
+ }
+
+ Attribute getAttributeOrSentinel(size_t index) {
+ if (index >= attributes.size())
+ return nullptr;
+ return attributes[index].entry;
+ }
+ Type getTypeOrSentinel(size_t index) {
+ if (index >= types.size())
+ return nullptr;
+ return types[index].entry;
}
- Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); }
/// Parse a reference to an attribute or type using the given reader.
LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) {
@@ -909,23 +949,36 @@ class AttrTypeReader {
llvm::getTypeName<T>(), ", but got: ", baseResult);
}
+ /// Add an index to the deferred worklist for re-parsing.
+ void addDeferredParsing(uint64_t index) { deferredWorklist.push_back(index); }
+
+ /// Whether currently resolving.
+ bool isResolving() const { return resolving; }
+
private:
/// Resolve the given entry at `index`.
template <typename T>
- T resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
- StringRef entryType);
+ T resolveEntry(SmallVectorImpl<Entry<T>> &entries, uint64_t index,
+ StringRef entryType, uint64_t depth = 0);
- /// Parse an entry using the given reader that was encoded using the textual
- /// assembly format.
+ /// Read the entry at the given index, returning failure if the entry is not
+ /// yet resolved.
template <typename T>
- LogicalResult parseAsmEntry(T &result, EncodingReader &reader,
- StringRef entryType);
+ LogicalResult readEntry(SmallVectorImpl<Entry<T>> &entries, uint64_t index,
+ T &result, StringRef entryType, uint64_t depth);
/// Parse an entry using the given reader that was encoded using a custom
/// bytecode format.
template <typename T>
LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader,
- StringRef entryType);
+ StringRef entryType, uint64_t index,
+ uint64_t depth);
+
+ /// Parse an entry using the given reader that was encoded using the textual
+ /// assembly format.
+ template <typename T>
+ LogicalResult parseAsmEntry(T &result, EncodingReader &reader,
+ StringRef entryType);
/// The string section reader used to resolve string references when parsing
/// custom encoded attribute/type entries.
@@ -951,6 +1004,13 @@ class AttrTypeReader {
/// Reference to the parser configuration.
const ParserConfig &parserConfig;
+
+ /// Worklist for deferred attribute/type parsing. This is used to handle
+ /// deeply nested structures like CallSiteLoc iteratively.
+ std::vector<uint64_t> deferredWorklist;
+
+ /// Flag indicating if we are currently resolving an attribute or type.
+ bool resolving = false;
};
class DialectReader : public DialectBytecodeReader {
@@ -959,10 +1019,11 @@ class DialectReader : public DialectBytecodeReader {
const StringSectionReader &stringReader,
const ResourceSectionReader &resourceReader,
const llvm::StringMap<BytecodeDialect *> &dialectsMap,
- EncodingReader &reader, uint64_t &bytecodeVersion)
+ EncodingReader &reader, uint64_t &bytecodeVersion,
+ uint64_t depth = 0)
: attrTypeReader(attrTypeReader), stringReader(stringReader),
resourceReader(resourceReader), dialectsMap(dialectsMap),
- reader(reader), bytecodeVersion(bytecodeVersion) {}
+ reader(reader), bytecodeVersion(bytecodeVersion), depth(depth) {}
InFlightDiagnostic emitError(const Twine &msg) const override {
return reader.emitError(msg);
@@ -998,14 +1059,64 @@ class DialectReader : public DialectBytecodeReader {
// IR
//===--------------------------------------------------------------------===//
+ /// The maximum depth to eagerly parse nested attributes/types before
+ /// deferring.
+ static constexpr uint64_t maxAttrTypeDepth = 5;
+
LogicalResult readAttribute(Attribute &result) override {
- return attrTypeReader.parseAttribute(reader, result);
+ uint64_t index;
+ if (failed(reader.parseVarInt(index)))
+ return failure();
+
+ // If we aren't currently resolving an attribute/type, we resolve this
+ // attribute eagerly. This is the case when we are parsing properties, which
+ // aren't processed via the worklist.
+ if (!attrTypeReader.isResolving()) {
+ if (Attribute attr = attrTypeReader.resolveAttribute(index)) {
+ result = attr;
+ return success();
+ }
+ return failure();
+ }
+
+ if (depth > maxAttrTypeDepth) {
+ if (Attribute attr = attrTypeReader.getAttributeOrSentinel(index)) {
+ result = attr;
+ return success();
+ }
+ attrTypeReader.addDeferredParsing(index);
+ return failure();
+ }
+ return attrTypeReader.readAttribute(index, result, depth + 1);
}
LogicalResult readOptionalAttribute(Attribute &result) override {
return attrTypeReader.parseOptionalAttribute(reader, result);
}
LogicalResult readType(Type &result) override {
- return attrTypeReader.parseType(reader, result);
+ uint64_t index;
+ if (failed(reader.parseVarInt(index)))
+ return failure();
+
+ // If we aren't currently resolving an attribute/type, we resolve this
+ // type eagerly. This is the case when we are parsing properties, which
+ // aren't processed via the worklist.
+ if (!attrTypeReader.isResolving()) {
+ if (Type type = attrTypeReader.resolveType(index)) {
+ result = type;
+ return success();
+ }
+ return failure();
+ }
+
+ if (depth > maxAttrTypeDepth) {
+ if (Type type = attrTypeReader.getTypeOrSentinel(index)) {
+ result = type;
+ return success();
+ }
+ attrTypeReader.addDeferredParsing(index);
+ return failure();
+ }
+ return attrTypeReader.readType(index, result, depth + 1);
}
FailureOr<AsmDialectResourceHandle> readResourceHandle() override {
@@ -1095,6 +1206,7 @@ class DialectReader : public DialectBytecodeReader {
const llvm::StringMap<BytecodeDialect *> &dialectsMap;
EncodingReader &reader;
uint64_t &bytecodeVersion;
+ uint64_t depth;
};
/// Wraps the properties section and handles reading properties out of it.
@@ -1239,68 +1351,115 @@ LogicalResult AttrTypeReader::initialize(
template <typename T>
T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
- StringRef entryType) {
+ StringRef entryType, uint64_t depth) {
+ bool oldResolving = resolving;
+ resolving = true;
+ auto restoreResolving =
+ llvm::make_scope_exit([&]() { resolving = oldResolving; });
+
if (index >= entries.size()) {
emitError(fileLoc) << "invalid " << entryType << " index: " << index;
return {};
}
- // If the entry has already been resolved, there is nothing left to do.
- Entry<T> &entry = entries[index];
- if (entry.entry)
- return entry.entry;
+ // Fast path: Try direct parsing without worklist overhead. This handles the
+ // common case where there are no deferred dependencies.
+ assert(deferredWorklist.empty());
+ T result;
+ if (succeeded(readEntry(entries, index, result, entryType, depth))) {
+ assert(deferredWorklist.empty());
+ return result;
+ }
+ if (deferredWorklist.empty()) {
+ // Failed with no deferred entries is error.
+ return T();
+ }
- // Parse the entry.
- EncodingReader reader(entry.data, fileLoc);
+ // Slow path: Use worklist to handle deferred dependencies. Use a deque to
+ // iteratively resolve entries with dependencies.
+ // - Pop from front to process
+ // - Push new dependencies to front (depth-first)
+ // - Move failed entries to back (retry after dependencies)
+ std::deque<size_t> worklist;
+ llvm::DenseSet<size_t> inWorklist;
- // Parse based on how the entry was encoded.
- if (entry.hasCustomEncoding) {
- if (failed(parseCustomEntry(entry, reader, entryType)))
- return T();
- } else if (failed(parseAsmEntry(entry.entry, reader, entryType))) {
- return T();
+ // Add the original index and any dependencies from the fast path attempt.
+ worklist.push_back(index);
+ inWorklist.insert(index);
+ for (uint64_t idx : llvm::reverse(deferredWorklist)) {
+ if (inWorklist.insert(idx).second)
+ worklist.push_front(idx);
}
- if (!reader.empty()) {
- reader.emitError("unexpected trailing bytes after " + entryType + " entry");
- return T();
+ while (!worklist.empty()) {
+ size_t currentIndex = worklist.front();
+ worklist.pop_front();
+
+ // Clear the deferred worklist before parsing to capture any new entries.
+ deferredWorklist.clear();
+
+ T result;
+ if (succeeded(readEntry(entries, currentIndex, result, entryType, depth))) {
+ inWorklist.erase(currentIndex);
+ continue;
+ }
+
+ if (deferredWorklist.empty()) {
+ // Parsing failed with no deferred entries which implies an error.
+ return T();
+ }
+
+ // Move this entry to the back to retry after dependencies.
+ worklist.push_back(currentIndex);
+
+ // Add dependencies to the front (in reverse so they maintain order).
+ for (uint64_t idx : llvm::reverse(deferredWorklist)) {
+ if (inWorklist.insert(idx).second)
+ worklist.push_front(idx);
+ }
+ deferredWorklist.clear();
}
- return entry.entry;
+ return entries[index].entry;
}
template <typename T>
-LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
- StringRef entryType) {
- StringRef asmStr;
- if (failed(reader.parseNullTerminatedString(asmStr)))
- return failure();
+LogicalResult AttrTypeReader::readEntry(SmallVectorImpl<Entry<T>> &entries,
+ uint64_t index, T &result,
+ StringRef entryType, uint64_t depth) {
+ if (index >= entries.size())
+ return emitError(fileLoc) << "invalid " << entryType << " index: " << index;
- // Invoke the MLIR assembly parser to parse the entry text.
- size_t numRead = 0;
- MLIRContext *context = fileLoc->getContext();
- if constexpr (std::is_same_v<T, Type>)
- result =
- ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true);
- else
- result = ::parseAttribute(asmStr, context, Type(), &numRead,
- /*isKnownNullTerminated=*/true);
- if (!result)
+ // If the entry has already been resolved, return it.
+ Entry<T> &entry = entries[index];
+ if (entry.entry) {
+ result = entry.entry;
+ return success();
+ }
+
+ // If the entry hasn't been resolved, try to parse it.
+ EncodingReader reader(entry.data, fileLoc);
+ LogicalResult parseResult =
+ entry.hasCustomEncoding
+ ? parseCustomEntry(entry, reader, entryType, index, depth)
+ : parseAsmEntry(entry.entry, reader, entryType);
+ if (failed(parseResult))
return failure();
- // Ensure there weren't dangling characters after the entry.
- if (numRead != asmStr.size()) {
- return reader.emitError("trailing characters found after ", entryType,
- " assembly format: ", asmStr.drop_front(numRead));
- }
+ if (!reader.empty())
+ return reader.emitError("unexpected trailing bytes after " + entryType +
+ " entry");
+
+ result = entry.entry;
return success();
}
template <typename T>
LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
EncodingReader &reader,
- StringRef entryType) {
+ StringRef entryType,
+ uint64_t index, uint64_t depth) {
DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap,
- reader, bytecodeVersion);
+ reader, bytecodeVersion, depth);
if (failed(entry.dialect->load(dialectReader, fileLoc.getContext())))
return failure();
@@ -1350,6 +1509,33 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
return success(!!entry.entry);
}
+template <typename T>
+LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
+ StringRef entryType) {
+ StringRef asmStr;
+ if (failed(reader.parseNullTerminatedString(asmStr)))
+ return failure();
+
+ // Invoke the MLIR assembly parser to parse the entry text.
+ size_t numRead = 0;
+ MLIRContext *context = fileLoc->getContext();
+ if constexpr (std::is_same_v<T, Type>)
+ result =
+ ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true);
+ else
+ result = ::parseAttribute(asmStr, context, Type(), &numRead,
+ /*isKnownNullTerminated=*/true);
+ if (!result)
+ return failure();
+
+ // Ensure there weren't dangling characters after the entry.
+ if (numRead != asmStr.size()) {
+ return reader.emitError("trailing characters found after ", entryType,
+ " assembly format: ", asmStr.drop_front(numRead));
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Bytecode Reader
//===----------------------------------------------------------------------===//
diff --git a/mlir/unittests/Bytecode/BytecodeTest.cpp b/mlir/unittests/Bytecode/BytecodeTest.cpp
index d7b442f6832d0..30e7ed9b6cb7e 100644
--- a/mlir/unittests/Bytecode/BytecodeTest.cpp
+++ b/mlir/unittests/Bytecode/BytecodeTest.cpp
@@ -15,6 +15,7 @@
#include "mlir/IR/OwningOpRef.h"
#include "mlir/Parser/Parser.h"
+#include "mlir/IR/BuiltinOps.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Alignment.h"
#include "llvm/Support/Endian.h"
@@ -228,3 +229,39 @@ TEST(Bytecode, OpWithoutProperties) {
EXPECT_TRUE(OperationEquivalence::computeHash(op.get()) ==
OperationEquivalence::computeHash(roundtripped));
}
+
+TEST(Bytecode, DeepCallSiteLoc) {
+ MLIRContext context;
+ ParserConfig config(&context);
+
+ // Create a deep CallSiteLoc chain to test iterative parsing.
+ Location baseLoc = FileLineColLoc::get(&context, "test.mlir", 1, 1);
+ Location loc = baseLoc;
+ constexpr int kDepth = 1000;
+ for (int i = 0; i < kDepth; ++i) {
+ loc = CallSiteLoc::get(loc, baseLoc);
+ }
+
+ // Create a simple module with the deep location.
+ Builder builder(&context);
+ OwningOpRef<ModuleOp> module =
+ ModuleOp::create(loc, /*attributes=*/std::nullopt);
+ ASSERT_TRUE(module);
+
+ // Write to bytecode.
+ std::string bytecode;
+ llvm::raw_string_ostream os(bytecode);
+ ASSERT_TRUE(succeeded(writeBytecodeToFile(module.get(), os)));
+
+ // Parse it back using the bytecode reader.
+ std::unique_ptr<Block> block = std::make_unique<Block>();
+ ASSERT_TRUE(succeeded(readBytecodeFile(
+ llvm::MemoryBufferRef(bytecode, "string-buffer"), block.get(), config)));
+
+ // Verify we got the roundtripped module.
+ ASSERT_FALSE(block->empty());
+ Operation *roundTripped = &block->front();
+
+ // Verify the location matches.
+ EXPECT_EQ(module.get()->getLoc(), roundTripped->getLoc());
+}
|
|
Submitting fix forward and can refine post. |
`uint64_t` != `size_t` on macos (use `uint64_t` as elsewhere in this impl)
`uint64_t` != `size_t` on macos (use `uint64_t` to match other uses in this impl)
…m#172901) Add ability to defer parsing and re-enqueueing oneself. This enables changing CallSiteLoc parsing to not recurse as deeply: previously this could fail (especially on large inputs in debug mode the recursion could overflow). Add a default depth cutoff, this could be a parameter later if needed. Roll-forward of llvm#170993 with relatively direct change such that if processing while not resolving/when parsing property it eagerly resolves.
Add ability to defer parsing and re-enqueueing oneself. This enables changing CallSiteLoc parsing to not recurse as deeply: previously this could fail (especially on large inputs in debug mode the recursion could overflow). Add a default depth cutoff, this could be a parameter later if needed.
Roll-forward of #170993 with relatively direct change such that if processing while not resolving/when parsing property it eagerly resolves.