Skip to content

Commit

Permalink
[ifTransform] Added naive if-transformation implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
MeronZerihun committed Jun 14, 2024
1 parent 3a31428 commit ad9690e
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 0 deletions.
43 changes: 43 additions & 0 deletions lib/Transforms/IfTransform/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "IfTransform",
srcs = ["IfTransform.cpp"],
hdrs = ["IfTransform.h"],
deps = [
":pass_inc_gen",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
)

gentbl_cc_library(
name = "pass_inc_gen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=IfTransform",
],
"IfTransform.h.inc",
),
(
["-gen-pass-doc"],
"IfTransformPasses.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "IfTransform.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)
88 changes: 88 additions & 0 deletions lib/Transforms/IfTransform/IfTransform.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#include "lib/Transforms/IfTransform/IfTransform.h"

#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project

using namespace mlir;
using namespace mlir::scf;

namespace mlir {
namespace heir {

#define GEN_PASS_DEF_IFTRANSFORM
#include "lib/Transforms/IfTransform/IfTransform.h.inc"

struct ConvertIfToSelect : OpRewritePattern<IfOp> {
using OpRewritePattern<IfOp>::OpRewritePattern;

LogicalResult matchAndRewrite(IfOp ifOp,
PatternRewriter &rewriter) const override {
auto cond = ifOp.getCondition();
auto thenYieldArgs = ifOp.thenYield().getOperands();
auto elseYieldArgs = ifOp.elseYield().getOperands();

SmallVector<Value> newIfResults(ifOp->getNumResults());
// 1. Hoist instructions in the 'then' block
for (auto &operation : ifOp.getThenRegion().getOps()) {
auto distance = std::distance(ifOp.getThenRegion().getOps().begin(),
ifOp.getThenRegion().getOps().end());
if (distance == 1) {
break;
}
rewriter.setInsertionPointToStart(ifOp.thenBlock());
rewriter.moveOpBefore(&operation, ifOp);
}
// 2. Hoist instructions in the 'else' block
for (auto &operation : ifOp.getElseRegion().getOps()) {
auto distance = std::distance(ifOp.getElseRegion().getOps().begin(),
ifOp.getElseRegion().getOps().end());
if (distance == 1) {
break;
}
rewriter.setInsertionPointToStart(ifOp.thenBlock());
rewriter.moveOpBefore(&operation, ifOp);
}

// 3. Translate YieldOp into a SelectOp
auto thenDistance = std::distance(ifOp.getThenRegion().getOps().begin(),
ifOp.getThenRegion().getOps().end());
auto elseDistance = std::distance(ifOp.getElseRegion().getOps().begin(),
ifOp.getElseRegion().getOps().end());

if (ifOp.getNumResults() > 0 && thenDistance == elseDistance) {
rewriter.setInsertionPoint(ifOp);

for (const auto &it :
llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
Value trueVal = std::get<0>(it.value());
Value falseVal = std::get<1>(it.value());
newIfResults[it.index()] = rewriter.create<arith::SelectOp>(
ifOp.getLoc(), cond, trueVal, falseVal);
}
rewriter.replaceOp(ifOp, newIfResults);

return success();
}

return failure();
}
};

struct IfTransform : impl::IfTransformBase<IfTransform> {
using IfTransformBase::IfTransformBase;

void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);

patterns.add<ConvertIfToSelect>(context);

(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

} // namespace heir
} // namespace mlir
18 changes: 18 additions & 0 deletions lib/Transforms/IfTransform/IfTransform.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef LIB_TRANSFORMS_IFTRANSFORM_IFTRANSFORM_H_
#define LIB_TRANSFORMS_IFTRANSFORM_IFTRANSFORM_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir {
namespace heir {

#define GEN_PASS_DECL
#include "lib/Transforms/IfTransform/IfTransform.h.inc"

#define GEN_PASS_REGISTRATION
#include "lib/Transforms/IfTransform/IfTransform.h.inc"

} // namespace heir
} // namespace mlir

#endif // LIB_TRANSFORMS_IFTRANSFORM_IFTRANSFORM_H_
15 changes: 15 additions & 0 deletions lib/Transforms/IfTransform/IfTransform.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef LIB_TRANSFORMS_IFTRANSFORM_IFTRANSFORM_TD_
#define LIB_TRANSFORMS_IFTRANSFORM_IFTRANSFORM_TD_

include "mlir/Pass/PassBase.td"

def IfTransform : Pass<"if-transform"> {
let summary = "Transform If-operations on secret conditions";
let description = [{ Transformation for If-operations that evaluate secret condition(s).
}];
let dependentDialects = [
"mlir::scf::SCFDialect",
];
}

#endif // LIB_TRANSFORMS_IFTRANSFORM_IFTRANSFORM_TD_
3 changes: 3 additions & 0 deletions lib/Transforms/IfTransform/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# !/bin/bash
bazel run //tools:heir-opt -- --if-transform ~/heir/tests/if_transform/secret_condition_with_secret_int.mlir
# bazel run //tools:heir-opt -- --if-transform ~/heir/tests/if_transform/secret_condition_with_secret_tensor.mlir

0 comments on commit ad9690e

Please sign in to comment.