102 changes: 77 additions & 25 deletions mlir/include/mlir/IR/Visitors.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,15 @@ class Operation;
class Block;
class Region;

/// A utility result that is used to signal if a walk method should be
/// interrupted or advance.
/// A utility result that is used to signal how to proceed with an ongoing walk:
/// * Interrupt: the walk will be interrupted and no more operations, regions
/// or blocks will be visited.
/// * Advance: the walk will continue.
/// * Skip: the walk of the current operation, region or block and their
/// nested elements that haven't been visited already will be skipped and will
/// continue with the next operation, region or block.
class WalkResult {
enum ResultEnum { Interrupt, Advance } result;
enum ResultEnum { Interrupt, Advance, Skip } result;

public:
WalkResult(ResultEnum result) : result(result) {}
Expand All @@ -44,11 +49,18 @@ class WalkResult {

static WalkResult interrupt() { return {Interrupt}; }
static WalkResult advance() { return {Advance}; }
static WalkResult skip() { return {Skip}; }

/// Returns true if the walk was interrupted.
bool wasInterrupted() const { return result == Interrupt; }

/// Returns true if the walk was skipped.
bool wasSkipped() const { return result == Skip; }
};

/// Traversal order for region, block and operation walk utilities.
enum class WalkOrder { PreOrder, PostOrder };

namespace detail {
/// Helper templates to deduce the first argument of a callback parameter.
template <typename Ret, typename Arg> Arg first_argument_type(Ret (*)(Arg));
Expand All @@ -64,48 +76,78 @@ template <typename T>
using first_argument = decltype(first_argument_type(std::declval<T>()));

/// Walk all of the regions, blocks, or operations nested under (and including)
/// the given operation.
void walk(Operation *op, function_ref<void(Region *)> callback);
void walk(Operation *op, function_ref<void(Block *)> callback);
void walk(Operation *op, function_ref<void(Operation *)> callback);

/// the given operation. Regions, blocks and operations at the same nesting
/// level are visited in lexicographical order. The walk order for enclosing
/// regions, blocks and operations with respect to their nested ones is
/// specified by 'order'. These methods are invoked for void-returning
/// callbacks. A callback on a block or operation is allowed to erase that block
/// or operation only if the walk is in post-order. See non-void method for
/// pre-order erasure.
void walk(Operation *op, function_ref<void(Region *)> callback,
WalkOrder order);
void walk(Operation *op, function_ref<void(Block *)> callback, WalkOrder order);
void walk(Operation *op, function_ref<void(Operation *)> callback,
WalkOrder order);
/// Walk all of the regions, blocks, or operations nested under (and including)
/// the given operation. These functions walk until an interrupt result is
/// returned by the callback.
WalkResult walk(Operation *op, function_ref<WalkResult(Region *)> callback);
WalkResult walk(Operation *op, function_ref<WalkResult(Block *)> callback);
WalkResult walk(Operation *op, function_ref<WalkResult(Operation *)> callback);
/// the given operation. Regions, blocks and operations at the same nesting
/// level are visited in lexicographical order. The walk order for enclosing
/// regions, blocks and operations with respect to their nested ones is
/// specified by 'order'. This method is invoked for skippable or interruptible
/// callbacks. A callback on a block or operation is allowed to erase that block
/// or operation if either:
/// * the walk is in post-order, or
/// * the walk is in pre-order and the walk is skipped after the erasure.
WalkResult walk(Operation *op, function_ref<WalkResult(Region *)> callback,
WalkOrder order);
WalkResult walk(Operation *op, function_ref<WalkResult(Block *)> callback,
WalkOrder order);
WalkResult walk(Operation *op, function_ref<WalkResult(Operation *)> callback,
WalkOrder order);

// Below are a set of functions to walk nested operations. Users should favor
// the direct `walk` methods on the IR classes(Operation/Block/etc) over these
// methods. They are also templated to allow for statically dispatching based
// upon the type of the callback function.

/// Walk all of the regions, blocks, or operations nested under (and including)
/// the given operation. This method is selected for callbacks that operate on
/// Region*, Block*, and Operation*.
/// the given operation. Regions, blocks and operations at the same nesting
/// level are visited in lexicographical order. The walk order for enclosing
/// regions, blocks and operations with respect to their nested ones is
/// specified by 'Order' (post-order by default). A callback on a block or
/// operation is allowed to erase that block or operation if either:
/// * the walk is in post-order, or
/// * the walk is in pre-order and the walk is skipped after the erasure.
/// This method is selected for callbacks that operate on Region*, Block*, and
/// Operation*.
///
/// Example:
/// op->walk([](Region *r) { ... });
/// op->walk([](Block *b) { ... });
/// op->walk([](Operation *op) { ... });
template <
typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
WalkOrder Order = WalkOrder::PostOrder, typename FuncTy,
typename ArgT = detail::first_argument<FuncTy>,
typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
typename std::enable_if<
llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value, RetT>::type
walk(Operation *op, FuncTy &&callback) {
return walk(op, function_ref<RetT(ArgT)>(callback));
return detail::walk(op, function_ref<RetT(ArgT)>(callback), Order);
}

/// Walk all of the operations of type 'ArgT' nested under and including the
/// given operation. This method is selected for void returning callbacks that
/// operate on a specific derived operation type.
/// given operation. Regions, blocks and operations at the same nesting
/// level are visited in lexicographical order. The walk order for enclosing
/// regions, blocks and operations with respect to their nested ones is
/// specified by 'order' (post-order by default). This method is selected for
/// void-returning callbacks that operate on a specific derived operation type.
/// A callback on an operation is allowed to erase that operation only if the
/// walk is in post-order. See non-void method for pre-order erasure.
///
/// Example:
/// op->walk([](ReturnOp op) { ... });
template <
typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
WalkOrder Order = WalkOrder::PostOrder, typename FuncTy,
typename ArgT = detail::first_argument<FuncTy>,
typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
typename std::enable_if<
!llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value &&
Expand All @@ -116,21 +158,31 @@ walk(Operation *op, FuncTy &&callback) {
if (auto derivedOp = dyn_cast<ArgT>(op))
callback(derivedOp);
};
return detail::walk(op, function_ref<RetT(Operation *)>(wrapperFn));
return detail::walk(op, function_ref<RetT(Operation *)>(wrapperFn), Order);
}

/// Walk all of the operations of type 'ArgT' nested under and including the
/// given operation. This method is selected for WalkReturn returning
/// interruptible callbacks that operate on a specific derived operation type.
/// given operation. Regions, blocks and operations at the same nesting level
/// are visited in lexicographical order. The walk order for enclosing regions,
/// blocks and operations with respect to their nested ones is specified by
/// 'Order' (post-order by default). This method is selected for WalkReturn
/// returning skippable or interruptible callbacks that operate on a specific
/// derived operation type. A callback on an operation is allowed to erase that
/// operation if either:
/// * the walk is in post-order, or
/// * the walk is in pre-order and the walk is skipped after the erasure.
///
/// Example:
/// op->walk([](ReturnOp op) {
/// if (some_invariant)
/// return WalkResult::skip();
/// if (another_invariant)
/// return WalkResult::interrupt();
/// return WalkResult::advance();
/// });
template <
typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
WalkOrder Order = WalkOrder::PostOrder, typename FuncTy,
typename ArgT = detail::first_argument<FuncTy>,
typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
typename std::enable_if<
!llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value &&
Expand All @@ -142,7 +194,7 @@ walk(Operation *op, FuncTy &&callback) {
return callback(derivedOp);
return WalkResult::advance();
};
return detail::walk(op, function_ref<RetT(Operation *)>(wrapperFn));
return detail::walk(op, function_ref<RetT(Operation *)>(wrapperFn), Order);
}

/// Utility to provide the return type of a templated walk method.
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Analysis/Liveness.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ static void buildBlockMapping(Operation *operation,
DenseMap<Block *, BlockInfoBuilder> &builders) {
llvm::SetVector<Block *> toProcess;

operation->walk([&](Block *block) {
operation->walk<WalkOrder::PreOrder>([&](Block *block) {
BlockInfoBuilder &builder =
builders.try_emplace(block, block).first->second;

Expand Down Expand Up @@ -270,7 +270,7 @@ void Liveness::print(raw_ostream &os) const {
DenseMap<Block *, size_t> blockIds;
DenseMap<Operation *, size_t> operationIds;
DenseMap<Value, size_t> valueIds;
operation->walk([&](Block *block) {
operation->walk<WalkOrder::PreOrder>([&](Block *block) {
blockIds.insert({block, blockIds.size()});
for (BlockArgument argument : block->getArguments())
valueIds.insert({argument, valueIds.size()});
Expand Down Expand Up @@ -304,7 +304,7 @@ void Liveness::print(raw_ostream &os) const {
};

// Dump information about in and out values.
operation->walk([&](Block *block) {
operation->walk<WalkOrder::PreOrder>([&](Block *block) {
os << "// - Block: " << blockIds[block] << "\n";
const auto *liveness = getLiveness(block);
os << "// --- LiveIn: ";
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Analysis/NumberOfExecutions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ static void computeRegionBlockNumberOfExecutions(
/// Creates a new NumberOfExecutions analysis that computes how many times a
/// block within a region is executed for all associated regions.
NumberOfExecutions::NumberOfExecutions(Operation *op) : operation(op) {
operation->walk([&](Region *region) {
operation->walk<WalkOrder::PreOrder>([&](Region *region) {
computeRegionBlockNumberOfExecutions(*region, blockNumbersOfExecution);
});
}
Expand Down Expand Up @@ -191,7 +191,7 @@ void NumberOfExecutions::printBlockExecutions(
raw_ostream &os, Region *perEntryOfThisRegion) const {
unsigned blockId = 0;

operation->walk([&](Block *block) {
operation->walk<WalkOrder::PreOrder>([&](Block *block) {
llvm::errs() << "Block: " << blockId++ << "\n";
llvm::errs() << "Number of executions: ";
if (auto n = getNumberOfExecutions(block, perEntryOfThisRegion))
Expand All @@ -203,7 +203,7 @@ void NumberOfExecutions::printBlockExecutions(

void NumberOfExecutions::printOperationExecutions(
raw_ostream &os, Region *perEntryOfThisRegion) const {
operation->walk([&](Block *block) {
operation->walk<WalkOrder::PreOrder>([&](Block *block) {
block->walk([&](Operation *operation) {
// Skip the operation that was used to build the analysis.
if (operation == this->operation)
Expand Down
117 changes: 93 additions & 24 deletions mlir/lib/IR/Visitors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,79 +12,148 @@
using namespace mlir;

/// Walk all of the regions/blocks/operations nested under and including the
/// given operation.
void detail::walk(Operation *op, function_ref<void(Region *)> callback) {
/// given operation. Regions, blocks and operations at the same nesting level
/// are visited in lexicographical order. The walk order for enclosing regions,
/// blocks and operations with respect to their nested ones is specified by
/// 'order'. These methods are invoked for void-returning callbacks. A callback
/// on a block or operation is allowed to erase that block or operation only if
/// the walk is in post-order. See non-void method for pre-order erasure.
void detail::walk(Operation *op, function_ref<void(Region *)> callback,
WalkOrder order) {
// We don't use early increment for regions because they can't be erased from
// a callback.
for (auto &region : op->getRegions()) {
callback(&region);
if (order == WalkOrder::PreOrder)
callback(&region);
for (auto &block : region) {
for (auto &nestedOp : block)
walk(&nestedOp, callback);
walk(&nestedOp, callback, order);
}
if (order == WalkOrder::PostOrder)
callback(&region);
}
}

void detail::walk(Operation *op, function_ref<void(Block *)> callback) {
void detail::walk(Operation *op, function_ref<void(Block *)> callback,
WalkOrder order) {
for (auto &region : op->getRegions()) {
for (auto &block : region) {
callback(&block);
// Early increment here in the case where the block is erased.
for (auto &block : llvm::make_early_inc_range(region)) {
if (order == WalkOrder::PreOrder)
callback(&block);
for (auto &nestedOp : block)
walk(&nestedOp, callback);
walk(&nestedOp, callback, order);
if (order == WalkOrder::PostOrder)
callback(&block);
}
}
}

void detail::walk(Operation *op, function_ref<void(Operation *op)> callback) {
void detail::walk(Operation *op, function_ref<void(Operation *)> callback,
WalkOrder order) {
if (order == WalkOrder::PreOrder)
callback(op);

// TODO: This walk should be iterative over the operations.
for (auto &region : op->getRegions()) {
for (auto &block : region) {
// Early increment here in the case where the operation is erased.
for (auto &nestedOp : llvm::make_early_inc_range(block))
walk(&nestedOp, callback);
walk(&nestedOp, callback, order);
}
}
callback(op);

if (order == WalkOrder::PostOrder)
callback(op);
}

/// Walk all of the regions/blocks/operations nested under and including the
/// given operation. These functions walk operations until an interrupt result
/// is returned by the callback.
/// is returned by the callback. Walks on regions, blocks and operations may
/// also be skipped if the callback returns a skip result. Regions, blocks and
/// operations at the same nesting level are visited in lexicographical order.
/// The walk order for enclosing regions, blocks and operations with respect to
/// their nested ones is specified by 'order'. A callback on a block or
/// operation is allowed to erase that block or operation if either:
/// * the walk is in post-order, or
/// * the walk is in pre-order and the walk is skipped after the erasure.
WalkResult detail::walk(Operation *op,
function_ref<WalkResult(Region *op)> callback) {
function_ref<WalkResult(Region *)> callback,
WalkOrder order) {
// We don't use early increment for regions because they can't be erased from
// a callback.
for (auto &region : op->getRegions()) {
if (callback(&region).wasInterrupted())
return WalkResult::interrupt();
if (order == WalkOrder::PreOrder) {
WalkResult result = callback(&region);
if (result.wasSkipped())
continue;
if (result.wasInterrupted())
return WalkResult::interrupt();
}
for (auto &block : region) {
for (auto &nestedOp : block)
walk(&nestedOp, callback);
walk(&nestedOp, callback, order);
}
if (order == WalkOrder::PostOrder) {
if (callback(&region).wasInterrupted())
return WalkResult::interrupt();
// We don't check if this region was skipped because its walk already
// finished and the walk will continue with the next region.
}
}
return WalkResult::advance();
}

WalkResult detail::walk(Operation *op,
function_ref<WalkResult(Block *op)> callback) {
function_ref<WalkResult(Block *)> callback,
WalkOrder order) {
for (auto &region : op->getRegions()) {
for (auto &block : region) {
if (callback(&block).wasInterrupted())
return WalkResult::interrupt();
// Early increment here in the case where the block is erased.
for (auto &block : llvm::make_early_inc_range(region)) {
if (order == WalkOrder::PreOrder) {
WalkResult result = callback(&block);
if (result.wasSkipped())
continue;
if (result.wasInterrupted())
return WalkResult::interrupt();
}
for (auto &nestedOp : block)
walk(&nestedOp, callback);
walk(&nestedOp, callback, order);
if (order == WalkOrder::PostOrder) {
if (callback(&block).wasInterrupted())
return WalkResult::interrupt();
// We don't check if this block was skipped because its walk already
// finished and the walk will continue with the next block.
}
}
}
return WalkResult::advance();
}

WalkResult detail::walk(Operation *op,
function_ref<WalkResult(Operation *op)> callback) {
function_ref<WalkResult(Operation *)> callback,
WalkOrder order) {
if (order == WalkOrder::PreOrder) {
WalkResult result = callback(op);
// If skipped, caller will continue the walk on the next operation.
if (result.wasSkipped())
return WalkResult::advance();
if (result.wasInterrupted())
return WalkResult::interrupt();
}

// TODO: This walk should be iterative over the operations.
for (auto &region : op->getRegions()) {
for (auto &block : region) {
// Early increment here in the case where the operation is erased.
for (auto &nestedOp : llvm::make_early_inc_range(block)) {
if (walk(&nestedOp, callback).wasInterrupted())
if (walk(&nestedOp, callback, order).wasInterrupted())
return WalkResult::interrupt();
}
}
}
return callback(op);

if (order == WalkOrder::PostOrder)
return callback(op);
return WalkResult::advance();
}
212 changes: 212 additions & 0 deletions mlir/test/IR/visitors.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
// RUN: mlir-opt -test-ir-visitors -allow-unregistered-dialect -split-input-file %s | FileCheck %s

// Verify the different configurations of IR visitors.
// Constant, yield and other terminator ops are not matched for simplicity.
// Module and function op and their immediately nested blocks are not erased in
// callbacks with return so that the output includes more cases in pre-order.

func @structured_cfg() {
%c0 = constant 0 : index
%c1 = constant 1 : index
%c10 = constant 10 : index
scf.for %i = %c1 to %c10 step %c1 {
%cond = "use0"(%i) : (index) -> (i1)
scf.if %cond {
"use1"(%i) : (index) -> ()
} else {
"use2"(%i) : (index) -> ()
}
"use3"(%i) : (index) -> ()
}
return
}

// CHECK-LABEL: Op pre-order visit
// CHECK: Visiting op 'module'
// CHECK: Visiting op 'func'
// CHECK: Visiting op 'scf.for'
// CHECK: Visiting op 'use0'
// CHECK: Visiting op 'scf.if'
// CHECK: Visiting op 'use1'
// CHECK: Visiting op 'use2'
// CHECK: Visiting op 'use3'
// CHECK: Visiting op 'std.return'

// CHECK-LABEL: Block pre-order visits
// CHECK: Visiting block ^bb0 from region 0 from operation 'module'
// CHECK: Visiting block ^bb0 from region 0 from operation 'func'
// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.for'
// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.if'
// CHECK: Visiting block ^bb0 from region 1 from operation 'scf.if'

// CHECK-LABEL: Region pre-order visits
// CHECK: Visiting region 0 from operation 'module'
// CHECK: Visiting region 0 from operation 'func'
// CHECK: Visiting region 0 from operation 'scf.for'
// CHECK: Visiting region 0 from operation 'scf.if'
// CHECK: Visiting region 1 from operation 'scf.if'

// CHECK-LABEL: Op post-order visits
// CHECK: Visiting op 'use0'
// CHECK: Visiting op 'use1'
// CHECK: Visiting op 'use2'
// CHECK: Visiting op 'scf.if'
// CHECK: Visiting op 'use3'
// CHECK: Visiting op 'scf.for'
// CHECK: Visiting op 'std.return'
// CHECK: Visiting op 'func'
// CHECK: Visiting op 'module'

// CHECK-LABEL: Block post-order visits
// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.if'
// CHECK: Visiting block ^bb0 from region 1 from operation 'scf.if'
// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.for'
// CHECK: Visiting block ^bb0 from region 0 from operation 'func'
// CHECK: Visiting block ^bb0 from region 0 from operation 'module'

// CHECK-LABEL: Region post-order visits
// CHECK: Visiting region 0 from operation 'scf.if'
// CHECK: Visiting region 1 from operation 'scf.if'
// CHECK: Visiting region 0 from operation 'scf.for'
// CHECK: Visiting region 0 from operation 'func'
// CHECK: Visiting region 0 from operation 'module'

// CHECK-LABEL: Op pre-order erasures
// CHECK: Erasing op 'scf.for'
// CHECK: Erasing op 'std.return'

// CHECK-LABEL: Block pre-order erasures
// CHECK: Erasing block ^bb0 from region 0 from operation 'scf.for'

// CHECK-LABEL: Op post-order erasures (skip)
// CHECK: Erasing op 'use0'
// CHECK: Erasing op 'use1'
// CHECK: Erasing op 'use2'
// CHECK: Erasing op 'scf.if'
// CHECK: Erasing op 'use3'
// CHECK: Erasing op 'scf.for'
// CHECK: Erasing op 'std.return'

// CHECK-LABEL: Block post-order erasures (skip)
// CHECK: Erasing block ^bb0 from region 0 from operation 'scf.if'
// CHECK: Erasing block ^bb0 from region 1 from operation 'scf.if'
// CHECK: Erasing block ^bb0 from region 0 from operation 'scf.for'

// CHECK-LABEL: Op post-order erasures (no skip)
// CHECK: Erasing op 'use0'
// CHECK: Erasing op 'use1'
// CHECK: Erasing op 'use2'
// CHECK: Erasing op 'scf.if'
// CHECK: Erasing op 'use3'
// CHECK: Erasing op 'scf.for'
// CHECK: Erasing op 'std.return'
// CHECK: Erasing op 'func'
// CHECK: Erasing op 'module'

// CHECK-LABEL: Block post-order erasures (no skip)
// CHECK: Erasing block ^bb0 from region 0 from operation 'scf.if'
// CHECK: Erasing block ^bb0 from region 1 from operation 'scf.if'
// CHECK: Erasing block ^bb0 from region 0 from operation 'scf.for'
// CHECK: Erasing block ^bb0 from region 0 from operation 'func'
// CHECK: Erasing block ^bb0 from region 0 from operation 'module'

// -----

func @unstructured_cfg() {
"regionOp0"() ({
^bb0:
"op0"() : () -> ()
br ^bb2
^bb1:
"op1"() : () -> ()
br ^bb2
^bb2:
"op2"() : () -> ()
}) : () -> ()
return
}

// CHECK-LABEL: Op pre-order visits
// CHECK: Visiting op 'module'
// CHECK: Visiting op 'func'
// CHECK: Visiting op 'regionOp0'
// CHECK: Visiting op 'op0'
// CHECK: Visiting op 'std.br'
// CHECK: Visiting op 'op1'
// CHECK: Visiting op 'std.br'
// CHECK: Visiting op 'op2'
// CHECK: Visiting op 'std.return'

// CHECK-LABEL: Block pre-order visits
// CHECK: Visiting block ^bb0 from region 0 from operation 'module'
// CHECK: Visiting block ^bb0 from region 0 from operation 'func'
// CHECK: Visiting block ^bb0 from region 0 from operation 'regionOp0'
// CHECK: Visiting block ^bb1 from region 0 from operation 'regionOp0'
// CHECK: Visiting block ^bb2 from region 0 from operation 'regionOp0'

// CHECK-LABEL: Region pre-order visits
// CHECK: Visiting region 0 from operation 'module'
// CHECK: Visiting region 0 from operation 'func'
// CHECK: Visiting region 0 from operation 'regionOp0'

// CHECK-LABEL: Op post-order visits
// CHECK: Visiting op 'op0'
// CHECK: Visiting op 'std.br'
// CHECK: Visiting op 'op1'
// CHECK: Visiting op 'std.br'
// CHECK: Visiting op 'op2'
// CHECK: Visiting op 'regionOp0'
// CHECK: Visiting op 'std.return'
// CHECK: Visiting op 'func'
// CHECK: Visiting op 'module'

// CHECK-LABEL: Block post-order visits
// CHECK: Visiting block ^bb0 from region 0 from operation 'regionOp0'
// CHECK: Visiting block ^bb1 from region 0 from operation 'regionOp0'
// CHECK: Visiting block ^bb2 from region 0 from operation 'regionOp0'
// CHECK: Visiting block ^bb0 from region 0 from operation 'func'
// CHECK: Visiting block ^bb0 from region 0 from operation 'module'

// CHECK-LABEL: Region post-order visits
// CHECK: Visiting region 0 from operation 'regionOp0'
// CHECK: Visiting region 0 from operation 'func'
// CHECK: Visiting region 0 from operation 'module'

// CHECK-LABEL: Op pre-order erasures (skip)
// CHECK: Erasing op 'regionOp0'
// CHECK: Erasing op 'std.return'

// CHECK-LABEL: Block pre-order erasures (skip)
// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0'
// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0'
// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0'

// CHECK-LABEL: Op post-order erasures (skip)
// CHECK: Erasing op 'op0'
// CHECK: Erasing op 'std.br'
// CHECK: Erasing op 'op1'
// CHECK: Erasing op 'std.br'
// CHECK: Erasing op 'op2'
// CHECK: Erasing op 'regionOp0'
// CHECK: Erasing op 'std.return'

// CHECK-LABEL: Block post-order erasures (skip)
// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0'
// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0'
// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0'

// CHECK-LABEL: Op post-order erasures (no skip)
// CHECK: Erasing op 'op0'
// CHECK: Erasing op 'std.br'
// CHECK: Erasing op 'op1'
// CHECK: Erasing op 'std.br'
// CHECK: Erasing op 'op2'
// CHECK: Erasing op 'regionOp0'
// CHECK: Erasing op 'std.return'

// CHECK-LABEL: Block post-order erasures (no skip)
// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0'
// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0'
// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0'
// CHECK: Erasing block ^bb0 from region 0 from operation 'func'
// CHECK: Erasing block ^bb0 from region 0 from operation 'module'
1 change: 1 addition & 0 deletions mlir/test/lib/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ add_mlir_library(MLIRTestIR
TestSlicing.cpp
TestSymbolUses.cpp
TestTypes.cpp
TestVisitors.cpp

EXCLUDE_FROM_LIBMLIR

Expand Down
171 changes: 171 additions & 0 deletions mlir/test/lib/IR/TestVisitors.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
//===- TestIRVisitors.cpp - Pass to test the IR visitors ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Pass/Pass.h"

using namespace mlir;

static void printRegion(Region *region) {
llvm::outs() << "region " << region->getRegionNumber() << " from operation '"
<< region->getParentOp()->getName() << "'";
}

static void printBlock(Block *block) {
llvm::outs() << "block ";
block->printAsOperand(llvm::outs(), /*printType=*/false);
llvm::outs() << " from ";
printRegion(block->getParent());
}

static void printOperation(Operation *op) {
llvm::outs() << "op '" << op->getName() << "'";
}

/// Tests pure callbacks.
static void testPureCallbacks(Operation *op) {
auto opPure = [](Operation *op) {
llvm::outs() << "Visiting ";
printOperation(op);
llvm::outs() << "\n";
};
auto blockPure = [](Block *block) {
llvm::outs() << "Visiting ";
printBlock(block);
llvm::outs() << "\n";
};
auto regionPure = [](Region *region) {
llvm::outs() << "Visiting ";
printRegion(region);
llvm::outs() << "\n";
};

llvm::outs() << "Op pre-order visits"
<< "\n";
op->walk<WalkOrder::PreOrder>(opPure);
llvm::outs() << "Block pre-order visits"
<< "\n";
op->walk<WalkOrder::PreOrder>(blockPure);
llvm::outs() << "Region pre-order visits"
<< "\n";
op->walk<WalkOrder::PreOrder>(regionPure);

llvm::outs() << "Op post-order visits"
<< "\n";
op->walk<WalkOrder::PostOrder>(opPure);
llvm::outs() << "Block post-order visits"
<< "\n";
op->walk<WalkOrder::PostOrder>(blockPure);
llvm::outs() << "Region post-order visits"
<< "\n";
op->walk<WalkOrder::PostOrder>(regionPure);
}

/// Tests erasure callbacks that skip the walk.
static void testSkipErasureCallbacks(Operation *op) {
auto skipOpErasure = [](Operation *op) {
// Do not erase module and function op. Otherwise there wouldn't be too
// much to test in pre-order.
if (isa<ModuleOp>(op) || isa<FuncOp>(op))
return WalkResult::advance();

llvm::outs() << "Erasing ";
printOperation(op);
llvm::outs() << "\n";
op->dropAllUses();
op->erase();
return WalkResult::skip();
};
auto skipBlockErasure = [](Block *block) {
// Do not erase module and function blocks. Otherwise there wouldn't be
// too much to test in pre-order.
Operation *parentOp = block->getParentOp();
if (isa<ModuleOp>(parentOp) || isa<FuncOp>(parentOp))
return WalkResult::advance();

llvm::outs() << "Erasing ";
printBlock(block);
llvm::outs() << "\n";
block->erase();
return WalkResult::skip();
};

llvm::outs() << "Op pre-order erasures (skip)"
<< "\n";
Operation *cloned = op->clone();
cloned->walk<WalkOrder::PreOrder>(skipOpErasure);
cloned->erase();

llvm::outs() << "Block pre-order erasures (skip)"
<< "\n";
cloned = op->clone();
cloned->walk<WalkOrder::PreOrder>(skipBlockErasure);
cloned->erase();

llvm::outs() << "Op post-order erasures (skip)"
<< "\n";
cloned = op->clone();
cloned->walk<WalkOrder::PostOrder>(skipOpErasure);
cloned->erase();

llvm::outs() << "Block post-order erasures (skip)"
<< "\n";
cloned = op->clone();
cloned->walk<WalkOrder::PostOrder>(skipBlockErasure);
cloned->erase();
}

/// Tests callbacks that erase the op or block but don't return 'Skip'. This
/// callbacks are only valid in post-order.
static void testNoSkipErasureCallbacks(Operation *op) {
auto noSkipOpErasure = [](Operation *op) {
llvm::outs() << "Erasing ";
printOperation(op);
llvm::outs() << "\n";
op->dropAllUses();
op->erase();
};
auto noSkipBlockErasure = [](Block *block) {
llvm::outs() << "Erasing ";
printBlock(block);
llvm::outs() << "\n";
block->erase();
};

llvm::outs() << "Op post-order erasures (no skip)"
<< "\n";
Operation *cloned = op->clone();
cloned->walk<WalkOrder::PostOrder>(noSkipOpErasure);

llvm::outs() << "Block post-order erasures (no skip)"
<< "\n";
cloned = op->clone();
cloned->walk<WalkOrder::PostOrder>(noSkipBlockErasure);
cloned->erase();
}

namespace {
/// This pass exercises the different configurations of the IR visitors.
struct TestIRVisitorsPass
: public PassWrapper<TestIRVisitorsPass, OperationPass<>> {
void runOnOperation() override {
Operation *op = getOperation();
testPureCallbacks(op);
testSkipErasureCallbacks(op);
testNoSkipErasureCallbacks(op);
}
};
} // end anonymous namespace

namespace mlir {
namespace test {
void registerTestIRVisitorsPass() {
PassRegistration<TestIRVisitorsPass>("test-ir-visitors",
"Test various visitors.");
}
} // namespace test
} // namespace mlir
2 changes: 2 additions & 0 deletions mlir/tools/mlir-opt/mlir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ void registerTestDominancePass();
void registerTestDynamicPipelinePass();
void registerTestExpandTanhPass();
void registerTestGpuParallelLoopMappingPass();
void registerTestIRVisitorsPass();
void registerTestInterfaces();
void registerTestLinalgCodegenStrategy();
void registerTestLinalgFusionTransforms();
Expand Down Expand Up @@ -146,6 +147,7 @@ void registerTestPasses() {
test::registerTestDynamicPipelinePass();
test::registerTestExpandTanhPass();
test::registerTestGpuParallelLoopMappingPass();
test::registerTestIRVisitorsPass();
test::registerTestInterfaces();
test::registerTestLinalgCodegenStrategy();
test::registerTestLinalgFusionTransforms();
Expand Down