Skip to content

Commit

Permalink
add arith_ext.mac operation
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderViand-Intel committed Jun 9, 2024
1 parent 6d19de1 commit a0caa5c
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 1 deletion.
29 changes: 29 additions & 0 deletions lib/Conversion/ArithExtToArith/ArithExtToArith.td
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,18 @@ def ConvertMulSimple : Pattern<
(addBenefit 2)
>;

def ConvertMacSimple : Pattern<
(ArithExt_MacOp:$op $x, $y, $acc, $mod),
[
(Arith_MulIOp:$mul $x, $y, DefOverflow),
(Arith_AddIOp:$add $mul, $acc, DefOverflow),
(Arith_RemUIOp $add, (Arith_ConstantOp (CastModulusAttributeMul $mod, $x)))
],
[(HasEnoughSpaceMul:$op $mod)],
[],
(addBenefit 2)
>;

def ConvertAdd : Pattern<
(ArithExt_AddOp $x, $y, $mod),
[
Expand Down Expand Up @@ -121,4 +133,21 @@ def ConvertMul : Pattern<
]
>;

def ConvertMac : Pattern<
(ArithExt_MacOp $x, $y, $acc, $mod),
[
(Arith_ConstantOp:$newmod (CastModulusAttributeMul $mod, $x)),
(Arith_MulIOp:$mul
(Arith_ExtUIOp $x,
(returnType $newmod)),
(Arith_ExtUIOp $y,
(returnType $newmod)),
DefOverflow),
(Arith_AddIOp:$add $mul,
(Arith_ExtUIOp:$extacc $acc, (returnType $newmod)), DefOverflow),
(Arith_TruncIOp:$res
(Arith_RemUIOp $add, $newmod))
]
>;

#endif // LIB_CONVERSION_ARITHEXTTOARITH_ARITHEXTTOARITH_TD_
2 changes: 2 additions & 0 deletions lib/Dialect/ArithExt/IR/ArithExtDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ LogicalResult SubOp::verify() { return verifyArithExtOpMod<SubOp>(*this); }

LogicalResult MulOp::verify() { return verifyArithExtOpMod<MulOp>(*this); }

LogicalResult MacOp::verify() { return verifyArithExtOpMod<MacOp>(*this); }

LogicalResult BarrettReduceOp::verify() {
auto inputType = getInput().getType();
unsigned bitWidth;
Expand Down
7 changes: 7 additions & 0 deletions lib/Dialect/ArithExt/IR/ArithExtOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ def ArithExt_MulOp : ArithExt_BinaryOp<"mul", [Commutative]> {
}];
}

def ArithExt_MacOp : ArithExt_Op<"mac", [SameOperandsAndResultType, Pure, ElementwiseMappable]> {
let summary = "modular multiplication-and-accumulation operation";
let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs, SignlessIntegerLike:$acc, APIntAttr:$modulus);
let results = (outs SignlessIntegerLike:$output);
let hasVerifier = 1;
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` qualified(type($output))";
}

def ArithExt_BarrettReduceOp : ArithExt_Op<"barrett_reduce", [SameOperandsAndResultType]> {
let summary = "Compute the first step of the Barrett reduction.";
Expand Down
63 changes: 63 additions & 0 deletions tests/arith_ext/arith-ext-to-arith.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,69 @@ func.func @test_lower_mul_vec(%lhs : tensor<4xi8>, %rhs : tensor<4xi8>) -> tenso
return %res : tensor<4xi8>
}

// CHECK-LABEL: @test_lower_simple_mac
// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]], %[[ACC:.*]]: [[TYPE]]) -> [[TYPE]] {
func.func @test_lower_simple_mac(%lhs : tensor<4xi16>, %rhs : tensor<4xi16>, %acc : tensor<4xi16>) -> tensor<4xi16> {
// CHECK-NOT: arith_ext.mac
// CHECK: %[[MUL:.*]] = arith.muli %[[LHS]], %[[RHS]] : [[TYPE]]
// CHECK: %[[ADD:.*]] = arith.addi %[[MUL]], %[[ACC]] : [[TYPE]]
// CHECK: %[[CMOD:.*]] = arith.constant dense<17> : [[TYPE]]
// CHECK: %[[REM:.*]] = arith.remui %[[ADD]], %[[CMOD]] : [[TYPE]]
// CHECK: return %[[REM]] : [[TYPE]]
%res = arith_ext.mac %lhs, %rhs, %acc {modulus = 17}: tensor<4xi16>
return %res : tensor<4xi16>
}

// CHECK-LABEL: @test_lower_simple_mac_vec
// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]], %[[ACC:.*]]: [[TYPE]]) -> [[TYPE]] {
func.func @test_lower_simple_mac_vec(%lhs : i16, %rhs : i16, %acc : i16) -> i16 {
// CHECK-NOT: arith_ext.mac
// CHECK: %[[MUL:.*]] = arith.muli %[[LHS]], %[[RHS]] : [[TYPE]]
// CHECK: %[[ADD:.*]] = arith.addi %[[MUL]], %[[ACC]] : [[TYPE]]
// CHECK: %[[CMOD:.*]] = arith.constant 17 : [[TYPE]]
// CHECK: %[[REM:.*]] = arith.remui %[[ADD]], %[[CMOD]] : [[TYPE]]
// CHECK: return %[[REM]] : [[TYPE]]
%res = arith_ext.mac %lhs, %rhs, %acc{modulus = 17}: i16
return %res : i16
}

// CHECK-LABEL: @test_lower_mac
// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]], %[[ACC:.*]]: [[TYPE]]) -> [[TYPE]] {
func.func @test_lower_mac(%lhs : i8, %rhs : i8, %acc : i8) -> i8 {
// CHECK-NOT: arith_ext.mac
// CHECK: %[[CMOD:.*]] = arith.constant 217 : [[INTERMEDIATE_TYPE:.*]]
// CHECK: %[[EXT0:.*]] = arith.extui %[[LHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]]
// CHECK: %[[EXT1:.*]] = arith.extui %[[RHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]]
// CHECK: %[[MUL:.*]] = arith.muli %[[EXT0]], %[[EXT1]] : [[INTERMEDIATE_TYPE]]
// CHECK: %[[EXT2:.*]] = arith.extui %[[ACC]] : [[TYPE]] to [[INTERMEDIATE_TYPE]]
// CHECK: %[[ADD:.*]] = arith.addi %[[MUL]], %[[EXT2]] : [[INTERMEDIATE_TYPE]]
// CHECK: %[[REM:.*]] = arith.remui %[[ADD]], %[[CMOD]] : [[INTERMEDIATE_TYPE]]
// CHECK: %[[TRUNC:.*]] = arith.trunci %[[REM]] : [[INTERMEDIATE_TYPE]] to [[TYPE]]
// CHECK: return %[[TRUNC]] : [[TYPE]]
%res = arith_ext.mac %lhs, %rhs, %acc {modulus = 217 }: i8
return %res : i8
}

// CHECK-LABEL: @test_lower_mac_vec
// CHECK-SAME: (%[[LHS:.*]]: [[TYPE:.*]], %[[RHS:.*]]: [[TYPE]], %[[ACC:.*]]: [[TYPE]]) -> [[TYPE]] {
func.func @test_lower_mac_vec(%lhs : tensor<4xi8>, %rhs : tensor<4xi8>, %acc : tensor<4xi8>) -> tensor<4xi8> {
// CHECK-NOT: arith_ext.mac
// CHECK: %[[CMOD:.*]] = arith.constant dense<217> : [[INTERMEDIATE_TYPE:.*]]
// CHECK: %[[EXT0:.*]] = arith.extui %[[LHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]]
// CHECK: %[[EXT1:.*]] = arith.extui %[[RHS]] : [[TYPE]] to [[INTERMEDIATE_TYPE]]
// CHECK: %[[MUL:.*]] = arith.muli %[[EXT0]], %[[EXT1]] : [[INTERMEDIATE_TYPE]]
// CHECK: %[[EXT2:.*]] = arith.extui %[[ACC]] : [[TYPE]] to [[INTERMEDIATE_TYPE]]
// CHECK: %[[ADD:.*]] = arith.addi %[[MUL]], %[[EXT2]] : [[INTERMEDIATE_TYPE]]
// CHECK: %[[REM:.*]] = arith.remui %[[ADD]], %[[CMOD]] : [[INTERMEDIATE_TYPE]]
// CHECK: %[[TRUNC:.*]] = arith.trunci %[[REM]] : [[INTERMEDIATE_TYPE]] to [[TYPE]]
// CHECK: return %[[TRUNC]] : [[TYPE]]
%res = arith_ext.mac %lhs, %rhs, %acc {modulus = 217 }: tensor<4xi8>
return %res : tensor<4xi8>
}


// -----

// CHECK-LABEL: @test_lower_subifge
// CHECK-SAME: (%[[LHS:.*]]: [[TENSOR_TYPE:.*]], %[[RHS:.*]]: [[TENSOR_TYPE]]) -> [[TENSOR_TYPE]] {
func.func @test_lower_subifge(%lhs : tensor<4xi8>, %rhs : tensor<4xi8>) -> tensor<4xi8> {
Expand Down
9 changes: 8 additions & 1 deletion tests/arith_ext/syntax.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
// CHECK-LABEL: @test_arith_syntax
func.func @test_arith_syntax() {
%zero = arith.constant 1 : i10
%c4 = arith.constant 4 : i10
%c5 = arith.constant 5 : i10
%c6 = arith.constant 6 : i10
%cmod = arith.constant 17 : i10
%c_vec = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi10>
%c_vec2 = arith.constant dense<[4,3,2,1]> : tensor<4xi10>
%c_vec2 = arith.constant dense<[4, 3, 2, 1]> : tensor<4xi10>
%c_vec3 = arith.constant dense<[1, 1, 1, 1]> : tensor<4xi10>
%cmod_vec = arith.constant dense<17> : tensor<4xi10>

// CHECK: arith_ext.add
Expand All @@ -25,6 +27,11 @@ func.func @test_arith_syntax() {
%mul = arith_ext.mul %c5, %c6 { modulus = 17 } : i10
%mul_vec = arith_ext.mul %c_vec, %c_vec2 { modulus = 17 } : tensor<4xi10>

// CHECK: arith_ext.mac
// CHECK: arith_ext.mac
%mac = arith_ext.mac %c5, %c6, %c4 { modulus = 17 } : i10
%mac_vec = arith_ext.mac %c_vec, %c_vec2, %c_vec3 { modulus = 17 } : tensor<4xi10>

// CHECK: arith_ext.barrett_reduce
// CHECK: arith_ext.barrett_reduce
%barrett = arith_ext.barrett_reduce %zero { modulus = 17 } : i10
Expand Down

0 comments on commit a0caa5c

Please sign in to comment.