Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ version = "0.1.0"
OpenFHE = "77ce9b8e-ecf5-45d1-bd8a-d31f384f2f95"

[compat]
OpenFHE = "0.1"
OpenFHE = "0.1.5"
julia = "1.8"
7 changes: 7 additions & 0 deletions src/arithmetic.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
# Add
Base.:+(sv1::SecureVector, sv2::SecureVector) = add(sv1, sv2)
Base.:+(sv::SecureVector, pv::PlainVector) = add(sv, pv)
Base.:+(pv::PlainVector, sv::SecureVector) = add(sv, pv)

# Subtract
Base.:-(sv1::SecureVector, sv2::SecureVector) = subtract(sv1, sv2)
Base.:-(sv::SecureVector, pv::PlainVector) = subtract(sv, pv)
Base.:-(pv::PlainVector, sv::SecureVector) = subtract(pv, sv)

# Negate
Base.:-(sv::SecureVector) = negate(sv)

# Multiply
Base.:*(sv1::SecureVector, sv2::SecureVector) = multiply(sv1, sv2)
Base.:*(sv::SecureVector, pv::PlainVector) = multiply(sv, pv)
Base.:*(pv::PlainVector, sv::SecureVector) = multiply(sv, pv)
Base.:*(sv::SecureVector, scalar::Real) = multiply(sv, scalar)
Base.:*(scalar::Real, sv::SecureVector) = multiply(sv, scalar)

# Circular shift
Base.circshift(sv::SecureVector, shift::Integer) = rotate(sv, shift)
32 changes: 32 additions & 0 deletions src/openfhe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ function add(sv1::SecureVector{<:OpenFHEBackend}, sv2::SecureVector{<:OpenFHEBac
secure_vector
end

function add(sv::SecureVector{<:OpenFHEBackend}, pv::PlainVector{<:OpenFHEBackend})
cc = get_crypto_context(sv)
ciphertext = OpenFHE.EvalAdd(cc, sv.ciphertext, pv.plaintext)
secure_vector = SecureVector(ciphertext, sv.context)

secure_vector
end

function subtract(sv1::SecureVector{<:OpenFHEBackend}, sv2::SecureVector{<:OpenFHEBackend})
cc = get_crypto_context(sv1)
ciphertext = OpenFHE.EvalSub(cc, sv1.ciphertext, sv2.ciphertext)
Expand All @@ -109,6 +117,30 @@ function subtract(sv1::SecureVector{<:OpenFHEBackend}, sv2::SecureVector{<:OpenF
secure_vector
end

function subtract(sv::SecureVector{<:OpenFHEBackend}, pv::PlainVector{<:OpenFHEBackend})
cc = get_crypto_context(sv)
ciphertext = OpenFHE.EvalSub(cc, sv.ciphertext, pv.plaintext)
secure_vector = SecureVector(ciphertext, sv.context)

secure_vector
end

function subtract(pv::PlainVector{<:OpenFHEBackend}, sv::SecureVector{<:OpenFHEBackend})
cc = get_crypto_context(sv)
ciphertext = OpenFHE.EvalSub(cc, pv.plaintext, sv.ciphertext)
secure_vector = SecureVector(ciphertext, sv.context)

secure_vector
end

function negate(sv::SecureVector{<:OpenFHEBackend})
cc = get_crypto_context(sv)
ciphertext = OpenFHE.EvalNegate(cc, sv.ciphertext)
secure_vector = SecureVector(ciphertext, sv.context)

secure_vector
end

function multiply(sv1::SecureVector{<:OpenFHEBackend}, sv2::SecureVector{<:OpenFHEBackend})
cc = get_crypto_context(sv1)
ciphertext = OpenFHE.EvalMult(cc, sv1.ciphertext, sv2.ciphertext)
Expand Down
16 changes: 16 additions & 0 deletions src/unencrypted.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,26 @@ function add(sv1::SecureVector{<:Unencrypted}, sv2::SecureVector{<:Unencrypted})
SecureVector(sv1.ciphertext .+ sv2.ciphertext, sv1.context)
end

function add(sv::SecureVector{<:Unencrypted}, pv::PlainVector{<:Unencrypted})
SecureVector(sv.ciphertext .+ pv.plaintext, sv.context)
end

function subtract(sv1::SecureVector{<:Unencrypted}, sv2::SecureVector{<:Unencrypted})
SecureVector(sv1.ciphertext .- sv2.ciphertext, sv1.context)
end

function subtract(sv::SecureVector{<:Unencrypted}, pv::PlainVector{<:Unencrypted})
SecureVector(sv.ciphertext .- pv.plaintext, sv.context)
end

function subtract(pv::PlainVector{<:Unencrypted}, sv::SecureVector{<:Unencrypted})
SecureVector(pv.plaintext .- sv.ciphertext, sv.context)
end

function negate(sv::SecureVector{<:Unencrypted})
SecureVector(-sv.ciphertext, sv.context)
end

function multiply(sv1::SecureVector{<:Unencrypted}, sv2::SecureVector{<:Unencrypted})
SecureVector(sv1.ciphertext .* sv2.ciphertext, sv1.context)
end
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ using Test

@time @testset verbose=true showtiming=true "SecureArithmetic.jl tests" begin
include("test_examples.jl")
include("test_unit.jl")
end

92 changes: 92 additions & 0 deletions test/test_unit.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
module TestUnit

using Test
using SecureArithmetic
using OpenFHE

@testset verbose=true showtiming=true "test_unit.jl" begin

# Set up OpenFHE backend
multiplicative_depth = 1
scaling_modulus = 50
batch_size = 8

parameters = CCParams{CryptoContextCKKSRNS}()
SetMultiplicativeDepth(parameters, multiplicative_depth)
SetScalingModSize(parameters, scaling_modulus)
SetBatchSize(parameters, batch_size)

cc = GenCryptoContext(parameters)
Enable(cc, PKE)
Enable(cc, KEYSWITCH)
Enable(cc, LEVELEDSHE)
context_openfhe = SecureContext(OpenFHEBackend(cc))

# Set up unencrypted backend
context_unencrypted = SecureContext(Unencrypted())

for backend in ((; name = "OpenFHE", BackendT = OpenFHEBackend, context = context_openfhe),
(; name = "Unencrypted", BackendT = Unencrypted, context = context_unencrypted))
(; name, BackendT, context) = backend

@testset verbose=true showtiming=true "$name" begin
@testset verbose=true showtiming=true "generate_keys" begin
@test_nowarn generate_keys(context)
end
public_key, private_key = generate_keys(context)

@testset verbose=true showtiming=true "init_multiplication!" begin
@test_nowarn init_multiplication!(context, private_key)
end

@testset verbose=true showtiming=true "init_rotation!" begin
@test_nowarn init_rotation!(context, private_key, [1, -2])
end

x1 = [0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0, 5.0]
x2 = [5.0, 4.0, 3.0, 2.0, 1.0, 0.75, 0.5, 0.25]

@testset verbose=true showtiming=true "PlainVector" begin
@test PlainVector(x1, context) isa PlainVector
end

pv1 = PlainVector(x1, context)
pv2 = PlainVector(x2, context)

@testset verbose=true showtiming=true "encrypt" begin
@test encrypt(pv1, public_key) isa SecureVector
end

sv1 = encrypt(pv1, public_key)
sv2 = encrypt(pv2, public_key)

@testset verbose=true showtiming=true "add" begin
@test sv1 + sv2 isa SecureVector
@test sv1 + pv1 isa SecureVector
@test pv1 + sv1 isa SecureVector
end

@testset verbose=true showtiming=true "subtract" begin
@test sv1 - sv2 isa SecureVector
@test sv1 - pv1 isa SecureVector
@test pv1 - sv1 isa SecureVector
end

@testset verbose=true showtiming=true "multiply" begin
@test sv1 * sv2 isa SecureVector
@test sv1 * pv1 isa SecureVector
@test pv1 * sv1 isa SecureVector
@test sv1 * 3 isa SecureVector
@test 4 * sv1 isa SecureVector
end

@testset verbose=true showtiming=true "negate" begin
@test -sv2 isa SecureVector
end
end
end

end # @testset "test_unit.jl"

end # module