-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ifTransform] Added naive if-transformation implementation
- Loading branch information
1 parent
3a31428
commit ad9690e
Showing
5 changed files
with
167 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") | ||
|
||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
cc_library( | ||
name = "IfTransform", | ||
srcs = ["IfTransform.cpp"], | ||
hdrs = ["IfTransform.h"], | ||
deps = [ | ||
":pass_inc_gen", | ||
"@llvm-project//mlir:IR", | ||
"@llvm-project//mlir:Pass", | ||
"@llvm-project//mlir:SCFDialect", | ||
"@llvm-project//mlir:Support", | ||
"@llvm-project//mlir:Transforms", | ||
], | ||
) | ||
|
||
gentbl_cc_library( | ||
name = "pass_inc_gen", | ||
tbl_outs = [ | ||
( | ||
[ | ||
"-gen-pass-decls", | ||
"-name=IfTransform", | ||
], | ||
"IfTransform.h.inc", | ||
), | ||
( | ||
["-gen-pass-doc"], | ||
"IfTransformPasses.md", | ||
), | ||
], | ||
tblgen = "@llvm-project//mlir:mlir-tblgen", | ||
td_file = "IfTransform.td", | ||
deps = [ | ||
"@llvm-project//mlir:OpBaseTdFiles", | ||
"@llvm-project//mlir:PassBaseTdFiles", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
#include "lib/Transforms/IfTransform/IfTransform.h" | ||
|
||
#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project | ||
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project | ||
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project | ||
|
||
using namespace mlir; | ||
using namespace mlir::scf; | ||
|
||
namespace mlir { | ||
namespace heir { | ||
|
||
#define GEN_PASS_DEF_IFTRANSFORM | ||
#include "lib/Transforms/IfTransform/IfTransform.h.inc" | ||
|
||
struct ConvertIfToSelect : OpRewritePattern<IfOp> { | ||
using OpRewritePattern<IfOp>::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(IfOp ifOp, | ||
PatternRewriter &rewriter) const override { | ||
auto cond = ifOp.getCondition(); | ||
auto thenYieldArgs = ifOp.thenYield().getOperands(); | ||
auto elseYieldArgs = ifOp.elseYield().getOperands(); | ||
|
||
SmallVector<Value> newIfResults(ifOp->getNumResults()); | ||
// 1. Hoist instructions in the 'then' block | ||
for (auto &operation : ifOp.getThenRegion().getOps()) { | ||
auto distance = std::distance(ifOp.getThenRegion().getOps().begin(), | ||
ifOp.getThenRegion().getOps().end()); | ||
if (distance == 1) { | ||
break; | ||
} | ||
rewriter.setInsertionPointToStart(ifOp.thenBlock()); | ||
rewriter.moveOpBefore(&operation, ifOp); | ||
} | ||
// 2. Hoist instructions in the 'else' block | ||
for (auto &operation : ifOp.getElseRegion().getOps()) { | ||
auto distance = std::distance(ifOp.getElseRegion().getOps().begin(), | ||
ifOp.getElseRegion().getOps().end()); | ||
if (distance == 1) { | ||
break; | ||
} | ||
rewriter.setInsertionPointToStart(ifOp.thenBlock()); | ||
rewriter.moveOpBefore(&operation, ifOp); | ||
} | ||
|
||
// 3. Translate YieldOp into a SelectOp | ||
auto thenDistance = std::distance(ifOp.getThenRegion().getOps().begin(), | ||
ifOp.getThenRegion().getOps().end()); | ||
auto elseDistance = std::distance(ifOp.getElseRegion().getOps().begin(), | ||
ifOp.getElseRegion().getOps().end()); | ||
|
||
if (ifOp.getNumResults() > 0 && thenDistance == elseDistance) { | ||
rewriter.setInsertionPoint(ifOp); | ||
|
||
for (const auto &it : | ||
llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) { | ||
Value trueVal = std::get<0>(it.value()); | ||
Value falseVal = std::get<1>(it.value()); | ||
newIfResults[it.index()] = rewriter.create<arith::SelectOp>( | ||
ifOp.getLoc(), cond, trueVal, falseVal); | ||
} | ||
rewriter.replaceOp(ifOp, newIfResults); | ||
|
||
return success(); | ||
} | ||
|
||
return failure(); | ||
} | ||
}; | ||
|
||
struct IfTransform : impl::IfTransformBase<IfTransform> { | ||
using IfTransformBase::IfTransformBase; | ||
|
||
void runOnOperation() override { | ||
MLIRContext *context = &getContext(); | ||
RewritePatternSet patterns(context); | ||
|
||
patterns.add<ConvertIfToSelect>(context); | ||
|
||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); | ||
} | ||
}; | ||
|
||
} // namespace heir | ||
} // namespace mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
#ifndef LIB_TRANSFORMS_IFTRANSFORM_IFTRANSFORM_H_ | ||
#define LIB_TRANSFORMS_IFTRANSFORM_IFTRANSFORM_H_ | ||
|
||
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project | ||
|
||
namespace mlir { | ||
namespace heir { | ||
|
||
#define GEN_PASS_DECL | ||
#include "lib/Transforms/IfTransform/IfTransform.h.inc" | ||
|
||
#define GEN_PASS_REGISTRATION | ||
#include "lib/Transforms/IfTransform/IfTransform.h.inc" | ||
|
||
} // namespace heir | ||
} // namespace mlir | ||
|
||
#endif // LIB_TRANSFORMS_IFTRANSFORM_IFTRANSFORM_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#ifndef LIB_TRANSFORMS_IFTRANSFORM_IFTRANSFORM_TD_ | ||
#define LIB_TRANSFORMS_IFTRANSFORM_IFTRANSFORM_TD_ | ||
|
||
include "mlir/Pass/PassBase.td" | ||
|
||
def IfTransform : Pass<"if-transform"> { | ||
let summary = "Transform If-operations on secret conditions"; | ||
let description = [{ Transformation for If-operations that evaluate secret condition(s). | ||
}]; | ||
let dependentDialects = [ | ||
"mlir::scf::SCFDialect", | ||
]; | ||
} | ||
|
||
#endif // LIB_TRANSFORMS_IFTRANSFORM_IFTRANSFORM_TD_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# !/bin/bash | ||
bazel run //tools:heir-opt -- --if-transform ~/heir/tests/if_transform/secret_condition_with_secret_int.mlir | ||
# bazel run //tools:heir-opt -- --if-transform ~/heir/tests/if_transform/secret_condition_with_secret_tensor.mlir |