diff --git a/examples/simple_ckks_bootstrapping.jl b/examples/simple_ckks_bootstrapping.jl index 5a316db..017fcbf 100644 --- a/examples/simple_ckks_bootstrapping.jl +++ b/examples/simple_ckks_bootstrapping.jl @@ -12,18 +12,17 @@ function simple_ckks_bootstrapping(context) init_bootstrapping!(context, private_key) x = [0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0, 5.0] - encoded_length = length(x) - pv = PlainVector(context, x) + pv = PlainVector(x, context) println("Input: ", pv) - sv = encrypt(context, public_key, pv) + sv = encrypt(pv, public_key) # Perform the bootstrapping operation. The goal is to increase the number of levels # remaining for HE computation. - sv_after = bootstrap!(context, sv) + sv_after = bootstrap!(sv) - result = decrypt(context, private_key, sv) + result = decrypt(sv, private_key) println("Output after bootstrapping \n\t", result) end diff --git a/examples/simple_real_numbers.jl b/examples/simple_real_numbers.jl index 655ca2a..61c0df5 100644 --- a/examples/simple_real_numbers.jl +++ b/examples/simple_real_numbers.jl @@ -11,14 +11,14 @@ function simple_real_numbers(context) x1 = [0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0, 5.0] x2 = [5.0, 4.0, 3.0, 2.0, 1.0, 0.75, 0.5, 0.25] - pv1 = PlainVector(context, x1) - pv2 = PlainVector(context, x2) + pv1 = PlainVector(x1, context) + pv2 = PlainVector(x2, context) println("Input x1: ", pv1) println("Input x2: ", pv2) - sv1 = encrypt(context, public_key, pv1) - sv2 = encrypt(context, public_key, pv2) + sv1 = encrypt(pv1, public_key) + sv2 = encrypt(pv2, public_key) sv_add = sv1 + sv2 @@ -35,25 +35,25 @@ function simple_real_numbers(context) println() println("Results of homomorphic computations: ") - result_sv1 = decrypt(context, private_key, sv1) + result_sv1 = decrypt(sv1, private_key) println("x1 = ", result_sv1) - result_sv_add = decrypt(context, private_key, sv_add) + result_sv_add = decrypt(sv_add, private_key) println("x1 + x2 = ", result_sv_add) - result_sv_sub = decrypt(context, private_key, sv_sub) + result_sv_sub = decrypt(sv_sub, private_key) println("x1 - x2 = ", result_sv_sub) - result_sv_scalar = decrypt(context, private_key, sv_scalar) + result_sv_scalar = decrypt(sv_scalar, private_key) println("4 * x1 = ", result_sv_scalar) - result_sv_mult = decrypt(context, private_key, sv_mult) + result_sv_mult = decrypt(sv_mult, private_key) println("x1 * x2 = ", result_sv_mult) - result_sv_shift1 = decrypt(context, private_key, sv_shift1) + result_sv_shift1 = decrypt(sv_shift1, private_key) println("x1 shifted circularly by -1 = ", result_sv_shift1) - result_sv_shift2 = decrypt(context, private_key, sv_shift2) + result_sv_shift2 = decrypt(sv_shift2, private_key) println("x1 shifted circularly by 2 = ", result_sv_shift2) end diff --git a/src/openfhe.jl b/src/openfhe.jl index 172b7aa..3842802 100644 --- a/src/openfhe.jl +++ b/src/openfhe.jl @@ -44,7 +44,7 @@ function init_bootstrapping!(context::SecureContext{<:OpenFHEBackend}, private_k nothing end -function PlainVector(context::SecureContext{<:OpenFHEBackend}, data::Vector{<:Real}) +function PlainVector(data::Vector{<:Real}, context::SecureContext{<:OpenFHEBackend}) cc = get_crypto_context(context) plaintext = OpenFHE.MakeCKKSPackedPlaintext(cc, data) plain_vector = PlainVector(plaintext, context) @@ -52,15 +52,15 @@ function PlainVector(context::SecureContext{<:OpenFHEBackend}, data::Vector{<:Re plain_vector end -function encrypt(context::SecureContext{<:OpenFHEBackend}, public_key, data::Vector{<:Real}) +function encrypt(data::Vector{<:Real}, public_key, context::SecureContext{<:OpenFHEBackend}) plain_vector = PlainVector(context, data) secure_vector = encrypt(context, public_key, plain_vector) secure_vector end -function encrypt(context::SecureContext{<:OpenFHEBackend}, public_key, - plain_vector::PlainVector) +function encrypt(plain_vector::PlainVector{<:OpenFHEBackend}, public_key) + context = plain_vector.context cc = get_crypto_context(context) ciphertext = OpenFHE.Encrypt(cc, public_key.public_key, plain_vector.plaintext) secure_vector = SecureVector(ciphertext, context) @@ -68,29 +68,25 @@ function encrypt(context::SecureContext{<:OpenFHEBackend}, public_key, secure_vector end -function decrypt!(plain_vector, context::SecureContext{<:OpenFHEBackend}, private_key, - secure_vector) - cc = get_crypto_context(context) +function decrypt!(plain_vector::PlainVector{<:OpenFHEBackend}, + secure_vector::SecureVector{<:OpenFHEBackend}, private_key) + cc = get_crypto_context(secure_vector) OpenFHE.Decrypt(cc, private_key.private_key, secure_vector.ciphertext, plain_vector.plaintext) plain_vector end -function decrypt(context::SecureContext{<:OpenFHEBackend}, private_key, secure_vector) +function decrypt(secure_vector::SecureVector{<:OpenFHEBackend}, private_key) + context = secure_vector.context plain_vector = PlainVector(OpenFHE.Plaintext(), context) - decrypt!(plain_vector, context, private_key, secure_vector) + decrypt!(plain_vector, secure_vector, private_key) end -function bootstrap!(context::SecureContext{<:OpenFHEBackend}, secure_vector) - cc = get_crypto_context(context) - OpenFHE.EvalBootstrap(cc, secure_vector.ciphertext) - - secure_vector -end -function bootstrap!(context::SecureContext{<:OpenFHEBackend}, secure_vector) +function bootstrap!(secure_vector::SecureVector{<:OpenFHEBackend}) + context = secure_vector.context cc = get_crypto_context(context) OpenFHE.EvalBootstrap(cc, secure_vector.ciphertext) diff --git a/src/unencrypted.jl b/src/unencrypted.jl index 63092a2..ae6cff9 100644 --- a/src/unencrypted.jl +++ b/src/unencrypted.jl @@ -10,33 +10,30 @@ init_multiplication!(context::SecureContext{<:Unencrypted}, private_key) = nothi init_rotation!(context::SecureContext{<:Unencrypted}, private_key, shifts) = nothing init_bootstrapping!(context::SecureContext{<:Unencrypted}, private_key) = nothing -function PlainVector(context::SecureContext{<:Unencrypted}, data::Vector{<:Real}) - plain_vector = PlainVector(data, context) -end +# No constructor for `PlainVector` necessary since we can directly use the inner constructor -function encrypt(context::SecureContext{<:Unencrypted}, public_key, data::Vector{<:Real}) +function encrypt(data::Vector{<:Real}, public_key, context::SecureContext{<:Unencrypted}) SecureVector(data, context) end -function encrypt(context::SecureContext{<:Unencrypted}, public_key, - plain_vector::PlainVector) - SecureVector(plain_vector.plaintext, context) +function encrypt(plain_vector::PlainVector{<:Unencrypted}, public_key) + SecureVector(plain_vector.plaintext, plain_vector.context) end -function decrypt!(plain_vector, context::SecureContext{<:Unencrypted}, private_key, - secure_vector) +function decrypt!(plain_vector::PlainVector{<:Unencrypted}, + secure_vector::SecureVector{<:Unencrypted}, private_key) plain_vector.plaintext .= secure_vector.ciphertext plain_vector end -function decrypt(context::SecureContext{<:Unencrypted}, private_key, secure_vector) - plain_vector = PlainVector(similar(secure_vector.ciphertext), context) +function decrypt(secure_vector::SecureVector{<:Unencrypted}, private_key) + plain_vector = PlainVector(similar(secure_vector.ciphertext), secure_vector.context) - decrypt!(plain_vector, context, private_key, secure_vector) + decrypt!(plain_vector, secure_vector, private_key) end -bootstrap!(context::SecureContext{<:Unencrypted}, secure_vector) = secure_vector +bootstrap!(secure_vector::SecureVector{<:Unencrypted}) = secure_vector function add(sv1::SecureVector{<:Unencrypted}, sv2::SecureVector{<:Unencrypted}) SecureVector(sv1.ciphertext .+ sv2.ciphertext, sv1.context)