Skip to content

Commit

Permalink
bgv/openfhe: update rotation index to be an integerattr
Browse files Browse the repository at this point in the history
This change updates the rotation index / offset to be an attribute rather than an SSA value, reflecting the requirement that we must have a statically known constant rotation for the RLWE rotation operation.

We can also lower uknown tensor_ext rotations like a blind rotate or PIR style operation, but that is future work if we need it.

Fixes #741

I'll rebase this on top of #696

PiperOrigin-RevId: 644401397
  • Loading branch information
asraa authored and Copybara-Service committed Jun 20, 2024
1 parent b10e7f3 commit 62b2671
Show file tree
Hide file tree
Showing 15 changed files with 70 additions and 99 deletions.
28 changes: 4 additions & 24 deletions lib/Conversion/BGVToOpenfhe/BGVToOpenfhe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,29 +180,9 @@ struct ConvertRotateOp : public OpConversionPattern<RotateOp> {
if (failed(result)) return result;

Value cryptoContext = result.value();
Value castOffset =
llvm::TypeSwitch<Type, Value>(adaptor.getOffset().getType())
.Case<IndexType>([&](auto ty) {
return rewriter
.create<arith::IndexCastOp>(
op.getLoc(), rewriter.getI64Type(), adaptor.getOffset())
.getResult();
})
.Case<IntegerType>([&](IntegerType ty) {
if (ty.getWidth() < 64) {
return rewriter
.create<arith::ExtUIOp>(op.getLoc(), rewriter.getI64Type(),
adaptor.getOffset())
.getResult();
}
return rewriter
.create<arith::TruncIOp>(op.getLoc(), rewriter.getI64Type(),
adaptor.getOffset())
.getResult();
});
rewriter.replaceOp(
op, rewriter.create<openfhe::RotOp>(op.getLoc(), cryptoContext,
adaptor.getInput(), castOffset));
rewriter.replaceOp(op, rewriter.create<openfhe::RotOp>(
op.getLoc(), cryptoContext, adaptor.getInput(),
adaptor.getOffset()));
return success();
}
};
Expand Down Expand Up @@ -376,7 +356,7 @@ struct ConvertExtractOp : public OpConversionPattern<ExtractOp> {
auto plainMul =
b.create<bgv::MulPlainOp>(adaptor.getInput(), oneHotPlaintext)
.getResult();
auto rotated = b.create<bgv::RotateOp>(plainMul, adaptor.getOffset());
auto rotated = b.create<bgv::RotateOp>(plainMul, offsetAttr);
// It might make sense to move this op to the add-client-interface pass,
// but it also seems like an implementation detail of OpenFHE, and not part
// of BGV generally.
Expand Down
24 changes: 22 additions & 2 deletions lib/Conversion/SecretToBGV/SecretToBGV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "lib/Dialect/Secret/IR/SecretTypes.h"
#include "lib/Dialect/TensorExt/IR/TensorExtOps.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.h" // from @llvm-project
Expand Down Expand Up @@ -155,6 +156,25 @@ class SecretGenericOpMulConversion
}
};

class SecretGenericOpRotateConversion
: public SecretGenericOpConversion<tensor_ext::RotateOp, bgv::RotateOp> {
public:
using SecretGenericOpConversion<tensor_ext::RotateOp,
bgv::RotateOp>::SecretGenericOpConversion;

void replaceOp(secret::GenericOp op, TypeRange outputTypes, ValueRange inputs,
ConversionPatternRewriter &rewriter) const override {
// Check that the offset is a constant.
auto offset = inputs[1];
auto constantOffset = dyn_cast<arith::ConstantOp>(offset.getDefiningOp());
if (!constantOffset) {
op.emitError("expected constant offset for rotate");
}
auto offsetAttr = llvm::dyn_cast<IntegerAttr>(constantOffset.getValue());
rewriter.replaceOpWithNewOp<bgv::RotateOp>(op, inputs[0], offsetAttr);
}
};

struct SecretToBGV : public impl::SecretToBGVBase<SecretToBGV> {
using SecretToBGVBase::SecretToBGVBase;

Expand Down Expand Up @@ -201,8 +221,8 @@ struct SecretToBGV : public impl::SecretToBGVBase<SecretToBGV> {
patterns.add<SecretGenericOpConversion<arith::AddIOp, bgv::AddOp>,
SecretGenericOpConversion<arith::SubIOp, bgv::SubOp>,
SecretGenericOpConversion<tensor::ExtractOp, bgv::ExtractOp>,
SecretGenericOpConversion<tensor_ext::RotateOp, bgv::RotateOp>,
SecretGenericOpMulConversion>(typeConverter, context);
SecretGenericOpRotateConversion, SecretGenericOpMulConversion>(
typeConverter, context);

if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
return signalPassFailure();
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/BGV/IR/BGVOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,15 @@ def BGV_RotateOp : BGV_Op<"rotate", [AllTypesMatch<["input", "output"]>]> {

let arguments = (ins
RLWECiphertext:$input,
SignlessIntegerLike:$offset
Builtin_IntegerAttr:$offset
);

let results = (outs
RLWECiphertext:$output
);

let hasVerifier = 1;
let assemblyFormat = "operands attr-dict `:` qualified(type($input)) `,` type($offset)" ;
let assemblyFormat = "operands attr-dict `:` qualified(type($input))" ;
}

def BGV_ExtractOp : BGV_Op<"extract", [SameOperandsAndResultRings]> {
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Openfhe/IR/OpenfheOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def RotOp : Openfhe_Op<"rot",[
let arguments = (ins
Openfhe_CryptoContext:$cryptoContext,
RLWECiphertext:$ciphertext,
I64:$index
Builtin_IntegerAttr:$index
);
let results = (outs RLWECiphertext:$output);
}
Expand Down
6 changes: 1 addition & 5 deletions lib/Dialect/Openfhe/Transforms/ConfigureCryptoContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,7 @@ bool hasMulOp(func::FuncOp op) {
SmallVector<int64_t> findAllRotIndices(func::FuncOp op) {
std::set<int64_t> distinctRotIndices;
op.walk([&](openfhe::RotOp rotOp) {
auto indexAttr =
dyn_cast<arith::ConstantOp>(rotOp.getIndex().getDefiningOp())
.getValue();
int64_t rotIndex = dyn_cast<IntegerAttr>(indexAttr).getInt();
distinctRotIndices.insert(rotIndex);
distinctRotIndices.insert(rotOp.getIndex().getInt());
return WalkResult::advance();
});
SmallVector<int64_t> rotIndicesResult(distinctRotIndices.begin(),
Expand Down
9 changes: 7 additions & 2 deletions lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,13 @@ LogicalResult OpenFhePkeEmitter::printOperation(LevelReduceOp op) {
}

LogicalResult OpenFhePkeEmitter::printOperation(RotOp op) {
return printEvalMethod(op.getResult(), op.getCryptoContext(),
{op.getCiphertext(), op.getIndex()}, "EvalRotate");
emitAutoAssignPrefix(op.getResult());

os << variableNames->getNameForValue(op.getCryptoContext()) << "->"
<< "EvalRotate" << "("
<< variableNames->getNameForValue(op.getCiphertext()) << ", "
<< op.getIndex().getValue() << ");\n";
return success();
}

LogicalResult OpenFhePkeEmitter::printOperation(AutomorphOp op) {
Expand Down
15 changes: 5 additions & 10 deletions tests/bgv/add_client_interface.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,16 @@
!out_ty = !lwe.rlwe_ciphertext<encoding = #lwe.polynomial_evaluation_encoding<cleartext_start = 16, cleartext_bitwidth = 16>, rlwe_params = <ring = <coefficientType = i32, coefficientModulus = 463187969 : i32, polynomialModulus=#polynomial.int_polynomial<1 + x**32>>>, underlying_type = i16>

func.func @simple_sum(%arg0: !in_ty) -> !out_ty {
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%c1 = arith.constant 1 : index
%c31 = arith.constant 31 : index
%0 = bgv.rotate %arg0, %c16 : !in_ty, index
%0 = bgv.rotate %arg0 { offset = 16 } : !in_ty
%1 = bgv.add %arg0, %0 : !in_ty
%2 = bgv.rotate %1, %c8 : !in_ty, index
%2 = bgv.rotate %1 { offset = 8 } : !in_ty
%3 = bgv.add %1, %2 : !in_ty
%4 = bgv.rotate %3, %c4 : !in_ty, index
%4 = bgv.rotate %3 { offset = 4 } : !in_ty
%5 = bgv.add %3, %4 : !in_ty
%6 = bgv.rotate %5, %c2 : !in_ty, index
%6 = bgv.rotate %5 { offset = 2 } : !in_ty
%7 = bgv.add %5, %6 : !in_ty
%8 = bgv.rotate %7, %c1 : !in_ty, index
%8 = bgv.rotate %7 { offset = 1 } : !in_ty
%9 = bgv.add %7, %8 : !in_ty
%10 = bgv.extract %9, %c31 : (!in_ty, index) -> !out_ty
return %10 : !out_ty
Expand Down
15 changes: 5 additions & 10 deletions tests/bgv/add_client_interface_public_key.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,16 @@
!out_ty = !lwe.rlwe_ciphertext<encoding = #encoding, rlwe_params = #params, underlying_type = i16>

func.func @simple_sum(%arg0: !in_ty) -> !out_ty {
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%c1 = arith.constant 1 : index
%c31 = arith.constant 31 : index
%0 = bgv.rotate %arg0, %c16 : !in_ty, index
%0 = bgv.rotate %arg0 { offset = 16 } : !in_ty
%1 = bgv.add %arg0, %0 : !in_ty
%2 = bgv.rotate %1, %c8 : !in_ty, index
%2 = bgv.rotate %1 { offset = 8 } : !in_ty
%3 = bgv.add %1, %2 : !in_ty
%4 = bgv.rotate %3, %c4 : !in_ty, index
%4 = bgv.rotate %3 { offset = 4 } : !in_ty
%5 = bgv.add %3, %4 : !in_ty
%6 = bgv.rotate %5, %c2 : !in_ty, index
%6 = bgv.rotate %5 { offset = 2 } : !in_ty
%7 = bgv.add %5, %6 : !in_ty
%8 = bgv.rotate %7, %c1 : !in_ty, index
%8 = bgv.rotate %7 { offset = 1 } : !in_ty
%9 = bgv.add %7, %8 : !in_ty
%10 = bgv.extract %9, %c31 : (!in_ty, index) -> !out_ty
return %10 : !out_ty
Expand Down
9 changes: 3 additions & 6 deletions tests/bgv/add_client_interface_split.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,14 @@
!mul_ty = !lwe.rlwe_ciphertext<encoding = #lwe.polynomial_evaluation_encoding<cleartext_start = 16, cleartext_bitwidth = 16>, rlwe_params = <dimension = 3, ring = <coefficientType = i32, coefficientModulus = 463187969 : i32, polynomialModulus=#polynomial.int_polynomial<1 + x**8>>>, underlying_type = tensor<8xi16>>

func.func @dot_product(%arg0: !in_ty, %arg1: !in_ty) -> (!out_ty, !out_ty) {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
%c7 = arith.constant 7 : index
%0 = bgv.mul %arg0, %arg1 : (!in_ty, !in_ty) -> !mul_ty
%1 = bgv.relinearize %0 {from_basis = array<i32: 0, 1, 2>, to_basis = array<i32: 0, 1>} : !mul_ty -> !in_ty
%2 = bgv.rotate %1, %c4 : !in_ty, index
%2 = bgv.rotate %1 { offset = 4 } : !in_ty
%3 = bgv.add %1, %2 : !in_ty
%4 = bgv.rotate %3, %c2 : !in_ty, index
%4 = bgv.rotate %3 { offset = 2 } : !in_ty
%5 = bgv.add %3, %4 : !in_ty
%6 = bgv.rotate %5, %c1 : !in_ty, index
%6 = bgv.rotate %5 { offset = 1 } : !in_ty
%7 = bgv.add %5, %6 : !in_ty
%8 = bgv.extract %7, %c7 : (!in_ty, index) -> !out_ty
return %8, %8 : !out_ty, !out_ty
Expand Down
3 changes: 1 addition & 2 deletions tests/bgv/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ module {
// CHECK-LABEL: @test_rotate_extract
func.func @test_rotate_extract(%arg3: !ct_tensor) -> !ct_scalar {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%add = bgv.rotate %arg3, %c1 : !ct_tensor, index
%add = bgv.rotate %arg3 { offset = 1 } : !ct_tensor
%ext = bgv.extract %add, %c0 : (!ct_tensor, index) -> !ct_scalar
// CHECK: rlwe_params = <ring = <coefficientType = i32, coefficientModulus = 161729713 : i32, polynomialModulus = <1 + x**1024>>>
return %ext : !ct_scalar
Expand Down
3 changes: 1 addition & 2 deletions tests/bgv/verifier.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
!ct1 = !lwe.rlwe_ciphertext<encoding=#encoding, rlwe_params=#params1, underlying_type=i3>

func.func @test_input_dimension_error(%input: !ct) {
%offset = arith.constant 4 : index
// expected-error@+1 {{x.dim == 2 does not hold}}
%out = bgv.rotate %input, %offset : !ct, index
%out = bgv.rotate %input { offset = 4 } : !ct
return
}

Expand Down
8 changes: 3 additions & 5 deletions tests/bgv_to_openfhe/bgv_to_openfhe.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,9 @@ module {
%sub = bgv.sub %x, %y : !ct
// CHECK: %[[v4:.*]] = openfhe.mul_no_relin [[C]], %[[x4:.*]], %[[y4:.*]]: ([[S]], [[T]], [[T]]) -> [[T2:.*]]
%mul = bgv.mul %x, %y : (!ct, !ct) -> !ct_level3
// CHECK: %[[c5:.*]] = arith.index_cast
// CHECK-SAME: to i64
%c4 = arith.constant 4 : index
// CHECK: %[[v5:.*]] = openfhe.rot [[C]], %[[x5:.*]], %[[c5:.*]]: ([[S]], [[T]], i64) -> [[T]]
%rot = bgv.rotate %x, %c4 : !ct, index
// CHECK: %[[v5:.*]] = openfhe.rot [[C]], %[[x5:.*]] {index = 4 : i64}
// CHECK-SAME: ([[S]], [[T]]) -> [[T]]
%rot = bgv.rotate %x { offset = 4 } : !ct
return
}

Expand Down
12 changes: 6 additions & 6 deletions tests/openfhe/configure_crypto_context.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,19 @@ func.func @simple_sum(%arg0: !ctxt_ty, %arg1: !in_ty) -> !out_ty {
%c8_i64 = arith.constant 8 : i64
%cst = arith.constant dense<[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]> : tensor<32xi16>
%c16_i64 = arith.constant 16 : i64
%0 = openfhe.rot %arg0, %arg1, %c16_i64 : (!ctxt_ty, !in_ty, i64) -> !in_ty
%0 = openfhe.rot %arg0, %arg1 { index = 16 } : (!ctxt_ty, !in_ty) -> !in_ty
%1 = openfhe.add %arg0, %arg1, %0 : (!ctxt_ty, !in_ty, !in_ty) -> !in_ty
%2 = openfhe.rot %arg0, %1, %c8_i64 : (!ctxt_ty, !in_ty, i64) -> !in_ty
%2 = openfhe.rot %arg0, %1 { index = 8 } : (!ctxt_ty, !in_ty) -> !in_ty
%3 = openfhe.add %arg0, %1, %2 : (!ctxt_ty, !in_ty, !in_ty) -> !in_ty
%4 = openfhe.rot %arg0, %3, %c4_i64 : (!ctxt_ty, !in_ty, i64) -> !in_ty
%4 = openfhe.rot %arg0, %3 { index = 4 } : (!ctxt_ty, !in_ty) -> !in_ty
%5 = openfhe.add %arg0, %3, %4 : (!ctxt_ty, !in_ty, !in_ty) -> !in_ty
%6 = openfhe.rot %arg0, %5, %c2_i64 : (!ctxt_ty, !in_ty, i64) -> !in_ty
%6 = openfhe.rot %arg0, %5 { index = 2 } : (!ctxt_ty, !in_ty) -> !in_ty
%7 = openfhe.add %arg0, %5, %6 : (!ctxt_ty, !in_ty, !in_ty) -> !in_ty
%8 = openfhe.rot %arg0, %7, %c1_i64 : (!ctxt_ty, !in_ty, i64) -> !in_ty
%8 = openfhe.rot %arg0, %7 { index = 1 } : (!ctxt_ty, !in_ty) -> !in_ty
%9 = openfhe.add %arg0, %7, %8 : (!ctxt_ty, !in_ty, !in_ty) -> !in_ty
%10 = lwe.rlwe_encode %cst {encoding = #encoding, ring = #ring} : tensor<32xi16> -> !plain_ty
%11 = openfhe.mul_plain %arg0, %9, %10 : (!ctxt_ty, !in_ty, !plain_ty) -> !in_ty
%12 = openfhe.rot %arg0, %11, %c31_i64 : (!ctxt_ty, !in_ty, i64) -> !in_ty
%12 = openfhe.rot %arg0, %11 { index = 31 } : (!ctxt_ty, !in_ty) -> !in_ty
%13 = lwe.reinterpret_underlying_type %12 : !in_ty to !out_ty
return %13 : !out_ty
}
Expand Down
28 changes: 8 additions & 20 deletions tests/openfhe/emit_openfhe_pke.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
// CHECK-NEXT: const auto& [[v10:.*]] = [[CC]]->Relinearize([[v9]]);
// CHECK-NEXT: const auto& [[v11:.*]] = [[CC]]->ModReduce([[v10]]);
// CHECK-NEXT: const auto& [[v12:.*]] = [[CC]]->LevelReduce([[v11]]);
// CHECK-NEXT: const auto& [[v13:.*]] = [[CC]]->EvalRotate([[v12]], [[const]]);
// CHECK-NEXT: const auto& [[v13:.*]] = [[CC]]->EvalRotate([[v12]], 4);
// CHECK-NEXT: std::map<uint32_t, EvalKeyT> [[v14_evalkeymap:.*]] = {{[{][{]}}0, [[ARG4]]{{[}][}]}};
// CHECK-NEXT: const auto& [[v14:.*]] = [[CC]]->EvalAutomorphism([[v13]], 0, [[v14_evalkeymap]]);
// CHECK-NEXT: const auto& [[v15:.*]] = [[CC]]->KeySwitch([[v14]], [[ARG4]]);
Expand All @@ -46,7 +46,7 @@ func.func @test_basic_emitter(%cc : !cc, %input1 : !ct, %input2 : !ct, %input3:
%relin_res = openfhe.relin %cc, %mul_const_res : (!cc, !ct) -> !ct
%mod_reduce_res = openfhe.mod_reduce %cc, %relin_res : (!cc, !ct) -> !ct
%level_reduce_res = openfhe.level_reduce %cc, %mod_reduce_res : (!cc, !ct) -> !ct
%rotate_res = openfhe.rot %cc, %level_reduce_res, %const : (!cc, !ct, i64) -> !ct
%rotate_res = openfhe.rot %cc, %level_reduce_res { index = 4 } : (!cc, !ct) -> !ct
%automorph_res = openfhe.automorph %cc, %rotate_res, %eval_key : (!cc, !ct, !ek) -> !ct
%key_switch_res = openfhe.key_switch %cc, %automorph_res, %eval_key : (!cc, !ct, !ek) -> !ct
return %key_switch_res: !ct
Expand All @@ -72,32 +72,20 @@ func.func @test_basic_emitter(%cc : !cc, %input1 : !ct, %input2 : !ct, %input3:
// CHECK: int16_t
// CHECK-SAME: [0]
func.func @simple_sum(%arg0: !openfhe.crypto_context, %arg1: !tensor_ct_ty) -> !scalar_ct_ty {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%c31 = arith.constant 31 : index
%0 = arith.index_cast %c16 : index to i64
%1 = openfhe.rot %arg0, %arg1, %0 : (!openfhe.crypto_context, !tensor_ct_ty, i64) -> !tensor_ct_ty
%1 = openfhe.rot %arg0, %arg1 { index = 16 } : (!openfhe.crypto_context, !tensor_ct_ty) -> !tensor_ct_ty
%2 = openfhe.add %arg0, %arg1, %1 : (!openfhe.crypto_context, !tensor_ct_ty, !tensor_ct_ty) -> !tensor_ct_ty
%3 = arith.index_cast %c8 : index to i64
%4 = openfhe.rot %arg0, %2, %3 : (!openfhe.crypto_context, !tensor_ct_ty, i64) -> !tensor_ct_ty
%4 = openfhe.rot %arg0, %2 { index = 8 } : (!openfhe.crypto_context, !tensor_ct_ty) -> !tensor_ct_ty
%5 = openfhe.add %arg0, %2, %4 : (!openfhe.crypto_context, !tensor_ct_ty, !tensor_ct_ty) -> !tensor_ct_ty
%6 = arith.index_cast %c4 : index to i64
%7 = openfhe.rot %arg0, %5, %6 : (!openfhe.crypto_context, !tensor_ct_ty, i64) -> !tensor_ct_ty
%7 = openfhe.rot %arg0, %5 { index = 4 } : (!openfhe.crypto_context, !tensor_ct_ty) -> !tensor_ct_ty
%8 = openfhe.add %arg0, %5, %7 : (!openfhe.crypto_context, !tensor_ct_ty, !tensor_ct_ty) -> !tensor_ct_ty
%9 = arith.index_cast %c2 : index to i64
%10 = openfhe.rot %arg0, %8, %9 : (!openfhe.crypto_context, !tensor_ct_ty, i64) -> !tensor_ct_ty
%10 = openfhe.rot %arg0, %8 { index = 2 } : (!openfhe.crypto_context, !tensor_ct_ty) -> !tensor_ct_ty
%11 = openfhe.add %arg0, %8, %10 : (!openfhe.crypto_context, !tensor_ct_ty, !tensor_ct_ty) -> !tensor_ct_ty
%12 = arith.index_cast %c1 : index to i64
%13 = openfhe.rot %arg0, %11, %12 : (!openfhe.crypto_context, !tensor_ct_ty, i64) -> !tensor_ct_ty
%13 = openfhe.rot %arg0, %11 { index = 1 } : (!openfhe.crypto_context, !tensor_ct_ty) -> !tensor_ct_ty
%14 = openfhe.add %arg0, %11, %13 : (!openfhe.crypto_context, !tensor_ct_ty, !tensor_ct_ty) -> !tensor_ct_ty
%cst = arith.constant dense<[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]> : tensor<32xi16>
%15 = lwe.rlwe_encode %cst {encoding = #eval_encoding, ring = #ring2} : tensor<32xi16> -> !tensor_pt_ty
%16 = openfhe.mul_plain %arg0, %14, %15 : (!openfhe.crypto_context, !tensor_ct_ty, !tensor_pt_ty) -> !tensor_ct_ty
%17 = arith.index_cast %c31 : index to i64
%18 = openfhe.rot %arg0, %16, %17 : (!openfhe.crypto_context, !tensor_ct_ty, i64) -> !tensor_ct_ty
%18 = openfhe.rot %arg0, %16 { index = 31 } : (!openfhe.crypto_context, !tensor_ct_ty) -> !tensor_ct_ty
%19 = lwe.reinterpret_underlying_type %18 : !tensor_ct_ty to !scalar_ct_ty
return %19 : !scalar_ct_ty
}
Expand Down
3 changes: 1 addition & 2 deletions tests/openfhe/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,8 @@ module {

// CHECK-LABEL: func @test_rot
func.func @test_rot(%cc : !cc, %pt : !pt, %pk: !pk) {
%0 = arith.constant 2 : i64
%ct = openfhe.encrypt %cc, %pt, %pk : (!cc, !pt, !pk) -> !ct
%out = openfhe.rot %cc, %ct, %0: (!cc, !ct, i64) -> !ct
%out = openfhe.rot %cc, %ct { index = 2 }: (!cc, !ct) -> !ct
return
}

Expand Down

0 comments on commit 62b2671

Please sign in to comment.