Skip to content

Commit

Permalink
Make MLIR Pass Timing output configurable through injection
Browse files Browse the repository at this point in the history
This makes it possible for the client to control where the pass timings will
be printed.

Differential Revision: https://reviews.llvm.org/D78891
  • Loading branch information
joker-eph committed Apr 28, 2020
1 parent cd84bfb commit f65a3f7
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 28 deletions.
28 changes: 27 additions & 1 deletion mlir/include/mlir/Pass/PassManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,38 @@ class PassManager : public OpPassManager {
//===--------------------------------------------------------------------===//
// Pass Timing

/// A configuration struct provided to the pass timing feature.
class PassTimingConfig {
public:
using PrintCallbackFn = function_ref<void(raw_ostream &)>;

/// Initialize the configuration.
/// * 'displayMode' switch between list or pipeline display (see the
/// `PassDisplayMode` enum documentation).
explicit PassTimingConfig(
PassDisplayMode displayMode = PassDisplayMode::Pipeline)
: displayMode(displayMode) {}

virtual ~PassTimingConfig();

/// A hook that may be overridden by a derived config to control the
/// printing. The callback is supplied by the framework and the config is
/// responsible to call it back with a stream for the output.
virtual void printTiming(PrintCallbackFn printCallback);

/// Return the `PassDisplayMode` this config was created with.
PassDisplayMode getDisplayMode() { return displayMode; }

private:
PassDisplayMode displayMode;
};

/// Add an instrumentation to time the execution of passes and the computation
/// of analyses.
/// Note: Timing should be enabled after all other instrumentations to avoid
/// any potential "ghost" timing from other instrumentations being
/// unintentionally included in the timing results.
void enableTiming(PassDisplayMode displayMode = PassDisplayMode::Pipeline);
void enableTiming(std::unique_ptr<PassTimingConfig> config = nullptr);

/// Prompts the pass manager to print the statistics collected for each of the
/// held passes after each call to 'run'.
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Pass/PassManagerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ void PassManagerOptions::addPrinterInstrumentation(PassManager &pm) {
/// Add a pass timing instrumentation if enabled by 'pass-timing' flags.
void PassManagerOptions::addTimingInstrumentation(PassManager &pm) {
if (passTiming)
pm.enableTiming(passTimingDisplayMode);
pm.enableTiming(
std::make_unique<PassManager::PassTimingConfig>(passTimingDisplayMode));
}

void mlir::registerPassManagerCLOptions() {
Expand Down
67 changes: 41 additions & 26 deletions mlir/lib/Pass/PassTiming.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ struct Timer {
};

struct PassTiming : public PassInstrumentation {
PassTiming(PassDisplayMode displayMode) : displayMode(displayMode) {}
PassTiming(std::unique_ptr<PassManager::PassTimingConfig> config)
: config(std::move(config)) {}
~PassTiming() override { print(); }

/// Setup the instrumentation hooks.
Expand Down Expand Up @@ -231,8 +232,8 @@ struct PassTiming : public PassInstrumentation {
/// A stack of the currently active pass timers per thread.
DenseMap<uint64_t, SmallVector<Timer *, 4>> activeThreadTimers;

/// The display mode to use when printing the timing results.
PassDisplayMode displayMode;
/// The configuration object to use when printing the timing results.
std::unique_ptr<PassManager::PassTimingConfig> config;

/// A mapping of pipeline timers that need to be merged into the parent
/// collection. The timers are mapped to the parent info to merge into.
Expand Down Expand Up @@ -353,28 +354,37 @@ void PassTiming::print() {
return;

assert(rootTimers.size() == 1 && "expected one remaining root timer");
auto &rootTimer = rootTimers.begin()->second;
auto os = llvm::CreateInfoOutputFile();

// Print the timer header.
TimeRecord totalTime = rootTimer->getTotalTime();
printTimerHeader(*os, totalTime);

// Defer to a specialized printer for each display mode.
switch (displayMode) {
case PassDisplayMode::List:
printResultsAsList(*os, rootTimer.get(), totalTime);
break;
case PassDisplayMode::Pipeline:
printResultsAsPipeline(*os, rootTimer.get(), totalTime);
break;
}
printTimeEntry(*os, 0, "Total", totalTime, totalTime);
os->flush();

// Reset root timers.
rootTimers.clear();
activeThreadTimers.clear();
auto printCallback = [&](raw_ostream &os) {
auto &rootTimer = rootTimers.begin()->second;
// Print the timer header.
TimeRecord totalTime = rootTimer->getTotalTime();
printTimerHeader(os, totalTime);
// Defer to a specialized printer for each display mode.
switch (config->getDisplayMode()) {
case PassDisplayMode::List:
printResultsAsList(os, rootTimer.get(), totalTime);
break;
case PassDisplayMode::Pipeline:
printResultsAsPipeline(os, rootTimer.get(), totalTime);
break;
}
printTimeEntry(os, 0, "Total", totalTime, totalTime);
os.flush();

// Reset root timers.
rootTimers.clear();
activeThreadTimers.clear();
};

config->printTiming(printCallback);
}

// The default implementation for printTiming uses
// `llvm::CreateInfoOutputFile()` as stream, it can be overridden by clients
// to customize the output.
void PassManager::PassTimingConfig::printTiming(PrintCallbackFn printCallback) {
printCallback(*llvm::CreateInfoOutputFile());
}

/// Print the timing result in list mode.
Expand Down Expand Up @@ -449,16 +459,21 @@ void PassTiming::printResultsAsPipeline(raw_ostream &os, Timer *root,
printTimer(0, topLevelTimer.second.get());
}

// Out-of-line as key function.
PassManager::PassTimingConfig::~PassTimingConfig() {}

//===----------------------------------------------------------------------===//
// PassManager
//===----------------------------------------------------------------------===//

/// Add an instrumentation to time the execution of passes and the computation
/// of analyses.
void PassManager::enableTiming(PassDisplayMode displayMode) {
void PassManager::enableTiming(std::unique_ptr<PassTimingConfig> config) {
// Check if pass timing is already enabled.
if (passTiming)
return;
addInstrumentation(std::make_unique<PassTiming>(displayMode));
if (!config)
config = std::make_unique<PassManager::PassTimingConfig>();
addInstrumentation(std::make_unique<PassTiming>(std::move(config)));
passTiming = true;
}

0 comments on commit f65a3f7

Please sign in to comment.