Skip to content

Commit

Permalink
feat!: change load_from_persistence to return an option
Browse files Browse the repository at this point in the history
`PersistBackend::is_empty` is removed. Instead, `load_from_persistence`
returns an option of the changeset. `None` means persistence is empty.
This is a better API than a separate method. We can now differentiate
between a persisted single changeset and nothing persisted.

`Store::aggregate_changeset` is refactored to return a `Result` instead
of a tuple. A new error type (`AggregateChangesetsError`) is introduced
to include the partially-aggregated changeset in the error. This is a
more idiomatic API.
  • Loading branch information
evanlinjin committed Nov 1, 2023
1 parent 7e0cd98 commit 632b806
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 86 deletions.
94 changes: 64 additions & 30 deletions crates/bdk/src/wallet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ pub enum LoadError<L> {
Descriptor(crate::descriptor::DescriptorError),
/// Loading data from the persistence backend failed.
Load(L),
/// Wallet not initialized, persistence backend is empty.
NotInitialized,
/// Data loaded from persistence is missing network type.
MissingNetwork,
/// Data loaded from persistence is missing genesis hash.
Expand All @@ -300,6 +302,9 @@ where
match self {
LoadError::Descriptor(e) => e.fmt(f),
LoadError::Load(e) => e.fmt(f),
LoadError::NotInitialized => {
write!(f, "wallet is not initialized, persistence backend is empty")
}
LoadError::MissingNetwork => write!(f, "loaded data is missing network type"),
LoadError::MissingGenesis => write!(f, "loaded data is missing genesis hash"),
}
Expand All @@ -323,6 +328,8 @@ pub enum NewOrLoadError<W, L> {
Write(W),
/// Loading from the persistence backend failed.
Load(L),
/// Wallet is not initialized, persistence backend is empty.
NotInitialized,
/// The loaded genesis hash does not match what was provided.
LoadedGenesisDoesNotMatch {
/// The expected genesis block hash.
Expand All @@ -349,6 +356,9 @@ where
NewOrLoadError::Descriptor(e) => e.fmt(f),
NewOrLoadError::Write(e) => write!(f, "failed to write to persistence: {}", e),
NewOrLoadError::Load(e) => write!(f, "failed to load from persistence: {}", e),
NewOrLoadError::NotInitialized => {
write!(f, "wallet is not initialized, persistence backend is empty")
}
NewOrLoadError::LoadedGenesisDoesNotMatch { expected, got } => {
write!(f, "loaded genesis hash is not {}, got {:?}", expected, got)
}
Expand Down Expand Up @@ -444,11 +454,26 @@ impl<D> Wallet<D> {
change_descriptor: Option<E>,
mut db: D,
) -> Result<Self, LoadError<D::LoadError>>
where
D: PersistBackend<ChangeSet>,
{
let changeset = db
.load_from_persistence()
.map_err(LoadError::Load)?
.ok_or(LoadError::NotInitialized)?;
Self::load_from_changeset(descriptor, change_descriptor, db, changeset)
}

fn load_from_changeset<E: IntoWalletDescriptor>(
descriptor: E,
change_descriptor: Option<E>,
db: D,
changeset: ChangeSet,
) -> Result<Self, LoadError<D::LoadError>>
where
D: PersistBackend<ChangeSet>,
{
let secp = Secp256k1::new();
let changeset = db.load_from_persistence().map_err(LoadError::Load)?;
let network = changeset.network.ok_or(LoadError::MissingNetwork)?;
let chain =
LocalChain::from_changeset(changeset.chain).map_err(|_| LoadError::MissingGenesis)?;
Expand Down Expand Up @@ -510,8 +535,43 @@ impl<D> Wallet<D> {
where
D: PersistBackend<ChangeSet>,
{
if db.is_empty().map_err(NewOrLoadError::Load)? {
return Self::new_with_genesis_hash(
let changeset = db.load_from_persistence().map_err(NewOrLoadError::Load)?;
match changeset {
Some(changeset) => {
let wallet =
Self::load_from_changeset(descriptor, change_descriptor, db, changeset)
.map_err(|e| match e {
LoadError::Descriptor(e) => NewOrLoadError::Descriptor(e),
LoadError::Load(e) => NewOrLoadError::Load(e),
LoadError::NotInitialized => NewOrLoadError::NotInitialized,
LoadError::MissingNetwork => {
NewOrLoadError::LoadedNetworkDoesNotMatch {
expected: network,
got: None,
}
}
LoadError::MissingGenesis => {
NewOrLoadError::LoadedGenesisDoesNotMatch {
expected: genesis_hash,
got: None,
}
}
})?;
if wallet.network != network {
return Err(NewOrLoadError::LoadedNetworkDoesNotMatch {
expected: network,
got: Some(wallet.network),
});
}
if wallet.chain.genesis_hash() != genesis_hash {
return Err(NewOrLoadError::LoadedGenesisDoesNotMatch {
expected: genesis_hash,
got: Some(wallet.chain.genesis_hash()),
});
}
Ok(wallet)
}
None => Self::new_with_genesis_hash(
descriptor,
change_descriptor,
db,
Expand All @@ -521,34 +581,8 @@ impl<D> Wallet<D> {
.map_err(|e| match e {
NewError::Descriptor(e) => NewOrLoadError::Descriptor(e),
NewError::Write(e) => NewOrLoadError::Write(e),
});
}

let wallet = Self::load(descriptor, change_descriptor, db).map_err(|e| match e {
LoadError::Descriptor(e) => NewOrLoadError::Descriptor(e),
LoadError::Load(e) => NewOrLoadError::Load(e),
LoadError::MissingNetwork => NewOrLoadError::LoadedNetworkDoesNotMatch {
expected: network,
got: None,
},
LoadError::MissingGenesis => NewOrLoadError::LoadedGenesisDoesNotMatch {
expected: genesis_hash,
got: None,
},
})?;
if wallet.network != network {
return Err(NewOrLoadError::LoadedNetworkDoesNotMatch {
expected: network,
got: Some(wallet.network),
});
}
if wallet.chain.genesis_hash() != genesis_hash {
return Err(NewOrLoadError::LoadedGenesisDoesNotMatch {
expected: genesis_hash,
got: Some(wallet.chain.genesis_hash()),
});
}),
}
Ok(wallet)
}

/// Get the Bitcoin network the wallet is using.
Expand Down
21 changes: 4 additions & 17 deletions crates/chain/src/persist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,19 +79,10 @@ pub trait PersistBackend<C> {
fn write_changes(&mut self, changeset: &C) -> Result<(), Self::WriteError>;

/// Return the aggregate changeset `C` from persistence.
fn load_from_persistence(&mut self) -> Result<C, Self::LoadError>;

/// Returns whether the persistence backend contains no data.
fn is_empty(&mut self) -> Result<bool, Self::LoadError>
where
C: Append,
{
self.load_from_persistence()
.map(|changeset| changeset.is_empty())
}
fn load_from_persistence(&mut self) -> Result<Option<C>, Self::LoadError>;
}

impl<C: Default> PersistBackend<C> for () {
impl<C> PersistBackend<C> for () {
type WriteError = Infallible;

type LoadError = Infallible;
Expand All @@ -100,11 +91,7 @@ impl<C: Default> PersistBackend<C> for () {
Ok(())
}

fn load_from_persistence(&mut self) -> Result<C, Self::LoadError> {
Ok(C::default())
}

fn is_empty(&mut self) -> Result<bool, Self::LoadError> {
Ok(true)
fn load_from_persistence(&mut self) -> Result<Option<C>, Self::LoadError> {
Ok(None)
}
}
79 changes: 41 additions & 38 deletions crates/file_store/src/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub struct Store<'a, C> {

impl<'a, C> PersistBackend<C> for Store<'a, C>
where
C: Default + Append + serde::Serialize + serde::de::DeserializeOwned,
C: Append + serde::Serialize + serde::de::DeserializeOwned,
{
type WriteError = std::io::Error;

Expand All @@ -33,23 +33,14 @@ where
self.append_changeset(changeset)
}

fn load_from_persistence(&mut self) -> Result<C, Self::LoadError> {
let (changeset, result) = self.aggregate_changesets();
result.map(|_| changeset)
}

fn is_empty(&mut self) -> Result<bool, Self::LoadError> {
let init_pos = self.db_file.stream_position()?;
let stream_len = self.db_file.seek(io::SeekFrom::End(0))?;
let magic_len = self.magic.len() as u64;
self.db_file.seek(io::SeekFrom::Start(init_pos))?;
Ok(stream_len == magic_len)
fn load_from_persistence(&mut self) -> Result<Option<C>, Self::LoadError> {
self.aggregate_changesets().map_err(|e| e.iter_error)
}
}

impl<'a, C> Store<'a, C>
where
C: Default + Append + serde::Serialize + serde::de::DeserializeOwned,
C: Append + serde::Serialize + serde::de::DeserializeOwned,
{
/// Create a new [`Store`] file in write-only mode; error if the file exists.
///
Expand Down Expand Up @@ -156,16 +147,24 @@ where
///
/// **WARNING**: This method changes the write position of the underlying file. The next
/// changeset will be written over the erroring entry (or the end of the file if none existed).
pub fn aggregate_changesets(&mut self) -> (C, Result<(), IterError>) {
let mut changeset = C::default();
let result = (|| {
for next_changeset in self.iter_changesets() {
changeset.append(next_changeset?);
pub fn aggregate_changesets(&mut self) -> Result<Option<C>, AggregateChangesetsError<C>> {
let mut changeset = Option::<C>::None;
for next_changeset in self.iter_changesets() {
let next_changeset = match next_changeset {
Ok(next_changeset) => next_changeset,
Err(iter_error) => {
return Err(AggregateChangesetsError {
changeset,
iter_error,
})
}
};
match &mut changeset {
Some(changeset) => changeset.append(next_changeset),
changeset => *changeset = Some(next_changeset),
}
Ok(())
})();

(changeset, result)
}
Ok(changeset)
}

/// Append a new changeset to the file and truncate the file to the end of the appended
Expand Down Expand Up @@ -196,6 +195,24 @@ where
}
}

/// Error type for [`Store::aggregate_changesets`].
#[derive(Debug)]
pub struct AggregateChangesetsError<C> {
/// The partially-aggregated changeset.
pub changeset: Option<C>,

/// The error returned by [`EntryIter`].
pub iter_error: IterError,
}

impl<C> std::fmt::Display for AggregateChangesetsError<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(&self.iter_error, f)
}
}

impl<C: std::fmt::Debug> std::error::Error for AggregateChangesetsError<C> {}

#[cfg(test)]
mod test {
use super::*;
Expand Down Expand Up @@ -248,25 +265,11 @@ mod test {
{
let mut db = Store::<TestChangeSet>::open_or_create_new(&TEST_MAGIC_BYTES, &file_path)
.expect("must recover");
let (recovered_changeset, r) = db.aggregate_changesets();
r.expect("must succeed");
assert_eq!(recovered_changeset, changeset);
let recovered_changeset = db.aggregate_changesets().expect("must succeed");
assert_eq!(recovered_changeset, Some(changeset));
}
}

#[test]
fn is_empty() {
let mut file = NamedTempFile::new().unwrap();
file.write_all(&TEST_MAGIC_BYTES).expect("should write");

let mut db =
Store::<TestChangeSet>::open(&TEST_MAGIC_BYTES, file.path()).expect("must open");
assert!(db.is_empty().expect("must read"));
db.write_changes(&vec!["hello".to_string(), "world".to_string()])
.expect("must write");
assert!(!db.is_empty().expect("must read"));
}

#[test]
fn new_fails_if_file_is_too_short() {
let mut file = NamedTempFile::new().unwrap();
Expand Down
2 changes: 1 addition & 1 deletion example-crates/example_cli/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ where
Err(err) => return Err(anyhow::anyhow!("failed to init db backend: {:?}", err)),
};

let init_changeset = db_backend.load_from_persistence()?;
let init_changeset = db_backend.load_from_persistence()?.unwrap_or_default();

Ok((
args,
Expand Down

0 comments on commit 632b806

Please sign in to comment.