diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index 724da009e70f1..690e9e88a87b8 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -37,6 +37,7 @@ class GreedyRewriteConfig; #define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS #define GEN_PASS_DECL_CONTROLFLOWSINK #define GEN_PASS_DECL_GENERATERUNTIMEVERIFICATION +#define GEN_PASS_DECL_HOISTPUREOPS #define GEN_PASS_DECL_LOOPINVARIANTCODEMOTION #define GEN_PASS_DECL_INLINER #define GEN_PASS_DECL_MEM2REG diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index 28b4a01cf0ecd..c74ce1946cb03 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -636,4 +636,8 @@ def BubbleDownMemorySpaceCasts : }]; } +def HoistPureOps : + Pass<"hoist-pure-ops"> { +} + #endif // MLIR_TRANSFORMS_PASSES diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt index 54b67f5c7a91e..b32865ed01b82 100644 --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -21,6 +21,7 @@ add_mlir_library(MLIRTransforms SymbolPrivatize.cpp TopologicalSort.cpp ViewOpGraph.cpp + HoistPureOps.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms diff --git a/mlir/lib/Transforms/HoistPureOps.cpp b/mlir/lib/Transforms/HoistPureOps.cpp new file mode 100644 index 0000000000000..b35743f8ffd40 --- /dev/null +++ b/mlir/lib/Transforms/HoistPureOps.cpp @@ -0,0 +1,136 @@ +//===- HoistPureOps.cpp - Hoist Pure ops ----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the function of hoist the pure op based on SSA +// dominance. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Passes.h" +#include "llvm/Support/DebugLog.h" + +namespace mlir { +#define GEN_PASS_DEF_HOISTPUREOPS +#include "mlir/Transforms/Passes.h.inc" +} // namespace mlir + +#define DEBUG_TYPE "hoist-pure-ops" + +using namespace mlir; + +namespace { + +/// Return the dominated Value. +static Value getDomaincedValue(DominanceInfo &dominanceInfo, Value a, Value b) { + Block *aB = a.getParentBlock(); + Block *bB = b.getParentBlock(); + if (isa(a) && isa(b)) { + return dominanceInfo.dominates(aB, bB) ? b : a; + } else if (isa(a) || isa(b)) { + if (aB != bB) + return dominanceInfo.dominates(aB, bB) ? b : a; + if (auto aArg = dyn_cast(a)) { + Operation *aFrontOp = &aArg.getOwner()->front(); + if (aFrontOp == b.getDefiningOp()) + return b; + return dominanceInfo.dominates(aFrontOp, b.getDefiningOp()) ? b : a; + } + auto bArg = cast(b); + Operation *bFrontOp = &bArg.getOwner()->front(); + if (bFrontOp == a.getDefiningOp()) + return a; + return dominanceInfo.dominates(a.getDefiningOp(), bFrontOp) ? b : a; + } else { + Operation *aDefineOp = a.getDefiningOp(); + Operation *bDefineOp = b.getDefiningOp(); + return dominanceInfo.dominates(aDefineOp, bDefineOp) ? b : a; + } +} + +static bool isOpContainBlock(Operation *op, Block *block) { + Operation *parentOp = block->getParentOp(); + while (parentOp && parentOp != op) { + parentOp = parentOp->getParentOp(); + } + return parentOp == op ? true : false; +} + +/// Find the hoisting position for the pure op. +static Value getDestPos(Operation *op) { + DominanceInfo dominanceInfo(op); + SmallVector operands(op->getOperands()); + if (op->getNumRegions()) { + op->walk([&](Operation *operation) { + for (auto operand : operation->getOperands()) { + Operation *defineOp = operand.getDefiningOp(); + if (!defineOp) { + BlockArgument argument = cast(operand); + if (!isOpContainBlock(op, argument.getOwner())) + operands.push_back(operand); + continue; + } + if (!isOpContainBlock(op, defineOp->getBlock())) { + operands.push_back(operand); + } + } + }); + } + if (operands.empty()) + return {}; + Value ret = operands[0]; + for (int i = 1, e = operands.size(); i < e; ++i) { + ret = getDomaincedValue(dominanceInfo, ret, operands[i]); + } + return ret; +} + +/// Hoist single pure op. +static void hoistPureOp(RewriterBase &rewriter, Operation *op) { + LDBG() << "hoistPureOp: " << OpWithFlags(op, OpPrintingFlags().skipRegions()); + Value pos = getDestPos(op); + if (!pos) + return; + + if (Operation *defineOp = pos.getDefiningOp()) { + if (op == defineOp) + return; + + LDBG() << "move " << OpWithFlags(op, OpPrintingFlags().skipRegions()) + << " after " + << OpWithFlags(defineOp, OpPrintingFlags().skipRegions()); + rewriter.moveOpAfter(op, defineOp); + return; + } + auto argument = cast(pos); + LDBG() << "move " << OpWithFlags(op, OpPrintingFlags().skipRegions()) + << " before " + << OpWithFlags(&argument.getOwner()->front(), + OpPrintingFlags().skipRegions()); + rewriter.moveOpBefore(op, &argument.getOwner()->front()); +} + +struct HoistPureOps : public impl::HoistPureOpsBase { + void runOnOperation() override; +}; +} // namespace + +void HoistPureOps::runOnOperation() { + Operation *module = getOperation(); + IRRewriter rewriter(module->getContext()); + module->walk([&](Operation *op) { + if (op->hasTrait()) + return; + if (isPure(op)) { + hoistPureOp(rewriter, op); + } + }); +} diff --git a/mlir/test/Transforms/hoist-pure-ops.mlir b/mlir/test/Transforms/hoist-pure-ops.mlir new file mode 100644 index 0000000000000..d719e84862134 --- /dev/null +++ b/mlir/test/Transforms/hoist-pure-ops.mlir @@ -0,0 +1,67 @@ +// RUN: mlir-opt %s -hoist-pure-ops -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func @hoist_cast_pos +// CHECK-SAME: %[[ARG0:.*]]: memref<10xf32>, +// CHECK-SAME: %[[ARG1:.*]]: i1 +func.func @hoist_cast_pos(%arg: memref<10xf32>, %arg1: i1) -> (memref) { + // CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]] + // CHECK: %[[CAST_1:.*]] = memref.cast %[[ARG0]] + // CHECK-NEXT: cf.cond_br %[[ARG1]] + cf.cond_br %arg1, ^bb1, ^bb2 +^bb1: + %cast = memref.cast %arg : memref<10xf32> to memref + // CHECK: return %[[CAST_1]] + return %cast : memref +^bb2: + %cast1 = memref.cast %arg : memref<10xf32> to memref + // CHECK: return %[[CAST_0]] + return %cast1 : memref +} + +// ----- + +// CHECK-LABEL: func.func @hoist_cast_pos_alloc +// CHECK-SAME: %[[ARG0:.*]]: i1 +func.func @hoist_cast_pos_alloc(%arg: i1) -> (memref) { + // CHECK: %[[ALLOC_0:.*]] = memref.alloc() + // CHECK: %[[CAST_0:.*]] = memref.cast %[[ALLOC_0]] + // CHECK: %[[CAST_1:.*]] = memref.cast %[[ALLOC_0]] + // CHECK-NEXT: cf.cond_br %[[ARG0]] + %alloc = memref.alloc() : memref<10xf32> + cf.cond_br %arg, ^bb1, ^bb2 +^bb1: + %cast = memref.cast %alloc : memref<10xf32> to memref + // CHECK: return %[[CAST_1]] + return %cast : memref +^bb2: + %cast1 = memref.cast %alloc : memref<10xf32> to memref + // CHECK: return %[[CAST_0]] + return %cast1 : memref +} + +// ----- + +// CHECK-LABEL: func @mult_scf_sum( +// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index +func.func @mult_scf_sum(%arg0: index, %arg1: index, %arg2: index) -> index { + %c0 = arith.constant 0 : index + %res0 = scf.for %iv0 = %arg0 to %arg1 step %arg2 iter_args(%sum0 = %c0) -> index { + %res1 = scf.for %iv1 = %arg0 to %arg1 step %arg2 iter_args(%sum1 = %sum0) -> index { + %res2 = scf.for %iv2 = %arg0 to %arg1 step %arg2 iter_args(%sum2 = %sum1) -> index { + %add0 = arith.addi %iv0, %iv1 : index + %add1 = arith.addi %add0, %iv2 : index + %add2 = arith.addi %add1, %sum2 : index + scf.yield %add1 : index + } + scf.yield %res2 : index + } + scf.yield %res1 : index + } + // CHECK: %[[FOR_0:.*]] = scf.for %[[IV_0:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]] + // CHECK-NEXT: %[[FOR_1:.*]] = scf.for %[[IV_1:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]] + // CHECK-NEXT: %[[ADDI_0:.*]] = arith.addi %[[IV_0]], %[[IV_1]] : index + // CHECK-NEXT: %[[FOR_2:.*]] = scf.for %[[IV_3:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]] iter_args(%[[ITER:.*]] = %{{.*}}) + // CHECK-NEXT: %[[ADDI_1:.*]] = arith.addi %[[ADDI_0]], %[[IV_3]] : index + // CHECK-NEXT: %[[ADDI_2:.*]] = arith.addi %[[ADDI_1]], %[[ITER]] : index + return %res0 : index +} \ No newline at end of file