diff --git a/lib/Conversion/CGGIToTfheRust/CGGIToTfheRust.cpp b/lib/Conversion/CGGIToTfheRust/CGGIToTfheRust.cpp index e0078cba1..4954b3c08 100644 --- a/lib/Conversion/CGGIToTfheRust/CGGIToTfheRust.cpp +++ b/lib/Conversion/CGGIToTfheRust/CGGIToTfheRust.cpp @@ -84,9 +84,9 @@ int widthFromEncodingAttr(Attribute encoding) { }); } -class PassTypeConverter : public TypeConverter { +class CGGIToTfheRustTypeConverter : public TypeConverter { public: - PassTypeConverter(MLIRContext *ctx) { + CGGIToTfheRustTypeConverter(MLIRContext *ctx) { addConversion([](Type type) { return type; }); addConversion([ctx](lwe::LWECiphertextType type) -> Type { int width = widthFromEncodingAttr(type.getEncoding()); @@ -402,7 +402,7 @@ class CGGIToTfheRust : public impl::CGGIToTfheRustBase { MLIRContext *context = &getContext(); auto *op = getOperation(); - PassTypeConverter typeConverter(context); + CGGIToTfheRustTypeConverter typeConverter(context); RewritePatternSet patterns(context); ConversionTarget target(*context); addStructuralConversionPatterns(typeConverter, patterns, target); diff --git a/lib/Conversion/CGGIToTfheRustBool/CGGIToTfheRustBool.cpp b/lib/Conversion/CGGIToTfheRustBool/CGGIToTfheRustBool.cpp index 3a4716fb4..af2b6342a 100644 --- a/lib/Conversion/CGGIToTfheRustBool/CGGIToTfheRustBool.cpp +++ b/lib/Conversion/CGGIToTfheRustBool/CGGIToTfheRustBool.cpp @@ -32,9 +32,9 @@ namespace mlir::heir { #define GEN_PASS_DEF_CGGITOTFHERUSTBOOL #include "include/Conversion/CGGIToTfheRustBool/CGGIToTfheRustBool.h.inc" -class BoolPassTypeConverter : public TypeConverter { +class CGGIToTfheRustBoolTypeConverter : public TypeConverter { public: - BoolPassTypeConverter(MLIRContext *ctx) { + CGGIToTfheRustBoolTypeConverter(MLIRContext *ctx) { addConversion([](Type type) { return type; }); addConversion([ctx](lwe::LWECiphertextType type) -> Type { return tfhe_rust_bool::EncryptedBoolType::get(ctx); @@ -131,35 +131,12 @@ struct AddBoolServerKeyArg : public OpConversionPattern { } }; -struct ConvertBoolAndOp : public OpConversionPattern { - ConvertBoolAndOp(mlir::MLIRContext *context) - : OpConversionPattern(context) {} - - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - cggi::AndOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - ImplicitLocOpBuilder b(op->getLoc(), rewriter); - FailureOr result = getContextualBoolServerKey(op); - if (failed(result)) return result; - - Value serverKey = result.value(); - - rewriter.replaceOp(op, b.create( - serverKey, adaptor.getLhs(), adaptor.getRhs())); - return success(); - } -}; - -struct ConvertBoolOrOp : public OpConversionPattern { - ConvertBoolOrOp(mlir::MLIRContext *context) - : OpConversionPattern(context) {} - - using OpConversionPattern::OpConversionPattern; +template +struct ConvertBinOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - cggi::OrOp op, OpAdaptor adaptor, + BinOp op, typename BinOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { ImplicitLocOpBuilder b(op->getLoc(), rewriter); FailureOr result = getContextualBoolServerKey(op); @@ -167,32 +144,15 @@ struct ConvertBoolOrOp : public OpConversionPattern { Value serverKey = result.value(); - rewriter.replaceOp(op, b.create( + rewriter.replaceOp(op, b.create( serverKey, adaptor.getLhs(), adaptor.getRhs())); return success(); } }; -struct ConvertBoolXorOp : public OpConversionPattern { - ConvertBoolXorOp(mlir::MLIRContext *context) - : OpConversionPattern(context) {} - - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - cggi::XorOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - ImplicitLocOpBuilder b(op->getLoc(), rewriter); - FailureOr result = getContextualBoolServerKey(op); - if (failed(result)) return result; - - Value serverKey = result.value(); - - rewriter.replaceOp(op, b.create( - serverKey, adaptor.getLhs(), adaptor.getRhs())); - return success(); - } -}; +using ConvertBoolAndOp = ConvertBinOp; +using ConvertBoolOrOp = ConvertBinOp; +using ConvertBoolXorOp = ConvertBinOp; struct ConvertBoolNotOp : public OpConversionPattern { ConvertBoolNotOp(mlir::MLIRContext *context) @@ -218,7 +178,7 @@ struct ConvertBoolNotOp : public OpConversionPattern { struct ConvertBoolTrivialEncryptOp : public OpConversionPattern { ConvertBoolTrivialEncryptOp(mlir::MLIRContext *context) - : OpConversionPattern(context, /*benefit=*/2) {} + : OpConversionPattern(context, /*benefit=*/1) {} using OpConversionPattern::OpConversionPattern; @@ -236,7 +196,7 @@ struct ConvertBoolTrivialEncryptOp << op.getInput().getDefiningOp()->getName(); } auto outputType = tfhe_rust_bool::EncryptedBoolType::get(getContext()); - ; + auto createTrivialOp = rewriter.create( op.getLoc(), outputType, serverKey, encodeOp.getPlaintext()); rewriter.replaceOp(op, createTrivialOp); @@ -264,7 +224,7 @@ class CGGIToTfheRustBool MLIRContext *context = &getContext(); auto *op = getOperation(); - BoolPassTypeConverter typeConverter(context); + CGGIToTfheRustBoolTypeConverter typeConverter(context); RewritePatternSet patterns(context); ConversionTarget target(*context); addStructuralConversionPatterns(typeConverter, patterns, target); @@ -307,7 +267,7 @@ class CGGIToTfheRustBool GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, - GenericOpPattern>(typeConverter, context); + GenericOpPattern >(typeConverter, context); if (failed(applyPartialConversion(op, target, std::move(patterns)))) { return signalPassFailure(); diff --git a/tests/tfhe_rust_bool/ops.mlir b/tests/tfhe_rust_bool/ops.mlir index e97d705f7..1eb7f7353 100644 --- a/tests/tfhe_rust_bool/ops.mlir +++ b/tests/tfhe_rust_bool/ops.mlir @@ -29,21 +29,6 @@ module { // CHECK-LABEL: func @test_packed_and func.func @test_packed_and(%bsks : !bsks, %lhs : tensor<4x!eb>, %rhs : tensor<4x!eb>) { - %0 = arith.constant 0 : index - %1 = arith.constant 1 : index - %4 = arith.constant 4 : index - - %c0 = arith.constant 0 : i1 - %c1 = arith.constant 1 : i1 - - scf.for %i = %0 to %4 step %1 { - %tmp1 = tfhe_rust_bool.create_trivial %bsks, %c0 : (!bsks, i1) -> !eb - %tmp2 = tfhe_rust_bool.create_trivial %bsks, %c1 : (!bsks, i1) -> !eb - - tensor.insert %tmp1 into %lhs[%i] : tensor<4x!eb> - tensor.insert %tmp2 into %rhs[%i] : tensor<4x!eb> - } - %out = tfhe_rust_bool.and_packed %bsks, %lhs, %rhs: (!bsks, tensor<4x!eb>, tensor<4x!eb>) -> tensor<4x!eb> return }