Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/SecureArithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ using OpenFHE: OpenFHE
# Basic types
export SecureContext, SecureVector, PlainVector

# Keys
export PrivateKey, PublicKey

# Backends
export Unencrypted, OpenFHEBackend

Expand Down
93 changes: 47 additions & 46 deletions src/openfhe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@ end
function get_crypto_context(context::SecureContext{<:OpenFHEBackend})
context.backend.crypto_context
end
function get_crypto_context(secure_vector::SecureVector{<:OpenFHEBackend})
get_crypto_context(secure_vector.context)
end
function get_crypto_context(plain_vector::PlainVector{<:OpenFHEBackend})
get_crypto_context(plain_vector.context)
function get_crypto_context(v::Union{SecureVector{<:OpenFHEBackend},
PlainVector{<:OpenFHEBackend}})
get_crypto_context(v.context)
end

function generate_keys(context::SecureContext{<:OpenFHEBackend})
Expand All @@ -21,21 +19,24 @@ function generate_keys(context::SecureContext{<:OpenFHEBackend})
public_key, private_key
end

function init_multiplication!(context::SecureContext{<:OpenFHEBackend}, private_key)
function init_multiplication!(context::SecureContext{<:OpenFHEBackend},
private_key::PrivateKey)
cc = get_crypto_context(context)
OpenFHE.EvalMultKeyGen(cc, private_key.private_key)

nothing
end

function init_rotation!(context::SecureContext{<:OpenFHEBackend}, private_key, shifts)
function init_rotation!(context::SecureContext{<:OpenFHEBackend}, private_key::PrivateKey,
shifts)
cc = get_crypto_context(context)
OpenFHE.EvalRotateKeyGen(cc, private_key.private_key, shifts)

nothing
end

function init_bootstrapping!(context::SecureContext{<:OpenFHEBackend}, private_key)
function init_bootstrapping!(context::SecureContext{<:OpenFHEBackend},
private_key::PrivateKey)
cc = get_crypto_context(context)
ring_dimension = OpenFHE.GetRingDimension(cc)
num_slots = div(ring_dimension, 2)
Expand All @@ -47,39 +48,39 @@ end
function PlainVector(data::Vector{<:Real}, context::SecureContext{<:OpenFHEBackend})
cc = get_crypto_context(context)
plaintext = OpenFHE.MakeCKKSPackedPlaintext(cc, data)
plain_vector = PlainVector(plaintext, context)
plain_vector = PlainVector(plaintext, length(data), context)

plain_vector
end

function encrypt(data::Vector{<:Real}, public_key, context::SecureContext{<:OpenFHEBackend})
plain_vector = PlainVector(context, data)
secure_vector = encrypt(context, public_key, plain_vector)
plain_vector = PlainVector(data, context)
secure_vector = encrypt(plain_vector, public_key)

secure_vector
end

function encrypt(plain_vector::PlainVector{<:OpenFHEBackend}, public_key)
function encrypt(plain_vector::PlainVector{<:OpenFHEBackend}, public_key::PublicKey)
context = plain_vector.context
cc = get_crypto_context(context)
ciphertext = OpenFHE.Encrypt(cc, public_key.public_key, plain_vector.plaintext)
secure_vector = SecureVector(ciphertext, context)
ciphertext = OpenFHE.Encrypt(cc, public_key.public_key, plain_vector.data)
secure_vector = SecureVector(ciphertext, length(plain_vector), context)

secure_vector
end

function decrypt!(plain_vector::PlainVector{<:OpenFHEBackend},
secure_vector::SecureVector{<:OpenFHEBackend}, private_key)
secure_vector::SecureVector{<:OpenFHEBackend}, private_key::PrivateKey)
cc = get_crypto_context(secure_vector)
OpenFHE.Decrypt(cc, private_key.private_key, secure_vector.ciphertext,
plain_vector.plaintext)
OpenFHE.Decrypt(cc, private_key.private_key, secure_vector.data,
plain_vector.data)

plain_vector
end

function decrypt(secure_vector::SecureVector{<:OpenFHEBackend}, private_key)
function decrypt(secure_vector::SecureVector{<:OpenFHEBackend}, private_key::PrivateKey)
context = secure_vector.context
plain_vector = PlainVector(OpenFHE.Plaintext(), context)
plain_vector = PlainVector(OpenFHE.Plaintext(), length(secure_vector), context)

decrypt!(plain_vector, secure_vector, private_key)
end
Expand All @@ -88,7 +89,7 @@ end
function bootstrap!(secure_vector::SecureVector{<:OpenFHEBackend})
context = secure_vector.context
cc = get_crypto_context(context)
OpenFHE.EvalBootstrap(cc, secure_vector.ciphertext)
OpenFHE.EvalBootstrap(cc, secure_vector.data)

secure_vector
end
Expand All @@ -100,105 +101,105 @@ end

function add(sv1::SecureVector{<:OpenFHEBackend}, sv2::SecureVector{<:OpenFHEBackend})
cc = get_crypto_context(sv1)
ciphertext = OpenFHE.EvalAdd(cc, sv1.ciphertext, sv2.ciphertext)
secure_vector = SecureVector(ciphertext, sv1.context)
ciphertext = OpenFHE.EvalAdd(cc, sv1.data, sv2.data)
secure_vector = SecureVector(ciphertext, length(sv1), sv1.context)

secure_vector
end

function add(sv::SecureVector{<:OpenFHEBackend}, pv::PlainVector{<:OpenFHEBackend})
cc = get_crypto_context(sv)
ciphertext = OpenFHE.EvalAdd(cc, sv.ciphertext, pv.plaintext)
secure_vector = SecureVector(ciphertext, sv.context)
ciphertext = OpenFHE.EvalAdd(cc, sv.data, pv.data)
secure_vector = SecureVector(ciphertext, length(sv), sv.context)

secure_vector
end

function add(sv::SecureVector{<:OpenFHEBackend}, scalar::Real)
cc = get_crypto_context(sv)
ciphertext = OpenFHE.EvalAdd(cc, sv.ciphertext, scalar)
secure_vector = SecureVector(ciphertext, sv.context)
ciphertext = OpenFHE.EvalAdd(cc, sv.data, scalar)
secure_vector = SecureVector(ciphertext, length(sv), sv.context)

secure_vector
end

function subtract(sv1::SecureVector{<:OpenFHEBackend}, sv2::SecureVector{<:OpenFHEBackend})
cc = get_crypto_context(sv1)
ciphertext = OpenFHE.EvalSub(cc, sv1.ciphertext, sv2.ciphertext)
secure_vector = SecureVector(ciphertext, sv1.context)
ciphertext = OpenFHE.EvalSub(cc, sv1.data, sv2.data)
secure_vector = SecureVector(ciphertext, length(sv1), sv1.context)

secure_vector
end

function subtract(sv::SecureVector{<:OpenFHEBackend}, pv::PlainVector{<:OpenFHEBackend})
cc = get_crypto_context(sv)
ciphertext = OpenFHE.EvalSub(cc, sv.ciphertext, pv.plaintext)
secure_vector = SecureVector(ciphertext, sv.context)
ciphertext = OpenFHE.EvalSub(cc, sv.data, pv.data)
secure_vector = SecureVector(ciphertext, length(sv), sv.context)

secure_vector
end

function subtract(pv::PlainVector{<:OpenFHEBackend}, sv::SecureVector{<:OpenFHEBackend})
cc = get_crypto_context(sv)
ciphertext = OpenFHE.EvalSub(cc, pv.plaintext, sv.ciphertext)
secure_vector = SecureVector(ciphertext, sv.context)
ciphertext = OpenFHE.EvalSub(cc, pv.data, sv.data)
secure_vector = SecureVector(ciphertext, length(sv), sv.context)

secure_vector
end

function subtract(sv::SecureVector{<:OpenFHEBackend}, scalar::Real)
cc = get_crypto_context(sv)
ciphertext = OpenFHE.EvalSub(cc, sv.ciphertext, scalar)
secure_vector = SecureVector(ciphertext, sv.context)
ciphertext = OpenFHE.EvalSub(cc, sv.data, scalar)
secure_vector = SecureVector(ciphertext, length(sv), sv.context)

secure_vector
end

function subtract(scalar::Real, sv::SecureVector{<:OpenFHEBackend})
cc = get_crypto_context(sv)
ciphertext = OpenFHE.EvalSub(cc, scalar, sv.ciphertext)
secure_vector = SecureVector(ciphertext, sv.context)
ciphertext = OpenFHE.EvalSub(cc, scalar, sv.data)
secure_vector = SecureVector(ciphertext, length(sv), sv.context)

secure_vector
end

function negate(sv::SecureVector{<:OpenFHEBackend})
cc = get_crypto_context(sv)
ciphertext = OpenFHE.EvalNegate(cc, sv.ciphertext)
secure_vector = SecureVector(ciphertext, sv.context)
ciphertext = OpenFHE.EvalNegate(cc, sv.data)
secure_vector = SecureVector(ciphertext, length(sv), sv.context)

secure_vector
end

function multiply(sv1::SecureVector{<:OpenFHEBackend}, sv2::SecureVector{<:OpenFHEBackend})
cc = get_crypto_context(sv1)
ciphertext = OpenFHE.EvalMult(cc, sv1.ciphertext, sv2.ciphertext)
secure_vector = SecureVector(ciphertext, sv1.context)
ciphertext = OpenFHE.EvalMult(cc, sv1.data, sv2.data)
secure_vector = SecureVector(ciphertext, length(sv1), sv1.context)

secure_vector
end

function multiply(sv::SecureVector{<:OpenFHEBackend}, pv::PlainVector{<:OpenFHEBackend})
cc = get_crypto_context(sv)
ciphertext = OpenFHE.EvalMult(cc, sv.ciphertext, pv.plaintext)
secure_vector = SecureVector(ciphertext, sv.context)
ciphertext = OpenFHE.EvalMult(cc, sv.data, pv.data)
secure_vector = SecureVector(ciphertext, length(sv), sv.context)

secure_vector
end

function multiply(sv::SecureVector{<:OpenFHEBackend}, scalar::Real)
cc = get_crypto_context(sv)
ciphertext = OpenFHE.EvalMult(cc, sv.ciphertext, scalar)
secure_vector = SecureVector(ciphertext, sv.context)
ciphertext = OpenFHE.EvalMult(cc, sv.data, scalar)
secure_vector = SecureVector(ciphertext, length(sv), sv.context)

secure_vector
end

function rotate(sv::SecureVector{<:OpenFHEBackend}, shift)
cc = get_crypto_context(sv)
# We use `-shift` to match Julia's usual `circshift` direction
ciphertext = OpenFHE.EvalRotate(cc, sv.ciphertext, -shift)
secure_vector = SecureVector(ciphertext, sv.context)
ciphertext = OpenFHE.EvalRotate(cc, sv.data, -shift)
secure_vector = SecureVector(ciphertext, length(sv), sv.context)

secure_vector
end
52 changes: 43 additions & 9 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,41 @@ struct SecureContext{CryptoBackendT <: AbstractCryptoBackend}
backend::CryptoBackendT
end

struct SecureVector{CryptoBackendT <: AbstractCryptoBackend, CiphertextT}
ciphertext::CiphertextT
function Base.show(io::IO, v::SecureContext)
print("SecureContext{", backend_name(v), "}()")
end

struct SecureVector{CryptoBackendT <: AbstractCryptoBackend, DataT}
data::DataT
length::Int
context::SecureContext{CryptoBackendT}

function SecureVector(ciphertext, context::SecureContext{CryptoBackendT}) where CryptoBackendT
new{CryptoBackendT, typeof(ciphertext)}(ciphertext, context)
function SecureVector(data, length, context::SecureContext{CryptoBackendT}) where CryptoBackendT
new{CryptoBackendT, typeof(data)}(data, length, context)
end
end

struct PlainVector{CryptoBackendT <: AbstractCryptoBackend, PlaintextT}
plaintext::PlaintextT
Base.length(v::SecureVector) = v.length
function Base.show(io::IO, v::SecureVector)
print("SecureVector{", backend_name(v), "}(data=<encrypted>, length=$(v.length))")
end

struct PlainVector{CryptoBackendT <: AbstractCryptoBackend, DataT}
data::DataT
length::Int
context::SecureContext{CryptoBackendT}

function PlainVector(plaintext, context::SecureContext{CryptoBackendT}) where CryptoBackendT
new{CryptoBackendT, typeof(plaintext)}(plaintext, context)
function PlainVector(data, length, context::SecureContext{CryptoBackendT}) where CryptoBackendT
new{CryptoBackendT, typeof(data)}(data, length, context)
end
end

Base.print(io::IO, plain_vector::PlainVector) = print(io, plain_vector.plaintext)
Base.length(v::PlainVector) = v.length
function Base.show(io::IO, v::PlainVector{CryptoBackendT}) where CryptoBackendT
print("PlainVector{", backend_name(v), "}(data=<plain>, length=$(v.length))")
end

Base.print(io::IO, plain_vector::PlainVector) = print(io, plain_vector.data)

struct PrivateKey{CryptoBackendT <: AbstractCryptoBackend, KeyT}
private_key::KeyT
Expand All @@ -33,6 +49,10 @@ struct PrivateKey{CryptoBackendT <: AbstractCryptoBackend, KeyT}
end
end

function Base.show(io::IO, key::PrivateKey{CryptoBackendT}) where CryptoBackendT
print("PrivateKey{", backend_name(key), "}()")
end

struct PublicKey{CryptoBackendT <: AbstractCryptoBackend, KeyT}
public_key::KeyT
context::SecureContext{CryptoBackendT}
Expand All @@ -41,3 +61,17 @@ struct PublicKey{CryptoBackendT <: AbstractCryptoBackend, KeyT}
new{CryptoBackendT, typeof(key)}(key, context)
end
end

function Base.show(io::IO, key::PublicKey{CryptoBackendT}) where CryptoBackendT
print("PublicKey{", backend_name(key), "}()")
end

# Get wrapper name of a potentially parametric type
# Copied from: https://github.com/ClapeyronThermo/Clapeyron.jl/blob/f40c282e2236ff68d91f37c39b5c1e4230ae9ef0/src/utils/core_utils.jl#L17
# Original source: https://github.com/JuliaArrays/ArrayInterface.jl/blob/40d9a87be07ba323cca00f9e59e5285c13f7ee72/src/ArrayInterface.jl#L20
# Note: prefixed by `__` since it is really, really dirty black magic internals we use here!
__parameterless_type(T) = Base.typename(T).wrapper

# Convenience method for getting the human-readable backend name
backend_name(x::Union{SecureContext{T}, SecureVector{T}, PlainVector{T}, PrivateKey{T},
PublicKey{T}}) where T = string(__parameterless_type(T))
Loading