diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h index e2014954e9fb11..659cd712655927 100644 --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -147,6 +147,13 @@ class MLIRContext { /// this call in this case. void setThreadPool(llvm::ThreadPool &pool); + /// Return the number of threads used by the thread pool in this context. The + /// number of computed hardware threads can change over the lifetime of a + /// process based on affinity changes, so users should use the number of + /// threads actually in the thread pool for dispatching work. Returns 1 if + /// multithreading is disabled. + unsigned getNumThreads(); + /// Return the thread pool used by this context. This method requires that /// multithreading be enabled within the context, and should generally not be /// used directly. Users should instead prefer the threading utilities within diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 96fe76eaac28da..7e811316c4e6f0 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -518,6 +518,16 @@ void MLIRContext::setThreadPool(llvm::ThreadPool &pool) { enableMultithreading(); } +unsigned MLIRContext::getNumThreads() { + if (isMultithreadingEnabled()) { + assert(impl->threadPool && + "multi-threading is enabled but threadpool not set"); + return impl->threadPool->getThreadCount(); + } + // No multithreading or active thread pool. Return 1 thread. + return 1; +} + llvm::ThreadPool &MLIRContext::getThreadPool() { assert(isMultithreadingEnabled() && "expected multi-threading to be enabled within the context"); diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp index cbeb94980dccaa..f080b4d112ba5f 100644 --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -679,8 +679,7 @@ InlinerPass::optimizeSCCAsync(MutableArrayRef nodesToVisit, // Note: The number of pass managers here needs to remain constant // to prevent issues with pass instrumentations that rely on having the same // pass manager for the main thread. - llvm::ThreadPool &threadPool = ctx->getThreadPool(); - size_t numThreads = threadPool.getThreadCount(); + size_t numThreads = ctx->getNumThreads(); if (opPipelines.size() < numThreads) { // Reserve before resizing so that we can use a reference to the first // element. diff --git a/mlir/test/Transforms/inlining.mlir b/mlir/test/Transforms/inlining.mlir index 98692853bf24ef..00a0445f5a88d9 100644 --- a/mlir/test/Transforms/inlining.mlir +++ b/mlir/test/Transforms/inlining.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -inline='default-pipeline=''' | FileCheck %s +// RUN: mlir-opt %s --mlir-disable-threading -inline='default-pipeline=''' | FileCheck %s // RUN: mlir-opt %s -inline='default-pipeline=''' -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s --check-prefix INLINE-LOC // RUN: mlir-opt %s -inline | FileCheck %s --check-prefix INLINE_SIMPLIFY