diff --git a/.gitignore b/.gitignore index ba35aa3..0b2ce60 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ -/target +target Cargo.lock - +*.sqlite .env diff --git a/Cargo.toml b/Cargo.toml index 89cda77..e117ba8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,34 +1,24 @@ [package] -name = "bdk-sqlx" +name = "bdk_sqlx" version = "0.1.0" edition = "2021" -[lib] -name = "bdk_sqlx" -path = "src/lib.rs" - -[[bin]] -name = "async_wallet_bdk_sqlx" -path = "src/main.rs" - [dependencies] -sqlx = { version = "0.8.1", default-features = false, features = ["runtime-tokio", "tls-rustls-ring","derive", "postgres", "json", "chrono", "uuid", "sqlx-macros", "migrate"] } -thiserror = "1" -tokio = { version = "1.40.0", features = ["macros", "rt-multi-thread"] } +bdk_wallet = { version = "1.0.0-beta.5" } serde = { version = "1.0.208", features = ["derive"] } serde_json = "1.0.125" -better-panic = "0.3.0" -rustls = "0.23.12" +sqlx = { version = "0.8.1", default-features = false, features = ["runtime-tokio", "tls-rustls-ring","derive", "postgres", "sqlite", "json", "chrono", "uuid", "sqlx-macros", "migrate"] } +thiserror = "1" +tokio = { version = "1.40.0", features = ["macros", "rt-multi-thread"] } tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["env-filter", "serde_json", "json"] } -anyhow = "1.0.86" -rand = "0.8.5" -uuid = "1.10.0" -assert_matches = "1.5.0" -pg-embed = { version = "0.7.1", features = ["default"] } +[dev-dependencies] +assert_matches = "1.5.0" +anyhow = "1.0.89" +bdk_electrum = { version = "0.19.0"} +rustls = "0.23.14" -bdk_wallet = { git = "https://github.com/bitcoindevkit/bdk", tag = "v1.0.0-beta.2", features = ["std"], default-features = false } -bdk_chain = { git = "https://github.com/bitcoindevkit/bdk", tag = "v1.0.0-beta.2" } -bdk_electrum = { git = "https://github.com/bitcoindevkit/bdk", tag = "v1.0.0-beta.2" } -bdk_testenv = { git = "https://github.com/bitcoindevkit/bdk", tag = "v1.0.0-beta.2" } +[[example]] +name = "bdk_sqlx_postgres" +path = "examples/bdk_sqlx_postgres.rs" \ No newline at end of file diff --git a/README.md b/README.md index b9da0dc..6ce4bdb 100644 --- a/README.md +++ b/README.md @@ -1 +1,43 @@ # bdk-sqlx + +## Status + +This crate is still **EXPERIMENTAL** do not use with mainnet wallets. + +## Testing + +1. Install postgresql with `psql` tool. For example (macos): + ``` + brew update + brew install postgresql + ``` +2. Create empty test database: + ``` + psql postgres + postgres=# create database test_bdk_wallet; + ``` +3. Set DATABASE_URL to test database: + ``` + export DATABASE_TEST_URL=postgresql://localhost/test_bdk_wallet + ``` +4. Run tests, must use a single test thread since we reuse the postgres db: + ``` + cargo test -- --test-threads=1 + ``` + +## Example + +1. Create empty test database: + ``` + psql postgres + postgres=# create database example_bdk_wallet; + postgres=# \q + ``` +2. Set DATABASE_URL to test database: + ``` + export DATABASE_URL=postgresql://localhost/example_bdk_wallet + ``` +3. Run example: + ``` + cargo run --example bdk_sqlx_postgres + ``` \ No newline at end of file diff --git a/src/main.rs b/examples/bdk_sqlx_postgres.rs similarity index 94% rename from src/main.rs rename to examples/bdk_sqlx_postgres.rs index 6288b90..084b2aa 100644 --- a/src/main.rs +++ b/examples/bdk_sqlx_postgres.rs @@ -3,11 +3,11 @@ use std::collections::HashSet; use std::io::Write; use bdk_electrum::{electrum_client, BdkElectrumClient}; +use bdk_sqlx::sqlx::Postgres; use bdk_sqlx::Store; use bdk_wallet::bitcoin::secp256k1::Secp256k1; use bdk_wallet::bitcoin::Network; use bdk_wallet::{KeychainKind, PersistedWallet, Wallet}; -use better_panic::Settings; use rustls::crypto::ring::default_provider; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; @@ -34,15 +34,12 @@ async fn main() -> anyhow::Result<()> { default_provider() .install_default() .expect("Failed to install rustls default crypto provider"); - Settings::debug() - .most_recent_first(false) - .lineno_suffix(true) - .install(); + tracing_subscriber::registry() .with(EnvFilter::new(std::env::var("RUST_LOG").unwrap_or_else( |_| { "sqlx=warn,\ - bdk_sqlx=info" + bdk_sqlx=debug" .into() }, ))) @@ -59,7 +56,8 @@ async fn main() -> anyhow::Result<()> { NETWORK, &secp, )?; - let mut store = bdk_sqlx::Store::new_with_url(url.clone(), Some(wallet_name)).await?; + let mut store = + bdk_sqlx::Store::::new_with_url(url.clone(), wallet_name, true).await?; let mut wallet = match Wallet::load().load_wallet_async(&mut store).await? { Some(wallet) => wallet, @@ -86,7 +84,8 @@ async fn main() -> anyhow::Result<()> { let wallet_name = bdk_wallet::wallet_name_from_descriptor(VAULT_DESC, Some(CHANGE_DESC), NETWORK, &secp)?; - let mut store = bdk_sqlx::Store::new_with_url(url.clone(), Some(wallet_name)).await?; + let mut store = + bdk_sqlx::Store::::new_with_url(url.clone(), wallet_name, true).await?; let mut wallet = match Wallet::load().load_wallet_async(&mut store).await? { Some(wallet) => wallet, @@ -114,10 +113,10 @@ async fn main() -> anyhow::Result<()> { Ok(()) } -fn electrum(wallet: &mut PersistedWallet) -> anyhow::Result<()> { +fn electrum(wallet: &mut PersistedWallet>) -> anyhow::Result<()> { let client = BdkElectrumClient::new(electrum_client::Client::new(ELECTRUM_URL)?); - // Populate the electrum client's transaction cache so it doesn't redownload transaction we + // Populate the electrum client's transaction cache so it doesn't re-download transaction we // already have. client.populate_tx_cache(wallet.tx_graph().full_txs().map(|tx_node| tx_node.tx)); diff --git a/db/migrations/01_bdk_wallet.sql b/migrations/postgres/01_bdk_wallet.sql similarity index 94% rename from db/migrations/01_bdk_wallet.sql rename to migrations/postgres/01_bdk_wallet.sql index 2637800..d7d806a 100644 --- a/db/migrations/01_bdk_wallet.sql +++ b/migrations/postgres/01_bdk_wallet.sql @@ -26,7 +26,7 @@ CREATE TABLE IF NOT EXISTS keychain ( -- Hash is block hash hex string, -- Block height is a u32 CREATE TABLE IF NOT EXISTS block ( - wallet_name TEXT, + wallet_name TEXT NOT NULL, hash TEXT NOT NULL, height INTEGER NOT NULL, PRIMARY KEY (wallet_name, hash) @@ -37,7 +37,7 @@ CREATE INDEX idx_block_height ON block (height); -- Whole_tx is a consensus encoded transaction, -- Last seen is a u64 unix epoch seconds CREATE TABLE IF NOT EXISTS tx ( - wallet_name TEXT, + wallet_name TEXT NOT NULL, txid TEXT NOT NULL, whole_tx BYTEA, last_seen BIGINT, @@ -49,7 +49,7 @@ CREATE TABLE IF NOT EXISTS tx ( -- TxOut value as SATs -- TxOut script consensus encoded CREATE TABLE IF NOT EXISTS txout ( - wallet_name TEXT, + wallet_name TEXT NOT NULL, txid TEXT NOT NULL, vout INTEGER NOT NULL, value BIGINT NOT NULL, @@ -62,7 +62,7 @@ CREATE TABLE IF NOT EXISTS txout ( -- Anchor is a json serialized Anchor structure as JSONB, -- Txid is transaction hash hex string (reversed) CREATE TABLE IF NOT EXISTS anchor_tx ( - wallet_name TEXT, + wallet_name TEXT NOT NULL, block_hash TEXT NOT NULL, anchor JSONB NOT NULL, txid TEXT NOT NULL, diff --git a/migrations/sqlite/01_bdk_wallet.sql b/migrations/sqlite/01_bdk_wallet.sql new file mode 100644 index 0000000..f6cd973 --- /dev/null +++ b/migrations/sqlite/01_bdk_wallet.sql @@ -0,0 +1,73 @@ +-- Schema version control +CREATE TABLE IF NOT EXISTS version ( + version INTEGER PRIMARY KEY +); + +-- Network is the valid network for all other table data +CREATE TABLE IF NOT EXISTS network ( + wallet_name TEXT PRIMARY KEY, + name TEXT NOT NULL +); + +-- Keychain is the json serialized keychain structure as JSONB, +-- descriptor is the complete descriptor string, +-- descriptor_id is a sha256::Hash id of the descriptor string w/o the checksum, +-- last revealed index is a u32 +CREATE TABLE IF NOT EXISTS keychain ( + wallet_name TEXT NOT NULL, + keychainkind TEXT NOT NULL, + descriptor TEXT NOT NULL, + descriptor_id BLOB NOT NULL, + last_revealed INTEGER DEFAULT 0, + PRIMARY KEY (wallet_name, keychainkind) + +); + +-- Hash is block hash hex string, +-- Block height is a u32 +CREATE TABLE IF NOT EXISTS block ( + wallet_name TEXT NOT NULL, + hash TEXT NOT NULL, + height INTEGER NOT NULL, + PRIMARY KEY (wallet_name, hash) +); +CREATE INDEX idx_block_height ON block (height); + +-- Txid is transaction hash hex string (reversed) +-- Whole_tx is a consensus encoded transaction, +-- Last seen is a u64 unix epoch seconds +CREATE TABLE IF NOT EXISTS tx ( + wallet_name TEXT NOT NULL, + txid TEXT NOT NULL, + whole_tx BLOB, + last_seen INTEGER, + PRIMARY KEY (wallet_name, txid) +); + +-- Outpoint txid hash hex string (reversed) +-- Outpoint vout +-- TxOut value as SATs +-- TxOut script consensus encoded +CREATE TABLE IF NOT EXISTS txout ( + wallet_name TEXT NOT NULL, + txid TEXT NOT NULL, + vout INTEGER NOT NULL, + value INTEGER NOT NULL, + script BLOB NOT NULL, + PRIMARY KEY (wallet_name, txid, vout) +); + +-- Join table between anchor and tx +-- Block hash hex string +-- Anchor is a json serialized Anchor structure as JSONB, +-- Txid is transaction hash hex string (reversed) +CREATE TABLE IF NOT EXISTS anchor_tx ( + wallet_name TEXT NOT NULL, + block_hash TEXT NOT NULL, + anchor BLOB NOT NULL, + txid TEXT NOT NULL, + PRIMARY KEY (wallet_name, block_hash, txid), + FOREIGN KEY (wallet_name, block_hash) REFERENCES block(wallet_name, hash), + FOREIGN KEY (wallet_name, txid) REFERENCES tx(wallet_name, txid) +); +CREATE INDEX idx_anchor_tx_txid ON anchor_tx (txid); diff --git a/src/lib.rs b/src/lib.rs index 72ce009..ecf8ac6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,38 +2,21 @@ #![warn(missing_docs)] -use std::future::Future; -use std::pin::Pin; -use std::str::FromStr; -use std::sync::Arc; - -use bdk_chain::{ - local_chain, miniscript, tx_graph, Anchor, ConfirmationBlockTime, DescriptorExt, DescriptorId, - Merge, -}; -use bdk_wallet::bitcoin::{ - self, - consensus::{self, Decodable}, - hashes::Hash, - Amount, BlockHash, Network, OutPoint, ScriptBuf, TxOut, Txid, -}; -use bdk_wallet::chain as bdk_chain; -use bdk_wallet::descriptor::{Descriptor, DescriptorPublicKey, ExtendedDescriptor}; -use bdk_wallet::KeychainKind::{External, Internal}; -use bdk_wallet::{AsyncWalletPersister, ChangeSet, KeychainKind}; -use serde_json::json; -use sqlx::migrate::Migrator; -use sqlx::postgres::PgRow; -use sqlx::{ - postgres::{PgPool, Postgres}, - FromRow, Pool, Row, Transaction, -}; -use tokio::sync::Mutex; -use tracing::info; +mod postgres; +mod sqlite; #[cfg(test)] mod test; +use std::future::Future; +use std::pin::Pin; + +use bdk_wallet::bitcoin; +use bdk_wallet::chain::miniscript; +pub use sqlx; +use sqlx::Database; +use sqlx::Pool; + /// Crate error #[derive(Debug, thiserror::Error)] pub enum BdkSqlxError { @@ -49,523 +32,16 @@ pub enum BdkSqlxError { /// sqlx error #[error("sqlx error: {0}")] Sqlx(#[from] sqlx::Error), + /// migrate error + #[error("migrate error: {0}")] + Migrate(#[from] sqlx::migrate::MigrateError), } /// Manages a pool of database connections. -#[derive(Debug)] -pub struct Store { - pub(crate) pool: Arc>>, +#[derive(Debug, Clone)] +pub struct Store { + pub(crate) pool: Pool, wallet_name: String, - migration: bool, -} - -impl Store { - /// Construct a new [`Store`] with an existing pg connection. - #[tracing::instrument] - pub async fn new( - pool: Arc>>, - wallet_name: Option, - migration: bool, - ) -> Result { - info!("new store"); - - let wallet_name = wallet_name.unwrap_or_else(|| "bdk_pg_wallet".to_string()); - - Ok(Self { - pool, - wallet_name, - migration, - }) - } - - /// Construct a new [`Store`] without an existing pg connection. - #[tracing::instrument] - pub async fn new_with_url( - url: String, - wallet_name: Option, - ) -> Result { - info!("new store with url"); - - let pool = PgPool::connect(url.as_str()).await?; - let pool = Arc::new(Mutex::new(pool)); - let wallet_name = wallet_name.unwrap_or_else(|| "bdk_pg_wallet".to_string()); - - Ok(Self { - pool, - wallet_name, - migration: true, - }) - } -} - -impl AsyncWalletPersister for Store { - type Error = BdkSqlxError; - - #[tracing::instrument] - fn initialize<'a>(store: &'a mut Self) -> FutureResult<'a, ChangeSet, Self::Error> - where - Self: 'a, - { - info!("initialize store"); - Box::pin(store.migrate_and_read()) - } - - #[tracing::instrument] - fn persist<'a>( - store: &'a mut Self, - changeset: &'a ChangeSet, - ) -> FutureResult<'a, (), Self::Error> - where - Self: 'a, - { - info!("persist store"); - Box::pin(store.write(changeset)) - } } type FutureResult<'a, T, E> = Pin> + Send + 'a>>; - -impl Store { - #[tracing::instrument] - async fn migrate_and_read(&self) -> Result { - info!("migrate and read"); - let pool = self.pool.lock().await; - if self.migration { - let migrator = Migrator::new(std::path::Path::new("./db/migrations/")) - .await - .unwrap(); - migrator.run(&*pool).await.unwrap(); - } - - let mut tx = pool.begin().await?; - - let mut changeset = ChangeSet::default(); - - let sql = - "SELECT n.name as network, - k_int.descriptor as internal_descriptor, k_int.last_revealed as internal_last_revealed, - k_ext.descriptor as external_descriptor, k_ext.last_revealed as external_last_revealed - FROM network n - LEFT JOIN keychain k_int ON n.wallet_name = k_int.wallet_name AND k_int.keychainkind = 'Internal' - LEFT JOIN keychain k_ext ON n.wallet_name = k_ext.wallet_name AND k_ext.keychainkind = 'External' - WHERE n.wallet_name = $1"; - - // Fetch wallet data - let row = sqlx::query(sql) - .bind(&self.wallet_name) - .fetch_optional(&mut *tx) - .await?; - - dbg!(&row); - - if let Some(row) = row { - Self::changeset_from_row(&mut tx, &mut changeset, row, &self.wallet_name).await?; - } - - Ok(changeset) - } - - #[tracing::instrument] - async fn changeset_from_row( - tx: &mut Transaction<'_, Postgres>, - changeset: &mut ChangeSet, - row: PgRow, - wallet_name: &str, - ) -> Result<(), BdkSqlxError> { - info!("changeset from row"); - - let network: String = row.get("network"); - let internal_last_revealed: Option = row.get("internal_last_revealed"); - let external_last_revealed: Option = row.get("external_last_revealed"); - let internal_desc_str: Option = row.get("internal_descriptor"); - let external_desc_str: Option = row.get("external_descriptor"); - - changeset.network = Some(Network::from_str(&network).expect("parse Network")); - - if let Some(desc_str) = external_desc_str { - let descriptor: Descriptor = desc_str.parse()?; - let did = descriptor.descriptor_id(); - changeset.descriptor = Some(descriptor); - if let Some(last_rev) = external_last_revealed { - changeset.indexer.last_revealed.insert(did, last_rev as u32); - } - } - - if let Some(desc_str) = internal_desc_str { - let descriptor: Descriptor = desc_str.parse()?; - let did = descriptor.descriptor_id(); - changeset.change_descriptor = Some(descriptor); - if let Some(last_rev) = internal_last_revealed { - changeset.indexer.last_revealed.insert(did, last_rev as u32); - } - } - - changeset.tx_graph = tx_graph_changeset_from_postgres(tx, wallet_name).await?; - changeset.local_chain = local_chain_changeset_from_postgres(tx, wallet_name).await?; - Ok(()) - } - - #[tracing::instrument] - async fn write(&self, changeset: &ChangeSet) -> Result<(), BdkSqlxError> { - info!("changeset write"); - if changeset.is_empty() { - return Ok(()); - } - - let wallet_name = &self.wallet_name; - let pool = self.pool.lock().await; - let mut tx = pool.begin().await?; - - if let Some(ref descriptor) = changeset.descriptor { - insert_descriptor(&mut tx, wallet_name, descriptor, External).await?; - } - - if let Some(ref change_descriptor) = changeset.change_descriptor { - insert_descriptor(&mut tx, wallet_name, change_descriptor, Internal).await?; - } - - if let Some(network) = changeset.network { - insert_network(&mut tx, wallet_name, network).await?; - } - - let last_revealed_indices = &changeset.indexer.last_revealed; - if !last_revealed_indices.is_empty() { - for (desc_id, index) in last_revealed_indices { - update_last_revealed(&mut tx, wallet_name, *desc_id, *index).await?; - } - } - - local_chain_changeset_persist_to_postgres(&mut tx, wallet_name, &changeset.local_chain) - .await?; - tx_graph_changeset_persist_to_postgres(&mut tx, wallet_name, &changeset.tx_graph).await?; - - tx.commit().await?; - - Ok(()) - } -} - -/// Insert keychain descriptors. -#[tracing::instrument] -async fn insert_descriptor( - tx: &mut Transaction<'_, Postgres>, - wallet_name: &str, - descriptor: &ExtendedDescriptor, - keychain: KeychainKind, -) -> Result<(), BdkSqlxError> { - info!("insert descriptor"); - let descriptor_str = descriptor.to_string(); - - let descriptor_id = descriptor.descriptor_id().to_byte_array(); - let keychain = match keychain { - External => "External", - Internal => "Internal", - }; - - sqlx::query( - "INSERT INTO keychain (wallet_name, keychainkind, descriptor, descriptor_id) VALUES ($1, $2, $3, $4)", - ) - .bind(wallet_name) - .bind(keychain) - .bind(descriptor_str) - .bind(descriptor_id.as_slice()) - .execute(&mut **tx) - .await?; - - Ok(()) -} - -/// Insert network. -#[tracing::instrument] -async fn insert_network( - tx: &mut Transaction<'_, Postgres>, - wallet_name: &str, - network: Network, -) -> Result<(), BdkSqlxError> { - info!("insert network"); - sqlx::query("INSERT INTO network (wallet_name, name) VALUES ($1, $2)") - .bind(wallet_name) - .bind(network.to_string()) - .execute(&mut **tx) - .await?; - - Ok(()) -} - -/// Update keychain last revealed -#[tracing::instrument] -async fn update_last_revealed( - tx: &mut Transaction<'_, Postgres>, - wallet_name: &str, - descriptor_id: DescriptorId, - last_revealed: u32, -) -> Result<(), BdkSqlxError> { - info!("update last revealed"); - - sqlx::query( - "UPDATE keychain SET last_revealed = $1 WHERE wallet_name = $2 AND descriptor_id = $3", - ) - .bind(last_revealed as i32) - .bind(wallet_name) - .bind(descriptor_id.to_byte_array()) - .execute(&mut **tx) - .await?; - - Ok(()) -} - -/// Select transactions, txouts, and anchors. -#[tracing::instrument] -pub async fn tx_graph_changeset_from_postgres( - db_tx: &mut Transaction<'_, Postgres>, - wallet_name: &str, -) -> Result, BdkSqlxError> { - info!("tx graph changeset from postgres"); - let mut changeset = tx_graph::ChangeSet::default(); - - // Fetch transactions - let rows = sqlx::query("SELECT txid, whole_tx, last_seen FROM tx WHERE wallet_name = $1") - .bind(wallet_name) - .fetch_all(&mut **db_tx) - .await?; - - for row in rows { - let txid: String = row.get("txid"); - let txid = Txid::from_str(&txid)?; - let whole_tx: Option> = row.get("whole_tx"); - let last_seen: Option = row.get("last_seen"); - - if let Some(tx_bytes) = whole_tx { - if let Ok(tx) = bitcoin::Transaction::consensus_decode(&mut tx_bytes.as_slice()) { - changeset.txs.insert(Arc::new(tx)); - } - } - if let Some(last_seen) = last_seen { - changeset.last_seen.insert(txid, last_seen as u64); - } - } - - // Fetch txouts - let rows = sqlx::query("SELECT txid, vout, value, script FROM txout WHERE wallet_name = $1") - .bind(wallet_name) - .fetch_all(&mut **db_tx) - .await?; - - for row in rows { - let txid: String = row.get("txid"); - let txid = Txid::from_str(&txid)?; - let vout: i32 = row.get("vout"); - let value: i64 = row.get("value"); - let script: Vec = row.get("script"); - - changeset.txouts.insert( - OutPoint { - txid, - vout: vout as u32, - }, - TxOut { - value: Amount::from_sat(value as u64), - script_pubkey: ScriptBuf::from(script), - }, - ); - } - - // Fetch anchors - let rows = sqlx::query("SELECT anchor, txid FROM anchor_tx WHERE wallet_name = $1") - .bind(wallet_name) - .fetch_all(&mut **db_tx) - .await?; - - for row in rows { - let anchor: serde_json::Value = row.get("anchor"); - let txid: String = row.get("txid"); - let txid = Txid::from_str(&txid)?; - - if let Ok(anchor) = serde_json::from_value::(anchor) { - changeset.anchors.insert((anchor, txid)); - } - } - - Ok(changeset) -} - -/// Insert transactions, txouts, and anchors. -#[tracing::instrument] -pub async fn tx_graph_changeset_persist_to_postgres( - db_tx: &mut Transaction<'_, Postgres>, - wallet_name: &str, - changeset: &tx_graph::ChangeSet, -) -> Result<(), BdkSqlxError> { - info!("tx graph changeset from postgres"); - for tx in &changeset.txs { - sqlx::query( - "INSERT INTO tx (wallet_name, txid, whole_tx) VALUES ($1, $2, $3) - ON CONFLICT (wallet_name, txid) DO UPDATE SET whole_tx = $3", - ) - .bind(wallet_name) - .bind(tx.compute_txid().to_string()) - .bind(consensus::serialize(tx.as_ref())) - .execute(&mut **db_tx) - .await?; - } - - for (&txid, &last_seen) in &changeset.last_seen { - sqlx::query("UPDATE tx SET last_seen = $1 WHERE wallet_name = $2 AND txid = $3") - .bind(last_seen as i64) - .bind(wallet_name) - .bind(txid.to_string()) - .execute(&mut **db_tx) - .await?; - } - - for (op, txo) in &changeset.txouts { - sqlx::query( - "INSERT INTO txout (wallet_name, txid, vout, value, script) VALUES ($1, $2, $3, $4, $5) - ON CONFLICT (wallet_name, txid, vout) DO UPDATE SET value = $4, script = $5", - ) - .bind(wallet_name) - .bind(op.txid.to_string()) - .bind(op.vout as i32) - .bind(txo.value.to_sat() as i64) - .bind(txo.script_pubkey.as_bytes()) - .execute(&mut **db_tx) - .await?; - } - - for (anchor, txid) in &changeset.anchors { - let block_hash = anchor.anchor_block().hash; - let anchor = serde_json::to_value(anchor)?; - sqlx::query( - "INSERT INTO anchor_tx (wallet_name, block_hash, anchor, txid) VALUES ($1, $2, $3, $4) - ON CONFLICT (wallet_name, block_hash, txid) DO UPDATE SET anchor = $3", - ) - .bind(wallet_name) - .bind(block_hash.to_string()) - .bind(anchor) - .bind(txid.to_string()) - .execute(&mut **db_tx) - .await?; - } - - Ok(()) -} - -/// Select blocks. -#[tracing::instrument] -pub async fn local_chain_changeset_from_postgres( - db_tx: &mut Transaction<'_, Postgres>, - wallet_name: &str, -) -> Result { - info!("local chain changeset from postgres"); - let mut changeset = local_chain::ChangeSet::default(); - - let rows = sqlx::query("SELECT hash, height FROM block WHERE wallet_name = $1") - .bind(wallet_name) - .fetch_all(&mut **db_tx) - .await?; - - for row in rows { - let hash: String = row.get("hash"); - let height: i32 = row.get("height"); - let block_hash = BlockHash::from_str(&hash)?; - changeset.blocks.insert(height as u32, Some(block_hash)); - } - - Ok(changeset) -} - -/// Insert blocks. -#[tracing::instrument] -pub async fn local_chain_changeset_persist_to_postgres( - db_tx: &mut Transaction<'_, Postgres>, - wallet_name: &str, - changeset: &local_chain::ChangeSet, -) -> Result<(), BdkSqlxError> { - info!("local chain changeset to postgres"); - for (&height, &hash) in &changeset.blocks { - match hash { - Some(hash) => { - sqlx::query( - "INSERT INTO block (wallet_name, hash, height) VALUES ($1, $2, $3) - ON CONFLICT (wallet_name, hash) DO UPDATE SET height = $3", - ) - .bind(wallet_name) - .bind(hash.to_string()) - .bind(height as i32) - .execute(&mut **db_tx) - .await?; - } - None => { - sqlx::query("DELETE FROM block WHERE wallet_name = $1 AND height = $2") - .bind(wallet_name) - .bind(height as i32) - .execute(&mut **db_tx) - .await?; - } - } - } - - Ok(()) -} - -/// Drops all tables. -#[tracing::instrument] -pub async fn drop_all(db: Pool) -> Result<(), BdkSqlxError> { - info!("Dropping all tables"); - - let drop_statements = vec![ - "DROP TABLE IF EXISTS _sqlx_migrations", - "DROP TABLE IF EXISTS vault_addresses", - "DROP TABLE IF EXISTS used_anchorwatch_keys", - "DROP TABLE IF EXISTS anchorwatch_keys", - "DROP TABLE IF EXISTS psbts", - "DROP TABLE IF EXISTS whitelist_update", - "DROP TABLE IF EXISTS vault_parameters", - "DROP TABLE IF EXISTS users", - "DROP TABLE IF EXISTS version", - "DROP TABLE IF EXISTS anchor_tx", - "DROP TABLE IF EXISTS txout", - "DROP TABLE IF EXISTS tx", - "DROP TABLE IF EXISTS block", - "DROP TABLE IF EXISTS keychain", - "DROP TABLE IF EXISTS network", - ]; - - let mut tx = db.begin().await?; - - for statement in drop_statements { - sqlx::query(statement).execute(&mut *tx).await?; - } - - tx.commit().await?; - - Ok(()) -} - -/// Represents a row in the keychain table. -#[derive(serde::Serialize, FromRow)] -struct KeychainEntry { - wallet_name: String, - keychainkind: String, - descriptor: String, - descriptor_id: Vec, - last_revealed: i32, -} - -/// Collects information on all the wallets in the database and dumps it to stdout. -#[tracing::instrument] -pub async fn easy_backup(db: Pool) -> Result<(), BdkSqlxError> { - info!("Starting easy backup"); - - let statement = "SELECT * FROM keychain"; - - let results = sqlx::query_as::<_, KeychainEntry>(statement) - .fetch_all(&db) - .await?; - - let json_array = json!(results); - println!("{}", serde_json::to_string_pretty(&json_array)?); - - info!("Easy backup completed successfully"); - Ok(()) -} diff --git a/src/postgres.rs b/src/postgres.rs new file mode 100644 index 0000000..906f1ea --- /dev/null +++ b/src/postgres.rs @@ -0,0 +1,480 @@ +//! bdk-sqlx postgres store + +#![warn(missing_docs)] + +use std::str::FromStr; +use std::sync::Arc; + +use super::{BdkSqlxError, FutureResult, Store}; +use bdk_chain::{ + local_chain, tx_graph, Anchor, ConfirmationBlockTime, DescriptorExt, DescriptorId, Merge, +}; +use bdk_wallet::bitcoin::{ + self, + consensus::{self, Decodable}, + hashes::Hash, + Amount, BlockHash, Network, OutPoint, ScriptBuf, TxOut, Txid, +}; +use bdk_wallet::chain as bdk_chain; +use bdk_wallet::descriptor::{Descriptor, DescriptorPublicKey, ExtendedDescriptor}; +use bdk_wallet::KeychainKind::{External, Internal}; +use bdk_wallet::{AsyncWalletPersister, ChangeSet, KeychainKind}; +use serde_json::json; +use sqlx::postgres::PgRow; +use sqlx::sqlx_macros::migrate; +use sqlx::{ + postgres::{PgPool, Postgres}, + FromRow, Pool, Row, Transaction, +}; +use tracing::info; + +impl AsyncWalletPersister for Store { + type Error = BdkSqlxError; + + #[tracing::instrument] + fn initialize<'a>(store: &'a mut Self) -> FutureResult<'a, ChangeSet, Self::Error> + where + Self: 'a, + { + info!("initialize store"); + Box::pin(store.read()) + } + + #[tracing::instrument] + fn persist<'a>( + store: &'a mut Self, + changeset: &'a ChangeSet, + ) -> FutureResult<'a, (), Self::Error> + where + Self: 'a, + { + info!("persist store"); + Box::pin(store.write(changeset)) + } +} + +impl Store { + /// Construct a new [`Store`] with an existing pg connection. + #[tracing::instrument] + pub async fn new( + pool: Pool, + wallet_name: String, + migrate: bool, + ) -> Result { + info!("new postgres store"); + if migrate { + info!("migrate"); + migrate!("./migrations/postgres").run(&pool).await?; + } + Ok(Self { pool, wallet_name }) + } + + /// Construct a new [`Store`] without an existing pg connection. + #[tracing::instrument] + pub async fn new_with_url( + url: String, + wallet_name: String, + migrate: bool, + ) -> Result, BdkSqlxError> { + info!("new store with url"); + let pool = PgPool::connect(url.as_str()).await?; + Self::new(pool, wallet_name, migrate).await + } +} + +impl Store { + #[tracing::instrument] + pub(crate) async fn read(&self) -> Result { + let mut tx = self.pool.begin().await?; + let mut changeset = ChangeSet::default(); + let sql = + "SELECT n.name as network, + k_int.descriptor as internal_descriptor, k_int.last_revealed as internal_last_revealed, + k_ext.descriptor as external_descriptor, k_ext.last_revealed as external_last_revealed + FROM network n + LEFT JOIN keychain k_int ON n.wallet_name = k_int.wallet_name AND k_int.keychainkind = 'Internal' + LEFT JOIN keychain k_ext ON n.wallet_name = k_ext.wallet_name AND k_ext.keychainkind = 'External' + WHERE n.wallet_name = $1"; + + // Fetch wallet data + let row = sqlx::query(sql) + .bind(&self.wallet_name) + .fetch_optional(&mut *tx) + .await?; + + if let Some(row) = row { + Self::changeset_from_row(&mut tx, &mut changeset, row, &self.wallet_name).await?; + } + + Ok(changeset) + } + + #[tracing::instrument] + pub(crate) async fn changeset_from_row( + tx: &mut Transaction<'_, Postgres>, + changeset: &mut ChangeSet, + row: PgRow, + wallet_name: &str, + ) -> Result<(), BdkSqlxError> { + info!("changeset from row"); + + let network: String = row.get("network"); + let internal_last_revealed: Option = row.get("internal_last_revealed"); + let external_last_revealed: Option = row.get("external_last_revealed"); + let internal_desc_str: Option = row.get("internal_descriptor"); + let external_desc_str: Option = row.get("external_descriptor"); + + changeset.network = Some(Network::from_str(&network).expect("parse Network")); + + if let Some(desc_str) = external_desc_str { + let descriptor: Descriptor = desc_str.parse()?; + let did = descriptor.descriptor_id(); + changeset.descriptor = Some(descriptor); + if let Some(last_rev) = external_last_revealed { + changeset.indexer.last_revealed.insert(did, last_rev as u32); + } + } + + if let Some(desc_str) = internal_desc_str { + let descriptor: Descriptor = desc_str.parse()?; + let did = descriptor.descriptor_id(); + changeset.change_descriptor = Some(descriptor); + if let Some(last_rev) = internal_last_revealed { + changeset.indexer.last_revealed.insert(did, last_rev as u32); + } + } + + changeset.tx_graph = tx_graph_changeset_from_postgres(tx, wallet_name).await?; + changeset.local_chain = local_chain_changeset_from_postgres(tx, wallet_name).await?; + Ok(()) + } + + #[tracing::instrument] + pub(crate) async fn write(&self, changeset: &ChangeSet) -> Result<(), BdkSqlxError> { + info!("changeset write"); + if changeset.is_empty() { + return Ok(()); + } + + let wallet_name = &self.wallet_name; + let mut tx = self.pool.begin().await?; + + if let Some(ref descriptor) = changeset.descriptor { + insert_descriptor(&mut tx, wallet_name, descriptor, External).await?; + } + + if let Some(ref change_descriptor) = changeset.change_descriptor { + insert_descriptor(&mut tx, wallet_name, change_descriptor, Internal).await?; + } + + if let Some(network) = changeset.network { + insert_network(&mut tx, wallet_name, network).await?; + } + + let last_revealed_indices = &changeset.indexer.last_revealed; + if !last_revealed_indices.is_empty() { + for (desc_id, index) in last_revealed_indices { + update_last_revealed(&mut tx, wallet_name, *desc_id, *index).await?; + } + } + + local_chain_changeset_persist_to_postgres(&mut tx, wallet_name, &changeset.local_chain) + .await?; + tx_graph_changeset_persist_to_postgres(&mut tx, wallet_name, &changeset.tx_graph).await?; + + tx.commit().await?; + + Ok(()) + } +} + +/// Insert keychain descriptors. +#[tracing::instrument] +async fn insert_descriptor( + tx: &mut Transaction<'_, Postgres>, + wallet_name: &str, + descriptor: &ExtendedDescriptor, + keychain: KeychainKind, +) -> Result<(), BdkSqlxError> { + info!("insert descriptor"); + let descriptor_str = descriptor.to_string(); + + let descriptor_id = descriptor.descriptor_id().to_byte_array(); + let keychain = match keychain { + External => "External", + Internal => "Internal", + }; + + sqlx::query( + "INSERT INTO keychain (wallet_name, keychainkind, descriptor, descriptor_id) VALUES ($1, $2, $3, $4)", + ) + .bind(wallet_name) + .bind(keychain) + .bind(descriptor_str) + .bind(descriptor_id.as_slice()) + .execute(&mut **tx) + .await?; + + Ok(()) +} + +/// Insert network. +#[tracing::instrument] +async fn insert_network( + tx: &mut Transaction<'_, Postgres>, + wallet_name: &str, + network: Network, +) -> Result<(), BdkSqlxError> { + info!("insert network"); + sqlx::query("INSERT INTO network (wallet_name, name) VALUES ($1, $2)") + .bind(wallet_name) + .bind(network.to_string()) + .execute(&mut **tx) + .await?; + + Ok(()) +} + +/// Update keychain last revealed +#[tracing::instrument] +async fn update_last_revealed( + tx: &mut Transaction<'_, Postgres>, + wallet_name: &str, + descriptor_id: DescriptorId, + last_revealed: u32, +) -> Result<(), BdkSqlxError> { + info!("update last revealed"); + + sqlx::query( + "UPDATE keychain SET last_revealed = $1 WHERE wallet_name = $2 AND descriptor_id = $3", + ) + .bind(last_revealed as i32) + .bind(wallet_name) + .bind(descriptor_id.to_byte_array()) + .execute(&mut **tx) + .await?; + + Ok(()) +} + +/// Select transactions, txouts, and anchors. +#[tracing::instrument] +pub async fn tx_graph_changeset_from_postgres( + db_tx: &mut Transaction<'_, Postgres>, + wallet_name: &str, +) -> Result, BdkSqlxError> { + info!("tx graph changeset from postgres"); + let mut changeset = tx_graph::ChangeSet::default(); + + // Fetch transactions + let rows = sqlx::query("SELECT txid, whole_tx, last_seen FROM tx WHERE wallet_name = $1") + .bind(wallet_name) + .fetch_all(&mut **db_tx) + .await?; + + for row in rows { + let txid: String = row.get("txid"); + let txid = Txid::from_str(&txid)?; + let whole_tx: Option> = row.get("whole_tx"); + let last_seen: Option = row.get("last_seen"); + + if let Some(tx_bytes) = whole_tx { + if let Ok(tx) = bitcoin::Transaction::consensus_decode(&mut tx_bytes.as_slice()) { + changeset.txs.insert(Arc::new(tx)); + } + } + if let Some(last_seen) = last_seen { + changeset.last_seen.insert(txid, last_seen as u64); + } + } + + // Fetch txouts + let rows = sqlx::query("SELECT txid, vout, value, script FROM txout WHERE wallet_name = $1") + .bind(wallet_name) + .fetch_all(&mut **db_tx) + .await?; + + for row in rows { + let txid: String = row.get("txid"); + let txid = Txid::from_str(&txid)?; + let vout: i32 = row.get("vout"); + let value: i64 = row.get("value"); + let script: Vec = row.get("script"); + + changeset.txouts.insert( + OutPoint { + txid, + vout: vout as u32, + }, + TxOut { + value: Amount::from_sat(value as u64), + script_pubkey: ScriptBuf::from(script), + }, + ); + } + + // Fetch anchors + let rows = sqlx::query("SELECT anchor, txid FROM anchor_tx WHERE wallet_name = $1") + .bind(wallet_name) + .fetch_all(&mut **db_tx) + .await?; + + for row in rows { + let anchor: serde_json::Value = row.get("anchor"); + let txid: String = row.get("txid"); + let txid = Txid::from_str(&txid)?; + + if let Ok(anchor) = serde_json::from_value::(anchor) { + changeset.anchors.insert((anchor, txid)); + } + } + + Ok(changeset) +} + +/// Insert transactions, txouts, and anchors. +#[tracing::instrument] +pub async fn tx_graph_changeset_persist_to_postgres( + db_tx: &mut Transaction<'_, Postgres>, + wallet_name: &str, + changeset: &tx_graph::ChangeSet, +) -> Result<(), BdkSqlxError> { + info!("tx graph changeset from postgres"); + for tx in &changeset.txs { + sqlx::query( + "INSERT INTO tx (wallet_name, txid, whole_tx) VALUES ($1, $2, $3) + ON CONFLICT (wallet_name, txid) DO UPDATE SET whole_tx = $3", + ) + .bind(wallet_name) + .bind(tx.compute_txid().to_string()) + .bind(consensus::serialize(tx.as_ref())) + .execute(&mut **db_tx) + .await?; + } + + for (&txid, &last_seen) in &changeset.last_seen { + sqlx::query("UPDATE tx SET last_seen = $1 WHERE wallet_name = $2 AND txid = $3") + .bind(last_seen as i64) + .bind(wallet_name) + .bind(txid.to_string()) + .execute(&mut **db_tx) + .await?; + } + + for (op, txo) in &changeset.txouts { + sqlx::query( + "INSERT INTO txout (wallet_name, txid, vout, value, script) VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (wallet_name, txid, vout) DO UPDATE SET value = $4, script = $5", + ) + .bind(wallet_name) + .bind(op.txid.to_string()) + .bind(op.vout as i32) + .bind(txo.value.to_sat() as i64) + .bind(txo.script_pubkey.as_bytes()) + .execute(&mut **db_tx) + .await?; + } + + for (anchor, txid) in &changeset.anchors { + let block_hash = anchor.anchor_block().hash; + let anchor = serde_json::to_value(anchor)?; + sqlx::query( + "INSERT INTO anchor_tx (wallet_name, block_hash, anchor, txid) VALUES ($1, $2, $3, $4) + ON CONFLICT (wallet_name, block_hash, txid) DO UPDATE SET anchor = $3", + ) + .bind(wallet_name) + .bind(block_hash.to_string()) + .bind(anchor) + .bind(txid.to_string()) + .execute(&mut **db_tx) + .await?; + } + + Ok(()) +} + +/// Select blocks. +#[tracing::instrument] +pub async fn local_chain_changeset_from_postgres( + db_tx: &mut Transaction<'_, Postgres>, + wallet_name: &str, +) -> Result { + info!("local chain changeset from postgres"); + let mut changeset = local_chain::ChangeSet::default(); + + let rows = sqlx::query("SELECT hash, height FROM block WHERE wallet_name = $1") + .bind(wallet_name) + .fetch_all(&mut **db_tx) + .await?; + + for row in rows { + let hash: String = row.get("hash"); + let height: i32 = row.get("height"); + let block_hash = BlockHash::from_str(&hash)?; + changeset.blocks.insert(height as u32, Some(block_hash)); + } + + Ok(changeset) +} + +/// Insert blocks. +#[tracing::instrument] +pub async fn local_chain_changeset_persist_to_postgres( + db_tx: &mut Transaction<'_, Postgres>, + wallet_name: &str, + changeset: &local_chain::ChangeSet, +) -> Result<(), BdkSqlxError> { + info!("local chain changeset to postgres"); + for (&height, &hash) in &changeset.blocks { + match hash { + Some(hash) => { + sqlx::query( + "INSERT INTO block (wallet_name, hash, height) VALUES ($1, $2, $3) + ON CONFLICT (wallet_name, hash) DO UPDATE SET height = $3", + ) + .bind(wallet_name) + .bind(hash.to_string()) + .bind(height as i32) + .execute(&mut **db_tx) + .await?; + } + None => { + sqlx::query("DELETE FROM block WHERE wallet_name = $1 AND height = $2") + .bind(wallet_name) + .bind(height as i32) + .execute(&mut **db_tx) + .await?; + } + } + } + + Ok(()) +} + +/// Collects information on all the wallets in the database and dumps it to stdout. +#[tracing::instrument] +pub async fn easy_backup(db: Pool) -> Result<(), BdkSqlxError> { + info!("Starting easy backup"); + + let statement = "SELECT * FROM keychain"; + + let results = sqlx::query_as::<_, KeychainEntry>(statement) + .fetch_all(&db) + .await?; + + let json_array = json!(results); + println!("{}", serde_json::to_string_pretty(&json_array)?); + + info!("Easy backup completed successfully"); + Ok(()) +} + +/// Represents a row in the keychain table. +#[derive(serde::Serialize, FromRow)] +struct KeychainEntry { + wallet_name: String, + keychainkind: String, + descriptor: String, + descriptor_id: Vec, + last_revealed: i32, +} diff --git a/src/sqlite.rs b/src/sqlite.rs new file mode 100644 index 0000000..3b16a91 --- /dev/null +++ b/src/sqlite.rs @@ -0,0 +1,498 @@ +//! bdk-sqlx sqlite store + +#![warn(missing_docs)] + +use std::str::FromStr; +use std::sync::Arc; + +use super::{BdkSqlxError, FutureResult, Store}; +use bdk_chain::{ + local_chain, tx_graph, Anchor, ConfirmationBlockTime, DescriptorExt, DescriptorId, Merge, +}; +use bdk_wallet::bitcoin::{ + self, + consensus::{self, Decodable}, + hashes::Hash, + Amount, BlockHash, Network, OutPoint, ScriptBuf, TxOut, Txid, +}; +use bdk_wallet::chain as bdk_chain; +use bdk_wallet::descriptor::{Descriptor, DescriptorPublicKey, ExtendedDescriptor}; +use bdk_wallet::KeychainKind::{External, Internal}; +use bdk_wallet::{AsyncWalletPersister, ChangeSet, KeychainKind}; +use serde_json::json; +use sqlx::sqlite::SqliteRow; +use sqlx::sqlite::{SqlitePool, SqlitePoolOptions}; +use sqlx::sqlx_macros::migrate; +use sqlx::{sqlite::Sqlite, FromRow, Pool, Row, Transaction}; +use tracing::info; + +impl AsyncWalletPersister for Store { + type Error = BdkSqlxError; + + #[tracing::instrument] + fn initialize<'a>(store: &'a mut Self) -> FutureResult<'a, ChangeSet, Self::Error> + where + Self: 'a, + { + info!("initialize store"); + Box::pin(store.read()) + } + + #[tracing::instrument] + fn persist<'a>( + store: &'a mut Self, + changeset: &'a ChangeSet, + ) -> FutureResult<'a, (), Self::Error> + where + Self: 'a, + { + info!("persist store"); + Box::pin(store.write(changeset)) + } +} + +impl Store { + /// Construct a new [`Store`] with an existing sqlite connection pool. + #[tracing::instrument] + pub async fn new( + pool: Pool, + wallet_name: String, + migrate: bool, + ) -> Result { + info!("new sqlite store"); + if migrate { + info!("migrate"); + migrate!("./migrations/postgres").run(&pool).await?; + } + Ok(Self { pool, wallet_name }) + } + + /// Construct a new [`Store`] without an existing sqlite connection pool. + /// + /// The SQLite DB URL should look like "sqlite://bdk_wallet.sqlite?mode=rwc". + /// + /// If no URL is given a memory DB (non-persisted) will be used. A memory DB + /// is useful for testing. + #[tracing::instrument] + pub async fn new_with_url( + url: Option, + wallet_name: String, + migrate: bool, + ) -> Result, BdkSqlxError> { + info!("new store with url"); + let pool = if let Some(url) = url { + SqlitePool::connect(url.as_str()).await? + } else { + // must limit to one connection and no timeout if using memory DB + SqlitePoolOptions::new() + .max_connections(1) + .min_connections(1) + .idle_timeout(None) + .max_lifetime(None) + .connect(":memory:") + .await? + }; + Self::new(pool, wallet_name, migrate).await + } +} + +impl Store { + #[tracing::instrument] + pub(crate) async fn read(&self) -> Result { + info!("migrate and read"); + let mut tx = self.pool.begin().await?; + let mut changeset = ChangeSet::default(); + let sql = + "SELECT n.name as network, + k_int.descriptor as internal_descriptor, k_int.last_revealed as internal_last_revealed, + k_ext.descriptor as external_descriptor, k_ext.last_revealed as external_last_revealed + FROM network n + LEFT JOIN keychain k_int ON n.wallet_name = k_int.wallet_name AND k_int.keychainkind = 'Internal' + LEFT JOIN keychain k_ext ON n.wallet_name = k_ext.wallet_name AND k_ext.keychainkind = 'External' + WHERE n.wallet_name = $1"; + + // Fetch wallet data + let row = sqlx::query(sql) + .bind(&self.wallet_name) + .fetch_optional(&mut *tx) + .await?; + + //dbg!(&row); + + if let Some(row) = row { + Self::changeset_from_row(&mut tx, &mut changeset, row, &self.wallet_name).await?; + } + + Ok(changeset) + } + + //#[tracing::instrument] + pub(crate) async fn changeset_from_row( + tx: &mut Transaction<'_, Sqlite>, + changeset: &mut ChangeSet, + row: SqliteRow, + wallet_name: &str, + ) -> Result<(), BdkSqlxError> { + info!("changeset from row"); + + let network: String = row.get("network"); + let internal_last_revealed: Option = row.get("internal_last_revealed"); + let external_last_revealed: Option = row.get("external_last_revealed"); + let internal_desc_str: Option = row.get("internal_descriptor"); + let external_desc_str: Option = row.get("external_descriptor"); + + changeset.network = Some(Network::from_str(&network).expect("parse Network")); + + if let Some(desc_str) = external_desc_str { + let descriptor: Descriptor = desc_str.parse()?; + let did = descriptor.descriptor_id(); + changeset.descriptor = Some(descriptor); + if let Some(last_rev) = external_last_revealed { + changeset.indexer.last_revealed.insert(did, last_rev as u32); + } + } + + if let Some(desc_str) = internal_desc_str { + let descriptor: Descriptor = desc_str.parse()?; + let did = descriptor.descriptor_id(); + changeset.change_descriptor = Some(descriptor); + if let Some(last_rev) = internal_last_revealed { + changeset.indexer.last_revealed.insert(did, last_rev as u32); + } + } + + changeset.tx_graph = tx_graph_changeset_from_sqlite(tx, wallet_name).await?; + changeset.local_chain = local_chain_changeset_from_sqlite(tx, wallet_name).await?; + Ok(()) + } + + #[tracing::instrument] + pub(crate) async fn write(&self, changeset: &ChangeSet) -> Result<(), BdkSqlxError> { + info!("changeset write"); + if changeset.is_empty() { + return Ok(()); + } + + let wallet_name = &self.wallet_name; + let mut tx = self.pool.begin().await?; + + if let Some(ref descriptor) = changeset.descriptor { + insert_descriptor(&mut tx, wallet_name, descriptor, External).await?; + } + + if let Some(ref change_descriptor) = changeset.change_descriptor { + insert_descriptor(&mut tx, wallet_name, change_descriptor, Internal).await?; + } + + if let Some(network) = changeset.network { + insert_network(&mut tx, wallet_name, network).await?; + } + + let last_revealed_indices = &changeset.indexer.last_revealed; + if !last_revealed_indices.is_empty() { + for (desc_id, index) in last_revealed_indices { + update_last_revealed(&mut tx, wallet_name, *desc_id, *index).await?; + } + } + + local_chain_changeset_persist_to_sqlite(&mut tx, wallet_name, &changeset.local_chain) + .await?; + tx_graph_changeset_persist_to_sqlite(&mut tx, wallet_name, &changeset.tx_graph).await?; + + tx.commit().await?; + + Ok(()) + } +} + +/// Insert keychain descriptors. +#[tracing::instrument] +async fn insert_descriptor( + tx: &mut Transaction<'_, Sqlite>, + wallet_name: &str, + descriptor: &ExtendedDescriptor, + keychain: KeychainKind, +) -> Result<(), BdkSqlxError> { + info!("insert descriptor"); + let descriptor_str = descriptor.to_string(); + + let descriptor_id = descriptor.descriptor_id().to_byte_array(); + let keychain = match keychain { + External => "External", + Internal => "Internal", + }; + + sqlx::query( + "INSERT INTO keychain (wallet_name, keychainkind, descriptor, descriptor_id) VALUES ($1, $2, $3, $4)", + ) + .bind(wallet_name) + .bind(keychain) + .bind(descriptor_str) + .bind(descriptor_id.as_slice()) + .execute(&mut **tx) + .await?; + + Ok(()) +} + +/// Insert network. +#[tracing::instrument] +async fn insert_network( + tx: &mut Transaction<'_, Sqlite>, + wallet_name: &str, + network: Network, +) -> Result<(), BdkSqlxError> { + info!("insert network"); + sqlx::query("INSERT INTO network (wallet_name, name) VALUES ($1, $2)") + .bind(wallet_name) + .bind(network.to_string()) + .execute(&mut **tx) + .await?; + + Ok(()) +} + +/// Update keychain last revealed +#[tracing::instrument] +async fn update_last_revealed( + tx: &mut Transaction<'_, Sqlite>, + wallet_name: &str, + descriptor_id: DescriptorId, + last_revealed: u32, +) -> Result<(), BdkSqlxError> { + info!("update last revealed"); + + sqlx::query::( + "UPDATE keychain SET last_revealed = $1 WHERE wallet_name = $2 AND descriptor_id = $3", + ) + .bind(last_revealed as i32) + .bind(wallet_name) + .bind(descriptor_id.to_byte_array().as_slice()) + .execute(&mut **tx) + .await?; + + Ok(()) +} + +/// Select transactions, txouts, and anchors. +#[tracing::instrument] +pub async fn tx_graph_changeset_from_sqlite( + db_tx: &mut Transaction<'_, Sqlite>, + wallet_name: &str, +) -> Result, BdkSqlxError> { + info!("tx graph changeset from sqlite"); + let mut changeset = tx_graph::ChangeSet::default(); + + // Fetch transactions + let rows = sqlx::query("SELECT txid, whole_tx, last_seen FROM tx WHERE wallet_name = $1") + .bind(wallet_name) + .fetch_all(&mut **db_tx) + .await?; + + for row in rows { + let txid: String = row.get("txid"); + let txid = Txid::from_str(&txid)?; + let whole_tx: Option> = row.get("whole_tx"); + let last_seen: Option = row.get("last_seen"); + + if let Some(tx_bytes) = whole_tx { + if let Ok(tx) = bitcoin::Transaction::consensus_decode(&mut tx_bytes.as_slice()) { + changeset.txs.insert(Arc::new(tx)); + } + } + if let Some(last_seen) = last_seen { + changeset.last_seen.insert(txid, last_seen as u64); + } + } + + // Fetch txouts + let rows = sqlx::query("SELECT txid, vout, value, script FROM txout WHERE wallet_name = $1") + .bind(wallet_name) + .fetch_all(&mut **db_tx) + .await?; + + for row in rows { + let txid: String = row.get("txid"); + let txid = Txid::from_str(&txid)?; + let vout: i32 = row.get("vout"); + let value: i64 = row.get("value"); + let script: Vec = row.get("script"); + + changeset.txouts.insert( + OutPoint { + txid, + vout: vout as u32, + }, + TxOut { + value: Amount::from_sat(value as u64), + script_pubkey: ScriptBuf::from(script), + }, + ); + } + + // Fetch anchors + let rows = + sqlx::query("SELECT json(anchor) as anchor, txid FROM anchor_tx WHERE wallet_name = $1") + .bind(wallet_name) + .fetch_all(&mut **db_tx) + .await?; + + for row in rows { + let anchor: serde_json::Value = row.get("anchor"); + let txid: String = row.get("txid"); + let txid = Txid::from_str(&txid)?; + + if let Ok(anchor) = serde_json::from_value::(anchor) { + changeset.anchors.insert((anchor, txid)); + } + } + + Ok(changeset) +} + +/// Insert transactions, txouts, and anchors. +#[tracing::instrument] +pub async fn tx_graph_changeset_persist_to_sqlite( + db_tx: &mut Transaction<'_, Sqlite>, + wallet_name: &str, + changeset: &tx_graph::ChangeSet, +) -> Result<(), BdkSqlxError> { + info!("tx graph changeset from sqlite"); + for tx in &changeset.txs { + sqlx::query( + "INSERT INTO tx (wallet_name, txid, whole_tx) VALUES ($1, $2, $3) + ON CONFLICT (wallet_name, txid) DO UPDATE SET whole_tx = $3", + ) + .bind(wallet_name) + .bind(tx.compute_txid().to_string()) + .bind(consensus::serialize(tx.as_ref())) + .execute(&mut **db_tx) + .await?; + } + + for (&txid, &last_seen) in &changeset.last_seen { + sqlx::query("UPDATE tx SET last_seen = $1 WHERE wallet_name = $2 AND txid = $3") + .bind(last_seen as i64) + .bind(wallet_name) + .bind(txid.to_string()) + .execute(&mut **db_tx) + .await?; + } + + for (op, txo) in &changeset.txouts { + sqlx::query( + "INSERT INTO txout (wallet_name, txid, vout, value, script) VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (wallet_name, txid, vout) DO UPDATE SET value = $4, script = $5", + ) + .bind(wallet_name) + .bind(op.txid.to_string()) + .bind(op.vout as i32) + .bind(txo.value.to_sat() as i64) + .bind(txo.script_pubkey.as_bytes()) + .execute(&mut **db_tx) + .await?; + } + + for (anchor, txid) in &changeset.anchors { + let block_hash = anchor.anchor_block().hash; + let anchor = serde_json::to_value(anchor)?; + sqlx::query( + "INSERT INTO anchor_tx (wallet_name, block_hash, anchor, txid) VALUES ($1, $2, jsonb($3), $4) + ON CONFLICT (wallet_name, block_hash, txid) DO UPDATE SET anchor = jsonb($3)", + ) + .bind(wallet_name) + .bind(block_hash.to_string()) + .bind(anchor) + .bind(txid.to_string()) + .execute(&mut **db_tx) + .await?; + } + + Ok(()) +} + +/// Select blocks. +#[tracing::instrument] +pub async fn local_chain_changeset_from_sqlite( + db_tx: &mut Transaction<'_, Sqlite>, + wallet_name: &str, +) -> Result { + info!("local chain changeset from sqlite"); + let mut changeset = local_chain::ChangeSet::default(); + + let rows = sqlx::query("SELECT hash, height FROM block WHERE wallet_name = $1") + .bind(wallet_name) + .fetch_all(&mut **db_tx) + .await?; + + for row in rows { + let hash: String = row.get("hash"); + let height: i32 = row.get("height"); + let block_hash = BlockHash::from_str(&hash)?; + changeset.blocks.insert(height as u32, Some(block_hash)); + } + + Ok(changeset) +} + +/// Insert blocks. +#[tracing::instrument] +pub async fn local_chain_changeset_persist_to_sqlite( + db_tx: &mut Transaction<'_, Sqlite>, + wallet_name: &str, + changeset: &local_chain::ChangeSet, +) -> Result<(), BdkSqlxError> { + info!("local chain changeset to sqlite"); + for (&height, &hash) in &changeset.blocks { + match hash { + Some(hash) => { + sqlx::query( + "INSERT INTO block (wallet_name, hash, height) VALUES ($1, $2, $3) + ON CONFLICT (wallet_name, hash) DO UPDATE SET height = $3", + ) + .bind(wallet_name) + .bind(hash.to_string()) + .bind(height as i32) + .execute(&mut **db_tx) + .await?; + } + None => { + sqlx::query("DELETE FROM block WHERE wallet_name = $1 AND height = $2") + .bind(wallet_name) + .bind(height as i32) + .execute(&mut **db_tx) + .await?; + } + } + } + + Ok(()) +} + +/// Collects information on all the wallets in the database and dumps it to stdout. +#[tracing::instrument] +pub async fn easy_backup(db: Pool) -> Result<(), BdkSqlxError> { + info!("Starting easy backup"); + + let statement = "SELECT * FROM keychain"; + + let results = sqlx::query_as::<_, KeychainEntry>(statement) + .fetch_all(&db) + .await?; + + let json_array = json!(results); + println!("{}", serde_json::to_string_pretty(&json_array)?); + + info!("Easy backup completed successfully"); + Ok(()) +} + +/// Represents a row in the keychain table. +#[derive(serde::Serialize, FromRow)] +struct KeychainEntry { + wallet_name: String, + keychainkind: String, + descriptor: String, + descriptor_id: Vec, + last_revealed: i32, +} diff --git a/src/test.rs b/src/test.rs index f08f44d..d0ae890 100644 --- a/src/test.rs +++ b/src/test.rs @@ -1,27 +1,25 @@ -use crate::{drop_all, Store}; +use crate::{BdkSqlxError, FutureResult, Store}; use assert_matches::assert_matches; -use bdk_chain::bitcoin::constants::ChainHash; -use bdk_chain::bitcoin::hashes::Hash; -use bdk_chain::bitcoin::secp256k1::Secp256k1; -use bdk_chain::bitcoin::Network::Signet; -use bdk_chain::bitcoin::{BlockHash, Network, Txid}; -use bdk_chain::miniscript::{Descriptor, DescriptorPublicKey}; -use bdk_chain::BlockId; -use bdk_electrum::electrum_client::Client; -use bdk_electrum::{electrum_client, BdkElectrumClient}; -use bdk_testenv::bitcoincore_rpc::RpcApi; -use bdk_testenv::TestEnv; +use bdk_wallet::bitcoin::constants::ChainHash; +use bdk_wallet::bitcoin::hashes::Hash; +use bdk_wallet::bitcoin::secp256k1::Secp256k1; +use bdk_wallet::bitcoin::Network::{Regtest, Signet}; +use bdk_wallet::bitcoin::{ + transaction, Address, Amount, BlockHash, Network, OutPoint, Transaction, TxIn, TxOut, Txid, +}; +use bdk_wallet::chain::{tx_graph, BlockId, ConfirmationBlockTime, ConfirmationTime}; +use bdk_wallet::miniscript::{Descriptor, DescriptorPublicKey}; use bdk_wallet::{ - descriptor::ExtendedDescriptor, wallet_name_from_descriptor, KeychainKind, LoadError, - LoadMismatch, LoadWithPersistError, PersistedWallet, Wallet, + bitcoin, descriptor::ExtendedDescriptor, wallet_name_from_descriptor, AsyncWalletPersister, + Balance, ChangeSet, KeychainKind::*, LoadError, LoadMismatch, LoadWithPersistError, Update, + Wallet, }; -use better_panic::Settings; -use rustls::crypto::ring::default_provider; -use sqlx::PgPool; -use std::collections::HashSet; +use sqlx::{Pool, Postgres, Sqlite, SqlitePool}; use std::env; -use std::io::Write; -use std::time::Duration; +use std::ops::Add; +use std::str::FromStr; +use std::sync::Once; +use tracing::info; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::EnvFilter; @@ -45,8 +43,6 @@ pub fn get_test_wpkh() -> &'static str { } const NETWORK: Network = Signet; -const STOP_GAP: usize = 50; -const BATCH_SIZE: usize = 50; fn parse_descriptor(s: &str) -> ExtendedDescriptor { >::parse_descriptor(&Secp256k1::new(), s) @@ -54,274 +50,409 @@ fn parse_descriptor(s: &str) -> ExtendedDescriptor { .0 } -#[tracing::instrument] -#[tokio::test] -async fn wallet_is_persisted() -> anyhow::Result<()> { - Settings::debug() - .most_recent_first(false) - .lineno_suffix(true) - .install(); - - tracing_subscriber::registry() - .with(EnvFilter::new( - env::var("RUST_LOG").unwrap_or_else(|_| "sqlx=warn,bdk_postgres=info".into()), - )) - .with(tracing_subscriber::fmt::layer()) - .try_init()?; - - // Set up the database URL (you might want to use a test-specific database) - let url = env::var("DATABASE_TEST_URL").expect("DATABASE_TEST_URL must be set for tests"); +static INIT: Once = Once::new(); + +// This must only be called once. +fn initialize() { + INIT.call_once(|| { + tracing_subscriber::registry() + .with(EnvFilter::new( + env::var("RUST_LOG").unwrap_or_else(|_| "sqlx=warn,bdk_sqlx=warn".into()), + )) + .with(tracing_subscriber::fmt::layer()) + .try_init() + .expect("setup tracing"); + }); +} - let pg = PgPool::connect(&url.clone()).await?; - match drop_all(pg).await { - Ok(_) => { - dbg!("tables dropped") - } - Err(_) => { - dbg!("Error dropping tables") +trait DropAll { + async fn drop_all(&self) -> anyhow::Result<()>; +} + +impl DropAll for Pool { + /// Drops all tables. + /// + /// Clean up (optional, depending on your test database strategy) + /// You might want to delete the test wallet from the database here. + #[tracing::instrument] + async fn drop_all(&self) -> anyhow::Result<()> { + let drop_statements = vec![ + "DROP TABLE IF EXISTS _sqlx_migrations", + "DROP TABLE IF EXISTS vault_addresses", + "DROP TABLE IF EXISTS used_anchorwatch_keys", + "DROP TABLE IF EXISTS anchorwatch_keys", + "DROP TABLE IF EXISTS psbts", + "DROP TABLE IF EXISTS whitelist_update", + "DROP TABLE IF EXISTS vault_parameters", + "DROP TABLE IF EXISTS users", + "DROP TABLE IF EXISTS version", + "DROP TABLE IF EXISTS anchor_tx", + "DROP TABLE IF EXISTS txout", + "DROP TABLE IF EXISTS tx", + "DROP TABLE IF EXISTS block", + "DROP TABLE IF EXISTS keychain", + "DROP TABLE IF EXISTS network", + ]; + + let mut tx = self.begin().await?; + + for statement in drop_statements { + sqlx::query(statement).execute(&mut *tx).await?; } - }; - // Define descriptors (you may need to adjust these based on your exact requirements) - let (external_desc, internal_desc) = get_test_tr_single_sig_xprv_with_change_desc(); - // Generate a unique name for this test wallet - let wallet_name = wallet_name_from_descriptor( - external_desc, - Some(internal_desc), - NETWORK, - &Secp256k1::new(), - )?; + tx.commit().await?; - // Create a new wallet - let wallet_spk_index = { - let mut store = Store::new_with_url(url.clone(), Some(wallet_name.clone())).await?; - let mut wallet = Wallet::create(external_desc, internal_desc) - .network(NETWORK) - .create_wallet_async(&mut store) - .await?; + Ok(()) + } +} - let deposit_address = wallet.reveal_next_address(KeychainKind::External); - let change_address = wallet.reveal_next_address(KeychainKind::Internal); - dbg!(deposit_address.address); - dbg!(change_address.address); +#[derive(Debug)] +enum TestStore { + Postgres(Store), + Sqlite(Store), +} - assert!(wallet.persist_async(&mut store).await?); - wallet.spk_index().clone() - }; +impl AsyncWalletPersister for TestStore { + type Error = BdkSqlxError; + #[tracing::instrument] + fn initialize<'a>(store: &'a mut Self) -> FutureResult<'a, ChangeSet, Self::Error> + where + Self: 'a, { - // Recover the wallet - let mut store = Store::new_with_url(url.clone(), Some(wallet_name.clone())).await?; - let mut wallet = Wallet::load() - .descriptor(KeychainKind::External, Some(external_desc)) - .descriptor(KeychainKind::Internal, Some(internal_desc)) - .load_wallet_async(&mut store) - .await? - .expect("wallet must exist"); - - assert_eq!(wallet.network(), NETWORK); - assert_eq!( - wallet.spk_index().keychains().collect::>(), - wallet_spk_index.keychains().collect::>() - ); - assert_eq!( - wallet.spk_index().last_revealed_indices(), - wallet_spk_index.last_revealed_indices() - ); - - let recovered_address = wallet.reveal_next_address(KeychainKind::External); - println!("Recovered next address: {}", recovered_address.address); + info!("initialize test store"); + match store { + TestStore::Postgres(store) => Box::pin(store.read()), + TestStore::Sqlite(store) => Box::pin(store.read()), + } + } - assert_eq!( - wallet.public_descriptor(KeychainKind::External).to_string(), - "tr(tpubD6NzVbkrYhZ4WgCeJid2Zds24zATB58r1q1qTLMuApUxZUxzETADNTeP6SvZKSsXs4qhvFAC21GFjXHwgxAcDtZqzzj8JMpsFDgqyjSJHGa/0/*)#celxt6vn".to_string(), - ); + #[tracing::instrument] + fn persist<'a>( + store: &'a mut Self, + changeset: &'a ChangeSet, + ) -> FutureResult<'a, (), Self::Error> + where + Self: 'a, + { + info!("persist test store"); + match store { + TestStore::Postgres(store) => Box::pin(store.write(changeset)), + TestStore::Sqlite(store) => Box::pin(store.write(changeset)), + } } +} - // Clean up (optional, depending on your test database strategy) - // You might want to delete the test wallet from the database here - let db = PgPool::connect(&url).await?; - drop_all(db).await.expect("hope its not mainet"); +async fn create_test_stores(wallet_name: String) -> anyhow::Result> { + let mut stores: Vec = Vec::new(); - Ok(()) + // Set up postgres database URL (you might want to use a test-specific database) + let url = env::var("DATABASE_TEST_URL").expect("DATABASE_TEST_URL must be set for tests"); + let pool = Pool::::connect(&url.clone()).await?; + // Drop all before creating new store for testing + pool.drop_all().await?; + let postgres_store = + Store::::new_with_url(url.clone(), wallet_name.clone(), true).await?; + stores.push(TestStore::Postgres(postgres_store)); + + // Setup sqlite in-memory database + let pool = SqlitePool::connect(":memory:").await?; + let sqlite_store = Store::::new(pool.clone(), wallet_name.clone(), true).await?; + stores.push(TestStore::Sqlite(sqlite_store)); + + Ok(stores) } -async fn setup_database() -> anyhow::Result { - let url = env::var("DATABASE_TEST_URL").expect("DATABASE_TEST_URL must be set for tests"); - let pg = PgPool::connect(&url).await?; - match drop_all(pg).await { - Ok(_) => dbg!("tables dropped"), - Err(_) => dbg!("Error dropping tables"), +/// Add a fake transaction to a wallet for testing. +/// +/// The test wallet must use the `Regtest` network and the added tx will have the given spent, +/// change, and fee amounts. +/// +/// The tx ids for the two created transactions (funding and spending) are returned. +pub fn insert_fake_tx(wallet: &mut Wallet, spent: Amount, change: Amount, fee: Amount) -> Txid { + let receive_address = wallet.reveal_next_address(External).address; + let change_address = wallet.reveal_next_address(Internal).address; + let sendto_address = Address::from_str("bcrt1q3qtze4ys45tgdvguj66zrk4fu6hq3a3v9pfly5") + .expect("address") + .require_network(Network::Regtest) + .unwrap(); + + let tx0 = Transaction { + version: transaction::Version::ONE, + lock_time: bitcoin::absolute::LockTime::ZERO, + input: vec![TxIn { + previous_output: OutPoint { + txid: Txid::all_zeros(), + vout: 0, + }, + script_sig: Default::default(), + sequence: Default::default(), + witness: Default::default(), + }], + output: vec![TxOut { + value: spent.add(change).add(fee), + script_pubkey: receive_address.script_pubkey(), + }], + }; + + let tx1 = Transaction { + version: transaction::Version::ONE, + lock_time: bitcoin::absolute::LockTime::ZERO, + input: vec![TxIn { + previous_output: OutPoint { + txid: tx0.compute_txid(), + vout: 0, + }, + script_sig: Default::default(), + sequence: Default::default(), + witness: Default::default(), + }], + output: vec![ + TxOut { + value: change, + script_pubkey: change_address.script_pubkey(), + }, + TxOut { + value: spent, + script_pubkey: sendto_address.script_pubkey(), + }, + ], }; - Ok(url) + + wallet + .insert_checkpoint(BlockId { + height: 42, + hash: BlockHash::all_zeros(), + }) + .unwrap(); + wallet + .insert_checkpoint(BlockId { + height: 1_000, + hash: BlockHash::all_zeros(), + }) + .unwrap(); + wallet + .insert_checkpoint(BlockId { + height: 2_000, + hash: BlockHash::all_zeros(), + }) + .unwrap(); + + wallet.insert_tx(tx0.clone()); + insert_anchor_from_conf( + wallet, + tx0.compute_txid(), + ConfirmationTime::Confirmed { + height: 1_000, + time: 100, + }, + ); + + wallet.insert_tx(tx1.clone()); + insert_anchor_from_conf( + wallet, + tx1.compute_txid(), + ConfirmationTime::Confirmed { + height: 2_000, + time: 200, + }, + ); + + tx1.compute_txid() } -fn get_wallet_descriptors(wallet_type: u8) -> (&'static str, &'static str) { - match wallet_type { - 1 => get_test_tr_single_sig_xprv_with_change_desc(), - 2 => ("wpkh([bdb9a801/84'/1'/0']tpubDCopxf4CiXF9dicdGrXgZV9f8j3pYbWBVfF8WxjaFHtic4DZsgp1tQ58hZdsSu6M7FFzUyAh9rMn7RZASUkPgZCMdByYKXvVtigzGi8VJs6/0/*)#j8mkwdgr", - "wpkh([bdb9a801/84'/1'/0']tpubDCopxf4CiXF9dicdGrXgZV9f8j3pYbWBVfF8WxjaFHtic4DZsgp1tQ58hZdsSu6M7FFzUyAh9rMn7RZASUkPgZCMdByYKXvVtigzGi8VJs6/1/*)#rn7hnccm"), - 3 => get_test_minisicript_with_change_desc(), - _ => panic!("Invalid wallet type"), +/// Simulates confirming a tx with `txid` at the specified `position` by inserting an anchor +/// at the lowest height in local chain that is greater or equal to `position`'s height, +/// assuming the confirmation time matches `ConfirmationTime::Confirmed`. +pub fn insert_anchor_from_conf(wallet: &mut Wallet, txid: Txid, position: ConfirmationTime) { + if let ConfirmationTime::Confirmed { height, time } = position { + // anchor tx to checkpoint with lowest height that is >= position's height + let anchor = wallet + .local_chain() + .range(height..) + .last() + .map(|anchor_cp| ConfirmationBlockTime { + block_id: anchor_cp.block_id(), + confirmation_time: time, + }) + .expect("confirmation height cannot be greater than tip"); + + wallet + .apply_update(Update { + tx_update: tx_graph::TxUpdate { + anchors: [(anchor, txid)].into(), + ..Default::default() + }, + ..Default::default() + }) + .unwrap(); } } -async fn create_and_scan_wallet( - url: &str, - external_desc: &str, - internal_desc: &str, -) -> anyhow::Result<(Store, String)> { +#[tracing::instrument] +#[tokio::test] +async fn wallet_is_persisted() -> anyhow::Result<()> { + initialize(); + + // Define descriptors (you may need to adjust these based on your exact requirements) + let (external_desc, internal_desc) = get_test_tr_single_sig_xprv_with_change_desc(); + // Generate a unique name for this test wallet let wallet_name = wallet_name_from_descriptor( external_desc, Some(internal_desc), NETWORK, &Secp256k1::new(), )?; - let mut store = Store::new_with_url(url.to_string(), Some(wallet_name.clone())).await?; - let mut wallet = Wallet::create(external_desc.to_string(), internal_desc.to_string()) - .network(NETWORK) - .create_wallet_async(&mut store) - .await?; - let _ = electrum_full_scan(&mut wallet).await?; - assert!(wallet.persist_async(&mut store).await?); - Ok((store, wallet_name)) -} - -async fn load_wallet_and_get_transactions( - store: &mut Store, - external_desc: &str, - internal_desc: &str, -) -> anyhow::Result> { - let wallet = Wallet::load() - .descriptor(KeychainKind::External, Some(external_desc.to_string())) - .descriptor(KeychainKind::Internal, Some(internal_desc.to_string())) - .load_wallet_async(store) - .await? - .expect("wallet must exist"); - Ok(wallet.transactions().map(|tx| tx.tx_node.txid).collect()) -} -async fn recover_wallet_and_get_transactions( - external_desc: &str, - internal_desc: &str, -) -> anyhow::Result> { - let mut wallet = Wallet::create(external_desc.to_string(), internal_desc.to_string()) - .network(NETWORK) - .create_wallet_no_persist()?; - let _ = electrum_full_scan_no_persist(&mut wallet).await?; - Ok(wallet.transactions().map(|tx| tx.tx_node.txid).collect()) -} - -#[tracing::instrument] -#[tokio::test] -async fn test_three_wallets_list_transactions() -> anyhow::Result<()> { - Settings::debug() - .most_recent_first(false) - .lineno_suffix(true) - .install(); - default_provider() - .install_default() - .expect("Failed to install rustls default crypto provider"); - - let url = setup_database().await?; - - let wallet_types = [1, 2, 3]; - let mut stores = Vec::new(); - let mut persisted_txs = Vec::new(); - let mut recovered_txs = Vec::new(); - - for wallet_type in wallet_types.iter() { - let (external_desc, internal_desc) = get_wallet_descriptors(*wallet_type); - let (store, _) = create_and_scan_wallet(&url, external_desc, internal_desc).await?; - stores.push(store); - } + let stores = create_test_stores(wallet_name).await?; + for mut store in stores { + // Create a new wallet + let mut wallet = Wallet::create(external_desc, internal_desc) + .network(NETWORK) + .create_wallet_async(&mut store) + .await?; - for (i, store) in stores.iter_mut().enumerate() { - let (external_desc, internal_desc) = get_wallet_descriptors(wallet_types[i]); - let mut txs = load_wallet_and_get_transactions(store, external_desc, internal_desc).await?; - txs.sort(); - persisted_txs.push(txs); - } + let external_addr0 = wallet.reveal_next_address(External); + for keychain in [External, Internal] { + let _ = wallet.reveal_addresses_to(keychain, 2); + } - for wallet_type in wallet_types.iter() { - let (external_desc, internal_desc) = get_wallet_descriptors(*wallet_type); - let mut txs = recover_wallet_and_get_transactions(external_desc, internal_desc).await?; - txs.sort(); - recovered_txs.push(txs); - } + assert!(wallet.persist_async(&mut store).await?); + let wallet_spk_index = wallet.spk_index(); - for i in 0..3 { - assert_eq!(persisted_txs[i], recovered_txs[i]); + { + // Recover the wallet + let wallet = Wallet::load() + .descriptor(External, Some(external_desc)) + .descriptor(Internal, Some(internal_desc)) + .load_wallet_async(&mut store) + .await? + .expect("wallet must exist"); + + assert_eq!(wallet.network(), NETWORK); + assert_eq!( + wallet.spk_index().keychains().collect::>(), + wallet_spk_index.keychains().collect::>() + ); + assert_eq!( + wallet.spk_index().last_revealed_indices(), + wallet_spk_index.last_revealed_indices() + ); + + let recovered_addr = wallet.peek_address(External, 0); + assert_eq!(recovered_addr, external_addr0, "failed to recover address"); + + assert_eq!( + wallet.public_descriptor(External).to_string(), + "tr(tpubD6NzVbkrYhZ4WgCeJid2Zds24zATB58r1q1qTLMuApUxZUxzETADNTeP6SvZKSsXs4qhvFAC21GFjXHwgxAcDtZqzzj8JMpsFDgqyjSJHGa/0/*)#celxt6vn".to_string(), + ); + } } - // Clean up - let db = PgPool::connect(&url).await?; - drop_all(db).await.expect("hope it's not mainnet"); Ok(()) } -async fn electrum_full_scan(wallet: &mut PersistedWallet) -> anyhow::Result<()> { - let client = BdkElectrumClient::new(Client::new("ssl://mempool.space:60602").unwrap()); - client.populate_tx_cache(wallet.tx_graph().full_txs().map(|tx_node| tx_node.tx)); - - let request = wallet.start_full_scan().inspect({ - let mut stdout = std::io::stdout(); - let mut once = HashSet::::new(); - move |k, spk_i, _| { - if once.insert(k) { - print!("\nScanning keychain [{:?}]", k); - } - print!(" {:<3}", spk_i); - stdout.flush().expect("must flush"); +#[tracing::instrument] +#[tokio::test] +async fn test_three_wallets_list_transactions() -> anyhow::Result<()> { + initialize(); + + struct TestCase { + descriptors: (String, String), + spent: Amount, + change: Amount, + fee: Amount, + store: TestStore, + } + impl TestCase { + async fn new( + descriptors: (&'static str, &'static str), + spent: u64, + change: u64, + fee: u64, + ) -> Vec { + let wallet_name = wallet_name_from_descriptor( + descriptors.0, + Some(descriptors.1), + NETWORK, + &Secp256k1::new(), + ) + .unwrap(); + let stores = create_test_stores(wallet_name.clone()).await.unwrap(); + stores + .into_iter() + .map(|store| Self { + descriptors: (descriptors.0.to_string(), descriptors.1.to_string()), + spent: Amount::from_sat(spent), + change: Amount::from_sat(change), + fee: Amount::from_sat(fee), + store, + }) + .collect() } - }); + } + let mut test_cases = [ + TestCase::new(get_test_tr_single_sig_xprv_with_change_desc(), 20_000, 11_000, 2000).await, + TestCase::new(("wpkh([bdb9a801/84'/1'/0']tpubDCopxf4CiXF9dicdGrXgZV9f8j3pYbWBVfF8WxjaFHtic4DZsgp1tQ58hZdsSu6M7FFzUyAh9rMn7RZASUkPgZCMdByYKXvVtigzGi8VJs6/0/*)#j8mkwdgr", + "wpkh([bdb9a801/84'/1'/0']tpubDCopxf4CiXF9dicdGrXgZV9f8j3pYbWBVfF8WxjaFHtic4DZsgp1tQ58hZdsSu6M7FFzUyAh9rMn7RZASUkPgZCMdByYKXvVtigzGi8VJs6/1/*)#rn7hnccm"), 12_000, 30_000, 1500).await, + TestCase::new(get_test_minisicript_with_change_desc(), 44_444, 20_000, 5000).await + ].into_iter().flatten().collect::>(); + + let mut saved_tx_ids = Vec::::new(); + let mut saved_balances = Vec::::new(); + + // create wallet and save test transaction + for test_case in &mut test_cases { + let mut wallet = Wallet::create( + test_case.descriptors.0.clone(), + test_case.descriptors.1.clone(), + ) + .network(Regtest) + .create_wallet_async(&mut test_case.store) + .await?; + let tx_id = insert_fake_tx( + &mut wallet, + test_case.spent, + test_case.change, + test_case.fee, + ); + saved_tx_ids.push(tx_id); + saved_balances.push(wallet.balance()); + wallet.persist_async(&mut test_case.store).await?; + } - let update = client.full_scan(request, STOP_GAP, BATCH_SIZE, true)?; - wallet.apply_update(update)?; - Ok(()) -} + saved_tx_ids.reverse(); + saved_balances.reverse(); -async fn electrum_full_scan_no_persist(wallet: &mut Wallet) -> anyhow::Result<()> { - let client = BdkElectrumClient::new(Client::new("ssl://mempool.space:60602").unwrap()); - client.populate_tx_cache(wallet.tx_graph().full_txs().map(|tx_node| tx_node.tx)); - - let request = wallet.start_full_scan().inspect({ - let mut stdout = std::io::stdout(); - let mut once = HashSet::::new(); - move |k, spk_i, _| { - if once.insert(k) { - print!("\nScanning keychain [{:?}]", k); - } - print!(" {:<3}", spk_i); - stdout.flush().expect("must flush"); - } - }); + // load wallet and test transaction and verify with saved + for test_case in &mut test_cases { + let wallet = Wallet::load() + .descriptor(External, Some(test_case.descriptors.0.clone())) + .descriptor(Internal, Some(test_case.descriptors.1.clone())) + .check_network(Regtest) + .load_wallet_async(&mut test_case.store) + .await? + .expect("wallet must exist"); + let saved_tx_ids = saved_tx_ids.pop().unwrap(); + let loaded_tx_id = wallet + .transactions() + .map(|tx| tx.tx_node.tx.compute_txid()) + .next() + .expect("txid must exist"); + assert_eq!(saved_tx_ids, loaded_tx_id); + + let saved_balance = saved_balances.pop().unwrap(); + let loaded_balance = wallet.balance(); + assert_eq!(saved_balance, loaded_balance); + } - let update = client.full_scan(request, STOP_GAP, BATCH_SIZE, true)?; - wallet.apply_update(update)?; Ok(()) } #[tracing::instrument] #[tokio::test] async fn wallet_load_checks() -> anyhow::Result<()> { - Settings::debug() - .most_recent_first(false) - .lineno_suffix(true) - .install(); - - // Set up the database URL (you might want to use a test-specific database) - let url = env::var("DATABASE_TEST_URL").expect("DATABASE_TEST_URL must be set for tests"); - - let pg = PgPool::connect(&url.clone()).await?; - match drop_all(pg).await { - Ok(_) => { - dbg!("tables dropped") - } - Err(_) => { - dbg!("Error dropping tables") - } - }; + initialize(); // Define descriptors (you may need to adjust these based on your exact requirements) let (external_desc, internal_desc) = get_test_tr_single_sig_xprv_with_change_desc(); @@ -335,74 +466,56 @@ async fn wallet_load_checks() -> anyhow::Result<()> { &Secp256k1::new(), )?; - // Create a new wallet - let mut store = Store::new_with_url(url.clone(), Some(wallet_name)).await?; - let _wallet = Wallet::create(external_desc, internal_desc) - .network(NETWORK) - .create_wallet_async(&mut store) - .await?; + let stores = create_test_stores(wallet_name).await?; + for mut store in stores { + // Create a new wallet + let _wallet = Wallet::create(external_desc, internal_desc) + .network(NETWORK) + .create_wallet_async(&mut store) + .await?; - { - assert_matches!( - Wallet::load() - .descriptor(KeychainKind::External, Some(internal_desc)) - .load_wallet_async(&mut store) - .await, - Err(LoadWithPersistError::InvalidChangeSet(LoadError::Mismatch( - LoadMismatch::Descriptor { keychain, loaded, expected } - ))) - if keychain == KeychainKind::External && loaded == Some(parsed_ext.clone()) && expected == Some(parsed_int), - "should error on wrong external descriptor" - ); - } - { - assert_matches!( - Wallet::load() - .descriptor(KeychainKind::External, Option::<&str>::None) - .load_wallet_async(&mut store) - .await, - Err(LoadWithPersistError::InvalidChangeSet(LoadError::Mismatch( - LoadMismatch::Descriptor { keychain, loaded, expected } - ))) - if keychain == KeychainKind::External && loaded == Some(parsed_ext) && expected.is_none(), - "external descriptor check should error when expected is none" - ); - } - { - let mainnet_hash = BlockHash::from_byte_array(ChainHash::BITCOIN.to_bytes()); - assert_matches!( - Wallet::load().check_genesis_hash(mainnet_hash).load_wallet_async(&mut store).await - , Err(LoadWithPersistError::InvalidChangeSet(LoadError::Mismatch(LoadMismatch::Genesis { .. }))), - "unexpected genesis hash check result: mainnet hash (check) is not testnet hash (loaded)"); + { + assert_matches!( + Wallet::load() + .descriptor(External, Some(internal_desc)) + .load_wallet_async(&mut store) + .await, + Err(LoadWithPersistError::InvalidChangeSet(LoadError::Mismatch( + LoadMismatch::Descriptor { keychain, loaded, expected } + ))) + if keychain == External && loaded == Some(parsed_ext.clone()) && expected == Some(parsed_int.clone()), + "should error on wrong external descriptor" + ); + } + { + assert_matches!( + Wallet::load() + .descriptor(External, Option::<&str>::None) + .load_wallet_async(&mut store) + .await, + Err(LoadWithPersistError::InvalidChangeSet(LoadError::Mismatch( + LoadMismatch::Descriptor { keychain, loaded, expected } + ))) + if keychain == External && loaded == Some(parsed_ext.clone()) && expected.is_none(), + "external descriptor check should error when expected is none" + ); + } + { + let mainnet_hash = BlockHash::from_byte_array(ChainHash::BITCOIN.to_bytes()); + assert_matches!( + Wallet::load().check_genesis_hash(mainnet_hash).load_wallet_async(&mut store).await + , Err(LoadWithPersistError::InvalidChangeSet(LoadError::Mismatch(LoadMismatch::Genesis { .. }))), + "unexpected genesis hash check result: mainnet hash (check) is not testnet hash (loaded)"); + } } - // Clean up (optional, depending on your test database strategy) - // You might want to delete the test wallet from the database here - let db = PgPool::connect(&url).await?; - drop_all(db).await.expect("hope its not mainnet"); - Ok(()) } #[tracing::instrument] #[tokio::test] async fn single_descriptor_wallet_persist_and_recover() -> anyhow::Result<()> { - Settings::debug() - .most_recent_first(false) - .lineno_suffix(true) - .install(); - // Set up the database URL (you might want to use a test-specific database) - let url = env::var("DATABASE_TEST_URL").expect("DATABASE_TEST_URL must be set for tests"); - - let pg = PgPool::connect(&url.clone()).await?; - match drop_all(pg).await { - Ok(_) => { - dbg!("tables dropped") - } - Err(_) => { - dbg!("Error dropping tables") - } - }; + initialize(); // Define descriptors let desc = get_test_tr_single_sig_xprv(); @@ -410,63 +523,46 @@ async fn single_descriptor_wallet_persist_and_recover() -> anyhow::Result<()> { // Generate a unique name for this test wallet let wallet_name = wallet_name_from_descriptor(desc, Some(desc), NETWORK, &Secp256k1::new())?; - // Create a new wallet - let mut store = Store::new_with_url(url.clone(), Some(wallet_name)).await?; - let mut wallet = Wallet::create_single(desc) - .network(NETWORK) - .create_wallet_async(&mut store) - .await?; + let stores = create_test_stores(wallet_name).await?; + for mut store in stores { + // Create a new wallet + let mut wallet = Wallet::create_single(desc) + .network(NETWORK) + .create_wallet_async(&mut store) + .await?; - let _ = wallet.reveal_addresses_to(KeychainKind::External, 2); - assert!(wallet.persist_async(&mut store).await?); + let _ = wallet.reveal_addresses_to(External, 2); + assert!(wallet.persist_async(&mut store).await?); - { - // Recover the wallet - let wallet = Wallet::load().load_wallet_async(&mut store).await?.unwrap(); - assert_eq!(wallet.derivation_index(KeychainKind::External), Some(2)); - } - { - // should error on wrong internal params - let desc = get_test_wpkh(); - let exp_desc = parse_descriptor(desc); - let err = Wallet::load() - .descriptor(KeychainKind::Internal, Some(desc)) - .load_wallet_async(&mut store) - .await; - assert_matches!( - err, - Err(LoadWithPersistError::InvalidChangeSet(LoadError::Mismatch(LoadMismatch::Descriptor { keychain, loaded, expected }))) - if keychain == KeychainKind::Internal && loaded.is_none() && expected == Some(exp_desc), - "single descriptor wallet should refuse change descriptor param" - ); + { + // Recover the wallet + let wallet = Wallet::load().load_wallet_async(&mut store).await?.unwrap(); + assert_eq!(wallet.derivation_index(External), Some(2)); + } + { + // should error on wrong internal params + let desc = get_test_wpkh(); + let exp_desc = parse_descriptor(desc); + let err = Wallet::load() + .descriptor(Internal, Some(desc)) + .load_wallet_async(&mut store) + .await; + assert_matches!( + err, + Err(LoadWithPersistError::InvalidChangeSet(LoadError::Mismatch(LoadMismatch::Descriptor { keychain, loaded, expected }))) + if keychain == Internal && loaded.is_none() && expected == Some(exp_desc), + "single descriptor wallet should refuse change descriptor param" + ); + } } - // Clean up (optional, depending on your test database strategy) - // You might want to delete the test wallet from the database here - let db = PgPool::connect(&url).await?; - drop_all(db).await.expect("hope its not mainnet"); Ok(()) } #[tracing::instrument] #[tokio::test] async fn two_wallets_load() -> anyhow::Result<()> { - Settings::debug() - .most_recent_first(false) - .lineno_suffix(true) - .install(); - // Set up the database URL (you might want to use a test-specific database) - let url = env::var("DATABASE_TEST_URL").expect("DATABASE_TEST_URL must be set for tests"); - - let pg = PgPool::connect(&url.clone()).await?; - match drop_all(pg).await { - Ok(_) => { - dbg!("tables dropped") - } - Err(_) => { - dbg!("Error dropping tables") - } - }; + initialize(); // Define descriptors let (external_desc_wallet_1, internal_desc_wallet_1) = @@ -488,125 +584,63 @@ async fn two_wallets_load() -> anyhow::Result<()> { &Secp256k1::new(), )?; - // Create wallets - let mut store_1 = Store::new_with_url(url.clone(), Some(wallet_1_name)).await?; - let mut store_2 = Store::new_with_url(url.clone(), Some(wallet_2_name)).await?; - - let mut wallet_1 = Wallet::create(external_desc_wallet_1, internal_desc_wallet_1) - .network(NETWORK) - .create_wallet_async(&mut store_1) - .await?; - let _ = wallet_1.reveal_next_address(KeychainKind::External); - let _ = wallet_1.reveal_next_address(KeychainKind::Internal); - assert!(wallet_1.persist_async(&mut store_1).await?); - - // for wallet 2 we reveal an extra internal address and insert a new checkpoint - // to check that loading returns the correct data for each wallet - let mut wallet_2 = Wallet::create(external_desc_wallet_2, internal_desc_wallet_2) - .network(NETWORK) - .create_wallet_async(&mut store_2) - .await?; - let _ = wallet_2.reveal_next_address(KeychainKind::External); - let _ = wallet_2.reveal_addresses_to(KeychainKind::Internal, 2); - let block = BlockId { - height: 100, - hash: BlockHash::all_zeros(), - }; - let _ = wallet_2.insert_checkpoint(block).unwrap(); - assert!(wallet_2.persist_async(&mut store_2).await?); + let mut stores1 = create_test_stores(wallet_1_name).await?; + let mut stores2 = create_test_stores(wallet_2_name).await?; - // Recover the wallet_1 - let wallet_1 = Wallet::load() - .load_wallet_async(&mut store_1) - .await? - .unwrap(); - - // Recover the wallet_2 - let wallet_2 = Wallet::load() - .load_wallet_async(&mut store_2) - .await? - .unwrap(); - - assert_eq!( - wallet_1.derivation_index(KeychainKind::External), - wallet_2.derivation_index(KeychainKind::External) - ); - // FIXME: see https://github.com/bitcoindevkit/bdk-sqlx/pull/10 - // assert_ne!( - // wallet_1.derivation_index(Internal), - // wallet_2.derivation_index(Internal), - // "different wallets should not have same derivation index" - // ); - // assert_ne!( - // wallet_1.latest_checkpoint(), - // wallet_2.latest_checkpoint(), - // "different wallets should not have same chain tip" - // ); - - // Clean up (optional, depending on your test database strategy) - // You might want to delete the test wallet from the database here - let db = PgPool::connect(&url).await?; - drop_all(db).await.expect("hope its not mainnet"); - Ok(()) -} - -#[tracing::instrument] -#[tokio::test] -async fn sync_with_electrum() -> anyhow::Result<()> { - Settings::debug() - .most_recent_first(false) - .lineno_suffix(true) - .install(); - - // Set up the database URL (you might want to use a test-specific database) - let url = env::var("DATABASE_TEST_URL").expect("DATABASE_TEST_URL must be set for tests"); + for _ in 0..stores1.len() { + let mut store_1 = stores1.pop().unwrap(); + let mut store_2 = stores2.pop().unwrap(); - let pg = PgPool::connect(&url.clone()).await?; - match drop_all(pg).await { - Ok(_) => { - dbg!("tables dropped") - } - Err(_) => { - dbg!("Error dropping tables") - } - }; + let mut wallet_1 = Wallet::create(external_desc_wallet_1, internal_desc_wallet_1) + .network(NETWORK) + .create_wallet_async(&mut store_1) + .await?; + let _ = wallet_1.reveal_next_address(External); + let _ = wallet_1.reveal_next_address(Internal); + assert!(wallet_1.persist_async(&mut store_1).await?); - // Define descriptors (you may need to adjust these based on your exact requirements) - let (external_desc, internal_desc) = get_test_tr_single_sig_xprv_with_change_desc(); - // Generate a unique name for this test wallet - let wallet_name = wallet_name_from_descriptor( - external_desc, - Some(internal_desc), - Network::Regtest, - &Secp256k1::new(), - )?; + // for wallet 2 we reveal an extra internal address and insert a new checkpoint + // to check that loading returns the correct data for each wallet + let mut wallet_2 = Wallet::create(external_desc_wallet_2, internal_desc_wallet_2) + .network(NETWORK) + .create_wallet_async(&mut store_2) + .await?; + let _ = wallet_2.reveal_next_address(External); + let _ = wallet_2.reveal_addresses_to(Internal, 2); + let block = BlockId { + height: 100, + hash: BlockHash::all_zeros(), + }; + let _ = wallet_2.insert_checkpoint(block).unwrap(); + assert!(wallet_2.persist_async(&mut store_2).await?); + + // Recover the wallet_1 + let wallet_1 = Wallet::load() + .load_wallet_async(&mut store_1) + .await? + .unwrap(); - let mut store = Store::new_with_url(url.clone(), Some(wallet_name)).await?; - let mut wallet = Wallet::create(external_desc, internal_desc) - .network(Network::Regtest) - .create_wallet_async(&mut store) - .await?; + // Recover the wallet_2 + let wallet_2 = Wallet::load() + .load_wallet_async(&mut store_2) + .await? + .unwrap(); - // mine blocks and sync with electrum - let env = TestEnv::new()?; - let electrum_client = electrum_client::Client::new(env.electrsd.electrum_url.as_str())?; - let client = BdkElectrumClient::new(electrum_client); - let _hashes = env.mine_blocks(9, None)?; - env.wait_until_electrum_sees_block(Duration::from_secs(10))?; - let new_tip_height: u32 = env.rpc_client().get_block_count()?.try_into()?; - assert_eq!(new_tip_height, 10); - - let request = wallet.start_full_scan(); - let update = client.full_scan(request, STOP_GAP, BATCH_SIZE, false)?; - wallet.apply_update(update)?; - assert!(wallet.persist_async(&mut store).await?); - - // Recover the wallet - let wallet = Wallet::load().load_wallet_async(&mut store).await?.unwrap(); - assert_eq!(wallet.latest_checkpoint().height(), new_tip_height); - - let db = PgPool::connect(&url).await?; - drop_all(db).await.expect("hope its not mainnet"); + assert_eq!( + wallet_1.derivation_index(External), + wallet_2.derivation_index(External) + ); + assert_ne!( + wallet_1.derivation_index(Internal), + wallet_2.derivation_index(Internal), + "different wallets should not have same derivation index" + ); + assert_ne!( + wallet_1.latest_checkpoint(), + wallet_2.latest_checkpoint(), + "different wallets should not have same chain tip" + ); + } Ok(()) }