diff --git a/.gitignore b/.gitignore index 0944affe3..5cee65586 100644 --- a/.gitignore +++ b/.gitignore @@ -15,4 +15,9 @@ hugo_stats.json venv # for rust codegen tests -Cargo.lock +**/Cargo.lock +tests/**/**/target/ +tests/tfhe_rust_bool/end_to_end_fpga/ + +# vscode +.vscode/** diff --git a/include/Conversion/CGGIToTfheRustBool/BUILD b/include/Conversion/CGGIToTfheRustBool/BUILD new file mode 100644 index 000000000..536ca0652 --- /dev/null +++ b/include/Conversion/CGGIToTfheRustBool/BUILD @@ -0,0 +1,37 @@ +# CGGIToTfheRustBool tablegen and headers. + +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +exports_files( + [ + "CGGIToTfheRustBool.h", + ], +) + +gentbl_cc_library( + name = "pass_inc_gen", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=CGGIToTfheRustBool", + ], + "CGGIToTfheRustBool.h.inc", + ), + ( + ["-gen-pass-doc"], + "CGGIToTfheRustBool.md", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "CGGIToTfheRustBool.td", + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + ], +) diff --git a/include/Conversion/CGGIToTfheRustBool/CGGIToTfheRustBool.h b/include/Conversion/CGGIToTfheRustBool/CGGIToTfheRustBool.h new file mode 100644 index 000000000..ec9a8abb6 --- /dev/null +++ b/include/Conversion/CGGIToTfheRustBool/CGGIToTfheRustBool.h @@ -0,0 +1,16 @@ +#ifndef INCLUDE_CONVERSION_CGGITOTFHERUSTBOOL_CGGITOTFHERUSTBOOL_H_ +#define INCLUDE_CONVERSION_CGGITOTFHERUSTBOOL_CGGITOTFHERUSTBOOL_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir::heir { + +#define GEN_PASS_DECL +#include "include/Conversion/CGGIToTfheRustBool/CGGIToTfheRustBool.h.inc" + +#define GEN_PASS_REGISTRATION +#include "include/Conversion/CGGIToTfheRustBool/CGGIToTfheRustBool.h.inc" + +} // namespace mlir::heir + +#endif // INCLUDE_CONVERSION_CGGITOTFHERUSTBOOL_CGGITOTFHERUSTBOOL_H_ diff --git a/include/Conversion/CGGIToTfheRustBool/CGGIToTfheRustBool.td b/include/Conversion/CGGIToTfheRustBool/CGGIToTfheRustBool.td new file mode 100644 index 000000000..3aad9673a --- /dev/null +++ b/include/Conversion/CGGIToTfheRustBool/CGGIToTfheRustBool.td @@ -0,0 +1,16 @@ +#ifndef INCLUDE_CONVERSION_CGGITOTFHERUSTBOOL_CGGITOTFHERUSTBOOL_TD_ +#define INCLUDE_CONVERSION_CGGITOTFHERUSTBOOL_CGGITOTFHERUSTBOOL_TD_ + +include "mlir/Pass/PassBase.td" + +def CGGIToTfheRustBool : Pass<"cggi-to-tfhe-rust-bool"> { + let summary = "Lower `cggi` to `tfhe_rust_bool` dialect."; + let dependentDialects = [ + "mlir::arith::ArithDialect", + "mlir::heir::cggi::CGGIDialect", + "mlir::heir::lwe::LWEDialect", + "mlir::heir::tfhe_rust_bool::TfheRustBoolDialect", + ]; +} + +#endif // INCLUDE_CONVERSION_CGGITOTFHERUSTBOOL_CGGITOTFHERUSTBOOL_TD_ diff --git a/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td b/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td index e03aba502..c711f3d93 100644 --- a/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td +++ b/include/Dialect/TfheRustBool/IR/TfheRustBoolOps.td @@ -8,6 +8,7 @@ include "mlir/IR/BuiltinAttributes.td" include "mlir/IR/CommonTypeConstraints.td" include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/BuiltinTypes.td" class TfheRustBool_Op traits = []> : @@ -43,6 +44,19 @@ def XorOp : TfheRustBool_BinaryGateOp<"xor"> { let summary = "Logical XOR of two def XnorOp : TfheRustBool_BinaryGateOp<"xnor"> { let summary = "Logical XNOR of two TFHE-rs Bool ciphertexts."; } +def AndPackedOp : TfheRustBool_Op<"and_packed", [ + Pure, + AllTypesMatch<["lhs", "rhs", "output"]> +]> { + let arguments = (ins + TfheRustBool_ServerKey:$serverKey, + TensorOf<[TfheRustBool_Encrypted]>:$lhs, + TensorOf<[TfheRustBool_Encrypted]>:$rhs + ); + let results = (outs TensorOf<[TfheRustBool_Encrypted]>:$output); +} + + def NotOp : TfheRustBool_Op<"not", [ Pure, AllTypesMatch<["input", "output"]> diff --git a/include/Target/TfheRustBool/TfheRustBoolEmitter.h b/include/Target/TfheRustBool/TfheRustBoolEmitter.h index 99f1288f7..fca9fe291 100644 --- a/include/Target/TfheRustBool/TfheRustBoolEmitter.h +++ b/include/Target/TfheRustBool/TfheRustBoolEmitter.h @@ -59,6 +59,8 @@ class TfheRustBoolEmitter { LogicalResult printOperation(XorOp op); LogicalResult printOperation(XnorOp op); + LogicalResult printOperation(AndPackedOp op); + // Helpers for above LogicalResult printSksMethod(::mlir::Value result, ::mlir::Value sks, ::mlir::ValueRange nonSksOperands, diff --git a/lib/Conversion/CGGIToTfheRustBool/BUILD b/lib/Conversion/CGGIToTfheRustBool/BUILD new file mode 100644 index 000000000..7ad4fcc56 --- /dev/null +++ b/lib/Conversion/CGGIToTfheRustBool/BUILD @@ -0,0 +1,28 @@ +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "CGGIToTfheRustBool", + srcs = ["CGGIToTfheRustBool.cpp"], + hdrs = [ + "@heir//include/Conversion/CGGIToTfheRustBool:CGGIToTfheRustBool.h", + ], + deps = [ + "@heir//include/Conversion/CGGIToTfheRustBool:pass_inc_gen", + "@heir//lib/Conversion:Utils", + "@heir//lib/Dialect/CGGI/IR:Dialect", + "@heir//lib/Dialect/LWE/IR:Dialect", + "@heir//lib/Dialect/TfheRustBool/IR:Dialect", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + ], +) diff --git a/lib/Conversion/CGGIToTfheRustBool/CGGIToTfheRustBool.cpp b/lib/Conversion/CGGIToTfheRustBool/CGGIToTfheRustBool.cpp new file mode 100644 index 000000000..3a4716fb4 --- /dev/null +++ b/lib/Conversion/CGGIToTfheRustBool/CGGIToTfheRustBool.cpp @@ -0,0 +1,318 @@ +#include "include/Conversion/CGGIToTfheRustBool/CGGIToTfheRustBool.h" + +#include + +#include "include/Dialect/CGGI/IR/CGGIDialect.h" +#include "include/Dialect/CGGI/IR/CGGIOps.h" +#include "include/Dialect/LWE/IR/LWEAttributes.h" +#include "include/Dialect/LWE/IR/LWEDialect.h" +#include "include/Dialect/LWE/IR/LWEOps.h" +#include "include/Dialect/LWE/IR/LWETypes.h" +#include "include/Dialect/TfheRustBool/IR/TfheRustBoolDialect.h" +#include "include/Dialect/TfheRustBool/IR/TfheRustBoolOps.h" +#include "include/Dialect/TfheRustBool/IR/TfheRustBoolTypes.h" +#include "lib/Conversion/Utils.h" +#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project +#include "llvm/include/llvm/Support/ErrorHandling.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir::heir { + +#define GEN_PASS_DEF_CGGITOTFHERUSTBOOL +#include "include/Conversion/CGGIToTfheRustBool/CGGIToTfheRustBool.h.inc" + +class BoolPassTypeConverter : public TypeConverter { + public: + BoolPassTypeConverter(MLIRContext *ctx) { + addConversion([](Type type) { return type; }); + addConversion([ctx](lwe::LWECiphertextType type) -> Type { + return tfhe_rust_bool::EncryptedBoolType::get(ctx); + }); + addConversion([this](ShapedType type) -> Type { + return type.cloneWith(type.getShape(), + this->convertType(type.getElementType())); + }); + } +}; + +// /// Returns true if the func's body contains any CGGI ops. +bool containsCGGIOpsBool(func::FuncOp func) { + auto walkResult = func.walk([&](Operation *op) { + if (llvm::isa(op->getDialect())) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return walkResult.wasInterrupted(); +} + +/// Returns the Value corresponding to a server key in the FuncOp containing +/// this op. +FailureOr getContextualBoolServerKey(Operation *op) { + Value serverKey = op->getParentOfType() + .getBody() + .getBlocks() + .front() + .getArguments() + .front(); + if (!serverKey.getType().isa()) { + return op->emitOpError() + << "Found CGGI op in a function without a server " + "key argument. Did the AddBoolServerKeyArg pattern fail to run?"; + } + return serverKey; +} + +template +struct GenericOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector retTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), + retTypes))) + return failure(); + rewriter.replaceOpWithNewOp(op, retTypes, adaptor.getOperands(), + op->getAttrs()); + + return success(); + } +}; + +/// Convert a func by adding a server key argument. Converted ops in other +/// patterns need a server key SSA value available, so this pattern needs a +/// higher benefit. +struct AddBoolServerKeyArg : public OpConversionPattern { + AddBoolServerKeyArg(mlir::MLIRContext *context) + : OpConversionPattern(context, /* benefit= */ 2) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + func::FuncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!containsCGGIOpsBool(op)) { + return failure(); + } + + auto serverKeyType = tfhe_rust_bool::ServerKeyType::get(getContext()); + FunctionType originalType = op.getFunctionType(); + llvm::SmallVector newTypes; + newTypes.reserve(originalType.getNumInputs() + 1); + newTypes.push_back(serverKeyType); + for (auto t : originalType.getInputs()) { + newTypes.push_back(t); + } + auto newFuncType = + FunctionType::get(getContext(), newTypes, originalType.getResults()); + rewriter.modifyOpInPlace(op, [&] { + op.setType(newFuncType); + + // In addition to updating the type signature, we need to update the + // entry block's arguments to match the type signature + Block &block = op.getBody().getBlocks().front(); + block.insertArgument(&block.getArguments().front(), serverKeyType, + op.getLoc()); + }); + + return success(); + } +}; + +struct ConvertBoolAndOp : public OpConversionPattern { + ConvertBoolAndOp(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + cggi::AndOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + FailureOr result = getContextualBoolServerKey(op); + if (failed(result)) return result; + + Value serverKey = result.value(); + + rewriter.replaceOp(op, b.create( + serverKey, adaptor.getLhs(), adaptor.getRhs())); + return success(); + } +}; + +struct ConvertBoolOrOp : public OpConversionPattern { + ConvertBoolOrOp(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + cggi::OrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + FailureOr result = getContextualBoolServerKey(op); + if (failed(result)) return result; + + Value serverKey = result.value(); + + rewriter.replaceOp(op, b.create( + serverKey, adaptor.getLhs(), adaptor.getRhs())); + return success(); + } +}; + +struct ConvertBoolXorOp : public OpConversionPattern { + ConvertBoolXorOp(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + cggi::XorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + FailureOr result = getContextualBoolServerKey(op); + if (failed(result)) return result; + + Value serverKey = result.value(); + + rewriter.replaceOp(op, b.create( + serverKey, adaptor.getLhs(), adaptor.getRhs())); + return success(); + } +}; + +struct ConvertBoolNotOp : public OpConversionPattern { + ConvertBoolNotOp(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + cggi::NotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + FailureOr result = getContextualBoolServerKey(op); + if (failed(result)) return result; + + Value serverKey = result.value(); + + rewriter.replaceOp( + op, b.create(serverKey, adaptor.getInput())); + return success(); + } +}; + +struct ConvertBoolTrivialEncryptOp + : public OpConversionPattern { + ConvertBoolTrivialEncryptOp(mlir::MLIRContext *context) + : OpConversionPattern(context, /*benefit=*/2) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + lwe::TrivialEncryptOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr result = getContextualBoolServerKey(op.getOperation()); + if (failed(result)) return result; + + Value serverKey = result.value(); + lwe::EncodeOp encodeOp = op.getInput().getDefiningOp(); + if (!encodeOp) { + return op.emitError() << "Expected input to TrivialEncrypt to be the " + "result of an EncodeOp, but it was " + << op.getInput().getDefiningOp()->getName(); + } + auto outputType = tfhe_rust_bool::EncryptedBoolType::get(getContext()); + ; + auto createTrivialOp = rewriter.create( + op.getLoc(), outputType, serverKey, encodeOp.getPlaintext()); + rewriter.replaceOp(op, createTrivialOp); + return success(); + } +}; + +struct ConvertBoolEncodeOp : public OpConversionPattern { + ConvertBoolEncodeOp(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + lwe::EncodeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return success(); + } +}; + +class CGGIToTfheRustBool + : public impl::CGGIToTfheRustBoolBase { + void runOnOperation() override { + MLIRContext *context = &getContext(); + auto *op = getOperation(); + + BoolPassTypeConverter typeConverter(context); + RewritePatternSet patterns(context); + ConversionTarget target(*context); + addStructuralConversionPatterns(typeConverter, patterns, target); + + target.addLegalDialect(); + target.addIllegalDialect(); + target.addIllegalDialect(); + + // FuncOp is marked legal by the default structural conversion patterns + // helper, just based on type conversion. We need more, but because the + // addDynamicallyLegalOp is a set-based method, we can add this after + // calling addStructuralConversionPatterns and it will overwrite the + // legality condition set in that function. + target.addDynamicallyLegalOp([&](func::FuncOp op) { + bool hasServerKeyArg = op.getFunctionType().getNumInputs() > 0 && + op.getFunctionType() + .getInputs() + .begin() + ->isa(); + return typeConverter.isSignatureLegal(op.getFunctionType()) && + typeConverter.isLegal(&op.getBody()) && + (!containsCGGIOpsBool(op) || hasServerKeyArg); + }); + target.addDynamicallyLegalOp( + [&](Operation *op) { + return typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes()); + }); + + // FIXME: still need to update callers to insert the new server key arg, if + // needed and possible. + patterns.add< + AddBoolServerKeyArg, ConvertBoolAndOp, ConvertBoolEncodeOp, + ConvertBoolOrOp, ConvertBoolTrivialEncryptOp, ConvertBoolXorOp, + ConvertBoolNotOp, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern>(typeConverter, context); + + if (failed(applyPartialConversion(op, target, std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +} // namespace mlir::heir diff --git a/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp b/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp index 48259cfa9..17975f349 100644 --- a/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp +++ b/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp @@ -65,7 +65,7 @@ LogicalResult TfheRustBoolEmitter::translate(Operation &op) { // Arith ops .Case([&](auto op) { return printOperation(op); }) // TfheRustBool ops - .Case( + .Case( [&](auto op) { return printOperation(op); }) // Tensor ops .Case( @@ -174,7 +174,18 @@ LogicalResult TfheRustBoolEmitter::printSksMethod( auto operandTypesIt = operandTypes.begin(); os << variableNames->getNameForValue(sks) << "." << op << "("; os << commaSeparatedValues(nonSksOperands, [&](Value value) { - return variableNames->getNameForValue(value) + + auto *prefix = value.getType().hasTrait() ? "&" : ""; + // First check if a DefiningOp exists + // if not: comes from function definition + mlir::Operation *op = value.getDefiningOp(); + if (op) { + prefix = isa(op) ? "" : prefix; + } + else{ + prefix = ""; + } + + return prefix + variableNames->getNameForValue(value) + (!operandTypes.empty() ? " as " + *operandTypesIt++ : ""); }); os << ");\n"; @@ -183,7 +194,7 @@ LogicalResult TfheRustBoolEmitter::printSksMethod( LogicalResult TfheRustBoolEmitter::printOperation(CreateTrivialOp op) { return printSksMethod(op.getResult(), op.getServerKey(), {op.getValue()}, - "create_trivial", {"i1"}); + "create_trivial", {"bool"}); } LogicalResult TfheRustBoolEmitter::printOperation(arith::ConstantOp op) { @@ -260,6 +271,11 @@ LogicalResult TfheRustBoolEmitter::printOperation(XnorOp op) { {op.getLhs(), op.getRhs()}, "xnor"); } +LogicalResult TfheRustBoolEmitter::printOperation(AndPackedOp op) { + return printSksMethod(op.getResult(), op.getServerKey(), + {op.getLhs(), op.getRhs()}, "and_packed"); +} + FailureOr TfheRustBoolEmitter::convertType(Type type) { // Note: these are probably not the right type names to use exactly, and they // will need to chance to the right values once we try to compile it against diff --git a/tests/cggi_to_tfhe_rust_bool/BUILD b/tests/cggi_to_tfhe_rust_bool/BUILD new file mode 100644 index 000000000..c571e6fc6 --- /dev/null +++ b/tests/cggi_to_tfhe_rust_bool/BUILD @@ -0,0 +1,10 @@ +load("//bazel:lit.bzl", "glob_lit_tests") + +package(default_applicable_licenses = ["@heir//:license"]) + +glob_lit_tests( + name = "all_tests", + data = ["@heir//tests:test_utilities"], + driver = "@heir//tests:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tests/cggi_to_tfhe_rust_bool/add_bool.mlir b/tests/cggi_to_tfhe_rust_bool/add_bool.mlir new file mode 100644 index 000000000..646c68e81 --- /dev/null +++ b/tests/cggi_to_tfhe_rust_bool/add_bool.mlir @@ -0,0 +1,78 @@ +// RUN: heir-opt --cggi-to-tfhe-rust-bool -cse -remove-dead-values %s | FileCheck %s + + +#encoding = #lwe.unspecified_bit_field_encoding +!ct_ty = !lwe.lwe_ciphertext +!pt_ty = !lwe.lwe_plaintext + + +// CHECK-LABEL: add_bool +// CHECK-NOT: cggi +// CHECK-NOT: lwe +func.func @add_bool(%arg0: tensor<8x!ct_ty>, %arg1: tensor<8x!ct_ty>) -> tensor<8x!ct_ty> { + %true = arith.constant true + %false = arith.constant false + %c7 = arith.constant 7 : index + %c6 = arith.constant 6 : index + %c5 = arith.constant 5 : index + %c4 = arith.constant 4 : index + %c3 = arith.constant 3 : index + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %extracted_00 = tensor.extract %arg0[%c0] : tensor<8x!ct_ty> + %extracted_01 = tensor.extract %arg0[%c1] : tensor<8x!ct_ty> + %extracted_02 = tensor.extract %arg0[%c2] : tensor<8x!ct_ty> + %extracted_03 = tensor.extract %arg0[%c3] : tensor<8x!ct_ty> + %extracted_04 = tensor.extract %arg0[%c4] : tensor<8x!ct_ty> + %extracted_05 = tensor.extract %arg0[%c5] : tensor<8x!ct_ty> + %extracted_06 = tensor.extract %arg0[%c6] : tensor<8x!ct_ty> + %extracted_07 = tensor.extract %arg0[%c7] : tensor<8x!ct_ty> + %extracted_10 = tensor.extract %arg1[%c0] : tensor<8x!ct_ty> + %extracted_11 = tensor.extract %arg1[%c1] : tensor<8x!ct_ty> + %extracted_12 = tensor.extract %arg1[%c2] : tensor<8x!ct_ty> + %extracted_13 = tensor.extract %arg1[%c3] : tensor<8x!ct_ty> + %extracted_14 = tensor.extract %arg1[%c4] : tensor<8x!ct_ty> + %extracted_15 = tensor.extract %arg1[%c5] : tensor<8x!ct_ty> + %extracted_16 = tensor.extract %arg1[%c6] : tensor<8x!ct_ty> + %extracted_17 = tensor.extract %arg1[%c7] : tensor<8x!ct_ty> + %ha_s = cggi.xor %extracted_00, %extracted_10 : !ct_ty + %ha_c = cggi.and %extracted_00, %extracted_10 : !ct_ty + %fa0_1 = cggi.xor %extracted_01, %extracted_11 : !ct_ty + %fa0_2 = cggi.and %extracted_01, %extracted_11 : !ct_ty + %fa0_3 = cggi.and %fa0_1, %ha_c : !ct_ty + %fa0_s = cggi.xor %fa0_1, %ha_c : !ct_ty + %fa0_c = cggi.xor %fa0_2, %fa0_3 : !ct_ty + %fa1_1 = cggi.xor %extracted_02, %extracted_12 : !ct_ty + %fa1_2 = cggi.and %extracted_02, %extracted_12 : !ct_ty + %fa1_3 = cggi.and %fa1_1, %fa0_c : !ct_ty + %fa1_s = cggi.xor %fa1_1, %fa0_c : !ct_ty + %fa1_c = cggi.xor %fa1_2, %fa1_3 : !ct_ty + %fa2_1 = cggi.xor %extracted_03, %extracted_13 : !ct_ty + %fa2_2 = cggi.and %extracted_03, %extracted_13 : !ct_ty + %fa2_3 = cggi.and %fa2_1, %fa1_c : !ct_ty + %fa2_s = cggi.xor %fa2_1, %fa1_c : !ct_ty + %fa2_c = cggi.xor %fa2_2, %fa2_3 : !ct_ty + %fa3_1 = cggi.xor %extracted_04, %extracted_14 : !ct_ty + %fa3_2 = cggi.and %extracted_04, %extracted_14 : !ct_ty + %fa3_3 = cggi.and %fa3_1, %fa2_c : !ct_ty + %fa3_s = cggi.xor %fa3_1, %fa2_c : !ct_ty + %fa3_c = cggi.xor %fa3_2, %fa3_3 : !ct_ty + %fa4_1 = cggi.xor %extracted_05, %extracted_15 : !ct_ty + %fa4_2 = cggi.and %extracted_05, %extracted_15 : !ct_ty + %fa4_3 = cggi.and %fa4_1, %fa3_c : !ct_ty + %fa4_s = cggi.xor %fa4_1, %fa3_c : !ct_ty + %fa4_c = cggi.xor %fa4_2, %fa4_3 : !ct_ty + %fa5_1 = cggi.xor %extracted_06, %extracted_16 : !ct_ty + %fa5_2 = cggi.and %extracted_06, %extracted_16 : !ct_ty + %fa5_3 = cggi.and %fa5_1, %fa4_c : !ct_ty + %fa5_s = cggi.xor %fa5_1, %fa4_c : !ct_ty + %fa5_c = cggi.xor %fa5_2, %fa5_3 : !ct_ty + %fa6_1 = cggi.xor %extracted_07, %extracted_17 : !ct_ty + %fa6_2 = cggi.and %extracted_07, %extracted_17 : !ct_ty + %fa6_3 = cggi.and %fa6_1, %fa5_c : !ct_ty + %fa6_s = cggi.xor %fa6_1, %fa5_c : !ct_ty + %fa6_c = cggi.xor %fa6_2, %fa6_3 : !ct_ty + %from_elements = tensor.from_elements %fa6_s, %fa5_s, %fa4_s, %fa3_s, %fa2_s, %fa1_s, %fa0_s, %ha_s : tensor<8x!ct_ty> + return %from_elements : tensor<8x!ct_ty> +} diff --git a/tests/cggi_to_tfhe_rust_bool/add_one_bool.mlir b/tests/cggi_to_tfhe_rust_bool/add_one_bool.mlir new file mode 100644 index 000000000..8a6e62cd8 --- /dev/null +++ b/tests/cggi_to_tfhe_rust_bool/add_one_bool.mlir @@ -0,0 +1,76 @@ +// RUN: heir-opt --cggi-to-tfhe-rust-bool -cse -remove-dead-values %s | FileCheck %s + + +#encoding = #lwe.unspecified_bit_field_encoding +!ct_ty = !lwe.lwe_ciphertext +!pt_ty = !lwe.lwe_plaintext + + +// CHECK-LABEL: add_one_bool +// CHECK-NOT: cggi +// CHECK-NOT: lwe +func.func @add_one_bool(%arg0: tensor<8x!ct_ty>, %arg1: tensor<8x!ct_ty>) -> tensor<8x!ct_ty> { + %true = arith.constant true + %false = arith.constant false + %c7 = arith.constant 7 : index + %c6 = arith.constant 6 : index + %c5 = arith.constant 5 : index + %c4 = arith.constant 4 : index + %c3 = arith.constant 3 : index + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %i1 = arith.constant 1 : i1 + %i0 = arith.constant 0 : i1 + %extracted_00 = tensor.extract %arg0[%c0] : tensor<8x!ct_ty> + %extracted_01 = tensor.extract %arg0[%c1] : tensor<8x!ct_ty> + %extracted_02 = tensor.extract %arg0[%c2] : tensor<8x!ct_ty> + %extracted_03 = tensor.extract %arg0[%c3] : tensor<8x!ct_ty> + %extracted_04 = tensor.extract %arg0[%c4] : tensor<8x!ct_ty> + %extracted_05 = tensor.extract %arg0[%c5] : tensor<8x!ct_ty> + %extracted_06 = tensor.extract %arg0[%c6] : tensor<8x!ct_ty> + %extracted_07 = tensor.extract %arg0[%c7] : tensor<8x!ct_ty> + %1 = lwe.encode %i0 {encoding = #encoding} : i1 to !pt_ty + %ec0 = lwe.trivial_encrypt %1 : !pt_ty to !ct_ty + %2 = lwe.encode %i1 {encoding = #encoding} : i1 to !pt_ty + %ec1 = lwe.trivial_encrypt %2 : !pt_ty to !ct_ty + %ha_s = cggi.xor %extracted_00, %ec1 : !ct_ty + %ha_c = cggi.and %extracted_00, %ec1 : !ct_ty + %fa0_1 = cggi.xor %extracted_01, %ec0 : !ct_ty + %fa0_2 = cggi.and %extracted_01, %ec0 : !ct_ty + %fa0_3 = cggi.and %fa0_1, %ha_c : !ct_ty + %fa0_s = cggi.xor %fa0_1, %ha_c : !ct_ty + %fa0_c = cggi.xor %fa0_2, %fa0_3 : !ct_ty + %fa1_1 = cggi.xor %extracted_02, %ec0 : !ct_ty + %fa1_2 = cggi.and %extracted_02, %ec0 : !ct_ty + %fa1_3 = cggi.and %fa1_1, %fa0_c : !ct_ty + %fa1_s = cggi.xor %fa1_1, %fa0_c : !ct_ty + %fa1_c = cggi.xor %fa1_2, %fa1_3 : !ct_ty + %fa2_1 = cggi.xor %extracted_03, %ec0 : !ct_ty + %fa2_2 = cggi.and %extracted_03, %ec0 : !ct_ty + %fa2_3 = cggi.and %fa2_1, %fa1_c : !ct_ty + %fa2_s = cggi.xor %fa2_1, %fa1_c : !ct_ty + %fa2_c = cggi.xor %fa2_2, %fa2_3 : !ct_ty + %fa3_1 = cggi.xor %extracted_04, %ec0 : !ct_ty + %fa3_2 = cggi.and %extracted_04, %ec0 : !ct_ty + %fa3_3 = cggi.and %fa3_1, %fa2_c : !ct_ty + %fa3_s = cggi.xor %fa3_1, %fa2_c : !ct_ty + %fa3_c = cggi.xor %fa3_2, %fa3_3 : !ct_ty + %fa4_1 = cggi.xor %extracted_05, %ec0 : !ct_ty + %fa4_2 = cggi.and %extracted_05, %ec0 : !ct_ty + %fa4_3 = cggi.and %fa4_1, %fa3_c : !ct_ty + %fa4_s = cggi.xor %fa4_1, %fa3_c : !ct_ty + %fa4_c = cggi.xor %fa4_2, %fa4_3 : !ct_ty + %fa5_1 = cggi.xor %extracted_06, %ec0 : !ct_ty + %fa5_2 = cggi.and %extracted_06, %ec0 : !ct_ty + %fa5_3 = cggi.and %fa5_1, %fa4_c : !ct_ty + %fa5_s = cggi.xor %fa5_1, %fa4_c : !ct_ty + %fa5_c = cggi.xor %fa5_2, %fa5_3 : !ct_ty + %fa6_1 = cggi.xor %extracted_07, %ec0 : !ct_ty + %fa6_2 = cggi.and %extracted_07, %ec0 : !ct_ty + %fa6_3 = cggi.and %fa6_1, %fa5_c : !ct_ty + %fa6_s = cggi.xor %fa6_1, %fa5_c : !ct_ty + %fa6_c = cggi.xor %fa6_2, %fa6_3 : !ct_ty + %from_elements = tensor.from_elements %fa6_s, %fa5_s, %fa4_s, %fa3_s, %fa2_s, %fa1_s, %fa0_s, %ha_s : tensor<8x!ct_ty> + return %from_elements : tensor<8x!ct_ty> +} diff --git a/tests/tfhe_rust_bool/add_one_bool.mlir b/tests/tfhe_rust_bool/add_one_bool.mlir new file mode 100644 index 000000000..e9018007d --- /dev/null +++ b/tests/tfhe_rust_bool/add_one_bool.mlir @@ -0,0 +1,72 @@ +// RUN: heir-translate %s --emit-tfhe-rust-bool | FileCheck %s + +!bsks = !tfhe_rust_bool.server_key +!eb = !tfhe_rust_bool.eb + +// CHECK-LABEL: pub fn fn_under_test( +// CHECK-NEXT: [[bsks:v[0-9]+]]: &ServerKey, +// CHECK-NEXT: [[input1:v[0-9]+]]: &Vec, +// CHECK-NEXT: [[input2:v[0-9]+]]: &Vec, +// CHECK-NEXT: ) -> Vec { +func.func @fn_under_test(%bsks : !bsks, %arg0: tensor<8x!eb>, %arg1: tensor<8x!eb>) -> tensor<8x!eb> { + %c7 = arith.constant 7 : index + %c6 = arith.constant 6 : index + %c5 = arith.constant 5 : index + %c4 = arith.constant 4 : index + %c3 = arith.constant 3 : index + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %extracted_00 = tensor.extract %arg0[%c0] : tensor<8x!eb> + %extracted_01 = tensor.extract %arg0[%c1] : tensor<8x!eb> + %extracted_02 = tensor.extract %arg0[%c2] : tensor<8x!eb> + %extracted_03 = tensor.extract %arg0[%c3] : tensor<8x!eb> + %extracted_04 = tensor.extract %arg0[%c4] : tensor<8x!eb> + %extracted_05 = tensor.extract %arg0[%c5] : tensor<8x!eb> + %extracted_06 = tensor.extract %arg0[%c6] : tensor<8x!eb> + %extracted_07 = tensor.extract %arg0[%c7] : tensor<8x!eb> + %extracted_10 = tensor.extract %arg1[%c0] : tensor<8x!eb> + %extracted_11 = tensor.extract %arg1[%c1] : tensor<8x!eb> + %extracted_12 = tensor.extract %arg1[%c2] : tensor<8x!eb> + %extracted_13 = tensor.extract %arg1[%c3] : tensor<8x!eb> + %extracted_14 = tensor.extract %arg1[%c4] : tensor<8x!eb> + %extracted_15 = tensor.extract %arg1[%c5] : tensor<8x!eb> + %extracted_16 = tensor.extract %arg1[%c6] : tensor<8x!eb> + %extracted_17 = tensor.extract %arg1[%c7] : tensor<8x!eb> + %ha_s = tfhe_rust_bool.xor %bsks, %extracted_00, %extracted_10 : (!bsks, !eb, !eb) -> !eb + %ha_c = tfhe_rust_bool.and %bsks, %extracted_00, %extracted_10 : (!bsks, !eb, !eb) -> !eb + %fa0_1 = tfhe_rust_bool.xor %bsks, %extracted_01, %extracted_11 : (!bsks, !eb, !eb) -> !eb + %fa0_2 = tfhe_rust_bool.and %bsks, %extracted_01, %extracted_11 : (!bsks, !eb, !eb) -> !eb + %fa0_3 = tfhe_rust_bool.and %bsks, %fa0_1, %ha_c : (!bsks, !eb, !eb) -> !eb + %fa0_s = tfhe_rust_bool.xor %bsks, %fa0_1, %ha_c : (!bsks, !eb, !eb) -> !eb + %fa0_c = tfhe_rust_bool.xor %bsks, %fa0_2, %fa0_3 : (!bsks, !eb, !eb) -> !eb + %fa1_1 = tfhe_rust_bool.xor %bsks, %extracted_02, %extracted_12 : (!bsks, !eb, !eb) -> !eb + %fa1_2 = tfhe_rust_bool.and %bsks, %extracted_02, %extracted_12 : (!bsks, !eb, !eb) -> !eb + %fa1_3 = tfhe_rust_bool.and %bsks, %fa1_1, %fa0_c : (!bsks, !eb, !eb) -> !eb + %fa1_s = tfhe_rust_bool.xor %bsks, %fa1_1, %fa0_c : (!bsks, !eb, !eb) -> !eb + %fa1_c = tfhe_rust_bool.xor %bsks, %fa1_2, %fa1_3 : (!bsks, !eb, !eb) -> !eb + %fa2_1 = tfhe_rust_bool.xor %bsks, %extracted_03, %extracted_13 : (!bsks, !eb, !eb) -> !eb + %fa2_2 = tfhe_rust_bool.and %bsks, %extracted_03, %extracted_13 : (!bsks, !eb, !eb) -> !eb + %fa2_3 = tfhe_rust_bool.and %bsks, %fa2_1, %fa1_c : (!bsks, !eb, !eb) -> !eb + %fa2_s = tfhe_rust_bool.xor %bsks, %fa2_1, %fa1_c : (!bsks, !eb, !eb) -> !eb + %fa2_c = tfhe_rust_bool.xor %bsks, %fa2_2, %fa2_3 : (!bsks, !eb, !eb) -> !eb + %fa3_1 = tfhe_rust_bool.xor %bsks, %extracted_04, %extracted_14 : (!bsks, !eb, !eb) -> !eb + %fa3_2 = tfhe_rust_bool.and %bsks, %extracted_04, %extracted_14 : (!bsks, !eb, !eb) -> !eb + %fa3_3 = tfhe_rust_bool.and %bsks, %fa3_1, %fa2_c : (!bsks, !eb, !eb) -> !eb + %fa3_s = tfhe_rust_bool.xor %bsks, %fa3_1, %fa2_c : (!bsks, !eb, !eb) -> !eb + %fa3_c = tfhe_rust_bool.xor %bsks, %fa3_2, %fa3_3 : (!bsks, !eb, !eb) -> !eb + %fa4_1 = tfhe_rust_bool.xor %bsks, %extracted_05, %extracted_15 : (!bsks, !eb, !eb) -> !eb + %fa4_2 = tfhe_rust_bool.and %bsks, %extracted_05, %extracted_15 : (!bsks, !eb, !eb) -> !eb + %fa4_3 = tfhe_rust_bool.and %bsks, %fa4_1, %fa3_c : (!bsks, !eb, !eb) -> !eb + %fa4_s = tfhe_rust_bool.xor %bsks, %fa4_1, %fa3_c : (!bsks, !eb, !eb) -> !eb + %fa4_c = tfhe_rust_bool.xor %bsks, %fa4_2, %fa4_3 : (!bsks, !eb, !eb) -> !eb + %fa5_1 = tfhe_rust_bool.xor %bsks, %extracted_06, %extracted_16 : (!bsks, !eb, !eb) -> !eb + %fa5_2 = tfhe_rust_bool.and %bsks, %extracted_06, %extracted_16 : (!bsks, !eb, !eb) -> !eb + %fa5_3 = tfhe_rust_bool.and %bsks, %fa5_1, %fa4_c : (!bsks, !eb, !eb) -> !eb + %fa5_s = tfhe_rust_bool.xor %bsks, %fa5_1, %fa4_c : (!bsks, !eb, !eb) -> !eb + %fa5_c = tfhe_rust_bool.xor %bsks, %fa5_2, %fa5_3 : (!bsks, !eb, !eb) -> !eb + %fa6_1 = tfhe_rust_bool.xor %bsks, %extracted_07, %extracted_17 : (!bsks, !eb, !eb) -> !eb + %fa6_s = tfhe_rust_bool.xor %bsks, %fa6_1, %fa5_c : (!bsks, !eb, !eb) -> !eb + %from_elements = tensor.from_elements %fa6_s, %fa5_s, %fa4_s, %fa3_s, %fa2_s, %fa1_s, %fa0_s, %ha_s : tensor<8x!eb> + return %from_elements : tensor<8x!eb> +} diff --git a/tests/tfhe_rust_bool/emit_tfhe_rust.mlir b/tests/tfhe_rust_bool/emit_tfhe_rust_bool.mlir similarity index 88% rename from tests/tfhe_rust_bool/emit_tfhe_rust.mlir rename to tests/tfhe_rust_bool/emit_tfhe_rust_bool.mlir index 1a96d3b25..e6dd85760 100644 --- a/tests/tfhe_rust_bool/emit_tfhe_rust.mlir +++ b/tests/tfhe_rust_bool/emit_tfhe_rust_bool.mlir @@ -8,7 +8,7 @@ // CHECK-NEXT: [[input1:v[0-9]+]]: &Ciphertext, // CHECK-NEXT: [[input2:v[0-9]+]]: &Ciphertext, // CHECK-NEXT: ) -> Ciphertext { -// CHECK-NEXT: let [[v0:.*]] = [[bsks]].and([[input1]], [[input2]]); +// CHECK-NEXT: let [[v0:.*]] = [[bsks]].and(&[[input1]], &[[input2]]); // CHECK-NEXT: [[v0]] // CHECK-NEXT: } func.func @test_and(%bsks : !bsks, %input1 : !eb, %input2 : !eb) -> !eb { diff --git a/tests/tfhe_rust_bool/end_to_end/src/main.rs b/tests/tfhe_rust_bool/end_to_end/src/main.rs index 3faa63949..33ab2c378 100644 --- a/tests/tfhe_rust_bool/end_to_end/src/main.rs +++ b/tests/tfhe_rust_bool/end_to_end/src/main.rs @@ -17,15 +17,10 @@ fn main() { let flags = Args::parse(); let (client_key, server_key) = tfhe::boolean::gen_keys(); - let pt_1: bool = flags.input1 == 1u8; - let pt_2: bool = flags.input2 == 1u8; - - let ct_1 = client_key.encrypt(pt_1); - let ct_2 = client_key.encrypt(pt_2); + let ct_1 = client_key.encrypt(flags.input1 == 1u8); + let ct_2 = client_key.encrypt(flags.input2 == 1u8); let result = fn_under_test::fn_under_test(&server_key, &ct_1, &ct_2); let output: bool = client_key.decrypt(&result); - print!("{:?} ", pt_1 as u8); - print!("{:?} ", pt_2 as u8); - print!("{:?} ", output as u8); + println!("{:?}", output as u8); } diff --git a/tests/tfhe_rust_bool/end_to_end/test_and.mlir b/tests/tfhe_rust_bool/end_to_end/test_and.mlir index f80f92c03..694a7ef2d 100644 --- a/tests/tfhe_rust_bool/end_to_end/test_and.mlir +++ b/tests/tfhe_rust_bool/end_to_end/test_and.mlir @@ -6,7 +6,7 @@ !bsks = !tfhe_rust_bool.server_key !eb = !tfhe_rust_bool.eb -// CHECK: 1 1 1 +// CHECK: 1 func.func @fn_under_test(%bsks : !bsks, %a: !eb, %b: !eb) -> !eb { %res = tfhe_rust_bool.and %bsks, %a, %b: (!bsks, !eb, !eb) -> !eb return %res : !eb diff --git a/tests/tfhe_rust_bool/ops.mlir b/tests/tfhe_rust_bool/ops.mlir index 784b0727b..e97d705f7 100644 --- a/tests/tfhe_rust_bool/ops.mlir +++ b/tests/tfhe_rust_bool/ops.mlir @@ -3,6 +3,7 @@ // This simply tests for syntax. !bsks = !tfhe_rust_bool.server_key +!eb = !tfhe_rust_bool.eb module { // CHECK-LABEL: func @test_create_trivial_bool @@ -26,4 +27,24 @@ module { return } + // CHECK-LABEL: func @test_packed_and + func.func @test_packed_and(%bsks : !bsks, %lhs : tensor<4x!eb>, %rhs : tensor<4x!eb>) { + %0 = arith.constant 0 : index + %1 = arith.constant 1 : index + %4 = arith.constant 4 : index + + %c0 = arith.constant 0 : i1 + %c1 = arith.constant 1 : i1 + + scf.for %i = %0 to %4 step %1 { + %tmp1 = tfhe_rust_bool.create_trivial %bsks, %c0 : (!bsks, i1) -> !eb + %tmp2 = tfhe_rust_bool.create_trivial %bsks, %c1 : (!bsks, i1) -> !eb + + tensor.insert %tmp1 into %lhs[%i] : tensor<4x!eb> + tensor.insert %tmp2 into %rhs[%i] : tensor<4x!eb> + } + + %out = tfhe_rust_bool.and_packed %bsks, %lhs, %rhs: (!bsks, tensor<4x!eb>, tensor<4x!eb>) -> tensor<4x!eb> + return + } } diff --git a/tools/BUILD b/tools/BUILD index 0c9bc44a6..1fab41a85 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -36,6 +36,7 @@ cc_binary( "@heir//lib/Conversion/BGVToOpenfhe", "@heir//lib/Conversion/BGVToPolynomial", "@heir//lib/Conversion/CGGIToTfheRust", + "@heir//lib/Conversion/CGGIToTfheRustBool", "@heir//lib/Conversion/CombToCGGI", "@heir//lib/Conversion/MemrefToArith:ExpandCopy", "@heir//lib/Conversion/MemrefToArith:MemrefToArithRegistration", diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index 8eb0ed41e..a1ae9f5e3 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -4,6 +4,7 @@ #include "include/Conversion/BGVToOpenfhe/BGVToOpenfhe.h" #include "include/Conversion/BGVToPolynomial/BGVToPolynomial.h" #include "include/Conversion/CGGIToTfheRust/CGGIToTfheRust.h" +#include "include/Conversion/CGGIToTfheRustBool/CGGIToTfheRustBool.h" #include "include/Conversion/CombToCGGI/CombToCGGI.h" #include "include/Conversion/MemrefToArith/MemrefToArith.h" #include "include/Conversion/PolynomialToStandard/PolynomialToStandard.h" @@ -311,6 +312,7 @@ int main(int argc, char **argv) { comb::registerCombToCGGIPasses(); polynomial::registerPolynomialToStandardPasses(); registerCGGIToTfheRustPasses(); + registerCGGIToTfheRustBoolPasses(); PassPipelineRegistration<>( "heir-tosa-to-arith",