Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/polygeist/Dialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def Polygeist_Dialect : Dialect {
// Base BFV operation definition.
//===----------------------------------------------------------------------===//

class Polygeist_Op<string mnemonic, list<OpTrait> traits = []>
class Polygeist_Op<string mnemonic, list<Trait> traits = []>
: Op<Polygeist_Dialect, mnemonic, traits>;

#endif // POLYGEIST_DIALECT
18 changes: 9 additions & 9 deletions include/polygeist/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

include "mlir/Pass/PassBase.td"

def AffineCFG : FunctionPass<"affine-cfg"> {
def AffineCFG : Pass<"affine-cfg", "FuncOp"> {
let summary = "Replace scf.if and similar with affine.if";
let constructor = "mlir::polygeist::replaceAffineCFGPass()";
}

def Mem2Reg : FunctionPass<"mem2reg"> {
def Mem2Reg : Pass<"mem2reg", "FuncOp"> {
let summary = "Replace scf.if and similar with affine.if";
let constructor = "mlir::polygeist::createMem2RegPass()";
}
Expand All @@ -20,12 +20,12 @@ def ParallelLower : Pass<"parallel-lower", "mlir::ModuleOp"> {
let constructor = "mlir::polygeist::createParallelLowerPass()";
}

def AffineReduction : FunctionPass<"detect-reduction"> {
def AffineReduction : Pass<"detect-reduction", "FuncOp"> {
let summary = "Detect reductions in affine.for";
let constructor = "mlir::polygeist::detectReductionPass()";
}

def SCFCPUify : FunctionPass<"cpuify"> {
def SCFCPUify : Pass<"cpuify", "FuncOp"> {
let summary = "remove scf.barrier";
let constructor = "mlir::polygeist::createCPUifyPass()";
let dependentDialects =
Expand All @@ -35,29 +35,29 @@ def SCFCPUify : FunctionPass<"cpuify"> {
];
}

def SCFBarrierRemovalContinuation : FunctionPass<"barrier-removal-continuation"> {
def SCFBarrierRemovalContinuation : Pass<"barrier-removal-continuation", "FuncOp"> {
let summary = "Remove scf.barrier using continuations";
let constructor = "mlir::polygeist::createBarrierRemovalContinuation()";
let dependentDialects = ["memref::MemRefDialect", "StandardOpsDialect"];
}

def SCFRaiseToAffine : FunctionPass<"raise-scf-to-affine"> {
def SCFRaiseToAffine : Pass<"raise-scf-to-affine", "FuncOp"> {
let summary = "Raise SCF to affine";
let constructor = "mlir::polygeist::createRaiseSCFToAffinePass()";
let dependentDialects = ["AffineDialect"];
}

def SCFCanonicalizeFor : FunctionPass<"canonicalize-scf-for"> {
def SCFCanonicalizeFor : Pass<"canonicalize-scf-for", "FuncOp"> {
let summary = "Run some additional canonicalization for scf::for";
let constructor = "mlir::polygeist::createCanonicalizeForPass()";
}

def LoopRestructure : FunctionPass<"loop-restructure"> {
def LoopRestructure : Pass<"loop-restructure", "FuncOp"> {
let constructor = "mlir::polygeist::createLoopRestructurePass()";
let dependentDialects = ["::mlir::scf::SCFDialect"];
}

def RemoveTrivialUse : FunctionPass<"trivialuse"> {
def RemoveTrivialUse : Pass<"trivialuse", "FuncOp"> {
let constructor = "mlir::polygeist::createRemoveTrivialUsePass()";
}

Expand Down
36 changes: 18 additions & 18 deletions lib/polygeist/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,23 @@
//===----------------------------------------------------------------------===//

#include "polygeist/Ops.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "polygeist/Dialect.h"
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>

#define GET_OP_CLASSES
#include "polygeist/PolygeistOps.cpp.inc"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Utils/Utils.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include <mlir/Dialect/SCF/SCF.h>

using namespace mlir;
using namespace polygeist;
Expand Down Expand Up @@ -161,8 +161,8 @@ class SubToCast final : public OpRewritePattern<SubIndexOp> {
if (cidx.value() != 0)
return failure();

rewriter.replaceOpWithNewOp<memref::CastOp>(subViewOp, subViewOp.source(),
post);
rewriter.replaceOpWithNewOp<memref::CastOp>(subViewOp, post,
subViewOp.source());
return success();
}

Expand Down Expand Up @@ -724,7 +724,7 @@ MutableOperandRange LoadSelect<LLVM::LoadOp>::ptrMutable(LLVM::LoadOp op) {
return op.getAddrMutable();
}

void SubIndexOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
void SubIndexOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<CastOfSubIndex, SubIndex2, SubToCast, SimplifySubViewUsers,
SimplifySubIndexUsers, SelectOfCast, SelectOfSubIndex,
Expand Down Expand Up @@ -870,8 +870,8 @@ class CopySimplification final : public OpRewritePattern<T> {
op.getLoc(), c0,
rewriter.create<arith::DivUIOp>(
op.getLoc(),
rewriter.create<arith::IndexCastOp>(op.getLoc(), op.getLen(),
rewriter.getIndexType()),
rewriter.create<arith::IndexCastOp>(
op.getLoc(), rewriter.getIndexType(), op.getLen()),
rewriter.create<arith::ConstantIndexOp>(op.getLoc(), width)),
c1);

Expand Down Expand Up @@ -918,8 +918,8 @@ OpFoldResult Memref2PointerOp::fold(ArrayRef<Attribute> operands) {
return nullptr;
}

void Memref2PointerOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
void Memref2PointerOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<Memref2Pointer2MemrefCast, Memref2PointerIndex,
CopySimplification<LLVM::MemcpyOp>,
CopySimplification<LLVM::MemmoveOp>>(context);
Expand Down Expand Up @@ -1021,16 +1021,16 @@ class MetaPointer2Memref final : public OpRewritePattern<Op> {
auto shape = mt.getShape();
for (size_t i = 0; i < shape.size(); i++) {
auto off = computeIndex(op, i, rewriter);
auto cur =
rewriter.create<IndexCastOp>(op.getLoc(), rewriter.getI32Type(), off);
auto cur = rewriter.create<arith::IndexCastOp>(
op.getLoc(), rewriter.getI32Type(), off);
if (idx == nullptr) {
idx = cur;
} else {
idx = rewriter.create<AddIOp>(
op.getLoc(),
rewriter.create<MulIOp>(
op.getLoc(), idx,
rewriter.create<ConstantIntOp>(op.getLoc(), shape[i], 32)),
rewriter.create<MulIOp>(op.getLoc(), idx,
rewriter.create<arith::ConstantIntOp>(
op.getLoc(), shape[i], 32)),
cur);
}
}
Expand Down Expand Up @@ -1433,8 +1433,8 @@ struct MoveIntoIfs : public OpRewritePattern<scf::IfOp> {
}
};

void Pointer2MemrefOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
void Pointer2MemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<
Pointer2MemrefCast, Pointer2Memref2PointerCast,
MetaPointer2Memref<memref::LoadOp>, MetaPointer2Memref<memref::StoreOp>,
Expand Down Expand Up @@ -1517,7 +1517,7 @@ struct TypeSizeCanonicalize : public OpRewritePattern<TypeSizeOp> {
}
};

void TypeSizeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
void TypeSizeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<TypeSizeCanonicalize>(context);
}
52 changes: 27 additions & 25 deletions lib/polygeist/Passes/AffineCFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ void fully2ComposeIntegerSetAndOperands(IntegerSet *set,

namespace {
struct AffineCFGPass : public AffineCFGBase<AffineCFGPass> {
void runOnFunction() override;
void runOnOperation() override;
};
} // namespace

Expand Down Expand Up @@ -561,9 +561,9 @@ struct SimplfyIntegerCastMath : public OpRewritePattern<IndexCastOp> {
setLocationAfter(b2, iadd.getOperand(1));
rewriter.replaceOpWithNewOp<AddIOp>(
op,
b.create<IndexCastOp>(op.getLoc(), iadd.getOperand(0), op.getType()),
b2.create<IndexCastOp>(op.getLoc(), iadd.getOperand(1),
op.getType()));
b.create<IndexCastOp>(op.getLoc(), op.getType(), iadd.getOperand(0)),
b2.create<IndexCastOp>(op.getLoc(), op.getType(),
iadd.getOperand(1)));
return success();
}
if (auto iadd = op.getOperand().getDefiningOp<SubIOp>()) {
Expand All @@ -573,9 +573,10 @@ struct SimplfyIntegerCastMath : public OpRewritePattern<IndexCastOp> {
setLocationAfter(b2, iadd.getOperand(1));
rewriter.replaceOpWithNewOp<SubIOp>(
op,
b.create<IndexCastOp>(op.getLoc(), iadd.getOperand(0), op.getType()),
b2.create<IndexCastOp>(op.getLoc(), iadd.getOperand(1),
op.getType()));
b.create<arith::IndexCastOp>(op.getLoc(), op.getType(),
iadd.getOperand(0)),
b2.create<arith::IndexCastOp>(op.getLoc(), op.getType(),
iadd.getOperand(1)));
return success();
}
if (auto iadd = op.getOperand().getDefiningOp<MulIOp>()) {
Expand All @@ -585,9 +586,9 @@ struct SimplfyIntegerCastMath : public OpRewritePattern<IndexCastOp> {
setLocationAfter(b2, iadd.getOperand(1));
rewriter.replaceOpWithNewOp<MulIOp>(
op,
b.create<IndexCastOp>(op.getLoc(), iadd.getOperand(0), op.getType()),
b2.create<IndexCastOp>(op.getLoc(), iadd.getOperand(1),
op.getType()));
b.create<IndexCastOp>(op.getLoc(), op.getType(), iadd.getOperand(0)),
b2.create<IndexCastOp>(op.getLoc(), op.getType(),
iadd.getOperand(1)));
return success();
}
if (auto iadd = op.getOperand().getDefiningOp<DivUIOp>()) {
Expand All @@ -597,9 +598,10 @@ struct SimplfyIntegerCastMath : public OpRewritePattern<IndexCastOp> {
setLocationAfter(b2, iadd.getOperand(1));
rewriter.replaceOpWithNewOp<DivUIOp>(
op,
b.create<IndexCastOp>(op.getLoc(), iadd.getOperand(0), op.getType()),
b2.create<IndexCastOp>(op.getLoc(), iadd.getOperand(1),
op.getType()));
b.create<arith::IndexCastOp>(op.getLoc(), op.getType(),
iadd.getOperand(0)),
b2.create<arith::IndexCastOp>(op.getLoc(), op.getType(),
iadd.getOperand(1)));
return success();
}
if (auto iadd = op.getOperand().getDefiningOp<DivSIOp>()) {
Expand All @@ -609,9 +611,10 @@ struct SimplfyIntegerCastMath : public OpRewritePattern<IndexCastOp> {
setLocationAfter(b2, iadd.getOperand(1));
rewriter.replaceOpWithNewOp<DivSIOp>(
op,
b.create<IndexCastOp>(op.getLoc(), iadd.getOperand(0), op.getType()),
b2.create<IndexCastOp>(op.getLoc(), iadd.getOperand(1),
op.getType()));
b.create<arith::IndexCastOp>(op.getLoc(), op.getType(),
iadd.getOperand(0)),
b2.create<arith::IndexCastOp>(op.getLoc(), op.getType(),
iadd.getOperand(1)));
return success();
}
return failure();
Expand Down Expand Up @@ -753,17 +756,17 @@ bool handle(OpBuilder &b, CmpIOp cmpi, SmallVectorImpl<AffineExpr> &exprs,
}
SmallVector<Value, 4> lhspack = {cmpi.getLhs()};
if (!lhspack[0].getType().isa<IndexType>()) {
auto op = b.create<IndexCastOp>(cmpi.getLoc(), lhspack[0],
IndexType::get(cmpi.getContext()));
auto op = b.create<arith::IndexCastOp>(
cmpi.getLoc(), IndexType::get(cmpi.getContext()), lhspack[0]);
lhspack[0] = op;
}

AffineMap rhsmap =
AffineMap::get(0, 1, getAffineSymbolExpr(0, cmpi.getContext()));
SmallVector<Value, 4> rhspack = {cmpi.getRhs()};
if (!rhspack[0].getType().isa<IndexType>()) {
auto op = b.create<IndexCastOp>(cmpi.getLoc(), rhspack[0],
IndexType::get(cmpi.getContext()));
auto op = b.create<arith::IndexCastOp>(
cmpi.getLoc(), IndexType::get(cmpi.getContext()), rhspack[0]);
rhspack[0] = op;
}

Expand Down Expand Up @@ -1131,16 +1134,15 @@ struct MoveIfToAffine : public OpRewritePattern<scf::IfOp> {
}
};

void AffineCFGPass::runOnFunction() {
mlir::RewritePatternSet rpl(getFunction().getContext());
void AffineCFGPass::runOnOperation() {
mlir::RewritePatternSet rpl(getOperation().getContext());
rpl.add<SimplfyIntegerCastMath, CanonicalizeAffineApply,
CanonicalizeIndexCast, IndexCastMovement, AffineFixup<AffineLoadOp>,
AffineFixup<AffineStoreOp>, CanonicalizIfBounds, MoveStoreToAffine,
MoveIfToAffine, MoveLoadToAffine, CanonicalieForBounds>(
getFunction().getContext());
getOperation().getContext());
GreedyRewriteConfig config;
(void)applyPatternsAndFoldGreedily(getFunction().getOperation(),
std::move(rpl), config);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(rpl), config);
}

std::unique_ptr<OperationPass<FuncOp>> mlir::polygeist::replaceAffineCFGPass() {
Expand Down
11 changes: 5 additions & 6 deletions lib/polygeist/Passes/AffineReduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using namespace polygeist;

namespace {
struct AffineReductionPass : public AffineReductionBase<AffineReductionPass> {
void runOnFunction() override;
void runOnOperation() override;
};
} // end namespace.

Expand Down Expand Up @@ -254,12 +254,11 @@ struct AffineForReductionIter : public OpRewritePattern<AffineForOp> {

} // end namespace.

void AffineReductionPass::runOnFunction() {
mlir::RewritePatternSet rpl(getFunction().getContext());
rpl.add<AffineForReductionIter>(getFunction().getContext());
void AffineReductionPass::runOnOperation() {
mlir::RewritePatternSet rpl(getOperation().getContext());
rpl.add<AffineForReductionIter>(getOperation().getContext());
GreedyRewriteConfig config;
(void)applyPatternsAndFoldGreedily(getFunction().getOperation(),
std::move(rpl), config);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(rpl), config);
}

namespace mlir {
Expand Down
19 changes: 10 additions & 9 deletions lib/polygeist/Passes/BarrierRemovalContinuation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@

#include "PassDetails.h"

#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Passes.h"
#include "mlir/Dialect/SCF/SCF.h"
Expand All @@ -27,7 +29,6 @@
#include "mlir/Transforms/DialectConversion.h"
#include "polygeist/BarrierUtils.h"
#include "polygeist/Passes/Passes.h"
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>

using namespace mlir;
using namespace mlir::arith;
Expand Down Expand Up @@ -72,7 +73,7 @@ static void wrapPersistingLoopBodies(FuncOp function) {
/// Convert SCF constructs except parallel ops with immediate barriers to a CFG.
static LogicalResult applyCFGConversion(FuncOp function) {
RewritePatternSet patterns(function.getContext());
populateLoopToStdConversionPatterns(patterns);
populateSCFToControlFlowConversionPatterns(patterns);

// Configure the target to preserve parallel ops with barriers, unless those
// barriers are nested in deeper parallel ops.
Expand Down Expand Up @@ -104,7 +105,7 @@ static void splitBlocksWithBarrier(Region &region) {
Block *original = op->getBlock();
Block *block = original->splitBlock(op->getNextNode());
auto builder = OpBuilder::atBlockEnd(original);
builder.create<BranchOp>(builder.getUnknownLoc(), block);
builder.create<cf::BranchOp>(builder.getUnknownLoc(), block);
}
}

Expand Down Expand Up @@ -225,8 +226,8 @@ replicateIntoRegion(Region &region, Value storage, ValueRange ivs,

// Branch from the entry block to the first cloned block.
builder.setInsertionPointToEnd(entryBlock);
builder.create<BranchOp>(builder.getUnknownLoc(),
mapping.lookup(blocks.front()));
builder.create<cf::BranchOp>(builder.getUnknownLoc(),
mapping.lookup(blocks.front()));

// Now that the block structure is created, clone the operations and introduce
// the flow between continuations.
Expand All @@ -244,7 +245,7 @@ replicateIntoRegion(Region &region, Value storage, ValueRange ivs,
// blocks are assumed to branch to the entry block of another subgraph.
// They are replaced with storing the correspnding continuation ID and a
// yield.
if (auto branch = dyn_cast<BranchOp>(&op)) {
if (auto branch = dyn_cast<cf::BranchOp>(&op)) {
// if (!blocks.contains(branch.dest())) {
if (isa_and_nonnull<polygeist::BarrierOp>(branch->getPrevNode())) {
auto it = llvm::find(subgraphEntryPoints, branch.getDest());
Expand Down Expand Up @@ -603,8 +604,8 @@ static void createContinuations(FuncOp func) {
namespace {
struct BarrierRemoval
: public SCFBarrierRemovalContinuationBase<BarrierRemoval> {
void runOnFunction() override {
auto f = getFunction();
void runOnOperation() override {
FuncOp f = getOperation();
if (failed(convertToCFG(f)))
return;
if (failed(splitBlocksWithBarrier(f)))
Expand Down
Loading