From ebd4dcf9e75b84b7d7adb575db544238a0c5388c Mon Sep 17 00:00:00 2001 From: Michael Schlottke-Lakemper Date: Sun, 21 Jan 2024 06:22:45 +0100 Subject: [PATCH 1/5] Change data field name to data for plain and secure vector --- src/openfhe.jl | 34 +++++++++++++++++----------------- src/types.jl | 18 +++++++++--------- src/unencrypted.jl | 32 ++++++++++++++++---------------- 3 files changed, 42 insertions(+), 42 deletions(-) diff --git a/src/openfhe.jl b/src/openfhe.jl index 690b00a..e163f28 100644 --- a/src/openfhe.jl +++ b/src/openfhe.jl @@ -62,7 +62,7 @@ end function encrypt(plain_vector::PlainVector{<:OpenFHEBackend}, public_key) context = plain_vector.context cc = get_crypto_context(context) - ciphertext = OpenFHE.Encrypt(cc, public_key.public_key, plain_vector.plaintext) + ciphertext = OpenFHE.Encrypt(cc, public_key.public_key, plain_vector.data) secure_vector = SecureVector(ciphertext, context) secure_vector @@ -71,8 +71,8 @@ end function decrypt!(plain_vector::PlainVector{<:OpenFHEBackend}, secure_vector::SecureVector{<:OpenFHEBackend}, private_key) cc = get_crypto_context(secure_vector) - OpenFHE.Decrypt(cc, private_key.private_key, secure_vector.ciphertext, - plain_vector.plaintext) + OpenFHE.Decrypt(cc, private_key.private_key, secure_vector.data, + plain_vector.data) plain_vector end @@ -88,7 +88,7 @@ end function bootstrap!(secure_vector::SecureVector{<:OpenFHEBackend}) context = secure_vector.context cc = get_crypto_context(context) - OpenFHE.EvalBootstrap(cc, secure_vector.ciphertext) + OpenFHE.EvalBootstrap(cc, secure_vector.data) secure_vector end @@ -100,7 +100,7 @@ end function add(sv1::SecureVector{<:OpenFHEBackend}, sv2::SecureVector{<:OpenFHEBackend}) cc = get_crypto_context(sv1) - ciphertext = OpenFHE.EvalAdd(cc, sv1.ciphertext, sv2.ciphertext) + ciphertext = OpenFHE.EvalAdd(cc, sv1.data, sv2.data) secure_vector = SecureVector(ciphertext, sv1.context) secure_vector @@ -108,7 +108,7 @@ end function add(sv::SecureVector{<:OpenFHEBackend}, pv::PlainVector{<:OpenFHEBackend}) cc = get_crypto_context(sv) - ciphertext = OpenFHE.EvalAdd(cc, sv.ciphertext, pv.plaintext) + ciphertext = OpenFHE.EvalAdd(cc, sv.data, pv.data) secure_vector = SecureVector(ciphertext, sv.context) secure_vector @@ -116,7 +116,7 @@ end function add(sv::SecureVector{<:OpenFHEBackend}, scalar::Real) cc = get_crypto_context(sv) - ciphertext = OpenFHE.EvalAdd(cc, sv.ciphertext, scalar) + ciphertext = OpenFHE.EvalAdd(cc, sv.data, scalar) secure_vector = SecureVector(ciphertext, sv.context) secure_vector @@ -124,7 +124,7 @@ end function subtract(sv1::SecureVector{<:OpenFHEBackend}, sv2::SecureVector{<:OpenFHEBackend}) cc = get_crypto_context(sv1) - ciphertext = OpenFHE.EvalSub(cc, sv1.ciphertext, sv2.ciphertext) + ciphertext = OpenFHE.EvalSub(cc, sv1.data, sv2.data) secure_vector = SecureVector(ciphertext, sv1.context) secure_vector @@ -132,7 +132,7 @@ end function subtract(sv::SecureVector{<:OpenFHEBackend}, pv::PlainVector{<:OpenFHEBackend}) cc = get_crypto_context(sv) - ciphertext = OpenFHE.EvalSub(cc, sv.ciphertext, pv.plaintext) + ciphertext = OpenFHE.EvalSub(cc, sv.data, pv.data) secure_vector = SecureVector(ciphertext, sv.context) secure_vector @@ -140,7 +140,7 @@ end function subtract(pv::PlainVector{<:OpenFHEBackend}, sv::SecureVector{<:OpenFHEBackend}) cc = get_crypto_context(sv) - ciphertext = OpenFHE.EvalSub(cc, pv.plaintext, sv.ciphertext) + ciphertext = OpenFHE.EvalSub(cc, pv.data, sv.data) secure_vector = SecureVector(ciphertext, sv.context) secure_vector @@ -148,7 +148,7 @@ end function subtract(sv::SecureVector{<:OpenFHEBackend}, scalar::Real) cc = get_crypto_context(sv) - ciphertext = OpenFHE.EvalSub(cc, sv.ciphertext, scalar) + ciphertext = OpenFHE.EvalSub(cc, sv.data, scalar) secure_vector = SecureVector(ciphertext, sv.context) secure_vector @@ -156,7 +156,7 @@ end function subtract(scalar::Real, sv::SecureVector{<:OpenFHEBackend}) cc = get_crypto_context(sv) - ciphertext = OpenFHE.EvalSub(cc, scalar, sv.ciphertext) + ciphertext = OpenFHE.EvalSub(cc, scalar, sv.data) secure_vector = SecureVector(ciphertext, sv.context) secure_vector @@ -164,7 +164,7 @@ end function negate(sv::SecureVector{<:OpenFHEBackend}) cc = get_crypto_context(sv) - ciphertext = OpenFHE.EvalNegate(cc, sv.ciphertext) + ciphertext = OpenFHE.EvalNegate(cc, sv.data) secure_vector = SecureVector(ciphertext, sv.context) secure_vector @@ -172,7 +172,7 @@ end function multiply(sv1::SecureVector{<:OpenFHEBackend}, sv2::SecureVector{<:OpenFHEBackend}) cc = get_crypto_context(sv1) - ciphertext = OpenFHE.EvalMult(cc, sv1.ciphertext, sv2.ciphertext) + ciphertext = OpenFHE.EvalMult(cc, sv1.data, sv2.data) secure_vector = SecureVector(ciphertext, sv1.context) secure_vector @@ -180,7 +180,7 @@ end function multiply(sv::SecureVector{<:OpenFHEBackend}, pv::PlainVector{<:OpenFHEBackend}) cc = get_crypto_context(sv) - ciphertext = OpenFHE.EvalMult(cc, sv.ciphertext, pv.plaintext) + ciphertext = OpenFHE.EvalMult(cc, sv.data, pv.data) secure_vector = SecureVector(ciphertext, sv.context) secure_vector @@ -188,7 +188,7 @@ end function multiply(sv::SecureVector{<:OpenFHEBackend}, scalar::Real) cc = get_crypto_context(sv) - ciphertext = OpenFHE.EvalMult(cc, sv.ciphertext, scalar) + ciphertext = OpenFHE.EvalMult(cc, sv.data, scalar) secure_vector = SecureVector(ciphertext, sv.context) secure_vector @@ -197,7 +197,7 @@ end function rotate(sv::SecureVector{<:OpenFHEBackend}, shift) cc = get_crypto_context(sv) # We use `-shift` to match Julia's usual `circshift` direction - ciphertext = OpenFHE.EvalRotate(cc, sv.ciphertext, -shift) + ciphertext = OpenFHE.EvalRotate(cc, sv.data, -shift) secure_vector = SecureVector(ciphertext, sv.context) secure_vector diff --git a/src/types.jl b/src/types.jl index 01e536a..65e633b 100644 --- a/src/types.jl +++ b/src/types.jl @@ -4,25 +4,25 @@ struct SecureContext{CryptoBackendT <: AbstractCryptoBackend} backend::CryptoBackendT end -struct SecureVector{CryptoBackendT <: AbstractCryptoBackend, CiphertextT} - ciphertext::CiphertextT +struct SecureVector{CryptoBackendT <: AbstractCryptoBackend, DataT} + data::DataT context::SecureContext{CryptoBackendT} - function SecureVector(ciphertext, context::SecureContext{CryptoBackendT}) where CryptoBackendT - new{CryptoBackendT, typeof(ciphertext)}(ciphertext, context) + function SecureVector(data, context::SecureContext{CryptoBackendT}) where CryptoBackendT + new{CryptoBackendT, typeof(data)}(data, context) end end -struct PlainVector{CryptoBackendT <: AbstractCryptoBackend, PlaintextT} - plaintext::PlaintextT +struct PlainVector{CryptoBackendT <: AbstractCryptoBackend, DataT} + data::DataT context::SecureContext{CryptoBackendT} - function PlainVector(plaintext, context::SecureContext{CryptoBackendT}) where CryptoBackendT - new{CryptoBackendT, typeof(plaintext)}(plaintext, context) + function PlainVector(data, context::SecureContext{CryptoBackendT}) where CryptoBackendT + new{CryptoBackendT, typeof(data)}(data, context) end end -Base.print(io::IO, plain_vector::PlainVector) = print(io, plain_vector.plaintext) +Base.print(io::IO, plain_vector::PlainVector) = print(io, plain_vector.data) struct PrivateKey{CryptoBackendT <: AbstractCryptoBackend, KeyT} private_key::KeyT diff --git a/src/unencrypted.jl b/src/unencrypted.jl index 7b15470..6493988 100644 --- a/src/unencrypted.jl +++ b/src/unencrypted.jl @@ -17,18 +17,18 @@ function encrypt(data::Vector{<:Real}, public_key, context::SecureContext{<:Unen end function encrypt(plain_vector::PlainVector{<:Unencrypted}, public_key) - SecureVector(plain_vector.plaintext, plain_vector.context) + SecureVector(plain_vector.data, plain_vector.context) end function decrypt!(plain_vector::PlainVector{<:Unencrypted}, secure_vector::SecureVector{<:Unencrypted}, private_key) - plain_vector.plaintext .= secure_vector.ciphertext + plain_vector.data .= secure_vector.data plain_vector end function decrypt(secure_vector::SecureVector{<:Unencrypted}, private_key) - plain_vector = PlainVector(similar(secure_vector.ciphertext), secure_vector.context) + plain_vector = PlainVector(similar(secure_vector.data), secure_vector.context) decrypt!(plain_vector, secure_vector, private_key) end @@ -41,53 +41,53 @@ bootstrap!(secure_vector::SecureVector{<:Unencrypted}) = secure_vector ############################################################################################ function add(sv1::SecureVector{<:Unencrypted}, sv2::SecureVector{<:Unencrypted}) - SecureVector(sv1.ciphertext .+ sv2.ciphertext, sv1.context) + SecureVector(sv1.data .+ sv2.data, sv1.context) end function add(sv::SecureVector{<:Unencrypted}, pv::PlainVector{<:Unencrypted}) - SecureVector(sv.ciphertext .+ pv.plaintext, sv.context) + SecureVector(sv.data .+ pv.data, sv.context) end function add(sv::SecureVector{<:Unencrypted}, scalar::Real) - SecureVector(sv.ciphertext .+ scalar, sv.context) + SecureVector(sv.data .+ scalar, sv.context) end function subtract(sv1::SecureVector{<:Unencrypted}, sv2::SecureVector{<:Unencrypted}) - SecureVector(sv1.ciphertext .- sv2.ciphertext, sv1.context) + SecureVector(sv1.data .- sv2.data, sv1.context) end function subtract(sv::SecureVector{<:Unencrypted}, pv::PlainVector{<:Unencrypted}) - SecureVector(sv.ciphertext .- pv.plaintext, sv.context) + SecureVector(sv.data .- pv.data, sv.context) end function subtract(pv::PlainVector{<:Unencrypted}, sv::SecureVector{<:Unencrypted}) - SecureVector(pv.plaintext .- sv.ciphertext, sv.context) + SecureVector(pv.data .- sv.data, sv.context) end function subtract(sv::SecureVector{<:Unencrypted}, scalar::Real) - SecureVector(sv.ciphertext .- scalar, sv.context) + SecureVector(sv.data .- scalar, sv.context) end function subtract(scalar::Real, sv::SecureVector{<:Unencrypted}) - SecureVector(scalar .- sv.ciphertext, sv.context) + SecureVector(scalar .- sv.data, sv.context) end function negate(sv::SecureVector{<:Unencrypted}) - SecureVector(-sv.ciphertext, sv.context) + SecureVector(-sv.data, sv.context) end function multiply(sv1::SecureVector{<:Unencrypted}, sv2::SecureVector{<:Unencrypted}) - SecureVector(sv1.ciphertext .* sv2.ciphertext, sv1.context) + SecureVector(sv1.data .* sv2.data, sv1.context) end function multiply(sv::SecureVector{<:Unencrypted}, pv::PlainVector{<:Unencrypted}) - SecureVector(sv.ciphertext .* pv.plaintext, sv.context) + SecureVector(sv.data .* pv.data, sv.context) end function multiply(sv::SecureVector{<:Unencrypted}, scalar::Real) - SecureVector(sv.ciphertext .* scalar, sv.context) + SecureVector(sv.data .* scalar, sv.context) end function rotate(sv::SecureVector{<:Unencrypted}, shift) - SecureVector(circshift(sv.ciphertext, shift), sv.context) + SecureVector(circshift(sv.data, shift), sv.context) end From 34d717f7d56ad8656897a04f0c36277e9f66c5e5 Mon Sep 17 00:00:00 2001 From: Michael Schlottke-Lakemper Date: Sun, 21 Jan 2024 06:40:52 +0100 Subject: [PATCH 2/5] Run unit tests first --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 7c0c6ef..633d1d2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,7 +1,7 @@ using Test @time @testset verbose=true showtiming=true "SecureArithmetic.jl tests" begin - include("test_examples.jl") include("test_unit.jl") + include("test_examples.jl") end From adfcc22ff59701b44e6c741b78f2ebbc197d8b7e Mon Sep 17 00:00:00 2001 From: Michael Schlottke-Lakemper Date: Sun, 21 Jan 2024 07:09:58 +0100 Subject: [PATCH 3/5] Introduce `length` for Plain/SecureVector --- src/openfhe.jl | 36 ++++++++++++++++++------------------ src/types.jl | 14 ++++++++++---- src/unencrypted.jl | 36 +++++++++++++++++++----------------- 3 files changed, 47 insertions(+), 39 deletions(-) diff --git a/src/openfhe.jl b/src/openfhe.jl index e163f28..3d1c6e3 100644 --- a/src/openfhe.jl +++ b/src/openfhe.jl @@ -47,14 +47,14 @@ end function PlainVector(data::Vector{<:Real}, context::SecureContext{<:OpenFHEBackend}) cc = get_crypto_context(context) plaintext = OpenFHE.MakeCKKSPackedPlaintext(cc, data) - plain_vector = PlainVector(plaintext, context) + plain_vector = PlainVector(plaintext, length(data), context) plain_vector end function encrypt(data::Vector{<:Real}, public_key, context::SecureContext{<:OpenFHEBackend}) - plain_vector = PlainVector(context, data) - secure_vector = encrypt(context, public_key, plain_vector) + plain_vector = PlainVector(data, length(data), context) + secure_vector = encrypt(plain_vector, public_key) secure_vector end @@ -63,7 +63,7 @@ function encrypt(plain_vector::PlainVector{<:OpenFHEBackend}, public_key) context = plain_vector.context cc = get_crypto_context(context) ciphertext = OpenFHE.Encrypt(cc, public_key.public_key, plain_vector.data) - secure_vector = SecureVector(ciphertext, context) + secure_vector = SecureVector(ciphertext, length(plain_vector), context) secure_vector end @@ -79,7 +79,7 @@ end function decrypt(secure_vector::SecureVector{<:OpenFHEBackend}, private_key) context = secure_vector.context - plain_vector = PlainVector(OpenFHE.Plaintext(), context) + plain_vector = PlainVector(OpenFHE.Plaintext(), length(secure_vector), context) decrypt!(plain_vector, secure_vector, private_key) end @@ -101,7 +101,7 @@ end function add(sv1::SecureVector{<:OpenFHEBackend}, sv2::SecureVector{<:OpenFHEBackend}) cc = get_crypto_context(sv1) ciphertext = OpenFHE.EvalAdd(cc, sv1.data, sv2.data) - secure_vector = SecureVector(ciphertext, sv1.context) + secure_vector = SecureVector(ciphertext, length(sv1), sv1.context) secure_vector end @@ -109,7 +109,7 @@ end function add(sv::SecureVector{<:OpenFHEBackend}, pv::PlainVector{<:OpenFHEBackend}) cc = get_crypto_context(sv) ciphertext = OpenFHE.EvalAdd(cc, sv.data, pv.data) - secure_vector = SecureVector(ciphertext, sv.context) + secure_vector = SecureVector(ciphertext, length(sv), sv.context) secure_vector end @@ -117,7 +117,7 @@ end function add(sv::SecureVector{<:OpenFHEBackend}, scalar::Real) cc = get_crypto_context(sv) ciphertext = OpenFHE.EvalAdd(cc, sv.data, scalar) - secure_vector = SecureVector(ciphertext, sv.context) + secure_vector = SecureVector(ciphertext, length(sv), sv.context) secure_vector end @@ -125,7 +125,7 @@ end function subtract(sv1::SecureVector{<:OpenFHEBackend}, sv2::SecureVector{<:OpenFHEBackend}) cc = get_crypto_context(sv1) ciphertext = OpenFHE.EvalSub(cc, sv1.data, sv2.data) - secure_vector = SecureVector(ciphertext, sv1.context) + secure_vector = SecureVector(ciphertext, length(sv1), sv1.context) secure_vector end @@ -133,7 +133,7 @@ end function subtract(sv::SecureVector{<:OpenFHEBackend}, pv::PlainVector{<:OpenFHEBackend}) cc = get_crypto_context(sv) ciphertext = OpenFHE.EvalSub(cc, sv.data, pv.data) - secure_vector = SecureVector(ciphertext, sv.context) + secure_vector = SecureVector(ciphertext, length(sv), sv.context) secure_vector end @@ -141,7 +141,7 @@ end function subtract(pv::PlainVector{<:OpenFHEBackend}, sv::SecureVector{<:OpenFHEBackend}) cc = get_crypto_context(sv) ciphertext = OpenFHE.EvalSub(cc, pv.data, sv.data) - secure_vector = SecureVector(ciphertext, sv.context) + secure_vector = SecureVector(ciphertext, length(sv), sv.context) secure_vector end @@ -149,7 +149,7 @@ end function subtract(sv::SecureVector{<:OpenFHEBackend}, scalar::Real) cc = get_crypto_context(sv) ciphertext = OpenFHE.EvalSub(cc, sv.data, scalar) - secure_vector = SecureVector(ciphertext, sv.context) + secure_vector = SecureVector(ciphertext, length(sv), sv.context) secure_vector end @@ -157,7 +157,7 @@ end function subtract(scalar::Real, sv::SecureVector{<:OpenFHEBackend}) cc = get_crypto_context(sv) ciphertext = OpenFHE.EvalSub(cc, scalar, sv.data) - secure_vector = SecureVector(ciphertext, sv.context) + secure_vector = SecureVector(ciphertext, length(sv), sv.context) secure_vector end @@ -165,7 +165,7 @@ end function negate(sv::SecureVector{<:OpenFHEBackend}) cc = get_crypto_context(sv) ciphertext = OpenFHE.EvalNegate(cc, sv.data) - secure_vector = SecureVector(ciphertext, sv.context) + secure_vector = SecureVector(ciphertext, length(sv), sv.context) secure_vector end @@ -173,7 +173,7 @@ end function multiply(sv1::SecureVector{<:OpenFHEBackend}, sv2::SecureVector{<:OpenFHEBackend}) cc = get_crypto_context(sv1) ciphertext = OpenFHE.EvalMult(cc, sv1.data, sv2.data) - secure_vector = SecureVector(ciphertext, sv1.context) + secure_vector = SecureVector(ciphertext, length(sv1), sv1.context) secure_vector end @@ -181,7 +181,7 @@ end function multiply(sv::SecureVector{<:OpenFHEBackend}, pv::PlainVector{<:OpenFHEBackend}) cc = get_crypto_context(sv) ciphertext = OpenFHE.EvalMult(cc, sv.data, pv.data) - secure_vector = SecureVector(ciphertext, sv.context) + secure_vector = SecureVector(ciphertext, length(sv), sv.context) secure_vector end @@ -189,7 +189,7 @@ end function multiply(sv::SecureVector{<:OpenFHEBackend}, scalar::Real) cc = get_crypto_context(sv) ciphertext = OpenFHE.EvalMult(cc, sv.data, scalar) - secure_vector = SecureVector(ciphertext, sv.context) + secure_vector = SecureVector(ciphertext, length(sv), sv.context) secure_vector end @@ -198,7 +198,7 @@ function rotate(sv::SecureVector{<:OpenFHEBackend}, shift) cc = get_crypto_context(sv) # We use `-shift` to match Julia's usual `circshift` direction ciphertext = OpenFHE.EvalRotate(cc, sv.data, -shift) - secure_vector = SecureVector(ciphertext, sv.context) + secure_vector = SecureVector(ciphertext, length(sv), sv.context) secure_vector end diff --git a/src/types.jl b/src/types.jl index 65e633b..a6192dc 100644 --- a/src/types.jl +++ b/src/types.jl @@ -6,22 +6,28 @@ end struct SecureVector{CryptoBackendT <: AbstractCryptoBackend, DataT} data::DataT + length::Int context::SecureContext{CryptoBackendT} - function SecureVector(data, context::SecureContext{CryptoBackendT}) where CryptoBackendT - new{CryptoBackendT, typeof(data)}(data, context) + function SecureVector(data, length, context::SecureContext{CryptoBackendT}) where CryptoBackendT + new{CryptoBackendT, typeof(data)}(data, length, context) end end +Base.length(v::SecureVector) = v.length + struct PlainVector{CryptoBackendT <: AbstractCryptoBackend, DataT} data::DataT + length::Int context::SecureContext{CryptoBackendT} - function PlainVector(data, context::SecureContext{CryptoBackendT}) where CryptoBackendT - new{CryptoBackendT, typeof(data)}(data, context) + function PlainVector(data, length, context::SecureContext{CryptoBackendT}) where CryptoBackendT + new{CryptoBackendT, typeof(data)}(data, length, context) end end +Base.length(v::PlainVector) = v.length + Base.print(io::IO, plain_vector::PlainVector) = print(io, plain_vector.data) struct PrivateKey{CryptoBackendT <: AbstractCryptoBackend, KeyT} diff --git a/src/unencrypted.jl b/src/unencrypted.jl index 6493988..001ad68 100644 --- a/src/unencrypted.jl +++ b/src/unencrypted.jl @@ -10,14 +10,16 @@ init_multiplication!(context::SecureContext{<:Unencrypted}, private_key) = nothi init_rotation!(context::SecureContext{<:Unencrypted}, private_key, shifts) = nothing init_bootstrapping!(context::SecureContext{<:Unencrypted}, private_key) = nothing -# No constructor for `PlainVector` necessary since we can directly use the inner constructor +function PlainVector(data::Vector{<:Real}, context::SecureContext{<:Unencrypted}) + PlainVector(data, length(data), context) +end function encrypt(data::Vector{<:Real}, public_key, context::SecureContext{<:Unencrypted}) - SecureVector(data, context) + SecureVector(data, length(data), context) end function encrypt(plain_vector::PlainVector{<:Unencrypted}, public_key) - SecureVector(plain_vector.data, plain_vector.context) + SecureVector(plain_vector.data, length(plain_vector), plain_vector.context) end function decrypt!(plain_vector::PlainVector{<:Unencrypted}, @@ -28,7 +30,7 @@ function decrypt!(plain_vector::PlainVector{<:Unencrypted}, end function decrypt(secure_vector::SecureVector{<:Unencrypted}, private_key) - plain_vector = PlainVector(similar(secure_vector.data), secure_vector.context) + plain_vector = PlainVector(similar(secure_vector.data), length(secure_vector), secure_vector.context) decrypt!(plain_vector, secure_vector, private_key) end @@ -41,53 +43,53 @@ bootstrap!(secure_vector::SecureVector{<:Unencrypted}) = secure_vector ############################################################################################ function add(sv1::SecureVector{<:Unencrypted}, sv2::SecureVector{<:Unencrypted}) - SecureVector(sv1.data .+ sv2.data, sv1.context) + SecureVector(sv1.data .+ sv2.data, length(sv1), sv1.context) end function add(sv::SecureVector{<:Unencrypted}, pv::PlainVector{<:Unencrypted}) - SecureVector(sv.data .+ pv.data, sv.context) + SecureVector(sv.data .+ pv.data, length(sv), sv.context) end function add(sv::SecureVector{<:Unencrypted}, scalar::Real) - SecureVector(sv.data .+ scalar, sv.context) + SecureVector(sv.data .+ scalar, length(sv), sv.context) end function subtract(sv1::SecureVector{<:Unencrypted}, sv2::SecureVector{<:Unencrypted}) - SecureVector(sv1.data .- sv2.data, sv1.context) + SecureVector(sv1.data .- sv2.data, length(sv1), sv1.context) end function subtract(sv::SecureVector{<:Unencrypted}, pv::PlainVector{<:Unencrypted}) - SecureVector(sv.data .- pv.data, sv.context) + SecureVector(sv.data .- pv.data, length(sv), sv.context) end function subtract(pv::PlainVector{<:Unencrypted}, sv::SecureVector{<:Unencrypted}) - SecureVector(pv.data .- sv.data, sv.context) + SecureVector(pv.data .- sv.data, length(sv), sv.context) end function subtract(sv::SecureVector{<:Unencrypted}, scalar::Real) - SecureVector(sv.data .- scalar, sv.context) + SecureVector(sv.data .- scalar, length(sv), sv.context) end function subtract(scalar::Real, sv::SecureVector{<:Unencrypted}) - SecureVector(scalar .- sv.data, sv.context) + SecureVector(scalar .- sv.data, length(sv), sv.context) end function negate(sv::SecureVector{<:Unencrypted}) - SecureVector(-sv.data, sv.context) + SecureVector(-sv.data, length(sv), sv.context) end function multiply(sv1::SecureVector{<:Unencrypted}, sv2::SecureVector{<:Unencrypted}) - SecureVector(sv1.data .* sv2.data, sv1.context) + SecureVector(sv1.data .* sv2.data, length(sv1), sv1.context) end function multiply(sv::SecureVector{<:Unencrypted}, pv::PlainVector{<:Unencrypted}) - SecureVector(sv.data .* pv.data, sv.context) + SecureVector(sv.data .* pv.data, length(sv), sv.context) end function multiply(sv::SecureVector{<:Unencrypted}, scalar::Real) - SecureVector(sv.data .* scalar, sv.context) + SecureVector(sv.data .* scalar, length(sv), sv.context) end function rotate(sv::SecureVector{<:Unencrypted}, shift) - SecureVector(circshift(sv.data, shift), sv.context) + SecureVector(circshift(sv.data, shift), length(sv), sv.context) end From e4bb357c637ecca11eb2b8a448cc2917d75d61c2 Mon Sep 17 00:00:00 2001 From: Michael Schlottke-Lakemper Date: Sun, 21 Jan 2024 11:38:48 +0100 Subject: [PATCH 4/5] Enable pretty-printing of types --- src/SecureArithmetic.jl | 3 +++ src/types.jl | 28 ++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/src/SecureArithmetic.jl b/src/SecureArithmetic.jl index 055f6f1..922fc59 100644 --- a/src/SecureArithmetic.jl +++ b/src/SecureArithmetic.jl @@ -5,6 +5,9 @@ using OpenFHE: OpenFHE # Basic types export SecureContext, SecureVector, PlainVector +# Keys +export PrivateKey, PublicKey + # Backends export Unencrypted, OpenFHEBackend diff --git a/src/types.jl b/src/types.jl index a6192dc..d2f2308 100644 --- a/src/types.jl +++ b/src/types.jl @@ -4,6 +4,10 @@ struct SecureContext{CryptoBackendT <: AbstractCryptoBackend} backend::CryptoBackendT end +function Base.show(io::IO, v::SecureContext) + print("SecureContext{", backend_name(v), "}()") +end + struct SecureVector{CryptoBackendT <: AbstractCryptoBackend, DataT} data::DataT length::Int @@ -15,6 +19,9 @@ struct SecureVector{CryptoBackendT <: AbstractCryptoBackend, DataT} end Base.length(v::SecureVector) = v.length +function Base.show(io::IO, v::SecureVector) + print("SecureVector{", backend_name(v), "}(data=, length=$(v.length))") +end struct PlainVector{CryptoBackendT <: AbstractCryptoBackend, DataT} data::DataT @@ -27,6 +34,9 @@ struct PlainVector{CryptoBackendT <: AbstractCryptoBackend, DataT} end Base.length(v::PlainVector) = v.length +function Base.show(io::IO, v::PlainVector{CryptoBackendT}) where CryptoBackendT + print("PlainVector{", backend_name(v), "}(data=, length=$(v.length))") +end Base.print(io::IO, plain_vector::PlainVector) = print(io, plain_vector.data) @@ -39,6 +49,10 @@ struct PrivateKey{CryptoBackendT <: AbstractCryptoBackend, KeyT} end end +function Base.show(io::IO, key::PrivateKey{CryptoBackendT}) where CryptoBackendT + print("PrivateKey{", backend_name(key), "}()") +end + struct PublicKey{CryptoBackendT <: AbstractCryptoBackend, KeyT} public_key::KeyT context::SecureContext{CryptoBackendT} @@ -47,3 +61,17 @@ struct PublicKey{CryptoBackendT <: AbstractCryptoBackend, KeyT} new{CryptoBackendT, typeof(key)}(key, context) end end + +function Base.show(io::IO, key::PublicKey{CryptoBackendT}) where CryptoBackendT + print("PublicKey{", backend_name(key), "}()") +end + +# Get wrapper name of a potentially parametric type +# Copied from: https://github.com/ClapeyronThermo/Clapeyron.jl/blob/f40c282e2236ff68d91f37c39b5c1e4230ae9ef0/src/utils/core_utils.jl#L17 +# Original source: https://github.com/JuliaArrays/ArrayInterface.jl/blob/40d9a87be07ba323cca00f9e59e5285c13f7ee72/src/ArrayInterface.jl#L20 +# Note: prefixed by `__` since it is really, really dirty black magic internals we use here! +__parameterless_type(T) = Base.typename(T).wrapper + +# Convenience method for getting the human-readable backend name +backend_name(x::Union{SecureContext{T}, SecureVector{T}, PlainVector{T}, PrivateKey{T}, + PublicKey{T}}) where T = string(__parameterless_type(T)) From ee582d6aa08b8f5940ec254c6d0e977873ba329e Mon Sep 17 00:00:00 2001 From: Michael Schlottke-Lakemper Date: Sun, 21 Jan 2024 12:36:56 +0100 Subject: [PATCH 5/5] Improve code coverage --- src/openfhe.jl | 25 +++++++++++++------------ src/unencrypted.jl | 15 ++++++++------- test/test_unit.jl | 18 ++++++++++++++++++ 3 files changed, 39 insertions(+), 19 deletions(-) diff --git a/src/openfhe.jl b/src/openfhe.jl index 3d1c6e3..f5b3c03 100644 --- a/src/openfhe.jl +++ b/src/openfhe.jl @@ -5,11 +5,9 @@ end function get_crypto_context(context::SecureContext{<:OpenFHEBackend}) context.backend.crypto_context end -function get_crypto_context(secure_vector::SecureVector{<:OpenFHEBackend}) - get_crypto_context(secure_vector.context) -end -function get_crypto_context(plain_vector::PlainVector{<:OpenFHEBackend}) - get_crypto_context(plain_vector.context) +function get_crypto_context(v::Union{SecureVector{<:OpenFHEBackend}, + PlainVector{<:OpenFHEBackend}}) + get_crypto_context(v.context) end function generate_keys(context::SecureContext{<:OpenFHEBackend}) @@ -21,21 +19,24 @@ function generate_keys(context::SecureContext{<:OpenFHEBackend}) public_key, private_key end -function init_multiplication!(context::SecureContext{<:OpenFHEBackend}, private_key) +function init_multiplication!(context::SecureContext{<:OpenFHEBackend}, + private_key::PrivateKey) cc = get_crypto_context(context) OpenFHE.EvalMultKeyGen(cc, private_key.private_key) nothing end -function init_rotation!(context::SecureContext{<:OpenFHEBackend}, private_key, shifts) +function init_rotation!(context::SecureContext{<:OpenFHEBackend}, private_key::PrivateKey, + shifts) cc = get_crypto_context(context) OpenFHE.EvalRotateKeyGen(cc, private_key.private_key, shifts) nothing end -function init_bootstrapping!(context::SecureContext{<:OpenFHEBackend}, private_key) +function init_bootstrapping!(context::SecureContext{<:OpenFHEBackend}, + private_key::PrivateKey) cc = get_crypto_context(context) ring_dimension = OpenFHE.GetRingDimension(cc) num_slots = div(ring_dimension, 2) @@ -53,13 +54,13 @@ function PlainVector(data::Vector{<:Real}, context::SecureContext{<:OpenFHEBacke end function encrypt(data::Vector{<:Real}, public_key, context::SecureContext{<:OpenFHEBackend}) - plain_vector = PlainVector(data, length(data), context) + plain_vector = PlainVector(data, context) secure_vector = encrypt(plain_vector, public_key) secure_vector end -function encrypt(plain_vector::PlainVector{<:OpenFHEBackend}, public_key) +function encrypt(plain_vector::PlainVector{<:OpenFHEBackend}, public_key::PublicKey) context = plain_vector.context cc = get_crypto_context(context) ciphertext = OpenFHE.Encrypt(cc, public_key.public_key, plain_vector.data) @@ -69,7 +70,7 @@ function encrypt(plain_vector::PlainVector{<:OpenFHEBackend}, public_key) end function decrypt!(plain_vector::PlainVector{<:OpenFHEBackend}, - secure_vector::SecureVector{<:OpenFHEBackend}, private_key) + secure_vector::SecureVector{<:OpenFHEBackend}, private_key::PrivateKey) cc = get_crypto_context(secure_vector) OpenFHE.Decrypt(cc, private_key.private_key, secure_vector.data, plain_vector.data) @@ -77,7 +78,7 @@ function decrypt!(plain_vector::PlainVector{<:OpenFHEBackend}, plain_vector end -function decrypt(secure_vector::SecureVector{<:OpenFHEBackend}, private_key) +function decrypt(secure_vector::SecureVector{<:OpenFHEBackend}, private_key::PrivateKey) context = secure_vector.context plain_vector = PlainVector(OpenFHE.Plaintext(), length(secure_vector), context) diff --git a/src/unencrypted.jl b/src/unencrypted.jl index 001ad68..6152844 100644 --- a/src/unencrypted.jl +++ b/src/unencrypted.jl @@ -6,30 +6,31 @@ function generate_keys(context::SecureContext{<:Unencrypted}) PublicKey(context, nothing), PrivateKey(context, nothing) end -init_multiplication!(context::SecureContext{<:Unencrypted}, private_key) = nothing -init_rotation!(context::SecureContext{<:Unencrypted}, private_key, shifts) = nothing -init_bootstrapping!(context::SecureContext{<:Unencrypted}, private_key) = nothing +init_multiplication!(context::SecureContext{<:Unencrypted}, private_key::PrivateKey) = nothing +init_rotation!(context::SecureContext{<:Unencrypted}, private_key::PrivateKey, shifts) = nothing +init_bootstrapping!(context::SecureContext{<:Unencrypted}, private_key::PrivateKey) = nothing function PlainVector(data::Vector{<:Real}, context::SecureContext{<:Unencrypted}) PlainVector(data, length(data), context) end -function encrypt(data::Vector{<:Real}, public_key, context::SecureContext{<:Unencrypted}) +function encrypt(data::Vector{<:Real}, public_key::PublicKey, + context::SecureContext{<:Unencrypted}) SecureVector(data, length(data), context) end -function encrypt(plain_vector::PlainVector{<:Unencrypted}, public_key) +function encrypt(plain_vector::PlainVector{<:Unencrypted}, public_key::PublicKey) SecureVector(plain_vector.data, length(plain_vector), plain_vector.context) end function decrypt!(plain_vector::PlainVector{<:Unencrypted}, - secure_vector::SecureVector{<:Unencrypted}, private_key) + secure_vector::SecureVector{<:Unencrypted}, private_key::PrivateKey) plain_vector.data .= secure_vector.data plain_vector end -function decrypt(secure_vector::SecureVector{<:Unencrypted}, private_key) +function decrypt(secure_vector::SecureVector{<:Unencrypted}, private_key::PrivateKey) plain_vector = PlainVector(similar(secure_vector.data), length(secure_vector), secure_vector.context) decrypt!(plain_vector, secure_vector, private_key) diff --git a/test/test_unit.jl b/test/test_unit.jl index b4be383..cd0b0b5 100644 --- a/test/test_unit.jl +++ b/test/test_unit.jl @@ -55,6 +55,7 @@ for backend in ((; name = "OpenFHE", BackendT = OpenFHEBackend, context = contex @testset verbose=true showtiming=true "encrypt" begin @test encrypt(pv1, public_key) isa SecureVector + @test encrypt([1.0, 2.0, 3.0], public_key, context) isa SecureVector end sv1 = encrypt(pv1, public_key) @@ -87,6 +88,23 @@ for backend in ((; name = "OpenFHE", BackendT = OpenFHEBackend, context = contex @testset verbose=true showtiming=true "negate" begin @test -sv2 isa SecureVector end + + @testset verbose=true showtiming=true "show" begin + @test_nowarn show(context) + println() + + @test_nowarn show(pv1) + println() + + @test_nowarn show(sv1) + println() + + @test_nowarn show(public_key) + println() + + @test_nowarn show(private_key) + println() + end end end