diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index da061b269daf7..fa4bcb5bce5db 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -58,6 +58,7 @@ #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" +#include "mlir/Conversion/SCFToAffine/SCFToAffine.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h" #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" @@ -87,6 +88,9 @@ namespace mlir { +std::unique_ptr replaceAffineCFGPass(); +std::unique_ptr createRaiseSCFToAffinePass(); + /// Generate the code for registering conversion passes. #define GEN_PASS_REGISTRATION #include "mlir/Conversion/Passes.h.inc" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 1a37d057776e2..85f49448e38da 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1029,6 +1029,24 @@ def ReconcileUnrealizedCastsPass : Pass<"reconcile-unrealized-casts"> { }]; } +//===----------------------------------------------------------------------===// +// SCFToAffine +//===----------------------------------------------------------------------===// +def AffineCFG : Pass<"affine-cfg"> { + let summary = "Replace scf.if and similar with affine.if"; + let constructor = "mlir::replaceAffineCFGPass()"; +} + +def RaiseSCFToAffine : Pass<"raise-scf-to-affine"> { + let summary = "Raise SCF to affine"; + let constructor = "mlir::createRaiseSCFToAffinePass()"; + let dependentDialects = [ + "affine::AffineDialect", + "scf::SCFDialect", + ]; +} + + //===----------------------------------------------------------------------===// // SCFToControlFlow //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/SCFToAffine/SCFToAffine.h b/mlir/include/mlir/Conversion/SCFToAffine/SCFToAffine.h new file mode 100644 index 0000000000000..372d19d60fdb3 --- /dev/null +++ b/mlir/include/mlir/Conversion/SCFToAffine/SCFToAffine.h @@ -0,0 +1,14 @@ +#ifndef __MLIR_CONVERSION_SCFTOAFFINE_H +#define __MLIR_CONVERSION_SCFTOAFFINE_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" + +namespace mlir { + +#define GEN_PASS_DECL_RAISESCFTOAFFINEPASS +#define GEN_PASS_DECL_AFFINECFGPASS +#include "mlir/Conversion/Passes.h.inc" + +} // namespace mlir + +#endif // __MLIR_CONVERSION_SCFTOAFFINE_H \ No newline at end of file diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 71986f83c4870..d9da085378834 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -52,6 +52,7 @@ add_subdirectory(OpenMPToLLVM) add_subdirectory(PDLToPDLInterp) add_subdirectory(PtrToLLVM) add_subdirectory(ReconcileUnrealizedCasts) +add_subdirectory(SCFToAffine) add_subdirectory(SCFToControlFlow) add_subdirectory(SCFToEmitC) add_subdirectory(SCFToGPU) diff --git a/mlir/lib/Conversion/SCFToAffine/AffineCFG.cpp b/mlir/lib/Conversion/SCFToAffine/AffineCFG.cpp new file mode 100644 index 0000000000000..ad33736f0b36a --- /dev/null +++ b/mlir/lib/Conversion/SCFToAffine/AffineCFG.cpp @@ -0,0 +1,1385 @@ +#include "./Ops.h" +#include "mlir/Conversion/Passes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/IntegerSet.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallSet.h" +#include + +#define DEBUG_TYPE "affine-cfg" + +using namespace mlir; +using namespace mlir::arith; +using namespace mlir::affine; + +namespace mlir { + +#define GEN_PASS_DEF_AFFINECFG +#include "mlir/Conversion/Passes.h.inc" + +} // namespace mlir + +bool isValidIndex(Value val); + +bool isReadOnly(Operation *op); + +bool isValidSymbolInt(Value value, bool recur = true); +bool isValidSymbolInt(Operation *defOp, bool recur) { + Attribute operandCst; + if (matchPattern(defOp, m_Constant(&operandCst))) + return true; + + if (recur) { + if (isa(defOp)) + if (llvm::all_of(defOp->getOperands(), + [&](Value v) { return isValidSymbolInt(v, recur); })) + return true; + if (auto ifOp = mlir::dyn_cast(defOp)) { + if (isValidSymbolInt(ifOp.getCondition(), recur)) { + if (llvm::all_of( + ifOp.thenBlock()->without_terminator(), + [&](Operation &o) { return isValidSymbolInt(&o, recur); }) && + llvm::all_of( + ifOp.elseBlock()->without_terminator(), + [&](Operation &o) { return isValidSymbolInt(&o, recur); })) + return true; + } + } + if (auto ifOp = dyn_cast(defOp)) { + if (llvm::all_of(ifOp.getOperands(), + [&](Value o) { return isValidSymbolInt(o, recur); })) + if (llvm::all_of( + ifOp.getThenBlock()->without_terminator(), + [&](Operation &o) { return isValidSymbolInt(&o, recur); }) && + llvm::all_of( + ifOp.getElseBlock()->without_terminator(), + [&](Operation &o) { return isValidSymbolInt(&o, recur); })) + return true; + } + } + return false; +} + +// isValidSymbol, even if not index +bool isValidSymbolInt(Value value, bool recur) { + // Check that the value is a top level value. + if (affine::isTopLevelValue(value)) + return true; + + if (auto *defOp = value.getDefiningOp()) { + if (isValidSymbolInt(defOp, recur)) + return true; + return affine::isValidSymbol(value, affine::getAffineScope(defOp)); + } + + return false; +} + +struct AffineApplyNormalizer { + AffineApplyNormalizer(AffineMap map, ArrayRef operands, + PatternRewriter &rewriter, DominanceInfo &DI); + + /// Returns the AffineMap resulting from normalization. + AffineMap getAffineMap() { return affineMap; } + + SmallVector getOperands() { + SmallVector res(reorderedDims); + res.append(concatenatedSymbols.begin(), concatenatedSymbols.end()); + return res; + } + +private: + /// Helper function to insert `v` into the coordinate system of the current + /// AffineApplyNormalizer. Returns the AffineDimExpr with the corresponding + /// renumbered position. + AffineDimExpr renumberOneDim(Value v); + + /// Maps of Value to position in `affineMap`. + DenseMap dimValueToPosition; + + /// Ordered dims and symbols matching positional dims and symbols in + /// `affineMap`. + SmallVector reorderedDims; + SmallVector concatenatedSymbols; + + AffineMap affineMap; +}; + +static bool isAffineForArg(Value val) { + if (!mlir::isa(val)) + return false; + Operation *parentOp = + mlir::cast(val).getOwner()->getParentOp(); + return ( + isa_and_nonnull(parentOp)); +} + +static bool legalCondition(Value en, bool dim = false) { + if (en.getDefiningOp()) + return true; + + if (!dim && !isValidSymbolInt(en, /*recur*/ false)) { + if (isValidIndex(en) || isValidSymbolInt(en, /*recur*/ true)) { + return true; + } + } + + while (auto ic = en.getDefiningOp()) + en = ic.getIn(); + + if ((en.getDefiningOp() || en.getDefiningOp() || + en.getDefiningOp() || en.getDefiningOp() || + en.getDefiningOp()) && + (en.getDefiningOp()->getOperand(1).getDefiningOp() || + en.getDefiningOp()->getOperand(1).getDefiningOp())) + return true; + // if (auto IC = dyn_cast_or_null(en.getDefiningOp())) { + // if (!outer || legalCondition(IC.getOperand(), false)) return true; + //} + if (!dim) + if (auto BA = dyn_cast(en)) { + if (isa( + BA.getOwner()->getParentOp())) + return true; + } + return false; +} + +/// The AffineNormalizer composes AffineApplyOp recursively. Its purpose is to +/// keep a correspondence between the mathematical `map` and the `operands` of +/// a given affine::AffineApplyOp. This correspondence is maintained by +/// iterating over the operands and forming an `auxiliaryMap` that can be +/// composed mathematically with `map`. To keep this correspondence in cases +/// where symbols are produced by affine.apply operations, we perform a local +/// rewrite of symbols as dims. +/// +/// Rationale for locally rewriting symbols as dims: +/// ================================================ +/// The mathematical composition of AffineMap must always concatenate symbols +/// because it does not have enough information to do otherwise. For example, +/// composing `(d0)[s0] -> (d0 + s0)` with itself must produce +/// `(d0)[s0, s1] -> (d0 + s0 + s1)`. +/// +/// The result is only equivalent to `(d0)[s0] -> (d0 + 2 * s0)` when +/// applied to the same mlir::Value for both s0 and s1. +/// As a consequence mathematical composition of AffineMap always concatenates +/// symbols. +/// +/// When AffineMaps are used in affine::AffineApplyOp however, they may specify +/// composition via symbols, which is ambiguous mathematically. This corner case +/// is handled by locally rewriting such symbols that come from +/// affine::AffineApplyOp into dims and composing through dims. +/// TODO: Composition via symbols comes at a significant code +/// complexity. Alternatively we should investigate whether we want to +/// explicitly disallow symbols coming from affine.apply and instead force the +/// user to compose symbols beforehand. The annoyances may be small (i.e. 1 or 2 +/// extra API calls for such uses, which haven't popped up until now) and the +/// benefit potentially big: simpler and more maintainable code for a +/// non-trivial, recursive, procedure. +AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map, + ArrayRef operands, + PatternRewriter &rewriter, + DominanceInfo &DI) { + assert(map.getNumInputs() == operands.size() && + "number of operands does not match the number of map inputs"); + + LLVM_DEBUG(map.print(llvm::dbgs() << "\nInput map: ")); + + SmallVector addedValues; + + llvm::SmallSet symbolsToPromote; + + unsigned numDims = map.getNumDims(); + + SmallVector dimReplacements; + SmallVector symReplacements; + + SmallVector *> opsTodos; + auto replaceOp = [&](Operation *oldOp, Operation *newOp) { + for (auto [oldV, newV] : + llvm::zip(oldOp->getResults(), newOp->getResults())) + for (auto *ops : opsTodos) + for (auto &op : *ops) + if (op == oldV) + op = newV; + }; + + std::function fix = [&](Value v, + bool index) -> Value /*legal*/ { + if (isValidSymbolInt(v, /*recur*/ false)) + return v; + if (index && isAffineForArg(v)) + return v; + auto *op = v.getDefiningOp(); + if (!op) + return nullptr; + if (!op) + llvm::errs() << v << "\n"; + assert(op); + if (isa(op) || isa(op)) + return v; + if (!isReadOnly(op)) { + return nullptr; + } + Operation *front = nullptr; + SmallVector ops; + opsTodos.push_back(&ops); + std::function getAllOps = [&](Operation *todo) { + for (auto v : todo->getOperands()) { + if (llvm::all_of(op->getRegions(), [&](Region &r) { + return !r.isAncestor(v.getParentRegion()); + })) + ops.push_back(v); + } + for (auto &r : todo->getRegions()) { + for (auto &b : r.getBlocks()) + for (auto &o2 : b.without_terminator()) + getAllOps(&o2); + } + }; + getAllOps(op); + for (auto o : ops) { + Operation *next; + if (auto *op = o.getDefiningOp()) { + if (Value nv = fix(o, index)) { + op = nv.getDefiningOp(); + } else { + return nullptr; + } + next = op->getNextNode(); + } else { + auto ba = mlir::cast(o); + if (index && isAffineForArg(ba)) { + } else if (!isValidSymbolInt(o, /*recur*/ false)) { + return nullptr; + } + next = &ba.getOwner()->front(); + } + if (front == nullptr) + front = next; + else if (DI.dominates(front, next)) + front = next; + } + opsTodos.pop_back(); + if (!front) + op->dump(); + assert(front); + PatternRewriter::InsertionGuard B(rewriter); + rewriter.setInsertionPoint(front); + auto *cloned = rewriter.clone(*op); + replaceOp(op, cloned); + rewriter.replaceOp(op, cloned->getResults()); + return cloned->getResult(0); + }; + auto renumberOneSymbol = [&](Value v) { + for (auto i : llvm::enumerate(addedValues)) { + if (i.value() == v) + return getAffineSymbolExpr(i.index(), map.getContext()); + } + auto expr = getAffineSymbolExpr(addedValues.size(), map.getContext()); + addedValues.push_back(v); + return expr; + }; + + // 2. Compose affine::AffineApplyOps and dispatch dims or symbols. + for (unsigned i = 0, e = operands.size(); i < e; ++i) { + auto t = operands[i]; + auto decast = t; + while (true) { + if (auto idx = decast.getDefiningOp()) { + decast = idx.getIn(); + continue; + } + if (auto idx = decast.getDefiningOp()) { + decast = idx.getIn(); + continue; + } + if (auto idx = decast.getDefiningOp()) { + decast = idx.getIn(); + continue; + } + break; + } + + if (!isValidSymbolInt(t, /*recur*/ false)) { + t = decast; + } + + // Only promote one at a time, lest we end up with two dimensions + // multiplying each other. + + if (((!isValidSymbolInt(t, /*recur*/ false) && + (t.getDefiningOp() || t.getDefiningOp() || + (t.getDefiningOp() && + ((isValidIndex(t.getDefiningOp()->getOperand(0)) && + isValidSymbolInt(t.getDefiningOp()->getOperand(1))) || + (isValidIndex(t.getDefiningOp()->getOperand(1)) && + isValidSymbolInt(t.getDefiningOp()->getOperand(0)))) && + !(fix(t.getDefiningOp()->getOperand(0), false) && + fix(t.getDefiningOp()->getOperand(1), false)) + + ) || + ((t.getDefiningOp() || t.getDefiningOp()) && + (isValidIndex(t.getDefiningOp()->getOperand(0)) && + isValidSymbolInt(t.getDefiningOp()->getOperand(1))) && + (!(fix(t.getDefiningOp()->getOperand(0), false) && + fix(t.getDefiningOp()->getOperand(1), false)))) || + (t.getDefiningOp() && + (isValidIndex(t.getDefiningOp()->getOperand(0)) && + isValidSymbolInt(t.getDefiningOp()->getOperand(1)))) || + (t.getDefiningOp() && + (isValidIndex(t.getDefiningOp()->getOperand(0)) && + isValidSymbolInt(t.getDefiningOp()->getOperand(1)))) || + (t.getDefiningOp() && + (isValidIndex(t.getDefiningOp()->getOperand(0)) && + isValidSymbolInt(t.getDefiningOp()->getOperand(1)))) || + t.getDefiningOp() || + t.getDefiningOp())) || + ((decast.getDefiningOp() || decast.getDefiningOp() || + decast.getDefiningOp() || decast.getDefiningOp() || + decast.getDefiningOp()) && + (decast.getDefiningOp() + ->getOperand(1) + .getDefiningOp() || + decast.getDefiningOp() + ->getOperand(1) + .getDefiningOp())))) { + t = decast; + LLVM_DEBUG(llvm::dbgs() << " Replacing: " << t << "\n"); + + AffineMap affineApplyMap; + SmallVector affineApplyOperands; + + // llvm::dbgs() << "\nop to start: " << t << "\n"; + + if (auto op = t.getDefiningOp()) { + affineApplyMap = + AffineMap::get(0, 2, + getAffineSymbolExpr(0, op.getContext()) + + getAffineSymbolExpr(1, op.getContext())); + affineApplyOperands.push_back(op.getLhs()); + affineApplyOperands.push_back(op.getRhs()); + } else if (auto op = t.getDefiningOp()) { + affineApplyMap = + AffineMap::get(0, 2, + getAffineSymbolExpr(0, op.getContext()) - + getAffineSymbolExpr(1, op.getContext())); + affineApplyOperands.push_back(op.getLhs()); + affineApplyOperands.push_back(op.getRhs()); + } else if (auto op = t.getDefiningOp()) { + if (auto ci = op.getRhs().getDefiningOp()) { + affineApplyMap = AffineMap::get( + 0, 1, getAffineSymbolExpr(0, op.getContext()) * ci.value()); + affineApplyOperands.push_back(op.getLhs()); + } else if (auto ci = op.getRhs().getDefiningOp()) { + affineApplyMap = AffineMap::get( + 0, 1, getAffineSymbolExpr(0, op.getContext()) * ci.value()); + affineApplyOperands.push_back(op.getLhs()); + } else { + affineApplyMap = + AffineMap::get(0, 2, + getAffineSymbolExpr(0, op.getContext()) * + getAffineSymbolExpr(1, op.getContext())); + affineApplyOperands.push_back(op.getLhs()); + affineApplyOperands.push_back(op.getRhs()); + } + } else if (auto op = t.getDefiningOp()) { + if (auto ci = op.getRhs().getDefiningOp()) { + affineApplyMap = AffineMap::get( + 0, 1, + getAffineSymbolExpr(0, op.getContext()).floorDiv(ci.value())); + affineApplyOperands.push_back(op.getLhs()); + } else if (auto ci = op.getRhs().getDefiningOp()) { + affineApplyMap = AffineMap::get( + 0, 1, + getAffineSymbolExpr(0, op.getContext()).floorDiv(ci.value())); + affineApplyOperands.push_back(op.getLhs()); + } else { + affineApplyMap = AffineMap::get( + 0, 2, + getAffineSymbolExpr(0, op.getContext()) + .floorDiv(getAffineSymbolExpr(1, op.getContext()))); + affineApplyOperands.push_back(op.getLhs()); + affineApplyOperands.push_back(op.getRhs()); + } + } else if (auto op = t.getDefiningOp()) { + if (auto ci = op.getRhs().getDefiningOp()) { + affineApplyMap = AffineMap::get( + 0, 1, + getAffineSymbolExpr(0, op.getContext()).floorDiv(ci.value())); + affineApplyOperands.push_back(op.getLhs()); + } else if (auto ci = op.getRhs().getDefiningOp()) { + affineApplyMap = AffineMap::get( + 0, 1, + getAffineSymbolExpr(0, op.getContext()).floorDiv(ci.value())); + affineApplyOperands.push_back(op.getLhs()); + } else { + affineApplyMap = AffineMap::get( + 0, 2, + getAffineSymbolExpr(0, op.getContext()) + .floorDiv(getAffineSymbolExpr(1, op.getContext()))); + affineApplyOperands.push_back(op.getLhs()); + affineApplyOperands.push_back(op.getRhs()); + } + } else if (auto op = t.getDefiningOp()) { + if (auto ci = op.getRhs().getDefiningOp()) { + affineApplyMap = AffineMap::get( + 0, 1, getAffineSymbolExpr(0, op.getContext()) % ci.value()); + affineApplyOperands.push_back(op.getLhs()); + } else if (auto ci = op.getRhs().getDefiningOp()) { + affineApplyMap = AffineMap::get( + 0, 1, getAffineSymbolExpr(0, op.getContext()) % ci.value()); + affineApplyOperands.push_back(op.getLhs()); + } else { + affineApplyMap = + AffineMap::get(0, 2, + getAffineSymbolExpr(0, op.getContext()) % + getAffineSymbolExpr(1, op.getContext())); + affineApplyOperands.push_back(op.getLhs()); + affineApplyOperands.push_back(op.getRhs()); + } + } else if (auto op = t.getDefiningOp()) { + if (auto ci = op.getRhs().getDefiningOp()) { + affineApplyMap = AffineMap::get( + 0, 1, getAffineSymbolExpr(0, op.getContext()) % ci.value()); + affineApplyOperands.push_back(op.getLhs()); + } else if (auto ci = op.getRhs().getDefiningOp()) { + affineApplyMap = AffineMap::get( + 0, 1, getAffineSymbolExpr(0, op.getContext()) % ci.value()); + affineApplyOperands.push_back(op.getLhs()); + } else { + affineApplyMap = + AffineMap::get(0, 2, + getAffineSymbolExpr(0, op.getContext()) % + getAffineSymbolExpr(1, op.getContext())); + affineApplyOperands.push_back(op.getLhs()); + affineApplyOperands.push_back(op.getRhs()); + } + } else if (auto op = t.getDefiningOp()) { + affineApplyMap = AffineMap::get( + 0, 0, getAffineConstantExpr(op.value(), op.getContext())); + } else if (auto op = t.getDefiningOp()) { + affineApplyMap = AffineMap::get( + 0, 0, getAffineConstantExpr(op.value(), op.getContext())); + } else { + llvm_unreachable(""); + } + + SmallVector dimRemapping; + unsigned numOtherSymbols = affineApplyOperands.size(); + SmallVector symRemapping(numOtherSymbols); + for (unsigned idx = 0; idx < numOtherSymbols; ++idx) { + symRemapping[idx] = renumberOneSymbol(affineApplyOperands[idx]); + } + affineApplyMap = affineApplyMap.replaceDimsAndSymbols( + dimRemapping, symRemapping, reorderedDims.size(), addedValues.size()); + + LLVM_DEBUG(affineApplyMap.print( + llvm::dbgs() << "\nRenumber into current normalizer: ")); + + if (i >= numDims) + symReplacements.push_back(affineApplyMap.getResult(0)); + else + dimReplacements.push_back(affineApplyMap.getResult(0)); + + } else if (isAffineForArg(t)) { + if (i >= numDims) + symReplacements.push_back(renumberOneDim(t)); + else + dimReplacements.push_back(renumberOneDim(t)); + } else if (t.getDefiningOp()) { + auto affineApply = t.getDefiningOp(); + // a. Compose affine.apply operations. + LLVM_DEBUG(affineApply->print( + llvm::dbgs() << "\nCompose affine::AffineApplyOp recursively: ")); + AffineMap affineApplyMap = affineApply.getAffineMap(); + SmallVector affineApplyOperands( + affineApply.getOperands().begin(), affineApply.getOperands().end()); + + SmallVector dimRemapping(affineApplyMap.getNumDims()); + + for (size_t i = 0; i < affineApplyMap.getNumDims(); ++i) { + assert(i < affineApplyOperands.size()); + dimRemapping[i] = renumberOneDim(affineApplyOperands[i]); + } + unsigned numOtherSymbols = affineApplyOperands.size(); + SmallVector symRemapping(numOtherSymbols - + affineApplyMap.getNumDims()); + for (unsigned idx = 0; idx < symRemapping.size(); ++idx) { + symRemapping[idx] = renumberOneSymbol( + affineApplyOperands[idx + affineApplyMap.getNumDims()]); + } + affineApplyMap = affineApplyMap.replaceDimsAndSymbols( + dimRemapping, symRemapping, reorderedDims.size(), addedValues.size()); + + LLVM_DEBUG( + affineApplyMap.print(llvm::dbgs() << "\nAffine apply fixup map: ")); + + if (i >= numDims) + symReplacements.push_back(affineApplyMap.getResult(0)); + else + dimReplacements.push_back(affineApplyMap.getResult(0)); + } else { + if (!isValidSymbolInt(t, /*recur*/ false)) { + if (t.getDefiningOp()) { + if ((t = fix(t, false))) { + assert(isValidSymbolInt(t, /*recur*/ false)); + } else + llvm_unreachable("cannot move"); + } else + llvm_unreachable("cannot move2"); + } + if (i < numDims) { + // b. The mathematical composition of AffineMap composes dims. + dimReplacements.push_back(renumberOneDim(t)); + } else { + // c. The mathematical composition of AffineMap concatenates symbols. + // Note that the map composition will put symbols already present + // in the map before any symbols coming from the auxiliary map, so + // we insert them before any symbols that are due to renumbering, + // and after the proper symbols we have seen already. + symReplacements.push_back(renumberOneSymbol(t)); + } + } + } + for (auto v : addedValues) + concatenatedSymbols.push_back(v); + + // Create the new map by replacing each symbol at pos by the next new dim. + unsigned numNewDims = reorderedDims.size(); + unsigned numNewSymbols = addedValues.size(); + assert(dimReplacements.size() == map.getNumDims()); + assert(symReplacements.size() == map.getNumSymbols()); + auto auxillaryMap = map.replaceDimsAndSymbols( + dimReplacements, symReplacements, numNewDims, numNewSymbols); + LLVM_DEBUG(auxillaryMap.print(llvm::dbgs() << "\nRewritten map: ")); + + affineMap = auxillaryMap; // simplifyAffineMap(auxillaryMap); + + LLVM_DEBUG(affineMap.print(llvm::dbgs() << "\nSimplified result: ")); + LLVM_DEBUG(llvm::dbgs() << "\n"); +} + +AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value v) { + DenseMap::iterator iterPos; + bool inserted = false; + std::tie(iterPos, inserted) = + dimValueToPosition.insert(std::make_pair(v, dimValueToPosition.size())); + if (inserted) { + reorderedDims.push_back(v); + } + return mlir::cast( + getAffineDimExpr(iterPos->second, v.getContext())); +} + +static void composeAffineMapAndOperands(AffineMap *map, + SmallVectorImpl *operands, + PatternRewriter &rewriter, + DominanceInfo &di) { + AffineApplyNormalizer normalizer(*map, *operands, rewriter, di); + auto normalizedMap = normalizer.getAffineMap(); + auto normalizedOperands = normalizer.getOperands(); + affine::canonicalizeMapAndOperands(&normalizedMap, &normalizedOperands); + *map = normalizedMap; + *operands = normalizedOperands; + assert(*map); +} + +bool need(AffineMap *map, SmallVectorImpl *operands) { + assert(map->getNumInputs() == operands->size()); + for (size_t i = 0; i < map->getNumInputs(); ++i) { + auto v = (*operands)[i]; + if (legalCondition(v, i < map->getNumDims())) + return true; + } + return false; +} +bool need(IntegerSet *map, SmallVectorImpl *operands) { + for (size_t i = 0; i < map->getNumInputs(); ++i) { + auto v = (*operands)[i]; + if (legalCondition(v, i < map->getNumDims())) + return true; + } + return false; +} + +void fully2ComposeAffineMapAndOperands(PatternRewriter &builder, AffineMap *map, + SmallVectorImpl *operands, + DominanceInfo &di) { + IRMapping indexMap; + for (auto op : *operands) { + SmallVector attempt; + auto idx0 = op.getDefiningOp(); + attempt.push_back(idx0); + if (!idx0) + continue; + + for (auto &u : idx0.getIn().getUses()) { + if (auto idx = dyn_cast(u.getOwner())) + if (di.dominates((Operation *)idx, &*builder.getInsertionPoint())) + attempt.push_back(idx); + } + + for (auto idx : attempt) { + if (affine::isValidSymbol(idx)) { + indexMap.map(idx.getIn(), idx); + break; + } + } + } + assert(map->getNumInputs() == operands->size()); + while (need(map, operands)) { + composeAffineMapAndOperands(map, operands, builder, di); + assert(map->getNumInputs() == operands->size()); + } + *map = simplifyAffineMap(*map); + for (auto &op : *operands) { + if (!op.getType().isIndex()) { + Operation *toInsert; + if (auto *o = op.getDefiningOp()) + toInsert = o->getNextNode(); + else { + auto ba = mlir::cast(op); + toInsert = &ba.getOwner()->front(); + } + + if (auto v = indexMap.lookupOrNull(op)) + op = v; + else { + PatternRewriter::InsertionGuard b(builder); + builder.setInsertionPoint(toInsert); + op = builder.create(op.getLoc(), builder.getIndexType(), + op); + } + } + } +} + +void fully2ComposeIntegerSetAndOperands(PatternRewriter &builder, + IntegerSet *set, + SmallVectorImpl *operands, + DominanceInfo &DI) { + IRMapping indexMap; + for (auto op : *operands) { + SmallVector attempt; + auto idx0 = op.getDefiningOp(); + attempt.push_back(idx0); + if (!idx0) + continue; + + for (auto &u : idx0.getIn().getUses()) { + if (auto idx = dyn_cast(u.getOwner())) + if (DI.dominates((Operation *)idx, &*builder.getInsertionPoint())) + attempt.push_back(idx); + } + + for (auto idx : attempt) { + if (affine::isValidSymbol(idx)) { + indexMap.map(idx.getIn(), idx); + break; + } + } + } + auto map = AffineMap::get(set->getNumDims(), set->getNumSymbols(), + set->getConstraints(), set->getContext()); + while (need(&map, operands)) { + composeAffineMapAndOperands(&map, operands, builder, DI); + } + map = simplifyAffineMap(map); + *set = IntegerSet::get(map.getNumDims(), map.getNumSymbols(), + map.getResults(), set->getEqFlags()); + for (auto &op : *operands) { + if (!op.getType().isIndex()) { + Operation *toInsert; + if (auto *o = op.getDefiningOp()) + toInsert = o->getNextNode(); + else { + auto ba = mlir::cast(op); + toInsert = &ba.getOwner()->front(); + } + + if (auto v = indexMap.lookupOrNull(op)) + op = v; + else { + PatternRewriter::InsertionGuard b(builder); + builder.setInsertionPoint(toInsert); + op = builder.create(op.getLoc(), builder.getIndexType(), + op); + } + } + } +} + +namespace { +struct AffineCFG : public impl::AffineCFGBase { + void runOnOperation() override; +}; +} // namespace + +struct IndexCastMovement : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IndexCastOp op, + PatternRewriter &rewriter) const override { + if (op.use_empty()) { + rewriter.eraseOp(op); + return success(); + } + + mlir::Value val = op.getOperand(); + if (auto bop = dyn_cast(val)) { + if (op.getOperation()->getBlock() != bop.getOwner()) { + op.getOperation()->moveBefore(bop.getOwner(), bop.getOwner()->begin()); + return success(); + } + return failure(); + } + + if (val.getDefiningOp()) { + if (op.getOperation()->getBlock() != val.getDefiningOp()->getBlock()) { + auto it = val.getDefiningOp()->getIterator(); + op.getOperation()->moveAfter(val.getDefiningOp()->getBlock(), it); + } + return failure(); + } + return failure(); + } +}; + +struct CanonicalizeAffineApply + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(affine::AffineApplyOp affineOp, + PatternRewriter &rewriter) const override { + + SmallVector mapOperands(affineOp.getMapOperands()); + auto map = affineOp.getMap(); + auto prevMap = map; + + auto *scope = affine::getAffineScope(affineOp)->getParentOp(); + DominanceInfo di(scope); + + fully2ComposeAffineMapAndOperands(rewriter, &map, &mapOperands, di); + affine::canonicalizeMapAndOperands(&map, &mapOperands); + map = removeDuplicateExprs(map); + + if (map == prevMap) + return failure(); + + rewriter.replaceOpWithNewOp(affineOp, map, + mapOperands); + return success(); + } +}; + +struct CanonicalizeIndexCast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IndexCastOp indexcastOp, + PatternRewriter &rewriter) const override { + + // Fold IndexCast(IndexCast(x)) -> x + auto cast = indexcastOp.getOperand().getDefiningOp(); + if (cast && cast.getOperand().getType() == indexcastOp.getType()) { + mlir::Value vals[] = {cast.getOperand()}; + rewriter.replaceOp(indexcastOp, vals); + return success(); + } + + // Fold IndexCast(constant) -> constant + // A little hack because we go through int. Otherwise, the size + // of the constant might need to change. + if (auto cst = indexcastOp.getOperand().getDefiningOp()) { + rewriter.replaceOpWithNewOp(indexcastOp, cst.value()); + return success(); + } + return failure(); + } +}; + +bool isValidIndex(Value val) { + if (isValidSymbolInt(val)) + return true; + + if (auto cast = val.getDefiningOp()) + return isValidIndex(cast.getOperand()); + + if (auto cast = val.getDefiningOp()) + return isValidIndex(cast.getOperand()); + + if (auto cast = val.getDefiningOp()) + return isValidIndex(cast.getOperand()); + + if (auto bop = val.getDefiningOp()) + return isValidIndex(bop.getOperand(0)) && isValidIndex(bop.getOperand(1)); + + if (auto bop = val.getDefiningOp()) + return (isValidIndex(bop.getOperand(0)) && + isValidSymbolInt(bop.getOperand(1))) || + (isValidIndex(bop.getOperand(1)) && + isValidSymbolInt(bop.getOperand(0))); + + if (auto bop = val.getDefiningOp()) + return (isValidIndex(bop.getOperand(0)) && + isValidSymbolInt(bop.getOperand(1))); + + if (auto bop = val.getDefiningOp()) + return (isValidIndex(bop.getOperand(0)) && + isValidSymbolInt(bop.getOperand(1))); + + if (auto bop = val.getDefiningOp()) { + return (isValidIndex(bop.getOperand(0)) && + bop.getOperand(1).getDefiningOp()); + } + + if (auto bop = val.getDefiningOp()) + return (isValidIndex(bop.getOperand(0)) && + bop.getOperand(1).getDefiningOp()); + + if (auto bop = val.getDefiningOp()) + return isValidIndex(bop.getOperand(0)) && isValidIndex(bop.getOperand(1)); + + if (val.getDefiningOp()) + return true; + + if (val.getDefiningOp()) + return true; + + if (auto ba = dyn_cast(val)) { + auto *owner = ba.getOwner(); + assert(owner); + + auto *parentOp = owner->getParentOp(); + if (!parentOp) { + owner->dump(); + llvm::errs() << " ba: " << ba << "\n"; + } + assert(parentOp); + if (isa(parentOp)) + return true; + if (auto af = dyn_cast(parentOp)) + return af.getInductionVar() == ba; + + // TODO ensure not a reduced var + if (isa(parentOp)) + return true; + + if (isa(parentOp)) + return true; + } + + LLVM_DEBUG(llvm::dbgs() << "illegal isValidIndex: " << val << "\n"); + return false; +} + +// returns legality +bool handleMinMax(Value start, SmallVectorImpl &out, bool &min, + bool &max) { + + SmallVector todo = {start}; + while (todo.size()) { + auto cur = todo.back(); + todo.pop_back(); + if (isValidIndex(cur)) { + out.push_back(cur); + continue; + } + if (auto selOp = cur.getDefiningOp()) { + // UB only has min of operands + if (auto cmp = selOp.getCondition().getDefiningOp()) { + if (cmp.getLhs() == selOp.getTrueValue() && + cmp.getRhs() == selOp.getFalseValue()) { + todo.push_back(cmp.getLhs()); + todo.push_back(cmp.getRhs()); + if (cmp.getPredicate() == CmpIPredicate::sle || + cmp.getPredicate() == CmpIPredicate::slt) { + min = true; + continue; + } + if (cmp.getPredicate() == CmpIPredicate::sge || + cmp.getPredicate() == CmpIPredicate::sgt) { + max = true; + continue; + } + } + } + } + return false; + } + return !(min && max); +} + +bool handle(PatternRewriter &b, CmpIOp cmpi, SmallVectorImpl &exprs, + SmallVectorImpl &eqflags, SmallVectorImpl &applies) { + SmallVector lhs; + bool lhsMin = false; + bool lhsMax = false; + if (!handleMinMax(cmpi.getLhs(), lhs, lhsMin, lhsMax)) { + LLVM_DEBUG(llvm::dbgs() + << "illegal lhs: " << cmpi.getLhs() << " - " << cmpi << "\n"); + return false; + } + assert(lhs.size()); + SmallVector rhs; + bool rhsMin = false; + bool rhsMax = false; + if (!handleMinMax(cmpi.getRhs(), rhs, rhsMin, rhsMax)) { + LLVM_DEBUG(llvm::dbgs() + << "illegal rhs: " << cmpi.getRhs() << " - " << cmpi << "\n"); + return false; + } + assert(rhs.size()); + for (auto &lhspack : lhs) + if (!mlir::isa(lhspack.getType())) { + lhspack = b.create( + cmpi.getLoc(), IndexType::get(cmpi.getContext()), lhspack); + } + + for (auto &rhspack : rhs) + if (!mlir::isa(rhspack.getType())) { + rhspack = b.create( + cmpi.getLoc(), IndexType::get(cmpi.getContext()), rhspack); + } + + switch (cmpi.getPredicate()) { + case CmpIPredicate::eq: { + if (lhsMin || lhsMax || rhsMin || rhsMax) + return false; + eqflags.push_back(true); + + applies.push_back(lhs[0]); + applies.push_back(rhs[0]); + AffineExpr dims[2] = {b.getAffineSymbolExpr(2 * exprs.size() + 0), + b.getAffineSymbolExpr(2 * exprs.size() + 1)}; + exprs.push_back(dims[0] - dims[1]); + } break; + + case CmpIPredicate::ugt: + case CmpIPredicate::uge: + for (auto lhspack : lhs) + if (!valueCmp(Cmp::GE, lhspack, 0)) { + LLVM_DEBUG(llvm::dbgs() << "illegal greater lhs icmp: " << cmpi << " - " + << lhspack << "\n"); + return false; + } + for (auto rhspack : rhs) + if (!valueCmp(Cmp::GE, rhspack, 0)) { + LLVM_DEBUG(llvm::dbgs() << "illegal greater rhs icmp: " << cmpi << " - " + << rhspack << "\n"); + return false; + } + LLVM_FALLTHROUGH; + case CmpIPredicate::sge: + case CmpIPredicate::sgt: { + // if lhs >=? rhs + // if lhs is a min(a, b) both must be true and this is fine + // if lhs is a max(a, b) either may be true, and sets require and + // similarly if rhs is a max(), both must be true; + if (lhsMax || rhsMin) + return false; + for (auto lhspack : lhs) + for (auto rhspack : rhs) { + eqflags.push_back(false); + applies.push_back(lhspack); + applies.push_back(rhspack); + AffineExpr dims[2] = {b.getAffineSymbolExpr(2 * exprs.size() + 0), + b.getAffineSymbolExpr(2 * exprs.size() + 1)}; + auto expr = dims[0] - dims[1]; + if (cmpi.getPredicate() == CmpIPredicate::sgt || + cmpi.getPredicate() == CmpIPredicate::ugt) + expr = expr - 1; + exprs.push_back(expr); + } + } break; + + case CmpIPredicate::ult: + case CmpIPredicate::ule: + for (auto lhspack : lhs) + if (!valueCmp(Cmp::GE, lhspack, 0)) { + LLVM_DEBUG(llvm::dbgs() << "illegal less lhs icmp: " << cmpi << " - " + << lhspack << "\n"); + return false; + } + for (auto rhspack : rhs) + if (!valueCmp(Cmp::GE, rhspack, 0)) { + LLVM_DEBUG(llvm::dbgs() << "illegal less rhs icmp: " << cmpi << " - " + << rhspack << "\n"); + return false; + } + LLVM_FALLTHROUGH; + case CmpIPredicate::slt: + case CmpIPredicate::sle: { + if (lhsMin || rhsMax) + return false; + for (auto lhspack : lhs) + for (auto rhspack : rhs) { + eqflags.push_back(false); + applies.push_back(lhspack); + applies.push_back(rhspack); + AffineExpr dims[2] = {b.getAffineSymbolExpr(2 * exprs.size() + 0), + b.getAffineSymbolExpr(2 * exprs.size() + 1)}; + auto expr = dims[1] - dims[0]; + if (cmpi.getPredicate() == CmpIPredicate::slt || + cmpi.getPredicate() == CmpIPredicate::ult) + expr = expr - 1; + exprs.push_back(expr); + } + } break; + + case CmpIPredicate::ne: + LLVM_DEBUG(llvm::dbgs() << "illegal icmp: " << cmpi << "\n"); + return false; + } + return true; +} + +struct MoveLoadToAffine : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::LoadOp load, + PatternRewriter &rewriter) const override { + if (!llvm::all_of(load.getIndices(), isValidIndex)) + return failure(); + + auto memrefType = mlir::cast(load.getMemRef().getType()); + int64_t rank = memrefType.getRank(); + + // Create identity map for memrefs with at least one dimension or () -> () + // for zero-dimensional memrefs. + SmallVector dimExprs; + dimExprs.reserve(rank); + for (unsigned i = 0; i < rank; ++i) + dimExprs.push_back(rewriter.getAffineSymbolExpr(i)); + auto map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/rank, dimExprs, + rewriter.getContext()); + + SmallVector operands = load.getIndices(); + + if (map.getNumInputs() != operands.size()) { + // load->getParentOfType().dump(); + llvm::errs() << " load: " << load << "\n"; + } + auto *scope = affine::getAffineScope(load)->getParentOp(); + DominanceInfo di(scope); + assert(map.getNumInputs() == operands.size()); + fully2ComposeAffineMapAndOperands(rewriter, &map, &operands, di); + assert(map.getNumInputs() == operands.size()); + affine::canonicalizeMapAndOperands(&map, &operands); + assert(map.getNumInputs() == operands.size()); + + affine::AffineLoadOp affineLoad = affine::AffineLoadOp::create( + rewriter, load.getLoc(), load.getMemRef(), map, operands); + load.getResult().replaceAllUsesWith(affineLoad.getResult()); + rewriter.eraseOp(load); + return success(); + } +}; + +struct MoveStoreToAffine : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::StoreOp store, + PatternRewriter &rewriter) const override { + if (!llvm::all_of(store.getIndices(), isValidIndex)) + return failure(); + + auto memrefType = mlir::cast(store.getMemRef().getType()); + int64_t rank = memrefType.getRank(); + + // Create identity map for memrefs with at least one dimension or () -> () + // for zero-dimensional memrefs. + SmallVector dimExprs; + dimExprs.reserve(rank); + for (unsigned i = 0; i < rank; ++i) + dimExprs.push_back(rewriter.getAffineSymbolExpr(i)); + auto map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/rank, dimExprs, + rewriter.getContext()); + SmallVector operands = store.getIndices(); + + auto *scope = affine::getAffineScope(store)->getParentOp(); + DominanceInfo di(scope); + + fully2ComposeAffineMapAndOperands(rewriter, &map, &operands, di); + affine::canonicalizeMapAndOperands(&map, &operands); + + affine::AffineStoreOp::create(rewriter, store.getLoc(), + store.getValueToStore(), store.getMemRef(), + map, operands); + rewriter.eraseOp(store); + return success(); + } +}; + +static bool areChanged(SmallVectorImpl &afterOperands, + SmallVectorImpl &beforeOperands) { + if (afterOperands.size() != beforeOperands.size()) + return true; + if (!std::equal(afterOperands.begin(), afterOperands.end(), + beforeOperands.begin())) + return true; + return false; +} + +template +struct AffineFixup : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + /// Replace the affine op with another instance of it with the supplied + /// map and mapOperands. + void replaceAffineOp(PatternRewriter &rewriter, T affineOp, AffineMap map, + ArrayRef mapOperands) const; + + LogicalResult matchAndRewrite(T op, + PatternRewriter &rewriter) const override { + auto map = op.getAffineMap(); + SmallVector operands = op.getMapOperands(); + + auto prevMap = map; + auto prevOperands = operands; + + auto *scope = affine::getAffineScope(op)->getParentOp(); + DominanceInfo di(scope); + + assert(map.getNumInputs() == operands.size()); + fully2ComposeAffineMapAndOperands(rewriter, &map, &operands, di); + assert(map.getNumInputs() == operands.size()); + affine::canonicalizeMapAndOperands(&map, &operands); + assert(map.getNumInputs() == operands.size()); + + if (map == prevMap && !areChanged(operands, prevOperands)) + return failure(); + + replaceAffineOp(rewriter, op, map, operands); + return success(); + } +}; + +// Specialize the template to account for the different build signatures for +// affine load, store, and apply ops. +template <> +void AffineFixup::replaceAffineOp( + PatternRewriter &rewriter, affine::AffineLoadOp load, AffineMap map, + ArrayRef mapOperands) const { + rewriter.replaceOpWithNewOp(load, load.getMemRef(), map, + mapOperands); +} +template <> +void AffineFixup::replaceAffineOp( + PatternRewriter &rewriter, affine::AffinePrefetchOp prefetch, AffineMap map, + ArrayRef mapOperands) const { + rewriter.replaceOpWithNewOp( + prefetch, prefetch.getMemref(), map, mapOperands, + prefetch.getLocalityHint(), prefetch.getIsWrite(), + prefetch.getIsDataCache()); +} +template <> +void AffineFixup::replaceAffineOp( + PatternRewriter &rewriter, affine::AffineStoreOp store, AffineMap map, + ArrayRef mapOperands) const { + rewriter.replaceOpWithNewOp( + store, store.getValueToStore(), store.getMemRef(), map, mapOperands); +} +template <> +void AffineFixup::replaceAffineOp( + PatternRewriter &rewriter, affine::AffineVectorLoadOp vectorload, + AffineMap map, ArrayRef mapOperands) const { + rewriter.replaceOpWithNewOp( + vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map, + mapOperands); +} +template <> +void AffineFixup::replaceAffineOp( + PatternRewriter &rewriter, affine::AffineVectorStoreOp vectorstore, + AffineMap map, ArrayRef mapOperands) const { + rewriter.replaceOpWithNewOp( + vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map, + mapOperands); +} + +// Generic version for ops that don't have extra operands. +template +void AffineFixup::replaceAffineOp( + PatternRewriter &rewriter, AffineOpTy op, AffineMap map, + ArrayRef mapOperands) const { + rewriter.replaceOpWithNewOp(op, map, mapOperands); +} + +struct CanonicalieForBounds : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(affine::AffineForOp forOp, + PatternRewriter &rewriter) const override { + SmallVector lbOperands(forOp.getLowerBoundOperands()); + SmallVector ubOperands(forOp.getUpperBoundOperands()); + SmallVector origLbOperands(forOp.getLowerBoundOperands()); + SmallVector origUbOperands(forOp.getUpperBoundOperands()); + + auto lbMap = forOp.getLowerBoundMap(); + auto ubMap = forOp.getUpperBoundMap(); + auto prevLbMap = lbMap; + auto prevUbMap = ubMap; + + // llvm::errs() << "*********\n"; + // ubMap.dump(); + + auto *scope = affine::getAffineScope(forOp)->getParentOp(); + DominanceInfo di(scope); + + fully2ComposeAffineMapAndOperands(rewriter, &lbMap, &lbOperands, di); + affine::canonicalizeMapAndOperands(&lbMap, &lbOperands); + lbMap = removeDuplicateExprs(lbMap); + + fully2ComposeAffineMapAndOperands(rewriter, &ubMap, &ubOperands, di); + affine::canonicalizeMapAndOperands(&ubMap, &ubOperands); + ubMap = removeDuplicateExprs(ubMap); + + // ubMap.dump(); + // forOp.dump(); + + // Any canonicalization change in map or operands always leads to updated + // map(s). + if ((lbMap == prevLbMap && ubMap == prevUbMap) && + (!areChanged(lbOperands, origLbOperands)) && + (!areChanged(ubOperands, origUbOperands))) + return failure(); + + // llvm::errs() << "oldParent:" << *forOp.getParentOp() << "\n"; + // llvm::errs() << "oldfor:" << forOp << "\n"; + + if ((lbMap != prevLbMap) || areChanged(lbOperands, origLbOperands)) + forOp.setLowerBound(lbOperands, lbMap); + if ((ubMap != prevUbMap) || areChanged(ubOperands, origUbOperands)) + forOp.setUpperBound(ubOperands, ubMap); + + // llvm::errs() << "newfor:" << forOp << "\n"; + return success(); + } +}; + +struct CanonicalizIfBounds : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(affine::AffineIfOp op, + PatternRewriter &rewriter) const override { + SmallVector operands(op.getOperands()); + SmallVector origOperands(operands); + + auto map = op.getIntegerSet(); + auto prevMap = map; + + // llvm::errs() << "*********\n"; + // ubMap.dump(); + + auto *scope = affine::getAffineScope(op)->getParentOp(); + DominanceInfo DI(scope); + + fully2ComposeIntegerSetAndOperands(rewriter, &map, &operands, DI); + affine::canonicalizeSetAndOperands(&map, &operands); + + // map(s). + if (map == prevMap && !areChanged(operands, origOperands)) + return failure(); + + op.setConditional(map, operands); + + return success(); + } +}; + +struct MoveIfToAffine : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::IfOp ifOp, + PatternRewriter &rewriter) const override { + if (!ifOp->getParentOfType() && + !ifOp->getParentOfType()) + return failure(); + + std::vector types; + for (auto v : ifOp.getResults()) { + types.push_back(v.getType()); + } + + SmallVector exprs; + SmallVector eqflags; + SmallVector applies; + + std::deque todo = {ifOp.getCondition()}; + while (todo.size()) { + auto cur = todo.front(); + todo.pop_front(); + if (auto cmpi = cur.getDefiningOp()) { + if (!handle(rewriter, cmpi, exprs, eqflags, applies)) { + return failure(); + } + continue; + } + if (auto andi = cur.getDefiningOp()) { + todo.push_back(andi.getOperand(0)); + todo.push_back(andi.getOperand(1)); + continue; + } + return failure(); + } + + auto *scope = affine::getAffineScope(ifOp)->getParentOp(); + DominanceInfo di(scope); + + auto iset = + IntegerSet::get(/*dim*/ 0, /*symbol*/ 2 * exprs.size(), exprs, eqflags); + fully2ComposeIntegerSetAndOperands(rewriter, &iset, &applies, di); + affine::canonicalizeSetAndOperands(&iset, &applies); + affine::AffineIfOp affineIfOp = affine::AffineIfOp::create( + rewriter, ifOp.getLoc(), types, iset, applies, + /*elseBlock=*/true); + + rewriter.setInsertionPoint(ifOp.thenYield()); + rewriter.replaceOpWithNewOp( + ifOp.thenYield(), ifOp.thenYield().getOperands()); + + rewriter.eraseBlock(affineIfOp.getThenBlock()); + rewriter.eraseBlock(affineIfOp.getElseBlock()); + if (ifOp.getElseRegion().getBlocks().size()) { + rewriter.setInsertionPoint(ifOp.elseYield()); + rewriter.replaceOpWithNewOp( + ifOp.elseYield(), ifOp.elseYield().getOperands()); + } + + rewriter.inlineRegionBefore(ifOp.getThenRegion(), + affineIfOp.getThenRegion(), + affineIfOp.getThenRegion().begin()); + rewriter.inlineRegionBefore(ifOp.getElseRegion(), + affineIfOp.getElseRegion(), + affineIfOp.getElseRegion().begin()); + + rewriter.replaceOp(ifOp, affineIfOp.getResults()); + return success(); + } +}; + +void AffineCFG::runOnOperation() { + mlir::RewritePatternSet rpl(getOperation()->getContext()); + rpl.add, + AffineFixup, CanonicalizIfBounds, + MoveStoreToAffine, MoveIfToAffine, MoveLoadToAffine, + CanonicalieForBounds>(getOperation()->getContext()); + GreedyRewriteConfig config; + (void)applyPatternsGreedily(getOperation(), std::move(rpl), config); +} + +std::unique_ptr mlir::replaceAffineCFGPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Conversion/SCFToAffine/CMakeLists.txt b/mlir/lib/Conversion/SCFToAffine/CMakeLists.txt new file mode 100644 index 0000000000000..8bc6d43ff199c --- /dev/null +++ b/mlir/lib/Conversion/SCFToAffine/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_conversion_library(MLIRSCFToAffine + RaiseToAffine.cpp + AffineCFG.cpp + Ops.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SCFToAffine + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRAffineDialect + MLIRLLVMDialect + MLIRSCFDialect + MLIRSCFTransforms + MLIRTransforms +) diff --git a/mlir/lib/Conversion/SCFToAffine/Ops.cpp b/mlir/lib/Conversion/SCFToAffine/Ops.cpp new file mode 100644 index 0000000000000..fdd13bbb5f384 --- /dev/null +++ b/mlir/lib/Conversion/SCFToAffine/Ops.cpp @@ -0,0 +1,359 @@ + +#include "./Ops.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Value.h" + +using namespace mlir; +using namespace mlir::arith; + +bool valueCmp(Cmp cmp, Value bval, ValueOrInt val) { + if (auto icast = bval.getDefiningOp()) { + return valueCmp(cmp, icast.getIn(), val); + } + + IntegerAttr iattr; + if (matchPattern(bval, m_Constant(&iattr))) { + switch (cmp) { + case Cmp::EQ: + return val == iattr.getValue(); + case Cmp::LT: + return val > iattr.getValue(); + case Cmp::LE: + return val >= iattr.getValue(); + case Cmp::GT: + return val < iattr.getValue(); + case Cmp::GE: + return val <= iattr.getValue(); + } + } + + if (auto baval = dyn_cast(bval)) { + if (affine::AffineForOp afFor = + dyn_cast(baval.getOwner()->getParentOp())) { + auto forLb = afFor.getLowerBoundMap().getResults()[baval.getArgNumber()]; + auto forUb = afFor.getUpperBoundMap().getResults()[baval.getArgNumber()]; + switch (cmp) { + // \forall i \in [LB, UB) == k => LB == k and UB == k+1 + case Cmp::EQ: { + if (!valueCmp(Cmp::EQ, forLb, afFor.getLowerBoundMap().getNumDims(), + afFor.getLowerBoundOperands(), val)) + return false; + if (!val.isValue) { + return valueCmp(Cmp::EQ, forUb, afFor.getUpperBoundMap().getNumDims(), + afFor.getUpperBoundOperands(), val.iVal + 1); + } + return false; + } + // \forall i \in [LB, UB) < k => UB <= k + case Cmp::LT: { + return valueCmp(Cmp::LE, forUb, afFor.getUpperBoundMap().getNumDims(), + afFor.getUpperBoundOperands(), val); + } + // \forall i \in [LB, UB) <= k => UB-1 <= k => UB <= k+1 + case Cmp::LE: { + if (!val.isValue) { + return valueCmp(Cmp::LE, forUb, afFor.getUpperBoundMap().getNumDims(), + afFor.getUpperBoundOperands(), val.iVal + 1); + } + return valueCmp(Cmp::LE, forUb, afFor.getUpperBoundMap().getNumDims(), + afFor.getUpperBoundOperands(), val); + } + // \forall i \in [LB, UB) > k => LB > k + case Cmp::GT: { + return valueCmp(Cmp::GT, forLb, afFor.getLowerBoundMap().getNumDims(), + afFor.getLowerBoundOperands(), val); + } + // \forall i \in [LB, UB) >= k => LB >= k + case Cmp::GE: { + return valueCmp(Cmp::GE, forLb, afFor.getLowerBoundMap().getNumDims(), + afFor.getLowerBoundOperands(), val); + } + } + } + if (affine::AffineParallelOp afFor = dyn_cast( + baval.getOwner()->getParentOp())) { + switch (cmp) { + // \forall i \in [max(LB...), min(UB...)) == k => all(LB == k) and + // all(UB == k+1) + case Cmp::EQ: { + for (auto forLb : + afFor.getLowerBoundMap(baval.getArgNumber()).getResults()) + if (!valueCmp(Cmp::EQ, forLb, afFor.getLowerBoundsMap().getNumDims(), + afFor.getLowerBoundsOperands(), val)) + return false; + if (!val.isValue) { + for (auto forUb : + afFor.getUpperBoundMap(baval.getArgNumber()).getResults()) + if (!valueCmp(Cmp::EQ, forUb, + afFor.getUpperBoundsMap().getNumDims(), + afFor.getUpperBoundsOperands(), val.iVal + 1)) + return false; + return true; + } + return false; + } + // \forall i \in [max(LB...), min(UB...)) < k => any(UB <= k) + case Cmp::LT: { + for (auto forUb : + afFor.getUpperBoundMap(baval.getArgNumber()).getResults()) + if (valueCmp(Cmp::LE, forUb, afFor.getUpperBoundsMap().getNumDims(), + afFor.getUpperBoundsOperands(), val)) + return true; + return false; + } + // \forall i \in [max(LB...), min(UB...)) <= k => any(UB-1 <= k) => + // any(UB <= k+1) + case Cmp::LE: { + if (!val.isValue) { + for (auto forUb : + afFor.getUpperBoundMap(baval.getArgNumber()).getResults()) + if (valueCmp(Cmp::LE, forUb, afFor.getUpperBoundsMap().getNumDims(), + afFor.getUpperBoundsOperands(), val.iVal + 1)) + return true; + return false; + } + + for (auto forUb : + afFor.getUpperBoundMap(baval.getArgNumber()).getResults()) + if (valueCmp(Cmp::LE, forUb, afFor.getUpperBoundsMap().getNumDims(), + afFor.getUpperBoundsOperands(), val)) + return true; + return false; + } + // \forall i \in [max(LB...), min(UB...)) > k => any(LB > k) + case Cmp::GT: { + for (auto forLb : + afFor.getLowerBoundMap(baval.getArgNumber()).getResults()) + if (valueCmp(Cmp::GT, forLb, afFor.getLowerBoundsMap().getNumDims(), + afFor.getLowerBoundsOperands(), val)) + return true; + return false; + } + // \forall i \in [max(LB...), min(UB...)) >= k => any(LB >= k) + case Cmp::GE: { + for (auto forLb : + afFor.getLowerBoundMap(baval.getArgNumber()).getResults()) + if (valueCmp(Cmp::GE, forLb, afFor.getLowerBoundsMap().getNumDims(), + afFor.getLowerBoundsOperands(), val)) + return true; + return false; + } + } + } + + if (scf::ForOp afFor = + dyn_cast(baval.getOwner()->getParentOp())) { + if (baval.getArgNumber() == 0) { + auto forLb = afFor.getLowerBound(); + auto forUb = afFor.getUpperBound(); + switch (cmp) { + // \forall i \in [LB, UB) == k => LB == k and UB == k+1 + case Cmp::EQ: { + if (!valueCmp(Cmp::EQ, forLb, val)) + return false; + if (!val.isValue) { + return valueCmp(Cmp::EQ, forUb, val.iVal + 1); + } + return false; + } + // \forall i \in [LB, UB) < k => UB <= k + case Cmp::LT: { + return valueCmp(Cmp::LE, forUb, val); + } + // \forall i \in [LB, UB) <= k => UB-1 <= k => UB <= k+1 + case Cmp::LE: { + if (!val.isValue) { + return valueCmp(Cmp::LE, forUb, val.iVal + 1); + } + return valueCmp(Cmp::LE, forUb, val); + } + // \forall i \in [LB, UB) > k => LB > k + case Cmp::GT: { + return valueCmp(Cmp::GT, forLb, val); + } + // \forall i \in [LB, UB) >= k => LB >= k + case Cmp::GE: { + return valueCmp(Cmp::GE, forLb, val); + } + } + } + } + + if (scf::ParallelOp afFor = + dyn_cast(baval.getOwner()->getParentOp())) { + auto forLb = afFor.getLowerBound()[baval.getArgNumber()]; + auto forUb = afFor.getUpperBound()[baval.getArgNumber()]; + switch (cmp) { + // \forall i \in [LB, UB) == k => LB == k and UB == k+1 + case Cmp::EQ: { + if (!valueCmp(Cmp::EQ, forLb, val)) + return false; + if (!val.isValue) { + return valueCmp(Cmp::EQ, forUb, val.iVal + 1); + } + return false; + } + // \forall i \in [LB, UB) < k => UB <= k + case Cmp::LT: { + return valueCmp(Cmp::LE, forUb, val); + } + // \forall i \in [LB, UB) <= k => UB-1 <= k => UB <= k+1 + case Cmp::LE: { + if (!val.isValue) { + return valueCmp(Cmp::LE, forUb, val.iVal + 1); + } + return valueCmp(Cmp::LE, forUb, val); + } + // \forall i \in [LB, UB) > k => LB > k + case Cmp::GT: { + return valueCmp(Cmp::GT, forLb, val); + } + // \forall i \in [LB, UB) >= k => LB >= k + case Cmp::GE: { + return valueCmp(Cmp::GE, forLb, val); + } + } + } + } + if (val.isValue && val.vVal == bval) { + switch (cmp) { + case Cmp::EQ: + return true; + case Cmp::LT: + return false; + case Cmp::LE: + return true; + case Cmp::GT: + return false; + case Cmp::GE: + return true; + } + } + return false; +} + +bool valueCmp(Cmp cmp, AffineExpr expr, size_t numDim, ValueRange operands, + ValueOrInt val) { + + if (auto opd = mlir::dyn_cast(expr)) { + switch (cmp) { + case Cmp::EQ: + return val == opd.getValue(); + case Cmp::LT: + return val > opd.getValue(); + case Cmp::LE: + return val >= opd.getValue(); + case Cmp::GT: + return val < opd.getValue(); + case Cmp::GE: + return val <= opd.getValue(); + } + } + if (auto opd = mlir::dyn_cast(expr)) { + return valueCmp(cmp, operands[opd.getPosition()], val); + } + if (auto opd = mlir::dyn_cast(expr)) { + return valueCmp(cmp, operands[opd.getPosition() + numDim], val); + } + + if (auto bop = mlir::dyn_cast(expr)) { + if (bop.getKind() == AffineExprKind::Add) { + switch (cmp) { + case Cmp::EQ: + return (valueCmp(cmp, bop.getLHS(), numDim, operands, val) && + valueCmp(cmp, bop.getRHS(), numDim, operands, 0)) || + (valueCmp(cmp, bop.getLHS(), numDim, operands, 0) && + valueCmp(cmp, bop.getRHS(), numDim, operands, val)); + case Cmp::LT: + return (valueCmp(cmp, bop.getLHS(), numDim, operands, val) && + valueCmp(Cmp::LE, bop.getRHS(), numDim, operands, 0)) || + (valueCmp(Cmp::LE, bop.getLHS(), numDim, operands, 0) && + valueCmp(cmp, bop.getRHS(), numDim, operands, val)) || + (valueCmp(Cmp::LE, bop.getLHS(), numDim, operands, val) && + valueCmp(cmp, bop.getRHS(), numDim, operands, 0)) || + (valueCmp(cmp, bop.getLHS(), numDim, operands, 0) && + valueCmp(Cmp::LE, bop.getRHS(), numDim, operands, val)); + case Cmp::LE: + return (valueCmp(cmp, bop.getLHS(), numDim, operands, val) && + valueCmp(cmp, bop.getRHS(), numDim, operands, 0)) || + (valueCmp(cmp, bop.getLHS(), numDim, operands, 0) && + valueCmp(cmp, bop.getRHS(), numDim, operands, val)); + case Cmp::GT: + return (valueCmp(cmp, bop.getLHS(), numDim, operands, val) && + valueCmp(Cmp::GE, bop.getRHS(), numDim, operands, 0)) || + (valueCmp(Cmp::GE, bop.getLHS(), numDim, operands, 0) && + valueCmp(cmp, bop.getRHS(), numDim, operands, val)) || + (valueCmp(Cmp::GE, bop.getLHS(), numDim, operands, val) && + valueCmp(cmp, bop.getRHS(), numDim, operands, 0)) || + (valueCmp(cmp, bop.getLHS(), numDim, operands, 0) && + valueCmp(Cmp::GE, bop.getRHS(), numDim, operands, val)); + case Cmp::GE: + return (valueCmp(cmp, bop.getLHS(), numDim, operands, val) && + valueCmp(cmp, bop.getRHS(), numDim, operands, 0)) || + (valueCmp(cmp, bop.getLHS(), numDim, operands, 0) && + valueCmp(cmp, bop.getRHS(), numDim, operands, val)); + } + } + if (bop.getKind() == AffineExprKind::Mul && val == 0) { + switch (cmp) { + case Cmp::EQ: + return (valueCmp(cmp, bop.getLHS(), numDim, operands, val) || + valueCmp(cmp, bop.getRHS(), numDim, operands, val)); + case Cmp::LT: + return (valueCmp(Cmp::LT, bop.getLHS(), numDim, operands, val) && + valueCmp(Cmp::GT, bop.getRHS(), numDim, operands, 0)) || + (valueCmp(Cmp::GT, bop.getLHS(), numDim, operands, 0) && + valueCmp(Cmp::LT, bop.getRHS(), numDim, operands, val)); + case Cmp::LE: + return valueCmp(Cmp::EQ, bop.getLHS(), numDim, operands, val) || + valueCmp(Cmp::EQ, bop.getRHS(), numDim, operands, val) || + ((valueCmp(Cmp::GE, bop.getLHS(), numDim, operands, 0) && + valueCmp(Cmp::LE, bop.getRHS(), numDim, operands, val)) || + (valueCmp(Cmp::LE, bop.getLHS(), numDim, operands, 0) && + valueCmp(Cmp::GE, bop.getRHS(), numDim, operands, val))); + case Cmp::GT: + return (valueCmp(Cmp::LT, bop.getLHS(), numDim, operands, val) && + valueCmp(Cmp::LT, bop.getRHS(), numDim, operands, 0)) || + (valueCmp(Cmp::GT, bop.getLHS(), numDim, operands, 0) && + valueCmp(Cmp::GT, bop.getRHS(), numDim, operands, val)); + case Cmp::GE: + return valueCmp(Cmp::EQ, bop.getLHS(), numDim, operands, val) || + valueCmp(Cmp::EQ, bop.getRHS(), numDim, operands, val) || + ((valueCmp(Cmp::GE, bop.getLHS(), numDim, operands, 0) && + valueCmp(Cmp::GE, bop.getRHS(), numDim, operands, val)) || + (valueCmp(Cmp::LE, bop.getLHS(), numDim, operands, 0) && + valueCmp(Cmp::LE, bop.getRHS(), numDim, operands, val))); + } + } + } + return false; +} + +bool isReadOnly(Operation *op) { + bool hasRecursiveEffects = op->hasTrait(); + if (hasRecursiveEffects) { + for (Region ®ion : op->getRegions()) { + for (auto &block : region) { + for (auto &nestedOp : block) + if (!isReadOnly(&nestedOp)) + return false; + } + } + return true; + } + + // If the op has memory effects, try to characterize them to see if the op + // is trivially dead here. + if (auto effectInterface = dyn_cast(op)) { + // Check to see if this op either has no effects, or only allocates/reads + // memory. + SmallVector effects; + effectInterface.getEffects(effects); + return llvm::all_of(effects, [](const MemoryEffects::EffectInstance &it) { + return isa(it.getEffect()); + }); + } + return false; +} \ No newline at end of file diff --git a/mlir/lib/Conversion/SCFToAffine/Ops.h b/mlir/lib/Conversion/SCFToAffine/Ops.h new file mode 100644 index 0000000000000..d8ddae9c42aca --- /dev/null +++ b/mlir/lib/Conversion/SCFToAffine/Ops.h @@ -0,0 +1,114 @@ +#ifndef POLYGEISTOPS_H +#define POLYGEISTOPS_H + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +bool collectEffects( + mlir::Operation *op, + llvm::SmallVectorImpl &effects, + bool ignoreBarriers); + +bool getEffectsBefore( + mlir::Operation *op, + llvm::SmallVectorImpl &effects, + bool stopAtBarrier); + +bool getEffectsAfter( + mlir::Operation *op, + llvm::SmallVectorImpl &effects, + bool stopAtBarrier); + +bool isReadOnly(mlir::Operation *); +bool isReadNone(mlir::Operation *); + +bool mayReadFrom(mlir::Operation *, mlir::Value); +bool mayWriteTo(mlir::Operation *, mlir::Value, bool ignoreBarrier = false); + +bool mayAlias(mlir::MemoryEffects::EffectInstance a, + mlir::MemoryEffects::EffectInstance b); + +bool mayAlias(mlir::MemoryEffects::EffectInstance a, mlir::Value b); + +struct ValueOrInt { + bool isValue; + mlir::Value vVal; + int64_t iVal; + ValueOrInt(mlir::Value v) { initValue(v); } + void initValue(mlir::Value v) { + using namespace mlir; + if (v) { + IntegerAttr iattr; + if (matchPattern(v, m_Constant(&iattr))) { + iVal = iattr.getValue().getSExtValue(); + vVal = nullptr; + isValue = false; + return; + } + } + isValue = true; + vVal = v; + } + + ValueOrInt(size_t i) : isValue(false), vVal(), iVal(i) {} + + bool operator>=(int64_t v) { + if (isValue) + return false; + return iVal >= v; + } + bool operator>(int64_t v) { + if (isValue) + return false; + return iVal > v; + } + bool operator==(int64_t v) { + if (isValue) + return false; + return iVal == v; + } + bool operator<(int64_t v) { + if (isValue) + return false; + return iVal < v; + } + bool operator<=(int64_t v) { + if (isValue) + return false; + return iVal <= v; + } + bool operator>=(const llvm::APInt &v) { + if (isValue) + return false; + return iVal >= v.getSExtValue(); + } + bool operator>(const llvm::APInt &v) { + if (isValue) + return false; + return iVal > v.getSExtValue(); + } + bool operator==(const llvm::APInt &v) { + if (isValue) + return false; + return iVal == v.getSExtValue(); + } + bool operator<(const llvm::APInt &v) { + if (isValue) + return false; + return iVal < v.getSExtValue(); + } + bool operator<=(const llvm::APInt &v) { + if (isValue) + return false; + return iVal <= v.getSExtValue(); + } +}; + +enum class Cmp { EQ, LT, LE, GT, GE }; + +bool valueCmp(Cmp cmp, mlir::AffineExpr expr, size_t numDim, + mlir::ValueRange operands, ValueOrInt val); + +bool valueCmp(Cmp cmp, mlir::Value bval, ValueOrInt val); +#endif diff --git a/mlir/lib/Conversion/SCFToAffine/RaiseToAffine.cpp b/mlir/lib/Conversion/SCFToAffine/RaiseToAffine.cpp new file mode 100644 index 0000000000000..fc8fe44b7216d --- /dev/null +++ b/mlir/lib/Conversion/SCFToAffine/RaiseToAffine.cpp @@ -0,0 +1,296 @@ +#include "mlir/Conversion/Passes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Passes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "raise-to-affine" + +using namespace mlir; +using namespace mlir::arith; +using namespace affine; + +namespace mlir { + +#define GEN_PASS_DEF_RAISESCFTOAFFINE +#include "mlir/Conversion/Passes.h.inc" + +} // namespace mlir + +bool isValidIndex(Value val); +void fully2ComposeAffineMapAndOperands(PatternRewriter &builder, AffineMap *map, + SmallVectorImpl *operands, + DominanceInfo &di); + +namespace { +struct RaiseSCFToAffine : public impl::RaiseSCFToAffineBase { + void runOnOperation() override; +}; +} // namespace + +struct ForOpRaising : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + // TODO: remove me or rename me. + bool isAffine(scf::ForOp loop) const { + // return true; + // enforce step to be a ConstantIndexOp (maybe too restrictive). + return affine::isValidSymbol(loop.getStep()); + } + + int64_t getStep(mlir::Value value) const { + ConstantIndexOp cstOp = value.getDefiningOp(); + if (cstOp) + return cstOp.value(); + return 1; + } + + AffineMap getMultiSymbolIdentity(Builder &b, unsigned rank) const { + SmallVector dimExprs; + dimExprs.reserve(rank); + for (unsigned i = 0; i < rank; ++i) + dimExprs.push_back(b.getAffineSymbolExpr(i)); + return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/rank, dimExprs, + b.getContext()); + } + LogicalResult matchAndRewrite(scf::ForOp loop, + PatternRewriter &rewriter) const final { + if (isAffine(loop)) { + OpBuilder builder(loop); + + SmallVector lbs; + { + SmallVector todo = {loop.getLowerBound()}; + while (todo.size()) { + auto cur = todo.back(); + todo.pop_back(); + if (isValidIndex(cur)) { + lbs.push_back(cur); + continue; + } + if (auto selOp = cur.getDefiningOp()) { + // LB only has max of operands + if (auto cmp = selOp.getCondition().getDefiningOp()) { + if (cmp.getLhs() == selOp.getTrueValue() && + cmp.getRhs() == selOp.getFalseValue() && + cmp.getPredicate() == CmpIPredicate::sge) { + todo.push_back(cmp.getLhs()); + todo.push_back(cmp.getRhs()); + continue; + } + } + } + return failure(); + } + } + + SmallVector ubs; + { + SmallVector todo = {loop.getUpperBound()}; + while (todo.size()) { + auto cur = todo.back(); + todo.pop_back(); + if (isValidIndex(cur)) { + ubs.push_back(cur); + continue; + } + if (auto selOp = cur.getDefiningOp()) { + // UB only has min of operands + if (auto cmp = selOp.getCondition().getDefiningOp()) { + if (cmp.getLhs() == selOp.getTrueValue() && + cmp.getRhs() == selOp.getFalseValue() && + cmp.getPredicate() == CmpIPredicate::sle) { + todo.push_back(cmp.getLhs()); + todo.push_back(cmp.getRhs()); + continue; + } + } + } + return failure(); + } + } + + bool rewrittenStep = false; + if (!loop.getStep().getDefiningOp()) { + if (ubs.size() != 1 || lbs.size() != 1) + return failure(); + ubs[0] = DivUIOp::create( + rewriter, loop.getLoc(), + AddIOp::create( + rewriter, loop.getLoc(), + SubIOp::create( + rewriter, loop.getLoc(), loop.getStep(), + ConstantIndexOp::create(rewriter, loop.getLoc(), 1)), + SubIOp::create(rewriter, loop.getLoc(), loop.getUpperBound(), + loop.getLowerBound())), + loop.getStep()); + lbs[0] = ConstantIndexOp::create(rewriter, loop.getLoc(), 0); + rewrittenStep = true; + } + + auto *scope = affine::getAffineScope(loop)->getParentOp(); + DominanceInfo di(scope); + + AffineMap lbMap = getMultiSymbolIdentity(builder, lbs.size()); + { + fully2ComposeAffineMapAndOperands(rewriter, &lbMap, &lbs, di); + affine::canonicalizeMapAndOperands(&lbMap, &lbs); + lbMap = removeDuplicateExprs(lbMap); + } + AffineMap ubMap = getMultiSymbolIdentity(builder, ubs.size()); + { + fully2ComposeAffineMapAndOperands(rewriter, &ubMap, &ubs, di); + affine::canonicalizeMapAndOperands(&ubMap, &ubs); + ubMap = removeDuplicateExprs(ubMap); + } + + affine::AffineForOp affineLoop = affine::AffineForOp::create( + rewriter, loop.getLoc(), lbs, lbMap, ubs, ubMap, + getStep(loop.getStep()), loop.getInits()); + + auto mergedYieldOp = + cast(loop.getRegion().front().getTerminator()); + + Block &newBlock = affineLoop.getRegion().front(); + + // The terminator is added if the iterator args are not provided. + // see the ::build method. + if (affineLoop.getNumIterOperands() == 0) { + auto *affineYieldOp = newBlock.getTerminator(); + rewriter.eraseOp(affineYieldOp); + } + + SmallVector vals; + rewriter.setInsertionPointToStart(&affineLoop.getRegion().front()); + for (Value arg : affineLoop.getRegion().front().getArguments()) { + if (rewrittenStep && arg == affineLoop.getInductionVar()) { + arg = AddIOp::create( + rewriter, loop.getLoc(), loop.getLowerBound(), + MulIOp::create(rewriter, loop.getLoc(), arg, loop.getStep())); + } + vals.push_back(arg); + } + assert(vals.size() == loop.getRegion().front().getNumArguments()); + rewriter.mergeBlocks(&loop.getRegion().front(), + &affineLoop.getRegion().front(), vals); + + rewriter.setInsertionPoint(mergedYieldOp); + affine::AffineYieldOp::create(rewriter, mergedYieldOp.getLoc(), + mergedYieldOp.getOperands()); + rewriter.eraseOp(mergedYieldOp); + + rewriter.replaceOp(loop, affineLoop.getResults()); + + return success(); + } + return failure(); + } +}; + +struct ParallelOpRaising : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + void canonicalizeLoopBounds(PatternRewriter &rewriter, + affine::AffineParallelOp forOp) const { + SmallVector lbOperands(forOp.getLowerBoundsOperands()); + SmallVector ubOperands(forOp.getUpperBoundsOperands()); + + auto lbMap = forOp.getLowerBoundsMap(); + auto ubMap = forOp.getUpperBoundsMap(); + + auto *scope = affine::getAffineScope(forOp)->getParentOp(); + DominanceInfo di(scope); + + fully2ComposeAffineMapAndOperands(rewriter, &lbMap, &lbOperands, di); + affine::canonicalizeMapAndOperands(&lbMap, &lbOperands); + + fully2ComposeAffineMapAndOperands(rewriter, &ubMap, &ubOperands, di); + affine::canonicalizeMapAndOperands(&ubMap, &ubOperands); + + forOp.setLowerBounds(lbOperands, lbMap); + forOp.setUpperBounds(ubOperands, ubMap); + } + + LogicalResult matchAndRewrite(scf::ParallelOp loop, + PatternRewriter &rewriter) const final { + OpBuilder builder(loop); + + if (loop.getResults().size()) + return failure(); + + if (!llvm::all_of(loop.getLowerBound(), isValidIndex)) { + return failure(); + } + + if (!llvm::all_of(loop.getUpperBound(), isValidIndex)) { + return failure(); + } + + SmallVector steps; + for (auto step : loop.getStep()) + if (auto cst = step.getDefiningOp()) + steps.push_back(cst.value()); + else + return failure(); + + ArrayRef reductions; + SmallVector bounds; + for (size_t i = 0; i < loop.getLowerBound().size(); i++) + bounds.push_back(AffineMap::get( + /*dimCount=*/0, /*symbolCount=*/loop.getLowerBound().size(), + builder.getAffineSymbolExpr(i))); + affine::AffineParallelOp affineLoop = affine::AffineParallelOp::create( + rewriter, loop.getLoc(), loop.getResultTypes(), reductions, bounds, + loop.getLowerBound(), bounds, loop.getUpperBound(), + steps); //, loop.getInitVals()); + + canonicalizeLoopBounds(rewriter, affineLoop); + + auto mergedReduceOp = + cast(loop.getRegion().front().getTerminator()); + + Block &newBlock = affineLoop.getRegion().front(); + + // The terminator is added if the iterator args are not provided. + // see the ::build method. + if (affineLoop.getResults().size() == 0) { + auto *affineYieldOp = newBlock.getTerminator(); + rewriter.eraseOp(affineYieldOp); + } + + SmallVector vals; + for (Value arg : affineLoop.getRegion().front().getArguments()) { + vals.push_back(arg); + } + rewriter.mergeBlocks(&loop.getRegion().front(), + &affineLoop.getRegion().front(), vals); + + rewriter.setInsertionPoint(mergedReduceOp); + affine::AffineYieldOp::create(rewriter, mergedReduceOp.getLoc(), + mergedReduceOp.getOperands()); + rewriter.eraseOp(mergedReduceOp); + + rewriter.replaceOp(loop, affineLoop.getResults()); + + return success(); + } +}; + +void RaiseSCFToAffine::runOnOperation() { + RewritePatternSet patterns(&getContext()); + patterns.insert(&getContext()); + + GreedyRewriteConfig config; + (void)applyPatternsGreedily(getOperation(), std::move(patterns), config); +} + +std::unique_ptr mlir::createRaiseSCFToAffinePass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/mlir/test/Conversion/SCFToAffine/affinecfg.mlir b/mlir/test/Conversion/SCFToAffine/affinecfg.mlir new file mode 100644 index 0000000000000..2b3dc6f3a34a0 --- /dev/null +++ b/mlir/test/Conversion/SCFToAffine/affinecfg.mlir @@ -0,0 +1,166 @@ +// RUN: mlir-opt --affine-cfg --split-input-file %s | FileCheck %s + +module { + func.func @_Z7runTestiPPc(%arg0: index, %arg2: memref) { + %c0_i32 = arith.constant 0 : i32 + %c1 = arith.constant 1 : index + %1 = arith.addi %arg0, %c1 : index + affine.for %arg3 = 0 to 2 { + %2 = arith.muli %arg3, %1 : index + affine.for %arg4 = 0 to 2 { + %3 = arith.addi %2, %arg4 : index + memref.store %c0_i32, %arg2[%3] : memref + } + } + return + } + +} + + +// CHECK: func.func @_Z7runTestiPPc(%[[arg0:.+]]: index, %[[arg1:.+]]: memref) { +// CHECK-NEXT: %[[c0_i32:.+]] = arith.constant 0 : i32 +// CHECK-NEXT: affine.for %[[arg2:.+]] = 0 to 2 { +// CHECK-NEXT: affine.for %[[arg3:.+]] = 0 to 2 { +// CHECK-NEXT: affine.store %c0_i32, %arg1[%[[arg3]] + %[[arg2]] * (symbol(%[[arg0]]) + 1)] : memref +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } + +// ----- +module { +func.func @kernel_nussinov(%arg0: i32, %arg2: memref) { + %c0 = arith.constant 0 : index + %true = arith.constant true + %c1_i32 = arith.constant 1 : i32 + %c59 = arith.constant 59 : index + %c100_i32 = arith.constant 100 : i32 + affine.for %arg3 = 0 to 60 { + %0 = arith.subi %c59, %arg3 : index + %1 = arith.index_cast %0 : index to i32 + %2 = arith.cmpi slt, %1, %c100_i32 : i32 + scf.if %2 { + affine.store %arg0, %arg2[] : memref + } + } + return +} +} + +// CHECK: #set = affine_set<(d0) : (d0 + 40 >= 0)> +// CHECK: func.func @kernel_nussinov(%[[arg0:.+]]: i32, %[[arg1:.+]]: memref) { +// CHECK-NEXT: affine.for %[[arg2:.+]] = 0 to 60 { +// CHECK-NEXT: affine.if #set(%[[arg2]]) { +// CHECK-NEXT: affine.store %[[arg0]], %[[arg1]][] : memref +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } + + +// ----- + +module { + func.func private @run() + + func.func @minif(%arg4: i32, %arg5 : i32, %arg10 : index) { + %c0_i32 = arith.constant 0 : i32 + + affine.for %i = 0 to 10 { + %70 = arith.index_cast %arg10 : index to i32 + %71 = arith.muli %70, %arg5 : i32 + %73 = arith.divui %71, %arg5 : i32 + %75 = arith.muli %73, %arg5 : i32 + %79 = arith.subi %arg4, %75 : i32 + %81 = arith.cmpi sle, %arg5, %79 : i32 + %83 = arith.select %81, %arg5, %79 : i32 + %92 = arith.cmpi slt, %c0_i32, %83 : i32 + scf.if %92 { + func.call @run() : () -> () + scf.yield + } + } + return + } +} + +// CHECK: #set = affine_set<()[s0] : (s0 - 1 >= 0)> +// CHECK: func.func @minif(%[[arg0:.+]]: i32, %[[arg1:.+]]: i32, %[[arg2:.+]]: index) { +// CHECK-NEXT: %[[V0:.+]] = arith.index_cast %[[arg2]] : index to i32 +// CHECK-NEXT: %[[V1:.+]] = arith.muli %[[V0]], %[[arg1]] : i32 +// CHECK-NEXT: %[[V2:.+]] = arith.divui %[[V1]], %[[arg1]] : i32 +// CHECK-NEXT: %[[V3:.+]] = arith.muli %[[V2]], %[[arg1]] : i32 +// CHECK-NEXT: %[[V4:.+]] = arith.subi %[[arg0]], %[[V3]] : i32 +// CHECK-NEXT: %[[V5:.+]] = arith.cmpi sle, %[[arg1]], %[[V4]] : i32 +// CHECK-NEXT: %[[V6:.+]] = arith.select %5, %[[arg1]], %[[V4]] : i32 +// CHECK-NEXT: %[[V7:.+]] = arith.index_cast %[[V6]] : i32 to index +// CHECK-NEXT: affine.for %[[arg3:.+]] = 0 to 10 { +// CHECK-NEXT: affine.if #set()[%[[V7]]] { +// CHECK-NEXT: func.call @run() : () -> () +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } + +// ----- + +module { + llvm.func @atoi(!llvm.ptr) -> i32 +func.func @_Z7runTestiPPc(%arg0: i32, %39: memref, %arg1: !llvm.ptr) attributes {llvm.linkage = #llvm.linkage} { + %c2_i32 = arith.constant 2 : i32 + %c16_i32 = arith.constant 16 : i32 + %58 = llvm.call @atoi(%arg1) : (!llvm.ptr) -> i32 + %40 = arith.divsi %58, %c16_i32 : i32 + affine.for %arg2 = 1 to 10 { + %62 = arith.index_cast %arg2 : index to i32 + %67 = arith.muli %58, %62 : i32 + %69 = arith.addi %67, %40 : i32 + %75 = arith.addi %69, %58 : i32 + %76 = arith.index_cast %75 : i32 to index + memref.store %c2_i32, %39[%76] : memref + } + return +} +} + +// CHECK: func.func @_Z7runTestiPPc(%[[arg0:.+]]: i32, %[[arg1:.+]]: memref, %[[arg2:.+]]: !llvm.ptr) attributes {llvm.linkage = #llvm.linkage} { +// CHECK-NEXT: %[[c2_i32:.+]] = arith.constant 2 : i32 +// CHECK-NEXT: %[[c16_i32:.+]] = arith.constant 16 : i32 +// CHECK-NEXT: %[[V0:.+]] = llvm.call @atoi(%[[arg2]]) : (!llvm.ptr) -> i32 +// CHECK-NEXT: %[[V1:.+]] = arith.index_cast %[[V0]] : i32 to index +// CHECK-NEXT: %[[V2:.+]] = arith.divsi %[[V0]], %[[c16_i32]] : i32 +// CHECK-NEXT: %[[V3:.+]] = arith.index_cast %[[V2]] : i32 to index +// CHECK-NEXT: affine.for %[[arg3:.+]] = 1 to 10 { +// CHECK-NEXT: affine.store %[[c2_i32]], %[[arg1]][%[[arg3]] * symbol(%1) + symbol(%1) + symbol(%[[V3]])] : memref +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } + +// ----- + +module { + func.func @c(%71: memref, %39: i64) { + affine.parallel (%arg2, %arg3) = (0, 0) to (42, 512) { + %262 = arith.index_cast %arg2 : index to i32 + %a264 = arith.extsi %262 : i32 to i64 + %268 = arith.cmpi slt, %a264, %39 : i64 + scf.if %268 { + "test.something"() : () -> () + } + } + return + } +} + +// CHECK: #set = affine_set<(d0)[s0] : (-d0 + s0 - 1 >= 0)> +// CHECK: func.func @c(%[[arg0:.+]]: memref, %[[arg1]]: i64) { +// CHECK-NEXT: %[[V0:.+]] = arith.index_cast %[[arg1]] : i64 to index +// CHECK-NEXT: affine.parallel (%[[arg2:.+]], %[[arg3:.+]]) = (0, 0) to (42, 512) { +// CHECK-NEXT: affine.if #set(%[[arg2]])[%[[V0]]] { +// CHECK-NEXT: "test.something"() : () -> () +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } + diff --git a/mlir/test/Conversion/SCFToAffine/affraise.mlir b/mlir/test/Conversion/SCFToAffine/affraise.mlir new file mode 100644 index 0000000000000..8c16682a2933b --- /dev/null +++ b/mlir/test/Conversion/SCFToAffine/affraise.mlir @@ -0,0 +1,48 @@ +// RUN: mlir-opt --affine-cfg --raise-scf-to-affine %s | FileCheck %s + +module { + func.func @withinif(%arg0: memref, %arg1: i32, %arg2: memref, %arg3: i1) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + scf.if %arg3 { + %3 = arith.index_cast %arg1 : i32 to index + scf.for %arg6 = %c1 to %3 step %c1 { + %4 = memref.load %arg0[%arg6] : memref + memref.store %4, %arg2[%arg6] : memref + } + } + return + } + func.func @aff(%c : i1, %arg0: i32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + scf.if %c { + %75 = arith.index_cast %arg0 : i32 to index + scf.parallel (%arg5) = (%c0) to (%75) step (%c1) -> () { + "test.op"() : () -> () + } + } + return + } +} + +// CHECK: func.func @withinif(%[[arg0:.+]]: memref, %[[arg1:.+]]: i32, %[[arg2:.+]]: memref, %[[arg3:.+]]: i1) { +// CHECK-DAG: %[[V0:.+]] = arith.index_cast %[[arg1]] : i32 to index +// CHECK-NEXT: scf.if %[[arg3]] { +// CHECK-NEXT: affine.for %[[arg4:.+]] = 1 to %[[V0]] { +// CHECK-NEXT: %[[V1:.+]] = memref.load %[[arg0]][%[[arg4]]] : memref +// CHECK-NEXT: memref.store %[[V1]], %[[arg2]][%[[arg4]]] : memref +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } + +// CHECK-NEXT: func.func @aff(%[[arg0:.+]]: i1, %[[arg1:.+]]: i32) { +// CHECK-NEXT: %[[V0:.+]] = arith.index_cast %[[arg1]] : i32 to index +// CHECK-NEXT: scf.if %[[arg0]] { +// CHECK-NEXT: affine.parallel (%[[arg2:.+]]) = (0) to (symbol(%[[V0]])) { +// CHECK-NEXT: "test.op"() : () -> () +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } diff --git a/mlir/test/Conversion/SCFToAffine/affraise2.mlir b/mlir/test/Conversion/SCFToAffine/affraise2.mlir new file mode 100644 index 0000000000000..2da1e3713dd96 --- /dev/null +++ b/mlir/test/Conversion/SCFToAffine/affraise2.mlir @@ -0,0 +1,31 @@ +// RUN: mlir-opt --affine-cfg --raise-scf-to-affine %s | FileCheck %s + +module { + func.func @main(%12 : i1, %14 : i32, %18 : memref, %19 : memref ) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + scf.if %12 { + %15 = arith.index_cast %14 : i32 to index + %16 = arith.muli %15, %c4 : index + %17 = arith.divui %16, %c4 : index + scf.for %arg2 = %c0 to %17 step %c1 { + %20 = memref.load %19[%arg2] : memref + memref.store %20, %18[%arg2] : memref + } + } + return + } +} + +// CHECK: func.func @main(%[[arg0:.+]]: i1, %[[arg1:.+]]: i32, %[[arg2:.+]]: memref, %[[arg3:.+]]: memref) { +// CHECK-NEXT: %[[c4:.+]] = arith.constant 4 : index +// CHECK-NEXT: %[[V0:.+]] = arith.index_cast %[[arg1]] : i32 to index +// CHECK-NEXT: %[[V1:.+]] = arith.muli %[[V0]], %[[c4]] : index +// CHECK-NEXT: scf.if %[[arg0]] { +// CHECK-NEXT: %[[V2:.+]] = arith.divui %[[V1]], %[[c4]] : index +// CHECK-NEXT: affine.for %[[arg4:.+]] = 0 to %[[V2]] { +// CHECK-NEXT: %[[a:.+]] = memref.load %[[arg3]][%[[arg4]]] : memref +// CHECK-NEXT: memref.store %[[a]], %[[arg2]][%[[arg4]]] : memref +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/mlir/test/Conversion/SCFToAffine/affraise3.mlir b/mlir/test/Conversion/SCFToAffine/affraise3.mlir new file mode 100644 index 0000000000000..f253c384622c3 --- /dev/null +++ b/mlir/test/Conversion/SCFToAffine/affraise3.mlir @@ -0,0 +1,95 @@ +// RUN: mlir-opt --affine-cfg --raise-scf-to-affine %s | FileCheck %s + +module { + func.func @slt(%arg0: index) { + affine.for %arg1 = 0 to 10 { + %c = arith.cmpi slt, %arg1, %arg0 : index + scf.if %c { + "test.run"(%arg1) : (index) -> () + } + } + return + } + func.func @sle(%arg0: index) { + affine.for %arg1 = 0 to 10 { + %c = arith.cmpi sle, %arg1, %arg0 : index + scf.if %c { + "test.run"(%arg1) : (index) -> () + } + } + return + } + func.func @sgt(%arg0: index) { + affine.for %arg1 = 0 to 10 { + %c = arith.cmpi sgt, %arg1, %arg0 : index + scf.if %c { + "test.run"(%arg1) : (index) -> () + } + } + return + } + func.func @sge(%arg0: index) { + affine.for %arg1 = 0 to 10 { + %c = arith.cmpi sge, %arg1, %arg0 : index + scf.if %c { + "test.run"(%arg1) : (index) -> () + } + } + return + } +} + +// -d0 + s0 - 1 >= 0 => +// -d0 >= 1 - s0 +// d0 <= s0 - 1 +// d0 < s0 +// CHECK: #set = affine_set<(d0)[s0] : (-d0 + s0 - 1 >= 0)> + + +// -d0 + s0 >= 0 => +// -d0 >= - s0 +// d0 <= s0 +// CHECK: #set1 = affine_set<(d0)[s0] : (-d0 + s0 >= 0)> + +// d0 - s0 - 1 >= 0 => +// d0 >= s0 + 1 +// d0 > s0 +// CHECK: #set2 = affine_set<(d0)[s0] : (d0 - s0 - 1 >= 0)> + +// d0 - s0 >= 0 => +// d0 >= s0 +// CHECK: #set3 = affine_set<(d0)[s0] : (d0 - s0 >= 0)> + +// CHECK: func.func @slt(%[[arg0:.+]]: index) { +// CHECK-NEXT: affine.for %[[arg1:.+]] = 0 to 10 { +// CHECK-NEXT: affine.if #set(%arg1)[%[[arg0]]] { +// CHECK-NEXT: "test.run"(%[[arg1]]) : (index) -> () +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } +// CHECK: func.func @sle(%[[arg0:.+]]: index) { +// CHECK-NEXT: affine.for %[[arg1:.+]] = 0 to 10 { +// CHECK-NEXT: affine.if #set1(%arg1)[%[[arg0]]] { +// CHECK-NEXT: "test.run"(%[[arg1]]) : (index) -> () +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } +// CHECK: func.func @sgt(%[[arg0:.+]]: index) { +// CHECK-NEXT: affine.for %[[arg1:.+]] = 0 to 10 { +// CHECK-NEXT: affine.if #set2(%arg1)[%[[arg0]]] { +// CHECK-NEXT: "test.run"(%[[arg1]]) : (index) -> () +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } +// CHECK: func.func @sge(%[[arg0:.+]]: index) { +// CHECK-NEXT: affine.for %[[arg1:.+]] = 0 to 10 { +// CHECK-NEXT: affine.if #set3(%arg1)[%[[arg0]]] { +// CHECK-NEXT: "test.run"(%[[arg1]]) : (index) -> () +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } + diff --git a/mlir/test/Conversion/SCFToAffine/raisescffor.mlir b/mlir/test/Conversion/SCFToAffine/raisescffor.mlir new file mode 100644 index 0000000000000..5eb78ff3079c8 --- /dev/null +++ b/mlir/test/Conversion/SCFToAffine/raisescffor.mlir @@ -0,0 +1,25 @@ +// RUN: mlir-opt --raise-scf-to-affine %s | FileCheck %s + +module { + func.func private @_Z12kernel5_initPc(%0: index, %arg0: memref) { + %c10 = arith.constant 10 : index + %c0 = arith.constant 0 : index + scf.for %arg1 = %c0 to %c10 step %0 { + memref.store %c10, %arg0[] : memref + } + return + } +} + +// CHECK-LABEL: func.func private @_Z12kernel5_initPc( +// CHECK-SAME: %[[VAL_0:.*]]: index, +// CHECK-SAME: %[[VAL_1:.*]]: memref) { +// CHECK: %[[VAL_3:.*]] = arith.constant 10 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_4:.*]] = arith.subi %[[VAL_0]], %[[VAL_2]] : index +// CHECK: %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_3]] : index +// CHECK: %[[VAL_6:.*]] = arith.divui %[[VAL_5]], %[[VAL_0]] : index +// CHECK: affine.for %[[VAL_7:.*]] = 0 to %[[VAL_6]] { +// CHECK: memref.store %[[VAL_3]], %[[VAL_1]][] : memref +// CHECK: } +