From 2e5308c675d5d4c007de0a40e100c4e7977d7ccb Mon Sep 17 00:00:00 2001 From: Michael Schlottke-Lakemper Date: Sat, 20 Jan 2024 08:22:13 +0100 Subject: [PATCH] Support scalar operations for addition and subtraction --- src/arithmetic.jl | 4 ++++ src/openfhe.jl | 29 +++++++++++++++++++++++++++++ src/unencrypted.jl | 17 +++++++++++++++++ test/test_unit.jl | 4 ++++ 4 files changed, 54 insertions(+) diff --git a/src/arithmetic.jl b/src/arithmetic.jl index 10d757e..2bb78a5 100644 --- a/src/arithmetic.jl +++ b/src/arithmetic.jl @@ -2,11 +2,15 @@ Base.:+(sv1::SecureVector, sv2::SecureVector) = add(sv1, sv2) Base.:+(sv::SecureVector, pv::PlainVector) = add(sv, pv) Base.:+(pv::PlainVector, sv::SecureVector) = add(sv, pv) +Base.:+(sv::SecureVector, scalar::Real) = add(sv, scalar) +Base.:+(scalar::Real, sv::SecureVector) = add(sv, scalar) # Subtract Base.:-(sv1::SecureVector, sv2::SecureVector) = subtract(sv1, sv2) Base.:-(sv::SecureVector, pv::PlainVector) = subtract(sv, pv) Base.:-(pv::PlainVector, sv::SecureVector) = subtract(pv, sv) +Base.:-(sv::SecureVector, scalar::Real) = subtract(sv, scalar) +Base.:-(scalar::Real, sv::SecureVector) = subtract(scalar, sv) # Negate Base.:-(sv::SecureVector) = negate(sv) diff --git a/src/openfhe.jl b/src/openfhe.jl index a328033..690b00a 100644 --- a/src/openfhe.jl +++ b/src/openfhe.jl @@ -93,6 +93,11 @@ function bootstrap!(secure_vector::SecureVector{<:OpenFHEBackend}) secure_vector end + +############################################################################################ +# Arithmetic operations +############################################################################################ + function add(sv1::SecureVector{<:OpenFHEBackend}, sv2::SecureVector{<:OpenFHEBackend}) cc = get_crypto_context(sv1) ciphertext = OpenFHE.EvalAdd(cc, sv1.ciphertext, sv2.ciphertext) @@ -109,6 +114,14 @@ function add(sv::SecureVector{<:OpenFHEBackend}, pv::PlainVector{<:OpenFHEBacken secure_vector end +function add(sv::SecureVector{<:OpenFHEBackend}, scalar::Real) + cc = get_crypto_context(sv) + ciphertext = OpenFHE.EvalAdd(cc, sv.ciphertext, scalar) + secure_vector = SecureVector(ciphertext, sv.context) + + secure_vector +end + function subtract(sv1::SecureVector{<:OpenFHEBackend}, sv2::SecureVector{<:OpenFHEBackend}) cc = get_crypto_context(sv1) ciphertext = OpenFHE.EvalSub(cc, sv1.ciphertext, sv2.ciphertext) @@ -133,6 +146,22 @@ function subtract(pv::PlainVector{<:OpenFHEBackend}, sv::SecureVector{<:OpenFHEB secure_vector end +function subtract(sv::SecureVector{<:OpenFHEBackend}, scalar::Real) + cc = get_crypto_context(sv) + ciphertext = OpenFHE.EvalSub(cc, sv.ciphertext, scalar) + secure_vector = SecureVector(ciphertext, sv.context) + + secure_vector +end + +function subtract(scalar::Real, sv::SecureVector{<:OpenFHEBackend}) + cc = get_crypto_context(sv) + ciphertext = OpenFHE.EvalSub(cc, scalar, sv.ciphertext) + secure_vector = SecureVector(ciphertext, sv.context) + + secure_vector +end + function negate(sv::SecureVector{<:OpenFHEBackend}) cc = get_crypto_context(sv) ciphertext = OpenFHE.EvalNegate(cc, sv.ciphertext) diff --git a/src/unencrypted.jl b/src/unencrypted.jl index 9b39d78..7b15470 100644 --- a/src/unencrypted.jl +++ b/src/unencrypted.jl @@ -35,6 +35,11 @@ end bootstrap!(secure_vector::SecureVector{<:Unencrypted}) = secure_vector + +############################################################################################ +# Arithmetic operations +############################################################################################ + function add(sv1::SecureVector{<:Unencrypted}, sv2::SecureVector{<:Unencrypted}) SecureVector(sv1.ciphertext .+ sv2.ciphertext, sv1.context) end @@ -43,6 +48,10 @@ function add(sv::SecureVector{<:Unencrypted}, pv::PlainVector{<:Unencrypted}) SecureVector(sv.ciphertext .+ pv.plaintext, sv.context) end +function add(sv::SecureVector{<:Unencrypted}, scalar::Real) + SecureVector(sv.ciphertext .+ scalar, sv.context) +end + function subtract(sv1::SecureVector{<:Unencrypted}, sv2::SecureVector{<:Unencrypted}) SecureVector(sv1.ciphertext .- sv2.ciphertext, sv1.context) end @@ -55,6 +64,14 @@ function subtract(pv::PlainVector{<:Unencrypted}, sv::SecureVector{<:Unencrypted SecureVector(pv.plaintext .- sv.ciphertext, sv.context) end +function subtract(sv::SecureVector{<:Unencrypted}, scalar::Real) + SecureVector(sv.ciphertext .- scalar, sv.context) +end + +function subtract(scalar::Real, sv::SecureVector{<:Unencrypted}) + SecureVector(scalar .- sv.ciphertext, sv.context) +end + function negate(sv::SecureVector{<:Unencrypted}) SecureVector(-sv.ciphertext, sv.context) end diff --git a/test/test_unit.jl b/test/test_unit.jl index d843cf0..b4be383 100644 --- a/test/test_unit.jl +++ b/test/test_unit.jl @@ -64,12 +64,16 @@ for backend in ((; name = "OpenFHE", BackendT = OpenFHEBackend, context = contex @test sv1 + sv2 isa SecureVector @test sv1 + pv1 isa SecureVector @test pv1 + sv1 isa SecureVector + @test sv1 + 3 isa SecureVector + @test 4 + sv1 isa SecureVector end @testset verbose=true showtiming=true "subtract" begin @test sv1 - sv2 isa SecureVector @test sv1 - pv1 isa SecureVector @test pv1 - sv1 isa SecureVector + @test sv1 - 3 isa SecureVector + @test 4 - sv1 isa SecureVector end @testset verbose=true showtiming=true "multiply" begin