diff --git a/Project.toml b/Project.toml index 0e4d6eb..56bc5ff 100644 --- a/Project.toml +++ b/Project.toml @@ -7,5 +7,5 @@ version = "0.1.0" OpenFHE = "77ce9b8e-ecf5-45d1-bd8a-d31f384f2f95" [compat] -OpenFHE = "0.1" +OpenFHE = "0.1.5" julia = "1.8" diff --git a/src/arithmetic.jl b/src/arithmetic.jl index 0662f4c..10d757e 100644 --- a/src/arithmetic.jl +++ b/src/arithmetic.jl @@ -1,15 +1,22 @@ +# Add Base.:+(sv1::SecureVector, sv2::SecureVector) = add(sv1, sv2) Base.:+(sv::SecureVector, pv::PlainVector) = add(sv, pv) Base.:+(pv::PlainVector, sv::SecureVector) = add(sv, pv) +# Subtract Base.:-(sv1::SecureVector, sv2::SecureVector) = subtract(sv1, sv2) Base.:-(sv::SecureVector, pv::PlainVector) = subtract(sv, pv) Base.:-(pv::PlainVector, sv::SecureVector) = subtract(pv, sv) +# Negate +Base.:-(sv::SecureVector) = negate(sv) + +# Multiply Base.:*(sv1::SecureVector, sv2::SecureVector) = multiply(sv1, sv2) Base.:*(sv::SecureVector, pv::PlainVector) = multiply(sv, pv) Base.:*(pv::PlainVector, sv::SecureVector) = multiply(sv, pv) Base.:*(sv::SecureVector, scalar::Real) = multiply(sv, scalar) Base.:*(scalar::Real, sv::SecureVector) = multiply(sv, scalar) +# Circular shift Base.circshift(sv::SecureVector, shift::Integer) = rotate(sv, shift) diff --git a/src/openfhe.jl b/src/openfhe.jl index 3842802..a328033 100644 --- a/src/openfhe.jl +++ b/src/openfhe.jl @@ -101,6 +101,14 @@ function add(sv1::SecureVector{<:OpenFHEBackend}, sv2::SecureVector{<:OpenFHEBac secure_vector end +function add(sv::SecureVector{<:OpenFHEBackend}, pv::PlainVector{<:OpenFHEBackend}) + cc = get_crypto_context(sv) + ciphertext = OpenFHE.EvalAdd(cc, sv.ciphertext, pv.plaintext) + secure_vector = SecureVector(ciphertext, sv.context) + + secure_vector +end + function subtract(sv1::SecureVector{<:OpenFHEBackend}, sv2::SecureVector{<:OpenFHEBackend}) cc = get_crypto_context(sv1) ciphertext = OpenFHE.EvalSub(cc, sv1.ciphertext, sv2.ciphertext) @@ -109,6 +117,30 @@ function subtract(sv1::SecureVector{<:OpenFHEBackend}, sv2::SecureVector{<:OpenF secure_vector end +function subtract(sv::SecureVector{<:OpenFHEBackend}, pv::PlainVector{<:OpenFHEBackend}) + cc = get_crypto_context(sv) + ciphertext = OpenFHE.EvalSub(cc, sv.ciphertext, pv.plaintext) + secure_vector = SecureVector(ciphertext, sv.context) + + secure_vector +end + +function subtract(pv::PlainVector{<:OpenFHEBackend}, sv::SecureVector{<:OpenFHEBackend}) + cc = get_crypto_context(sv) + ciphertext = OpenFHE.EvalSub(cc, pv.plaintext, sv.ciphertext) + secure_vector = SecureVector(ciphertext, sv.context) + + secure_vector +end + +function negate(sv::SecureVector{<:OpenFHEBackend}) + cc = get_crypto_context(sv) + ciphertext = OpenFHE.EvalNegate(cc, sv.ciphertext) + secure_vector = SecureVector(ciphertext, sv.context) + + secure_vector +end + function multiply(sv1::SecureVector{<:OpenFHEBackend}, sv2::SecureVector{<:OpenFHEBackend}) cc = get_crypto_context(sv1) ciphertext = OpenFHE.EvalMult(cc, sv1.ciphertext, sv2.ciphertext) diff --git a/src/unencrypted.jl b/src/unencrypted.jl index ae6cff9..9b39d78 100644 --- a/src/unencrypted.jl +++ b/src/unencrypted.jl @@ -39,10 +39,26 @@ function add(sv1::SecureVector{<:Unencrypted}, sv2::SecureVector{<:Unencrypted}) SecureVector(sv1.ciphertext .+ sv2.ciphertext, sv1.context) end +function add(sv::SecureVector{<:Unencrypted}, pv::PlainVector{<:Unencrypted}) + SecureVector(sv.ciphertext .+ pv.plaintext, sv.context) +end + function subtract(sv1::SecureVector{<:Unencrypted}, sv2::SecureVector{<:Unencrypted}) SecureVector(sv1.ciphertext .- sv2.ciphertext, sv1.context) end +function subtract(sv::SecureVector{<:Unencrypted}, pv::PlainVector{<:Unencrypted}) + SecureVector(sv.ciphertext .- pv.plaintext, sv.context) +end + +function subtract(pv::PlainVector{<:Unencrypted}, sv::SecureVector{<:Unencrypted}) + SecureVector(pv.plaintext .- sv.ciphertext, sv.context) +end + +function negate(sv::SecureVector{<:Unencrypted}) + SecureVector(-sv.ciphertext, sv.context) +end + function multiply(sv1::SecureVector{<:Unencrypted}, sv2::SecureVector{<:Unencrypted}) SecureVector(sv1.ciphertext .* sv2.ciphertext, sv1.context) end diff --git a/test/runtests.jl b/test/runtests.jl index c8217d2..7c0c6ef 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,5 +2,6 @@ using Test @time @testset verbose=true showtiming=true "SecureArithmetic.jl tests" begin include("test_examples.jl") + include("test_unit.jl") end diff --git a/test/test_unit.jl b/test/test_unit.jl new file mode 100644 index 0000000..d843cf0 --- /dev/null +++ b/test/test_unit.jl @@ -0,0 +1,92 @@ +module TestUnit + +using Test +using SecureArithmetic +using OpenFHE + +@testset verbose=true showtiming=true "test_unit.jl" begin + +# Set up OpenFHE backend +multiplicative_depth = 1 +scaling_modulus = 50 +batch_size = 8 + +parameters = CCParams{CryptoContextCKKSRNS}() +SetMultiplicativeDepth(parameters, multiplicative_depth) +SetScalingModSize(parameters, scaling_modulus) +SetBatchSize(parameters, batch_size) + +cc = GenCryptoContext(parameters) +Enable(cc, PKE) +Enable(cc, KEYSWITCH) +Enable(cc, LEVELEDSHE) +context_openfhe = SecureContext(OpenFHEBackend(cc)) + +# Set up unencrypted backend +context_unencrypted = SecureContext(Unencrypted()) + +for backend in ((; name = "OpenFHE", BackendT = OpenFHEBackend, context = context_openfhe), + (; name = "Unencrypted", BackendT = Unencrypted, context = context_unencrypted)) + (; name, BackendT, context) = backend + + @testset verbose=true showtiming=true "$name" begin + @testset verbose=true showtiming=true "generate_keys" begin + @test_nowarn generate_keys(context) + end + public_key, private_key = generate_keys(context) + + @testset verbose=true showtiming=true "init_multiplication!" begin + @test_nowarn init_multiplication!(context, private_key) + end + + @testset verbose=true showtiming=true "init_rotation!" begin + @test_nowarn init_rotation!(context, private_key, [1, -2]) + end + + x1 = [0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0, 5.0] + x2 = [5.0, 4.0, 3.0, 2.0, 1.0, 0.75, 0.5, 0.25] + + @testset verbose=true showtiming=true "PlainVector" begin + @test PlainVector(x1, context) isa PlainVector + end + + pv1 = PlainVector(x1, context) + pv2 = PlainVector(x2, context) + + @testset verbose=true showtiming=true "encrypt" begin + @test encrypt(pv1, public_key) isa SecureVector + end + + sv1 = encrypt(pv1, public_key) + sv2 = encrypt(pv2, public_key) + + @testset verbose=true showtiming=true "add" begin + @test sv1 + sv2 isa SecureVector + @test sv1 + pv1 isa SecureVector + @test pv1 + sv1 isa SecureVector + end + + @testset verbose=true showtiming=true "subtract" begin + @test sv1 - sv2 isa SecureVector + @test sv1 - pv1 isa SecureVector + @test pv1 - sv1 isa SecureVector + end + + @testset verbose=true showtiming=true "multiply" begin + @test sv1 * sv2 isa SecureVector + @test sv1 * pv1 isa SecureVector + @test pv1 * sv1 isa SecureVector + @test sv1 * 3 isa SecureVector + @test 4 * sv1 isa SecureVector + end + + @testset verbose=true showtiming=true "negate" begin + @test -sv2 isa SecureVector + end + end +end + +end # @testset "test_unit.jl" + +end # module +