Skip to content

Commit

Permalink
[mlir] Add a ThreadPool to MLIRContext and refactor MLIR threading usage
Browse files Browse the repository at this point in the history
This revision refactors the usage of multithreaded utilities in MLIR to use a common
thread pool within the MLIR context, in addition to a new utility that makes writing
multi-threaded code in MLIR less error prone. Using a unified thread pool brings about
several advantages:

* Better thread usage and more control
We currently use the static llvm threading utilities, which do not allow multiple
levels of asynchronous scheduling (even if there are open threads). This is due to
how the current TaskGroup structure works, which only allows one truly multithreaded
instance at a time. By having our own ThreadPool we gain more control and flexibility
over our job/thread scheduling, and in a followup can enable threading more parts of
the compiler.

* The static nature of TaskGroup causes issues in certain configurations
Due to the static nature of TaskGroup, there have been quite a few problems related to
destruction that have caused several downstream projects to disable threading. See
D104207 for discussion on some related fallout. By having a ThreadPool scoped to
the context, we don't have to worry about destruction and can ensure that any
additional MLIR thread usage ends when the context is destroyed.

Differential Revision: https://reviews.llvm.org/D104516
  • Loading branch information
River707 committed Jun 23, 2021
1 parent 18465bc commit 6569cf2
Show file tree
Hide file tree
Showing 13 changed files with 299 additions and 143 deletions.
3 changes: 3 additions & 0 deletions llvm/include/llvm/Support/ThreadPool.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ class ThreadPool {

unsigned getThreadCount() const { return ThreadCount; }

/// Returns true if the current thread is a worker thread of this thread pool.
bool isWorkerThread() const;

private:
bool workCompletedUnlocked() { return !ActiveThreads && Tasks.empty(); }

Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Support/ThreadPool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ void ThreadPool::wait() {
CompletionCondition.wait(LockGuard, [&] { return workCompletedUnlocked(); });
}

bool ThreadPool::isWorkerThread() const {
std::thread::id CurrentThreadId = std::this_thread::get_id();
for (const std::thread &Thread : Threads)
if (CurrentThreadId == Thread.get_id())
return true;
return false;
}

std::shared_future<void> ThreadPool::asyncImpl(TaskTy Task) {
/// Wrap the Task in a packaged_task to return a future object.
PackagedTaskTy PackagedTask(std::move(Task));
Expand Down
10 changes: 10 additions & 0 deletions mlir/include/mlir/IR/MLIRContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
#include <memory>
#include <vector>

namespace llvm {
class ThreadPool;
} // end namespace llvm

namespace mlir {
class AbstractOperation;
class DebugActionManager;
Expand Down Expand Up @@ -114,6 +118,12 @@ class MLIRContext {
disableMultithreading(!enable);
}

/// Return the thread pool owned by this context. This method requires that
/// multithreading be enabled within the context, and should generally not be
/// used directly. Users should instead prefer the threading utilities within
/// Threading.h.
llvm::ThreadPool &getThreadPool();

/// Return true if we should attach the operation to diagnostics emitted via
/// Operation::emit.
bool shouldPrintOpOnDiagnostic();
Expand Down
153 changes: 153 additions & 0 deletions mlir/include/mlir/IR/Threading.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
//===- Threading.h - MLIR Threading Utilities -------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines various utilies for multithreaded processing within MLIR.
// These utilities automatically handle many of the necessary threading
// conditions, such as properly ordering diagnostics, observing if threading is
// disabled, etc. These utilities should be used over other threading utilities
// whenever feasible.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_IR_THREADING_H
#define MLIR_IR_THREADING_H

#include "mlir/IR/Diagnostics.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/Support/ThreadPool.h"
#include <atomic>

namespace mlir {

/// Invoke the given function on the elements between [begin, end)
/// asynchronously. If the given function returns a failure when processing any
/// of the elements, execution is stopped and a failure is returned from this
/// function. This means that in the case of failure, not all elements of the
/// range will be processed. Diagnostics emitted during processing are ordered
/// relative to the element's position within [begin, end). If the provided
/// context does not have multi-threading enabled, this function always
/// processes elements sequentially.
template <typename IteratorT, typename FuncT>
LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin,
IteratorT end, FuncT &&func) {
unsigned numElements = static_cast<unsigned>(std::distance(begin, end));
if (numElements == 0)
return success();

// If multithreading is disabled or there is a small number of elements,
// process the elements directly on this thread.
// FIXME: ThreadPool should allow work stealing to avoid deadlocks when
// scheduling work within a worker thread.
if (!context->isMultithreadingEnabled() || numElements <= 1 ||
context->getThreadPool().isWorkerThread()) {
for (; begin != end; ++begin)
if (failed(func(*begin)))
return failure();
return success();
}

// Build a wrapper processing function that properly initializes a parallel
// diagnostic handler.
ParallelDiagnosticHandler handler(context);
std::atomic<unsigned> curIndex(0);
std::atomic<bool> processingFailed(false);
auto processFn = [&] {
while (!processingFailed) {
unsigned index = curIndex++;
if (index >= numElements)
break;
handler.setOrderIDForThread(index);
if (failed(func(*std::next(begin, index))))
processingFailed = true;
handler.eraseOrderIDForThread();
}
};

// Otherwise, process the elements in parallel.
llvm::ThreadPool &threadPool = context->getThreadPool();
size_t numActions = std::min(numElements, threadPool.getThreadCount());
SmallVector<std::shared_future<void>> threadFutures;
threadFutures.reserve(numActions - 1);
for (unsigned i = 1; i < numActions; ++i)
threadFutures.emplace_back(threadPool.async(processFn));
processFn();

// Wait for all of the threads to finish.
for (std::shared_future<void> &future : threadFutures)
future.wait();
return failure(processingFailed);
}

/// Invoke the given function on the elements in the provided range
/// asynchronously. If the given function returns a failure when processing any
/// of the elements, execution is stopped and a failure is returned from this
/// function. This means that in the case of failure, not all elements of the
/// range will be processed. Diagnostics emitted during processing are ordered
/// relative to the element's position within the range. If the provided context
/// does not have multi-threading enabled, this function always processes
/// elements sequentially.
template <typename RangeT, typename FuncT>
LogicalResult failableParallelForEach(MLIRContext *context, RangeT &&range,
FuncT &&func) {
return failableParallelForEach(context, std::begin(range), std::end(range),
std::forward<FuncT>(func));
}

/// Invoke the given function on the elements between [begin, end)
/// asynchronously. If the given function returns a failure when processing any
/// of the elements, execution is stopped and a failure is returned from this
/// function. This means that in the case of failure, not all elements of the
/// range will be processed. Diagnostics emitted during processing are ordered
/// relative to the element's position within [begin, end). If the provided
/// context does not have multi-threading enabled, this function always
/// processes elements sequentially.
template <typename FuncT>
LogicalResult failableParallelForEachN(MLIRContext *context, size_t begin,
size_t end, FuncT &&func) {
return failableParallelForEach(context, llvm::seq(begin, end),
std::forward<FuncT>(func));
}

/// Invoke the given function on the elements between [begin, end)
/// asynchronously. Diagnostics emitted during processing are ordered relative
/// to the element's position within [begin, end). If the provided context does
/// not have multi-threading enabled, this function always processes elements
/// sequentially.
template <typename IteratorT, typename FuncT>
void parallelForEach(MLIRContext *context, IteratorT begin, IteratorT end,
FuncT &&func) {
(void)failableParallelForEach(context, begin, end, [&](auto &&value) {
return func(std::forward<decltype(value)>(value)), success();
});
}

/// Invoke the given function on the elements in the provided range
/// asynchronously. Diagnostics emitted during processing are ordered relative
/// to the element's position within the range. If the provided context does not
/// have multi-threading enabled, this function always processes elements
/// sequentially.
template <typename RangeT, typename FuncT>
void parallelForEach(MLIRContext *context, RangeT &&range, FuncT &&func) {
parallelForEach(context, std::begin(range), std::end(range),
std::forward<FuncT>(func));
}

/// Invoke the given function on the elements between [begin, end)
/// asynchronously. Diagnostics emitted during processing are ordered relative
/// to the element's position within [begin, end). If the provided context does
/// not have multi-threading enabled, this function always processes elements
/// sequentially.
template <typename FuncT>
void parallelForEachN(MLIRContext *context, size_t begin, size_t end,
FuncT &&func) {
parallelForEach(context, llvm::seq(begin, end), std::forward<FuncT>(func));
}

} // end namespace mlir

#endif // MLIR_IR_THREADING_H
10 changes: 10 additions & 0 deletions mlir/lib/IR/MLIRContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/RWMutex.h"
#include "llvm/Support/ThreadPool.h"
#include "llvm/Support/raw_ostream.h"
#include <memory>

Expand Down Expand Up @@ -260,6 +261,9 @@ class MLIRContextImpl {
// Other
//===--------------------------------------------------------------------===//

/// The thread pool to use when processing MLIR tasks in parallel.
llvm::ThreadPool threadPool;

/// This is a list of dialects that are created referring to this context.
/// The MLIRContext owns the objects.
DenseMap<StringRef, std::unique_ptr<Dialect>> loadedDialects;
Expand Down Expand Up @@ -571,6 +575,12 @@ void MLIRContext::disableMultithreading(bool disable) {
impl->typeUniquer.disableMultithreading(disable);
}

llvm::ThreadPool &MLIRContext::getThreadPool() {
assert(isMultithreadingEnabled() &&
"expected multi-threading to be enabled within the context");
return impl->threadPool;
}

void MLIRContext::enterMultiThreadedExecution() {
#ifndef NDEBUG
++impl->multiThreadedExecutionContext;
Expand Down
36 changes: 5 additions & 31 deletions mlir/lib/IR/Verifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/RegionKindInterface.h"
#include "mlir/IR/Threading.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/Parallel.h"
Expand All @@ -43,11 +44,6 @@ namespace {
/// This class encapsulates all the state used to verify an operation region.
class OperationVerifier {
public:
explicit OperationVerifier(MLIRContext *context)
// TODO: Re-enable parallelism once deadlocks found in D104207 are
// resolved.
: parallelismEnabled(false) {}

/// Verify the given operation.
LogicalResult verifyOpAndDominance(Operation &op);

Expand All @@ -66,9 +62,6 @@ class OperationVerifier {
/// Operation.
LogicalResult verifyDominanceOfContainedRegions(Operation &op,
DominanceInfo &domInfo);

/// This is true if parallelism is enabled on the MLIRContext.
const bool parallelismEnabled;
};
} // end anonymous namespace

Expand All @@ -91,28 +84,9 @@ LogicalResult OperationVerifier::verifyOpAndDominance(Operation &op) {

// Check the dominance properties and invariants of any operations in the
// regions contained by the 'opsWithIsolatedRegions' operations.
if (!parallelismEnabled || opsWithIsolatedRegions.size() <= 1) {
// If parallelism is disabled or if there is only 0/1 operation to do, use
// a simple non-parallel loop.
for (Operation *op : opsWithIsolatedRegions) {
if (failed(verifyOpAndDominance(*op)))
return failure();
}
} else {
// Otherwise, verify the operations and their bodies in parallel.
ParallelDiagnosticHandler handler(op.getContext());
std::atomic<bool> passFailed(false);
llvm::parallelForEachN(0, opsWithIsolatedRegions.size(), [&](size_t opIdx) {
handler.setOrderIDForThread(opIdx);
if (failed(verifyOpAndDominance(*opsWithIsolatedRegions[opIdx])))
passFailed = true;
handler.eraseOrderIDForThread();
});
if (passFailed)
return failure();
}

return success();
return failableParallelForEach(
op.getContext(), opsWithIsolatedRegions,
[&](Operation *op) { return verifyOpAndDominance(*op); });
}

/// Returns true if this block may be valid without terminator. That is if:
Expand Down Expand Up @@ -378,5 +352,5 @@ OperationVerifier::verifyDominanceOfContainedRegions(Operation &op,
/// compiler bugs. On error, this reports the error through the MLIRContext and
/// returns failure.
LogicalResult mlir::verify(Operation *op) {
return OperationVerifier(op->getContext()).verifyOpAndDominance(*op);
return OperationVerifier().verifyOpAndDominance(*op);
}
74 changes: 27 additions & 47 deletions mlir/lib/Pass/Pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "PassDetail.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Threading.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Support/FileUtilities.h"
#include "llvm/ADT/STLExtras.h"
Expand Down Expand Up @@ -580,61 +581,40 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
}
}

// A parallel diagnostic handler that provides deterministic diagnostic
// ordering.
ParallelDiagnosticHandler diagHandler(&getContext());

// An index for the current operation/analysis manager pair.
std::atomic<unsigned> opIt(0);

// Get the current thread for this adaptor.
PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
this};
auto *instrumentor = am.getPassInstrumentor();

// An atomic failure variable for the async executors.
std::atomic<bool> passFailed(false);
llvm::parallelForEach(
asyncExecutors.begin(),
std::next(asyncExecutors.begin(),
std::min(asyncExecutors.size(), opAMPairs.size())),
[&](MutableArrayRef<OpPassManager> pms) {
for (auto e = opAMPairs.size(); !passFailed && opIt < e;) {
// Get the next available operation index.
unsigned nextID = opIt++;
if (nextID >= e)
break;

// Set the order id for this thread in the diagnostic handler.
diagHandler.setOrderIDForThread(nextID);

// Get the pass manager for this operation and execute it.
auto &it = opAMPairs[nextID];
auto *pm = findPassManagerFor(
pms, it.first->getName().getIdentifier(), getContext());
assert(pm && "expected valid pass manager for operation");

unsigned initGeneration = pm->impl->initializationGeneration;
LogicalResult pipelineResult =
runPipeline(pm->getPasses(), it.first, it.second, verifyPasses,
initGeneration, instrumentor, &parentInfo);

// Drop this thread from being tracked by the diagnostic handler.
// After this task has finished, the thread may be used outside of
// this pass manager context meaning that we don't want to track
// diagnostics from it anymore.
diagHandler.eraseOrderIDForThread();

// Handle a failed pipeline result.
if (failed(pipelineResult)) {
passFailed = true;
break;
}
}
});
std::vector<std::atomic<bool>> activePMs(asyncExecutors.size());
std::fill(activePMs.begin(), activePMs.end(), false);
auto processFn = [&](auto &opPMPair) {
// Find a pass manager for this operation.
auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) {
bool expectedInactive = false;
return isActive.compare_exchange_strong(expectedInactive, true);
});
unsigned pmIndex = it - activePMs.begin();

// Get the pass manager for this operation and execute it.
auto *pm = findPassManagerFor(asyncExecutors[pmIndex],
opPMPair.first->getName().getIdentifier(),
getContext());
assert(pm && "expected valid pass manager for operation");

unsigned initGeneration = pm->impl->initializationGeneration;
LogicalResult pipelineResult =
runPipeline(pm->getPasses(), opPMPair.first, opPMPair.second,
verifyPasses, initGeneration, instrumentor, &parentInfo);

// Reset the active bit for this pass manager.
activePMs[pmIndex].store(false);
return pipelineResult;
};

// Signal a failure if any of the executors failed.
if (passFailed)
if (failed(failableParallelForEach(&getContext(), opAMPairs, processFn)))
signalPassFailure();
}

Expand Down
Loading

0 comments on commit 6569cf2

Please sign in to comment.