diff --git a/Cargo.lock b/Cargo.lock index 344e61b9..6d481097 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -789,6 +789,7 @@ dependencies = [ "derive_more", "ferveo", "ferveo-common", + "itertools", "pyo3", "pyo3-build-config", "rand 0.8.5", diff --git a/ferveo-common/src/serialization.rs b/ferveo-common/src/serialization.rs index 32b5b984..f9507731 100644 --- a/ferveo-common/src/serialization.rs +++ b/ferveo-common/src/serialization.rs @@ -29,7 +29,7 @@ pub mod ser { S: serde::Serializer, { let mut bytes = vec![]; - val.serialize_uncompressed(&mut bytes) + val.serialize_compressed(&mut bytes) .map_err(serde::ser::Error::custom)?; Bytes::serialize_as(&bytes, serializer) @@ -43,7 +43,7 @@ pub mod ser { D: serde::Deserializer<'de>, { let bytes: Vec = Bytes::deserialize_as(deserializer)?; - T::deserialize_uncompressed(&mut &bytes[..]) + T::deserialize_compressed(&mut &bytes[..]) .map_err(serde::de::Error::custom) } } @@ -67,7 +67,7 @@ where S: serde::Serializer, { let mut bytes = vec![]; - val.serialize_uncompressed(&mut bytes) + val.serialize_compressed(&mut bytes) .map_err(serde::ser::Error::custom)?; Bytes::serialize_as(&bytes, serializer) @@ -83,7 +83,7 @@ where D: serde::Deserializer<'de>, { let bytes: Vec = Bytes::deserialize_as(deserializer)?; - T::deserialize_uncompressed(&mut &bytes[..]) + T::deserialize_compressed(&mut &bytes[..]) .map_err(serde::de::Error::custom) } } diff --git a/ferveo-python/Cargo.toml b/ferveo-python/Cargo.toml index 2b9de923..34f176bc 100644 --- a/ferveo-python/Cargo.toml +++ b/ferveo-python/Cargo.toml @@ -11,9 +11,18 @@ name = "ferveo_py" [dependencies] ferveo = { path = "../ferveo" } ferveo-common = { path = "../ferveo-common" } -pyo3 = { version = "0.18.2", features = ["macros", "extension-module"] } +pyo3 = { version = "0.18.2", features = ["macros"] } derive_more = { version = "0.99", default-features = false, features = ["from", "as_ref"] } rand = "0.8" +itertools = "0.10.5" + +# We avoid declaring "pyo3/extension-module" in `dependencies` since it causes compile-time issues: +# https://github.com/PyO3/pyo3/issues/340 +# Instead, we expose it in certain cases: +# https://github.com/PyO3/maturin/issues/325 +# TODO: Verify whether this actually works +#[tool.maturin] +#features = ["pyo3/extension-module"] [build-dependencies] pyo3-build-config = "*" diff --git a/ferveo-python/examples/server_api_precomputed.py b/ferveo-python/examples/server_api_precomputed.py index 03986618..d9e49180 100644 --- a/ferveo-python/examples/server_api_precomputed.py +++ b/ferveo-python/examples/server_api_precomputed.py @@ -4,11 +4,10 @@ decrypt_with_shared_secret, Keypair, PublicKey, - ExternalValidator, + Validator, Transcript, Dkg, Ciphertext, - UnblindingKey, DecryptionSharePrecomputed, AggregatedTranscript, DkgPublicKey, @@ -25,9 +24,10 @@ def gen_eth_addr(i: int) -> str: shares_num = 4 # In precomputed variant, security threshold must be equal to shares_num security_threshold = shares_num + validator_keypairs = [Keypair.random() for _ in range(0, shares_num)] validators = [ - ExternalValidator(gen_eth_addr(i), keypair.public_key()) + Validator(gen_eth_addr(i), keypair.public_key()) for i, keypair in enumerate(validator_keypairs) ] @@ -47,11 +47,6 @@ def gen_eth_addr(i: int) -> str: ) messages.append((sender, dkg.generate_transcript())) - -# Let's say that we've only received `security_threshold` transcripts -messages = messages[:security_threshold] -transcripts = [transcript for _, transcript in messages] - # Every validator can aggregate the transcripts dkg = Dkg( tau=tau, @@ -61,27 +56,17 @@ def gen_eth_addr(i: int) -> str: me=validators[0], ) -server_aggregate = dkg.aggregate_transcripts(transcripts) -assert server_aggregate.verify(shares_num, transcripts) - -# Clients can also create aggregates and verify them -client_aggregate = AggregatedTranscript.from_transcripts(transcripts) -assert client_aggregate.verify(shares_num, transcripts) +# Let's say that we've only received `security_threshold` transcripts +messages = messages[:security_threshold] -# We can persist transcripts and the aggregated transcript -transcripts_ser = [bytes(transcript) for _, transcript in messages] -_transcripts_deser = [Transcript.from_bytes(t) for t in transcripts_ser] -agg_transcript_ser = bytes(server_aggregate) -_agg_transcript_deser = AggregatedTranscript.from_bytes(agg_transcript_ser) +server_aggregate = dkg.aggregate_transcripts(messages) +assert server_aggregate.verify(shares_num, messages) # In the meantime, the client creates a ciphertext and decryption request msg = "abc".encode() aad = "my-aad".encode() ciphertext = encrypt(msg, aad, dkg.final_key) -# The client can serialize/deserialize ciphertext for transport -ciphertext_ser = bytes(ciphertext) - # Having aggregated the transcripts, the validators can now create decryption shares decryption_shares = [] for validator, validator_keypair in zip(validators, validator_keypairs): @@ -92,15 +77,15 @@ def gen_eth_addr(i: int) -> str: validators=validators, me=validator, ) - # Assume the aggregated transcript is obtained through deserialization from a side-channel - agg_transcript_deser = AggregatedTranscript.from_bytes(agg_transcript_ser) - agg_transcript_deser.verify(shares_num, transcripts) + + # We can also obtain the aggregated transcript from the side-channel (deserialize) + aggregate = AggregatedTranscript(messages) + assert aggregate.verify(shares_num, messages) # The ciphertext is obtained from the client - ciphertext_deser = Ciphertext.from_bytes(ciphertext_ser) # Create a decryption share for the ciphertext - decryption_share = agg_transcript_deser.create_decryption_share_precomputed( + decryption_share = aggregate.create_decryption_share_precomputed( dkg, ciphertext, aad, validator_keypair ) decryption_shares.append(decryption_share) @@ -108,13 +93,11 @@ def gen_eth_addr(i: int) -> str: # Now, the decryption share can be used to decrypt the ciphertext # This part is in the client API -# The client should have access to the public parameters of the DKG -dkg_public_params_ser = bytes(dkg.public_params) -dkg_public_params_deser = DkgPublicParameters.from_bytes(dkg_public_params_ser) - shared_secret = combine_decryption_shares_precomputed(decryption_shares) -plaintext = decrypt_with_shared_secret(ciphertext, aad, shared_secret, dkg_public_params_deser) +# The client should have access to the public parameters of the DKG + +plaintext = decrypt_with_shared_secret(ciphertext, aad, shared_secret, dkg.public_params) assert bytes(plaintext) == msg print("Success!") diff --git a/ferveo-python/examples/server_api_simple.py b/ferveo-python/examples/server_api_simple.py index e6e45965..1a4331b7 100644 --- a/ferveo-python/examples/server_api_simple.py +++ b/ferveo-python/examples/server_api_simple.py @@ -4,11 +4,10 @@ decrypt_with_shared_secret, Keypair, PublicKey, - ExternalValidator, + Validator, Transcript, Dkg, Ciphertext, - UnblindingKey, DecryptionShareSimple, AggregatedTranscript, DkgPublicKey, @@ -26,7 +25,7 @@ def gen_eth_addr(i: int) -> str: shares_num = 4 validator_keypairs = [Keypair.random() for _ in range(0, shares_num)] validators = [ - ExternalValidator(gen_eth_addr(i), keypair.public_key()) + Validator(gen_eth_addr(i), keypair.public_key()) for i, keypair in enumerate(validator_keypairs) ] @@ -56,16 +55,13 @@ def gen_eth_addr(i: int) -> str: validators=validators, me=me, ) + # Let's say that we've only received `security_threshold` transcripts messages = messages[:security_threshold] + pvss_aggregated = dkg.aggregate_transcripts(messages) assert pvss_aggregated.verify(shares_num, messages) -# Server can persist transcripts and the aggregated transcript -transcripts_ser = [bytes(transcript) for _, transcript in messages] -_transcripts_deser = [Transcript.from_bytes(t) for t in transcripts_ser] -agg_transcript_ser = bytes(pvss_aggregated) - # In the meantime, the client creates a ciphertext and decryption request msg = "abc".encode() aad = "my-aad".encode() @@ -84,15 +80,15 @@ def gen_eth_addr(i: int) -> str: validators=validators, me=validator, ) - # Assume the aggregated transcript is obtained through deserialization from a side-channel - agg_transcript_deser = AggregatedTranscript.from_bytes(agg_transcript_ser) - agg_transcript_deser.verify(dkg) + + # We can also obtain the aggregated transcript from the side-channel (deserialize) + aggregate = AggregatedTranscript(messages) + assert aggregate.verify(shares_num, messages) # The ciphertext is obtained from the client - ciphertext_deser = Ciphertext.from_bytes(ciphertext_ser) # Create a decryption share for the ciphertext - decryption_share = agg_transcript_deser.create_decryption_share_simple( + decryption_share = aggregate.create_decryption_share_simple( dkg, ciphertext, aad, validator_keypair ) decryption_shares.append(decryption_share) @@ -100,13 +96,11 @@ def gen_eth_addr(i: int) -> str: # Now, the decryption share can be used to decrypt the ciphertext # This part is in the client API -# The client should have access to the public parameters of the DKG -dkg_public_params_ser = bytes(dkg.public_params) -dkg_public_params_deser = DkgPublicParameters.from_bytes(dkg_public_params_ser) +shared_secret = combine_decryption_shares_simple(decryption_shares, dkg.public_params) -shared_secret = combine_decryption_shares_simple(decryption_shares, dkg_public_params_deser) +# The client should have access to the public parameters of the DKG -plaintext = decrypt_with_shared_secret(ciphertext, aad, shared_secret, dkg_public_params_deser) +plaintext = decrypt_with_shared_secret(ciphertext, aad, shared_secret, dkg.public_params) assert bytes(plaintext) == msg -print("Success!") \ No newline at end of file +print("Success!") diff --git a/ferveo-python/ferveo/__init__.py b/ferveo-python/ferveo/__init__.py index 4694382a..d5f99372 100644 --- a/ferveo-python/ferveo/__init__.py +++ b/ferveo-python/ferveo/__init__.py @@ -5,11 +5,10 @@ decrypt_with_shared_secret, Keypair, PublicKey, - ExternalValidator, + Validator, Transcript, Dkg, Ciphertext, - UnblindingKey, DecryptionShareSimple, DecryptionSharePrecomputed, AggregatedTranscript, diff --git a/ferveo-python/ferveo/__init__.pyi b/ferveo-python/ferveo/__init__.pyi index 66421ab0..265a077e 100644 --- a/ferveo-python/ferveo/__init__.pyi +++ b/ferveo-python/ferveo/__init__.pyi @@ -10,6 +10,10 @@ class Keypair: def from_secure_randomness(data: bytes) -> Keypair: ... + @staticmethod + def secure_randomness_size(data: bytes) -> int: + ... + @staticmethod def from_bytes(data: bytes) -> Keypair: ... @@ -17,7 +21,8 @@ class Keypair: def __bytes__(self) -> bytes: ... - public_key: PublicKey + def public_key(self) -> PublicKey: + ... class PublicKey: @@ -29,7 +34,7 @@ class PublicKey: ... -class ExternalValidator: +class Validator: def __init__(self, address: str, public_key: PublicKey): ... @@ -64,8 +69,8 @@ class Dkg: tau: int, shares_num: int, security_threshold: int, - validators: Sequence[ExternalValidator], - me: ExternalValidator, + validators: Sequence[Validator], + me: Validator, ): ... @@ -76,7 +81,7 @@ class Dkg: def generate_transcript(self) -> Transcript: ... - def aggregate_transcripts(self, transcripts: Sequence[(ExternalValidator, Transcript)]) -> Transcript: + def aggregate_transcripts(self, messages: Sequence[(Validator, Transcript)]) -> AggregatedTranscript: ... @@ -89,16 +94,6 @@ class Ciphertext: ... -class UnblindingKey: - - @staticmethod - def from_bytes(data: bytes) -> Keypair: - ... - - def __bytes__(self) -> bytes: - ... - - class DecryptionShareSimple: @staticmethod def from_bytes(data: bytes) -> DecryptionShareSimple: @@ -128,6 +123,12 @@ class DkgPublicParameters: class AggregatedTranscript: + def __init__(self, messages: Sequence[(Validator, Transcript)]): + ... + + def verify(self, shares_num: int, messages: Sequence[(Validator, Transcript)]) -> bool: + ... + def create_decryption_share_simple( self, dkg: Dkg, @@ -146,13 +147,8 @@ class AggregatedTranscript: ) -> DecryptionSharePrecomputed: ... - def validate(self, dkg: Dkg) -> bool: - ... @staticmethod - def from_transcripts(transcripts: Sequence[Transcript]) -> AggregatedTranscript: - ... - @staticmethod def from_bytes(data: bytes) -> AggregatedTranscript: ... @@ -160,16 +156,6 @@ class AggregatedTranscript: ... -class LagrangeCoefficient: - - @staticmethod - def from_bytes(data: bytes) -> LagrangeCoefficient: - ... - - def __bytes__(self) -> bytes: - ... - - class SharedSecret: @staticmethod @@ -186,7 +172,7 @@ def encrypt(message: bytes, add: bytes, dkg_public_key: DkgPublicKey) -> Ciphert def combine_decryption_shares_simple( decryption_shares: Sequence[DecryptionShareSimple], - lagrange_coefficients: LagrangeCoefficient, + dkg_public_params: DkgPublicParameters, ) -> bytes: ... diff --git a/ferveo-python/src/lib.rs b/ferveo-python/src/lib.rs index 4bbb553f..be79a5ac 100644 --- a/ferveo-python/src/lib.rs +++ b/ferveo-python/src/lib.rs @@ -1,4 +1,5 @@ extern crate alloc; +extern crate core; use std::fmt::{self}; @@ -216,7 +217,7 @@ pub struct Validator(ferveo::api::Validator); #[pymethods] impl Validator { #[new] - pub fn new(address: String, public_key: PublicKey) -> PyResult { + pub fn new(address: String, public_key: &PublicKey) -> PyResult { let validator = ferveo::api::Validator::new(address, public_key.0) .map_err(map_py_error)?; Ok(Self(validator)) @@ -265,7 +266,7 @@ impl DkgPublicKey { } } -#[derive(FromPyObject)] +#[derive(FromPyObject, Clone)] pub struct ValidatorMessage(Validator, Transcript); #[pyclass(module = "ferveo")] @@ -280,7 +281,7 @@ impl Dkg { shares_num: u32, security_threshold: u32, validators: Vec, - me: Validator, + me: &Validator, ) -> PyResult { let validators: Vec<_> = validators.into_iter().map(|v| v.0).collect(); let dkg = ferveo::api::Dkg::new( @@ -311,11 +312,13 @@ impl Dkg { messages: Vec, ) -> PyResult { let messages: Vec<_> = messages - .into_iter() - .map(|ValidatorMessage(v, t)| (v.0, t.0)) + .iter() + .map(|m| ((m.0).0.clone(), (m.1).0.clone())) .collect(); - let aggregated_transcript = - ferveo::api::AggregatedTranscript::new(&messages); + let aggregated_transcript = self + .0 + .aggregate_transcripts(&messages) + .map_err(map_py_error)?; Ok(AggregatedTranscript(aggregated_transcript)) } @@ -341,22 +344,6 @@ impl Ciphertext { } } -#[pyclass(module = "ferveo")] -#[derive(derive_more::From, derive_more::AsRef)] -pub struct UnblindingKey(ferveo::api::UnblindingKey); - -#[pymethods] -impl UnblindingKey { - #[staticmethod] - pub fn from_bytes(bytes: &[u8]) -> PyResult { - from_py_bytes(bytes).map(Self) - } - - fn __bytes__(&self) -> PyResult { - to_py_bytes(self.0) - } -} - #[pyclass(module = "ferveo")] #[derive(Clone, derive_more::AsRef, derive_more::From)] pub struct DecryptionShareSimple(ferveo::api::DecryptionShareSimple); @@ -482,7 +469,6 @@ fn ferveo_py(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; - m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -491,3 +477,217 @@ fn ferveo_py(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; Ok(()) } + +// TODO: Consider adding remaining ferveo/api.rs tests here +#[cfg(test)] +mod test_ferveo_python { + use itertools::izip; + + use crate::*; + + type TestInputs = (Vec, Vec, Vec); + + fn make_test_inputs( + tau: u64, + security_threshold: u32, + shares_num: u32, + ) -> TestInputs { + let validator_keypairs = (0..shares_num) + .map(|_| Keypair::random()) + .collect::>(); + let validators: Vec<_> = validator_keypairs + .iter() + .enumerate() + .map(|(i, keypair)| { + Validator::new(format!("0x{:040}", i), &keypair.public_key()) + .unwrap() + }) + .collect(); + + // Each validator holds their own DKG instance and generates a transcript every + // every validator, including themselves + let messages: Vec<_> = validators + .iter() + .cloned() + .map(|sender| { + let dkg = Dkg::new( + tau, + shares_num, + security_threshold, + validators.clone(), + &sender, + ) + .unwrap(); + ValidatorMessage(sender, dkg.generate_transcript().unwrap()) + }) + .collect(); + (messages, validators, validator_keypairs) + } + + #[test] + fn test_server_api_tdec_precomputed() { + let tau = 1; + let shares_num = 4; + // In precomputed variant, the security threshold is equal to the number of shares + let security_threshold = shares_num; + + let (messages, validators, validator_keypairs) = + make_test_inputs(tau, security_threshold, shares_num); + + // Now that every validator holds a dkg instance and a transcript for every other validator, + // every validator can aggregate the transcripts + + let me = validators[0].clone(); + let mut dkg = Dkg::new( + tau, + shares_num, + security_threshold, + validators.clone(), + &me, + ) + .unwrap(); + + // Lets say that we've only receives `security_threshold` transcripts + let messages = messages[..security_threshold as usize].to_vec(); + let pvss_aggregated = + dkg.aggregate_transcripts(messages.clone()).unwrap(); + assert!(pvss_aggregated + .verify(shares_num, messages.clone()) + .unwrap()); + + // At this point, any given validator should be able to provide a DKG public key + let dkg_public_key = dkg.final_key(); + + // In the meantime, the client creates a ciphertext and decryption request + let msg: &[u8] = "abc".as_bytes(); + let aad: &[u8] = "my-aad".as_bytes(); + let ciphertext = encrypt(msg, aad, &dkg_public_key).unwrap(); + + // Having aggregated the transcripts, the validators can now create decryption shares + let decryption_shares: Vec<_> = izip!(&validators, &validator_keypairs) + .map(|(validator, validator_keypair)| { + // Each validator holds their own instance of DKG and creates their own aggregate + let mut dkg = Dkg::new( + tau, + shares_num, + security_threshold, + validators.clone(), + validator, + ) + .unwrap(); + let aggregate = + dkg.aggregate_transcripts(messages.clone()).unwrap(); + assert!(pvss_aggregated + .verify(shares_num, messages.clone()) + .is_ok()); + aggregate + .create_decryption_share_precomputed( + &dkg, + &ciphertext, + aad, + validator_keypair, + ) + .unwrap() + }) + .collect(); + + // Now, the decryption share can be used to decrypt the ciphertext + // This part is part of the client API + + let shared_secret = + combine_decryption_shares_precomputed(decryption_shares); + + let plaintext = decrypt_with_shared_secret( + &ciphertext, + aad, + &shared_secret, + &dkg.public_params(), + ) + .unwrap(); + assert_eq!(plaintext, msg); + } + + #[test] + fn test_server_api_tdec_simple() { + let tau = 1; + let shares_num = 4; + let security_threshold = 3; + + let (messages, validators, validator_keypairs) = + make_test_inputs(tau, security_threshold, shares_num); + + // Now that every validator holds a dkg instance and a transcript for every other validator, + // every validator can aggregate the transcripts + let me = validators[0].clone(); + let mut dkg = Dkg::new( + tau, + shares_num, + security_threshold, + validators.clone(), + &me, + ) + .unwrap(); + + // Lets say that we've only receives `security_threshold` transcripts + let messages = messages[..security_threshold as usize].to_vec(); + let pvss_aggregated = + dkg.aggregate_transcripts(messages.clone()).unwrap(); + assert!(pvss_aggregated + .verify(shares_num, messages.clone()) + .unwrap()); + + // At this point, any given validator should be able to provide a DKG public key + let dkg_public_key = dkg.final_key(); + + // In the meantime, the client creates a ciphertext and decryption request + let msg: &[u8] = "abc".as_bytes(); + let aad: &[u8] = "my-aad".as_bytes(); + let ciphertext = encrypt(msg, aad, &dkg_public_key).unwrap(); + + // Having aggregated the transcripts, the validators can now create decryption shares + let decryption_shares: Vec<_> = izip!(&validators, &validator_keypairs) + .map(|(validator, validator_keypair)| { + // Each validator holds their own instance of DKG and creates their own aggregate + let mut dkg = Dkg::new( + tau, + shares_num, + security_threshold, + validators.clone(), + validator, + ) + .unwrap(); + let aggregate = + dkg.aggregate_transcripts(messages.clone()).unwrap(); + assert!(aggregate + .verify(shares_num, messages.clone()) + .unwrap()); + aggregate + .create_decryption_share_simple( + &dkg, + &ciphertext, + aad, + validator_keypair, + ) + .unwrap() + }) + .collect(); + + // Now, the decryption share can be used to decrypt the ciphertext + // This part is part of the client API + + let shared_secret = combine_decryption_shares_simple( + decryption_shares, + &dkg.public_params(), + ); + + // TODO: Fails because of a bad shared secret + let plaintext = decrypt_with_shared_secret( + &ciphertext, + aad, + &shared_secret, + &dkg.public_params(), + ) + .unwrap(); + assert_eq!(plaintext, msg); + } +} diff --git a/ferveo-python/test/test_ferveo.py b/ferveo-python/test/test_ferveo.py index 34fb4cd2..713b7372 100644 --- a/ferveo-python/test/test_ferveo.py +++ b/ferveo-python/test/test_ferveo.py @@ -7,11 +7,10 @@ decrypt_with_shared_secret, Keypair, PublicKey, - ExternalValidator, + Validator, Transcript, Dkg, Ciphertext, - UnblindingKey, DecryptionShareSimple, DecryptionSharePrecomputed, AggregatedTranscript, @@ -50,10 +49,11 @@ def scenario_for_variant(variant, shares_num=4, security_threshold=3): tau = 1 validator_keypairs = [Keypair.random() for _ in range(0, shares_num)] validators = [ - ExternalValidator(gen_eth_addr(i), keypair.public_key()) + Validator(gen_eth_addr(i), keypair.public_key()) for i, keypair in enumerate(validator_keypairs) ] validators.sort(key=lambda v: v.address) + messages = [] for sender in validators: dkg = Dkg( @@ -64,6 +64,7 @@ def scenario_for_variant(variant, shares_num=4, security_threshold=3): me=sender, ) messages.append((sender, dkg.generate_transcript())) + me = validators[0] dkg = Dkg( tau=tau, @@ -73,10 +74,12 @@ def scenario_for_variant(variant, shares_num=4, security_threshold=3): me=me, ) pvss_aggregated = dkg.aggregate_transcripts(messages) - assert pvss_aggregated.validate(dkg) + assert pvss_aggregated.verify(shares_num, messages) + msg = "abc".encode() aad = "my-aad".encode() ciphertext = encrypt(msg, aad, dkg.final_key) + decryption_shares = [] for validator, validator_keypair in zip(validators, validator_keypairs): dkg = Dkg( @@ -87,13 +90,15 @@ def scenario_for_variant(variant, shares_num=4, security_threshold=3): me=validator, ) agg_transcript_deser = AggregatedTranscript.from_bytes(bytes(pvss_aggregated)) - agg_transcript_deser.validate(dkg) + agg_transcript_deser.verify(shares_num, messages) decryption_share = decryption_share_for_variant('simple', agg_transcript_deser)( dkg, ciphertext, aad, validator_keypair ) decryption_shares.append(decryption_share) + shared_secret = combine_shares_for_variant('simple')(decryption_shares, dkg.public_params) + plaintext = decrypt_with_shared_secret(ciphertext, aad, shared_secret, dkg.public_params) assert bytes(plaintext) == msg diff --git a/ferveo-python/test/test_serialization.py b/ferveo-python/test/test_serialization.py index 715e7246..bb3a48f2 100644 --- a/ferveo-python/test/test_serialization.py +++ b/ferveo-python/test/test_serialization.py @@ -1,7 +1,7 @@ from ferveo_py import ( Keypair, PublicKey, - ExternalValidator, + Validator, Transcript, Dkg, AggregatedTranscript, @@ -20,7 +20,7 @@ def gen_eth_addr(i: int) -> str: shares_num = 4 validator_keypairs = [Keypair.random() for _ in range(shares_num)] validators = [ - ExternalValidator(gen_eth_addr(i), keypair.public_key()) + Validator(gen_eth_addr(i), keypair.public_key()) for i, keypair in enumerate(validator_keypairs) ] validators.sort(key=lambda v: v.address) diff --git a/ferveo/Cargo.toml b/ferveo/Cargo.toml index 7e57ea21..ef48daeb 100644 --- a/ferveo/Cargo.toml +++ b/ferveo/Cargo.toml @@ -24,7 +24,7 @@ rand = "0.8" rand_old = { package = "rand", version = "0.7" } # used by benchmarks/pairing.rs serde = { version = "1.0", features = ["derive"] } bincode = "1.3" -itertools = "0.10.1" +itertools = "0.10.5" measure_time = "0.8" rand_core = "0.6.4" serde_with = "2.0.1" diff --git a/ferveo/src/api.rs b/ferveo/src/api.rs index 2cc3cef0..b2d57046 100644 --- a/ferveo/src/api.rs +++ b/ferveo/src/api.rs @@ -239,8 +239,46 @@ mod test_ferveo_api { type E = ark_bls12_381::Bls12_381; + type TestInputs = + (Vec, Vec>, Vec>); + + fn make_test_inputs( + rng: &mut StdRng, + tau: u64, + security_threshold: u32, + shares_num: u32, + ) -> TestInputs { + let validator_keypairs = gen_keypairs(shares_num); + let validators = validator_keypairs + .iter() + .enumerate() + .map(|(i, keypair)| Validator { + address: gen_address(i), + public_key: keypair.public(), + }) + .collect::>(); + + // Each validator holds their own DKG instance and generates a transcript every + // every validator, including themselves + let messages: Vec<_> = validators + .iter() + .map(|sender| { + let dkg = Dkg::new( + tau, + shares_num, + security_threshold, + &validators, + sender, + ) + .unwrap(); + (sender.clone(), dkg.generate_transcript(rng).unwrap()) + }) + .collect(); + (messages, validators, validator_keypairs) + } + #[test] - fn test_dkg_public_serialization() { + fn test_dkg_public_key_serialization() { let shares_num = 4; let validator_keypairs = gen_keypairs(shares_num); let validators = validator_keypairs @@ -273,32 +311,8 @@ mod test_ferveo_api { // Or figure out a different way to simplify the precomputed variant API. let security_threshold = shares_num; - let validator_keypairs = gen_keypairs(shares_num); - let validators = validator_keypairs - .iter() - .enumerate() - .map(|(i, keypair)| Validator { - address: gen_address(i), - public_key: keypair.public(), - }) - .collect::>(); - - // Each validator holds their own DKG instance and generates a transcript every - // every validator, including themselves - let messages: Vec<_> = validators - .iter() - .map(|sender| { - let dkg = Dkg::new( - tau, - shares_num, - security_threshold, - &validators, - sender, - ) - .unwrap(); - (sender.clone(), dkg.generate_transcript(rng).unwrap()) - }) - .collect(); + let (messages, validators, validator_keypairs) = + make_test_inputs(rng, tau, security_threshold, shares_num); // Now that every validator holds a dkg instance and a transcript for every other validator, // every validator can aggregate the transcripts @@ -309,22 +323,17 @@ mod test_ferveo_api { // Lets say that we've only receives `security_threshold` transcripts let messages = messages[..security_threshold as usize].to_vec(); - let _transcripts: Vec<_> = messages - .iter() - .map(|(_, transcript)| transcript) - .cloned() - .collect(); let pvss_aggregated = dkg.aggregate_transcripts(&messages).unwrap(); assert!(pvss_aggregated.verify(shares_num, &messages).unwrap()); // At this point, any given validator should be able to provide a DKG public key - let public_key = dkg.final_key(); + let dkg_public_key = dkg.final_key(); // In the meantime, the client creates a ciphertext and decryption request let msg: &[u8] = "abc".as_bytes(); let aad: &[u8] = "my-aad".as_bytes(); let rng = &mut thread_rng(); - let ciphertext = encrypt(msg, aad, &public_key.0, rng).unwrap(); + let ciphertext = encrypt(msg, aad, &dkg_public_key.0, rng).unwrap(); // Having aggregated the transcripts, the validators can now create decryption shares let decryption_shares: Vec<_> = izip!(&validators, &validator_keypairs) @@ -339,7 +348,7 @@ mod test_ferveo_api { ) .unwrap(); let aggregate = dkg.aggregate_transcripts(&messages).unwrap(); - assert!(pvss_aggregated.verify(shares_num, &messages).is_ok()); + assert!(pvss_aggregated.verify(shares_num, &messages).unwrap()); aggregate .create_decryption_share_precomputed( &dkg, @@ -374,32 +383,8 @@ mod test_ferveo_api { let shares_num = 4; let security_threshold = 3; - let validator_keypairs = gen_keypairs(shares_num); - let validators = validator_keypairs - .iter() - .enumerate() - .map(|(i, keypair)| Validator { - address: gen_address(i), - public_key: keypair.public(), - }) - .collect::>(); - - // Each validator holds their own DKG instance and generates a transcript every - // every validator, including themselves - let messages: Vec<_> = validators - .iter() - .map(|sender| { - let dkg = Dkg::new( - tau, - shares_num, - security_threshold, - &validators, - sender, - ) - .unwrap(); - (sender.clone(), dkg.generate_transcript(rng).unwrap()) - }) - .collect(); + let (messages, validators, validator_keypairs) = + make_test_inputs(rng, tau, security_threshold, shares_num); // Now that every validator holds a dkg instance and a transcript for every other validator, // every validator can aggregate the transcripts @@ -437,7 +422,7 @@ mod test_ferveo_api { let aggregate = dkg.aggregate_transcripts(&messages).unwrap(); assert!(aggregate.verify(shares_num, &messages).unwrap()); aggregate - .create_decryption_share_precomputed( + .create_decryption_share_simple( &dkg, &ciphertext, aad, @@ -450,7 +435,10 @@ mod test_ferveo_api { // Now, the decryption share can be used to decrypt the ciphertext // This part is part of the client API - let shared_secret = share_combine_precomputed(&decryption_shares); + let lagrange_coeffs = + prepare_combine_simple::(&dkg.public_params().domain_points); + let shared_secret = + share_combine_simple(&decryption_shares, &lagrange_coeffs); let plaintext = decrypt_with_shared_secret( &ciphertext, @@ -470,7 +458,7 @@ mod test_ferveo_api { let security_threshold = 3; let shares_num = 4; - let (messages, validators, _validator_keypairs) = + let (messages, validators, _) = make_test_inputs(rng, tau, security_threshold, shares_num); // Now that every validator holds a dkg instance and a transcript for every other validator, @@ -530,42 +518,4 @@ mod test_ferveo_api { let result = bad_aggregate.verify(shares_num, messages); assert!(result.is_err()); } - - type TestInputs = - (Vec, Vec>, Vec>); - - fn make_test_inputs( - rng: &mut StdRng, - tau: u64, - security_threshold: u32, - shares_num: u32, - ) -> TestInputs { - let validator_keypairs = gen_keypairs(shares_num); - let validators = validator_keypairs - .iter() - .enumerate() - .map(|(i, keypair)| Validator { - address: gen_address(i), - public_key: keypair.public(), - }) - .collect::>(); - - // Each validator holds their own DKG instance and generates a transcript every - // every validator, including themselves - let messages: Vec<_> = validators - .iter() - .map(|sender| { - let dkg = Dkg::new( - tau, - shares_num, - security_threshold, - &validators, - sender, - ) - .unwrap(); - (sender.clone(), dkg.generate_transcript(rng).unwrap()) - }) - .collect(); - (messages, validators, validator_keypairs) - } } diff --git a/ferveo/src/dkg.rs b/ferveo/src/dkg.rs index 3fff0398..acf7a295 100644 --- a/ferveo/src/dkg.rs +++ b/ferveo/src/dkg.rs @@ -337,11 +337,7 @@ pub(crate) mod test_common { pub fn gen_keypairs(n: u32) -> Vec> { let rng = &mut ark_std::test_rng(); - let mut keypair: Vec<_> = - (0..n).map(|_| Keypair::::new(rng)).collect(); - keypair.sort_by_key(|a| a.public()); - // keypair.sort(); - keypair + (0..n).map(|_| Keypair::::new(rng)).collect() } pub fn gen_address(i: usize) -> EthereumAddress { @@ -440,6 +436,10 @@ mod test_dkg_init { let shares_num = 4; let known_keypairs = gen_keypairs(shares_num); let unknown_keypair = ferveo_common::Keypair::::new(rng); + let unknown_validator = Validator:: { + address: gen_address((shares_num + 1) as usize), + public_key: unknown_keypair.public(), + }; let err = PubliclyVerifiableDkg::::new( &gen_validators(&known_keypairs), &DkgParams { @@ -447,14 +447,11 @@ mod test_dkg_init { security_threshold: shares_num / 2, shares_num, }, - &Validator:: { - address: gen_address(0), - public_key: unknown_keypair.public(), - }, + &unknown_validator, ) - .expect_err("Test failed"); + .unwrap_err(); - assert_eq!(err.to_string(), "Validator public key mismatch") + assert_eq!(err.to_string(), "Expected validator to be a part of the DKG validator set: 0x0000000000000000000000000000000000000005") } } diff --git a/ferveo/src/pvss.rs b/ferveo/src/pvss.rs index d19c592c..3c6221cd 100644 --- a/ferveo/src/pvss.rs +++ b/ferveo/src/pvss.rs @@ -9,7 +9,6 @@ use ark_poly::{ use ferveo_common::is_sorted; use group_threshold_cryptography as tpke; use itertools::Itertools; -use measure_time::print_time; use rand::RngCore; use serde::{Deserialize, Serialize}; use serde_with::serde_as; @@ -193,15 +192,13 @@ pub fn do_verify_full( validators: &[ferveo_common::Validator], domain: &ark_poly::Radix2EvaluationDomain, ) -> bool { - // compute the commitment let mut commitment = batch_to_projective_g1::(pvss_coefficients); - print_time!("commitment fft"); domain.fft_in_place(&mut commitment); // At this point, validators must be sorted assert!(is_sorted(validators)); - //Each validator checks that their share is correct + // Each validator checks that their share is correct validators .iter() .zip(pvss_encrypted_shares.iter()) diff --git a/tpke-wasm/src/lib.rs b/tpke-wasm/src/lib.rs index a614518d..097ecb1f 100644 --- a/tpke-wasm/src/lib.rs +++ b/tpke-wasm/src/lib.rs @@ -88,7 +88,7 @@ impl PrivateKey { #[wasm_bindgen(js_name = "fromBytes")] pub fn from_bytes(bytes: &[u8]) -> Result { let mut reader = bytes; - let pk = tpke::api::PrivateKey::deserialize_uncompressed(&mut reader) + let pk = tpke::api::PrivateKey::deserialize_compressed(&mut reader) .map_err(map_js_err)?; Ok(PrivateKey(pk)) } @@ -97,7 +97,7 @@ impl PrivateKey { pub fn to_bytes(&self) -> Result> { let mut bytes = Vec::new(); self.0 - .serialize_uncompressed(&mut bytes) + .serialize_compressed(&mut bytes) .map_err(map_js_err)?; Ok(bytes) } diff --git a/tpke/src/ciphertext.rs b/tpke/src/ciphertext.rs index aa2dc8de..42e1e17c 100644 --- a/tpke/src/ciphertext.rs +++ b/tpke/src/ciphertext.rs @@ -43,7 +43,7 @@ impl Ciphertext { fn construct_tag_hash(&self) -> Result { let mut hash_input = Vec::::new(); - self.commitment.serialize_uncompressed(&mut hash_input)?; + self.commitment.serialize_compressed(&mut hash_input)?; hash_input.extend_from_slice(&self.ciphertext); hash_to_g2(&hash_input) @@ -173,7 +173,7 @@ pub fn shared_secret_to_chacha( s: &E::TargetField, ) -> Result { let mut prf_key = Vec::new(); - s.serialize_uncompressed(&mut prf_key)?; + s.serialize_compressed(&mut prf_key)?; let prf_key_32 = sha256(&prf_key); Ok(ChaCha20Poly1305::new(GenericArray::from_slice(&prf_key_32))) @@ -181,7 +181,7 @@ pub fn shared_secret_to_chacha( fn nonce_from_commitment(commitment: E::G1Affine) -> Result { let mut commitment_bytes = Vec::new(); - commitment.serialize_uncompressed(&mut commitment_bytes)?; + commitment.serialize_compressed(&mut commitment_bytes)?; let commitment_hash = sha256(&commitment_bytes); Ok(*Nonce::from_slice(&commitment_hash[..12])) } @@ -191,8 +191,8 @@ fn hash_to_g2( ) -> Result { let point = htp_bls12381_g2(message); let mut point_ser: Vec = Vec::new(); - point.serialize_uncompressed(&mut point_ser)?; - T::deserialize_uncompressed(&point_ser[..]) + point.serialize_compressed(&mut point_ser)?; + T::deserialize_compressed(&point_ser[..]) .map_err(Error::ArkworksSerializationError) } @@ -202,7 +202,7 @@ fn construct_tag_hash( aad: &[u8], ) -> Result { let mut hash_input = Vec::::new(); - commitment.serialize_uncompressed(&mut hash_input)?; + commitment.serialize_compressed(&mut hash_input)?; hash_input.extend_from_slice(stream_ciphertext); hash_input.extend_from_slice(aad); hash_to_g2(&hash_input)