From a00de653a521279f163bff6a456980087dda1392 Mon Sep 17 00:00:00 2001 From: Asra Ali Date: Mon, 11 Mar 2024 11:43:57 -0700 Subject: [PATCH] bgv: add BGV ciphertext-plaintext ops and cleanup traits Fixes #100 Fixes #212 PiperOrigin-RevId: 614740812 --- include/Dialect/BGV/IR/BGVOps.h | 2 +- include/Dialect/BGV/IR/BGVOps.td | 39 ++++++++++++++++--- include/Dialect/BGV/IR/BUILD | 1 - include/Dialect/LWE/IR/BUILD | 1 + .../IR/BGVTraits.h => LWE/IR/LWETraits.h} | 17 ++++---- lib/Dialect/BGV/IR/BGVDialect.cpp | 31 ++++++++++++++- lib/Dialect/BGV/IR/BUILD | 1 - lib/Dialect/LWE/IR/BUILD | 2 + tests/bgv/ops.mlir | 10 +++++ 9 files changed, 87 insertions(+), 17 deletions(-) rename include/Dialect/{BGV/IR/BGVTraits.h => LWE/IR/LWETraits.h} (71%) diff --git a/include/Dialect/BGV/IR/BGVOps.h b/include/Dialect/BGV/IR/BGVOps.h index 73d668ad6..e99eeba8c 100644 --- a/include/Dialect/BGV/IR/BGVOps.h +++ b/include/Dialect/BGV/IR/BGVOps.h @@ -2,7 +2,7 @@ #define HEIR_INCLUDE_DIALECT_BGV_IR_BGVOPS_H_ #include "include/Dialect/BGV/IR/BGVDialect.h" -#include "include/Dialect/BGV/IR/BGVTraits.h" +#include "include/Dialect/LWE/IR/LWETraits.h" #include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project diff --git a/include/Dialect/BGV/IR/BGVOps.td b/include/Dialect/BGV/IR/BGVOps.td index 0d885aa62..c7f32fa3f 100644 --- a/include/Dialect/BGV/IR/BGVOps.td +++ b/include/Dialect/BGV/IR/BGVOps.td @@ -10,19 +10,36 @@ include "include/Dialect/LWE/IR/LWETypes.td" include "include/Dialect/Polynomial/IR/PolynomialAttributes.td" def SameOperandsAndResultRings: NativeOpTrait<"SameOperandsAndResultRings"> { - let cppNamespace = "::mlir::heir::bgv"; + let cppNamespace = "::mlir::heir::lwe"; } class BGV_Op traits = []> : Op { + let cppNamespace = "::mlir::heir::bgv"; + let assemblyFormat = [{ `(` operands `)` attr-dict `:` `(` qualified(type(operands)) `)` `->` qualified(type(results)) }]; - let cppNamespace = "::mlir::heir::bgv"; } -// TODO(#100): Add plaintext-ciphertext operations. +class BGV_CiphertextPlaintextOp traits = + [AllTypesMatch<["x", "output"]>, + TypesMatchWith<"type of 'y' matches encoding type of 'x'", + "output", "y", + "lwe::RLWEPlaintextType::get($_ctxt, ::llvm::cast($_self).getEncoding())">]> : + BGV_Op { + let arguments = (ins + RLWECiphertext:$x, + RLWEPlaintext:$y + ); + + let results = (outs + RLWECiphertext:$output + ); + + let assemblyFormat = "`(` operands `)` attr-dict `:` qualified(type($output))" ; +} def BGV_AddOp : BGV_Op<"add", [Commutative, SameOperandsAndResultType]> { let summary = "Addition operation between ciphertexts."; @@ -39,6 +56,10 @@ def BGV_AddOp : BGV_Op<"add", [Commutative, SameOperandsAndResultType]> { let assemblyFormat = "`(` operands `)` attr-dict `:` qualified(type($output))" ; } +def BGV_AddPlainOp : BGV_CiphertextPlaintextOp<"add_plain"> { + let summary = "Addition operation between ciphertext-plaintext."; +} + def BGV_SubOp : BGV_Op<"sub", [SameOperandsAndResultType]> { let summary = "Subtraction operation between ciphertexts."; @@ -54,7 +75,11 @@ def BGV_SubOp : BGV_Op<"sub", [SameOperandsAndResultType]> { let assemblyFormat = "`(` operands `)` attr-dict `:` qualified(type($output))" ; } -def BGV_MulOp : BGV_Op<"mul", [Commutative, SameOperandsAndResultRings, SameTypeOperands]> { +def BGV_SubPlainOp : BGV_CiphertextPlaintextOp<"sub_plain"> { + let summary = "Subtraction operation between ciphertext-plaintext."; +} + +def BGV_MulOp : BGV_Op<"mul", [Commutative, SameOperandsAndResultRings, SameTypeOperands, InferTypeOpAdaptor]> { let summary = "Multiplication operation between ciphertexts."; let arguments = (ins @@ -71,6 +96,10 @@ def BGV_MulOp : BGV_Op<"mul", [Commutative, SameOperandsAndResultRings, SameType let hasVerifier = 1; } +def BGV_MulPlainOp : BGV_CiphertextPlaintextOp<"mul_plain"> { + let summary = "Multiplication operation between ciphertext-plaintext."; +} + def BGV_Rotate : BGV_Op<"rotate", [SameOperandsAndResultRings]> { let summary = "Rotate the coefficients of the ciphertext using a Galois automorphism."; @@ -100,7 +129,7 @@ def BGV_Negate : BGV_Op<"negate", [SameOperandsAndResultType]> { let assemblyFormat = "`(` operands `)` attr-dict `:` qualified(type($output))" ; } -def BGV_Relinearize : BGV_Op<"relinearize", [SameOperandsAndResultRings]> { +def BGV_Relinearize : BGV_Op<"relinearize", [SameOperandsAndResultRings, InferTypeOpAdaptor]> { let summary = "Relinearize the ciphertext."; let description = [{ diff --git a/include/Dialect/BGV/IR/BUILD b/include/Dialect/BGV/IR/BUILD index 99d304734..ccd232525 100644 --- a/include/Dialect/BGV/IR/BUILD +++ b/include/Dialect/BGV/IR/BUILD @@ -11,7 +11,6 @@ exports_files( [ "BGVDialect.h", "BGVOps.h", - "BGVTraits.h", ], ) diff --git a/include/Dialect/LWE/IR/BUILD b/include/Dialect/LWE/IR/BUILD index 978d4bde7..49f7a035a 100644 --- a/include/Dialect/LWE/IR/BUILD +++ b/include/Dialect/LWE/IR/BUILD @@ -13,6 +13,7 @@ exports_files( "LWEAttributes.h", "LWETypes.h", "LWEOps.h", + "LWETraits.h", ], ) diff --git a/include/Dialect/BGV/IR/BGVTraits.h b/include/Dialect/LWE/IR/LWETraits.h similarity index 71% rename from include/Dialect/BGV/IR/BGVTraits.h rename to include/Dialect/LWE/IR/LWETraits.h index 5eee3ce69..328763672 100644 --- a/include/Dialect/BGV/IR/BGVTraits.h +++ b/include/Dialect/LWE/IR/LWETraits.h @@ -1,14 +1,15 @@ -#ifndef HEIR_INCLUDE_DIALECT_BGV_IR_BGVTRAITS_H_ -#define HEIR_INCLUDE_DIALECT_BGV_IR_BGVTRAITS_H_ +#ifndef HEIR_INCLUDE_DIALECT_LWE_IR_LWETRAITS_H_ +#define HEIR_INCLUDE_DIALECT_LWE_IR_LWETRAITS_H_ #include "include/Dialect/LWE/IR/LWETypes.h" #include "include/Dialect/Polynomial/IR/PolynomialAttributes.h" -#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project -#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project -namespace mlir::heir::bgv { +namespace mlir::heir::lwe { -// TODO(#212): Move to LWE dialect/namespace // Trait that ensures that all operands and results ciphertext have the same set // of rings. template @@ -48,6 +49,6 @@ class SameOperandsAndResultRings } }; -} // namespace mlir::heir::bgv +} // namespace mlir::heir::lwe -#endif // HEIR_INCLUDE_DIALECT_BGV_IR_BGVTRAITS_H_ +#endif // HEIR_INCLUDE_DIALECT_LWE_IR_LWETRAITS_H_ diff --git a/lib/Dialect/BGV/IR/BGVDialect.cpp b/lib/Dialect/BGV/IR/BGVDialect.cpp index 7e89b7906..183e5b45c 100644 --- a/lib/Dialect/BGV/IR/BGVDialect.cpp +++ b/lib/Dialect/BGV/IR/BGVDialect.cpp @@ -1,9 +1,15 @@ #include "include/Dialect/BGV/IR/BGVDialect.h" +#include + #include "include/Dialect/BGV/IR/BGVOps.h" -#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "include/Dialect/LWE/IR/LWEAttributes.h" +#include "include/Dialect/LWE/IR/LWETypes.h" #include "mlir/include/mlir/IR/Builders.h" // from @llvm-project #include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project +#include "mlir/include/mlir/IR/Location.h" // from @llvm-project +#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project // Generated definitions #include "include/Dialect/BGV/IR/BGVDialect.cpp.inc" @@ -85,6 +91,29 @@ LogicalResult ModulusSwitch::verify() { return success(); } +LogicalResult MulOp::inferReturnTypes( + MLIRContext *ctx, std::optional, MulOp::Adaptor adaptor, + SmallVectorImpl &inferredReturnTypes) { + auto x = cast(adaptor.getX().getType()); + auto y = cast(adaptor.getY().getType()); + auto newDim = + x.getRlweParams().getDimension() + y.getRlweParams().getDimension() - 1; + inferredReturnTypes.push_back(lwe::RLWECiphertextType::get( + ctx, x.getEncoding(), + lwe::RLWEParamsAttr::get(ctx, newDim, x.getRlweParams().getRing()))); + return success(); +} + +LogicalResult Relinearize::inferReturnTypes( + MLIRContext *ctx, std::optional, Relinearize::Adaptor adaptor, + SmallVectorImpl &inferredReturnTypes) { + auto x = cast(adaptor.getX().getType()); + inferredReturnTypes.push_back(lwe::RLWECiphertextType::get( + ctx, x.getEncoding(), + lwe::RLWEParamsAttr::get(ctx, 2, x.getRlweParams().getRing()))); + return success(); +} + } // namespace bgv } // namespace heir } // namespace mlir diff --git a/lib/Dialect/BGV/IR/BUILD b/lib/Dialect/BGV/IR/BUILD index 4ba153eef..ad70e4690 100644 --- a/lib/Dialect/BGV/IR/BUILD +++ b/lib/Dialect/BGV/IR/BUILD @@ -13,7 +13,6 @@ cc_library( hdrs = [ "@heir//include/Dialect/BGV/IR:BGVDialect.h", "@heir//include/Dialect/BGV/IR:BGVOps.h", - "@heir//include/Dialect/BGV/IR:BGVTraits.h", ], deps = [ "@heir//include/Dialect/BGV/IR:dialect_inc_gen", diff --git a/lib/Dialect/LWE/IR/BUILD b/lib/Dialect/LWE/IR/BUILD index 7c315ee1a..f96fd9c9c 100644 --- a/lib/Dialect/LWE/IR/BUILD +++ b/lib/Dialect/LWE/IR/BUILD @@ -12,6 +12,7 @@ cc_library( "@heir//include/Dialect/LWE/IR:LWEAttributes.h", "@heir//include/Dialect/LWE/IR:LWEDialect.h", "@heir//include/Dialect/LWE/IR:LWEOps.h", + "@heir//include/Dialect/LWE/IR:LWETraits.h", "@heir//include/Dialect/LWE/IR:LWETypes.h", ], deps = [ @@ -21,6 +22,7 @@ cc_library( "@heir//include/Dialect/LWE/IR:types_inc_gen", "@heir//include/Dialect/Polynomial/IR:attributes_inc_gen", "@heir//lib/Dialect/Polynomial/IR:Dialect", + "@heir//lib/Dialect/Polynomial/IR:Polynomial", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", diff --git a/tests/bgv/ops.mlir b/tests/bgv/ops.mlir index 6f4cb3bfc..a2524b3bf 100644 --- a/tests/bgv/ops.mlir +++ b/tests/bgv/ops.mlir @@ -14,6 +14,8 @@ #params1 = #lwe.rlwe_params #params2 = #lwe.rlwe_params +!pt = !lwe.rlwe_plaintext + !ct = !lwe.rlwe_ciphertext !ct1 = !lwe.rlwe_ciphertext !ct2 = !lwe.rlwe_ciphertext @@ -31,4 +33,12 @@ module { // CHECK: rlwe_params = >> return %arg0 : !ct } + + func.func @test_ciphertext_plaintext(%arg0: !pt, %arg1: !pt, %arg2: !pt, %arg3: !ct) -> !ct { + %add = bgv.add_plain(%arg3, %arg0) : !ct + %sub = bgv.sub_plain(%add, %arg1) : !ct + %mul = bgv.mul_plain(%sub, %arg2) : !ct + // CHECK: rlwe_params = >> + return %mul : !ct + } }