Skip to content

Commit

Permalink
add lowering of arith_ext.add/sub/mul to arith
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderViand-Intel committed Jun 7, 2024
1 parent 53d71b7 commit 7f7743e
Show file tree
Hide file tree
Showing 3 changed files with 269 additions and 2 deletions.
18 changes: 17 additions & 1 deletion lib/Conversion/ArithExtToArith/ArithExtToArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "lib/Dialect/ArithExt/IR/ArithExtOps.h"
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project

Expand All @@ -13,6 +14,20 @@ namespace arith_ext {
#define GEN_PASS_DEF_ARITHEXTTOARITH
#include "lib/Conversion/ArithExtToArith/ArithExtToArith.h.inc"

template <typename ValueOrOpResult>
TypedAttr modulusHelper(IntegerAttr mod, ValueOrOpResult op, bool mul = false) {
auto width = getElementTypeOrSelf(op).getIntOrFloatBitWidth();
auto modWidth = (mod.getValue() - 1).getActiveBits();
width = std::max(width, mul ? 2 * modWidth : modWidth + 1);
auto intType = IntegerType::get(op.getContext(), width);
auto truncmod = mod.getValue().sextOrTrunc(width);
if (auto st = mlir::dyn_cast_or_null<ShapedType>(op.getType())) {
auto containerType = st.cloneWith(st.getShape(), intType);
return DenseElementsAttr::get(containerType, truncmod);
}
return IntegerAttr::get(intType, truncmod);
}

namespace rewrites {
// In an inner namespace to avoid conflicts with canonicalization patterns
#include "lib/Conversion/ArithExtToArith/ArithExtToArith.cpp.inc"
Expand Down Expand Up @@ -92,7 +107,8 @@ void ArithExtToArith::runOnOperation() {
target.addLegalDialect<arith::ArithDialect>();

RewritePatternSet patterns(context);
patterns.add<rewrites::ConvertSubIfGE, ConvertBarrettReduce>(context);
rewrites::populateWithGenerated(patterns);
patterns.add<ConvertBarrettReduce>(context);

if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
signalPassFailure();
Expand Down
89 changes: 89 additions & 0 deletions lib/Conversion/ArithExtToArith/ArithExtToArith.td
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,93 @@ def ConvertSubIfGE : Pattern<
]
>;


def HasEnoughSpaceAddSub: Constraint<CPred<"llvm::cast<IntegerType>(getElementTypeOrSelf($_self.getType())).getWidth() >= ($0.getValue() - 1).getActiveBits() + 1">,
"underlying type is sufficient for modular add/sub operation without overflow">;

def HasEnoughSpaceMul: Constraint<CPred<"llvm::cast<IntegerType>(getElementTypeOrSelf($_self.getType())).getWidth() >= 2 * ($0.getValue() - 1).getActiveBits()">,
"underlying type is sufficient for modular mul operation without overflow">;

def CastModulusAttributeAddSub : NativeCodeCall<"modulusHelper($0,$1,false)">;
def CastModulusAttributeMul : NativeCodeCall<"modulusHelper($0,$1,true)">;

def ConvertAddSimple : Pattern<
(ArithExt_AddOp:$op $x, $y, $mod),
[
(Arith_AddIOp:$add $x, $y, DefOverflow),
(Arith_RemUIOp $add, (Arith_ConstantOp (CastModulusAttributeAddSub $mod, $x)))
],
[(HasEnoughSpaceAddSub:$op $mod)],
[],
(addBenefit 2)
>;

def ConvertSubSimple : Pattern<
(ArithExt_SubOp:$op $x, $y, $mod),
[
(Arith_SubIOp:$add $x, $y, DefOverflow),
(Arith_RemUIOp $add, (Arith_ConstantOp (CastModulusAttributeAddSub $mod, $x)))
],
[(HasEnoughSpaceAddSub:$op $mod)],
[],
(addBenefit 2)
>;

def ConvertMulSimple : Pattern<
(ArithExt_MulOp:$op $x, $y, $mod),
[
(Arith_MulIOp:$add $x, $y, DefOverflow),
(Arith_RemUIOp $add, (Arith_ConstantOp (CastModulusAttributeMul $mod, $x)))
],
[(HasEnoughSpaceMul:$op $mod)],
[],
(addBenefit 2)
>;


def ConvertAdd : Pattern<
(ArithExt_AddOp $x, $y, $mod),
[
(Arith_ConstantOp:$newmod (CastModulusAttributeAddSub $mod, $x)),
(Arith_AddIOp:$add
(Arith_ExtUIOp $x,
(returnType $newmod)),
(Arith_ExtUIOp $y,
(returnType $newmod)),
DefOverflow),
(Arith_TruncIOp:$res
(Arith_RemUIOp $add, $newmod))
]
>;

def ConvertSub : Pattern<
(ArithExt_SubOp $x, $y, $mod),
[
(Arith_ConstantOp:$newmod (CastModulusAttributeAddSub $mod, $x)),
(Arith_SubIOp:$add
(Arith_ExtUIOp $x,
(returnType $newmod)),
(Arith_ExtUIOp $y,
(returnType $newmod)),
DefOverflow),
(Arith_TruncIOp:$res
(Arith_RemUIOp $add, $newmod))
]
>;

def ConvertMul : Pattern<
(ArithExt_MulOp $x, $y, $mod),
[
(Arith_ConstantOp:$newmod (CastModulusAttributeMul $mod, $x)),
(Arith_MulIOp:$add
(Arith_ExtUIOp $x,
(returnType $newmod)),
(Arith_ExtUIOp $y,
(returnType $newmod)),
DefOverflow),
(Arith_TruncIOp:$res
(Arith_RemUIOp $add, $newmod))
]
>;

#endif // LIB_CONVERSION_ARITHEXTTOARITH_ARITHEXTTOARITH_TD_
164 changes: 163 additions & 1 deletion tests/arith_ext/arith-ext-to-arith.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,166 @@
// RUN: heir-opt -arith-ext-to-arith --split-input-file %s | FileCheck %s
// RUN: heir-opt -arith-ext-to-arith --split-input-file %s | FileCheck %s --enable-var-scope

// CHECK-LABEL: @test_lower_simple_add
// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]]) -> [[TYPE]] {
func.func @test_lower_simple_add(%lhs : i8, %rhs : i8) -> i8 {
// CHECK-NOT: arith_ext.add
// CHECK: %[[ADD:.*]] = arith.addi %[[LHS]], %[[RHS]] : [[TYPE]]
// CHECK: %[[CMOD:.*]] = arith.constant 17 : [[TYPE]]
// CHECK: %[[REM:.*]] = arith.remui %[[ADD]], %[[CMOD]] : [[TYPE]]
// CHECK: return %[[REM]] : [[TYPE]]
%res = arith_ext.add %lhs, %rhs {modulus = 17 }: i8
return %res : i8
}

// CHECK-LABEL: @test_lower_simple_add_vec
// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]]) -> [[TYPE]] {
func.func @test_lower_simple_add_vec(%lhs : tensor<4xi8>, %rhs : tensor<4xi8>) -> tensor<4xi8> {
// CHECK-NOT: arith_ext.add
// CHECK: %[[ADD:.*]] = arith.addi %[[LHS]], %[[RHS]] : [[TYPE]]
// CHECK: %[[CMOD:.*]] = arith.constant dense<17> : [[TYPE]]
// CHECK: %[[REM:.*]] = arith.remui %[[ADD]], %[[CMOD]] : [[TYPE]]
// CHECK: return %[[REM]] : [[TYPE]]
%res = arith_ext.add %lhs, %rhs {modulus = 17}: tensor<4xi8>
return %res : tensor<4xi8>
}

// CHECK-LABEL: @test_lower_add
// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]]) -> [[TYPE]] {
func.func @test_lower_add(%lhs : i8, %rhs : i8) -> i8 {
// CHECK-NOT: arith_ext.add
// CHECK: %[[CMOD:.*]] = arith.constant 217 : [[INTERMEDIATE_TYPE:.*]]
// CHECK: %[[EXT0:.*]] = arith.extui %[[LHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]]
// CHECK: %[[EXT1:.*]] = arith.extui %[[RHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]]
// CHECK: %[[ADD:.*]] = arith.addi %[[EXT0]], %[[EXT1]] : [[INTERMEDIATE_TYPE]]
// CHECK: %[[REM:.*]] = arith.remui %[[ADD]], %[[CMOD]] : [[INTERMEDIATE_TYPE]]
// CHECK: %[[TRUNC:.*]] = arith.trunci %[[REM]] : [[INTERMEDIATE_TYPE]] to [[TYPE]]
// CHECK: return %[[TRUNC]] : [[TYPE]]
%res = arith_ext.add %lhs, %rhs {modulus = 217 }: i8
return %res : i8
}

// CHECK-LABEL: @test_lower_add_vec
// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]]) -> [[TYPE]] {
func.func @test_lower_add_vec(%lhs : tensor<4xi8>, %rhs : tensor<4xi8>) -> tensor<4xi8> {
// CHECK-NOT: arith_ext.add
// CHECK: %[[CMOD:.*]] = arith.constant dense<217> : [[INTERMEDIATE_TYPE:.*]]
// CHECK: %[[EXT0:.*]] = arith.extui %[[LHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]]
// CHECK: %[[EXT1:.*]] = arith.extui %[[RHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]]
// CHECK: %[[ADD:.*]] = arith.addi %[[EXT0]], %[[EXT1]] : [[INTERMEDIATE_TYPE]]
// CHECK: %[[REM:.*]] = arith.remui %[[ADD]], %[[CMOD]] : [[INTERMEDIATE_TYPE]]
// CHECK: %[[TRUNC:.*]] = arith.trunci %[[REM]] : [[INTERMEDIATE_TYPE]] to [[TYPE]]
// CHECK: return %[[TRUNC]] : [[TYPE]]
%res = arith_ext.add %lhs, %rhs {modulus = 217 }: tensor<4xi8>
return %res : tensor<4xi8>
}

// CHECK-LABEL: @test_lower_simple_sub
// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]]) -> [[TYPE]] {
func.func @test_lower_simple_sub(%lhs : i8, %rhs : i8) -> i8 {
// CHECK-NOT: arith_ext.sub
// CHECK: %[[SUB:.*]] = arith.subi %[[LHS]], %[[RHS]] : [[TYPE]]
// CHECK: %[[CMOD:.*]] = arith.constant 17 : [[TYPE]]
// CHECK: %[[REM:.*]] = arith.remui %[[SUB]], %[[CMOD]] : [[TYPE]]
// CHECK: return %[[REM]] : [[TYPE]]
%res = arith_ext.sub %lhs, %rhs {modulus = 17}: i8
return %res : i8
}

// CHECK-LABEL: @test_lower_simple_sub_vec
// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]]) -> [[TYPE]] {
func.func @test_lower_simple_sub_vec(%lhs : tensor<4xi8>, %rhs : tensor<4xi8>) -> tensor<4xi8> {
// CHECK-NOT: arith_ext.sub
// CHECK: %[[SUB:.*]] = arith.subi %[[LHS]], %[[RHS]] : [[TYPE]]
// CHECK: %[[CMOD:.*]] = arith.constant dense<17> : [[TYPE]]
// CHECK: %[[REM:.*]] = arith.remui %[[SUB]], %[[CMOD]] : [[TYPE]]
// CHECK: return %[[REM]] : [[TYPE]]
%res = arith_ext.sub %lhs, %rhs {modulus = 17}: tensor<4xi8>
return %res : tensor<4xi8>
}

// CHECK-LABEL: @test_lower_sub
// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]]) -> [[TYPE]] {
func.func @test_lower_sub(%lhs : i8, %rhs : i8) -> i8 {
// CHECK-NOT: arith_ext.sub
// CHECK: %[[CMOD:.*]] = arith.constant 217 : [[INTERMEDIATE_TYPE:.*]]
// CHECK: %[[EXT0:.*]] = arith.extui %[[LHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]]
// CHECK: %[[EXT1:.*]] = arith.extui %[[RHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]]
// CHECK: %[[SUB:.*]] = arith.subi %[[EXT0]], %[[EXT1]] : [[INTERMEDIATE_TYPE]]
// CHECK: %[[REM:.*]] = arith.remui %[[SUB]], %[[CMOD]] : [[INTERMEDIATE_TYPE]]
// CHECK: %[[TRUNC:.*]] = arith.trunci %[[REM]] : [[INTERMEDIATE_TYPE]] to [[TYPE]]
// CHECK: return %[[TRUNC]] : [[TYPE]]
%res = arith_ext.sub %lhs, %rhs {modulus = 217 }: i8
return %res : i8
}

// CHECK-LABEL: @test_lower_sub_vec
// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]]) -> [[TYPE]] {
func.func @test_lower_sub_vec(%lhs : tensor<4xi8>, %rhs : tensor<4xi8>) -> tensor<4xi8> {
// CHECK-NOT: arith_ext.sub
// CHECK: %[[CMOD:.*]] = arith.constant dense<217> : [[INTERMEDIATE_TYPE:.*]]
// CHECK: %[[EXT0:.*]] = arith.extui %[[LHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]]
// CHECK: %[[EXT1:.*]] = arith.extui %[[RHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]]
// CHECK: %[[SUB:.*]] = arith.subi %[[EXT0]], %[[EXT1]] : [[INTERMEDIATE_TYPE]]
// CHECK: %[[REM:.*]] = arith.remui %[[SUB]], %[[CMOD]] : [[INTERMEDIATE_TYPE]]
// CHECK: %[[TRUNC:.*]] = arith.trunci %[[REM]] : [[INTERMEDIATE_TYPE]] to [[TYPE]]
// CHECK: return %[[TRUNC]] : [[TYPE]]
%res = arith_ext.sub %lhs, %rhs {modulus = 217 }: tensor<4xi8>
return %res : tensor<4xi8>
}

// CHECK-LABEL: @test_lower_simple_mul
// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]]) -> [[TYPE]] {
func.func @test_lower_simple_mul(%lhs : i16, %rhs : i16) -> i16 {
// CHECK-NOT: arith_ext.mul
// CHECK: %[[MUL:.*]] = arith.muli %[[LHS]], %[[RHS]] : [[TYPE]]
// CHECK: %[[CMOD:.*]] = arith.constant 17 : [[TYPE]]
// CHECK: %[[REM:.*]] = arith.remui %[[MUL]], %[[CMOD]] : [[TYPE]]
// CHECK: return %[[REM]] : [[TYPE]]
%res = arith_ext.mul %lhs, %rhs {modulus = 17}: i16
return %res : i16
}

// CHECK-LABEL: @test_lower_simple_mul_vec
// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]]) -> [[TYPE]] {
func.func @test_lower_simple_mul_vec(%lhs : tensor<4xi16>, %rhs : tensor<4xi16>) -> tensor<4xi16> {
// CHECK-NOT: arith_ext.mul
// CHECK: %[[MUL:.*]] = arith.muli %[[LHS]], %[[RHS]] : [[TYPE]]
// CHECK: %[[CMOD:.*]] = arith.constant dense<17> : [[TYPE]]
// CHECK: %[[REM:.*]] = arith.remui %[[MUL]], %[[CMOD]] : [[TYPE]]
// CHECK: return %[[REM]] : [[TYPE]]
%res = arith_ext.mul %lhs, %rhs {modulus = 17}: tensor<4xi16>
return %res : tensor<4xi16>
}

// CHECK-LABEL: @test_lower_mul
// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]]) -> [[TYPE]] {
func.func @test_lower_mul(%lhs : i8, %rhs : i8) -> i8 {
// CHECK-NOT: arith_ext.mul
// CHECK: %[[CMOD:.*]] = arith.constant 217 : [[INTERMEDIATE_TYPE:.*]]
// CHECK: %[[EXT0:.*]] = arith.extui %[[LHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]]
// CHECK: %[[EXT1:.*]] = arith.extui %[[RHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]]
// CHECK: %[[MUL:.*]] = arith.muli %[[EXT0]], %[[EXT1]] : [[INTERMEDIATE_TYPE]]
// CHECK: %[[REM:.*]] = arith.remui %[[MUL]], %[[CMOD]] : [[INTERMEDIATE_TYPE]]
// CHECK: %[[TRUNC:.*]] = arith.trunci %[[REM]] : [[INTERMEDIATE_TYPE]] to [[TYPE]]
// CHECK: return %[[TRUNC]] : [[TYPE]]
%res = arith_ext.mul %lhs, %rhs {modulus = 217 }: i8
return %res : i8
}

// CHECK-LABEL: @test_lower_mul_vec
// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]]) -> [[TYPE]] {
func.func @test_lower_mul_vec(%lhs : tensor<4xi8>, %rhs : tensor<4xi8>) -> tensor<4xi8> {
// CHECK-NOT: arith_ext.mul
// CHECK: %[[CMOD:.*]] = arith.constant dense<217> : [[INTERMEDIATE_TYPE:.*]]
// CHECK: %[[EXT0:.*]] = arith.extui %[[LHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]]
// CHECK: %[[EXT1:.*]] = arith.extui %[[RHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]]
// CHECK: %[[MUL:.*]] = arith.muli %[[EXT0]], %[[EXT1]] : [[INTERMEDIATE_TYPE]]
// CHECK: %[[REM:.*]] = arith.remui %[[MUL]], %[[CMOD]] : [[INTERMEDIATE_TYPE]]
// CHECK: %[[TRUNC:.*]] = arith.trunci %[[REM]] : [[INTERMEDIATE_TYPE]] to [[TYPE]]
// CHECK: return %[[TRUNC]] : [[TYPE]]
%res = arith_ext.mul %lhs, %rhs {modulus = 217 }: tensor<4xi8>
return %res : tensor<4xi8>
}

// CHECK-LABEL: @test_lower_subifge
// CHECK-SAME: (%[[LHS:.*]]: [[TENSOR_TYPE:.*]], %[[RHS:.*]]: [[TENSOR_TYPE]]) -> [[TENSOR_TYPE]] {
Expand Down

0 comments on commit 7f7743e

Please sign in to comment.