Skip to content

Commit

Permalink
[mlir] Fix a use after free when loading dependent dialects
Browse files Browse the repository at this point in the history
The way dependent dialects are implemented is by recursively calling
loadDialect in the constructor. This means we have to reload from the
dialect table because the constructor might have rehashed that table.

The steps for loading a dialect are
  1. Insert a nullptr into loadedDialects. This indicates the dialect is
     loading
  2. Call ctor(). This recursively loads dependent dialects
  3. Insert the new dialect into the table.

We had a conflict between steps 2 and 3 here. You have to be extremely
unlucky though as rehashing is rare and operator[] does no generation
checking on DenseMap. Changing that to an iterator would've uncovered
this issue immediately.
  • Loading branch information
d0k committed Apr 5, 2023
1 parent b7f9bb7 commit 74a8a1e
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions mlir/lib/IR/MLIRContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -438,9 +438,9 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
function_ref<std::unique_ptr<Dialect>()> ctor) {
auto &impl = getImpl();
// Get the correct insertion position sorted by namespace.
auto dialectIt = impl.loadedDialects.find(dialectNamespace);
auto dialectIt = impl.loadedDialects.try_emplace(dialectNamespace, nullptr);

if (dialectIt == impl.loadedDialects.end()) {
if (dialectIt.second) {
LLVM_DEBUG(llvm::dbgs()
<< "Load new dialect in Context " << dialectNamespace << "\n");
#ifndef NDEBUG
Expand All @@ -452,9 +452,11 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
"missing `dependentDialects` in a pass for example.");
#endif // NDEBUG
// loadedDialects entry is initialized to nullptr, indicating that the
// dialect is currently being loaded.
std::unique_ptr<Dialect> &dialect = impl.loadedDialects[dialectNamespace];
dialect = ctor();
// dialect is currently being loaded. Re-lookup the address in
// loadedDialects because the table might have been rehashed by recursive
// dialect loading in ctor().
std::unique_ptr<Dialect> &dialect = impl.loadedDialects[dialectNamespace] =
ctor();
assert(dialect && "dialect ctor failed");

// Refresh all the identifiers dialect field, this catches cases where a
Expand All @@ -473,15 +475,15 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
}

#ifndef NDEBUG
if (dialectIt->second == nullptr)
if (dialectIt.first->second == nullptr)
llvm::report_fatal_error(
"Loading (and getting) a dialect (" + dialectNamespace +
") while the same dialect is still loading: use loadDialect instead "
"of getOrLoadDialect.");
#endif // NDEBUG

// Abort if dialect with namespace has already been registered.
std::unique_ptr<Dialect> &dialect = dialectIt->second;
std::unique_ptr<Dialect> &dialect = dialectIt.first->second;
if (dialect->getTypeID() != dialectID)
llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace +
"' has already been registered");
Expand Down

0 comments on commit 74a8a1e

Please sign in to comment.