diff --git a/src/SecureArithmetic.jl b/src/SecureArithmetic.jl index 055f6f1..922fc59 100644 --- a/src/SecureArithmetic.jl +++ b/src/SecureArithmetic.jl @@ -5,6 +5,9 @@ using OpenFHE: OpenFHE # Basic types export SecureContext, SecureVector, PlainVector +# Keys +export PrivateKey, PublicKey + # Backends export Unencrypted, OpenFHEBackend diff --git a/src/openfhe.jl b/src/openfhe.jl index 690b00a..f5b3c03 100644 --- a/src/openfhe.jl +++ b/src/openfhe.jl @@ -5,11 +5,9 @@ end function get_crypto_context(context::SecureContext{<:OpenFHEBackend}) context.backend.crypto_context end -function get_crypto_context(secure_vector::SecureVector{<:OpenFHEBackend}) - get_crypto_context(secure_vector.context) -end -function get_crypto_context(plain_vector::PlainVector{<:OpenFHEBackend}) - get_crypto_context(plain_vector.context) +function get_crypto_context(v::Union{SecureVector{<:OpenFHEBackend}, + PlainVector{<:OpenFHEBackend}}) + get_crypto_context(v.context) end function generate_keys(context::SecureContext{<:OpenFHEBackend}) @@ -21,21 +19,24 @@ function generate_keys(context::SecureContext{<:OpenFHEBackend}) public_key, private_key end -function init_multiplication!(context::SecureContext{<:OpenFHEBackend}, private_key) +function init_multiplication!(context::SecureContext{<:OpenFHEBackend}, + private_key::PrivateKey) cc = get_crypto_context(context) OpenFHE.EvalMultKeyGen(cc, private_key.private_key) nothing end -function init_rotation!(context::SecureContext{<:OpenFHEBackend}, private_key, shifts) +function init_rotation!(context::SecureContext{<:OpenFHEBackend}, private_key::PrivateKey, + shifts) cc = get_crypto_context(context) OpenFHE.EvalRotateKeyGen(cc, private_key.private_key, shifts) nothing end -function init_bootstrapping!(context::SecureContext{<:OpenFHEBackend}, private_key) +function init_bootstrapping!(context::SecureContext{<:OpenFHEBackend}, + private_key::PrivateKey) cc = get_crypto_context(context) ring_dimension = OpenFHE.GetRingDimension(cc) num_slots = div(ring_dimension, 2) @@ -47,39 +48,39 @@ end function PlainVector(data::Vector{<:Real}, context::SecureContext{<:OpenFHEBackend}) cc = get_crypto_context(context) plaintext = OpenFHE.MakeCKKSPackedPlaintext(cc, data) - plain_vector = PlainVector(plaintext, context) + plain_vector = PlainVector(plaintext, length(data), context) plain_vector end function encrypt(data::Vector{<:Real}, public_key, context::SecureContext{<:OpenFHEBackend}) - plain_vector = PlainVector(context, data) - secure_vector = encrypt(context, public_key, plain_vector) + plain_vector = PlainVector(data, context) + secure_vector = encrypt(plain_vector, public_key) secure_vector end -function encrypt(plain_vector::PlainVector{<:OpenFHEBackend}, public_key) +function encrypt(plain_vector::PlainVector{<:OpenFHEBackend}, public_key::PublicKey) context = plain_vector.context cc = get_crypto_context(context) - ciphertext = OpenFHE.Encrypt(cc, public_key.public_key, plain_vector.plaintext) - secure_vector = SecureVector(ciphertext, context) + ciphertext = OpenFHE.Encrypt(cc, public_key.public_key, plain_vector.data) + secure_vector = SecureVector(ciphertext, length(plain_vector), context) secure_vector end function decrypt!(plain_vector::PlainVector{<:OpenFHEBackend}, - secure_vector::SecureVector{<:OpenFHEBackend}, private_key) + secure_vector::SecureVector{<:OpenFHEBackend}, private_key::PrivateKey) cc = get_crypto_context(secure_vector) - OpenFHE.Decrypt(cc, private_key.private_key, secure_vector.ciphertext, - plain_vector.plaintext) + OpenFHE.Decrypt(cc, private_key.private_key, secure_vector.data, + plain_vector.data) plain_vector end -function decrypt(secure_vector::SecureVector{<:OpenFHEBackend}, private_key) +function decrypt(secure_vector::SecureVector{<:OpenFHEBackend}, private_key::PrivateKey) context = secure_vector.context - plain_vector = PlainVector(OpenFHE.Plaintext(), context) + plain_vector = PlainVector(OpenFHE.Plaintext(), length(secure_vector), context) decrypt!(plain_vector, secure_vector, private_key) end @@ -88,7 +89,7 @@ end function bootstrap!(secure_vector::SecureVector{<:OpenFHEBackend}) context = secure_vector.context cc = get_crypto_context(context) - OpenFHE.EvalBootstrap(cc, secure_vector.ciphertext) + OpenFHE.EvalBootstrap(cc, secure_vector.data) secure_vector end @@ -100,96 +101,96 @@ end function add(sv1::SecureVector{<:OpenFHEBackend}, sv2::SecureVector{<:OpenFHEBackend}) cc = get_crypto_context(sv1) - ciphertext = OpenFHE.EvalAdd(cc, sv1.ciphertext, sv2.ciphertext) - secure_vector = SecureVector(ciphertext, sv1.context) + ciphertext = OpenFHE.EvalAdd(cc, sv1.data, sv2.data) + secure_vector = SecureVector(ciphertext, length(sv1), sv1.context) secure_vector end function add(sv::SecureVector{<:OpenFHEBackend}, pv::PlainVector{<:OpenFHEBackend}) cc = get_crypto_context(sv) - ciphertext = OpenFHE.EvalAdd(cc, sv.ciphertext, pv.plaintext) - secure_vector = SecureVector(ciphertext, sv.context) + ciphertext = OpenFHE.EvalAdd(cc, sv.data, pv.data) + secure_vector = SecureVector(ciphertext, length(sv), sv.context) secure_vector end function add(sv::SecureVector{<:OpenFHEBackend}, scalar::Real) cc = get_crypto_context(sv) - ciphertext = OpenFHE.EvalAdd(cc, sv.ciphertext, scalar) - secure_vector = SecureVector(ciphertext, sv.context) + ciphertext = OpenFHE.EvalAdd(cc, sv.data, scalar) + secure_vector = SecureVector(ciphertext, length(sv), sv.context) secure_vector end function subtract(sv1::SecureVector{<:OpenFHEBackend}, sv2::SecureVector{<:OpenFHEBackend}) cc = get_crypto_context(sv1) - ciphertext = OpenFHE.EvalSub(cc, sv1.ciphertext, sv2.ciphertext) - secure_vector = SecureVector(ciphertext, sv1.context) + ciphertext = OpenFHE.EvalSub(cc, sv1.data, sv2.data) + secure_vector = SecureVector(ciphertext, length(sv1), sv1.context) secure_vector end function subtract(sv::SecureVector{<:OpenFHEBackend}, pv::PlainVector{<:OpenFHEBackend}) cc = get_crypto_context(sv) - ciphertext = OpenFHE.EvalSub(cc, sv.ciphertext, pv.plaintext) - secure_vector = SecureVector(ciphertext, sv.context) + ciphertext = OpenFHE.EvalSub(cc, sv.data, pv.data) + secure_vector = SecureVector(ciphertext, length(sv), sv.context) secure_vector end function subtract(pv::PlainVector{<:OpenFHEBackend}, sv::SecureVector{<:OpenFHEBackend}) cc = get_crypto_context(sv) - ciphertext = OpenFHE.EvalSub(cc, pv.plaintext, sv.ciphertext) - secure_vector = SecureVector(ciphertext, sv.context) + ciphertext = OpenFHE.EvalSub(cc, pv.data, sv.data) + secure_vector = SecureVector(ciphertext, length(sv), sv.context) secure_vector end function subtract(sv::SecureVector{<:OpenFHEBackend}, scalar::Real) cc = get_crypto_context(sv) - ciphertext = OpenFHE.EvalSub(cc, sv.ciphertext, scalar) - secure_vector = SecureVector(ciphertext, sv.context) + ciphertext = OpenFHE.EvalSub(cc, sv.data, scalar) + secure_vector = SecureVector(ciphertext, length(sv), sv.context) secure_vector end function subtract(scalar::Real, sv::SecureVector{<:OpenFHEBackend}) cc = get_crypto_context(sv) - ciphertext = OpenFHE.EvalSub(cc, scalar, sv.ciphertext) - secure_vector = SecureVector(ciphertext, sv.context) + ciphertext = OpenFHE.EvalSub(cc, scalar, sv.data) + secure_vector = SecureVector(ciphertext, length(sv), sv.context) secure_vector end function negate(sv::SecureVector{<:OpenFHEBackend}) cc = get_crypto_context(sv) - ciphertext = OpenFHE.EvalNegate(cc, sv.ciphertext) - secure_vector = SecureVector(ciphertext, sv.context) + ciphertext = OpenFHE.EvalNegate(cc, sv.data) + secure_vector = SecureVector(ciphertext, length(sv), sv.context) secure_vector end function multiply(sv1::SecureVector{<:OpenFHEBackend}, sv2::SecureVector{<:OpenFHEBackend}) cc = get_crypto_context(sv1) - ciphertext = OpenFHE.EvalMult(cc, sv1.ciphertext, sv2.ciphertext) - secure_vector = SecureVector(ciphertext, sv1.context) + ciphertext = OpenFHE.EvalMult(cc, sv1.data, sv2.data) + secure_vector = SecureVector(ciphertext, length(sv1), sv1.context) secure_vector end function multiply(sv::SecureVector{<:OpenFHEBackend}, pv::PlainVector{<:OpenFHEBackend}) cc = get_crypto_context(sv) - ciphertext = OpenFHE.EvalMult(cc, sv.ciphertext, pv.plaintext) - secure_vector = SecureVector(ciphertext, sv.context) + ciphertext = OpenFHE.EvalMult(cc, sv.data, pv.data) + secure_vector = SecureVector(ciphertext, length(sv), sv.context) secure_vector end function multiply(sv::SecureVector{<:OpenFHEBackend}, scalar::Real) cc = get_crypto_context(sv) - ciphertext = OpenFHE.EvalMult(cc, sv.ciphertext, scalar) - secure_vector = SecureVector(ciphertext, sv.context) + ciphertext = OpenFHE.EvalMult(cc, sv.data, scalar) + secure_vector = SecureVector(ciphertext, length(sv), sv.context) secure_vector end @@ -197,8 +198,8 @@ end function rotate(sv::SecureVector{<:OpenFHEBackend}, shift) cc = get_crypto_context(sv) # We use `-shift` to match Julia's usual `circshift` direction - ciphertext = OpenFHE.EvalRotate(cc, sv.ciphertext, -shift) - secure_vector = SecureVector(ciphertext, sv.context) + ciphertext = OpenFHE.EvalRotate(cc, sv.data, -shift) + secure_vector = SecureVector(ciphertext, length(sv), sv.context) secure_vector end diff --git a/src/types.jl b/src/types.jl index 01e536a..d2f2308 100644 --- a/src/types.jl +++ b/src/types.jl @@ -4,25 +4,41 @@ struct SecureContext{CryptoBackendT <: AbstractCryptoBackend} backend::CryptoBackendT end -struct SecureVector{CryptoBackendT <: AbstractCryptoBackend, CiphertextT} - ciphertext::CiphertextT +function Base.show(io::IO, v::SecureContext) + print("SecureContext{", backend_name(v), "}()") +end + +struct SecureVector{CryptoBackendT <: AbstractCryptoBackend, DataT} + data::DataT + length::Int context::SecureContext{CryptoBackendT} - function SecureVector(ciphertext, context::SecureContext{CryptoBackendT}) where CryptoBackendT - new{CryptoBackendT, typeof(ciphertext)}(ciphertext, context) + function SecureVector(data, length, context::SecureContext{CryptoBackendT}) where CryptoBackendT + new{CryptoBackendT, typeof(data)}(data, length, context) end end -struct PlainVector{CryptoBackendT <: AbstractCryptoBackend, PlaintextT} - plaintext::PlaintextT +Base.length(v::SecureVector) = v.length +function Base.show(io::IO, v::SecureVector) + print("SecureVector{", backend_name(v), "}(data=, length=$(v.length))") +end + +struct PlainVector{CryptoBackendT <: AbstractCryptoBackend, DataT} + data::DataT + length::Int context::SecureContext{CryptoBackendT} - function PlainVector(plaintext, context::SecureContext{CryptoBackendT}) where CryptoBackendT - new{CryptoBackendT, typeof(plaintext)}(plaintext, context) + function PlainVector(data, length, context::SecureContext{CryptoBackendT}) where CryptoBackendT + new{CryptoBackendT, typeof(data)}(data, length, context) end end -Base.print(io::IO, plain_vector::PlainVector) = print(io, plain_vector.plaintext) +Base.length(v::PlainVector) = v.length +function Base.show(io::IO, v::PlainVector{CryptoBackendT}) where CryptoBackendT + print("PlainVector{", backend_name(v), "}(data=, length=$(v.length))") +end + +Base.print(io::IO, plain_vector::PlainVector) = print(io, plain_vector.data) struct PrivateKey{CryptoBackendT <: AbstractCryptoBackend, KeyT} private_key::KeyT @@ -33,6 +49,10 @@ struct PrivateKey{CryptoBackendT <: AbstractCryptoBackend, KeyT} end end +function Base.show(io::IO, key::PrivateKey{CryptoBackendT}) where CryptoBackendT + print("PrivateKey{", backend_name(key), "}()") +end + struct PublicKey{CryptoBackendT <: AbstractCryptoBackend, KeyT} public_key::KeyT context::SecureContext{CryptoBackendT} @@ -41,3 +61,17 @@ struct PublicKey{CryptoBackendT <: AbstractCryptoBackend, KeyT} new{CryptoBackendT, typeof(key)}(key, context) end end + +function Base.show(io::IO, key::PublicKey{CryptoBackendT}) where CryptoBackendT + print("PublicKey{", backend_name(key), "}()") +end + +# Get wrapper name of a potentially parametric type +# Copied from: https://github.com/ClapeyronThermo/Clapeyron.jl/blob/f40c282e2236ff68d91f37c39b5c1e4230ae9ef0/src/utils/core_utils.jl#L17 +# Original source: https://github.com/JuliaArrays/ArrayInterface.jl/blob/40d9a87be07ba323cca00f9e59e5285c13f7ee72/src/ArrayInterface.jl#L20 +# Note: prefixed by `__` since it is really, really dirty black magic internals we use here! +__parameterless_type(T) = Base.typename(T).wrapper + +# Convenience method for getting the human-readable backend name +backend_name(x::Union{SecureContext{T}, SecureVector{T}, PlainVector{T}, PrivateKey{T}, + PublicKey{T}}) where T = string(__parameterless_type(T)) diff --git a/src/unencrypted.jl b/src/unencrypted.jl index 7b15470..6152844 100644 --- a/src/unencrypted.jl +++ b/src/unencrypted.jl @@ -6,29 +6,32 @@ function generate_keys(context::SecureContext{<:Unencrypted}) PublicKey(context, nothing), PrivateKey(context, nothing) end -init_multiplication!(context::SecureContext{<:Unencrypted}, private_key) = nothing -init_rotation!(context::SecureContext{<:Unencrypted}, private_key, shifts) = nothing -init_bootstrapping!(context::SecureContext{<:Unencrypted}, private_key) = nothing +init_multiplication!(context::SecureContext{<:Unencrypted}, private_key::PrivateKey) = nothing +init_rotation!(context::SecureContext{<:Unencrypted}, private_key::PrivateKey, shifts) = nothing +init_bootstrapping!(context::SecureContext{<:Unencrypted}, private_key::PrivateKey) = nothing -# No constructor for `PlainVector` necessary since we can directly use the inner constructor +function PlainVector(data::Vector{<:Real}, context::SecureContext{<:Unencrypted}) + PlainVector(data, length(data), context) +end -function encrypt(data::Vector{<:Real}, public_key, context::SecureContext{<:Unencrypted}) - SecureVector(data, context) +function encrypt(data::Vector{<:Real}, public_key::PublicKey, + context::SecureContext{<:Unencrypted}) + SecureVector(data, length(data), context) end -function encrypt(plain_vector::PlainVector{<:Unencrypted}, public_key) - SecureVector(plain_vector.plaintext, plain_vector.context) +function encrypt(plain_vector::PlainVector{<:Unencrypted}, public_key::PublicKey) + SecureVector(plain_vector.data, length(plain_vector), plain_vector.context) end function decrypt!(plain_vector::PlainVector{<:Unencrypted}, - secure_vector::SecureVector{<:Unencrypted}, private_key) - plain_vector.plaintext .= secure_vector.ciphertext + secure_vector::SecureVector{<:Unencrypted}, private_key::PrivateKey) + plain_vector.data .= secure_vector.data plain_vector end -function decrypt(secure_vector::SecureVector{<:Unencrypted}, private_key) - plain_vector = PlainVector(similar(secure_vector.ciphertext), secure_vector.context) +function decrypt(secure_vector::SecureVector{<:Unencrypted}, private_key::PrivateKey) + plain_vector = PlainVector(similar(secure_vector.data), length(secure_vector), secure_vector.context) decrypt!(plain_vector, secure_vector, private_key) end @@ -41,53 +44,53 @@ bootstrap!(secure_vector::SecureVector{<:Unencrypted}) = secure_vector ############################################################################################ function add(sv1::SecureVector{<:Unencrypted}, sv2::SecureVector{<:Unencrypted}) - SecureVector(sv1.ciphertext .+ sv2.ciphertext, sv1.context) + SecureVector(sv1.data .+ sv2.data, length(sv1), sv1.context) end function add(sv::SecureVector{<:Unencrypted}, pv::PlainVector{<:Unencrypted}) - SecureVector(sv.ciphertext .+ pv.plaintext, sv.context) + SecureVector(sv.data .+ pv.data, length(sv), sv.context) end function add(sv::SecureVector{<:Unencrypted}, scalar::Real) - SecureVector(sv.ciphertext .+ scalar, sv.context) + SecureVector(sv.data .+ scalar, length(sv), sv.context) end function subtract(sv1::SecureVector{<:Unencrypted}, sv2::SecureVector{<:Unencrypted}) - SecureVector(sv1.ciphertext .- sv2.ciphertext, sv1.context) + SecureVector(sv1.data .- sv2.data, length(sv1), sv1.context) end function subtract(sv::SecureVector{<:Unencrypted}, pv::PlainVector{<:Unencrypted}) - SecureVector(sv.ciphertext .- pv.plaintext, sv.context) + SecureVector(sv.data .- pv.data, length(sv), sv.context) end function subtract(pv::PlainVector{<:Unencrypted}, sv::SecureVector{<:Unencrypted}) - SecureVector(pv.plaintext .- sv.ciphertext, sv.context) + SecureVector(pv.data .- sv.data, length(sv), sv.context) end function subtract(sv::SecureVector{<:Unencrypted}, scalar::Real) - SecureVector(sv.ciphertext .- scalar, sv.context) + SecureVector(sv.data .- scalar, length(sv), sv.context) end function subtract(scalar::Real, sv::SecureVector{<:Unencrypted}) - SecureVector(scalar .- sv.ciphertext, sv.context) + SecureVector(scalar .- sv.data, length(sv), sv.context) end function negate(sv::SecureVector{<:Unencrypted}) - SecureVector(-sv.ciphertext, sv.context) + SecureVector(-sv.data, length(sv), sv.context) end function multiply(sv1::SecureVector{<:Unencrypted}, sv2::SecureVector{<:Unencrypted}) - SecureVector(sv1.ciphertext .* sv2.ciphertext, sv1.context) + SecureVector(sv1.data .* sv2.data, length(sv1), sv1.context) end function multiply(sv::SecureVector{<:Unencrypted}, pv::PlainVector{<:Unencrypted}) - SecureVector(sv.ciphertext .* pv.plaintext, sv.context) + SecureVector(sv.data .* pv.data, length(sv), sv.context) end function multiply(sv::SecureVector{<:Unencrypted}, scalar::Real) - SecureVector(sv.ciphertext .* scalar, sv.context) + SecureVector(sv.data .* scalar, length(sv), sv.context) end function rotate(sv::SecureVector{<:Unencrypted}, shift) - SecureVector(circshift(sv.ciphertext, shift), sv.context) + SecureVector(circshift(sv.data, shift), length(sv), sv.context) end diff --git a/test/runtests.jl b/test/runtests.jl index 7c0c6ef..633d1d2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,7 +1,7 @@ using Test @time @testset verbose=true showtiming=true "SecureArithmetic.jl tests" begin - include("test_examples.jl") include("test_unit.jl") + include("test_examples.jl") end diff --git a/test/test_unit.jl b/test/test_unit.jl index b4be383..cd0b0b5 100644 --- a/test/test_unit.jl +++ b/test/test_unit.jl @@ -55,6 +55,7 @@ for backend in ((; name = "OpenFHE", BackendT = OpenFHEBackend, context = contex @testset verbose=true showtiming=true "encrypt" begin @test encrypt(pv1, public_key) isa SecureVector + @test encrypt([1.0, 2.0, 3.0], public_key, context) isa SecureVector end sv1 = encrypt(pv1, public_key) @@ -87,6 +88,23 @@ for backend in ((; name = "OpenFHE", BackendT = OpenFHEBackend, context = contex @testset verbose=true showtiming=true "negate" begin @test -sv2 isa SecureVector end + + @testset verbose=true showtiming=true "show" begin + @test_nowarn show(context) + println() + + @test_nowarn show(pv1) + println() + + @test_nowarn show(sv1) + println() + + @test_nowarn show(public_key) + println() + + @test_nowarn show(private_key) + println() + end end end