Skip to content

Commit

Permalink
update python bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
piotr-roslaniec committed Apr 26, 2023
1 parent dd1eccf commit a77fc7a
Show file tree
Hide file tree
Showing 16 changed files with 367 additions and 246 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions ferveo-common/src/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub mod ser {
S: serde::Serializer,
{
let mut bytes = vec![];
val.serialize_uncompressed(&mut bytes)
val.serialize_compressed(&mut bytes)
.map_err(serde::ser::Error::custom)?;

Bytes::serialize_as(&bytes, serializer)
Expand All @@ -43,7 +43,7 @@ pub mod ser {
D: serde::Deserializer<'de>,
{
let bytes: Vec<u8> = Bytes::deserialize_as(deserializer)?;
T::deserialize_uncompressed(&mut &bytes[..])
T::deserialize_compressed(&mut &bytes[..])
.map_err(serde::de::Error::custom)
}
}
Expand All @@ -67,7 +67,7 @@ where
S: serde::Serializer,
{
let mut bytes = vec![];
val.serialize_uncompressed(&mut bytes)
val.serialize_compressed(&mut bytes)
.map_err(serde::ser::Error::custom)?;

Bytes::serialize_as(&bytes, serializer)
Expand All @@ -83,7 +83,7 @@ where
D: serde::Deserializer<'de>,
{
let bytes: Vec<u8> = Bytes::deserialize_as(deserializer)?;
T::deserialize_uncompressed(&mut &bytes[..])
T::deserialize_compressed(&mut &bytes[..])
.map_err(serde::de::Error::custom)
}
}
Expand Down
11 changes: 10 additions & 1 deletion ferveo-python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,18 @@ name = "ferveo_py"
[dependencies]
ferveo = { path = "../ferveo" }
ferveo-common = { path = "../ferveo-common" }
pyo3 = { version = "0.18.2", features = ["macros", "extension-module"] }
pyo3 = { version = "0.18.2", features = ["macros"] }
derive_more = { version = "0.99", default-features = false, features = ["from", "as_ref"] }
rand = "0.8"
itertools = "0.10.5"

# We avoid declaring "pyo3/extension-module" in `dependencies` since it causes compile-time issues:
# https://github.com/PyO3/pyo3/issues/340
# Instead, we expose it in certain cases:
# https://github.com/PyO3/maturin/issues/325
# TODO: Verify whether this actually works
#[tool.maturin]
#features = ["pyo3/extension-module"]

[build-dependencies]
pyo3-build-config = "*"
47 changes: 15 additions & 32 deletions ferveo-python/examples/server_api_precomputed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
decrypt_with_shared_secret,
Keypair,
PublicKey,
ExternalValidator,
Validator,
Transcript,
Dkg,
Ciphertext,
UnblindingKey,
DecryptionSharePrecomputed,
AggregatedTranscript,
DkgPublicKey,
Expand All @@ -25,9 +24,10 @@ def gen_eth_addr(i: int) -> str:
shares_num = 4
# In precomputed variant, security threshold must be equal to shares_num
security_threshold = shares_num

validator_keypairs = [Keypair.random() for _ in range(0, shares_num)]
validators = [
ExternalValidator(gen_eth_addr(i), keypair.public_key())
Validator(gen_eth_addr(i), keypair.public_key())
for i, keypair in enumerate(validator_keypairs)
]

Expand All @@ -47,11 +47,6 @@ def gen_eth_addr(i: int) -> str:
)
messages.append((sender, dkg.generate_transcript()))


# Let's say that we've only received `security_threshold` transcripts
messages = messages[:security_threshold]
transcripts = [transcript for _, transcript in messages]

# Every validator can aggregate the transcripts
dkg = Dkg(
tau=tau,
Expand All @@ -61,27 +56,17 @@ def gen_eth_addr(i: int) -> str:
me=validators[0],
)

server_aggregate = dkg.aggregate_transcripts(transcripts)
assert server_aggregate.verify(shares_num, transcripts)

# Clients can also create aggregates and verify them
client_aggregate = AggregatedTranscript.from_transcripts(transcripts)
assert client_aggregate.verify(shares_num, transcripts)
# Let's say that we've only received `security_threshold` transcripts
messages = messages[:security_threshold]

# We can persist transcripts and the aggregated transcript
transcripts_ser = [bytes(transcript) for _, transcript in messages]
_transcripts_deser = [Transcript.from_bytes(t) for t in transcripts_ser]
agg_transcript_ser = bytes(server_aggregate)
_agg_transcript_deser = AggregatedTranscript.from_bytes(agg_transcript_ser)
server_aggregate = dkg.aggregate_transcripts(messages)
assert server_aggregate.verify(shares_num, messages)

# In the meantime, the client creates a ciphertext and decryption request
msg = "abc".encode()
aad = "my-aad".encode()
ciphertext = encrypt(msg, aad, dkg.final_key)

# The client can serialize/deserialize ciphertext for transport
ciphertext_ser = bytes(ciphertext)

# Having aggregated the transcripts, the validators can now create decryption shares
decryption_shares = []
for validator, validator_keypair in zip(validators, validator_keypairs):
Expand All @@ -92,29 +77,27 @@ def gen_eth_addr(i: int) -> str:
validators=validators,
me=validator,
)
# Assume the aggregated transcript is obtained through deserialization from a side-channel
agg_transcript_deser = AggregatedTranscript.from_bytes(agg_transcript_ser)
agg_transcript_deser.verify(shares_num, transcripts)

# We can also obtain the aggregated transcript from the side-channel (deserialize)
aggregate = AggregatedTranscript(messages)
assert aggregate.verify(shares_num, messages)

# The ciphertext is obtained from the client
ciphertext_deser = Ciphertext.from_bytes(ciphertext_ser)

# Create a decryption share for the ciphertext
decryption_share = agg_transcript_deser.create_decryption_share_precomputed(
decryption_share = aggregate.create_decryption_share_precomputed(
dkg, ciphertext, aad, validator_keypair
)
decryption_shares.append(decryption_share)

# Now, the decryption share can be used to decrypt the ciphertext
# This part is in the client API

# The client should have access to the public parameters of the DKG
dkg_public_params_ser = bytes(dkg.public_params)
dkg_public_params_deser = DkgPublicParameters.from_bytes(dkg_public_params_ser)

shared_secret = combine_decryption_shares_precomputed(decryption_shares)

plaintext = decrypt_with_shared_secret(ciphertext, aad, shared_secret, dkg_public_params_deser)
# The client should have access to the public parameters of the DKG

plaintext = decrypt_with_shared_secret(ciphertext, aad, shared_secret, dkg.public_params)
assert bytes(plaintext) == msg

print("Success!")
32 changes: 13 additions & 19 deletions ferveo-python/examples/server_api_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
decrypt_with_shared_secret,
Keypair,
PublicKey,
ExternalValidator,
Validator,
Transcript,
Dkg,
Ciphertext,
UnblindingKey,
DecryptionShareSimple,
AggregatedTranscript,
DkgPublicKey,
Expand All @@ -26,7 +25,7 @@ def gen_eth_addr(i: int) -> str:
shares_num = 4
validator_keypairs = [Keypair.random() for _ in range(0, shares_num)]
validators = [
ExternalValidator(gen_eth_addr(i), keypair.public_key())
Validator(gen_eth_addr(i), keypair.public_key())
for i, keypair in enumerate(validator_keypairs)
]

Expand Down Expand Up @@ -56,16 +55,13 @@ def gen_eth_addr(i: int) -> str:
validators=validators,
me=me,
)

# Let's say that we've only received `security_threshold` transcripts
messages = messages[:security_threshold]

pvss_aggregated = dkg.aggregate_transcripts(messages)
assert pvss_aggregated.verify(shares_num, messages)

# Server can persist transcripts and the aggregated transcript
transcripts_ser = [bytes(transcript) for _, transcript in messages]
_transcripts_deser = [Transcript.from_bytes(t) for t in transcripts_ser]
agg_transcript_ser = bytes(pvss_aggregated)

# In the meantime, the client creates a ciphertext and decryption request
msg = "abc".encode()
aad = "my-aad".encode()
Expand All @@ -84,29 +80,27 @@ def gen_eth_addr(i: int) -> str:
validators=validators,
me=validator,
)
# Assume the aggregated transcript is obtained through deserialization from a side-channel
agg_transcript_deser = AggregatedTranscript.from_bytes(agg_transcript_ser)
agg_transcript_deser.verify(dkg)

# We can also obtain the aggregated transcript from the side-channel (deserialize)
aggregate = AggregatedTranscript(messages)
assert aggregate.verify(shares_num, messages)

# The ciphertext is obtained from the client
ciphertext_deser = Ciphertext.from_bytes(ciphertext_ser)

# Create a decryption share for the ciphertext
decryption_share = agg_transcript_deser.create_decryption_share_simple(
decryption_share = aggregate.create_decryption_share_simple(
dkg, ciphertext, aad, validator_keypair
)
decryption_shares.append(decryption_share)

# Now, the decryption share can be used to decrypt the ciphertext
# This part is in the client API

# The client should have access to the public parameters of the DKG
dkg_public_params_ser = bytes(dkg.public_params)
dkg_public_params_deser = DkgPublicParameters.from_bytes(dkg_public_params_ser)
shared_secret = combine_decryption_shares_simple(decryption_shares, dkg.public_params)

shared_secret = combine_decryption_shares_simple(decryption_shares, dkg_public_params_deser)
# The client should have access to the public parameters of the DKG

plaintext = decrypt_with_shared_secret(ciphertext, aad, shared_secret, dkg_public_params_deser)
plaintext = decrypt_with_shared_secret(ciphertext, aad, shared_secret, dkg.public_params)
assert bytes(plaintext) == msg

print("Success!")
print("Success!")
3 changes: 1 addition & 2 deletions ferveo-python/ferveo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
decrypt_with_shared_secret,
Keypair,
PublicKey,
ExternalValidator,
Validator,
Transcript,
Dkg,
Ciphertext,
UnblindingKey,
DecryptionShareSimple,
DecryptionSharePrecomputed,
AggregatedTranscript,
Expand Down
48 changes: 17 additions & 31 deletions ferveo-python/ferveo/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,19 @@ class Keypair:
def from_secure_randomness(data: bytes) -> Keypair:
...

@staticmethod
def secure_randomness_size(data: bytes) -> int:
...

@staticmethod
def from_bytes(data: bytes) -> Keypair:
...

def __bytes__(self) -> bytes:
...

public_key: PublicKey
def public_key(self) -> PublicKey:
...


class PublicKey:
Expand All @@ -29,7 +34,7 @@ class PublicKey:
...


class ExternalValidator:
class Validator:

def __init__(self, address: str, public_key: PublicKey):
...
Expand Down Expand Up @@ -64,8 +69,8 @@ class Dkg:
tau: int,
shares_num: int,
security_threshold: int,
validators: Sequence[ExternalValidator],
me: ExternalValidator,
validators: Sequence[Validator],
me: Validator,
):
...

Expand All @@ -76,7 +81,7 @@ class Dkg:
def generate_transcript(self) -> Transcript:
...

def aggregate_transcripts(self, transcripts: Sequence[(ExternalValidator, Transcript)]) -> Transcript:
def aggregate_transcripts(self, messages: Sequence[(Validator, Transcript)]) -> AggregatedTranscript:
...


Expand All @@ -89,16 +94,6 @@ class Ciphertext:
...


class UnblindingKey:

@staticmethod
def from_bytes(data: bytes) -> Keypair:
...

def __bytes__(self) -> bytes:
...


class DecryptionShareSimple:
@staticmethod
def from_bytes(data: bytes) -> DecryptionShareSimple:
Expand Down Expand Up @@ -128,6 +123,12 @@ class DkgPublicParameters:

class AggregatedTranscript:

def __init__(self, messages: Sequence[(Validator, Transcript)]):
...

def verify(self, shares_num: int, messages: Sequence[(Validator, Transcript)]) -> bool:
...

def create_decryption_share_simple(
self,
dkg: Dkg,
Expand All @@ -146,30 +147,15 @@ class AggregatedTranscript:
) -> DecryptionSharePrecomputed:
...

def validate(self, dkg: Dkg) -> bool:
...

@staticmethod
def from_transcripts(transcripts: Sequence[Transcript]) -> AggregatedTranscript:
...
@staticmethod
def from_bytes(data: bytes) -> AggregatedTranscript:
...

def __bytes__(self) -> bytes:
...


class LagrangeCoefficient:

@staticmethod
def from_bytes(data: bytes) -> LagrangeCoefficient:
...

def __bytes__(self) -> bytes:
...


class SharedSecret:

@staticmethod
Expand All @@ -186,7 +172,7 @@ def encrypt(message: bytes, add: bytes, dkg_public_key: DkgPublicKey) -> Ciphert

def combine_decryption_shares_simple(
decryption_shares: Sequence[DecryptionShareSimple],
lagrange_coefficients: LagrangeCoefficient,
dkg_public_params: DkgPublicParameters,
) -> bytes:
...

Expand Down
Loading

0 comments on commit a77fc7a

Please sign in to comment.