Skip to content

Commit

Permalink
Add a flag to the IRPrinter instrumentation to only print after a pas…
Browse files Browse the repository at this point in the history
…s if there is a change to the IR.

This adds an additional filtering mode for printing after a pass that checks to see if the pass actually changed the IR before printing it. This "change" detection is implemented using a SHA1 hash of the current operation and its children.

PiperOrigin-RevId: 284291089
  • Loading branch information
River707 authored and tensorflower-gardener committed Dec 7, 2019
1 parent ca23bd7 commit 8904e91
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 17 deletions.
33 changes: 27 additions & 6 deletions mlir/g3doc/WritingAPass.md
Expand Up @@ -624,7 +624,7 @@ pipeline. This display mode is available in mlir-opt via
`-pass-timing-display=list`.
```shell
$ mlir-opt foo.mlir -disable-pass-threading -cse -canonicalize -convert-std-to-llvm -pass-timing -pass-timing-display=list
$ mlir-opt foo.mlir -disable-pass-threading -pass-pipeline='func(cse,canonicalize)' -convert-std-to-llvm -pass-timing -pass-timing-display=list
===-------------------------------------------------------------------------===
... Pass execution timing report ...
Expand All @@ -649,7 +649,7 @@ the most time, and can also be used to identify when analyses are being
invalidated and recomputed. This is the default display mode.

```shell
$ mlir-opt foo.mlir -disable-pass-threading -cse -canonicalize -convert-std-to-llvm -pass-timing
$ mlir-opt foo.mlir -disable-pass-threading -pass-pipeline='func(cse,canonicalize)' -convert-std-to-llvm -pass-timing

===-------------------------------------------------------------------------===
... Pass execution timing report ...
Expand Down Expand Up @@ -680,7 +680,7 @@ perceived time, or clock time, whereas the `User Time` will display the total
cpu time.

```shell
$ mlir-opt foo.mlir -cse -canonicalize -convert-std-to-llvm -pass-timing
$ mlir-opt foo.mlir -pass-pipeline='func(cse,canonicalize)' -convert-std-to-llvm -pass-timing

===-------------------------------------------------------------------------===
... Pass execution timing report ...
Expand Down Expand Up @@ -716,7 +716,7 @@ this instrumentation:
* Print the IR before every pass in the pipeline.

```shell
$ mlir-opt foo.mlir -cse -print-ir-before=cse
$ mlir-opt foo.mlir -pass-pipeline='func(cse)' -print-ir-before=cse

*** IR Dump Before CSE ***
func @simple_constant() -> (i32, i32) {
Expand All @@ -732,7 +732,28 @@ func @simple_constant() -> (i32, i32) {
* Print the IR after every pass in the pipeline.

```shell
$ mlir-opt foo.mlir -cse -print-ir-after=cse
$ mlir-opt foo.mlir -pass-pipeline='func(cse)' -print-ir-after=cse

*** IR Dump After CSE ***
func @simple_constant() -> (i32, i32) {
%c1_i32 = constant 1 : i32
return %c1_i32, %c1_i32 : i32, i32
}
```

* `print-ir-after-change`
* Only print the IR after a pass if the pass mutated the IR. This helps to
reduce the number of IR dumps for "uninteresting" passes.
* Note: Changes are detected by comparing a hash of the operation before
and after the pass. This adds additional run-time to compute the hash of
the IR, and in some rare cases may result in false-positives depending
on the collision rate of the hash algorithm used.
* Note: This option should be used in unison with one of the other
'print-ir-after' options above, as this option alone does not enable
printing.

```shell
$ mlir-opt foo.mlir -pass-pipeline='func(cse,cse)' -print-ir-after=cse -print-ir-after-change

*** IR Dump After CSE ***
func @simple_constant() -> (i32, i32) {
Expand All @@ -748,7 +769,7 @@ func @simple_constant() -> (i32, i32) {
is disabled(`-disable-pass-threading`)

```shell
$ mlir-opt foo.mlir -disable-pass-threading -cse -print-ir-after=cse -print-ir-module-scope
$ mlir-opt foo.mlir -disable-pass-threading -pass-pipeline='func(cse)' -print-ir-after=cse -print-ir-module-scope

*** IR Dump After CSE *** ('func' operation: @bar)
func @bar(%arg0: f32, %arg1: f32) -> f32 {
Expand Down
20 changes: 18 additions & 2 deletions mlir/include/mlir/Pass/PassManager.h
Expand Up @@ -172,7 +172,12 @@ class PassManager : public OpPassManager {
/// printed. This should only be set to true when multi-threading is
/// disabled, otherwise we may try to print IR that is being modified
/// asynchronously.
explicit IRPrinterConfig(bool printModuleScope = false);
/// * 'printAfterOnlyOnChange' signals that when printing the IR after a
/// pass, in the case of a non-failure, we should first check if any
/// potential mutations were made. This allows for reducing the number of
/// logs that don't contain meaningful changes.
explicit IRPrinterConfig(bool printModuleScope = false,
bool printAfterOnlyOnChange = false);
virtual ~IRPrinterConfig();

/// A hook that may be overridden by a derived config that checks if the IR
Expand All @@ -192,9 +197,17 @@ class PassManager : public OpPassManager {
/// Returns true if the IR should always be printed at the top-level scope.
bool shouldPrintAtModuleScope() const { return printModuleScope; }

/// Returns true if the IR should only printed after a pass if the IR
/// "changed".
bool shouldPrintAfterOnlyOnChange() const { return printAfterOnlyOnChange; }

private:
/// A flag that indicates if the IR should be printed at module scope.
bool printModuleScope;

/// A flag that indicates that the IR after a pass should only be printed if
/// a change is detected.
bool printAfterOnlyOnChange;
};

/// Add an instrumentation to print the IR before and after pass execution,
Expand All @@ -208,11 +221,14 @@ class PassManager : public OpPassManager {
/// return true if the IR should be printed or not.
/// * 'printModuleScope' signals if the module IR should be printed, even
/// for non module passes.
/// * 'printAfterOnlyOnChange' signals that when printing the IR after a
/// pass, in the case of a non-failure, we should first check if any
/// potential mutations were made.
/// * 'out' corresponds to the stream to output the printed IR to.
void enableIRPrinting(
std::function<bool(Pass *, Operation *)> shouldPrintBeforePass,
std::function<bool(Pass *, Operation *)> shouldPrintAfterPass,
bool printModuleScope, raw_ostream &out);
bool printModuleScope, bool printAfterOnlyOnChange, raw_ostream &out);

//===--------------------------------------------------------------------===//
// Pass Timing
Expand Down
101 changes: 93 additions & 8 deletions mlir/lib/Pass/IRPrinting.cpp
Expand Up @@ -20,11 +20,70 @@
#include "mlir/Pass/PassManager.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/SHA1.h"

using namespace mlir;
using namespace mlir::detail;

namespace {
//===----------------------------------------------------------------------===//
// OperationFingerPrint
//===----------------------------------------------------------------------===//

/// A unique fingerprint for a specific operation, and all of it's internal
/// operations.
class OperationFingerPrint {
public:
OperationFingerPrint(Operation *topOp) {
llvm::SHA1 hasher;

// Hash each of the operations based upon their mutable bits:
topOp->walk([&](Operation *op) {
// - Operation pointer
addDataToHash(hasher, op);
// - Attributes
addDataToHash(hasher,
op->getAttrList().getDictionary().getAsOpaquePointer());
// - Blocks in Regions
for (Region &region : op->getRegions()) {
for (Block &block : region) {
addDataToHash(hasher, &block);
for (BlockArgument *arg : block.getArguments())
addDataToHash(hasher, arg);
}
}
// - Location
addDataToHash(hasher, op->getLoc().getAsOpaquePointer());
// - Operands
for (Value *operand : op->getOperands())
addDataToHash(hasher, operand);
// - Successors
for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i)
addDataToHash(hasher, op->getSuccessor(i));
});
hash = hasher.result();
}

bool operator==(const OperationFingerPrint &other) const {
return hash == other.hash;
}
bool operator!=(const OperationFingerPrint &other) const {
return !(*this == other);
}

private:
template <typename T> void addDataToHash(llvm::SHA1 &hasher, const T &data) {
hasher.update(
ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
}

SmallString<20> hash;
};

//===----------------------------------------------------------------------===//
// IRPrinter
//===----------------------------------------------------------------------===//

class IRPrinterInstrumentation : public PassInstrumentation {
public:
IRPrinterInstrumentation(std::unique_ptr<PassManager::IRPrinterConfig> config)
Expand All @@ -38,6 +97,11 @@ class IRPrinterInstrumentation : public PassInstrumentation {

/// Configuration to use.
std::unique_ptr<PassManager::IRPrinterConfig> config;

/// The following is a set of fingerprints for operations that are currently
/// being operated on in a pass. This field is only used when the
/// configuration asked for change detection.
DenseMap<Pass *, OperationFingerPrint> beforePassFingerPrints;
};
} // end anonymous namespace

Expand Down Expand Up @@ -81,6 +145,10 @@ static void printIR(Operation *op, bool printModuleScope, raw_ostream &out,
void IRPrinterInstrumentation::runBeforePass(Pass *pass, Operation *op) {
if (isHiddenPass(pass))
return;
// If the config asked to detect changes, record the current fingerprint.
if (config->shouldPrintAfterOnlyOnChange())
beforePassFingerPrints.try_emplace(pass, op);

config->printBeforeIfEnabled(pass, op, [&](raw_ostream &out) {
out << formatv("*** IR Dump Before {0} ***", pass->getName());
printIR(op, config->shouldPrintAtModuleScope(), out, OpPrintingFlags());
Expand All @@ -91,6 +159,20 @@ void IRPrinterInstrumentation::runBeforePass(Pass *pass, Operation *op) {
void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) {
if (isHiddenPass(pass))
return;
// If the config asked to detect changes, compare the current fingerprint with
// the previous.
if (config->shouldPrintAfterOnlyOnChange()) {
auto fingerPrintIt = beforePassFingerPrints.find(pass);
assert(fingerPrintIt != beforePassFingerPrints.end() &&
"expected valid fingerprint");
// If the fingerprints are the same, we don't print the IR.
if (fingerPrintIt->second == OperationFingerPrint(op)) {
beforePassFingerPrints.erase(fingerPrintIt);
return;
}
beforePassFingerPrints.erase(fingerPrintIt);
}

config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) {
out << formatv("*** IR Dump After {0} ***", pass->getName());
printIR(op, config->shouldPrintAtModuleScope(), out, OpPrintingFlags());
Expand All @@ -101,6 +183,9 @@ void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) {
void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) {
if (isAdaptorPass(pass))
return;
if (config->shouldPrintAfterOnlyOnChange())
beforePassFingerPrints.erase(pass);

config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) {
out << formatv("*** IR Dump After {0} Failed ***", pass->getName());
printIR(op, config->shouldPrintAtModuleScope(), out,
Expand All @@ -114,10 +199,10 @@ void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) {
//===----------------------------------------------------------------------===//

/// Initialize the configuration.
/// * 'printModuleScope' signals if the module IR should be printed, even
/// for non module passes.
PassManager::IRPrinterConfig::IRPrinterConfig(bool printModuleScope)
: printModuleScope(printModuleScope) {}
PassManager::IRPrinterConfig::IRPrinterConfig(bool printModuleScope,
bool printAfterOnlyOnChange)
: printModuleScope(printModuleScope),
printAfterOnlyOnChange(printAfterOnlyOnChange) {}
PassManager::IRPrinterConfig::~IRPrinterConfig() {}

/// A hook that may be overridden by a derived config that checks if the IR
Expand Down Expand Up @@ -148,8 +233,8 @@ struct BasicIRPrinterConfig : public PassManager::IRPrinterConfig {
BasicIRPrinterConfig(
std::function<bool(Pass *, Operation *)> shouldPrintBeforePass,
std::function<bool(Pass *, Operation *)> shouldPrintAfterPass,
bool printModuleScope, raw_ostream &out)
: IRPrinterConfig(printModuleScope),
bool printModuleScope, bool printAfterOnlyOnChange, raw_ostream &out)
: IRPrinterConfig(printModuleScope, printAfterOnlyOnChange),
shouldPrintBeforePass(shouldPrintBeforePass),
shouldPrintAfterPass(shouldPrintAfterPass), out(out) {
assert((shouldPrintBeforePass || shouldPrintAfterPass) &&
Expand Down Expand Up @@ -188,8 +273,8 @@ void PassManager::enableIRPrinting(std::unique_ptr<IRPrinterConfig> config) {
void PassManager::enableIRPrinting(
std::function<bool(Pass *, Operation *)> shouldPrintBeforePass,
std::function<bool(Pass *, Operation *)> shouldPrintAfterPass,
bool printModuleScope, raw_ostream &out) {
bool printModuleScope, bool printAfterOnlyOnChange, raw_ostream &out) {
enableIRPrinting(std::make_unique<BasicIRPrinterConfig>(
std::move(shouldPrintBeforePass), std::move(shouldPrintAfterPass),
printModuleScope, out));
printModuleScope, printAfterOnlyOnChange, out));
}
7 changes: 6 additions & 1 deletion mlir/lib/Pass/PassManagerOptions.cpp
Expand Up @@ -54,6 +54,11 @@ struct PassManagerOptions {
llvm::cl::opt<bool> printAfterAll{"print-ir-after-all",
llvm::cl::desc("Print IR after each pass"),
llvm::cl::init(false)};
llvm::cl::opt<bool> printAfterChange{
"print-ir-after-change",
llvm::cl::desc(
"When printing the IR after a pass, only print if the IR changed"),
llvm::cl::init(false)};
llvm::cl::opt<bool> printModuleScope{
"print-ir-module-scope",
llvm::cl::desc("When printing IR for print-ir-[before|after]{-all} "
Expand Down Expand Up @@ -139,7 +144,7 @@ void PassManagerOptions::addPrinterInstrumentation(PassManager &pm) {

// Otherwise, add the IR printing instrumentation.
pm.enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
printModuleScope, llvm::errs());
printModuleScope, printAfterChange, llvm::errs());
}

/// Add a pass timing instrumentation if enabled by 'pass-timing' flags.
Expand Down
8 changes: 8 additions & 0 deletions mlir/test/Pass/ir-printing.mlir
Expand Up @@ -3,8 +3,10 @@
// RUN: mlir-opt %s -disable-pass-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-after=cse -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER %s
// RUN: mlir-opt %s -disable-pass-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-after-all -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER_ALL %s
// RUN: mlir-opt %s -disable-pass-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-before=cse -print-ir-module-scope -o /dev/null 2>&1 | FileCheck -check-prefix=BEFORE_MODULE %s
// RUN: mlir-opt %s -disable-pass-threading=true -pass-pipeline='func(cse,cse)' -print-ir-after-all -print-ir-after-change -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER_ALL_CHANGE %s

func @foo() {
%0 = constant 0 : i32
return
}

Expand Down Expand Up @@ -52,3 +54,9 @@ func @bar() {
// BEFORE_MODULE: *** IR Dump Before{{.*}}CSE *** ('func' operation: @bar)
// BEFORE_MODULE: func @foo()
// BEFORE_MODULE: func @bar()

// AFTER_ALL_CHANGE: *** IR Dump After{{.*}}CSE ***
// AFTER_ALL_CHANGE-NEXT: func @foo()
// AFTER_ALL_CHANGE-NOT: *** IR Dump After{{.*}}CSE ***
// We expect that only 'foo' changed during CSE, and the second run of CSE did
// nothing.

0 comments on commit 8904e91

Please sign in to comment.