Skip to content

Commit 0f304ef

Browse files
committed
[mlir] Add asserts when changing various MLIRContext configurations
This helps to prevent tsan failures when users inadvertantly mutate the context in a non-safe way. Differential Revision: https://reviews.llvm.org/D112021
1 parent 9d9eddd commit 0f304ef

File tree

5 files changed

+44
-8
lines changed

5 files changed

+44
-8
lines changed

mlir/include/mlir/IR/DialectRegistry.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,10 @@ class DialectRegistry {
212212
addExtension(std::make_unique<Extension>(std::move(extensionFn)));
213213
}
214214

215+
/// Returns true if the current registry is a subset of 'rhs', i.e. if 'rhs'
216+
/// contains all of the components of this registry.
217+
bool isSubsetOf(const DialectRegistry &rhs) const;
218+
215219
private:
216220
MapTy registry;
217221
std::vector<std::unique_ptr<DialectExtensionBase>> extensions;

mlir/lib/IR/Dialect.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,12 @@ void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
228228
for (const auto &extension : extensions)
229229
applyExtension(*extension);
230230
}
231+
232+
bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const {
233+
// Treat any extensions conservatively.
234+
if (!extensions.empty())
235+
return false;
236+
// Check that the current dialects fully overlap with the dialects in 'rhs'.
237+
return llvm::all_of(
238+
registry, [&](const auto &it) { return rhs.registry.count(it.first); });
239+
}

mlir/lib/IR/MLIRContext.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,12 @@ DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
355355
//===----------------------------------------------------------------------===//
356356

357357
void MLIRContext::appendDialectRegistry(const DialectRegistry &registry) {
358+
if (registry.isSubsetOf(impl->dialectsRegistry))
359+
return;
360+
361+
assert(impl->multiThreadedExecutionContext == 0 &&
362+
"appending to the MLIRContext dialect registry while in a "
363+
"multi-threaded execution context");
358364
registry.appendTo(impl->dialectsRegistry);
359365

360366
// For the already loaded dialects, apply any possible extensions immediately.
@@ -470,6 +476,9 @@ bool MLIRContext::allowsUnregisteredDialects() {
470476
}
471477

472478
void MLIRContext::allowUnregisteredDialects(bool allowing) {
479+
assert(impl->multiThreadedExecutionContext == 0 &&
480+
"changing MLIRContext `allow-unregistered-dialects` configuration "
481+
"while in a multi-threaded execution context");
473482
impl->allowUnregisteredDialects = allowing;
474483
}
475484

@@ -484,6 +493,9 @@ void MLIRContext::disableMultithreading(bool disable) {
484493
// --mlir-disable-threading
485494
if (isThreadingGloballyDisabled())
486495
return;
496+
assert(impl->multiThreadedExecutionContext == 0 &&
497+
"changing MLIRContext `disable-threading` configuration while "
498+
"in a multi-threaded execution context");
487499

488500
impl->threadingIsEnabled = !disable;
489501

@@ -557,6 +569,9 @@ bool MLIRContext::shouldPrintOpOnDiagnostic() {
557569
/// Set the flag specifying if we should attach the operation to diagnostics
558570
/// emitted via Operation::emit.
559571
void MLIRContext::printOpOnDiagnostic(bool enable) {
572+
assert(impl->multiThreadedExecutionContext == 0 &&
573+
"changing MLIRContext `print-op-on-diagnostic` configuration while in "
574+
"a multi-threaded execution context");
560575
impl->printOpOnDiagnostic = enable;
561576
}
562577

@@ -569,6 +584,9 @@ bool MLIRContext::shouldPrintStackTraceOnDiagnostic() {
569584
/// Set the flag specifying if we should attach the current stacktrace when
570585
/// emitting diagnostics.
571586
void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
587+
assert(impl->multiThreadedExecutionContext == 0 &&
588+
"changing MLIRContext `print-stacktrace-on-diagnostic` configuration "
589+
"while in a multi-threaded execution context");
572590
impl->printStackTraceOnDiagnostic = enable;
573591
}
574592

mlir/lib/Reducer/OptReductionPass.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ void OptReductionPass::runOnOperation() {
4242
ModuleOp module = this->getOperation();
4343
ModuleOp moduleVariant = module.clone();
4444

45-
PassManager passManager(module.getContext());
45+
OpPassManager passManager("builtin.module");
4646
if (failed(parsePassPipeline(optPass, passManager))) {
4747
module.emitError() << "\nfailed to parse pass pipeline";
4848
return signalPassFailure();
@@ -54,7 +54,13 @@ void OptReductionPass::runOnOperation() {
5454
return signalPassFailure();
5555
}
5656

57-
if (failed(passManager.run(moduleVariant))) {
57+
// Temporarily push the variant under the main module and execute the pipeline
58+
// on it.
59+
module.getBody()->push_back(moduleVariant);
60+
LogicalResult pipelineResult = runPipeline(passManager, moduleVariant);
61+
moduleVariant->remove();
62+
63+
if (failed(pipelineResult)) {
5864
module.emitError() << "\nfailed to run pass pipeline";
5965
return signalPassFailure();
6066
}

mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -255,14 +255,13 @@ struct TestLinalgGreedyFusion
255255
patterns.add<ExtractSliceOfPadTensorSwapPattern>(context);
256256
scf::populateSCFForLoopCanonicalizationPatterns(patterns);
257257
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
258+
OpPassManager pm(FuncOp::getOperationName());
259+
pm.addPass(createLoopInvariantCodeMotionPass());
260+
pm.addPass(createCanonicalizerPass());
261+
pm.addPass(createCSEPass());
258262
do {
259263
(void)applyPatternsAndFoldGreedily(getOperation(), frozenPatterns);
260-
PassManager pm(context);
261-
pm.addPass(createLoopInvariantCodeMotionPass());
262-
pm.addPass(createCanonicalizerPass());
263-
pm.addPass(createCSEPass());
264-
LogicalResult res = pm.run(getOperation()->getParentOfType<ModuleOp>());
265-
if (failed(res))
264+
if (failed(runPipeline(pm, getOperation())))
266265
this->signalPassFailure();
267266
} while (succeeded(fuseLinalgOpsGreedily(getOperation())));
268267
}

0 commit comments

Comments
 (0)