Skip to content

Conversation

@jpienaar
Copy link
Member

@jpienaar jpienaar commented Dec 6, 2025

Add ability to defer parsing and re-enqueueing oneself. This enables changing CallSiteLoc parsing to not recurse as deeply: previously this could fail (especially on large inputs in debug mode the recursion could overflow). Add a default depth cutoff, this could be a parameter later if needed.

@jpienaar jpienaar requested a review from joker-eph December 6, 2025 18:48
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:ods labels Dec 6, 2025
@llvmbot
Copy link
Member

llvmbot commented Dec 6, 2025

@llvm/pr-subscribers-mlir-ods
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Jacques Pienaar (jpienaar)

Changes

Add ability to defer parsing and re-enqueueing oneself. This enables changing CallSiteLoc parsing to not recurse as deeply: previously this could fail (especially on large inputs in debug mode the recursion could overflow).

I tried a few different variants of this (peeking, thunks, adding preload section, considered breaking encoding change), but this felt the cleanest but at cost of introducing a worklist that mostly only has 1 item.


Full diff: https://github.com/llvm/llvm-project/pull/170993.diff

5 Files Affected:

  • (modified) mlir/docs/DefiningDialects/_index.md (+20)
  • (modified) mlir/include/mlir/Bytecode/BytecodeImplementation.h (+21)
  • (modified) mlir/include/mlir/IR/BuiltinDialectBytecode.td (+8-2)
  • (modified) mlir/lib/Bytecode/Reader/BytecodeReader.cpp (+115-24)
  • (modified) mlir/unittests/Bytecode/BytecodeTest.cpp (+37)
diff --git a/mlir/docs/DefiningDialects/_index.md b/mlir/docs/DefiningDialects/_index.md
index 987b51b4ab4ef..9c9f9c93fcf39 100644
--- a/mlir/docs/DefiningDialects/_index.md
+++ b/mlir/docs/DefiningDialects/_index.md
@@ -425,6 +425,26 @@ struct FooDialectBytecodeInterface : public BytecodeDialectInterface {
 along with defining the corresponding build rules to invoke generator
 (`-gen-bytecode -bytecode-dialect="Quant"`).
 
+#### Deferred Parsing for Recursive Dependencies
+
+When parsing attributes or types that reference other attributes or types (e.g.,
+`CallSiteLoc` which contains nested location attributes), the referenced entries
+may not yet be resolved. The `DialectBytecodeReader` provides helpers to handle
+this:
+
+```c++
+Attribute attr = reader.getOrDeferParsingAttribute();
+if (!attr)
+  return failure();  // Will be retried after dependencies are resolved
+```
+
+The `getOrDeferParsingAttribute()` method reads the attribute index from the
+stream and attempts to resolve it. If the referenced attribute hasn't been
+parsed yet, it registers for deferred parsing and returns nullptr. The bytecode
+reader will automatically retry parsing after processing the dependencies.
+
+Note: for error cases, one needs to return failure *before* deferring parsing.
+
 ## Defining an Extensible dialect
 
 This section documents the design and API of the extensible dialects. Extensible
diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
index 0ddc531073e23..65e7d23fa1139 100644
--- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h
+++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
@@ -103,6 +103,16 @@ class DialectBytecodeReader {
   /// the Attribute isn't present.
   virtual LogicalResult readOptionalAttribute(Attribute &attr) = 0;
 
+  /// Try to get an attribute, deferring parsing if not yet resolved (returning
+  /// nullptr and enqueuing for deferred parsing).
+  virtual Attribute getOrDeferParsingAttribute() = 0;
+
+  /// Typed version of getOrDeferParsingAttribute. Returns the attribute cast
+  /// to the specified type, or nullptr if not resolved or cast fails.
+  template <typename T> T getOrDeferParsingAttribute() {
+    return llvm::dyn_cast_or_null<T>(getOrDeferParsingAttribute());
+  }
+
   template <typename T>
   LogicalResult readAttributes(SmallVectorImpl<T> &attrs) {
     return readList(attrs, [this](T &attr) { return readAttribute(attr); });
@@ -132,6 +142,17 @@ class DialectBytecodeReader {
 
   /// Read a reference to the given type.
   virtual LogicalResult readType(Type &result) = 0;
+
+  /// Try to get an type, deferring parsing if not yet resolved (returning
+  /// nullptr and enqueuing for deferred parsing).
+  virtual Type getOrDeferParsingType() = 0;
+
+  /// Typed version of getOrDeferParsingType. Returns the type cast
+  /// to the specified type, or nullptr if not resolved or cast fails.
+  template <typename T> T getOrDeferParsingType() {
+    return llvm::dyn_cast_or_null<T>(getOrDeferParsingType());
+  }
+
   template <typename T>
   LogicalResult readTypes(SmallVectorImpl<T> &types) {
     return readList(types, [this](T &type) { return readType(type); });
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 0208e8cdbf293..4162d6dad3c67 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -25,6 +25,12 @@ def Location : CompositeBytecode {
   let cBuilder = "Location($_args)";
 }
 
+def MaybeDeferredLocationAttr :
+  WithParser <"($_var = $_reader.getOrDeferParsingAttribute<LocationAttr>())",
+  WithBuilder<"$_args",
+  WithPrinter<"$_writer.writeAttribute($_getter)",
+  WithType   <"LocationAttr">>>>;
+
 def String :
   WithParser <"succeeded($_reader.readString($_var))",
   WithBuilder<"$_args",
@@ -91,8 +97,8 @@ def FloatAttr : DialectAttribute<(attr
 }
 
 def CallSiteLoc : DialectAttribute<(attr
-  LocationAttr:$callee,
-  LocationAttr:$caller
+  MaybeDeferredLocationAttr:$callee,
+  MaybeDeferredLocationAttr:$caller
 )>;
 
 let cType = "FileLineColRange" in {
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 1659437e1eb24..9c6de9ed6ebec 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -27,6 +27,7 @@
 
 #include <cstddef>
 #include <cstdint>
+#include <deque>
 #include <list>
 #include <memory>
 #include <numeric>
@@ -925,7 +926,7 @@ class AttrTypeReader {
   /// bytecode format.
   template <typename T>
   LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader,
-                                 StringRef entryType);
+                                 StringRef entryType, uint64_t index);
 
   /// The string section reader used to resolve string references when parsing
   /// custom encoded attribute/type entries.
@@ -951,6 +952,28 @@ class AttrTypeReader {
 
   /// Reference to the parser configuration.
   const ParserConfig &parserConfig;
+
+  /// Worklist for deferred attribute/type parsing. This is used to
+  /// handle deeply nested structures like CallSiteLoc iteratively.
+  std::vector<uint64_t> deferredWorklist;
+
+public:
+  /// Get the attribute at the given index, returning null if not resolved.
+  Attribute getAttributeOrSentinel(size_t index) {
+    if (index >= attributes.size())
+      return {};
+    return attributes[index].entry;
+  }
+
+  /// Get the type at the given index, returning null if not resolved.
+  Type getTypeOrSentinel(size_t index) {
+    if (index >= types.size())
+      return {};
+    return types[index].entry;
+  }
+
+  /// Add an index to the deferred worklist for re-parsing.
+  void addDeferredParsing(uint64_t index) { deferredWorklist.push_back(index); }
 };
 
 class DialectReader : public DialectBytecodeReader {
@@ -959,10 +982,12 @@ class DialectReader : public DialectBytecodeReader {
                 const StringSectionReader &stringReader,
                 const ResourceSectionReader &resourceReader,
                 const llvm::StringMap<BytecodeDialect *> &dialectsMap,
-                EncodingReader &reader, uint64_t &bytecodeVersion)
+                EncodingReader &reader, uint64_t &bytecodeVersion,
+                uint64_t currentIndex = 0)
       : attrTypeReader(attrTypeReader), stringReader(stringReader),
         resourceReader(resourceReader), dialectsMap(dialectsMap),
-        reader(reader), bytecodeVersion(bytecodeVersion) {}
+        reader(reader), bytecodeVersion(bytecodeVersion),
+        currentIndex(currentIndex) {}
 
   InFlightDiagnostic emitError(const Twine &msg) const override {
     return reader.emitError(msg);
@@ -989,7 +1014,7 @@ class DialectReader : public DialectBytecodeReader {
 
   DialectReader withEncodingReader(EncodingReader &encReader) const {
     return DialectReader(attrTypeReader, stringReader, resourceReader,
-                         dialectsMap, encReader, bytecodeVersion);
+                         dialectsMap, encReader, bytecodeVersion, currentIndex);
   }
 
   Location getLoc() const { return reader.getLoc(); }
@@ -1004,9 +1029,27 @@ class DialectReader : public DialectBytecodeReader {
   LogicalResult readOptionalAttribute(Attribute &result) override {
     return attrTypeReader.parseOptionalAttribute(reader, result);
   }
+  Attribute getOrDeferParsingAttribute() override {
+    uint64_t index;
+    if (failed(reader.parseVarInt(index)))
+      return nullptr;
+    Attribute attr = attrTypeReader.getAttributeOrSentinel(index);
+    if (!attr)
+      attrTypeReader.addDeferredParsing(index);
+    return attr;
+  }
   LogicalResult readType(Type &result) override {
     return attrTypeReader.parseType(reader, result);
   }
+  Type getOrDeferParsingType() override {
+    uint64_t index;
+    if (failed(reader.parseVarInt(index)))
+      return nullptr;
+    Type type = attrTypeReader.getTypeOrSentinel(index);
+    if (!type)
+      attrTypeReader.addDeferredParsing(index);
+    return type;
+  }
 
   FailureOr<AsmDialectResourceHandle> readResourceHandle() override {
     AsmDialectResourceHandle handle;
@@ -1095,6 +1138,7 @@ class DialectReader : public DialectBytecodeReader {
   const llvm::StringMap<BytecodeDialect *> &dialectsMap;
   EncodingReader &reader;
   uint64_t &bytecodeVersion;
+  uint64_t currentIndex;
 };
 
 /// Wraps the properties section and handles reading properties out of it.
@@ -1245,27 +1289,74 @@ T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
     return {};
   }
 
-  // If the entry has already been resolved, there is nothing left to do.
-  Entry<T> &entry = entries[index];
-  if (entry.entry)
-    return entry.entry;
+  // Use a deque to iteratively resolve entries with dependencies.
+  // - Pop from front to process
+  // - Push new dependencies to front (depth-first)
+  // - Move failed entries to back (retry after dependencies)
+  std::deque<size_t> worklist;
+  llvm::DenseSet<size_t> inWorklist;
+  worklist.push_back(index);
+  inWorklist.insert(index);
 
-  // Parse the entry.
-  EncodingReader reader(entry.data, fileLoc);
+  while (!worklist.empty()) {
+    size_t currentIndex = worklist.front();
+    worklist.pop_front();
+
+    if (currentIndex >= entries.size()) {
+      emitError(fileLoc) << "invalid " << entryType
+                         << " index: " << currentIndex;
+      return {};
+    }
 
-  // Parse based on how the entry was encoded.
-  if (entry.hasCustomEncoding) {
-    if (failed(parseCustomEntry(entry, reader, entryType)))
+    Entry<T> &entry = entries[currentIndex];
+
+    // If already resolved, continue.
+    if (entry.entry) {
+      inWorklist.erase(currentIndex);
+      continue;
+    }
+
+    // Clear the deferred worklist before parsing to capture any new entries.
+    deferredWorklist.clear();
+
+    // Parse the entry.
+    EncodingReader reader(entry.data, fileLoc);
+
+    // Parse based on how the entry was encoded.
+    LogicalResult parsed =
+        entry.hasCustomEncoding
+            ? parseCustomEntry(entry, reader, entryType, currentIndex)
+            : parseAsmEntry(entry.entry, reader, entryType);
+    bool parseSucceeded = succeeded(parsed);
+
+    if (parseSucceeded && !reader.empty()) {
+      reader.emitError("unexpected trailing bytes after " + entryType +
+                       " entry");
+      parseSucceeded = false;
+    }
+
+    if (parseSucceeded && entry.entry) {
+      // Successfully parsed, done with this entry.
+      inWorklist.erase(currentIndex);
+    } else if (!deferredWorklist.empty()) {
+      // Check if deferred parsing was requested.
+
+      // Move this entry to the back to retry after dependencies.
+      worklist.push_back(currentIndex);
+
+      // Add dependencies to the front (in reverse so they maintain order).
+      for (uint64_t idx : llvm::reverse(deferredWorklist)) {
+        if (inWorklist.insert(idx).second)
+          worklist.push_front(idx);
+      }
+      deferredWorklist.clear();
+    } else {
+      // Parsing failed with no deferred entries which implies an error.
       return T();
-  } else if (failed(parseAsmEntry(entry.entry, reader, entryType))) {
-    return T();
+    }
   }
 
-  if (!reader.empty()) {
-    reader.emitError("unexpected trailing bytes after " + entryType + " entry");
-    return T();
-  }
-  return entry.entry;
+  return entries[index].entry;
 }
 
 template <typename T>
@@ -1296,11 +1387,11 @@ LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
 }
 
 template <typename T>
-LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
-                                               EncodingReader &reader,
-                                               StringRef entryType) {
+LogicalResult
+AttrTypeReader::parseCustomEntry(Entry<T> &entry, EncodingReader &reader,
+                                 StringRef entryType, uint64_t index) {
   DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap,
-                              reader, bytecodeVersion);
+                              reader, bytecodeVersion, index);
   if (failed(entry.dialect->load(dialectReader, fileLoc.getContext())))
     return failure();
 
diff --git a/mlir/unittests/Bytecode/BytecodeTest.cpp b/mlir/unittests/Bytecode/BytecodeTest.cpp
index d7b442f6832d0..d5c6f010f5b8a 100644
--- a/mlir/unittests/Bytecode/BytecodeTest.cpp
+++ b/mlir/unittests/Bytecode/BytecodeTest.cpp
@@ -228,3 +228,40 @@ TEST(Bytecode, OpWithoutProperties) {
   EXPECT_TRUE(OperationEquivalence::computeHash(op.get()) ==
               OperationEquivalence::computeHash(roundtripped));
 }
+
+TEST(Bytecode, DeepCallSiteLoc) {
+  MLIRContext context;
+  ParserConfig config(&context);
+
+  // Create a deep CallSiteLoc chain to test iterative parsing.
+  // Use a depth that fits in the stack for writing but is still substantial.
+  Location baseLoc = FileLineColLoc::get(&context, "test.mlir", 1, 1);
+  Location loc = baseLoc;
+  constexpr int kDepth = 1000;
+  for (int i = 0; i < kDepth; ++i) {
+    loc = CallSiteLoc::get(loc, baseLoc);
+  }
+
+  // Create a simple module with the deep location.
+  OwningOpRef<Operation *> module =
+      parseSourceString<Operation *>("module {}", config);
+  ASSERT_TRUE(module);
+  module.get()->setLoc(loc);
+
+  // Write to bytecode.
+  std::string bytecode;
+  llvm::raw_string_ostream os(bytecode);
+  ASSERT_TRUE(succeeded(writeBytecodeToFile(module.get(), os)));
+
+  // Parse it back using the bytecode reader.
+  std::unique_ptr<Block> block = std::make_unique<Block>();
+  ASSERT_TRUE(succeeded(readBytecodeFile(
+      llvm::MemoryBufferRef(bytecode, "string-buffer"), block.get(), config)));
+
+  // Verify we got the roundtripped module.
+  ASSERT_FALSE(block->empty());
+  Operation *roundTripped = &block->front();
+
+  // Verify the location matches.
+  EXPECT_EQ(module.get()->getLoc(), roundTripped->getLoc());
+}

@github-actions
Copy link

github-actions bot commented Dec 6, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@github-actions
Copy link

github-actions bot commented Dec 8, 2025

🐧 Linux x64 Test Results

  • 7204 tests passed
  • 598 tests skipped

✅ The build succeeded and all tests passed.

void addDeferredParsing(uint64_t index) { deferredWorklist.push_back(index); }

private:
/// Resolve the given entry at `index`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you document "depth" here and elsewhere?

return entry.entry;
// Fast path: Try direct parsing without worklist overhead.
// This handles the common case where there are no deferred dependencies.
deferredWorklist.clear();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be an assert( deferredWorklist.empty()); here?

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems fine overall, but could deserve some more documentation as this is all non-trivial.
Any place you can write the high-level process / logic?

@jpienaar
Copy link
Member Author

jpienaar commented Dec 9, 2025

Seems fine overall, but could deserve some more documentation as this is all non-trivial. Any place you can write the high-level process / logic?

I've been thinking about it ... its not really related to the file format so the bytecode doc not good place, it is rather opaque to user (just happens under the hood and nobody should ever need to know about it) so perhaps in the reader file is best place. Perhaps as part of the depth documentation.

Add ability to defer parsing and re-enqueueing oneself. This enables
changing CallSiteLoc parsing to not recurse as deeply: previously this
could fail (especially on large inputs in debug mode). Chose an
arbitrary depth for now.
@jpienaar jpienaar enabled auto-merge (squash) December 9, 2025 18:14
@jpienaar jpienaar merged commit 93d2ef1 into llvm:main Dec 9, 2025
8 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:core MLIR Core Infrastructure mlir:ods mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants