From 8c533fb271202be1dfc888fc09a7c3bcad4763fc Mon Sep 17 00:00:00 2001 From: bochaco Date: Fri, 12 Nov 2021 16:26:30 -0300 Subject: [PATCH] fix(api): get_proof_chain was returning a subchain with forks in some cases BREAKING CHANGE: the 'merge' public API was renamed to 'join' --- src/error.rs | 2 ++ src/lib.rs | 38 +++++++++++++++++++++++++++++++------- src/tests.rs | 45 +++++++++++++++++++++++++++++++++++---------- 3 files changed, 68 insertions(+), 17 deletions(-) diff --git a/src/error.rs b/src/error.rs index a987141..cb338ef 100644 --- a/src/error.rs +++ b/src/error.rs @@ -16,6 +16,8 @@ pub enum Error { FailedSignature, #[error("key not found in the chain")] KeyNotFound, + #[error("no sub-chain was found in the chain")] + SubChainNotFound, #[error("chain doesn't contain any trusted keys")] Untrusted, #[error("attempted operation is invalid")] diff --git a/src/lib.rs b/src/lib.rs index 3deb44a..4a74601 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -88,7 +88,7 @@ impl SecuredLinkedList { /// /// This succeeds only if the root key of one of the chain is present in the other one. /// Otherwise it returns `Error::InvalidOperation` - pub fn merge(&mut self, mut other: Self) -> Result<(), Error> { + pub fn join(&mut self, mut other: Self) -> Result<(), Error> { let root_index = if let Some(index) = self.index_of(other.root_key()) { index } else if let Some(index) = other.index_of(self.root_key()) { @@ -121,13 +121,37 @@ impl SecuredLinkedList { from_key: &bls::PublicKey, to_key: &bls::PublicKey, ) -> Result { - self.minimize(vec![from_key, to_key]) - } + let from_index = self.index_of(from_key).ok_or(Error::KeyNotFound)?; + let mut chain = Self::new(if from_index == 0 { + self.root + } else { + self.tree[from_index - 1].key + }); + + let mut curr_index = self.index_of(to_key).ok_or(Error::KeyNotFound)?; + while curr_index != 0 && curr_index != from_index { + let block = &self.tree[curr_index - 1]; + chain.tree.insert( + 0, + Block { + key: block.key, + signature: block.signature.clone(), + parent_index: 0, // we'll update it afterwards + }, + ); + curr_index = block.parent_index; + } - /// Creates a sub-chain from a given key to the end. - /// Returns `Error::KeyNotFound` if the given from key is not present in the chain. - pub fn get_proof_chain_to_current(&self, from_key: &bls::PublicKey) -> Result { - self.minimize(vec![from_key, self.last_key()]) + if curr_index != from_index { + // the 'from_key' is not an ancestor in any chain containing 'to_key' + Err(Error::SubChainNotFound) + } else { + for (i, elem) in chain.tree.iter_mut().enumerate() { + elem.parent_index = i; + } + + Ok(chain) + } } /// Creates a minimal sub-chain of `self` that contains all `required_keys`. diff --git a/src/tests.rs b/src/tests.rs index 23285c3..e56c178 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -32,6 +32,11 @@ fn insert_last() { #[test] fn insert_fork() { + // We use a chain with two branches, a and b: + // pk0 -> pk1_a -> pk2_a + // | + // +-> pk1_b + // let (sk0, pk0) = gen_keypair(); let (sk1_a, pk1_a, sig1_a) = gen_signed_keypair(&sk0); let (_, pk2_a, sig2_a) = gen_signed_keypair(&sk1_a); @@ -40,7 +45,11 @@ fn insert_fork() { let mut chain = SecuredLinkedList::new(pk0); assert_eq!(chain.insert(&pk0, pk1_a, sig1_a), Ok(())); assert_eq!(chain.insert(&pk1_a, pk2_a, sig2_a), Ok(())); - assert_eq!(chain.insert(&pk0, pk1_b, sig1_b), Ok(())); + let branch_a_only = chain.clone(); + + assert_eq!(chain.insert(&pk0, pk1_b, sig1_b.clone()), Ok(())); + let mut branch_b_only = SecuredLinkedList::new(pk0); + assert_eq!(branch_b_only.insert(&pk0, pk1_b, sig1_b), Ok(())); let expected_keys = if pk1_a > pk1_b { vec![&pk0, &pk1_b, &pk1_a, &pk2_a] @@ -51,6 +60,22 @@ fn insert_fork() { let actual_keys: Vec<_> = chain.keys().collect(); assert_eq!(actual_keys, expected_keys); + + assert_eq!( + chain.get_proof_chain(&pk0, &pk0), + Ok(SecuredLinkedList::new(pk0)) + ); + assert_eq!(chain.get_proof_chain(&pk0, &pk2_a), Ok(branch_a_only)); + assert_eq!(chain.get_proof_chain(&pk0, &pk1_b), Ok(branch_b_only)); + + assert_eq!( + chain.get_proof_chain(&pk2_a, &pk0), + Err(Error::SubChainNotFound) + ); + assert_eq!( + chain.get_proof_chain(&pk1_a, &pk1_b), + Err(Error::SubChainNotFound) + ); } #[test] @@ -204,7 +229,7 @@ fn invalid_deserialized_chain_wrong_unrelated_block_order() { } #[test] -fn merge() { +fn join() { let (sk0, pk0) = gen_keypair(); let (sk1, pk1, sig1) = gen_signed_keypair(&sk0); let (sk2, pk2, sig2) = gen_signed_keypair(&sk1); @@ -222,21 +247,21 @@ fn merge() { ], ); let rhs = make_chain(pk2, vec![(&pk2, pk3, sig3.clone())]); - assert_eq!(merge_chains(lhs, rhs), Ok(vec![pk0, pk1, pk2, pk3])); + assert_eq!(join_chains(lhs, rhs), Ok(vec![pk0, pk1, pk2, pk3])); // lhs: 1->2->3 // rhs: 0->1 // out: 0->1->2->3 let lhs = make_chain(pk1, vec![(&pk1, pk2, sig2), (&pk2, pk3, sig3.clone())]); let rhs = make_chain(pk0, vec![(&pk0, pk1, sig1.clone())]); - assert_eq!(merge_chains(lhs, rhs), Ok(vec![pk0, pk1, pk2, pk3])); + assert_eq!(join_chains(lhs, rhs), Ok(vec![pk0, pk1, pk2, pk3])); // lhs: 0->1 // rhs: 2->3 // out: Err(Incompatible) let lhs = make_chain(pk0, vec![(&pk0, pk1, sig1)]); let rhs = make_chain(pk2, vec![(&pk2, pk3, sig3)]); - assert_eq!(merge_chains(lhs, rhs), Err(Error::InvalidOperation)); + assert_eq!(join_chains(lhs, rhs), Err(Error::InvalidOperation)); } #[test] @@ -259,7 +284,7 @@ fn merge_fork() { vec![pk0, pk2, pk1] }; - assert_eq!(merge_chains(lhs, rhs), Ok(expected)) + assert_eq!(join_chains(lhs, rhs), Ok(expected)) } #[test] @@ -578,7 +603,7 @@ fn self_verify() { let (_, pk4, sig4) = gen_signed_keypair(&sk3); let fork_chain = make_chain(pk1, vec![(&pk1, pk3, sig3), (&pk3, pk4, sig4)]); - assert_eq!(main_chain.merge(fork_chain), Ok(())); + assert_eq!(main_chain.join(fork_chain), Ok(())); assert!(main_chain.self_verify()); // create another fork (from root key) with valid signatures @@ -590,7 +615,7 @@ fn self_verify() { let (_, pk6, sig6) = gen_signed_keypair(&sk5); let fork_chain = make_chain(pk0, vec![(&pk0, pk5, sig5), (&pk5, pk6, sig6)]); - assert_eq!(main_chain.merge(fork_chain), Ok(())); + assert_eq!(main_chain.join(fork_chain), Ok(())); assert!(main_chain.self_verify()); } @@ -673,11 +698,11 @@ fn make_chain( } // Merge `rhs` into `lhs`, verify the resulting chain is valid and return a vector of its keys. -fn merge_chains( +fn join_chains( mut lhs: SecuredLinkedList, rhs: SecuredLinkedList, ) -> Result, Error> { - lhs.merge(rhs)?; + lhs.join(rhs)?; Ok(lhs.keys().copied().collect()) }