From 7b6751ce85b5cc688e950dc79759d3612d4839ca Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Thu, 29 Feb 2024 17:20:46 -0800 Subject: [PATCH] add a full loop unroll pass --- include/Transforms/FullLoopUnroll/BUILD | 35 ++++++++++++++++++ .../FullLoopUnroll/FullLoopUnroll.h | 18 ++++++++++ .../FullLoopUnroll/FullLoopUnroll.td | 16 +++++++++ lib/Transforms/FullLoopUnroll/BUILD | 22 ++++++++++++ .../FullLoopUnroll/FullLoopUnroll.cpp | 36 +++++++++++++++++++ tools/BUILD | 1 + tools/heir-opt.cpp | 2 ++ 7 files changed, 130 insertions(+) create mode 100644 include/Transforms/FullLoopUnroll/BUILD create mode 100644 include/Transforms/FullLoopUnroll/FullLoopUnroll.h create mode 100644 include/Transforms/FullLoopUnroll/FullLoopUnroll.td create mode 100644 lib/Transforms/FullLoopUnroll/BUILD create mode 100644 lib/Transforms/FullLoopUnroll/FullLoopUnroll.cpp diff --git a/include/Transforms/FullLoopUnroll/BUILD b/include/Transforms/FullLoopUnroll/BUILD new file mode 100644 index 000000000..a51a8eb14 --- /dev/null +++ b/include/Transforms/FullLoopUnroll/BUILD @@ -0,0 +1,35 @@ +# FullLoopUnroll tablegen and headers. + +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +exports_files([ + "FullLoopUnroll.h", +]) + +gentbl_cc_library( + name = "pass_inc_gen", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=FullLoopUnroll", + ], + "FullLoopUnroll.h.inc", + ), + ( + ["-gen-pass-doc"], + "FullLoopUnrollPasses.md", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "FullLoopUnroll.td", + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + ], +) diff --git a/include/Transforms/FullLoopUnroll/FullLoopUnroll.h b/include/Transforms/FullLoopUnroll/FullLoopUnroll.h new file mode 100644 index 000000000..f7ed05963 --- /dev/null +++ b/include/Transforms/FullLoopUnroll/FullLoopUnroll.h @@ -0,0 +1,18 @@ +#ifndef INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_H_ +#define INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace heir { + +#define GEN_PASS_DECL +#include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h.inc" + +#define GEN_PASS_REGISTRATION +#include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h.inc" + +} // namespace heir +} // namespace mlir + +#endif // INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_H_ diff --git a/include/Transforms/FullLoopUnroll/FullLoopUnroll.td b/include/Transforms/FullLoopUnroll/FullLoopUnroll.td new file mode 100644 index 000000000..e800ee384 --- /dev/null +++ b/include/Transforms/FullLoopUnroll/FullLoopUnroll.td @@ -0,0 +1,16 @@ +#ifndef INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_TD_ +#define INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_TD_ + +include "mlir/Pass/PassBase.td" + +def FullLoopUnroll : Pass<"full-loop-unroll"> { + let summary = "Fully unroll all loops"; + let description = [{ + Scan the IR for affine.for loops and unroll them all. + }]; + let dependentDialects = [ + + ]; +} + +#endif // INCLUDE_TRANSFORMS_FULLLOOPUNROLL_FULLLOOPUNROLL_TD_ diff --git a/lib/Transforms/FullLoopUnroll/BUILD b/lib/Transforms/FullLoopUnroll/BUILD new file mode 100644 index 000000000..8b40f31d9 --- /dev/null +++ b/lib/Transforms/FullLoopUnroll/BUILD @@ -0,0 +1,22 @@ +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "FullLoopUnroll", + srcs = ["FullLoopUnroll.cpp"], + hdrs = [ + "@heir//include/Transforms/FullLoopUnroll:FullLoopUnroll.h", + ], + deps = [ + "@heir//include/Transforms/FullLoopUnroll:pass_inc_gen", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:AffineUtils", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], +) diff --git a/lib/Transforms/FullLoopUnroll/FullLoopUnroll.cpp b/lib/Transforms/FullLoopUnroll/FullLoopUnroll.cpp new file mode 100644 index 000000000..b48b3f7b7 --- /dev/null +++ b/lib/Transforms/FullLoopUnroll/FullLoopUnroll.cpp @@ -0,0 +1,36 @@ +#include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h" + +#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Affine/LoopUtils.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project + +namespace mlir { +namespace heir { + +#define GEN_PASS_DEF_FULLLOOPUNROLL +#include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h.inc" + +struct AffineFullUnrollPattern : public OpRewritePattern { + AffineFullUnrollPattern(mlir::MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/1) {} + + LogicalResult matchAndRewrite(affine::AffineForOp op, + PatternRewriter &rewriter) const override { + return mlir::affine::loopUnrollFull(op); + } +}; + +struct FullLoopUnroll : impl::FullLoopUnrollBase { + using FullLoopUnrollBase::FullLoopUnrollBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + patterns.add(context); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + +} // namespace heir +} // namespace mlir diff --git a/tools/BUILD b/tools/BUILD index 823635c91..0c9bc44a6 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -55,6 +55,7 @@ cc_binary( "@heir//lib/Dialect/TfheRust/IR:Dialect", "@heir//lib/Dialect/TfheRustBool/IR:Dialect", "@heir//lib/Transforms/ForwardStoreToLoad", + "@heir//lib/Transforms/FullLoopUnroll", "@heir//lib/Transforms/Secretize", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index 1e9d54b42..8eb0ed41e 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -22,6 +22,7 @@ #include "include/Dialect/TfheRust/IR/TfheRustDialect.h" #include "include/Dialect/TfheRustBool/IR/TfheRustBoolDialect.h" #include "include/Transforms/ForwardStoreToLoad/ForwardStoreToLoad.h" +#include "include/Transforms/FullLoopUnroll/FullLoopUnroll.h" #include "include/Transforms/Secretize/Passes.h" #include "llvm/include/llvm/Support/CommandLine.h" // from @llvm-project #include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project @@ -288,6 +289,7 @@ int main(int argc, char **argv) { lwe::registerLWEPasses(); secret::registerSecretPasses(); registerSecretizePasses(); + registerFullLoopUnrollPasses(); registerForwardStoreToLoadPasses(); // Register yosys optimizer pipeline if configured. #ifndef HEIR_NO_YOSYS