Skip to content

Commit

Permalink
Added basic --convert-if-to-select pass
Browse files Browse the repository at this point in the history
Basic pass that converts all scf.if operations to arith.select operations.
  • Loading branch information
MeronZerihun committed Jun 28, 2024
1 parent e3075cf commit 79af940
Show file tree
Hide file tree
Showing 11 changed files with 561 additions and 0 deletions.
44 changes: 44 additions & 0 deletions lib/Transforms/ConvertIfToSelect/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "ConvertIfToSelect",
srcs = ["ConvertIfToSelect.cpp"],
hdrs = ["ConvertIfToSelect.h"],
deps = [
":pass_inc_gen",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
)

gentbl_cc_library(
name = "pass_inc_gen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=ConvertIfToSelect",
],
"ConvertIfToSelect.h.inc",
),
(
["-gen-pass-doc"],
"ConvertIfToSelectPasses.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "ConvertIfToSelect.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)
77 changes: 77 additions & 0 deletions lib/Transforms/ConvertIfToSelect/ConvertIfToSelect.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#include "lib/Transforms/ConvertIfToSelect/ConvertIfToSelect.h"

#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project

namespace mlir {
namespace heir {

#define GEN_PASS_DEF_CONVERTIFTOSELECT
#include "lib/Transforms/ConvertIfToSelect/ConvertIfToSelect.h.inc"

struct IfToSelectConversion : OpRewritePattern<scf::IfOp> {
using OpRewritePattern<scf::IfOp>::OpRewritePattern;

LogicalResult matchAndRewrite(scf::IfOp ifOp,
PatternRewriter &rewriter) const override {
// Hoist instructions in the 'then' and 'else' regions
auto thenOps = ifOp.getThenRegion().getOps();
auto elseOps = ifOp.getElseRegion().getOps();

rewriter.setInsertionPointToStart(ifOp.thenBlock());
for (auto &operation : llvm::make_early_inc_range(
llvm::concat<Operation>(thenOps, elseOps))) {
if (!isPure(&operation)) {
ifOp->emitError()
<< "Can't convert scf.if to arith.select operation. If-operation "
"contains code that can't be safely hoisted on line "
<< operation.getLoc();
return failure();
}
if (!llvm::isa<scf::YieldOp>(operation)) {
rewriter.moveOpBefore(&operation, ifOp);
}
}

// Translate YieldOp into SelectOp
auto cond = ifOp.getCondition();
auto thenYieldArgs = ifOp.thenYield().getOperands();
auto elseYieldArgs = ifOp.elseYield().getOperands();

SmallVector<Value> newIfResults(ifOp->getNumResults());
if (ifOp->getNumResults() > 0) {
rewriter.setInsertionPoint(ifOp);

for (const auto &it :
llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
Value trueVal = std::get<0>(it.value());
Value falseVal = std::get<1>(it.value());
newIfResults[it.index()] = rewriter.create<arith::SelectOp>(
ifOp.getLoc(), cond, trueVal, falseVal);
}
rewriter.replaceOp(ifOp, newIfResults);
}

return success();
}
};

struct ConvertIfToSelect : impl::ConvertIfToSelectBase<ConvertIfToSelect> {
using ConvertIfToSelectBase::ConvertIfToSelectBase;

void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);

patterns.add<IfToSelectConversion>(context);

(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

} // namespace heir
} // namespace mlir
18 changes: 18 additions & 0 deletions lib/Transforms/ConvertIfToSelect/ConvertIfToSelect.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef LIB_TRANSFORMS_CONVERTIFTOSELECT_CONVERTIFTOSELECT_H_
#define LIB_TRANSFORMS_CONVERTIFTOSELECT_CONVERTIFTOSELECT_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir {
namespace heir {

#define GEN_PASS_DECL
#include "lib/Transforms/ConvertIfToSelect/ConvertIfToSelect.h.inc"

#define GEN_PASS_REGISTRATION
#include "lib/Transforms/ConvertIfToSelect/ConvertIfToSelect.h.inc"

} // namespace heir
} // namespace mlir

#endif // LIB_TRANSFORMS_CONVERTIFTOSELECT_CONVERTIFTOSELECT_H_
15 changes: 15 additions & 0 deletions lib/Transforms/ConvertIfToSelect/ConvertIfToSelect.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef LIB_TRANSFORMS_CONVERTIFTOSELECT_CONVERTIFTOSELECT_TD_
#define LIB_TRANSFORMS_CONVERTIFTOSELECT_CONVERTIFTOSELECT_TD_

include "mlir/Pass/PassBase.td"

def ConvertIfToSelect : Pass<"convert-if-to-select"> {
let summary = "Convert scf.if operations on secret conditions to arith.select operations.";
let description = [{ Conversion for If-operations that evaluate secret condition to alternative select operations. }];
let dependentDialects = [
"mlir::scf::SCFDialect",
"mlir::arith::ArithDialect"
];
}

#endif // LIB_TRANSFORMS_CONVERTIFTOSELECT_CONVERTIFTOSELECT_TD_
10 changes: 10 additions & 0 deletions tests/convert_if_to_select/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
load("//bazel:lit.bzl", "glob_lit_tests")

package(default_applicable_licenses = ["@heir//:license"])

glob_lit_tests(
name = "all_tests",
data = ["@heir//tests:test_utilities"],
driver = "@heir//tests:run_lit.sh",
test_file_exts = ["mlir"],
)
160 changes: 160 additions & 0 deletions tests/convert_if_to_select/expected_outputs.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
// RUN: heir-opt %s | FileCheck %s

// CHECK-LABEL: @secret_condition_with_non_secret_int
func.func @secret_condition_with_non_secret_int(%inp: i16, %cond: !secret.secret<i1>) -> !secret.secret<i16> {
// CHECK-NEXT: %[[RESULT:.*]] = secret.generic ins(%[[INP:.*]], %[[COND:.*]] : [[T:.*]], !secret.secret<i1>) {
// CHECK-NEXT: ^[[bb0:.*]](%[[CPY_INP:.*]]: [[T]], %[[SCRT_COND:.*]]: i1):
// CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[CPY_INP]], %[[CPY_INP]] : [[T]]
// CHECK-NEXT: %[[SEL:.*]] = arith.select %[[SCRT_COND]], %[[ADD]], %[[CPY_INP]] : [[T]]
// CHECK-NEXT: secret.yield %[[SEL]] : [[T]]
// CHECK-NEXT: } -> !secret.secret<[[T]]>
// CHECK-NEXT: return %[[RESULT]] : !secret.secret<[[T]]>
%0 = secret.generic ins(%inp, %cond : i16, !secret.secret<i1>) {
^bb0(%copy_inp: i16, %secret_cond: i1):
%2 = arith.addi %copy_inp, %copy_inp : i16
%1 = arith.select %secret_cond, %2, %copy_inp : i16
secret.yield %1 : i16
} -> !secret.secret<i16>
return %0 : !secret.secret<i16>
}



// CHECK-LABEL: @secret_condition_with_secret_int
func.func @secret_condition_with_secret_int(%inp: !secret.secret<i16>, %cond: !secret.secret<i1>) -> !secret.secret<i16> {
// CHECK-NEXT: %[[RESULT:.*]] = secret.generic ins(%[[INP:.*]], %[[COND:.*]] : !secret.secret<[[T:.*]]>, !secret.secret<i1>) {
// CHECK-NEXT: ^[[bb0:.*]](%[[SCRT_INP:.*]]: [[T]], %[[SCRT_COND:.*]]: i1):
// CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[SCRT_INP]], %[[SCRT_INP]] : [[T]]
// CHECK-NEXT: %[[SEL:.*]] = arith.select %[[SCRT_COND]], %[[ADD]], %[[SCRT_INP]] : [[T]]
// CHECK-NEXT: secret.yield %[[SEL]] : [[T]]
// CHECK-NEXT: } -> !secret.secret<[[T]]>
// CHECK-NEXT: return %[[RESULT]] : !secret.secret<[[T]]>
%0 = secret.generic ins(%inp, %cond : !secret.secret<i16>, !secret.secret<i1>) {
^bb0(%secret_inp: i16, %secret_cond: i1):
%2 = arith.addi %secret_inp, %secret_inp : i16
%1 = arith.select %secret_cond, %2, %secret_inp : i16
secret.yield %1 : i16
} -> !secret.secret<i16>
return %0 : !secret.secret<i16>
}



// CHECK-LABEL: @secret_condition_with_secret_int_and_multiple_yields
func.func @secret_condition_with_secret_int_and_multiple_yields(%inp: !secret.secret<i16>, %cond: !secret.secret<i1>) -> !secret.secret<i16> {
// CHECK-NEXT: %[[RESULT:.*]] = secret.generic ins(%[[INP:.*]], %[[COND:.*]] : !secret.secret<[[T:.*]]>, !secret.secret<i1>) {
// CHECK-NEXT: ^[[bb0:.*]](%[[SCRT_INP:.*]]: [[T]], %[[SCRT_COND:.*]]: i1):
// CHECK-NEXT: %[[ADD1:.*]] = arith.addi %[[SCRT_INP]], %[[SCRT_INP]] : [[T]]
// CHECK-NEXT: %[[MUL:.*]] = arith.muli %[[SCRT_INP]], %[[ADD1]] : [[T]]
// CHECK-NEXT: %[[SEL1:.*]] = arith.select %[[SCRT_COND]], %[[ADD1]], %[[SCRT_INP]] : [[T]]
// CHECK-NEXT: %[[SEL2:.*]] = arith.select %[[SCRT_COND]], %[[MUL]], %[[SCRT_INP]] : [[T]]
// CHECK-NEXT: %[[ADD2:.*]] = arith.addi %[[SEL1]], %[[SEL2]] : [[T]]
// CHECK-NEXT: secret.yield %[[ADD2]] : [[T]]
// CHECK-NEXT: } -> !secret.secret<[[T]]>
// CHECK-NEXT: return %[[RESULT]] : !secret.secret<[[T]]>
%0 = secret.generic ins(%inp, %cond : !secret.secret<i16>, !secret.secret<i1>) {
^bb0(%secret_inp: i16, %secret_cond: i1):
%2 = arith.addi %secret_inp, %secret_inp : i16
%4 = arith.muli %secret_inp, %2 : i16
%1 = arith.select %secret_cond, %2, %secret_inp : i16
%3 = arith.select %secret_cond, %4, %secret_inp : i16
%5 = arith.addi %1, %3 : i16
secret.yield %5 : i16
} -> !secret.secret<i16>
return %0 : !secret.secret<i16>
}



// CHECK-LABEL: @secret_condition_with_secret_tensor
func.func @secret_condition_with_secret_tensor(%inp: !secret.secret<tensor<16xi16>>, %cond: !secret.secret<i1>) -> !secret.secret<tensor<16xi16>> {
// CHECK-NEXT: %[[RESULT:.*]] = secret.generic ins(%[[INP:.*]], %[[COND:.*]] : !secret.secret<tensor<[[T:.*]]>>, !secret.secret<i1>)
// CHECK-NEXT: ^[[bb0:.*]](%[[SCRT_INP:.*]]: tensor<[[T:.*]]>, %[[SCRT_COND:.*]]: i1):
// CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[SCRT_INP]], %[[SCRT_INP]] : tensor<[[T]]>
// CHECK-NEXT: %[[MUL:.*]] = arith.muli %[[SCRT_INP]], %[[SCRT_INP]] : tensor<[[T]]>
// CHECK-NEXT: %[[SEL:.*]] = arith.select %[[SCRT_COND]], %[[ADD]], %[[MUL]] : tensor<[[T]]>
// CHECK-NEXT: secret.yield %[[SEL]] : tensor<[[T]]>
// CHECK-NEXT: } -> !secret.secret<tensor<[[T]]>>
// CHECK-NEXT: return %[[RESULT]] : !secret.secret<tensor<[[T]]>>
%0 = secret.generic ins(%inp, %cond : !secret.secret<tensor<16xi16>>, !secret.secret<i1>) {
^bb0(%secret_inp: tensor<16xi16>, %secret_cond: i1):
%2 = arith.addi %secret_inp, %secret_inp : tensor<16xi16>
%3 = arith.muli %secret_inp, %secret_inp : tensor<16xi16>
%1 = arith.select %secret_cond, %2, %3 : tensor<16xi16>
secret.yield %1 : tensor<16xi16>
} -> !secret.secret<tensor<16xi16>>
return %0 : !secret.secret<tensor<16xi16>>
}



// CHECK-LABEL: @non_secret_condition_with_secret_tensor
func.func @non_secret_condition_with_secret_tensor(%inp: !secret.secret<tensor<16xi16>>, %cond: i1) -> !secret.secret<tensor<16xi16>> {
// CHECK-NEXT: %[[RESULT:.*]] = secret.generic ins(%[[INP:.*]], %[[COND:.*]] : !secret.secret<tensor<[[T:.*]]>>, i1) {
// CHECK-NEXT: ^[[bb0:.*]](%[[SCRT_INP:.*]]: tensor<[[T]]>, %[[CPY_COND:.*]]: i1):
// CHECK-NEXT: %[[IF:.*]] = scf.if %[[CPY_COND]] -> (tensor<[[T]]>) {
// CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[SCRT_INP]], %[[SCRT_INP]] : tensor<[[T]]>
// CHECK-NEXT: scf.yield %[[ADD]] : tensor<[[T]]>
// CHECK-NEXT: } else {
// CHECK-NEXT: scf.yield %[[SCRT_INP]] : tensor<[[T]]>
// CHECK-NEXT: }
// CHECK-NEXT: secret.yield %[[IF]] : tensor<[[T]]>
// CHECK-NEXT: } -> !secret.secret<tensor<[[T]]>>
// CHECK-NEXT: return %[[RESULT]] : !secret.secret<tensor<[[T]]>>
%0 = secret.generic ins(%inp, %cond : !secret.secret<tensor<16xi16>>, i1) {
^bb0(%secret_inp: tensor<16xi16>, %copy_cond: i1):
%1 = scf.if %copy_cond -> (tensor<16xi16>) {
%2 = arith.addi %secret_inp, %secret_inp : tensor<16xi16>
scf.yield %2 : tensor<16xi16>
} else {
scf.yield %secret_inp : tensor<16xi16>
}
secret.yield %1 : tensor<16xi16>
} -> !secret.secret<tensor<16xi16>>
return %0 : !secret.secret<tensor<16xi16>>
}



// CHECK-LABEL: @secret_condition_with_secret_vector
func.func @secret_condition_with_secret_vector(%inp: !secret.secret<vector<4xf32>>, %cond: !secret.secret<i1>) -> !secret.secret<vector<4xf32>> {
// CHECK-NEXT: %[[RESULT:.*]] = secret.generic ins(%[[INP:.*]], %[[COND:.*]] : !secret.secret<vector<[[T:.*]]>>, !secret.secret<i1>) {
// CHECK-NEXT: ^[[bb0:.*]](%[[SCRT_INP:.*]]: vector<[[T]]>, %[[SCRT_COND:.*]]: i1):
// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[SCRT_INP]], %[[SCRT_INP]] : vector<[[T]]>
// CHECK-NEXT: %[[SEL:.*]] = arith.select %[[SCRT_COND]], %[[ADD]], %[[SCRT_INP]] : vector<[[T]]>
// CHECK-NEXT: secret.yield %[[SEL]] : vector<[[T]]>
// CHECK-NEXT: } -> !secret.secret<vector<[[T]]>>
// CHECK-NEXT: return %[[RESULT]] : !secret.secret<vector<[[T]]>>
%0 = secret.generic ins(%inp, %cond : !secret.secret<vector<4xf32>>, !secret.secret<i1>) {
^bb0(%secret_inp: vector<4xf32>, %secret_cond: i1):
%2 = arith.addf %secret_inp, %secret_inp : vector<4xf32>
%1 = arith.select %secret_cond, %2, %secret_inp : vector<4xf32>
secret.yield %1 : vector<4xf32>
} -> !secret.secret<vector<4xf32>>
return %0 : !secret.secret<vector<4xf32>>
}



// CHECK-LABEL: @tainted_condition
func.func @tainted_condition(%inp: !secret.secret<i16>) -> !secret.secret<i16>{
// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0 : [[T:.*]]
// CHECK-NEXT: %[[RESULT:.*]] = secret.generic ins(%[[INP:.*]] : !secret.secret<[[T]]>) {
// CHECK-NEXT: ^[[bb0:.*]](%[[SCRT_INP:.*]]: [[T]]):
// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi eq, %[[SCRT_INP]], %[[ZERO]] : [[T]]
// CHECK-NEXT: %[[ADD:.*]] = arith.addi %[[SCRT_INP]], %[[SCRT_INP]] : [[T]]
// CHECK-NEXT: %[[SEL:.*]] = arith.select %[[CMP]], %[[ADD]], %[[SCRT_INP]] : [[T]]
// CHECK-NEXT: secret.yield %[[SEL]] : [[T]]
// CHECK-NEXT: } -> !secret.secret<[[T]]>
// CHECK-NEXT: return %[[RESULT]] : !secret.secret<[[T]]>
%0 = arith.constant 0 : i16
%1 = secret.generic ins(%inp: !secret.secret<i16>) {
^bb0(%secret_inp: i16):
%2 = arith.cmpi eq, %secret_inp, %0 : i16
%4 = arith.addi %secret_inp, %secret_inp : i16
%3 = arith.select %2, %4, %secret_inp : i16
secret.yield %3 : i16
} -> !secret.secret<i16>

return %1 : !secret.secret<i16>
}
57 changes: 57 additions & 0 deletions tests/convert_if_to_select/invalid_conditionals.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// RUN: heir-opt --convert-if-to-select --split-input-file --verify-diagnostics %s

func.func private @printer(%inp: tensor<16xi16>) -> ()

func.func @impure_operation(%inp: !secret.secret<tensor<16xi16>>, %cond: i1) -> !secret.secret<tensor<16xi16>> {
%0 = secret.generic ins(%inp, %cond : !secret.secret<tensor<16xi16>>, i1) {
^bb0(%secret_inp: tensor<16xi16>, %copy_cond: i1):
// expected-error@below {{Can't convert scf.if to arith.select operation. If-operation contains code that can't be safely hoisted on line }}
%1 = scf.if %copy_cond -> (tensor<16xi16>) {
%2 = arith.addi %secret_inp, %secret_inp : tensor<16xi16>
func.call @printer(%2) : (tensor<16xi16>) -> ()
scf.yield %2 : tensor<16xi16>
} else {
scf.yield %secret_inp : tensor<16xi16>
}
secret.yield %1 : tensor<16xi16>
} -> !secret.secret<tensor<16xi16>>
return %0 : !secret.secret<tensor<16xi16>>
}

// -----

func.func @non_speculative_code(%inp: !secret.secret<i16>, %divisor: !secret.secret<i16>) -> !secret.secret<i16> {
%0 = secret.generic ins(%inp, %divisor : !secret.secret<i16>, !secret.secret<i16>) {
^bb0(%secret_inp: i16, %secret_divisor: i16):
%0 = arith.constant 0 : i16
%secret_cond = arith.cmpi eq, %0, %secret_divisor : i16
// expected-error@below {{Can't convert scf.if to arith.select operation. If-operation contains code that can't be safely hoisted on line }}
%1 = scf.if %secret_cond -> (i16) {
%2 = arith.divui %secret_inp, %secret_divisor : i16 // non-pure
scf.yield %2 : i16
} else {
scf.yield %secret_inp : i16
}
secret.yield %1 : i16
} -> !secret.secret<i16>
return %0 : !secret.secret<i16>
}

// -----

func.func @conditionally_speculative_code(%inp: !secret.secret<i16>, %cond :!secret.secret<i1>) -> !secret.secret<i16> {
%divisor = arith.constant 0 : i16
%0 = secret.generic ins(%inp, %cond : !secret.secret<i16>, !secret.secret<i1>) {
^bb0(%secret_inp: i16, %secret_cond: i1):
%0 = arith.constant 0 : i16
// expected-error@below {{Can't convert scf.if to arith.select operation. If-operation contains code that can't be safely hoisted on line }}
%1 = scf.if %secret_cond -> (i16) {
%2 = arith.divui %secret_inp, %divisor : i16
scf.yield %2 : i16
} else {
scf.yield %secret_inp : i16
}
secret.yield %1 : i16
} -> !secret.secret<i16>
return %0 : !secret.secret<i16>
}
Loading

0 comments on commit 79af940

Please sign in to comment.