diff --git a/include/polygeist/BarrierUtils.h b/include/polygeist/BarrierUtils.h index 807ff41406db..f6cf85755063 100644 --- a/include/polygeist/BarrierUtils.h +++ b/include/polygeist/BarrierUtils.h @@ -74,6 +74,23 @@ static T allocateTemporaryBuffer(mlir::OpBuilder &rewriter, mlir::Value value, return rewriter.create(value.getLoc(), type, iterationCounts); } +template <> +mlir::LLVM::AllocaOp allocateTemporaryBuffer( + mlir::OpBuilder &rewriter, mlir::Value value, + mlir::ValueRange iterationCounts, bool alloca, mlir::DataLayout *DLI) { + using namespace mlir; + auto val = value.getDefiningOp(); + auto sz = val.getArraySize(); + assert(DLI); + for (auto iter : iterationCounts) { + sz = + rewriter.create(value.getLoc(), sz, + rewriter.create( + value.getLoc(), sz.getType(), iter)); + } + return rewriter.create(value.getLoc(), val.getType(), sz); +} + template <> mlir::LLVM::CallOp allocateTemporaryBuffer( mlir::OpBuilder &rewriter, mlir::Value value, diff --git a/include/polygeist/Passes/Passes.h b/include/polygeist/Passes/Passes.h index 021b94dad52a..9b3e7c987645 100644 --- a/include/polygeist/Passes/Passes.h +++ b/include/polygeist/Passes/Passes.h @@ -25,7 +25,8 @@ std::unique_ptr createConvertPolygeistToLLVMPass(); } // namespace mlir void fully2ComposeAffineMapAndOperands( - mlir::AffineMap *map, llvm::SmallVectorImpl *operands); + mlir::OpBuilder &, mlir::AffineMap *map, + llvm::SmallVectorImpl *operands); bool isValidIndex(mlir::Value val); namespace mlir { diff --git a/include/polygeist/Passes/Passes.td b/include/polygeist/Passes/Passes.td index 63a47b9456a7..fd1da19e119b 100644 --- a/include/polygeist/Passes/Passes.td +++ b/include/polygeist/Passes/Passes.td @@ -8,7 +8,7 @@ def AffineCFG : Pass<"affine-cfg"> { let constructor = "mlir::polygeist::replaceAffineCFGPass()"; } -def Mem2Reg : Pass<"mem2reg", "FuncOp"> { +def Mem2Reg : Pass<"mem2reg"> { let summary = "Replace scf.if and similar with affine.if"; let constructor = "mlir::polygeist::createMem2RegPass()"; } @@ -25,7 +25,7 @@ def AffineReduction : Pass<"detect-reduction"> { let constructor = "mlir::polygeist::detectReductionPass()"; } -def SCFCPUify : Pass<"cpuify", "FuncOp"> { +def SCFCPUify : Pass<"cpuify"> { let summary = "remove scf.barrier"; let constructor = "mlir::polygeist::createCPUifyPass()"; let dependentDialects = @@ -35,7 +35,7 @@ def SCFCPUify : Pass<"cpuify", "FuncOp"> { ]; } -def SCFBarrierRemovalContinuation : Pass<"barrier-removal-continuation", "FuncOp"> { +def SCFBarrierRemovalContinuation : InterfacePass<"barrier-removal-continuation", "FunctionOpInterface"> { let summary = "Remove scf.barrier using continuations"; let constructor = "mlir::polygeist::createBarrierRemovalContinuation()"; let dependentDialects = ["memref::MemRefDialect", "func::FuncDialect"]; diff --git a/include/polygeist/Passes/Utils.h b/include/polygeist/Passes/Utils.h new file mode 100644 index 000000000000..1aeefe494bbc --- /dev/null +++ b/include/polygeist/Passes/Utils.h @@ -0,0 +1,109 @@ +#pragma once + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/IntegerSet.h" + +static inline mlir::scf::IfOp +cloneWithoutResults(mlir::scf::IfOp op, mlir::OpBuilder &rewriter, + mlir::BlockAndValueMapping mapping = {}, + mlir::TypeRange types = {}) { + using namespace mlir; + return rewriter.create( + op.getLoc(), types, mapping.lookupOrDefault(op.getCondition()), true); +} +static inline mlir::AffineIfOp +cloneWithoutResults(mlir::AffineIfOp op, mlir::OpBuilder &rewriter, + mlir::BlockAndValueMapping mapping = {}, + mlir::TypeRange types = {}) { + using namespace mlir; + SmallVector lower; + for (auto o : op.getOperands()) + lower.push_back(mapping.lookupOrDefault(o)); + return rewriter.create(op.getLoc(), types, op.getIntegerSet(), + lower, true); +} + +static inline mlir::scf::ForOp +cloneWithoutResults(mlir::scf::ForOp op, mlir::PatternRewriter &rewriter, + mlir::BlockAndValueMapping mapping = {}) { + using namespace mlir; + return rewriter.create( + op.getLoc(), mapping.lookupOrDefault(op.getLowerBound()), + mapping.lookupOrDefault(op.getUpperBound()), + mapping.lookupOrDefault(op.getStep())); +} +static inline mlir::AffineForOp +cloneWithoutResults(mlir::AffineForOp op, mlir::PatternRewriter &rewriter, + mlir::BlockAndValueMapping mapping = {}) { + using namespace mlir; + SmallVector lower; + for (auto o : op.getLowerBoundOperands()) + lower.push_back(mapping.lookupOrDefault(o)); + SmallVector upper; + for (auto o : op.getUpperBoundOperands()) + upper.push_back(mapping.lookupOrDefault(o)); + return rewriter.create(op.getLoc(), lower, op.getLowerBoundMap(), + upper, op.getUpperBoundMap(), + op.getStep()); +} + +static inline mlir::Block *getThenBlock(mlir::scf::IfOp op) { + return op.thenBlock(); +} +static inline mlir::Block *getThenBlock(mlir::AffineIfOp op) { + return op.getThenBlock(); +} +static inline mlir::Block *getElseBlock(mlir::scf::IfOp op) { + return op.elseBlock(); +} +static inline mlir::Block *getElseBlock(mlir::AffineIfOp op) { + return op.getElseBlock(); +} + +static inline mlir::Region &getThenRegion(mlir::scf::IfOp op) { + return op.getThenRegion(); +} +static inline mlir::Region &getThenRegion(mlir::AffineIfOp op) { + return op.thenRegion(); +} +static inline mlir::Region &getElseRegion(mlir::scf::IfOp op) { + return op.getElseRegion(); +} +static inline mlir::Region &getElseRegion(mlir::AffineIfOp op) { + return op.elseRegion(); +} + +static inline mlir::scf::YieldOp getThenYield(mlir::scf::IfOp op) { + return op.thenYield(); +} +static inline mlir::AffineYieldOp getThenYield(mlir::AffineIfOp op) { + return llvm::cast(op.getThenBlock()->getTerminator()); +} +static inline mlir::scf::YieldOp getElseYield(mlir::scf::IfOp op) { + return op.elseYield(); +} +static inline mlir::AffineYieldOp getElseYield(mlir::AffineIfOp op) { + return llvm::cast(op.getElseBlock()->getTerminator()); +} + +static inline bool inBound(mlir::scf::IfOp op, mlir::Value v) { + return op.getCondition() == v; +} +static inline bool inBound(mlir::AffineIfOp op, mlir::Value v) { + return llvm::any_of(op.getOperands(), [&](mlir::Value e) { return e == v; }); +} +static inline bool inBound(mlir::scf::ForOp op, mlir::Value v) { + return op.getUpperBound() == v; +} +static inline bool inBound(mlir::AffineForOp op, mlir::Value v) { + return llvm::any_of(op.getUpperBoundOperands(), + [&](mlir::Value e) { return e == v; }); +} +static inline bool hasElse(mlir::scf::IfOp op) { + return op.getElseRegion().getBlocks().size() > 0; +} +static inline bool hasElse(mlir::AffineIfOp op) { + return op.elseRegion().getBlocks().size() > 0; +} diff --git a/include/polygeist/PolygeistOps.td b/include/polygeist/PolygeistOps.td index 6f1cf0be8f29..bdcd8a6afa84 100644 --- a/include/polygeist/PolygeistOps.td +++ b/include/polygeist/PolygeistOps.td @@ -24,6 +24,7 @@ def BarrierOp let arguments = (ins Variadic:$indices); let summary = "barrier for parallel loops"; let description = [{}]; + let hasCanonicalizer = true; } //===----------------------------------------------------------------------===// diff --git a/lib/polygeist/Ops.cpp b/lib/polygeist/Ops.cpp index 17062ce73c43..012c29b587cb 100644 --- a/lib/polygeist/Ops.cpp +++ b/lib/polygeist/Ops.cpp @@ -84,6 +84,62 @@ void BarrierOp::getEffects( // TODO: we need to handle regions in case the parent op isn't an SCF parallel } +class BarrierHoist final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(BarrierOp barrier, + PatternRewriter &rewriter) const override { + if (isa(barrier->getParentOp())) { + + bool below = true; + for (Operation *it = barrier->getNextNode(); it != nullptr; + it = it->getNextNode()) { + auto memInterface = dyn_cast(it); + if (!memInterface) { + below = false; + break; + } + if (!memInterface.hasNoEffect()) { + below = false; + break; + } + } + if (below) { + rewriter.setInsertionPoint(barrier->getParentOp()->getNextNode()); + rewriter.create(barrier.getLoc(), barrier.getOperands()); + rewriter.eraseOp(barrier); + return success(); + } + bool above = true; + for (Operation *it = barrier->getPrevNode(); it != nullptr; + it = it->getPrevNode()) { + auto memInterface = dyn_cast(it); + if (!memInterface) { + above = false; + break; + } + if (!memInterface.hasNoEffect()) { + above = false; + break; + } + } + if (above) { + rewriter.setInsertionPoint(barrier->getParentOp()); + rewriter.create(barrier.getLoc(), barrier.getOperands()); + rewriter.eraseOp(barrier); + return success(); + } + } + return failure(); + } +}; + +void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.insert(context); +} + /// Replace cast(subindex(x, InterimType), FinalType) with subindex(x, /// FinalType) class CastOfSubIndex final : public OpRewritePattern { @@ -1262,10 +1318,12 @@ struct IfAndLazy : public OpRewritePattern { llvm::zip(prevIf.getResults(), prevIf.elseYield().getOperands(), prevIf.thenYield().getOperands())) { if (std::get<0>(it) == nextIf.getCondition()) { - if (matchPattern(std::get<1>(it), m_Zero())) { + if (matchPattern(std::get<1>(it), m_Zero()) || + std::get<1>(it).getDefiningOp()) { nextIfCondition = std::get<2>(it); thenRegion = true; - } else if (matchPattern(std::get<2>(it), m_Zero())) { + } else if (matchPattern(std::get<2>(it), m_Zero()) || + std::get<2>(it).getDefiningOp()) { nextIfCondition = std::get<1>(it); thenRegion = false; } else @@ -1355,109 +1413,6 @@ struct IfAndLazy : public OpRewritePattern { } }; -struct CombineIfs : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(scf::IfOp nextIf, - PatternRewriter &rewriter) const override { - using namespace scf; - Block *parent = nextIf->getBlock(); - if (nextIf == &parent->front()) - return failure(); - - auto prevIf = dyn_cast(nextIf->getPrevNode()); - if (!prevIf) - return failure(); - - if (nextIf.getCondition() != prevIf.getCondition()) - return failure(); - - //* Changed*// - SmallVector prevElseYielded; - if (!prevIf.getElseRegion().empty()) - prevElseYielded = prevIf.elseYield().getOperands(); - // Replace all uses of return values of op within nextIf with the - // corresponding yields - for (auto it : llvm::zip(prevIf.getResults(), - prevIf.thenYield().getOperands(), prevElseYielded)) - for (OpOperand &use : - llvm::make_early_inc_range(std::get<0>(it).getUses())) { - if (nextIf.getThenRegion().isAncestor( - use.getOwner()->getParentRegion())) { - rewriter.startRootUpdate(use.getOwner()); - use.set(std::get<1>(it)); - rewriter.finalizeRootUpdate(use.getOwner()); - } else if (nextIf.getElseRegion().isAncestor( - use.getOwner()->getParentRegion())) { - rewriter.startRootUpdate(use.getOwner()); - use.set(std::get<2>(it)); - rewriter.finalizeRootUpdate(use.getOwner()); - } - } - //* End Changed*// - - SmallVector mergedTypes(prevIf.getResultTypes()); - llvm::append_range(mergedTypes, nextIf.getResultTypes()); - - //* Changed nextIf cond to nextIf cond*// - scf::IfOp combinedIf = rewriter.create( - nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false); - rewriter.eraseBlock(&combinedIf.getThenRegion().back()); - - scf::YieldOp thenYield = prevIf.thenYield(); - scf::YieldOp thenYield2 = nextIf.thenYield(); - - combinedIf.getThenRegion().getBlocks().splice( - combinedIf.getThenRegion().getBlocks().begin(), - prevIf.getThenRegion().getBlocks()); - - rewriter.mergeBlocks(nextIf.thenBlock(), combinedIf.thenBlock()); - rewriter.setInsertionPointToEnd(combinedIf.thenBlock()); - - SmallVector mergedYields(thenYield.getOperands()); - llvm::append_range(mergedYields, thenYield2.getOperands()); - rewriter.create(thenYield2.getLoc(), mergedYields); - rewriter.eraseOp(thenYield); - rewriter.eraseOp(thenYield2); - - combinedIf.getElseRegion().getBlocks().splice( - combinedIf.getElseRegion().getBlocks().begin(), - prevIf.getElseRegion().getBlocks()); - - if (!nextIf.getElseRegion().empty()) { - if (combinedIf.getElseRegion().empty()) { - combinedIf.getElseRegion().getBlocks().splice( - combinedIf.getElseRegion().getBlocks().begin(), - nextIf.getElseRegion().getBlocks()); - } else { - scf::YieldOp elseYield = combinedIf.elseYield(); - scf::YieldOp elseYield2 = nextIf.elseYield(); - rewriter.mergeBlocks(nextIf.elseBlock(), combinedIf.elseBlock()); - - rewriter.setInsertionPointToEnd(combinedIf.elseBlock()); - - SmallVector mergedElseYields(elseYield.getOperands()); - llvm::append_range(mergedElseYields, elseYield2.getOperands()); - - rewriter.create(elseYield2.getLoc(), mergedElseYields); - rewriter.eraseOp(elseYield); - rewriter.eraseOp(elseYield2); - } - } - - SmallVector prevValues; - SmallVector nextValues; - for (const auto &pair : llvm::enumerate(combinedIf.getResults())) { - if (pair.index() < prevIf.getNumResults()) - prevValues.push_back(pair.value()); - else - nextValues.push_back(pair.value()); - } - rewriter.replaceOp(prevIf, prevValues); - rewriter.replaceOp(nextIf, nextValues); - return success(); - } -}; struct MoveIntoIfs : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1476,6 +1431,18 @@ struct MoveIntoIfs : public OpRewritePattern { if (isa(prevOp)) return failure(); + // Don't attempt to move into if in the case where there are two + // ifs to combine. + auto nestedOps = nextIf.thenBlock()->without_terminator(); + // Nested `if` must be the only op in block. + if (llvm::hasSingleElement(nestedOps)) { + + if (!nextIf.elseBlock() || llvm::hasSingleElement(*nextIf.elseBlock())) { + if (auto nestedIf = dyn_cast(*nestedOps.begin())) + return failure(); + } + } + bool thenUse = false; bool elseUse = false; bool outsideUse = false; @@ -1533,13 +1500,57 @@ struct MoveIntoIfs : public OpRewritePattern { } }; +struct MoveOutOfIfs : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::IfOp nextIf, + PatternRewriter &rewriter) const override { + // Don't attempt to move into if in the case where there are two + // ifs to combine. + auto nestedOps = nextIf.thenBlock()->without_terminator(); + // Nested `if` must be the only op in block. + if (nestedOps.empty() || llvm::hasSingleElement(nestedOps)) { + return failure(); + } + + if (nextIf.elseBlock() && !llvm::hasSingleElement(*nextIf.elseBlock())) { + return failure(); + } + + auto nestedIf = dyn_cast(*(--nestedOps.end())); + if (!nestedIf) { + return failure(); + } + SmallVector toMove; + for (auto &o : nestedOps) + if (&o != nestedIf) { + auto memInterface = dyn_cast(&o); + if (!memInterface) { + return failure(); + } + if (!memInterface.hasNoEffect()) { + return failure(); + } + toMove.push_back(&o); + } + + rewriter.setInsertionPoint(nextIf); + for (auto o : toMove) { + auto rep = rewriter.clone(*o); + rewriter.replaceOp(o, rep->getResults()); + } + + return success(); + } +}; + void Pointer2MemrefOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.insert< Pointer2MemrefCast, Pointer2Memref2PointerCast, MetaPointer2Memref, MetaPointer2Memref, MetaPointer2Memref, MetaPointer2Memref, - CombineIfs, MoveIntoIfs, IfAndLazy>(context); + MoveIntoIfs, MoveOutOfIfs, IfAndLazy>(context); } OpFoldResult Pointer2MemrefOp::fold(ArrayRef operands) { @@ -1647,7 +1658,161 @@ struct TypeAlignCanonicalize : public OpRewritePattern { } }; +class OrIExcludedMiddle final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::OrIOp op, + PatternRewriter &rewriter) const override { + auto lhs = op.getLhs().getDefiningOp(); + auto rhs = op.getRhs().getDefiningOp(); + if (!lhs || !rhs) + return failure(); + if (lhs.getLhs() != rhs.getLhs() || lhs.getRhs() != rhs.getRhs() || + lhs.getPredicate() != arith::invertPredicate(rhs.getPredicate())) + return failure(); + rewriter.replaceOpWithNewOp(op, true, 1); + return success(); + } +}; + +class SelectI1Ext final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::SelectOp op, + PatternRewriter &rewriter) const override { + auto ty = op.getType().dyn_cast(); + if (!ty) + return failure(); + if (ty.getWidth() == 1) + return failure(); + IntegerAttr lhs, rhs; + Value lhs_v = nullptr, rhs_v = nullptr; + if (auto ext = op.getTrueValue().getDefiningOp()) { + lhs_v = ext.getIn(); + if (lhs_v.getType().cast().getWidth() != 1) + return failure(); + } else if (matchPattern(op.getTrueValue(), m_Constant(&lhs))) { + } else + return failure(); + + if (auto ext = op.getFalseValue().getDefiningOp()) { + rhs_v = ext.getIn(); + if (rhs_v.getType().cast().getWidth() != 1) + return failure(); + } else if (matchPattern(op.getFalseValue(), m_Constant(&rhs))) { + } else + return failure(); + + if (!lhs_v) + lhs_v = rewriter.create(op.getLoc(), lhs.getInt(), 1); + if (!rhs_v) + rhs_v = rewriter.create(op.getLoc(), rhs.getInt(), 1); + + rewriter.replaceOpWithNewOp( + op, op.getType(), + rewriter.create(op.getLoc(), op.getCondition(), lhs_v, + rhs_v)); + return success(); + } +}; + +template class UndefProp final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(T op, + PatternRewriter &rewriter) const override { + Value v = op->getOperand(0); + Operation *undef; + if (!(undef = v.getDefiningOp())) + return failure(); + rewriter.setInsertionPoint(undef); + rewriter.replaceOpWithNewOp(op, op.getType()); + return success(); + } +}; + +class UndefCmpProp final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CmpIOp op, + PatternRewriter &rewriter) const override { + Value v = op->getOperand(0); + Operation *undef; + if (!(undef = v.getDefiningOp())) + return failure(); + if (!op.getRhs().getDefiningOp()) + return failure(); + rewriter.setInsertionPoint(undef); + rewriter.replaceOpWithNewOp(op, op.getType()); + return success(); + } +}; +class CmpProp final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CmpIOp op, + PatternRewriter &rewriter) const override { + auto ifOp = op.getLhs().getDefiningOp(); + if (!ifOp) + return failure(); + auto rhs = op.getRhs().getDefiningOp(); + if (!rhs) { + return failure(); + } + auto idx = op.getLhs().cast().getResultNumber(); + bool change = false; + for (auto v : + {ifOp.thenYield().getOperand(idx), ifOp.elseYield().getOperand(idx)}) { + change |= + v.getDefiningOp() || v.getDefiningOp(); + if (auto extOp = v.getDefiningOp()) + if (auto it = extOp.getIn().getType().dyn_cast()) + change |= it.getWidth() == 1; + if (auto extOp = v.getDefiningOp()) + if (auto it = extOp.getIn().getType().dyn_cast()) + change |= it.getWidth() == 1; + } + if (!change) { + return failure(); + } + + SmallVector resultTypes; + llvm::append_range(resultTypes, ifOp.getResultTypes()); + resultTypes.push_back(op.getType()); + + auto rhs2 = rewriter.clone(*rhs)->getResult(0); + auto nop = rewriter.create( + ifOp.getLoc(), resultTypes, ifOp.getCondition(), /*hasElse*/ true); + nop.getThenRegion().takeBody(ifOp.getThenRegion()); + nop.getElseRegion().takeBody(ifOp.getElseRegion()); + + SmallVector thenYields; + llvm::append_range(thenYields, nop.thenYield().getOperands()); + rewriter.setInsertionPoint(nop.thenYield()); + thenYields.push_back(rewriter.create(op.getLoc(), op.getPredicate(), + thenYields[idx], rhs2)); + nop.thenYield()->setOperands(thenYields); + + SmallVector elseYields; + llvm::append_range(elseYields, nop.elseYield().getOperands()); + rewriter.setInsertionPoint(nop.elseYield()); + elseYields.push_back(rewriter.create(op.getLoc(), op.getPredicate(), + elseYields[idx], rhs2)); + nop.elseYield()->setOperands(elseYields); + rewriter.replaceOp(ifOp, nop.getResults().take_front(ifOp.getNumResults())); + rewriter.replaceOp(op, nop.getResults().take_back(1)); + return success(); + } +}; + void TypeAlignOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.insert(context); + results.insert, UndefProp, UndefProp, + CmpProp, UndefCmpProp>(context); } diff --git a/lib/polygeist/Passes/AffineCFG.cpp b/lib/polygeist/Passes/AffineCFG.cpp index 688f0850688b..12f87edb68f2 100644 --- a/lib/polygeist/Passes/AffineCFG.cpp +++ b/lib/polygeist/Passes/AffineCFG.cpp @@ -4,10 +4,15 @@ #include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/FunctionInterfaces.h" #include "mlir/IR/IntegerSet.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "polygeist/Passes/Passes.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/Support/Debug.h" #include #include @@ -18,6 +23,36 @@ using namespace mlir; using namespace mlir::arith; using namespace polygeist; +bool isReadOnly(Operation *op); + +// isValidSymbol, even if not index +bool isValidSymbolInt(Value value, bool recur = true) { + // Check that the value is a top level value. + if (isTopLevelValue(value)) + return true; + + if (auto *defOp = value.getDefiningOp()) { + Attribute operandCst; + if (matchPattern(defOp, m_Constant(&operandCst))) + return true; + + if (recur) { + if (isa(defOp)) + if (llvm::all_of(defOp->getOperands(), [&](Value v) { + bool b = isValidSymbolInt(v, true); + // if (!b) + // LLVM_DEBUG(llvm::dbgs() << "illegal isValidSymbolInt: " + //<< value << " due to " << v << "\n"); + return b; + })) + return true; + } + return isValidSymbol(value, getAffineScope(defOp)); + } + return false; +} + struct AffineApplyNormalizer { AffineApplyNormalizer(AffineMap map, ArrayRef operands); @@ -83,18 +118,14 @@ static bool isAffineForArg(Value val) { return (parentOp && isa(parentOp)); } -static bool legalCondition(Value en, bool outer = true, bool dim = false) { - if (en.getDefiningOp() || en.getDefiningOp()) +static bool legalCondition(Value en, bool dim = false) { + if (en.getDefiningOp()) return true; - if (!isValidSymbol(en)) { - if (en.getDefiningOp() || en.getDefiningOp() || - en.getDefiningOp() || en.getDefiningOp()) { + if (!dim && !isValidSymbolInt(en, /*recur*/ false)) { + if (isValidIndex(en) || isValidSymbolInt(en, /*recur*/ true)) { return true; } - if (auto m = en.getDefiningOp()) { - return m.getRhs().getDefiningOp(); - } } // if (auto IC = dyn_cast_or_null(en.getDefiningOp())) { // if (!outer || legalCondition(IC.getOperand(), false)) return true; @@ -107,55 +138,6 @@ static bool legalCondition(Value en, bool outer = true, bool dim = false) { return false; } -// Gather the positions of the operands that are produced by an AffineApplyOp. -static llvm::SetVector -indicesFromAffineApplyOp(ArrayRef operands) { - llvm::SetVector res; - for (auto en : llvm::enumerate(operands)) { - if (legalCondition(en.value())) - res.insert(en.index()); - } - return res; -} - -static AffineMap promoteComposedSymbolsAsDims(AffineMap map, - ArrayRef symbols) { - if (symbols.empty()) { - return map; - } - - // Sanity check on symbols. - for (auto sym : symbols) { - // assert(isValidSymbol(sym) && "Expected only valid symbols"); - (void)sym; - } - - // Extract the symbol positions that come from an AffineApplyOp and - // needs to be rewritten as dims. - auto symPositions = indicesFromAffineApplyOp(symbols); - if (symPositions.empty()) { - return map; - } - - // Create the new map by replacing each symbol at pos by the next new dim. - unsigned numDims = map.getNumDims(); - unsigned numSymbols = map.getNumSymbols(); - unsigned numNewDims = 0; - unsigned numNewSymbols = 0; - SmallVector symReplacements(numSymbols); - for (unsigned i = 0; i < numSymbols; ++i) { - symReplacements[i] = - symPositions.count(i) > 0 - ? getAffineDimExpr(numDims + numNewDims++, map.getContext()) - : getAffineSymbolExpr(numNewSymbols++, map.getContext()); - } - assert(numSymbols >= numNewDims); - AffineMap newMap = map.replaceDimsAndSymbols( - {}, symReplacements, numDims + numNewDims, numNewSymbols); - - return newMap; -} - /// The AffineNormalizer composes AffineApplyOp recursively. Its purpose is to /// keep a correspondence between the mathematical `map` and the `operands` of /// a given AffineApplyOp. This correspondence is maintained by iterating over @@ -196,198 +178,241 @@ AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map, LLVM_DEBUG(map.print(llvm::dbgs() << "\nInput map: ")); - // Promote symbols that come from an AffineApplyOp to dims by rewriting the - // map to always refer to: - // (dims, symbols coming from AffineApplyOp, other symbols). - // The order of operands can remain unchanged. - // This is a simplification that relies on 2 ordering properties: - // 1. rewritten symbols always appear after the original dims in the map; - // 2. operands are traversed in order and either dispatched to: - // a. auxiliaryExprs (dims and symbols rewritten as dims); - // b. concatenatedSymbols (all other symbols) - // This allows operand order to remain unchanged. - unsigned numDimsBeforeRewrite = map.getNumDims(); - map = promoteComposedSymbolsAsDims(map, - operands.take_back(map.getNumSymbols())); + SmallVector auxiliaryExprs; + SmallVector addedValues; - LLVM_DEBUG(map.print(llvm::dbgs() << "\nRewritten map: ")); + unsigned numDimsBeforeRewrite = map.getNumDims(); + llvm::SmallSet symbolsToPromote; + + // 2. Compose AffineApplyOps and dispatch dims or symbols. + for (unsigned i = 0, e = operands.size(); i < e; ++i) { + auto t = operands[i]; + if (!isValidSymbolInt(t, /*recur*/ false)) { + while (auto idx = t.getDefiningOp()) { + if (idx.getIn().getDefiningOp() || + idx.getIn().getDefiningOp() || + idx.getIn().getDefiningOp() || + idx.getIn().getDefiningOp() || + idx.getIn().getDefiningOp() || + idx.getIn().getDefiningOp() || + idx.getIn().getDefiningOp() || + idx.getIn().getDefiningOp()) + t = idx.getIn(); + else + break; + } + } - SmallVector auxiliaryExprs; - bool furtherCompose = (affineApplyDepth() <= kMaxAffineApplyDepth); - // We fully spell out the 2 cases below. In this particular instance a little - // code duplication greatly improves readability. - // Note that the first branch would disappear if we only supported full - // composition (i.e. infinite kMaxAffineApplyDepth). - if (!furtherCompose) { - // 1. Only dispatch dims or symbols. - for (auto en : llvm::enumerate(operands)) { - auto t = en.value(); - assert(t.getType().isIndex()); - bool isDim = (en.index() < map.getNumDims()); - if (isDim) { - // a. The mathematical composition of AffineMap composes dims. - auxiliaryExprs.push_back(renumberOneDim(t)); + if (!isValidSymbolInt(t, /*recur*/ false) && + (t.getDefiningOp() || t.getDefiningOp() || + t.getDefiningOp() || t.getDefiningOp() || + t.getDefiningOp() || t.getDefiningOp() || + t.getDefiningOp() || t.getDefiningOp() || + t.getDefiningOp())) { + + AffineMap affineApplyMap; + SmallVector affineApplyOperands; + + // llvm::dbgs() << "\nop to start: " << t << "\n"; + + if (auto op = t.getDefiningOp()) { + affineApplyMap = + AffineMap::get(0, 2, + getAffineSymbolExpr(0, op.getContext()) + + getAffineSymbolExpr(1, op.getContext())); + } else if (auto op = t.getDefiningOp()) { + affineApplyMap = + AffineMap::get(0, 2, + getAffineSymbolExpr(0, op.getContext()) - + getAffineSymbolExpr(1, op.getContext())); + } else if (auto op = t.getDefiningOp()) { + affineApplyMap = + AffineMap::get(0, 2, + getAffineSymbolExpr(0, op.getContext()) * + getAffineSymbolExpr(1, op.getContext())); + } else if (auto op = t.getDefiningOp()) { + affineApplyMap = AffineMap::get( + 0, 2, + getAffineSymbolExpr(0, op.getContext()) + .floorDiv(getAffineSymbolExpr(1, op.getContext()))); + } else if (auto op = t.getDefiningOp()) { + affineApplyMap = AffineMap::get( + 0, 2, + getAffineSymbolExpr(0, op.getContext()) + .floorDiv(getAffineSymbolExpr(1, op.getContext()))); + } else if (auto op = t.getDefiningOp()) { + affineApplyMap = + AffineMap::get(0, 2, + getAffineSymbolExpr(0, op.getContext()) % + getAffineSymbolExpr(1, op.getContext())); + } else if (auto op = t.getDefiningOp()) { + affineApplyMap = + AffineMap::get(0, 2, + getAffineSymbolExpr(0, op.getContext()) % + getAffineSymbolExpr(1, op.getContext())); + } else if (auto op = t.getDefiningOp()) { + affineApplyMap = AffineMap::get( + 0, 0, getAffineConstantExpr(op.value(), op.getContext())); + } else if (auto op = t.getDefiningOp()) { + affineApplyMap = AffineMap::get( + 0, 0, getAffineConstantExpr(op.value(), op.getContext())); } else { - // b. The mathematical composition of AffineMap concatenates symbols. - // We do the same for symbol operands. - concatenatedSymbols.push_back(t); + llvm_unreachable(""); } - } - } else { - assert(numDimsBeforeRewrite <= operands.size()); - - SmallVector addedValues; - - // 2. Compose AffineApplyOps and dispatch dims or symbols. - for (unsigned i = 0, e = operands.size(); i < e; ++i) { - auto t = operands[i]; - - if (!isValidSymbol(t) && - (t.getDefiningOp() || t.getDefiningOp() || - t.getDefiningOp() || t.getDefiningOp() || - t.getDefiningOp())) { - - AffineMap affineApplyMap; - SmallVector affineApplyOperands; - - // llvm::dbgs() << "\nop to start: " << t << "\n"; - - if (auto op = t.getDefiningOp()) { - affineApplyMap = - AffineMap::get(0, 2, - getAffineSymbolExpr(0, op.getContext()) + - getAffineSymbolExpr(1, op.getContext())); - affineApplyOperands.append(op.getOperands().begin(), - op.getOperands().end()); - } else if (auto op = t.getDefiningOp()) { - affineApplyMap = - AffineMap::get(0, 2, - getAffineSymbolExpr(0, op.getContext()) - - getAffineSymbolExpr(1, op.getContext())); - affineApplyOperands.append(op.getOperands().begin(), - op.getOperands().end()); - } else if (auto op = t.getDefiningOp()) { - affineApplyMap = - AffineMap::get(0, 2, - getAffineSymbolExpr(0, op.getContext()) * - getAffineSymbolExpr(1, op.getContext())); - affineApplyOperands.append(op.getOperands().begin(), - op.getOperands().end()); - } else if (auto op = t.getDefiningOp()) { - affineApplyMap = AffineMap::get( - 0, 2, - getAffineSymbolExpr(0, op.getContext()) - .floorDiv(getAffineSymbolExpr(1, op.getContext()))); - affineApplyOperands.append(op.getOperands().begin(), - op.getOperands().end()); - } else if (auto op = t.getDefiningOp()) { - affineApplyMap = AffineMap::get( - 0, 2, - getAffineSymbolExpr(0, op.getContext()) - .floorDiv(getAffineSymbolExpr(1, op.getContext()))); - affineApplyOperands.append(op.getOperands().begin(), - op.getOperands().end()); - } else { - llvm_unreachable(""); - } - SmallVector dimRemapping; - unsigned numOtherSymbols = affineApplyOperands.size(); - SmallVector symRemapping(numOtherSymbols); - for (unsigned idx = 0; idx < numOtherSymbols; ++idx) { - symRemapping[idx] = getAffineSymbolExpr(addedValues.size(), - affineApplyMap.getContext()); - addedValues.push_back(affineApplyOperands[idx]); - } - affineApplyMap = affineApplyMap.replaceDimsAndSymbols( - dimRemapping, symRemapping, reorderedDims.size(), - addedValues.size()); - - LLVM_DEBUG(affineApplyMap.print( - llvm::dbgs() << "\nRenumber into current normalizer: ")); - auxiliaryExprs.push_back(affineApplyMap.getResult(0)); - /* - llvm::dbgs() << "\n"; - for(auto op : affineApplyOperands) { - llvm::dbgs() << " + prevop: " << op << "\n"; - } - */ - } else if (isAffineForArg(t)) { - auxiliaryExprs.push_back(renumberOneDim(t)); - /* - } else if (auto op = t.getDefiningOp()) { - // Todo index cast - if (legalCondition(op.getOperand())) { - if (i < numDimsBeforeRewrite) { - auxiliaryExprs.push_back(renumberOneDim(t)); - } else { - auxiliaryExprs.push_back(getAffineSymbolExpr(addedValues.size(), - op.getContext())); addedValues.push_back(op.getOperand()); - } - } else { - auxiliaryExprs.push_back(getAffineSymbolExpr(addedValues.size(), - op.getContext())); addedValues.push_back(op); - } - } else if (auto op = t.getDefiningOp()) { - auxiliaryExprs.push_back(getAffineSymbolExpr(addedValues.size(), - op.getContext())); addedValues.push_back(op.getOperand()); - */ - } else if (auto affineApply = t.getDefiningOp()) { - // a. Compose affine.apply operations. - LLVM_DEBUG(affineApply->print( - llvm::dbgs() << "\nCompose AffineApplyOp recursively: ")); - AffineMap affineApplyMap = affineApply.getAffineMap(); - SmallVector affineApplyOperands( - affineApply.getOperands().begin(), affineApply.getOperands().end()); - - SmallVector dimRemapping(affineApplyMap.getNumDims()); - - for (size_t i = 0; i < affineApplyMap.getNumDims(); ++i) { - assert(i < affineApplyOperands.size()); - dimRemapping[i] = renumberOneDim(affineApplyOperands[i]); - } - unsigned numOtherSymbols = affineApplyOperands.size(); - SmallVector symRemapping(numOtherSymbols - - affineApplyMap.getNumDims()); - for (unsigned idx = 0; idx < symRemapping.size(); ++idx) { - symRemapping[idx] = getAffineSymbolExpr(addedValues.size(), - affineApplyMap.getContext()); - addedValues.push_back( - affineApplyOperands[idx + affineApplyMap.getNumDims()]); - } - affineApplyMap = affineApplyMap.replaceDimsAndSymbols( - dimRemapping, symRemapping, reorderedDims.size(), - addedValues.size()); + for (auto op : t.getDefiningOp()->getOperands()) { + affineApplyOperands.push_back(op); + } - LLVM_DEBUG( - affineApplyMap.print(llvm::dbgs() << "\nAffine apply fixup map: ")); - auxiliaryExprs.push_back(affineApplyMap.getResult(0)); + SmallVector dimRemapping; + unsigned numOtherSymbols = affineApplyOperands.size(); + SmallVector symRemapping(numOtherSymbols); + for (unsigned idx = 0; idx < numOtherSymbols; ++idx) { + symRemapping[idx] = getAffineSymbolExpr(addedValues.size(), + affineApplyMap.getContext()); + addedValues.push_back(affineApplyOperands[idx]); + } + affineApplyMap = affineApplyMap.replaceDimsAndSymbols( + dimRemapping, symRemapping, reorderedDims.size(), addedValues.size()); + + if (i >= numDimsBeforeRewrite) + symbolsToPromote.insert(i - numDimsBeforeRewrite); + + LLVM_DEBUG(affineApplyMap.print( + llvm::dbgs() << "\nRenumber into current normalizer: ")); + auxiliaryExprs.push_back(affineApplyMap.getResult(0)); + /* + llvm::dbgs() << "\n"; + for(auto op : affineApplyOperands) { + llvm::dbgs() << " + prevop: " << op << "\n"; + } + */ + } else if (isAffineForArg(t)) { + auxiliaryExprs.push_back(renumberOneDim(t)); + if (i >= numDimsBeforeRewrite) + symbolsToPromote.insert(i - numDimsBeforeRewrite); + } else if (auto affineApply = t.getDefiningOp()) { + // a. Compose affine.apply operations. + LLVM_DEBUG(affineApply->print( + llvm::dbgs() << "\nCompose AffineApplyOp recursively: ")); + AffineMap affineApplyMap = affineApply.getAffineMap(); + SmallVector affineApplyOperands( + affineApply.getOperands().begin(), affineApply.getOperands().end()); + + SmallVector dimRemapping(affineApplyMap.getNumDims()); + + for (size_t i = 0; i < affineApplyMap.getNumDims(); ++i) { + assert(i < affineApplyOperands.size()); + dimRemapping[i] = renumberOneDim(affineApplyOperands[i]); + } + unsigned numOtherSymbols = affineApplyOperands.size(); + SmallVector symRemapping(numOtherSymbols - + affineApplyMap.getNumDims()); + for (unsigned idx = 0; idx < symRemapping.size(); ++idx) { + symRemapping[idx] = getAffineSymbolExpr(addedValues.size(), + affineApplyMap.getContext()); + addedValues.push_back( + affineApplyOperands[idx + affineApplyMap.getNumDims()]); + } + affineApplyMap = affineApplyMap.replaceDimsAndSymbols( + dimRemapping, symRemapping, reorderedDims.size(), addedValues.size()); + + if (i >= numDimsBeforeRewrite) + symbolsToPromote.insert(i - numDimsBeforeRewrite); + + LLVM_DEBUG( + affineApplyMap.print(llvm::dbgs() << "\nAffine apply fixup map: ")); + auxiliaryExprs.push_back(affineApplyMap.getResult(0)); + } else { + if (!isValidSymbolInt(t, /*recur*/ false)) { + if (auto idx = t.getDefiningOp()) { + auto scope = getAffineScope(idx)->getParentOp(); + DominanceInfo DI(scope); + + std::function fix = [&](Value v) -> bool /*legal*/ { + if (isValidSymbolInt(v, /*recur*/ false)) + return true; + auto op = v.getDefiningOp(); + if (!op) + llvm::errs() << v << "\n"; + assert(op); + if (isa(op) || isa(op)) + return true; + if (!isReadOnly(op)) { + return false; + } + Operation *front = nullptr; + for (auto o : op->getOperands()) { + Operation *next; + if (auto op = o.getDefiningOp()) { + if (!fix(o)) { + return false; + } + next = op; + } else { + auto BA = o.cast(); + if (!isValidSymbolInt(o, /*recur*/ false)) { + return false; + } + next = &BA.getOwner()->front(); + } + if (front == nullptr) + front = next; + else if (DI.dominates(front, next)) + front = next; + } + if (!front) + op->dump(); + assert(front); + op->moveAfter(front); + return true; + }; + if (fix(idx)) + assert(isValidSymbolInt(idx, /*recur*/ false)); + else + t = idx.getIn(); + } else + assert(0 && "cannot move"); + } + if (i < numDimsBeforeRewrite) { + // b. The mathematical composition of AffineMap composes dims. + auxiliaryExprs.push_back(renumberOneDim(t)); } else { - if (i < numDimsBeforeRewrite) { - // b. The mathematical composition of AffineMap composes dims. - auxiliaryExprs.push_back(renumberOneDim(t)); - } else { - // c. The mathematical composition of AffineMap concatenates symbols. - // Note that the map composition will put symbols already present - // in the map before any symbols coming from the auxiliary map, so - // we insert them before any symbols that are due to renumbering, - // and after the proper symbols we have seen already. - concatenatedSymbols.insert( - std::next(concatenatedSymbols.begin(), numProperSymbols++), t); - } + // c. The mathematical composition of AffineMap concatenates symbols. + // Note that the map composition will put symbols already present + // in the map before any symbols coming from the auxiliary map, so + // we insert them before any symbols that are due to renumbering, + // and after the proper symbols we have seen already. + concatenatedSymbols.insert( + std::next(concatenatedSymbols.begin(), numProperSymbols++), t); } } - for (auto val : addedValues) { - concatenatedSymbols.push_back(val); - } } - // Early exit if `map` is already composed. - if (auxiliaryExprs.empty()) { - affineMap = map; - return; + for (auto val : addedValues) { + concatenatedSymbols.push_back(val); + } + + { + // Create the new map by replacing each symbol at pos by the next new dim. + unsigned numDims = map.getNumDims(); + unsigned numSymbols = map.getNumSymbols(); + unsigned numNewDims = 0; + unsigned numNewSymbols = 0; + SmallVector symReplacements(numSymbols); + for (unsigned i = 0; i < numSymbols; ++i) { + symReplacements[i] = + symbolsToPromote.count(i) > 0 + ? getAffineDimExpr(numDims + numNewDims++, map.getContext()) + : getAffineSymbolExpr(numNewSymbols++, map.getContext()); + } + assert(numSymbols >= numNewDims); + map = map.replaceDimsAndSymbols({}, symReplacements, numDims + numNewDims, + numNewSymbols); } + LLVM_DEBUG(map.print(llvm::dbgs() << "\nRewritten map: ")); + assert(concatenatedSymbols.size() >= map.getNumSymbols() && "Unexpected number of concatenated symbols"); auto numDims = dimValueToPosition.size(); @@ -447,22 +472,51 @@ bool need(AffineMap *map, SmallVectorImpl *operands) { assert(map->getNumInputs() == operands->size()); for (size_t i = 0; i < map->getNumInputs(); ++i) { auto v = (*operands)[i]; - if (legalCondition(v, true, i < map->getNumDims())) + if (legalCondition(v, i < map->getNumDims())) { return true; + } } return false; } bool need(IntegerSet *map, SmallVectorImpl *operands) { for (size_t i = 0; i < map->getNumInputs(); ++i) { auto v = (*operands)[i]; - if (legalCondition(v, true, i < map->getNumDims())) + if (legalCondition(v, i < map->getNumDims())) return true; } return false; } -void fully2ComposeAffineMapAndOperands(AffineMap *map, +void fully2ComposeAffineMapAndOperands(OpBuilder &builder, AffineMap *map, SmallVectorImpl *operands) { + BlockAndValueMapping indexMap; + for (auto op : *operands) { + if (auto idx = op.getDefiningOp()) { + Operation *start = idx; + bool immediate = false; + + while (1) { + if (start == idx.getIn().getDefiningOp()) { + immediate = true; + break; + } + if (isa(start)) { + if (start == &start->getBlock()->front()) { + if (auto BA = idx.getIn().dyn_cast()) + if (start->getBlock() == BA.getOwner()) { + immediate = true; + break; + } + break; + } + start = start->getPrevNode(); + } + break; + } + if (immediate) + indexMap.map(idx.getIn(), idx); + } + } assert(map->getNumInputs() == operands->size()); while (need(map, operands)) { // llvm::errs() << "pre: " << *map << "\n"; @@ -476,6 +530,26 @@ void fully2ComposeAffineMapAndOperands(AffineMap *map, // llvm::errs() << " -- operands: " << op << "\n"; //} } + for (auto &op : *operands) { + if (!op.getType().isIndex()) { + Operation *toInsert; + if (auto o = op.getDefiningOp()) + toInsert = o->getNextNode(); + else { + auto BA = op.cast(); + toInsert = &BA.getOwner()->front(); + } + + if (auto v = indexMap.lookupOrNull(op)) + op = v; + else { + OpBuilder::InsertionGuard B(builder); + builder.setInsertionPoint(toInsert); + op = builder.create(op.getLoc(), builder.getIndexType(), + op); + } + } + } } static void composeIntegerSetAndOperands(IntegerSet *set, @@ -492,11 +566,59 @@ static void composeIntegerSetAndOperands(IntegerSet *set, *operands = normalizedOperands; } -void fully2ComposeIntegerSetAndOperands(IntegerSet *set, +void fully2ComposeIntegerSetAndOperands(OpBuilder &builder, IntegerSet *set, SmallVectorImpl *operands) { + BlockAndValueMapping indexMap; + for (auto op : *operands) { + if (auto idx = op.getDefiningOp()) { + Operation *start = idx; + bool immediate = false; + + while (1) { + if (start == idx.getIn().getDefiningOp()) { + immediate = true; + break; + } + if (isa(start)) { + if (start == &start->getBlock()->front()) { + if (auto BA = idx.getIn().dyn_cast()) + if (start->getBlock() == BA.getOwner()) { + immediate = true; + break; + } + break; + } + start = start->getPrevNode(); + } + break; + } + if (immediate) + indexMap.map(idx.getIn(), idx); + } + } while (need(set, operands)) { composeIntegerSetAndOperands(set, operands); } + for (auto &op : *operands) { + if (!op.getType().isIndex()) { + Operation *toInsert; + if (auto o = op.getDefiningOp()) + toInsert = o->getNextNode(); + else { + auto BA = op.cast(); + toInsert = &BA.getOwner()->front(); + } + + if (auto v = indexMap.lookupOrNull(op)) + op = v; + else { + OpBuilder::InsertionGuard B(builder); + builder.setInsertionPoint(toInsert); + op = builder.create(op.getLoc(), builder.getIndexType(), + op); + } + } + } } namespace { @@ -617,6 +739,55 @@ struct SimplfyIntegerCastMath : public OpRewritePattern { iadd.getOperand(1))); return success(); } + if (auto iadd = op.getOperand().getDefiningOp()) { + OpBuilder b(rewriter); + setLocationAfter(b, iadd.getOperand(0)); + OpBuilder b2(rewriter); + setLocationAfter(b2, iadd.getOperand(1)); + rewriter.replaceOpWithNewOp( + op, + b.create(op.getLoc(), op.getType(), + iadd.getOperand(0)), + b2.create(op.getLoc(), op.getType(), + iadd.getOperand(1))); + return success(); + } + if (auto iadd = op.getOperand().getDefiningOp()) { + OpBuilder b(rewriter); + setLocationAfter(b, iadd.getOperand(0)); + OpBuilder b2(rewriter); + setLocationAfter(b2, iadd.getOperand(1)); + rewriter.replaceOpWithNewOp( + op, + b.create(op.getLoc(), op.getType(), + iadd.getOperand(0)), + b2.create(op.getLoc(), op.getType(), + iadd.getOperand(1))); + return success(); + } + if (auto iadd = op.getOperand().getDefiningOp()) { + OpBuilder b(rewriter); + setLocationAfter(b, iadd.getTrueValue()); + OpBuilder b2(rewriter); + setLocationAfter(b2, iadd.getFalseValue()); + auto cond = iadd.getCondition(); + OpBuilder b3(rewriter); + setLocationAfter(b3, cond); + if (auto cmp = iadd.getCondition().getDefiningOp()) { + if (cmp.getLhs() == iadd.getTrueValue() && + cmp.getRhs() == iadd.getFalseValue()) { + + auto truev = b.create(op.getLoc(), op.getType(), + iadd.getTrueValue()); + auto falsev = b2.create(op.getLoc(), op.getType(), + iadd.getFalseValue()); + cond = b3.create(cmp.getLoc(), cmp.getPredicate(), truev, + falsev); + rewriter.replaceOpWithNewOp(op, cond, truev, falsev); + return success(); + } + } + } return failure(); } }; @@ -631,7 +802,7 @@ struct CanonicalizeAffineApply : public OpRewritePattern { auto map = affineOp.map(); auto prevMap = map; - fully2ComposeAffineMapAndOperands(&map, &mapOperands); + fully2ComposeAffineMapAndOperands(rewriter, &map, &mapOperands); canonicalizeMapAndOperands(&map, &mapOperands); map = removeDuplicateExprs(map); @@ -688,7 +859,7 @@ struct CanonicalizeAffineIf : public OpRewritePattern { */ bool isValidIndex(Value val) { - if (mlir::isValidSymbol(val)) + if (isValidSymbolInt(val)) return true; if (auto cast = val.getDefiningOp()) @@ -699,17 +870,26 @@ bool isValidIndex(Value val) { if (auto bop = val.getDefiningOp()) return (isValidIndex(bop.getOperand(0)) && - isValidSymbol(bop.getOperand(1))) || + isValidSymbolInt(bop.getOperand(1))) || (isValidIndex(bop.getOperand(1)) && - isValidSymbol(bop.getOperand(0))); + isValidSymbolInt(bop.getOperand(0))); if (auto bop = val.getDefiningOp()) return (isValidIndex(bop.getOperand(0)) && - isValidSymbol(bop.getOperand(1))); + isValidSymbolInt(bop.getOperand(1))); if (auto bop = val.getDefiningOp()) return (isValidIndex(bop.getOperand(0)) && - isValidSymbol(bop.getOperand(1))); + isValidSymbolInt(bop.getOperand(1))); + + if (auto bop = val.getDefiningOp()) { + return (isValidIndex(bop.getOperand(0)) && + isValidSymbolInt(bop.getOperand(1))); + } + + if (auto bop = val.getDefiningOp()) + return (isValidIndex(bop.getOperand(0)) && + isValidSymbolInt(bop.getOperand(1))); if (auto bop = val.getDefiningOp()) return isValidIndex(bop.getOperand(0)) && isValidIndex(bop.getOperand(1)); @@ -723,8 +903,15 @@ bool isValidIndex(Value val) { if (auto ba = val.dyn_cast()) { auto owner = ba.getOwner(); assert(owner); - auto parentOp = owner->getParentOp(); + auto parentOp = owner->getParentOp(); + if (!parentOp) { + owner->dump(); + llvm::errs() << " ba: " << ba << "\n"; + } + assert(parentOp); + if (isa(parentOp)) + return true; if (auto af = dyn_cast(parentOp)) return af.getInductionVar() == ba; @@ -732,7 +919,7 @@ bool isValidIndex(Value val) { if (isa(parentOp)) return true; - if (isa(parentOp)) + if (isa(parentOp)) return true; } @@ -740,67 +927,126 @@ bool isValidIndex(Value val) { return false; } +// returns legality +bool handleMinMax(Value start, SmallVectorImpl &out, bool &min, + bool &max) { + + SmallVector todo = {start}; + while (todo.size()) { + auto cur = todo.back(); + todo.pop_back(); + if (isValidIndex(cur)) { + out.push_back(cur); + continue; + } else if (auto selOp = cur.getDefiningOp()) { + // UB only has min of operands + if (auto cmp = selOp.getCondition().getDefiningOp()) { + if (cmp.getLhs() == selOp.getTrueValue() && + cmp.getRhs() == selOp.getFalseValue()) { + todo.push_back(cmp.getLhs()); + todo.push_back(cmp.getRhs()); + if (cmp.getPredicate() == CmpIPredicate::sle || + cmp.getPredicate() == CmpIPredicate::slt) { + min = true; + continue; + } + if (cmp.getPredicate() == CmpIPredicate::sge || + cmp.getPredicate() == CmpIPredicate::sgt) { + max = true; + continue; + } + } + } + } + return false; + } + return !(min && max); +} + bool handle(OpBuilder &b, CmpIOp cmpi, SmallVectorImpl &exprs, SmallVectorImpl &eqflags, SmallVectorImpl &applies) { - AffineMap lhsmap = - AffineMap::get(0, 1, getAffineSymbolExpr(0, cmpi.getContext())); - if (!isValidIndex(cmpi.getLhs())) { + SmallVector lhs; + bool lhs_min = false; + bool lhs_max = false; + if (!handleMinMax(cmpi.getLhs(), lhs, lhs_min, lhs_max)) { LLVM_DEBUG(llvm::dbgs() << "illegal lhs: " << cmpi.getLhs() << " - " << cmpi << "\n"); return false; } - if (!isValidIndex(cmpi.getRhs())) { + assert(lhs.size()); + SmallVector rhs; + bool rhs_min = false; + bool rhs_max = false; + if (!handleMinMax(cmpi.getRhs(), rhs, rhs_min, rhs_max)) { LLVM_DEBUG(llvm::dbgs() << "illegal rhs: " << cmpi.getRhs() << " - " << cmpi << "\n"); return false; } - SmallVector lhspack = {cmpi.getLhs()}; - if (!lhspack[0].getType().isa()) { - auto op = b.create( - cmpi.getLoc(), IndexType::get(cmpi.getContext()), lhspack[0]); - lhspack[0] = op; - } + assert(rhs.size()); + for (auto &lhspack : lhs) + if (!lhspack.getType().isa()) { + lhspack = b.create( + cmpi.getLoc(), IndexType::get(cmpi.getContext()), lhspack); + } - AffineMap rhsmap = - AffineMap::get(0, 1, getAffineSymbolExpr(0, cmpi.getContext())); - SmallVector rhspack = {cmpi.getRhs()}; - if (!rhspack[0].getType().isa()) { - auto op = b.create( - cmpi.getLoc(), IndexType::get(cmpi.getContext()), rhspack[0]); - rhspack[0] = op; - } + for (auto &rhspack : rhs) + if (!rhspack.getType().isa()) { + rhspack = b.create( + cmpi.getLoc(), IndexType::get(cmpi.getContext()), rhspack); + } - applies.push_back( - b.create(cmpi.getLoc(), lhsmap, lhspack)); - applies.push_back( - b.create(cmpi.getLoc(), rhsmap, rhspack)); - AffineExpr dims[2] = {b.getAffineDimExpr(2 * exprs.size() + 0), - b.getAffineDimExpr(2 * exprs.size() + 1)}; switch (cmpi.getPredicate()) { - case CmpIPredicate::eq: - exprs.push_back(dims[0] - dims[1]); + case CmpIPredicate::eq: { + if (lhs_min || lhs_max || rhs_min || rhs_max) + return false; eqflags.push_back(true); - break; - case CmpIPredicate::sge: + applies.push_back(lhs[0]); + applies.push_back(rhs[0]); + AffineExpr dims[2] = {b.getAffineSymbolExpr(2 * exprs.size() + 0), + b.getAffineSymbolExpr(2 * exprs.size() + 1)}; exprs.push_back(dims[0] - dims[1]); - eqflags.push_back(false); - break; - - case CmpIPredicate::sle: - exprs.push_back(dims[1] - dims[0]); - eqflags.push_back(false); - break; + } break; - case CmpIPredicate::sgt: - exprs.push_back(dims[0] - dims[1] + 1); - eqflags.push_back(false); - break; + case CmpIPredicate::sge: + case CmpIPredicate::sgt: { + // if lhs >=? rhs + // if lhs is a min(a, b) both must be true and this is fine + // if lhs is a max(a, b) either may be true, and sets require and + // similarly if rhs is a max(), both must be true; + if (lhs_max || rhs_min) + return false; + for (auto lhspack : lhs) + for (auto rhspack : rhs) { + eqflags.push_back(false); + applies.push_back(lhspack); + applies.push_back(rhspack); + AffineExpr dims[2] = {b.getAffineSymbolExpr(2 * exprs.size() + 0), + b.getAffineSymbolExpr(2 * exprs.size() + 1)}; + auto expr = dims[0] - dims[1]; + if (cmpi.getPredicate() == CmpIPredicate::sgt) + expr = expr + 1; + exprs.push_back(expr); + } + } break; case CmpIPredicate::slt: - exprs.push_back(dims[1] - dims[0] - 1); - eqflags.push_back(false); - break; + case CmpIPredicate::sle: { + if (lhs_min || rhs_max) + return false; + for (auto lhspack : lhs) + for (auto rhspack : rhs) { + eqflags.push_back(false); + applies.push_back(lhspack); + applies.push_back(rhspack); + AffineExpr dims[2] = {b.getAffineSymbolExpr(2 * exprs.size() + 0), + b.getAffineSymbolExpr(2 * exprs.size() + 1)}; + auto expr = dims[1] - dims[0]; + if (cmpi.getPredicate() == CmpIPredicate::slt) + expr = expr - 1; + exprs.push_back(expr); + } + } break; case CmpIPredicate::ne: case CmpIPredicate::ult: @@ -857,10 +1103,16 @@ struct MoveLoadToAffine : public OpRewritePattern { auto memrefType = load.getMemRef().getType().cast(); int64_t rank = memrefType.getRank(); + // Create identity map for memrefs with at least one dimension or () -> () // for zero-dimensional memrefs. - auto map = rank ? rewriter.getMultiDimIdentityMap(rank) - : rewriter.getEmptyAffineMap(); + SmallVector dimExprs; + dimExprs.reserve(rank); + for (unsigned i = 0; i < rank; ++i) + dimExprs.push_back(rewriter.getAffineSymbolExpr(i)); + auto map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/rank, dimExprs, + rewriter.getContext()); + SmallVector operands = load.getIndices(); if (map.getNumInputs() != operands.size()) { @@ -868,7 +1120,7 @@ struct MoveLoadToAffine : public OpRewritePattern { llvm::errs() << " load: " << load << "\n"; } assert(map.getNumInputs() == operands.size()); - fully2ComposeAffineMapAndOperands(&map, &operands); + fully2ComposeAffineMapAndOperands(rewriter, &map, &operands); assert(map.getNumInputs() == operands.size()); canonicalizeMapAndOperands(&map, &operands); assert(map.getNumInputs() == operands.size()); @@ -891,13 +1143,18 @@ struct MoveStoreToAffine : public OpRewritePattern { auto memrefType = store.getMemRef().getType().cast(); int64_t rank = memrefType.getRank(); + // Create identity map for memrefs with at least one dimension or () -> () // for zero-dimensional memrefs. - auto map = rank ? rewriter.getMultiDimIdentityMap(rank) - : rewriter.getEmptyAffineMap(); + SmallVector dimExprs; + dimExprs.reserve(rank); + for (unsigned i = 0; i < rank; ++i) + dimExprs.push_back(rewriter.getAffineSymbolExpr(i)); + auto map = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/rank, dimExprs, + rewriter.getContext()); SmallVector operands = store.getIndices(); - fully2ComposeAffineMapAndOperands(&map, &operands); + fully2ComposeAffineMapAndOperands(rewriter, &map, &operands); canonicalizeMapAndOperands(&map, &operands); rewriter.create(store.getLoc(), store.getValueToStore(), @@ -934,7 +1191,7 @@ template struct AffineFixup : public OpRewritePattern { auto prevOperands = operands; assert(map.getNumInputs() == operands.size()); - fully2ComposeAffineMapAndOperands(&map, &operands); + fully2ComposeAffineMapAndOperands(rewriter, &map, &operands); assert(map.getNumInputs() == operands.size()); canonicalizeMapAndOperands(&map, &operands); assert(map.getNumInputs() == operands.size()); @@ -1014,11 +1271,11 @@ struct CanonicalieForBounds : public OpRewritePattern { // llvm::errs() << "*********\n"; // ubMap.dump(); - fully2ComposeAffineMapAndOperands(&lbMap, &lbOperands); + fully2ComposeAffineMapAndOperands(rewriter, &lbMap, &lbOperands); canonicalizeMapAndOperands(&lbMap, &lbOperands); lbMap = removeDuplicateExprs(lbMap); - fully2ComposeAffineMapAndOperands(&ubMap, &ubOperands); + fully2ComposeAffineMapAndOperands(rewriter, &ubMap, &ubOperands); canonicalizeMapAndOperands(&ubMap, &ubOperands); ubMap = removeDuplicateExprs(ubMap); @@ -1059,7 +1316,7 @@ struct CanonicalizIfBounds : public OpRewritePattern { // llvm::errs() << "*********\n"; // ubMap.dump(); - fully2ComposeIntegerSetAndOperands(&map, &operands); + fully2ComposeIntegerSetAndOperands(rewriter, &map, &operands); canonicalizeSetAndOperands(&map, &operands); // map(s). @@ -1109,8 +1366,8 @@ struct MoveIfToAffine : public OpRewritePattern { } auto iset = - IntegerSet::get(/*dim*/ 2 * exprs.size(), /*symbol*/ 0, exprs, eqflags); - fully2ComposeIntegerSetAndOperands(&iset, &applies); + IntegerSet::get(/*dim*/ 0, /*symbol*/ 2 * exprs.size(), exprs, eqflags); + fully2ComposeIntegerSetAndOperands(rewriter, &iset, &applies); canonicalizeSetAndOperands(&iset, &applies); AffineIfOp affineIfOp = rewriter.create(ifOp.getLoc(), types, iset, applies, diff --git a/lib/polygeist/Passes/BarrierRemovalContinuation.cpp b/lib/polygeist/Passes/BarrierRemovalContinuation.cpp index 075dc533bbe8..be511a5588d9 100644 --- a/lib/polygeist/Passes/BarrierRemovalContinuation.cpp +++ b/lib/polygeist/Passes/BarrierRemovalContinuation.cpp @@ -48,7 +48,7 @@ static bool hasImmediateBarriers(scf::ParallelOp op) { /// Wrap the bodies of all parallel ops with immediate barriers, i.e. the /// parallel ops that will persist after the partial loop-to-cfg conversion, /// into an execute region op. -static void wrapPersistingLoopBodies(FuncOp function) { +static void wrapPersistingLoopBodies(FunctionOpInterface function) { SmallVector loops; function.walk([&](scf::ParallelOp op) { if (hasImmediateBarriers(op)) @@ -71,7 +71,7 @@ static void wrapPersistingLoopBodies(FuncOp function) { } /// Convert SCF constructs except parallel ops with immediate barriers to a CFG. -static LogicalResult applyCFGConversion(FuncOp function) { +static LogicalResult applyCFGConversion(FunctionOpInterface function) { RewritePatternSet patterns(function.getContext()); populateSCFToControlFlowConversionPatterns(patterns); @@ -91,7 +91,7 @@ static LogicalResult applyCFGConversion(FuncOp function) { /// Convert SCF constructs except parallel loops with immediate barriers to a /// CFG after wrapping the bodies of such loops in an execute_region op so as to /// comply with the single-block requirement of the body. -static LogicalResult convertToCFG(FuncOp function) { +static LogicalResult convertToCFG(FunctionOpInterface function) { wrapPersistingLoopBodies(function); return applyCFGConversion(function); } @@ -111,7 +111,7 @@ static void splitBlocksWithBarrier(Region ®ion) { /// Split blocks with barriers into parts in the parallel ops of the given /// function. -static LogicalResult splitBlocksWithBarrier(FuncOp function) { +static LogicalResult splitBlocksWithBarrier(FunctionOpInterface function) { WalkResult result = function.walk([](scf::ParallelOp op) -> WalkResult { if (!hasImmediateBarriers(op)) return success(); @@ -586,11 +586,12 @@ static void createContinuations(scf::ParallelOp parallel, Value storage) { parallel.erase(); } -static void createContinuations(FuncOp func) { - if (func->getNumRegions() == 0 || func.body().empty()) +static void createContinuations(FunctionOpInterface func) { + if (func->getNumRegions() == 0 || func.getBody().empty()) return; - OpBuilder allocaBuilder(&func.body().front(), func.body().front().begin()); + OpBuilder allocaBuilder(&func.getBody().front(), + func.getBody().front().begin()); func.walk([&](scf::ParallelOp parallel) { // Ignore parallel ops with no barriers. if (!hasImmediateBarriers(parallel)) @@ -605,7 +606,7 @@ namespace { struct BarrierRemoval : public SCFBarrierRemovalContinuationBase { void runOnOperation() override { - FuncOp f = getOperation(); + auto f = getOperation(); if (failed(convertToCFG(f))) return; if (failed(splitBlocksWithBarrier(f))) diff --git a/lib/polygeist/Passes/CanonicalizeFor.cpp b/lib/polygeist/Passes/CanonicalizeFor.cpp index 477ea9153c2e..edd9db025c64 100644 --- a/lib/polygeist/Passes/CanonicalizeFor.cpp +++ b/lib/polygeist/Passes/CanonicalizeFor.cpp @@ -215,6 +215,36 @@ struct RemoveUnusedArgs : public OpRewritePattern { } }; +struct ReplaceRedundantArgs : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ForOp op, + PatternRewriter &rewriter) const override { + auto yieldOp = cast(op.getBody()->getTerminator()); + bool replaced = false; + unsigned i = 0; + for (auto blockArg : op.getRegionIterArgs()) { + for (unsigned j = 0; j < i; j++) { + if (op.getOperand(op.getNumControlOperands() + i) == + op.getOperand(op.getNumControlOperands() + j) && + yieldOp.getOperand(i) == yieldOp.getOperand(j)) { + + rewriter.updateRootInPlace(op, [&] { + op.getResult(i).replaceAllUsesWith(op.getResult(j)); + blockArg.replaceAllUsesWith(op.getRegionIterArgs()[j]); + }); + replaced = true; + goto skip; + } + } + skip: + i++; + } + + return success(replaced); + } +}; + /* +struct RemoveNotIf : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; @@ -343,79 +373,82 @@ yop2.results()[idx]); +}; */ -bool isWhile(WhileOp wop) { - bool hasCondOp = false; - wop.getBefore().walk([&](Operation *op) { - if (isa(op)) - hasCondOp = true; - }); - return hasCondOp; +bool isTopLevelArgValue(Value value, Region *region) { + if (auto arg = value.dyn_cast()) + return arg.getParentRegion() == region; + return false; } -struct MoveWhileToFor : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - bool isTopLevelArgValue(Value value, Region *region) const { - if (auto arg = value.dyn_cast()) - return arg.getParentRegion() == region; - return false; - } - - bool isBlockArg(Value value) const { - if (auto arg = value.dyn_cast()) - return true; - return false; - } +bool isBlockArg(Value value) { + if (auto arg = value.dyn_cast()) + return true; + return false; +} - bool dominateWhile(Value value, WhileOp loop) const { - Operation *op = value.getDefiningOp(); - assert(op && "expect non-null"); - DominanceInfo dom(loop); - return dom.properlyDominates(op, loop); - } +bool dominateWhile(Value value, WhileOp loop) { + Operation *op = value.getDefiningOp(); + assert(op && "expect non-null"); + DominanceInfo dom(loop); + return dom.properlyDominates(op, loop); +} - bool canMoveOpOutsideWhile(Operation *op, WhileOp loop) const { - DominanceInfo dom(loop); - for (auto operand : op->getOperands()) { - if (!dom.properlyDominates(operand, loop)) - return false; - } - return true; +bool canMoveOpOutsideWhile(Operation *op, WhileOp loop) { + DominanceInfo dom(loop); + for (auto operand : op->getOperands()) { + if (!dom.properlyDominates(operand, loop)) + return false; } + return true; +} - LogicalResult matchAndRewrite(WhileOp loop, - PatternRewriter &rewriter) const override { - Value step = nullptr; - - struct LoopInfo { - Value ub = nullptr; - Value lb = nullptr; - } loopInfo; +struct WhileToForHelper { + WhileOp loop; + CmpIOp cmpIOp; + Value step; + Value lb; + bool lb_addOne; + Value ub; + bool ub_addOne; + bool ub_cloneMove; + bool negativeStep; + AddIOp addIOp; + BlockArgument indVar; + size_t afterArgIdx; + bool computeLegality(bool sizeCheck, Value lookThrough = nullptr) { + step = nullptr; + lb = nullptr; + lb_addOne = false; + ub = nullptr; + ub_addOne = false; + ub_cloneMove = false; + negativeStep = false; auto condOp = loop.getConditionOp(); - SmallVector results = {condOp.getArgs()}; - auto cmpIOp = condOp.getCondition().getDefiningOp(); - if (!cmpIOp) { - return failure(); - } - - BlockArgument indVar = cmpIOp.getLhs().dyn_cast(); + indVar = cmpIOp.getLhs().dyn_cast(); Type extType = nullptr; // todo handle ext if (auto ext = cmpIOp.getLhs().getDefiningOp()) { indVar = ext.getIn().dyn_cast(); extType = ext.getType(); } - if (!indVar) - return failure(); - if (indVar.getOwner() != &loop.getBefore().front()) - return failure(); + // Condition is not the same as an induction variable + { + if (!indVar) { + return false; + } - size_t size = loop.getBefore().front().getOperations().size(); - if (extType) - size--; - if (size != 2) { - return failure(); + if (indVar.getOwner() != &loop.getBefore().front()) + return false; + } + + // Before region contains more than just the comparison + if (sizeCheck) { + size_t size = loop.getBefore().front().getOperations().size(); + if (extType) + size--; + if (size != 2) { + return false; + } } SmallVector afterArgs; @@ -426,34 +459,58 @@ struct MoveWhileToFor : public OpRewritePattern { auto endYield = cast(loop.getAfter().back().getTerminator()); - auto addIOp = + // Check that the block argument is actually an induction var: + // Namely, its next value adds to the previous with an invariant step. + addIOp = endYield.getResults()[indVar.getArgNumber()].getDefiningOp(); - if (!addIOp) - return failure(); + if (!addIOp) { + if (auto ifOp = endYield.getResults()[indVar.getArgNumber()] + .getDefiningOp()) { + if (ifOp.getCondition() == lookThrough) { + for (auto r : llvm::enumerate(ifOp.getResults())) { + if (r.value() == endYield.getResults()[indVar.getArgNumber()]) { + addIOp = ifOp.thenYield() + .getOperand(r.index()) + .getDefiningOp(); + break; + } + } + } + } else if (auto selOp = endYield.getResults()[indVar.getArgNumber()] + .getDefiningOp()) { + if (selOp.getCondition() == lookThrough) + addIOp = selOp.getTrueValue().getDefiningOp(); + } + } + if (!addIOp) { + return false; + } for (auto afterArg : afterArgs) { auto arg = loop.getAfter().getArgument(afterArg); if (addIOp.getOperand(0) == arg) { step = addIOp.getOperand(1); + afterArgIdx = afterArg; break; } if (addIOp.getOperand(1) == arg) { step = addIOp.getOperand(0); + afterArgIdx = afterArg; break; } } if (!step) - return failure(); + return false; // Cannot transform for if step is not loop-invariant if (auto op = step.getDefiningOp()) { if (loop->isAncestor(op)) { - return failure(); + return false; } } - bool negativeStep = false; + negativeStep = false; if (auto cop = step.getDefiningOp()) { if (cop.value() < 0) { negativeStep = true; @@ -464,63 +521,74 @@ struct MoveWhileToFor : public OpRewritePattern { } if (!negativeStep) - loopInfo.lb = loop.getOperand(indVar.getArgNumber()); - else - loopInfo.ub = rewriter.create( - loop.getLoc(), loop.getOperand(indVar.getArgNumber()), - rewriter.create(loop.getLoc(), 1, indVar.getType())); + lb = loop.getOperand(indVar.getArgNumber()); + else { + ub = loop.getOperand(indVar.getArgNumber()); + ub_addOne = true; + } if (isBlockArg(cmpIOp.getRhs()) || dominateWhile(cmpIOp.getRhs(), loop)) { switch (cmpIOp.getPredicate()) { case CmpIPredicate::slt: case CmpIPredicate::ult: { - loopInfo.ub = cmpIOp.getRhs(); + ub = cmpIOp.getRhs(); break; } case CmpIPredicate::ule: case CmpIPredicate::sle: { - // TODO: f32 likely not always true. - auto one = rewriter.create(loop.getLoc(), 1, - cmpIOp.getRhs().getType()); - auto addIOp = - rewriter.create(loop.getLoc(), cmpIOp.getRhs(), one); - loopInfo.ub = addIOp.getResult(); + ub = cmpIOp.getRhs(); + ub_addOne = true; break; } case CmpIPredicate::uge: case CmpIPredicate::sge: { - loopInfo.lb = cmpIOp.getRhs(); + lb = cmpIOp.getRhs(); break; } case CmpIPredicate::ugt: case CmpIPredicate::sgt: { - // TODO: f32 likely not always true. - auto one = rewriter.create(loop.getLoc(), 1, - cmpIOp.getRhs().getType()); - auto addIOp = - rewriter.create(loop.getLoc(), cmpIOp.getRhs(), one); - loopInfo.lb = addIOp.getResult(); + lb = cmpIOp.getRhs(); + lb_addOne = true; break; } case CmpIPredicate::eq: case CmpIPredicate::ne: { - return failure(); + return false; } } } else { if (negativeStep) - return failure(); + return false; auto *op = cmpIOp.getRhs().getDefiningOp(); if (!op || !canMoveOpOutsideWhile(op, loop) || (op->getNumResults() != 1)) - return failure(); - auto newOp = rewriter.clone(*op); - loopInfo.ub = newOp->getResult(0); - cmpIOp.getRhs().replaceAllUsesWith(newOp->getResult(0)); + return false; + ub = cmpIOp.getRhs(); + ub_cloneMove = true; } - if ((!loopInfo.ub) || (!loopInfo.lb) || (!step)) - return failure(); + return lb && ub; + } + + void prepareFor(PatternRewriter &rewriter) { + Value one; + if (lb_addOne) { + Value one = + rewriter.create(loop.getLoc(), 1, lb.getType()); + lb = rewriter.create(loop.getLoc(), lb, one); + } + if (ub_cloneMove) { + auto op = ub.getDefiningOp(); + assert(op); + auto newOp = rewriter.clone(*op); + rewriter.replaceOp(op, newOp->getResults()); + ub = newOp->getResult(0); + } + if (ub_addOne) { + Value one = + rewriter.create(loop.getLoc(), 1, ub.getType()); + ub = rewriter.create(loop.getLoc(), ub, one); + } if (negativeStep) { if (auto cop = step.getDefiningOp()) { @@ -531,12 +599,31 @@ struct MoveWhileToFor : public OpRewritePattern { } } - Value ub = rewriter.create( - loop.getLoc(), IndexType::get(loop.getContext()), loopInfo.ub); - Value lb = rewriter.create( - loop.getLoc(), IndexType::get(loop.getContext()), loopInfo.lb); + ub = rewriter.create(loop.getLoc(), + IndexType::get(loop.getContext()), ub); + lb = rewriter.create(loop.getLoc(), + IndexType::get(loop.getContext()), lb); step = rewriter.create( loop.getLoc(), IndexType::get(loop.getContext()), step); + } +}; + +struct MoveWhileToFor : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(WhileOp loop, + PatternRewriter &rewriter) const override { + auto condOp = loop.getConditionOp(); + SmallVector results = {condOp.getArgs()}; + WhileToForHelper helper; + helper.loop = loop; + helper.cmpIOp = condOp.getCondition().getDefiningOp(); + if (!helper.cmpIOp) { + return failure(); + } + if (!helper.computeLegality(/*sizeCheck*/ true)) + return failure(); + helper.prepareFor(rewriter); // input of the for goes the input of the scf::while plus the output taken // from the conditionOp. @@ -562,8 +649,8 @@ struct MoveWhileToFor : public OpRewritePattern { forArgs.push_back(res); } - auto forloop = - rewriter.create(loop.getLoc(), lb, ub, step, forArgs); + auto forloop = rewriter.create(loop.getLoc(), helper.lb, + helper.ub, helper.step, forArgs); if (!forloop.getBody()->empty()) rewriter.eraseOp(forloop.getBody()->getTerminator()); @@ -612,6 +699,170 @@ struct MoveWhileToFor : public OpRewritePattern { } }; +// If and and with something is preventing creating a for +// move the and into the after body guarded by an if +struct MoveWhileAndDown : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(WhileOp loop, + PatternRewriter &rewriter) const override { + auto condOp = loop.getConditionOp(); + auto andIOp = condOp.getCondition().getDefiningOp(); + if (!andIOp) + return failure(); + for (int i = 0; i < 2; i++) { + WhileToForHelper helper; + helper.loop = loop; + helper.cmpIOp = andIOp->getOperand(i).getDefiningOp(); + if (!helper.cmpIOp) + continue; + + YieldOp oldYield = cast(loop.getAfter().front().getTerminator()); + + Value extraCmp = andIOp->getOperand(1 - i); + Value lookThrough = nullptr; + if (auto BA = extraCmp.dyn_cast()) { + lookThrough = oldYield.getOperand(BA.getArgNumber()); + } + if (!helper.computeLegality(/*sizeCheck*/ false, lookThrough)) + continue; + + SmallVector origBeforeArgs( + loop.getBeforeArguments().begin(), loop.getBeforeArguments().end()); + + SmallVector origAfterArgs( + loop.getAfterArguments().begin(), loop.getAfterArguments().end()); + + BlockAndValueMapping preMap; + for (auto tup : llvm::zip(origBeforeArgs, loop.getInits())) + preMap.map(std::get<0>(tup), std::get<1>(tup)); + for (auto &op : loop.getBefore().front()) { + if (&op == condOp) + break; + preMap.map(op.getResults(), rewriter.clone(op, preMap)->getResults()); + } + IfOp unroll = rewriter.create(loop.getLoc(), loop.getResultTypes(), + preMap.lookup(condOp.getCondition())); + + if (unroll.getThenRegion().getBlocks().size()) + rewriter.eraseBlock(unroll.thenBlock()); + rewriter.createBlock(&unroll.getThenRegion()); + rewriter.createBlock(&unroll.getElseRegion()); + + rewriter.setInsertionPointToEnd(unroll.elseBlock()); + SmallVector unrollYield; + for (auto v : condOp.getArgs()) + unrollYield.push_back(preMap.lookup(v)); + rewriter.create(loop.getLoc(), unrollYield); + rewriter.setInsertionPointToEnd(unroll.thenBlock()); + + SmallVector nextInits(unrollYield.begin(), unrollYield.end()); + Value falsev = + rewriter.create(loop.getLoc(), 0, extraCmp.getType()); + Value truev = + rewriter.create(loop.getLoc(), 1, extraCmp.getType()); + nextInits.push_back(truev); + nextInits.push_back(loop.getInits()[helper.indVar.getArgNumber()]); + + SmallVector resTys; + for (auto a : nextInits) + resTys.push_back(a.getType()); + + auto nop = rewriter.create(loop.getLoc(), resTys, nextInits); + rewriter.createBlock(&nop.getBefore()); + SmallVector newBeforeYieldArgs; + for (auto a : origAfterArgs) { + auto arg = nop.getBefore().addArgument(a.getType(), a.getLoc()); + newBeforeYieldArgs.push_back(arg); + } + Value notExited = nop.getBefore().front().addArgument(extraCmp.getType(), + loop.getLoc()); + newBeforeYieldArgs.push_back(notExited); + + Value trueInd = nop.getBefore().front().addArgument( + helper.indVar.getType(), loop.getLoc()); + newBeforeYieldArgs.push_back(trueInd); + + { + BlockAndValueMapping postMap; + postMap.map(helper.indVar, trueInd); + auto newCmp = cast(rewriter.clone(*helper.cmpIOp, postMap)); + rewriter.create(condOp.getLoc(), newCmp, + newBeforeYieldArgs); + } + + rewriter.createBlock(&nop.getAfter()); + SmallVector postElseYields; + for (auto a : origAfterArgs) { + auto arg = nop.getAfter().front().addArgument(a.getType(), a.getLoc()); + postElseYields.push_back(arg); + a.replaceAllUsesWith(arg); + } + SmallVector resultTypes(loop.getResultTypes()); + resultTypes.push_back(notExited.getType()); + notExited = nop.getAfter().front().addArgument(notExited.getType(), + loop.getLoc()); + + trueInd = + nop.getAfter().front().addArgument(trueInd.getType(), loop.getLoc()); + + IfOp guard = rewriter.create(loop.getLoc(), resultTypes, notExited); + if (guard.getThenRegion().getBlocks().size()) + rewriter.eraseBlock(guard.thenBlock()); + Block *post = rewriter.splitBlock(&loop.getAfter().front(), + loop.getAfter().front().begin()); + rewriter.createBlock(&guard.getThenRegion()); + rewriter.createBlock(&guard.getElseRegion()); + rewriter.mergeBlocks(post, guard.thenBlock()); + + { + BlockAndValueMapping postMap; + for (auto tup : llvm::zip(origBeforeArgs, oldYield.getOperands())) { + postMap.map(std::get<0>(tup), std::get<1>(tup)); + } + rewriter.setInsertionPoint(oldYield); + for (auto &op : loop.getBefore().front()) { + if (&op == condOp) + break; + postMap.map(op.getResults(), + rewriter.clone(op, postMap)->getResults()); + } + SmallVector postIfYields; + for (auto a : condOp.getArgs()) { + postIfYields.push_back(postMap.lookup(a)); + } + postIfYields.push_back(postMap.lookup(extraCmp)); + oldYield->setOperands(postIfYields); + } + + rewriter.setInsertionPointToEnd(guard.elseBlock()); + postElseYields.push_back(falsev); + rewriter.create(loop.getLoc(), postElseYields); + + rewriter.setInsertionPointToEnd(&nop.getAfter().front()); + SmallVector postAfter(guard.getResults()); + BlockAndValueMapping postMap; + postMap.map(helper.indVar, trueInd); + postMap.map(postElseYields[helper.afterArgIdx], trueInd); + assert(helper.addIOp.getLhs() == postElseYields[helper.afterArgIdx] || + helper.addIOp.getRhs() == postElseYields[helper.afterArgIdx]); + postAfter.push_back( + cast(rewriter.clone(*helper.addIOp, postMap))); + rewriter.create(loop.getLoc(), postAfter); + + rewriter.setInsertionPointToEnd(unroll.thenBlock()); + rewriter.create( + loop.getLoc(), nop.getResults().take_front(loop.getResults().size())); + + rewriter.replaceOp(loop, unroll.getResults()); + + return success(); + } + + return failure(); + } +}; + struct MoveWhileDown : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -926,14 +1177,7 @@ struct MoveWhileInvariantIfResult : public OpRewritePattern { if (!std::get<0>(pair).use_empty()) { if (auto ifOp = std::get<1>(pair).getDefiningOp()) { if (ifOp.getCondition() == term.getCondition()) { - ssize_t idx = -1; - for (auto tup : llvm::enumerate(ifOp.getResults())) { - if (tup.value() == std::get<1>(pair)) { - idx = tup.index(); - break; - } - } - assert(idx != -1); + auto idx = std::get<1>(pair).cast().getResultNumber(); Value returnWith = ifOp.elseYield().getResults()[idx]; if (!op.getBefore().isAncestor(returnWith.getParentRegion())) { rewriter.updateRootInPlace(op, [&] { @@ -942,6 +1186,17 @@ struct MoveWhileInvariantIfResult : public OpRewritePattern { changed = true; } } + } else if (auto selOp = + std::get<1>(pair).getDefiningOp()) { + if (selOp.getCondition() == term.getCondition()) { + Value returnWith = selOp.getFalseValue(); + if (!op.getBefore().isAncestor(returnWith.getParentRegion())) { + rewriter.updateRootInPlace(op, [&] { + std::get<0>(pair).replaceAllUsesWith(returnWith); + }); + changed = true; + } + } } } } @@ -1084,6 +1339,69 @@ struct WhileCmpOffset : public OpRewritePattern { } }; +/// Given a while loop which yields a select whose condition is +/// the same as the condition, remove the select. +struct RemoveWhileSelect : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(WhileOp loop, + PatternRewriter &rewriter) const override { + scf::ConditionOp term = + cast(loop.getBefore().front().getTerminator()); + + SmallVector origAfterArgs( + loop.getAfterArguments().begin(), loop.getAfterArguments().end()); + SmallVector newResults; + SmallVector newAfter; + SmallVector newYields; + bool changed = false; + for (auto pair : + llvm::zip(loop.getResults(), term.getArgs(), origAfterArgs)) { + auto selOp = std::get<1>(pair).getDefiningOp(); + if (!selOp || selOp.getCondition() != term.getCondition()) { + newResults.push_back(newYields.size()); + newAfter.push_back(newYields.size()); + newYields.push_back(std::get<1>(pair)); + continue; + } + newResults.push_back(newYields.size()); + newYields.push_back(selOp.getFalseValue()); + newAfter.push_back(newYields.size()); + newYields.push_back(selOp.getTrueValue()); + changed = true; + } + if (!changed) + return failure(); + + SmallVector resultTypes; + for (auto v : newYields) { + resultTypes.push_back(v.getType()); + } + auto nop = + rewriter.create(loop.getLoc(), resultTypes, loop.getInits()); + + nop.getBefore().takeBody(loop.getBefore()); + + auto after = rewriter.createBlock(&nop.getAfter()); + for (auto y : newYields) + after->addArgument(y.getType(), loop.getLoc()); + + SmallVector replacedArgs; + for (auto idx : newAfter) + replacedArgs.push_back(after->getArgument(idx)); + rewriter.mergeBlocks(&loop.getAfter().front(), after, replacedArgs); + + SmallVector replacedReturns; + for (auto idx : newResults) + replacedReturns.push_back(nop.getResult(idx)); + rewriter.replaceOp(loop, replacedReturns); + rewriter.setInsertionPoint(term); + rewriter.replaceOpWithNewOp(term, term.getCondition(), + newYields); + return success(); + } +}; + struct MoveWhileDown3 : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1408,17 +1726,20 @@ struct ReturnSq : public OpRewritePattern { return success(changed); } }; + void CanonicalizeFor::runOnOperation() { mlir::RewritePatternSet rpl(getOperation()->getContext()); rpl.add(getOperation()->getContext()); + MoveWhileAndDown, MoveWhileDown3, MoveWhileInvariantIfResult, + WhileLogicalNegation, SubToAdd, WhileCmpOffset, WhileLICM, + RemoveUnusedCondVar, ReturnSq, MoveSideEffectFreeWhile>( + getOperation()->getContext()); GreedyRewriteConfig config; config.maxIterations = 47; (void)applyPatternsAndFoldGreedily(getOperation(), std::move(rpl), config); diff --git a/lib/polygeist/Passes/ConvertPolygeistToLLVM.cpp b/lib/polygeist/Passes/ConvertPolygeistToLLVM.cpp index 0289cfb44353..47587f2cd0bf 100644 --- a/lib/polygeist/Passes/ConvertPolygeistToLLVM.cpp +++ b/lib/polygeist/Passes/ConvertPolygeistToLLVM.cpp @@ -14,13 +14,13 @@ #include "mlir/Analysis/DataLayoutAnalysis.h" #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/Passes.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -386,7 +386,7 @@ struct ConvertPolygeistToLLVMPass populateSCFToControlFlowConversionPatterns(patterns); cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); populateMemRefToLLVMConversionPatterns(converter, patterns); - populateStdToLLVMConversionPatterns(converter, patterns); + populateFuncToLLVMConversionPatterns(converter, patterns); populateMathToLLVMConversionPatterns(converter, patterns); populateOpenMPToLLVMConversionPatterns(converter, patterns); arith::populateArithmeticToLLVMConversionPatterns(converter, patterns); diff --git a/lib/polygeist/Passes/Mem2Reg.cpp b/lib/polygeist/Passes/Mem2Reg.cpp index 31327a510da3..00dc3bb5d049 100644 --- a/lib/polygeist/Passes/Mem2Reg.cpp +++ b/lib/polygeist/Passes/Mem2Reg.cpp @@ -33,6 +33,7 @@ #include #include "polygeist/Ops.h" +#include "polygeist/Passes/Utils.h" #define DEBUG_TYPE "mem2reg" @@ -174,6 +175,7 @@ class ReplacementHandler { ValueOrPlaceholder *get(Value val); ValueOrPlaceholder *get(Block *val); ValueOrPlaceholder *get(scf::IfOp val, ValueOrPlaceholder *ifVal); + ValueOrPlaceholder *get(AffineIfOp val, ValueOrPlaceholder *ifVal); ValueOrPlaceholder *get(scf::ExecuteRegionOp val); void replaceValue(Value orig, Value post); @@ -191,7 +193,7 @@ class ValueOrPlaceholder { Value val; Block *valueAtStart; scf::ExecuteRegionOp exOp; - scf::IfOp ifOp; + Operation *ifOp; ValueOrPlaceholder(ValueOrPlaceholder &&) = delete; ValueOrPlaceholder(const ValueOrPlaceholder &) = delete; ValueOrPlaceholder(std::nullptr_t, ReplacementHandler &metaMap) @@ -220,6 +222,14 @@ class ValueOrPlaceholder { if (ifLastVal) metaMap.opOperands[ifOp] = ifLastVal; } + ValueOrPlaceholder(AffineIfOp ifOp, ReplaceableUse ifLastVal, + ReplacementHandler &metaMap) + : metaMap(metaMap), overwritten(false), val(nullptr), + valueAtStart(nullptr), exOp(nullptr), ifOp(ifOp) { + assert(ifOp); + if (ifLastVal) + metaMap.opOperands[ifOp] = ifLastVal; + } // Return true if this represents a full expression if all block argsare // defined at start Append the list of blocks requiring definition to block. bool definedWithArg(SmallPtrSetImpl &block) { @@ -237,26 +247,50 @@ class ValueOrPlaceholder { return true; } if (ifOp) { - auto thenFind = metaMap.valueAtEndOfBlock.find(ifOp.thenBlock()); - assert(thenFind != metaMap.valueAtEndOfBlock.end()); - assert(thenFind->second); - if (!thenFind->second->definedWithArg(block)) - return false; - - if (ifOp.getElseRegion().getBlocks().size()) { - auto elseFind = metaMap.valueAtEndOfBlock.find(ifOp.elseBlock()); - assert(elseFind != metaMap.valueAtEndOfBlock.end()); - assert(elseFind->second); - if (!elseFind->second->definedWithArg(block)) + if (auto sifOp = dyn_cast(ifOp)) { + auto thenFind = metaMap.valueAtEndOfBlock.find(getThenBlock(sifOp)); + assert(thenFind != metaMap.valueAtEndOfBlock.end()); + assert(thenFind->second); + if (!thenFind->second->definedWithArg(block)) return false; + + if (hasElse(sifOp)) { + auto elseFind = metaMap.valueAtEndOfBlock.find(getElseBlock(sifOp)); + assert(elseFind != metaMap.valueAtEndOfBlock.end()); + assert(elseFind->second); + if (!elseFind->second->definedWithArg(block)) + return false; + } else { + auto opFound = metaMap.opOperands.find(sifOp); + assert(opFound != metaMap.opOperands.end()); + auto ifLastValue = opFound->second; + if (!ifLastValue->definedWithArg(block)) + return false; + } + return true; } else { - auto opFound = metaMap.opOperands.find(ifOp); - assert(opFound != metaMap.opOperands.end()); - auto ifLastValue = opFound->second; - if (!ifLastValue->definedWithArg(block)) + auto aifOp = cast(ifOp); + auto thenFind = metaMap.valueAtEndOfBlock.find(getThenBlock(aifOp)); + assert(thenFind != metaMap.valueAtEndOfBlock.end()); + assert(thenFind->second); + if (!thenFind->second->definedWithArg(block)) return false; + + if (hasElse(aifOp)) { + auto elseFind = metaMap.valueAtEndOfBlock.find(getElseBlock(aifOp)); + assert(elseFind != metaMap.valueAtEndOfBlock.end()); + assert(elseFind->second); + if (!elseFind->second->definedWithArg(block)) + return false; + } else { + auto opFound = metaMap.opOperands.find(ifOp); + assert(opFound != metaMap.opOperands.end()); + auto ifLastValue = opFound->second; + if (!ifLastValue->definedWithArg(block)) + return false; + } + return true; } - return true; } if (exOp) { for (auto &B : exOp.getRegion()) { @@ -402,8 +436,17 @@ class ValueOrPlaceholder { this->exOp = nullptr; return this->val; } + Value materializeIf(bool full = true) { - auto thenFind = metaMap.valueAtEndOfBlock.find(ifOp.thenBlock()); + if (auto sop = dyn_cast(ifOp)) + return materializeIf(sop, full); + return materializeIf(cast(ifOp), + full); + } + + template + Value materializeIf(IfType ifOp, bool full = true) { + auto thenFind = metaMap.valueAtEndOfBlock.find(getThenBlock(ifOp)); assert(thenFind != metaMap.valueAtEndOfBlock.end()); assert(thenFind->second); Value thenVal = thenFind->second->materialize(full); @@ -423,8 +466,8 @@ class ValueOrPlaceholder { } Value elseVal; - if (ifOp.getElseRegion().getBlocks().size()) { - auto elseFind = metaMap.valueAtEndOfBlock.find(ifOp.elseBlock()); + if (hasElse(ifOp)) { + auto elseFind = metaMap.valueAtEndOfBlock.find(getElseBlock(ifOp)); assert(elseFind != metaMap.valueAtEndOfBlock.end()); assert(elseFind->second); elseVal = elseFind->second->materialize(full); @@ -464,10 +507,10 @@ class ValueOrPlaceholder { return thenVal; } - if (ifOp.getElseRegion().getBlocks().size()) { + if (hasElse(ifOp)) { for (auto tup : llvm::reverse( - llvm::zip(ifOp.getResults(), ifOp.thenYield().getOperands(), - ifOp.elseYield().getOperands()))) { + llvm::zip(ifOp.getResults(), getThenYield(ifOp).getOperands(), + getElseYield(ifOp).getOperands()))) { if (std::get<1>(tup) == thenVal && std::get<2>(tup) == elseVal) { return thenVal; } @@ -479,25 +522,24 @@ class ValueOrPlaceholder { SmallVector tys(ifOp.getResultTypes().begin(), ifOp.getResultTypes().end()); tys.push_back(thenVal.getType()); - auto nextIf = B.create( - ifOp.getLoc(), tys, ifOp.getCondition(), /*hasElse*/ true); + auto nextIf = cloneWithoutResults(ifOp, B, {}, tys); - SmallVector thenVals = ifOp.thenYield().getResults(); + SmallVector thenVals = getThenYield(ifOp).getOperands(); thenVals.push_back(thenVal); - nextIf.getThenRegion().takeBody(ifOp.getThenRegion()); - nextIf.thenYield()->setOperands(thenVals); + getThenRegion(nextIf).takeBody(getThenRegion(ifOp)); + getThenYield(nextIf)->setOperands(thenVals); - if (ifOp.getElseRegion().getBlocks().size()) { - nextIf.getElseRegion().getBlocks().clear(); - SmallVector elseVals = ifOp.elseYield().getResults(); + if (hasElse(ifOp)) { + getElseRegion(nextIf).getBlocks().clear(); + SmallVector elseVals = getElseYield(ifOp).getOperands(); elseVals.push_back(elseVal); - nextIf.getElseRegion().takeBody(ifOp.getElseRegion()); - nextIf.elseYield()->setOperands(elseVals); + getElseRegion(nextIf).takeBody(getElseRegion(ifOp)); + getElseYield(nextIf)->setOperands(elseVals); } else { - B.setInsertionPoint(&nextIf.getElseRegion().back(), - nextIf.getElseRegion().back().begin()); + B.setInsertionPoint(&getElseRegion(nextIf).back(), + getElseRegion(nextIf).back().begin()); SmallVector elseVals = {elseVal}; - B.create(ifOp.getLoc(), elseVals); + B.create(ifOp.getLoc(), elseVals); } SmallVector resvals = nextIf.getResults(); @@ -517,6 +559,7 @@ class ValueOrPlaceholder { return this->val; } }; + static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, ValueOrPlaceholder &PH) { if (PH.overwritten) @@ -559,6 +602,12 @@ ValueOrPlaceholder *ReplacementHandler::get(scf::IfOp val, allocs.emplace_back(PH = new ValueOrPlaceholder(val, ifVal, *this)); return PH; } +ValueOrPlaceholder *ReplacementHandler::get(AffineIfOp val, + ValueOrPlaceholder *ifVal) { + ValueOrPlaceholder *PH; + allocs.emplace_back(PH = new ValueOrPlaceholder(val, ifVal, *this)); + return PH; +} ValueOrPlaceholder *ReplacementHandler::get(scf::ExecuteRegionOp val) { ValueOrPlaceholder *PH; allocs.emplace_back(PH = new ValueOrPlaceholder(val, *this)); @@ -1335,6 +1384,15 @@ bool Mem2Reg::forwardStoreToLoad(mlir::Value AI, std::vector idx, lastVal = metaMap.get(ifOp, lastVal); } continue; + } else if (auto ifOp = dyn_cast(a)) { + handleBlock(*ifOp.thenRegion().begin(), lastVal); + if (ifOp.elseRegion().getBlocks().size()) { + handleBlock(*ifOp.elseRegion().begin(), lastVal); + lastVal = metaMap.get(ifOp, emptyValue); + } else { + lastVal = metaMap.get(ifOp, lastVal); + } + continue; } LLVM_DEBUG(llvm::dbgs() << "erased store due to: " << *a << "\n"); @@ -1784,8 +1842,7 @@ StoreMap getLastStored(mlir::Value AI) { } void Mem2Reg::runOnOperation() { - // Only supports single block functions at the moment. - FuncOp f = getOperation(); + auto f = getOperation(); // Variable indicating that a memref has had a load removed // and or been deleted. Because there can be memrefs of @@ -1803,22 +1860,22 @@ void Mem2Reg::runOnOperation() { // Walk all load's and perform store to load forwarding. SmallVector toPromote; - f.walk([&](mlir::memref::AllocaOp AI) { + f->walk([&](mlir::memref::AllocaOp AI) { if (isPromotable(AI)) { toPromote.push_back(AI); } }); - f.walk([&](mlir::memref::AllocOp AI) { + f->walk([&](mlir::memref::AllocOp AI) { if (isPromotable(AI)) { toPromote.push_back(AI); } }); - f.walk([&](LLVM::AllocaOp AI) { + f->walk([&](LLVM::AllocaOp AI) { if (isPromotable(AI)) { toPromote.push_back(AI); } }); - f.walk([&](memref::GetGlobalOp AI) { + f->walk([&](memref::GetGlobalOp AI) { if (isPromotable(AI)) { toPromote.push_back(AI); } diff --git a/lib/polygeist/Passes/OpenMPOpt.cpp b/lib/polygeist/Passes/OpenMPOpt.cpp index 852ccec8b2cf..9c1bda86112f 100644 --- a/lib/polygeist/Passes/OpenMPOpt.cpp +++ b/lib/polygeist/Passes/OpenMPOpt.cpp @@ -38,6 +38,35 @@ struct OpenMPOpt : public OpenMPOptPassBase { /// omp.barrier /// codeB(); /// } +bool isReadOnly(Operation *op) { + bool hasRecursiveEffects = op->hasTrait(); + if (hasRecursiveEffects) { + for (Region ®ion : op->getRegions()) { + for (auto &block : region) { + for (auto &nestedOp : block) + if (!isReadOnly(&nestedOp)) + return false; + } + } + } + + // If the op has memory effects, try to characterize them to see if the op + // is trivially dead here. + if (auto effectInterface = dyn_cast(op)) { + // Check to see if this op either has no effects, or only allocates/reads + // memory. + SmallVector effects; + effectInterface.getEffects(effects); + if (!llvm::all_of(effects, [op](const MemoryEffects::EffectInstance &it) { + return isa(it.getEffect()); + })) { + return false; + } + return true; + } + return false; +} + struct CombineParallel : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -47,10 +76,47 @@ struct CombineParallel : public OpRewritePattern { if (nextParallel == &parent->front()) return failure(); - auto prevParallel = dyn_cast(nextParallel->getPrevNode()); - if (!prevParallel) + // Only attempt this if there is another parallel within the function, which + // is not contained within this operation. + bool noncontained = false; + nextParallel->getParentOfType()->walk([&](omp::ParallelOp other) { + if (!nextParallel->isAncestor(other)) { + noncontained = true; + } + }); + if (!noncontained) return failure(); + omp::ParallelOp prevParallel; + SmallVector prevOps; + + bool changed = false; + + for (Operation *prevOp = nextParallel->getPrevNode(); 1;) { + if (prevParallel = dyn_cast(prevOp)) { + break; + } + // We can move this into the parallel if it only reads + if (isReadOnly(prevOp) && + llvm::all_of(prevOp->getResults(), [&](Value v) { + return llvm::all_of(v.getUsers(), [&](Operation *user) { + return nextParallel->isAncestor(user); + }); + })) { + auto prevIter = + (prevOp == &parent->front()) ? nullptr : prevOp->getPrevNode(); + rewriter.setInsertionPointToStart(&nextParallel.getRegion().front()); + auto replacement = rewriter.clone(*prevOp); + rewriter.replaceOp(prevOp, replacement->getResults()); + changed = true; + if (!prevIter) + return success(); + prevOp = prevIter; + continue; + } + return success(changed); + } + rewriter.setInsertionPointToEnd(&prevParallel.getRegion().front()); rewriter.replaceOpWithNewOp( prevParallel.getRegion().front().getTerminator(), TypeRange()); @@ -90,9 +156,36 @@ struct ParallelForInterchange : public OpRewritePattern { } }; +struct ParallelIfInterchange : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(omp::ParallelOp nextParallel, + PatternRewriter &rewriter) const override { + Block *parent = nextParallel->getBlock(); + if (parent->getOperations().size() != 2) + return failure(); + + auto prevIf = dyn_cast(nextParallel->getParentOp()); + if (!prevIf || prevIf->getResults().size()) + return failure(); + + nextParallel->moveBefore(prevIf); + auto yield = nextParallel.getRegion().front().getTerminator(); + auto contents = + rewriter.splitBlock(&nextParallel.getRegion().front(), + nextParallel.getRegion().front().begin()); + rewriter.mergeBlockBefore(contents, &prevIf.getBody()->front()); + rewriter.setInsertionPointToEnd(&nextParallel.getRegion().front()); + auto newYield = rewriter.clone(*yield); + rewriter.eraseOp(yield); + prevIf->moveBefore(newYield); + return success(); + } +}; + void OpenMPOpt::runOnOperation() { mlir::RewritePatternSet rpl(getOperation()->getContext()); - rpl.add( + rpl.add( getOperation()->getContext()); GreedyRewriteConfig config; config.maxIterations = 47; diff --git a/lib/polygeist/Passes/ParallelLoopDistribute.cpp b/lib/polygeist/Passes/ParallelLoopDistribute.cpp index 6d3c983b3cd3..a2a847546ef3 100644 --- a/lib/polygeist/Passes/ParallelLoopDistribute.cpp +++ b/lib/polygeist/Passes/ParallelLoopDistribute.cpp @@ -26,6 +26,7 @@ #include "polygeist/BarrierUtils.h" #include "polygeist/Ops.h" #include "polygeist/Passes/Passes.h" +#include "polygeist/Passes/Utils.h" #include #define DEBUG_TYPE "cpuify" @@ -35,58 +36,6 @@ using namespace mlir; using namespace mlir::arith; using namespace polygeist; -scf::IfOp cloneWithoutResults(scf::IfOp op, PatternRewriter &rewriter, - BlockAndValueMapping mapping = {}) { - return rewriter.create(op.getLoc(), TypeRange(), - mapping.lookupOrDefault(op.getCondition()), - true); -} - -AffineIfOp cloneWithoutResults(AffineIfOp op, PatternRewriter &rewriter, - BlockAndValueMapping mapping = {}) { - SmallVector lower; - for (auto o : op.getOperands()) - lower.push_back(mapping.lookupOrDefault(o)); - return rewriter.create(op.getLoc(), TypeRange(), - op.getIntegerSet(), lower, true); -} - -scf::ForOp cloneWithoutResults(scf::ForOp op, PatternRewriter &rewriter, - BlockAndValueMapping mapping = {}) { - return rewriter.create( - op.getLoc(), mapping.lookupOrDefault(op.getLowerBound()), - mapping.lookupOrDefault(op.getUpperBound()), - mapping.lookupOrDefault(op.getStep())); -} -AffineForOp cloneWithoutResults(AffineForOp op, PatternRewriter &rewriter, - BlockAndValueMapping mapping = {}) { - SmallVector lower; - for (auto o : op.getLowerBoundOperands()) - lower.push_back(mapping.lookupOrDefault(o)); - SmallVector upper; - for (auto o : op.getUpperBoundOperands()) - upper.push_back(mapping.lookupOrDefault(o)); - return rewriter.create(op.getLoc(), lower, op.getLowerBoundMap(), - upper, op.getUpperBoundMap(), - op.getStep()); -} - -Block *getThenBlock(scf::IfOp op) { return op.thenBlock(); } -Block *getThenBlock(AffineIfOp op) { return op.getThenBlock(); } -Block *getElseBlock(scf::IfOp op) { return op.elseBlock(); } -Block *getElseBlock(AffineIfOp op) { return op.getElseBlock(); } -bool inBound(scf::IfOp op, Value v) { return op.getCondition() == v; } -bool inBound(AffineIfOp op, Value v) { - return llvm::any_of(op.getOperands(), [&](Value e) { return e == v; }); -} -bool inBound(scf::ForOp op, Value v) { return op.getUpperBound() == v; } -bool inBound(AffineForOp op, Value v) { - return llvm::any_of(op.getUpperBoundOperands(), - [&](Value e) { return e == v; }); -} -bool hasElse(scf::IfOp op) { return op.getElseRegion().getBlocks().size() > 0; } -bool hasElse(AffineIfOp op) { return op.elseRegion().getBlocks().size() > 0; } - static bool couldWrite(Operation *op) { if (auto iface = dyn_cast(op)) { SmallVector localEffects; @@ -397,9 +346,10 @@ static LogicalResult wrapWithBarriers( /// Puts a barrier before and/or after an "if" operation if there isn't already /// one, potentially with a single load that supplies the upper bound of a /// (normalized) loop. -struct WrapIfWithBarrier : public OpRewritePattern { - WrapIfWithBarrier(MLIRContext *ctx) : OpRewritePattern(ctx) {} - LogicalResult matchAndRewrite(scf::IfOp op, +template +struct WrapIfWithBarrier : public OpRewritePattern { + WrapIfWithBarrier(MLIRContext *ctx) : OpRewritePattern(ctx) {} + LogicalResult matchAndRewrite(IfType op, PatternRewriter &rewriter) const override { SmallVector vals; if (failed(canWrapWithBarriers(op, vals))) @@ -419,7 +369,7 @@ struct WrapIfWithBarrier : public OpRewritePattern { return wrapWithBarriers(op, rewriter, vals, [&](Operation *prevOp) { if (auto loadOp = dyn_cast_or_null(prevOp)) { - if (loadOp.result() == op.getCondition() && + if (inBound(op, loadOp.result()) && llvm::all_of(loadOp.indices(), [&](Value v) { return indVars.contains(v); })) { prevOp = prevOp->getPrevNode(); @@ -698,9 +648,9 @@ static void moveBodiesFor(PatternRewriter &rewriter, T op, ForType forLoop, rewriter.mergeBlockBefore(op.getBody(), &newParallel.getBody()->back(), newParallel.getBody()->getArguments()); rewriter.eraseOp(&newParallel.getBody()->back()); + rewriter.eraseOp(&forLoop.getBody()->back()); rewriter.mergeBlockBefore(forLoop.getBody(), &newParallel.getBody()->back(), newForLoop.getBody()->getArguments()); - rewriter.eraseOp(&newParallel.getBody()->back()); rewriter.eraseOp(op); rewriter.eraseOp(forLoop); } @@ -1032,7 +982,7 @@ struct DistributeAroundBarrier : public OpRewritePattern { LogicalResult splitSubLoop(T op, PatternRewriter &rewriter, BarrierOp barrier, SmallVector &iterCounts, T &preLoop, T &postLoop, Block *&outerBlock, T &outerLoop, - scf::ExecuteRegionOp &outerEx) const; + memref::AllocaScopeOp &outerEx) const; LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const override { @@ -1069,7 +1019,7 @@ struct DistributeAroundBarrier : public OpRewritePattern { Block *outerBlock; T outerLoop = nullptr; - scf::ExecuteRegionOp outerEx = nullptr; + memref::AllocaScopeOp outerEx = nullptr; if (splitSubLoop(op, rewriter, barrier, iterCounts, preLoop, postLoop, outerBlock, outerLoop, outerEx) @@ -1105,12 +1055,11 @@ struct DistributeAroundBarrier : public OpRewritePattern { DataLayout DLI(mod); for (Value v : crossing) { if (auto ao = v.getDefiningOp()) { - allocations.push_back(allocateTemporaryBuffer( - rewriter, v, iterCounts, true, &DLI) - .getResult(0)); + allocations.push_back(allocateTemporaryBuffer( + rewriter, v, iterCounts, true, &DLI)); } else { allocations.push_back( - allocateTemporaryBuffer(rewriter, v, iterCounts)); + allocateTemporaryBuffer(rewriter, v, iterCounts)); } } @@ -1196,14 +1145,6 @@ struct DistributeAroundBarrier : public OpRewritePattern { // Create the second loop. rewriter.setInsertionPointToEnd(outerBlock); - auto freefn = GetOrCreateFreeFunction(mod); - for (auto alloc : allocations) { - if (alloc.getType().isa()) { - Value args[1] = {alloc}; - rewriter.create(alloc.getLoc(), freefn, args); - } else - rewriter.create(alloc.getLoc(), alloc); - } if (outerLoop) { if (isa(outerLoop)) rewriter.create(op.getLoc()); @@ -1211,6 +1152,8 @@ struct DistributeAroundBarrier : public OpRewritePattern { assert(isa(outerLoop)); rewriter.create(op.getLoc()); } + } else { + rewriter.create(op.getLoc()); } // Recreate the operations in the new loop with new values. @@ -1228,15 +1171,6 @@ struct DistributeAroundBarrier : public OpRewritePattern { for (Operation *o : llvm::reverse(toDelete)) rewriter.eraseOp(o); - for (auto ao : allocations) - if (ao.getDefiningOp() || - ao.getDefiningOp()) - rewriter.eraseOp(ao.getDefiningOp()); - - if (!outerLoop) { - rewriter.mergeBlockBefore(outerBlock, op); - rewriter.eraseOp(outerEx); - } rewriter.eraseOp(op); LLVM_DEBUG(DBGS() << "[distribute] distributed around a barrier\n"); @@ -1248,7 +1182,7 @@ LogicalResult DistributeAroundBarrier::splitSubLoop( scf::ParallelOp op, PatternRewriter &rewriter, BarrierOp barrier, SmallVector &iterCounts, scf::ParallelOp &preLoop, scf::ParallelOp &postLoop, Block *&outerBlock, scf::ParallelOp &outerLoop, - scf::ExecuteRegionOp &outerEx) const { + memref::AllocaScopeOp &outerEx) const { SmallVector outerLower; SmallVector outerUpper; @@ -1280,7 +1214,7 @@ LogicalResult DistributeAroundBarrier::splitSubLoop( rewriter.eraseOp(&outerLoop.getBody()->back()); outerBlock = outerLoop.getBody(); } else { - outerEx = rewriter.create(op.getLoc(), TypeRange()); + outerEx = rewriter.create(op.getLoc(), TypeRange()); outerBlock = new Block(); outerEx.getRegion().push_back(outerBlock); } @@ -1307,7 +1241,7 @@ LogicalResult DistributeAroundBarrier::splitSubLoop( AffineParallelOp op, PatternRewriter &rewriter, BarrierOp barrier, SmallVector &iterCounts, AffineParallelOp &preLoop, AffineParallelOp &postLoop, Block *&outerBlock, AffineParallelOp &outerLoop, - scf::ExecuteRegionOp &outerEx) const { + memref::AllocaScopeOp &outerEx) const { SmallVector outerLower; SmallVector outerUpper; @@ -1343,7 +1277,7 @@ LogicalResult DistributeAroundBarrier::splitSubLoop( rewriter.eraseOp(&outerLoop.getBody()->back()); outerBlock = outerLoop.getBody(); } else { - outerEx = rewriter.create(op.getLoc(), TypeRange()); + outerEx = rewriter.create(op.getLoc(), TypeRange()); outerBlock = new Block(); outerEx.getRegion().push_back(outerBlock); } @@ -1431,13 +1365,22 @@ template struct Reg2MemFor : public OpRewritePattern { auto oldTerminator = op.getBody()->getTerminator(); rewriter.mergeBlockBefore(op.getBody(), newOp.getBody()->getTerminator(), newRegionArguments); + SmallVector oldOps; + llvm::append_range(oldOps, oldTerminator->getOperands()); + rewriter.eraseOp(oldTerminator); - rewriter.setInsertionPoint(newOp.getBody()->getTerminator()); - for (auto en : llvm::enumerate(oldTerminator->getResults())) { + Operation *IP = newOp.getBody()->getTerminator(); + while (IP != &IP->getBlock()->front()) { + if (isa(IP->getPrevNode())) { + IP = IP->getPrevNode(); + } + break; + } + rewriter.setInsertionPoint(IP); + for (auto en : llvm::enumerate(oldOps)) { rewriter.create(op.getLoc(), en.value(), allocated[en.index()], ValueRange()); } - rewriter.eraseOp(oldTerminator); rewriter.setInsertionPointAfter(op); SmallVector loaded; @@ -1593,8 +1536,8 @@ struct CPUifyPass : public SCFCPUifyBase { patterns.insert, Reg2MemFor, Reg2MemWhile, Reg2MemIf, Reg2MemIf, WrapForWithBarrier, WrapAffineForWithBarrier, - WrapIfWithBarrier, WrapWhileWithBarrier, - + WrapIfWithBarrier, + WrapIfWithBarrier, WrapWhileWithBarrier, InterchangeForPFor, InterchangeForPFor, InterchangeForPForLoad, @@ -1627,7 +1570,7 @@ struct CPUifyPass : public SCFCPUifyBase { signalPassFailure(); } else if (method == "omp") { SmallVector toReplace; - getOperation().walk( + getOperation()->walk( [&](polygeist::BarrierOp b) { toReplace.push_back(b); }); for (auto b : toReplace) { OpBuilder Builder(b); diff --git a/lib/polygeist/Passes/PassDetails.h b/lib/polygeist/Passes/PassDetails.h index 6b1f6484240b..b61dc810465c 100644 --- a/lib/polygeist/Passes/PassDetails.h +++ b/lib/polygeist/Passes/PassDetails.h @@ -22,6 +22,7 @@ #include "polygeist/Passes/Passes.h" namespace mlir { +class FunctionOpInterface; // Forward declaration from Dialect.h template void registerDialect(DialectRegistry ®istry); diff --git a/lib/polygeist/Passes/RaiseToAffine.cpp b/lib/polygeist/Passes/RaiseToAffine.cpp index 8d614f0780f4..0334017d2674 100644 --- a/lib/polygeist/Passes/RaiseToAffine.cpp +++ b/lib/polygeist/Passes/RaiseToAffine.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/SCF/SCF.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "polygeist/Passes/Passes.h" #include "llvm/Support/Debug.h" @@ -30,10 +31,11 @@ struct ForOpRaising : public OpRewritePattern { bool isAffine(scf::ForOp loop) const { // return true; // enforce step to be a ConstantIndexOp (maybe too restrictive). - return isa_and_nonnull(loop.getStep().getDefiningOp()); + return isValidSymbol(loop.getStep()); } - void canonicalizeLoopBounds(AffineForOp forOp) const { + void canonicalizeLoopBounds(PatternRewriter &rewriter, + AffineForOp forOp) const { SmallVector lbOperands(forOp.getLowerBoundOperands()); SmallVector ubOperands(forOp.getUpperBoundOperands()); @@ -42,11 +44,11 @@ struct ForOpRaising : public OpRewritePattern { auto prevLbMap = lbMap; auto prevUbMap = ubMap; - fully2ComposeAffineMapAndOperands(&lbMap, &lbOperands); + fully2ComposeAffineMapAndOperands(rewriter, &lbMap, &lbOperands); canonicalizeMapAndOperands(&lbMap, &lbOperands); lbMap = removeDuplicateExprs(lbMap); - fully2ComposeAffineMapAndOperands(&ubMap, &ubOperands); + fully2ComposeAffineMapAndOperands(rewriter, &ubMap, &ubOperands); canonicalizeMapAndOperands(&ubMap, &ubOperands); ubMap = removeDuplicateExprs(ubMap); @@ -58,29 +60,94 @@ struct ForOpRaising : public OpRewritePattern { int64_t getStep(mlir::Value value) const { ConstantIndexOp cstOp = value.getDefiningOp(); - assert(cstOp && "expect non-null operation"); - return cstOp.value(); + if (cstOp) + return cstOp.value(); + else + return 1; } + AffineMap getMultiSymbolIdentity(Builder &B, unsigned rank) const { + SmallVector dimExprs; + dimExprs.reserve(rank); + for (unsigned i = 0; i < rank; ++i) + dimExprs.push_back(B.getAffineSymbolExpr(i)); + return AffineMap::get(/*dimCount=*/0, /*symbolCount=*/rank, dimExprs, + B.getContext()); + } LogicalResult matchAndRewrite(scf::ForOp loop, PatternRewriter &rewriter) const final { if (isAffine(loop)) { OpBuilder builder(loop); - if (!isValidIndex(loop.getLowerBound())) { - return failure(); + SmallVector lbs; + { + SmallVector todo = {loop.getLowerBound()}; + while (todo.size()) { + auto cur = todo.back(); + todo.pop_back(); + if (isValidIndex(cur)) { + lbs.push_back(cur); + continue; + } else if (auto selOp = cur.getDefiningOp()) { + // LB only has max of operands + if (auto cmp = selOp.getCondition().getDefiningOp()) { + if (cmp.getLhs() == selOp.getTrueValue() && + cmp.getRhs() == selOp.getFalseValue() && + cmp.getPredicate() == CmpIPredicate::sge) { + todo.push_back(cmp.getLhs()); + todo.push_back(cmp.getRhs()); + continue; + } + } + } + return failure(); + } } - if (!isValidIndex(loop.getUpperBound())) { - return failure(); + SmallVector ubs; + { + SmallVector todo = {loop.getUpperBound()}; + while (todo.size()) { + auto cur = todo.back(); + todo.pop_back(); + if (isValidIndex(cur)) { + ubs.push_back(cur); + continue; + } else if (auto selOp = cur.getDefiningOp()) { + // UB only has min of operands + if (auto cmp = selOp.getCondition().getDefiningOp()) { + if (cmp.getLhs() == selOp.getTrueValue() && + cmp.getRhs() == selOp.getFalseValue() && + cmp.getPredicate() == CmpIPredicate::sle) { + todo.push_back(cmp.getLhs()); + todo.push_back(cmp.getRhs()); + continue; + } + } + } + return failure(); + } + } + + bool rewrittenStep = false; + if (!loop.getStep().getDefiningOp()) { + if (ubs.size() != 1 || lbs.size() != 1) + return failure(); + ubs[0] = rewriter.create( + loop.getLoc(), + rewriter.create(loop.getLoc(), loop.getUpperBound(), + loop.getLowerBound()), + loop.getStep()); + lbs[0] = rewriter.create(loop.getLoc(), 0); + rewrittenStep = true; } AffineForOp affineLoop = rewriter.create( - loop.getLoc(), loop.getLowerBound(), builder.getSymbolIdentityMap(), - loop.getUpperBound(), builder.getSymbolIdentityMap(), - getStep(loop.getStep()), loop.getIterOperands()); + loop.getLoc(), lbs, getMultiSymbolIdentity(builder, lbs.size()), ubs, + getMultiSymbolIdentity(builder, ubs.size()), getStep(loop.getStep()), + loop.getIterOperands()); - canonicalizeLoopBounds(affineLoop); + canonicalizeLoopBounds(rewriter, affineLoop); auto mergedYieldOp = cast(loop.getRegion().front().getTerminator()); @@ -94,16 +161,19 @@ struct ForOpRaising : public OpRewritePattern { rewriter.eraseOp(affineYieldOp); } - rewriter.updateRootInPlace(loop, [&] { - affineLoop.region().front().getOperations().splice( - affineLoop.region().front().getOperations().begin(), - loop.getRegion().front().getOperations()); - - for (auto pair : llvm::zip(affineLoop.region().front().getArguments(), - loop.getRegion().front().getArguments())) { - std::get<1>(pair).replaceAllUsesWith(std::get<0>(pair)); + SmallVector vals; + rewriter.setInsertionPointToStart(&affineLoop.region().front()); + for (Value arg : affineLoop.region().front().getArguments()) { + if (rewrittenStep && arg == affineLoop.getInductionVar()) { + arg = rewriter.create( + loop.getLoc(), loop.getLowerBound(), + rewriter.create(loop.getLoc(), arg, loop.getStep())); } - }); + vals.push_back(arg); + } + assert(vals.size() == loop.getRegion().front().getNumArguments()); + rewriter.mergeBlocks(&loop.getRegion().front(), + &affineLoop.region().front(), vals); rewriter.setInsertionPoint(mergedYieldOp); rewriter.create(mergedYieldOp.getLoc(), @@ -129,7 +199,8 @@ struct ParallelOpRaising : public OpRewritePattern { return true; } - void canonicalizeLoopBounds(AffineParallelOp forOp) const { + void canonicalizeLoopBounds(PatternRewriter &rewriter, + AffineParallelOp forOp) const { SmallVector lbOperands(forOp.getLowerBoundsOperands()); SmallVector ubOperands(forOp.getUpperBoundsOperands()); @@ -138,10 +209,10 @@ struct ParallelOpRaising : public OpRewritePattern { auto prevLbMap = lbMap; auto prevUbMap = ubMap; - fully2ComposeAffineMapAndOperands(&lbMap, &lbOperands); + fully2ComposeAffineMapAndOperands(rewriter, &lbMap, &lbOperands); canonicalizeMapAndOperands(&lbMap, &lbOperands); - fully2ComposeAffineMapAndOperands(&ubMap, &ubOperands); + fully2ComposeAffineMapAndOperands(rewriter, &ubMap, &ubOperands); canonicalizeMapAndOperands(&ubMap, &ubOperands); if (lbMap != prevLbMap) @@ -183,7 +254,7 @@ struct ParallelOpRaising : public OpRewritePattern { loop.getLowerBound(), bounds, loop.getUpperBound(), steps); //, loop.getInitVals()); - canonicalizeLoopBounds(affineLoop); + canonicalizeLoopBounds(rewriter, affineLoop); auto mergedYieldOp = cast(loop.getRegion().front().getTerminator()); @@ -197,16 +268,12 @@ struct ParallelOpRaising : public OpRewritePattern { rewriter.eraseOp(affineYieldOp); } - rewriter.updateRootInPlace(loop, [&] { - affineLoop.region().front().getOperations().splice( - affineLoop.region().front().getOperations().begin(), - loop.getRegion().front().getOperations()); - - for (auto pair : llvm::zip(affineLoop.region().front().getArguments(), - loop.getRegion().front().getArguments())) { - std::get<1>(pair).replaceAllUsesWith(std::get<0>(pair)); - } - }); + SmallVector vals; + for (Value arg : affineLoop.region().front().getArguments()) { + vals.push_back(arg); + } + rewriter.mergeBlocks(&loop.getRegion().front(), + &affineLoop.region().front(), vals); rewriter.setInsertionPoint(mergedYieldOp); rewriter.create(mergedYieldOp.getLoc(), @@ -220,15 +287,12 @@ struct ParallelOpRaising : public OpRewritePattern { }; void RaiseSCFToAffine::runOnOperation() { - ConversionTarget target(getContext()); - target.addLegalDialect(); - RewritePatternSet patterns(&getContext()); patterns.insert(&getContext()); - if (failed( - applyPartialConversion(getOperation(), target, std::move(patterns)))) - signalPassFailure(); + GreedyRewriteConfig config; + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config); } namespace mlir { diff --git a/llvm-project b/llvm-project index 87ec6f41bba6..d8a6a696bfa1 160000 --- a/llvm-project +++ b/llvm-project @@ -1 +1 @@ -Subproject commit 87ec6f41bba6d72a3408e71cf19ae56feff523bc +Subproject commit d8a6a696bfa14f9d209cee2aa0442d7cbc11679c diff --git a/test/polygeist-opt/affinecfg.mlir b/test/polygeist-opt/affinecfg.mlir index 6ebd8310e3a0..39a6beb25199 100644 --- a/test/polygeist-opt/affinecfg.mlir +++ b/test/polygeist-opt/affinecfg.mlir @@ -13,8 +13,10 @@ module { } return } + } + // CHECK: func @_Z7runTestiPPc(%arg0: index, %arg1: memref) { // CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 // CHECK-NEXT: %c1 = arith.constant 1 : index @@ -26,3 +28,81 @@ module { // CHECK-NEXT: } // CHECK-NEXT: return // CHECK-NEXT: } + +// ----- +module { +func @kernel_nussinov(%arg0: i32, %arg2: memref) { + %c0 = arith.constant 0 : index + %true = arith.constant true + %c1_i32 = arith.constant 1 : i32 + %c59 = arith.constant 59 : index + %c100_i32 = arith.constant 100 : i32 + affine.for %arg3 = 0 to 60 { + %0 = arith.subi %c59, %arg3 : index + %1 = arith.index_cast %0 : index to i32 + %2 = arith.cmpi slt, %1, %c100_i32 : i32 + scf.if %2 { + affine.store %arg0, %arg2[] : memref + } + } + return +} +} + +// CHECK: #set = affine_set<(d0) : (d0 + 40 >= 0)> +// CHECK: func @kernel_nussinov(%arg0: i32, %arg1: memref) { +// CHECK-NEXT: affine.for %arg2 = 0 to 60 { +// CHECK-NEXT: affine.if #set(%arg2) { +// CHECK-NEXT: affine.store %arg0, %arg1[] : memref +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } + + +// ----- + +module { + func private @run() + + func @minif(%arg4: i32, %arg5 : i32, %arg10 : index) { + %c0_i32 = arith.constant 0 : i32 + + affine.for %i = 0 to 10 { + %70 = arith.index_cast %arg10 : index to i32 + %71 = arith.muli %70, %arg5 : i32 + %73 = arith.divui %71, %arg5 : i32 + %75 = arith.muli %73, %arg5 : i32 + %79 = arith.subi %arg4, %75 : i32 + %81 = arith.cmpi sle, %arg5, %79 : i32 + %83 = arith.select %81, %arg5, %79 : i32 + %92 = arith.cmpi slt, %c0_i32, %83 : i32 + scf.if %92 { + call @run() : () -> () + scf.yield + } + } + return + } +} + +// CHECK: #set = affine_set<()[s0] : (s0 - 1 >= 0)> +// CHECK: func @minif(%arg0: i32, %arg1: i32, %arg2: index) { +// CHECK-NEXT: %0 = arith.index_cast %arg1 : i32 to index +// CHECK-NEXT: %1 = arith.index_cast %arg1 : i32 to index +// CHECK-NEXT: %2 = arith.index_cast %arg1 : i32 to index +// CHECK-NEXT: %3 = arith.index_cast %arg0 : i32 to index +// CHECK-NEXT: %4 = arith.index_cast %arg1 : i32 to index +// CHECK-NEXT: %5 = arith.muli %arg2, %0 : index +// CHECK-NEXT: %6 = arith.divui %5, %1 : index +// CHECK-NEXT: %7 = arith.muli %6, %2 : index +// CHECK-NEXT: %8 = arith.subi %3, %7 : index +// CHECK-NEXT: %9 = arith.cmpi sle, %4, %8 : index +// CHECK-NEXT: %10 = arith.select %9, %4, %8 : index +// CHECK-NEXT: affine.for %arg3 = 0 to 10 { +// CHECK-NEXT: affine.if #set()[%10] { +// CHECK-NEXT: call @run() : () -> () +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } diff --git a/test/polygeist-opt/affraise.mlir b/test/polygeist-opt/affraise.mlir new file mode 100644 index 000000000000..457f77815562 --- /dev/null +++ b/test/polygeist-opt/affraise.mlir @@ -0,0 +1,27 @@ +// RUN: polygeist-opt --raise-scf-to-affine --split-input-file %s | FileCheck %s + +module { + func @withinif(%arg0: memref, %arg1: i32, %arg2: memref, %arg3: i1) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + scf.if %arg3 { + %3 = arith.index_cast %arg1 : i32 to index + scf.for %arg6 = %c1 to %3 step %c1 { + %4 = memref.load %arg0[%arg6] : memref + memref.store %4, %arg2[%arg6] : memref + } + } + return + } +} + +// CHECK: func @withinif(%arg0: memref, %arg1: i32, %arg2: memref, %arg3: i1) { +// CHECK-DAG: %0 = arith.index_cast %arg1 : i32 to index +// CHECK-NEXT: scf.if %arg3 { +// CHECK-NEXT: affine.for %arg4 = 1 to %0 { +// CHECK-NEXT: %1 = memref.load %arg0[%arg4] : memref +// CHECK-NEXT: memref.store %1, %arg2[%arg4] : memref +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } diff --git a/test/polygeist-opt/allocdist.mlir b/test/polygeist-opt/allocdist.mlir index 3ceec3bd37cb..d363959b2399 100644 --- a/test/polygeist-opt/allocdist.mlir +++ b/test/polygeist-opt/allocdist.mlir @@ -35,13 +35,14 @@ module { // CHECK-NEXT: %c0 = arith.constant 0 : index // CHECK-NEXT: %c1 = arith.constant 1 : index // CHECK-NEXT: %c5 = arith.constant 5 : index -// CHECK-NEXT: %0 = memref.alloc(%c5) : memref -// CHECK-NEXT: %1 = memref.alloc(%c5) : memref> -// CHECK-NEXT: %2 = memref.alloc(%c5) : memref -// CHECK-NEXT: %3 = memref.alloc(%c5) : memref> -// CHECK-NEXT: %4 = memref.alloc(%c5) : memref -// CHECK-NEXT: %5 = memref.alloc(%c5) : memref -// CHECK-NEXT: %6 = memref.alloc(%c5) : memref +// CHECK-NEXT: memref.alloca_scope { +// CHECK-NEXT: %0 = memref.alloca(%c5) : memref +// CHECK-NEXT: %1 = memref.alloca(%c5) : memref> +// CHECK-NEXT: %2 = memref.alloca(%c5) : memref +// CHECK-NEXT: %3 = memref.alloca(%c5) : memref> +// CHECK-NEXT: %4 = memref.alloca(%c5) : memref +// CHECK-NEXT: %5 = memref.alloca(%c5) : memref +// CHECK-NEXT: %6 = memref.alloca(%c5) : memref // CHECK-NEXT: scf.parallel (%arg0) = (%c0) to (%c5) step (%c1) { // CHECK-NEXT: %7 = "polygeist.subindex"(%4, %arg0) : (memref, index) -> memref<2xi32> // CHECK-NEXT: %8 = memref.cast %7 : memref<2xi32> to memref @@ -68,12 +69,6 @@ module { // CHECK-DAG: call @use(%[[i9]], %[[i10]], %[[i8]], %[[i11]]) : (memref, f32, i32, f32) -> () // CHECK-NEXT: scf.yield // CHECK-NEXT: } -// CHECK-NEXT: memref.dealloc %0 : memref -// CHECK-NEXT: memref.dealloc %1 : memref> -// CHECK-NEXT: memref.dealloc %2 : memref -// CHECK-NEXT: memref.dealloc %3 : memref> -// CHECK-NEXT: memref.dealloc %4 : memref -// CHECK-NEXT: memref.dealloc %5 : memref -// CHECK-NEXT: memref.dealloc %6 : memref +// CHECK-NEXT: } // CHECK-NEXT: return // CHECK-NEXT: } diff --git a/test/polygeist-opt/ifcomb.mlir b/test/polygeist-opt/ifcomb.mlir new file mode 100644 index 000000000000..595a49d21421 --- /dev/null +++ b/test/polygeist-opt/ifcomb.mlir @@ -0,0 +1,36 @@ +// RUN: polygeist-opt --canonicalize --split-input-file %s | FileCheck %s + +module { + func @_Z17compute_tran_tempPfPS_iiiiiiii(%arg0: memref, %arg1: i32, %arg2: i32, %arg3: i32) -> i8 { + %c1_i8 = arith.constant 1 : i8 + %c0_i8 = arith.constant 0 : i8 + %cst = arith.constant 0.000000e+00 : f32 + %0 = arith.cmpi sge, %arg3, %arg1 : i32 + %1 = scf.if %0 -> (i8) { + %2 = arith.cmpi sle, %arg3, %arg2 : i32 + %3 = scf.if %2 -> (i8) { + affine.store %cst, %arg0[] : memref + scf.yield %c1_i8 : i8 + } else { + scf.yield %c0_i8 : i8 + } + scf.yield %3 : i8 + } else { + scf.yield %c0_i8 : i8 + } + return %1 : i8 + } +} + +// CHECK: func @_Z17compute_tran_tempPfPS_iiiiiiii(%arg0: memref, %arg1: i32, %arg2: i32, %arg3: i32) -> i8 { +// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 +// CHECK-NEXT: %0 = arith.cmpi sge, %arg3, %arg1 : i32 +// CHECK-NEXT: %1 = arith.cmpi sle, %arg3, %arg2 : i32 +// CHECK-NEXT: %2 = arith.andi %0, %1 : i1 +// CHECK-NEXT: %3 = arith.andi %2, %1 : i1 +// CHECK-NEXT: %4 = arith.extui %3 : i1 to i8 +// CHECK-NEXT: scf.if %2 { +// CHECK-NEXT: affine.store %cst, %arg0[] : memref +// CHECK-NEXT: } +// CHECK-NEXT: return %4 : i8 +// CHECK-NEXT: } diff --git a/test/polygeist-opt/paralleldistribute.mlir b/test/polygeist-opt/paralleldistribute.mlir new file mode 100644 index 000000000000..60640ab00209 --- /dev/null +++ b/test/polygeist-opt/paralleldistribute.mlir @@ -0,0 +1,113 @@ +// RUN: polygeist-opt --cpuify="method=distribute" --canonicalize --split-input-file %s | FileCheck %s + +module { + func private @print() + func @main() { + %c0_i8 = arith.constant 0 : i8 + %c1_i8 = arith.constant 1 : i8 + %c1_i64 = arith.constant 1 : i64 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c0_i32 = arith.constant 0 : i32 + %c5 = arith.constant 5 : index + %c2 = arith.constant 2 : index + scf.parallel (%arg2) = (%c0) to (%c5) step (%c1) { + %0 = llvm.alloca %c1_i64 x i8 : (i64) -> !llvm.ptr + scf.parallel (%arg3) = (%c0) to (%c2) step (%c1) { + %4 = scf.while (%arg4 = %c1_i8) : (i8) -> i8 { + %6 = arith.cmpi ne, %arg4, %c0_i8 : i8 + scf.condition(%6) %arg4 : i8 + } do { + ^bb0(%arg4: i8): // no predecessors + llvm.store %c0_i8, %0 : !llvm.ptr + "polygeist.barrier"(%arg3) : (index) -> () + scf.yield %c0_i8 : i8 + } + %5 = arith.cmpi ne, %4, %c0_i8 : i8 + scf.if %5 { + call @print() : () -> () + } + scf.yield + } + scf.yield + } + return + } + func @_Z17compute_tran_tempPfPS_iiiiiiii(%arg0: memref, %len : index, %f : f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + affine.parallel (%arg15, %arg16) = (0, 0) to (16, 16) { + scf.for %arg17 = %c0 to %len step %c1 { + affine.store %f, %arg0[%arg15] : memref + "polygeist.barrier"(%arg15, %arg16, %c0) : (index, index, index) -> () + } + } + return + } +} + + +// CHECK: func @main() { +// CHECK-DAG: %c0_i8 = arith.constant 0 : i8 +// CHECK-DAG: %c1_i8 = arith.constant 1 : i8 +// CHECK-DAG: %c1_i64 = arith.constant 1 : i64 +// CHECK-DAG: %c0 = arith.constant 0 : index +// CHECK-DAG: %c1 = arith.constant 1 : index +// CHECK-DAG: %c5 = arith.constant 5 : index +// CHECK-DAG: %c2 = arith.constant 2 : index +// CHECK-DAG: scf.parallel (%arg0) = (%c0) to (%c5) step (%c1) { +// CHECK-NEXT: %0 = llvm.alloca %c1_i64 x i8 : (i64) -> !llvm.ptr +// CHECK-NEXT: %1 = memref.alloca() : memref<2xi8> +// CHECK-NEXT: %2 = memref.alloca() : memref<2xi8> +// CHECK-NEXT: scf.parallel (%arg1) = (%c0) to (%c2) step (%c1) { +// CHECK-NEXT: memref.store %c1_i8, %2[%arg1] : memref<2xi8> +// CHECK-NEXT: scf.yield +// CHECK-NEXT: } +// CHECK-NEXT: scf.while : () -> () { +// CHECK-NEXT: %3 = memref.alloca() : memref +// CHECK-NEXT: scf.parallel (%arg1) = (%c0) to (%c2) step (%c1) { +// CHECK-NEXT: %5 = memref.load %2[%arg1] : memref<2xi8> +// CHECK-NEXT: %6 = arith.cmpi ne, %5, %c0_i8 : i8 +// CHECK-NEXT: %7 = arith.cmpi eq, %c0, %arg1 : index +// CHECK-NEXT: scf.if %7 { +// CHECK-NEXT: memref.store %6, %3[] : memref +// CHECK-NEXT: } +// CHECK-NEXT: memref.store %5, %1[%arg1] : memref<2xi8> +// CHECK-NEXT: scf.yield +// CHECK-NEXT: } +// CHECK-NEXT: %4 = memref.load %3[] : memref +// CHECK-NEXT: scf.condition(%4) +// CHECK-NEXT: } do { +// CHECK-NEXT: scf.parallel (%arg1) = (%c0) to (%c2) step (%c1) { +// CHECK-NEXT: llvm.store %c0_i8, %0 : !llvm.ptr +// CHECK-NEXT: scf.yield +// CHECK-NEXT: } +// CHECK-NEXT: scf.parallel (%arg1) = (%c0) to (%c2) step (%c1) { +// CHECK-NEXT: memref.store %c0_i8, %2[%arg1] : memref<2xi8> +// CHECK-NEXT: scf.yield +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield +// CHECK-NEXT: } +// CHECK-NEXT: scf.parallel (%arg1) = (%c0) to (%c2) step (%c1) { +// CHECK-NEXT: %3 = memref.load %1[%arg1] : memref<2xi8> +// CHECK-NEXT: %4 = arith.cmpi ne, %3, %c0_i8 : i8 +// CHECK-NEXT: scf.if %4 { +// CHECK-NEXT: call @print() : () -> () +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } + +// CHECK: func @_Z17compute_tran_tempPfPS_iiiiiiii(%arg0: memref, %arg1: index, %arg2: f32) { +// CHECK-DAG: %c0 = arith.constant 0 : index +// CHECK-DAG: %c1 = arith.constant 1 : index +// CHECK-NEXT: scf.for %arg3 = %c0 to %arg1 step %c1 { +// CHECK-NEXT: affine.parallel (%arg4, %arg5) = (0, 0) to (16, 16) { +// CHECK-NEXT: affine.store %arg2, %arg0[%arg4] : memref +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } diff --git a/test/polygeist-opt/paralleldistributefor.mlir b/test/polygeist-opt/paralleldistributefor.mlir new file mode 100644 index 000000000000..b4de9be99a36 --- /dev/null +++ b/test/polygeist-opt/paralleldistributefor.mlir @@ -0,0 +1,38 @@ +// RUN: polygeist-opt --cpuify="method=distribute" --canonicalize --split-input-file %s | FileCheck %s + +module { + func private @use(%arg : i1) + func @_Z17compute_tran_tempPfPS_iiiiiiii(%arg0: memref, %len : index, %f : f32, %start : i1, %end : i1) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + affine.parallel (%arg15, %arg16) = (0, 0) to (16, 16) { + %r = scf.for %arg17 = %c0 to %len step %c1 iter_args(%mid = %start) -> (i1) { + affine.store %f, %arg0[%arg15] : memref + "polygeist.barrier"(%arg15, %arg16, %c0) : (index, index, index) -> () + scf.yield %end : i1 + } + call @use(%r) : (i1) -> () + } + return + } +} + +// CHECK: func @_Z17compute_tran_tempPfPS_iiiiiiii(%arg0: memref, %arg1: index, %arg2: f32, %arg3: i1, %arg4: i1) { +// CHECK-DAG: %c1 = arith.constant 1 : index +// CHECK-DAG: %c0 = arith.constant 0 : index +// CHECK-NEXT: %0 = memref.alloca() : memref<16x16xi1> +// CHECK-NEXT: affine.parallel (%arg5, %arg6) = (0, 0) to (16, 16) { +// CHECK-NEXT: memref.store %arg3, %0[%arg5, %arg6] : memref<16x16xi1> +// CHECK-NEXT: } +// CHECK-NEXT: scf.for %arg5 = %c0 to %arg1 step %c1 { +// CHECK-NEXT: affine.parallel (%arg6, %arg7) = (0, 0) to (16, 16) { +// CHECK-NEXT: affine.store %arg2, %arg0[%arg6] : memref +// CHECK-NEXT: memref.store %arg4, %0[%arg6, %arg7] : memref<16x16xi1> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: affine.parallel (%arg5, %arg6) = (0, 0) to (16, 16) { +// CHECK-NEXT: %1 = memref.load %0[%arg5, %arg6] : memref<16x16xi1> +// CHECK-NEXT: call @use(%1) : (i1) -> () +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } diff --git a/test/polygeist-opt/pathfinder.mlir b/test/polygeist-opt/pathfinder.mlir index da5ebf97a0bf..f051623ee456 100644 --- a/test/polygeist-opt/pathfinder.mlir +++ b/test/polygeist-opt/pathfinder.mlir @@ -31,8 +31,10 @@ module { // CHECK-NEXT: %c9 = arith.constant 9 : index // CHECK-NEXT: %true = arith.constant true // CHECK-NEXT: %0 = memref.alloca() : memref<256xi32> -// CHECK-NEXT: %1 = memref.alloc(%c9) : memref +// CHECK-NEXT: memref.alloca_scope { +// CHECK-NEXT: %1 = memref.alloca(%c9) : memref // CHECK-NEXT: scf.if %arg1 { +// CHECK-NEXT: memref.alloca_scope { // CHECK-NEXT: scf.parallel (%arg2) = (%c0) to (%c9) step (%c1) { // CHECK-NEXT: memref.store %c0_i32, %0[%c0] : memref<256xi32> // CHECK-NEXT: scf.yield @@ -42,6 +44,7 @@ module { // CHECK-NEXT: memref.store %true, %2[] : memref // CHECK-NEXT: scf.yield // CHECK-NEXT: } +// CHECK-NEXT: } // CHECK-NEXT: } else { // CHECK-NEXT: scf.parallel (%arg2) = (%c0) to (%c9) step (%c1) { // CHECK-NEXT: %2 = "polygeist.subindex"(%1, %arg2) : (memref, index) -> memref @@ -49,6 +52,6 @@ module { // CHECK-NEXT: scf.yield // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: memref.dealloc %1 : memref +// CHECK-NEXT: } // CHECK-NEXT: return // CHECK-NEXT: } diff --git a/test/polygeist-opt/whiletofor2.mlir b/test/polygeist-opt/whiletofor2.mlir new file mode 100644 index 000000000000..841e0b184291 --- /dev/null +++ b/test/polygeist-opt/whiletofor2.mlir @@ -0,0 +1,171 @@ +// RUN: polygeist-opt -allow-unregistered-dialect --canonicalize-scf-for --split-input-file %s | FileCheck %s + +module { + func @w2f(%ub : i32) -> (i32, f32) { + %cst = arith.constant 0.000000e+00 : f32 + %cst1 = arith.constant 1.000000e+00 : f32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %true = arith.constant true + %2:2 = scf.while (%arg10 = %c0_i32, %arg12 = %cst, %ac = %true) : (i32, f32, i1) -> (i32, f32) { + %3 = arith.cmpi ult, %arg10, %ub : i32 + %a = arith.andi %3, %ac : i1 + scf.condition(%a) %arg10, %arg12 : i32, f32 + } do { + ^bb0(%arg10: i32, %arg12: f32): + %c = "test.something"() : () -> (i1) + %3 = arith.addf %arg12, %cst1 : f32 + %p = arith.addi %arg10, %c1_i32 : i32 + scf.yield %p, %3, %c : i32, f32, i1 + } + return %2#0, %2#1 : i32, f32 + } + + func @w2f_inner(%ub : i32) -> (i32, f32) { + %cst = arith.constant 0.000000e+00 : f32 + %cst1 = arith.constant 1.000000e+00 : f32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %true = arith.constant true + %2:2 = scf.while (%arg10 = %c0_i32, %arg12 = %cst, %ac = %true) : (i32, f32, i1) -> (i32, f32) { + %3 = arith.cmpi ult, %arg10, %ub : i32 + %a = arith.andi %3, %ac : i1 + scf.condition(%a) %arg10, %arg12 : i32, f32 + } do { + ^bb0(%arg10: i32, %arg12: f32): + %c = "test.something"() : () -> (i1) + %r:2 = scf.if %c -> (i32, f32) { + %3 = arith.addf %arg12, %cst1 : f32 + %p = arith.addi %arg10, %c1_i32 : i32 + scf.yield %p, %3 : i32, f32 + } else { + scf.yield %arg10, %arg12 : i32, f32 + } + scf.yield %r#0, %r#1, %c : i32, f32, i1 + } + return %2#0, %2#1 : i32, f32 + } + + func @_Z17compute_tran_tempPfPS_iiiiiiii(%arg0: i8, %arg1: index, %arg2: i32, %arg3: i32, %arg4: i32) -> i32 { + %c1_i8 = arith.constant 1 : i8 + %c0_i8 = arith.constant 0 : i8 + %true = arith.constant true + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %0:2 = scf.while (%arg5 = %c0_i32, %arg6 = %arg0, %arg7 = %true) : (i32, i8, i1) -> (i8, i32) { + %1 = arith.cmpi slt, %arg5, %arg2 : i32 + %2 = arith.andi %1, %arg7 : i1 + scf.condition(%2) %arg6, %arg5 : i8, i32 + } do { + ^bb0(%arg5: i8, %arg6: i32): + %1 = arith.addi %arg6, %c1_i32 : i32 + %2 = arith.cmpi ne, %arg6, %arg4 : i32 + %3 = scf.if %2 -> (i32) { + scf.yield %1 : i32 + } else { + scf.yield %arg6 : i32 + } + scf.yield %3, %c0_i8, %2 : i32, i8, i1 + } + return %0#1 : i32 + } +} + +// CHECK: func @w2f(%arg0: i32) -> (i32, f32) { +// CHECK-DAG: %c1_i32 = arith.constant 1 : i32 +// CHECK-DAG: %c0_i32 = arith.constant 0 : i32 +// CHECK-DAG: %[[cst:.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: %[[cst_0:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %false = arith.constant false +// CHECK-DAG: %true = arith.constant true +// CHECK-DAG: %c0 = arith.constant 0 : index +// CHECK-DAG: %c1 = arith.constant 1 : index +// CHECK-NEXT: %0 = arith.cmpi ult, %c0_i32, %arg0 : i32 +// CHECK-NEXT: %1:2 = scf.if %0 -> (i32, f32) { +// CHECK-NEXT: %2 = arith.index_cast %arg0 : i32 to index +// CHECK-NEXT: %3:3 = scf.for %arg1 = %c0 to %2 step %c1 iter_args(%arg2 = %c0_i32, %arg3 = %[[cst_0]], %arg4 = %true) -> (i32, f32, i1) { +// CHECK-NEXT: %4:3 = scf.if %arg4 -> (i32, f32, i1) { +// CHECK-NEXT: %5 = "test.something"() : () -> i1 +// CHECK-NEXT: %6 = arith.addf %arg3, %[[cst]] : f32 +// CHECK-NEXT: %7 = arith.addi %arg2, %c1_i32 : i32 +// CHECK-NEXT: scf.yield %7, %6, %5 : i32, f32, i1 +// CHECK-NEXT: } else { +// CHECK-NEXT: scf.yield %arg2, %arg3, %false : i32, f32, i1 +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %4#0, %4#1, %4#2 : i32, f32, i1 +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %3#0, %3#1 : i32, f32 +// CHECK-NEXT: } else { +// CHECK-NEXT: scf.yield %c0_i32, %[[cst_0]] : i32, f32 +// CHECK-NEXT: } +// CHECK-NEXT: return %1#0, %1#1 : i32, f32 +// CHECK-NEXT: } + +// CHECK: func @w2f_inner(%arg0: i32) -> (i32, f32) { +// CHECK-DAG: %c1_i32 = arith.constant 1 : i32 +// CHECK-DAG: %c0_i32 = arith.constant 0 : i32 +// CHECK-DAG: %[[cst:.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: %[[cst_0:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %false = arith.constant false +// CHECK-DAG: %true = arith.constant true +// CHECK-DAG: %c0 = arith.constant 0 : index +// CHECK-DAG: %c1 = arith.constant 1 : index +// CHECK-NEXT: %0 = arith.cmpi ult, %c0_i32, %arg0 : i32 +// CHECK-NEXT: %1:2 = scf.if %0 -> (i32, f32) { +// CHECK-NEXT: %2 = arith.index_cast %arg0 : i32 to index +// CHECK-NEXT: %3:3 = scf.for %arg1 = %c0 to %2 step %c1 iter_args(%arg2 = %c0_i32, %arg3 = %[[cst_0]], %arg4 = %true) -> (i32, f32, i1) { +// CHECK-NEXT: %4:3 = scf.if %arg4 -> (i32, f32, i1) { + +// CHECK-NEXT: %5 = "test.something"() : () -> i1 +// CHECK-NEXT: %6:2 = scf.if %5 -> (i32, f32) { +// CHECK-NEXT: %7 = arith.addf %arg3, %cst_0 : f32 +// CHECK-NEXT: %8 = arith.addi %arg2, %c1_i32 : i32 +// CHECK-NEXT: scf.yield %8, %7 : i32, f32 +// CHECK-NEXT: } else { +// CHECK-NEXT: scf.yield %arg2, %arg3 : i32, f32 +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %6#0, %6#1, %5 : i32, f32, i1 +// CHECK-NEXT: } else { +// CHECK-NEXT: scf.yield %arg2, %arg3, %false : i32, f32, i1 +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %4#0, %4#1, %4#2 : i32, f32, i1 +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %3#0, %3#1 : i32, f32 +// CHECK-NEXT: } else { +// CHECK-NEXT: scf.yield %c0_i32, %[[cst_0]] : i32, f32 +// CHECK-NEXT: } +// CHECK-NEXT: return %1#0, %1#1 : i32, f32 +// CHECK-NEXT: } + +// CHECK: func @_Z17compute_tran_tempPfPS_iiiiiiii(%arg0: i8, %arg1: index, %arg2: i32, %arg3: i32, %arg4: i32) -> i32 { +// CHECK-DAG: %c0_i8 = arith.constant 0 : i8 +// CHECK-DAG: %c1_i32 = arith.constant 1 : i32 +// CHECK-DAG: %c0_i32 = arith.constant 0 : i32 +// CHECK-DAG: %false = arith.constant false +// CHECK-DAG: %true = arith.constant true +// CHECK-DAG: %c0 = arith.constant 0 : index +// CHECK-DAG: %c1 = arith.constant 1 : index +// CHECK-DAG: %0 = arith.cmpi slt, %c0_i32, %arg2 : i32 +// CHECK-DAG: %1:2 = scf.if %0 -> (i8, i32) { +// CHECK-DAG: %2 = arith.index_cast %arg2 : i32 to index +// CHECK-NEXT: %3:3 = scf.for %arg5 = %c0 to %2 step %c1 iter_args(%arg6 = %arg0, %arg7 = %c0_i32, %arg8 = %true) -> (i8, i32, i1) { +// CHECK-NEXT: %4:3 = scf.if %arg8 -> (i8, i32, i1) { +// CHECK-NEXT: %5 = arith.addi %arg7, %c1_i32 : i32 +// CHECK-NEXT: %6 = arith.cmpi ne, %arg7, %arg4 : i32 +// CHECK-NEXT: %7 = scf.if %6 -> (i32) { +// CHECK-NEXT: scf.yield %5 : i32 +// CHECK-NEXT: } else { +// CHECK-NEXT: scf.yield %arg7 : i32 +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %c0_i8, %7, %6 : i8, i32, i1 +// CHECK-NEXT: } else { +// CHECK-NEXT: scf.yield %arg6, %arg7, %false : i8, i32, i1 +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %4#0, %4#1, %4#2 : i8, i32, i1 +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %3#0, %3#1 : i8, i32 +// CHECK-NEXT: } else { +// CHECK-NEXT: scf.yield %arg0, %c0_i32 : i8, i32 +// CHECK-NEXT: } +// CHECK-NEXT: return %1#1 : i32 +// CHECK-NEXT: } diff --git a/tools/mlir-clang/CMakeLists.txt b/tools/mlir-clang/CMakeLists.txt index b54c1436f2f2..d522642c0efe 100644 --- a/tools/mlir-clang/CMakeLists.txt +++ b/tools/mlir-clang/CMakeLists.txt @@ -51,7 +51,7 @@ target_link_libraries(mlir-clang PRIVATE MLIRGPUOps MLIRTransforms MLIRSCFToControlFlow - MLIRStandardToLLVM + MLIRFuncToLLVM MLIRAffineTransforms MLIRAffineToStandard MLIRMathToLLVM diff --git a/tools/mlir-clang/Lib/CGCall.cc b/tools/mlir-clang/Lib/CGCall.cc index 9f9689bc7938..8b3e46993a80 100644 --- a/tools/mlir-clang/Lib/CGCall.cc +++ b/tools/mlir-clang/Lib/CGCall.cc @@ -60,7 +60,7 @@ static mlir::Value castCallerMemRefArg(mlir::Value callerArg, static void castCallerArgs(mlir::FuncOp callee, llvm::SmallVectorImpl &args, mlir::OpBuilder &b) { - mlir::FunctionType funcTy = callee.getType().cast(); + mlir::FunctionType funcTy = callee.getFunctionType(); assert(args.size() == funcTy.getNumInputs() && "The caller arguments should have the same size as the number of " "callee arguments as the interface."); @@ -82,7 +82,7 @@ ValueCategory MLIRScanner::CallHelper( ArrayRef> arguments, QualType retType, bool retReference, clang::Expr *expr) { SmallVector args; - auto fnType = tocall.getType(); + auto fnType = tocall.getFunctionType(); size_t i = 0; // map from declaration name to mlir::value diff --git a/tools/mlir-clang/Lib/clang-mlir.cc b/tools/mlir-clang/Lib/clang-mlir.cc index 11ada90ad98a..5cae58ebaffe 100644 --- a/tools/mlir-clang/Lib/clang-mlir.cc +++ b/tools/mlir-clang/Lib/clang-mlir.cc @@ -270,9 +270,9 @@ void MLIRScanner::init(mlir::FuncOp function, const FunctionDecl *fd) { builder.create(loc, type)}); builder.create(loc, truev, loops.back().noBreak); builder.create(loc, truev, loops.back().keepRunning); - if (function.getType().getResults().size()) { - auto type = - mlir::MemRefType::get({}, function.getType().getResult(0), {}, 0); + if (function.getFunctionType().getResults().size()) { + auto type = mlir::MemRefType::get( + {}, function.getFunctionType().getResult(0), {}, 0); returnVal = builder.create(loc, type); if (type.getElementType().isa()) { builder.create( @@ -282,7 +282,7 @@ void MLIRScanner::init(mlir::FuncOp function, const FunctionDecl *fd) { } Visit(stmt); - if (function.getType().getResults().size()) { + if (function.getFunctionType().getResults().size()) { mlir::Value vals[1] = { builder.create(loc, returnVal)}; builder.create(loc, vals); diff --git a/tools/mlir-clang/Lib/utils.cc b/tools/mlir-clang/Lib/utils.cc index 1595333d6633..d03ec9797826 100644 --- a/tools/mlir-clang/Lib/utils.cc +++ b/tools/mlir-clang/Lib/utils.cc @@ -41,10 +41,9 @@ Operation *buildLinalgOp(StringRef name, OpBuilder &b, } } -Operation * -mlirclang::replaceFuncByOperation(FuncOp f, StringRef opName, OpBuilder &b, - SmallVectorImpl &input, - SmallVectorImpl &output) { +Operation *mlirclang::replaceFuncByOperation( + func::FuncOp f, StringRef opName, OpBuilder &b, + SmallVectorImpl &input, SmallVectorImpl &output) { MLIRContext *ctx = f->getContext(); assert(ctx->isOperationRegistered(opName) && "Provided lower_to opName should be registered."); diff --git a/tools/mlir-clang/Lib/utils.h b/tools/mlir-clang/Lib/utils.h index 4fc951db9a7b..bf95692480ef 100644 --- a/tools/mlir-clang/Lib/utils.h +++ b/tools/mlir-clang/Lib/utils.h @@ -13,7 +13,9 @@ namespace mlir { class Operation; +namespace func { class FuncOp; +} class Value; class OpBuilder; class AbstractOperation; @@ -38,7 +40,7 @@ namespace mlirclang { /// operands %a and %b. The new op will be inserted at where the insertion point /// of the provided OpBuilder is. mlir::Operation * -replaceFuncByOperation(mlir::FuncOp f, llvm::StringRef opName, +replaceFuncByOperation(mlir::func::FuncOp f, llvm::StringRef opName, mlir::OpBuilder &b, llvm::SmallVectorImpl &input, llvm::SmallVectorImpl &output); diff --git a/tools/mlir-clang/Test/Verification/label.c b/tools/mlir-clang/Test/Verification/label.c index 3faed11e6c7c..1e06b78fd2a3 100644 --- a/tools/mlir-clang/Test/Verification/label.c +++ b/tools/mlir-clang/Test/Verification/label.c @@ -18,13 +18,13 @@ int fir (int d_i[1000], int idx[1000] ) { // CHECK-DAG: %c1000 = arith.constant 1000 : index // CHECK-DAG: %c999 = arith.constant 999 : index // CHECK-DAG: %c0_i32 = arith.constant 0 : i32 -// CHECK-NEXT: %0:2 = scf.for %arg2 = %c0 to %c1000 step %c1 iter_args(%arg3 = %c0_i32, %arg4 = %c0_i32) -> (i32, i32) { +// CHECK-NEXT: %0 = scf.for %arg2 = %c0 to %c1000 step %c1 iter_args(%arg3 = %c0_i32) -> (i32) { // CHECK-NEXT: %1 = memref.load %arg1[%arg2] : memref // CHECK-NEXT: %2 = arith.subi %c999, %arg2 : index // CHECK-NEXT: %3 = memref.load %arg0[%2] : memref // CHECK-NEXT: %4 = arith.muli %1, %3 : i32 // CHECK-NEXT: %5 = arith.addi %arg3, %4 : i32 -// CHECK-NEXT: scf.yield %5, %5 : i32, i32 +// CHECK-NEXT: scf.yield %5 : i32 // CHECK-NEXT: } -// CHECK-NEXT: return %0#1 : i32 +// CHECK-NEXT: return %0 : i32 // CHECK-NEXT: } diff --git a/tools/mlir-clang/Test/Verification/loopinc.c b/tools/mlir-clang/Test/Verification/loopinc.c index 28c209bf1029..c7b58d438deb 100644 --- a/tools/mlir-clang/Test/Verification/loopinc.c +++ b/tools/mlir-clang/Test/Verification/loopinc.c @@ -13,21 +13,14 @@ unsigned int test() { // CHECK: func @test() -> i32 attributes {llvm.linkage = #llvm.linkage} { // CHECK-DAG: %c0_i32 = arith.constant 0 : i32 // CHECK-DAG: %c1_i32 = arith.constant 1 : i32 -// CHECK-DAG: %true = arith.constant true -// CHECK-DAG: %0 = scf.while (%arg0 = %c0_i32, %arg1 = %true) : (i32, i1) -> i32 { -// CHECK-NEXT: scf.condition(%arg1) %arg0 : i32 +// CHECK-NEXT: %0 = scf.while (%arg0 = %c0_i32) : (i32) -> i32 { +// CHECK-NEXT: %1 = arith.shli %c1_i32, %arg0 : i32 +// CHECK-NEXT: %2 = arith.cmpi ult, %1, %c1_i32 : i32 +// CHECK-NEXT: scf.condition(%2) %arg0 : i32 // CHECK-NEXT: } do { // CHECK-NEXT: ^bb0(%arg0: i32): -// CHECK-NEXT: %1 = arith.shli %c1_i32, %arg0 : i32 -// CHECK-NEXT: %2 = arith.cmpi uge, %1, %c1_i32 : i32 -// CHECK-NEXT: %3 = arith.cmpi ult, %1, %c1_i32 : i32 -// CHECK-NEXT: %4 = scf.if %2 -> (i32) { -// CHECK-NEXT: scf.yield %arg0 : i32 -// CHECK-NEXT: } else { -// CHECK-NEXT: %5 = arith.addi %arg0, %c1_i32 : i32 -// CHECK-NEXT: scf.yield %5 : i32 -// CHECK-NEXT: } -// CHECK-NEXT: scf.yield %4, %3 : i32, i1 +// CHECK-NEXT: %1 = arith.addi %arg0, %c1_i32 : i32 +// CHECK-NEXT: scf.yield %1 : i32 // CHECK-NEXT: } // CHECK-NEXT: return %0 : i32 // CHECK-NEXT: } diff --git a/tools/mlir-clang/Test/Verification/min.c b/tools/mlir-clang/Test/Verification/min.c index d7580915bdb9..662d97f2080f 100644 --- a/tools/mlir-clang/Test/Verification/min.c +++ b/tools/mlir-clang/Test/Verification/min.c @@ -1,6 +1,7 @@ // RUN: mlir-clang %s --function=min -S | FileCheck %s -// XFAIL: * +// TODO combine selects + int min(int a, int b) { if (a < b) return a; return b; diff --git a/tools/mlir-clang/Test/Verification/whiletofor.c b/tools/mlir-clang/Test/Verification/whiletofor.c index 443e9e7b42df..41ca203c2f2f 100644 --- a/tools/mlir-clang/Test/Verification/whiletofor.c +++ b/tools/mlir-clang/Test/Verification/whiletofor.c @@ -30,26 +30,25 @@ void whiletofor() { // CHECK-DAG: %c20_i32 = arith.constant 20 : i32 // CHECK-DAG: %c2_i32 = arith.constant 2 : i32 // CHECK-DAG: %c3_i32 = arith.constant 3 : i32 -// CHECK-DAG: %c1_i32 = arith.constant 1 : i32 // CHECK-DAG: %0 = memref.alloca() : memref<100x100xi32> -// CHECK-NEXT: %1 = scf.for %arg0 = %c0 to %c100 step %c1 iter_args(%arg1 = %c7_i32) -> (i32) { -// CHECK-NEXT: %3 = scf.for %arg2 = %c0 to %c100 step %c1 iter_args(%arg3 = %arg1) -> (i32) { -// CHECK-NEXT: %4 = arith.index_cast %arg2 : index to i32 -// CHECK-NEXT: %5 = arith.addi %arg1, %4 : i32 -// CHECK-NEXT: %[[i4:.+]] = arith.remsi %5, %c20_i32 : i32 +// CHECK-NEXT: scf.for %arg0 = %c0 to %c100 step %c1 { +// CHECK-NEXT: %2 = arith.index_cast %arg0 : index to i32 +// CHECK-NEXT: %3 = arith.muli %2, %c100_i32 : i32 +// CHECK-NEXT: %4 = arith.addi %3, %c7_i32 : i32 +// CHECK-NEXT: scf.for %[[arg2:.+]] = %c0 to %c100 step %c1 { +// CHECK-NEXT: %[[a4:.+]] = arith.index_cast %[[arg2]] : index to i32 +// CHECK-NEXT: %[[a5:.+]] = arith.addi %4, %[[a4]] : i32 +// CHECK-NEXT: %[[i4:.+]] = arith.remsi %[[a5]], %c20_i32 : i32 // CHECK-NEXT: %[[i5:.+]] = arith.cmpi eq, %[[i4]], %c0_i32 : i32 // CHECK-NEXT: scf.if %[[i5]] { -// CHECK-NEXT: memref.store %c2_i32, %0[%arg0, %arg2] : memref<100x100xi32> +// CHECK-NEXT: memref.store %c2_i32, %0[%arg0, %[[arg2]]] : memref<100x100xi32> // CHECK-NEXT: } else { -// CHECK-NEXT: memref.store %c3_i32, %0[%arg0, %arg2] : memref<100x100xi32> +// CHECK-NEXT: memref.store %c3_i32, %0[%arg0, %[[arg2]]] : memref<100x100xi32> // CHECK-NEXT: } -// CHECK-NEXT: %[[i6:.+]] = arith.addi %5, %c1_i32 : i32 -// CHECK-NEXT: scf.yield %[[i6]] : i32 // CHECK-NEXT: } -// CHECK-NEXT: scf.yield %3 : i32 // CHECK-NEXT: } -// CHECK-NEXT: %2 = memref.cast %0 : memref<100x100xi32> to memref -// CHECK-NEXT: call @use(%2) : (memref) -> () +// CHECK-NEXT: %[[k2:.+]] = memref.cast %0 : memref<100x100xi32> to memref +// CHECK-NEXT: call @use(%[[k2]]) : (memref) -> () // CHECK-NEXT: return // CHECK-NEXT: } diff --git a/tools/mlir-clang/mlir-clang.cc b/tools/mlir-clang/mlir-clang.cc index f0dff7c5b238..1cc07bfb8011 100644 --- a/tools/mlir-clang/mlir-clang.cc +++ b/tools/mlir-clang/mlir-clang.cc @@ -25,7 +25,6 @@ #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h" -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/GPU/GPUDialect.h" @@ -543,7 +542,9 @@ int main(int argc, char **argv) { mlir::OpPassManager &optPM = pm.nest(); if (CudaLower) { optPM.addPass(mlir::createCanonicalizerPass()); + optPM.addPass(mlir::createCSEPass()); optPM.addPass(polygeist::createMem2RegPass()); + optPM.addPass(mlir::createCanonicalizerPass()); optPM.addPass(mlir::createCSEPass()); optPM.addPass(mlir::createCanonicalizerPass()); optPM.addPass(polygeist::createCanonicalizeForPass()); @@ -567,6 +568,21 @@ int main(int argc, char **argv) { optPM.addPass(polygeist::createCPUifyPass(ToCPU)); } optPM.addPass(mlir::createCanonicalizerPass()); + optPM.addPass(mlir::createCSEPass()); + optPM.addPass(polygeist::createMem2RegPass()); + optPM.addPass(mlir::createCanonicalizerPass()); + optPM.addPass(mlir::createCSEPass()); + if (RaiseToAffine) { + optPM.addPass(polygeist::createCanonicalizeForPass()); + optPM.addPass(mlir::createCanonicalizerPass()); + optPM.addPass(mlir::createLoopInvariantCodeMotionPass()); + optPM.addPass(polygeist::createRaiseSCFToAffinePass()); + optPM.addPass(mlir::createCanonicalizerPass()); + optPM.addPass(polygeist::replaceAffineCFGPass()); + optPM.addPass(mlir::createCanonicalizerPass()); + if (ScalarReplacement) + optPM.addPass(mlir::createAffineScalarReplacementPass()); + } } pm.addPass(mlir::createSymbolDCEPass()); @@ -599,7 +615,7 @@ int main(int argc, char **argv) { // invalid for gemm.c init array // options.useBarePtrCallConv = true; pm3.addPass(polygeist::createConvertPolygeistToLLVMPass(options)); - pm3.addPass(mlir::createLowerToLLVMPass(options)); + // pm3.addPass(mlir::createLowerFuncToLLVMPass(options)); pm3.addPass(mlir::createCanonicalizerPass()); if (mlir::failed(pm3.run(module.get()))) { module->dump(); diff --git a/tools/polygeist-opt/polygeist-opt.cpp b/tools/polygeist-opt/polygeist-opt.cpp index 457b98efbf50..dcc888df9e52 100644 --- a/tools/polygeist-opt/polygeist-opt.cpp +++ b/tools/polygeist-opt/polygeist-opt.cpp @@ -23,8 +23,9 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/SCF/SCF.h" +#include "mlir/InitAllPasses.h" #include "mlir/Pass/PassRegistry.h" -#include "mlir/Support/MlirOptMain.h" +#include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "mlir/Transforms/Passes.h" #include "polygeist/Dialect.h" @@ -69,13 +70,20 @@ int main(int argc, char **argv) { mlir::registerSymbolDCEPass(); mlir::registerLoopInvariantCodeMotionPass(); mlir::registerConvertSCFToOpenMPPass(); + mlir::registerAffinePasses(); - registry.addTypeInterface(); - registry.addTypeInterface(); - registry.addTypeInterface>(); + registry.addExtension( + +[](MLIRContext *ctx, polygeist::PolygeistDialect *dialect) { + LLVM::LLVMPointerType::attachInterface(*ctx); + }); + registry.addExtension( + +[](MLIRContext *ctx, polygeist::PolygeistDialect *dialect) { + LLVM::LLVMStructType::attachInterface(*ctx); + }); + registry.addExtension( + +[](MLIRContext *ctx, polygeist::PolygeistDialect *dialect) { + MemRefType::attachInterface>(*ctx); + }); return mlir::failed(mlir::MlirOptMain( argc, argv, "Polygeist modular optimizer driver", registry,