Skip to content

Commit

Permalink
Merge pull request #715 from inbelic:inbelic/lower-arith-ext
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 642617614
  • Loading branch information
Copybara-Service committed Jun 12, 2024
2 parents 5335ca8 + ebbbcef commit 9cad9ed
Show file tree
Hide file tree
Showing 8 changed files with 304 additions and 0 deletions.
104 changes: 104 additions & 0 deletions lib/Conversion/ArithExtToArith/ArithExtToArith.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#include "lib/Conversion/ArithExtToArith/ArithExtToArith.h"

#include "lib/Dialect/ArithExt/IR/ArithExtOps.h"
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace arith_ext {

#define GEN_PASS_DEF_ARITHEXTTOARITH
#include "lib/Conversion/ArithExtToArith/ArithExtToArith.h.inc"

namespace rewrites {
// In an inner namespace to avoid conflicts with canonicalization patterns
#include "lib/Conversion/ArithExtToArith/ArithExtToArith.cpp.inc"
} // namespace rewrites

struct ConvertBarrettReduce : public OpConversionPattern<BarrettReduceOp> {
ConvertBarrettReduce(mlir::MLIRContext *context)
: OpConversionPattern<BarrettReduceOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
BarrettReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

// Compute B = 4^{bitWidth} and ratio = floordiv(B / modulus)
auto input = adaptor.getInput();
auto mod = APInt(64, op.getModulus());
auto bitWidth = (mod - 1).getActiveBits();
mod = mod.trunc(3 * bitWidth);
auto B = APInt(3 * bitWidth, 1).shl(2 * bitWidth);
auto barrettRatio = B.udiv(mod);

Type intermediateType = IntegerType::get(b.getContext(), 3 * bitWidth);

// Create our pre-computed constants
TypedAttr ratioAttr, shiftAttr, modAttr;
if (auto tensorType = dyn_cast<RankedTensorType>(input.getType())) {
tensorType = tensorType.clone(tensorType.getShape(), intermediateType);
ratioAttr = DenseElementsAttr::get(tensorType, barrettRatio);
shiftAttr =
DenseElementsAttr::get(tensorType, APInt(3 * bitWidth, 2 * bitWidth));
modAttr = DenseElementsAttr::get(tensorType, mod);
intermediateType = tensorType;
} else if (auto integerType = dyn_cast<IntegerType>(input.getType())) {
ratioAttr = IntegerAttr::get(intermediateType, barrettRatio);
shiftAttr =
IntegerAttr::get(intermediateType, APInt(3 * bitWidth, 2 * bitWidth));
modAttr = IntegerAttr::get(intermediateType, mod);
}

auto ratioValue = b.create<arith::ConstantOp>(intermediateType, ratioAttr);
auto shiftValue = b.create<arith::ConstantOp>(intermediateType, shiftAttr);
auto modValue = b.create<arith::ConstantOp>(intermediateType, modAttr);

// Intermediate value will be in the range [0,p^3) so we need to extend to
// 3*bitWidth
auto extendOp = b.create<arith::ExtUIOp>(intermediateType, input);

// Compute x - floordiv(x * ratio, B) * mod
auto mulRatioOp = b.create<arith::MulIOp>(extendOp, ratioValue);
auto shrOp = b.create<arith::ShRUIOp>(mulRatioOp, shiftValue);
auto mulModOp = b.create<arith::MulIOp>(shrOp, modValue);
auto subOp = b.create<arith::SubIOp>(extendOp, mulModOp);

auto truncOp = b.create<arith::TruncIOp>(input.getType(), subOp);

rewriter.replaceOp(op, truncOp);

return success();
}
};

struct ArithExtToArith : impl::ArithExtToArithBase<ArithExtToArith> {
using ArithExtToArithBase::ArithExtToArithBase;

void runOnOperation() override;
};

void ArithExtToArith::runOnOperation() {
MLIRContext *context = &getContext();
ModuleOp module = getOperation();

ConversionTarget target(*context);
target.addIllegalDialect<ArithExtDialect>();
target.addLegalDialect<arith::ArithDialect>();

RewritePatternSet patterns(context);
patterns.add<rewrites::ConvertSubIfGE, ConvertBarrettReduce>(context);

if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
signalPassFailure();
}
}

} // namespace arith_ext
} // namespace heir
} // namespace mlir
21 changes: 21 additions & 0 deletions lib/Conversion/ArithExtToArith/ArithExtToArith.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifndef LIB_DIALECT_ARITHEXT_TRANSFORMS_ARITHEXTTOARITH_H_
#define LIB_DIALECT_ARITHEXT_TRANSFORMS_ARITHEXTTOARITH_H_

#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace arith_ext {

#define GEN_PASS_DECL
#include "lib/Conversion/ArithExtToArith/ArithExtToArith.h.inc"

#define GEN_PASS_REGISTRATION
#include "lib/Conversion/ArithExtToArith/ArithExtToArith.h.inc"

} // namespace arith_ext
} // namespace heir
} // namespace mlir

#endif // LIB_DIALECT_ARITHEXT_TRANSFORMS_ARITHEXTTOARITH_H_
36 changes: 36 additions & 0 deletions lib/Conversion/ArithExtToArith/ArithExtToArith.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#ifndef LIB_CONVERSION_ARITHEXTTOARITH_ARITHEXTTOARITH_TD_
#define LIB_CONVERSION_ARITHEXTTOARITH_ARITHEXTTOARITH_TD_

include "lib/DRR/Utils.td"
include "lib/Dialect/ArithExt/IR/ArithExtOps.td"
include "mlir/Dialect/Arith/IR/ArithOps.td"
include "mlir/IR/PatternBase.td"
include "mlir/Pass/PassBase.td"

def ArithExtToArith : Pass<"arith-ext-to-arith", "ModuleOp"> {
let summary = "Lower `arith_ext` to standard `arith`.";

let description = [{
This pass lowers the `arith_ext` dialect to their `arith` equivalents.
}];

let dependentDialects = [
"mlir::arith::ArithDialect",
"mlir::heir::arith_ext::ArithExtDialect",
];
}

// Using DRR to generate the lowering patterns for specific operations

defvar DefGE = ConstantEnumCase<Arith_CmpIPredicateAttr, "uge">;

def ConvertSubIfGE : Pattern<
(ArithExt_SubIfGEOp $x, $y),
[
(Arith_SubIOp:$subOp $x, $y, DefOverflow),
(Arith_CmpIOp:$cmpOp DefGE, $x, $y),
(SelectOp $cmpOp, $subOp, $x)
]
>;

#endif // LIB_CONVERSION_ARITHEXTTOARITH_ARITHEXTTOARITH_TD_
49 changes: 49 additions & 0 deletions lib/Conversion/ArithExtToArith/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "ArithExtToArith",
srcs = ["ArithExtToArith.cpp"],
hdrs = [
"ArithExtToArith.h",
],
deps = [
":pass_inc_gen",
"@heir//lib/Dialect/ArithExt/IR:Dialect",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:TransformUtils",
],
)

gentbl_cc_library(
name = "pass_inc_gen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=ArithExtToArith",
],
"ArithExtToArith.h.inc",
),
(
["-gen-rewriters"],
"ArithExtToArith.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "ArithExtToArith.td",
deps = [
"@heir//lib/DRR",
"@heir//lib/Dialect/ArithExt/IR:ops_inc_gen",
"@heir//lib/Dialect/ArithExt/IR:td_files",
"@llvm-project//mlir:ArithOpsTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)
72 changes: 72 additions & 0 deletions tests/arith_ext/arith-ext-to-arith.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// RUN: heir-opt -arith-ext-to-arith --split-input-file %s | FileCheck %s

// CHECK-LABEL: @test_lower_subifge
// CHECK-SAME: (%[[LHS:.*]]: [[TENSOR_TYPE:.*]], %[[RHS:.*]]: [[TENSOR_TYPE]]) -> [[TENSOR_TYPE]] {
func.func @test_lower_subifge(%lhs : tensor<4xi8>, %rhs : tensor<4xi8>) -> tensor<4xi8> {

// CHECK: %[[SUB:.*]] = arith.subi %[[LHS]], %[[RHS]] : [[TENSOR_TYPE]]
// CHECK: %[[CMP:.*]] = arith.cmpi uge, %[[LHS]], %[[RHS]] : [[TENSOR_TYPE]]
// CHECK: %[[RES:.*]] = arith.select %[[CMP]], %[[SUB]], %[[LHS]] : tensor<4xi1>, [[TENSOR_TYPE]]
%res = arith_ext.subifge %lhs, %rhs: tensor<4xi8>
return %res : tensor<4xi8>
}

// -----

// CHECK-LABEL: @test_lower_subifge_int
// CHECK-SAME: (%[[LHS:.*]]: [[INT_TYPE:.*]], %[[RHS:.*]]: [[INT_TYPE]]) -> [[INT_TYPE]] {
func.func @test_lower_subifge_int(%lhs : i8, %rhs : i8) -> i8 {

// CHECK: %[[SUB:.*]] = arith.subi %[[LHS]], %[[RHS]] : [[INT_TYPE]]
// CHECK: %[[CMP:.*]] = arith.cmpi uge, %[[LHS]], %[[RHS]] : [[INT_TYPE]]
// CHECK: %[[RES:.*]] = arith.select %[[CMP]], %[[SUB]], %[[LHS]] : [[INT_TYPE]]
%res = arith_ext.subifge %lhs, %rhs: i8
return %res : i8
}

// -----

// CHECK-LABEL: @test_lower_barrett_reduce

// CHECK-SAME: (%[[ARG:.*]]: [[TENSOR_TYPE:.*]]) -> [[TENSOR_TYPE]] {
func.func @test_lower_barrett_reduce(%arg : tensor<4xi10>) -> tensor<4xi10> {

// CHECK: %[[RATIO:.*]] = arith.constant dense<60> : [[INTER_TYPE:.*]]
// CHECK: %[[BITWIDTH:.*]] = arith.constant dense<10> : [[INTER_TYPE]]
// CHECK: %[[CMOD:.*]] = arith.constant dense<17> : [[INTER_TYPE]]

// CHECK: %[[EXT:.*]] = arith.extui %[[ARG]] : [[TENSOR_TYPE]] to [[INTER_TYPE]]
// CHECK: %[[MULRATIO:.*]] = arith.muli %[[EXT]], %[[RATIO]] : [[INTER_TYPE]]
// CHECK: %[[SHIFTED:.*]] = arith.shrui %[[MULRATIO]], %[[BITWIDTH]] : [[INTER_TYPE]]
// CHECK: %[[MULCMOD:.*]] = arith.muli %[[SHIFTED]], %[[CMOD]] : [[INTER_TYPE]]
// CHECK: %[[SUB:.*]] = arith.subi %[[EXT]], %[[MULCMOD]] : [[INTER_TYPE]]
// CHECK: %[[RES:.*]] = arith.trunci %[[SUB]] : [[INTER_TYPE]] to [[TENSOR_TYPE]]
%res = arith_ext.barrett_reduce %arg { modulus = 17 } : tensor<4xi10>

// CHECK: return %[[RES]] : [[TENSOR_TYPE]]
return %res : tensor<4xi10>
}

// -----

// CHECK-LABEL: @test_lower_barrett_reduce_int
// CHECK-SAME: (%[[ARG:.*]]: [[INT_TYPE:.*]]) -> [[INT_TYPE]] {
func.func @test_lower_barrett_reduce_int(%arg : i10) -> i10 {

// CHECK: %[[RATIO:.*]] = arith.constant 60 : [[INTER_TYPE:.*]]
// CHECK: %[[BITWIDTH:.*]] = arith.constant 10 : [[INTER_TYPE]]
// CHECK: %[[CMOD:.*]] = arith.constant 17 : [[INTER_TYPE]]

// CHECK: %[[EXT:.*]] = arith.extui %[[ARG]] : [[INT_TYPE]] to [[INTER_TYPE]]
// CHECK: %[[MULRATIO:.*]] = arith.muli %[[EXT]], %[[RATIO]] : [[INTER_TYPE]]
// CHECK: %[[SHIFTED:.*]] = arith.shrui %[[MULRATIO]], %[[BITWIDTH]] : [[INTER_TYPE]]
// CHECK: %[[MULCMOD:.*]] = arith.muli %[[SHIFTED]], %[[CMOD]] : [[INTER_TYPE]]
// CHECK: %[[SUB:.*]] = arith.subi %[[EXT]], %[[MULCMOD]] : [[INTER_TYPE]]
// CHECK: %[[RES:.*]] = arith.trunci %[[SUB]] : [[INTER_TYPE]] to [[INT_TYPE]]
%res = arith_ext.barrett_reduce %arg { modulus = 17 } : i10

// CHECK: return %[[RES]] : [[INT_TYPE]]
return %res : i10
}

// -----
19 changes: 19 additions & 0 deletions tests/arith_ext/barrett_reduce_runner.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: heir-opt %s --arith-ext-to-arith --heir-polynomial-to-llvm \
// RUN: | mlir-cpu-runner -e test_lower_barrett_reduce -entry-point-result=void \
// RUN: --shared-libs="%mlir_lib_dir/libmlir_c_runner_utils%shlibext,%mlir_runner_utils" > %t
// RUN: FileCheck %s --check-prefix=CHECK_TEST_BARRETT < %t

func.func private @printMemrefI32(memref<*xi32>) attributes { llvm.emit_c_interface }

func.func @test_lower_barrett_reduce() {
%coeffs = arith.constant dense<[29498763, 58997760, 17, 7681]> : tensor<4xi26>
%1 = arith_ext.barrett_reduce %coeffs { modulus = 7681 } : tensor<4xi26>

%2 = arith.extui %1 : tensor<4xi26> to tensor<4xi32>
%3 = bufferization.to_memref %2 : memref<4xi32>
%U = memref.cast %3 : memref<4xi32> to memref<*xi32>
func.call @printMemrefI32(%U) : (memref<*xi32>) -> ()
return
}

// CHECK_TEST_BARRETT: [3723, 7680, 17, 7681]
1 change: 1 addition & 0 deletions tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ cc_binary(
}),
includes = ["include"],
deps = [
"@heir//lib/Conversion/ArithExtToArith",
"@heir//lib/Conversion/BGVToOpenfhe",
"@heir//lib/Conversion/BGVToPolynomial",
"@heir//lib/Conversion/CGGIToTfheRust",
Expand Down
2 changes: 2 additions & 0 deletions tools/heir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <string>
#include <vector>

#include "lib/Conversion/ArithExtToArith/ArithExtToArith.h"
#include "lib/Conversion/BGVToOpenfhe/BGVToOpenfhe.h"
#include "lib/Conversion/BGVToPolynomial/BGVToPolynomial.h"
#include "lib/Conversion/CGGIToTfheRust/CGGIToTfheRust.h"
Expand Down Expand Up @@ -532,6 +533,7 @@ int main(int argc, char **argv) {
#endif

// Dialect conversion passes in HEIR
arith_ext::registerArithExtToArithPasses();
bgv::registerBGVToPolynomialPasses();
bgv::registerBGVToOpenfhePasses();
comb::registerCombToCGGIPasses();
Expand Down

0 comments on commit 9cad9ed

Please sign in to comment.