diff --git a/examples/limit.rs b/examples/limit.rs new file mode 100644 index 000000000..6aaa2921f --- /dev/null +++ b/examples/limit.rs @@ -0,0 +1,361 @@ +/// Example how to limit blob requests by hash and node id, and to add +/// throttling or limiting the maximum number of connections. +/// +/// Limiting is done via a fn that returns an EventSender and internally +/// makes liberal use of spawn to spawn background tasks. +/// +/// This is fine, since the tasks will terminate as soon as the [BlobsProtocol] +/// instance holding the [EventSender] will be dropped. But for production +/// grade code you might nevertheless put the tasks into a [tokio::task::JoinSet] or +/// [n0_future::FuturesUnordered]. +mod common; +use std::{ + collections::{HashMap, HashSet}, + path::PathBuf, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, +}; + +use anyhow::Result; +use clap::Parser; +use common::setup_logging; +use iroh::{protocol::Router, NodeAddr, NodeId, SecretKey, Watcher}; +use iroh_blobs::{ + provider::events::{ + AbortReason, ConnectMode, EventMask, EventSender, ProviderMessage, RequestMode, + ThrottleMode, + }, + store::mem::MemStore, + ticket::BlobTicket, + BlobFormat, BlobsProtocol, Hash, +}; +use rand::thread_rng; + +use crate::common::get_or_generate_secret_key; + +#[derive(Debug, Parser)] +#[command(version, about)] +pub enum Args { + /// Limit requests by node id + ByNodeId { + /// Path for files to add. + paths: Vec, + #[clap(long("allow"))] + /// Nodes that are allowed to download content. + allowed_nodes: Vec, + /// Number of secrets to generate for allowed node ids. + #[clap(long, default_value_t = 1)] + secrets: usize, + }, + /// Limit requests by hash, only first hash is allowed + ByHash { + /// Path for files to add. + paths: Vec, + }, + /// Throttle requests + Throttle { + /// Path for files to add. + paths: Vec, + /// Delay in milliseconds after sending a chunk group of 16 KiB. + #[clap(long, default_value = "100")] + delay_ms: u64, + }, + /// Limit maximum number of connections. + MaxConnections { + /// Path for files to add. + paths: Vec, + /// Maximum number of concurrent get requests. + #[clap(long, default_value = "1")] + max_connections: usize, + }, + /// Get a blob. Just for completeness sake. + Get { + /// Ticket for the blob to download + ticket: BlobTicket, + }, +} + +fn limit_by_node_id(allowed_nodes: HashSet) -> EventSender { + let mask = EventMask { + // We want a request for each incoming connection so we can accept + // or reject them. We don't need any other events. + connected: ConnectMode::Intercept, + ..EventMask::DEFAULT + }; + let (tx, mut rx) = EventSender::channel(32, mask); + n0_future::task::spawn(async move { + while let Some(msg) = rx.recv().await { + if let ProviderMessage::ClientConnected(msg) = msg { + let node_id = msg.node_id; + let res = if allowed_nodes.contains(&node_id) { + println!("Client connected: {node_id}"); + Ok(()) + } else { + println!("Client rejected: {node_id}"); + Err(AbortReason::Permission) + }; + msg.tx.send(res).await.ok(); + } + } + }); + tx +} + +fn limit_by_hash(allowed_hashes: HashSet) -> EventSender { + let mask = EventMask { + // We want to get a request for each get request that we can answer + // with OK or not OK depending on the hash. We do not want detailed + // events once it has been decided to handle a request. + get: RequestMode::Intercept, + ..EventMask::DEFAULT + }; + let (tx, mut rx) = EventSender::channel(32, mask); + n0_future::task::spawn(async move { + while let Some(msg) = rx.recv().await { + if let ProviderMessage::GetRequestReceived(msg) = msg { + let res = if !msg.request.ranges.is_blob() { + println!("HashSeq request not allowed"); + Err(AbortReason::Permission) + } else if !allowed_hashes.contains(&msg.request.hash) { + println!("Request for hash {} not allowed", msg.request.hash); + Err(AbortReason::Permission) + } else { + println!("Request for hash {} allowed", msg.request.hash); + Ok(()) + }; + msg.tx.send(res).await.ok(); + } + } + }); + tx +} + +fn throttle(delay_ms: u64) -> EventSender { + let mask = EventMask { + // We want to get requests for each sent user data blob, so we can add a delay. + // Other than that, we don't need any events. + throttle: ThrottleMode::Intercept, + ..EventMask::DEFAULT + }; + let (tx, mut rx) = EventSender::channel(32, mask); + n0_future::task::spawn(async move { + while let Some(msg) = rx.recv().await { + if let ProviderMessage::Throttle(msg) = msg { + n0_future::task::spawn(async move { + println!( + "Throttling {} {}, {}ms", + msg.connection_id, msg.request_id, delay_ms + ); + // we could compute the delay from the size of the data to have a fixed rate. + // but the size is almost always 16 KiB (16 chunks). + tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; + msg.tx.send(Ok(())).await.ok(); + }); + } + } + }); + tx +} + +fn limit_max_connections(max_connections: usize) -> EventSender { + #[derive(Default, Debug, Clone)] + struct ConnectionCounter(Arc<(AtomicUsize, usize)>); + + impl ConnectionCounter { + fn new(max: usize) -> Self { + Self(Arc::new((Default::default(), max))) + } + + fn inc(&self) -> Result { + let (c, max) = &*self.0; + c.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |n| { + if n >= *max { + None + } else { + Some(n + 1) + } + }) + } + + fn dec(&self) { + let (c, _) = &*self.0; + c.fetch_sub(1, Ordering::SeqCst); + } + } + + let mask = EventMask { + // For each get request, we want to get a request so we can decide + // based on the current connection count if we want to accept or reject. + // We also want detailed logging of events for the get request, so we can + // detect when the request is finished one way or another. + connected: ConnectMode::Intercept, + ..EventMask::DEFAULT + }; + let (tx, mut rx) = EventSender::channel(32, mask); + n0_future::task::spawn(async move { + let requests = ConnectionCounter::new(max_connections); + while let Some(msg) = rx.recv().await { + match msg { + ProviderMessage::ClientConnected(msg) => { + let connection_id = msg.connection_id; + let node_id = msg.node_id; + let res = if let Ok(n) = requests.inc() { + println!("Accepting connection {n}, node_id {node_id}, connection_id {connection_id}"); + Ok(()) + } else { + Err(AbortReason::RateLimited) + }; + msg.tx.send(res).await.ok(); + } + ProviderMessage::ConnectionClosed(msg) => { + requests.dec(); + println!("Connection closed, connection_id {}", msg.connection_id,); + } + _ => {} + } + } + }); + tx +} + +#[tokio::main] +async fn main() -> Result<()> { + setup_logging(); + let args = Args::parse(); + let secret = get_or_generate_secret_key()?; + let endpoint = iroh::Endpoint::builder() + .secret_key(secret) + .discovery_n0() + .bind() + .await?; + match args { + Args::Get { ticket } => { + let connection = endpoint + .connect(ticket.node_addr().clone(), iroh_blobs::ALPN) + .await?; + let (data, stats) = iroh_blobs::get::request::get_blob(connection, ticket.hash()) + .bytes_and_stats() + .await?; + println!("Downloaded {} bytes", data.len()); + println!("Stats: {stats:?}"); + } + Args::ByNodeId { + paths, + allowed_nodes, + secrets, + } => { + let mut allowed_nodes = allowed_nodes.into_iter().collect::>(); + if secrets > 0 { + println!("Generating {secrets} new secret keys for allowed nodes:"); + let mut rand = thread_rng(); + for _ in 0..secrets { + let secret = SecretKey::generate(&mut rand); + let public = secret.public(); + allowed_nodes.insert(public); + println!("IROH_SECRET={}", hex::encode(secret.to_bytes())); + } + } + + let store = MemStore::new(); + let hashes = add_paths(&store, paths).await?; + let events = limit_by_node_id(allowed_nodes.clone()); + let (router, addr) = setup(store, events).await?; + + for (path, hash) in hashes { + let ticket = BlobTicket::new(addr.clone(), hash, BlobFormat::Raw); + println!("{}: {ticket}", path.display()); + } + println!(); + println!("Node id: {}\n", router.endpoint().node_id()); + for id in &allowed_nodes { + println!("Allowed node: {id}"); + } + + tokio::signal::ctrl_c().await?; + router.shutdown().await?; + } + Args::ByHash { paths } => { + let store = MemStore::new(); + + let mut hashes = HashMap::new(); + let mut allowed_hashes = HashSet::new(); + for (i, path) in paths.into_iter().enumerate() { + let tag = store.add_path(&path).await?; + hashes.insert(path, tag.hash); + if i == 0 { + allowed_hashes.insert(tag.hash); + } + } + + let events = limit_by_hash(allowed_hashes.clone()); + let (router, addr) = setup(store, events).await?; + + for (path, hash) in hashes.iter() { + let ticket = BlobTicket::new(addr.clone(), *hash, BlobFormat::Raw); + let permitted = if allowed_hashes.contains(hash) { + "allowed" + } else { + "forbidden" + }; + println!("{}: {ticket} ({permitted})", path.display()); + } + tokio::signal::ctrl_c().await?; + router.shutdown().await?; + } + Args::Throttle { paths, delay_ms } => { + let store = MemStore::new(); + let hashes = add_paths(&store, paths).await?; + let events = throttle(delay_ms); + let (router, addr) = setup(store, events).await?; + for (path, hash) in hashes { + let ticket = BlobTicket::new(addr.clone(), hash, BlobFormat::Raw); + println!("{}: {ticket}", path.display()); + } + tokio::signal::ctrl_c().await?; + router.shutdown().await?; + } + Args::MaxConnections { + paths, + max_connections, + } => { + let store = MemStore::new(); + let hashes = add_paths(&store, paths).await?; + let events = limit_max_connections(max_connections); + let (router, addr) = setup(store, events).await?; + for (path, hash) in hashes { + let ticket = BlobTicket::new(addr.clone(), hash, BlobFormat::Raw); + println!("{}: {ticket}", path.display()); + } + tokio::signal::ctrl_c().await?; + router.shutdown().await?; + } + } + Ok(()) +} + +async fn add_paths(store: &MemStore, paths: Vec) -> Result> { + let mut hashes = HashMap::new(); + for path in paths { + let tag = store.add_path(&path).await?; + hashes.insert(path, tag.hash); + } + Ok(hashes) +} + +async fn setup(store: MemStore, events: EventSender) -> Result<(Router, NodeAddr)> { + let secret = get_or_generate_secret_key()?; + let endpoint = iroh::Endpoint::builder() + .discovery_n0() + .secret_key(secret) + .bind() + .await?; + let _ = endpoint.home_relay().initialized().await; + let addr = endpoint.node_addr().initialized().await; + let blobs = BlobsProtocol::new(&store, endpoint.clone(), Some(events)); + let router = Router::builder(endpoint) + .accept(iroh_blobs::ALPN, blobs) + .spawn(); + Ok((router, addr)) +} diff --git a/examples/random_store.rs b/examples/random_store.rs index ffdd9b826..d3f9a0fc4 100644 --- a/examples/random_store.rs +++ b/examples/random_store.rs @@ -6,14 +6,15 @@ use iroh::{SecretKey, Watcher}; use iroh_base::ticket::NodeTicket; use iroh_blobs::{ api::downloader::Shuffled, - provider::Event, + provider::events::{AbortReason, EventMask, EventSender, ProviderMessage}, store::fs::FsStore, test::{add_hash_sequences, create_random_blobs}, HashAndFormat, }; +use irpc::RpcMessage; use n0_future::StreamExt; use rand::{rngs::StdRng, Rng, SeedableRng}; -use tokio::{signal::ctrl_c, sync::mpsc}; +use tokio::signal::ctrl_c; use tracing::info; #[derive(Parser, Debug)] @@ -100,77 +101,77 @@ pub fn get_or_generate_secret_key() -> Result { } } -pub fn dump_provider_events( - allow_push: bool, -) -> ( - tokio::task::JoinHandle<()>, - mpsc::Sender, -) { - let (tx, mut rx) = mpsc::channel(100); +pub fn dump_provider_events(allow_push: bool) -> (tokio::task::JoinHandle<()>, EventSender) { + let (tx, mut rx) = EventSender::channel(100, EventMask::ALL_READONLY); + fn dump_updates(mut rx: irpc::channel::mpsc::Receiver) { + tokio::spawn(async move { + while let Ok(Some(update)) = rx.recv().await { + println!("{update:?}"); + } + }); + } let dump_task = tokio::spawn(async move { while let Some(event) = rx.recv().await { match event { - Event::ClientConnected { - node_id, - connection_id, - permitted, - } => { - permitted.send(true).await.ok(); - println!("Client connected: {node_id} {connection_id}"); + ProviderMessage::ClientConnected(msg) => { + println!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + } + ProviderMessage::ClientConnectedNotify(msg) => { + println!("{:?}", msg.inner); + } + ProviderMessage::ConnectionClosed(msg) => { + println!("{:?}", msg.inner); } - Event::GetRequestReceived { - connection_id, - request_id, - hash, - ranges, - } => { - println!( - "Get request received: {connection_id} {request_id} {hash} {ranges:?}" - ); + ProviderMessage::GetRequestReceived(msg) => { + println!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + dump_updates(msg.rx); } - Event::TransferCompleted { - connection_id, - request_id, - stats, - } => { - println!("Transfer completed: {connection_id} {request_id} {stats:?}"); + ProviderMessage::GetRequestReceivedNotify(msg) => { + println!("{:?}", msg.inner); + dump_updates(msg.rx); } - Event::TransferAborted { - connection_id, - request_id, - stats, - } => { - println!("Transfer aborted: {connection_id} {request_id} {stats:?}"); + ProviderMessage::GetManyRequestReceived(msg) => { + println!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + dump_updates(msg.rx); } - Event::TransferProgress { - connection_id, - request_id, - index, - end_offset, - } => { - info!("Transfer progress: {connection_id} {request_id} {index} {end_offset}"); + ProviderMessage::GetManyRequestReceivedNotify(msg) => { + println!("{:?}", msg.inner); + dump_updates(msg.rx); } - Event::PushRequestReceived { - connection_id, - request_id, - hash, - ranges, - permitted, - } => { - if allow_push { - permitted.send(true).await.ok(); - println!( - "Push request received: {connection_id} {request_id} {hash} {ranges:?}" - ); + ProviderMessage::PushRequestReceived(msg) => { + println!("{:?}", msg.inner); + let res = if allow_push { + Ok(()) } else { - permitted.send(false).await.ok(); - println!( - "Push request denied: {connection_id} {request_id} {hash} {ranges:?}" - ); - } + Err(AbortReason::Permission) + }; + msg.tx.send(res).await.ok(); + dump_updates(msg.rx); + } + ProviderMessage::PushRequestReceivedNotify(msg) => { + println!("{:?}", msg.inner); + dump_updates(msg.rx); + } + ProviderMessage::ObserveRequestReceived(msg) => { + println!("{:?}", msg.inner); + let res = if allow_push { + Ok(()) + } else { + Err(AbortReason::Permission) + }; + msg.tx.send(res).await.ok(); + dump_updates(msg.rx); + } + ProviderMessage::ObserveRequestReceivedNotify(msg) => { + println!("{:?}", msg.inner); + dump_updates(msg.rx); } - _ => { - info!("Received event: {:?}", event); + ProviderMessage::Throttle(msg) => { + println!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); } } } diff --git a/proptest-regressions/store/fs/util/entity_manager.txt b/proptest-regressions/store/fs/util/entity_manager.txt new file mode 100644 index 000000000..94b6aa63c --- /dev/null +++ b/proptest-regressions/store/fs/util/entity_manager.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 0f2ebc49ab2f84e112f08407bb94654fbcb1f19050a4a8a6196383557696438a # shrinks to input = _TestCountersManagerProptestFsArgs { entries: [(15313427648878534792, 264348813928009031854006459208395772047), (1642534478798447378, 15989109311941500072752977306696275871), (8755041673862065815, 172763711808688570294350362332402629716), (4993597758667891804, 114145440157220458287429360639759690928), (15031383154962489250, 63217081714858286463391060323168548783), (17668469631267503333, 11878544422669770587175118199598836678), (10507570291819955314, 126584081645379643144412921692654648228), (3979008599365278329, 283717221942996985486273080647433218905), (8316838360288996639, 334043288511621783152802090833905919408), (15673798930962474157, 77551315511802713260542200115027244708), (12058791254144360414, 56638044274259821850511200885092637649), (8191628769638031337, 314181956273420400069887649110740549194), (6290369460137232066, 255779791286732775990301011955519176773), (11919824746661852269, 319400891587146831511371932480749645441), (12491631698789073154, 271279849791970841069522263758329847554), (53891048909263304, 12061234604041487609497959407391945555), (9486366498650667097, 311383186592430597410801882015456718030), (15696332331789302593, 306911490707714340526403119780178604150), (8699088947997536151, 312272624973367009520183311568498652066), (1144772544750976199, 200591877747619565555594857038887015), (5907208586200645081, 299942008952473970881666769409865744975), (3384528743842518913, 26230956866762934113564101494944411446), (13877357832690956494, 229457597607752760006918374695475345151), (2965687966026226090, 306489188264741716662410004273408761623), (13624286905717143613, 232801392956394366686194314010536008033), (3622356130274722018, 162030840677521022192355139208505458492), (17807768575470996347, 264107246314713159406963697924105744409), (5103434150074147746, 331686166459964582006209321975587627262), (5962771466034321974, 300961804728115777587520888809168362574), (2930645694242691907, 127752709774252686733969795258447263979), (16197574560597474644, 245410120683069493317132088266217906749), (12478835478062365617, 103838791113879912161511798836229961653), (5503595333662805357, 92368472243854403026472376408708548349), (18122734335129614364, 288955542597300001147753560885976966029), (12688080215989274550, 85237436689682348751672119832134138932), (4148468277722853958, 297778117327421209654837771300216669574), (8749445804640085302, 79595866493078234154562014325793780126), (12442730869682574563, 196176786402808588883611974143577417817), (6110644747049355904, 26592587989877021920275416199052685135), (5851164380497779369, 158876888501825038083692899057819261957), (9497384378514985275, 15279835675313542048650599472403150097), (10661092311826161857, 250089949043892591422587928179995867509), (10046856000675345423, 231369150063141386398059701278066296663)] } diff --git a/src/api.rs b/src/api.rs index a2a34a2db..117c59e25 100644 --- a/src/api.rs +++ b/src/api.rs @@ -30,7 +30,7 @@ pub mod downloader; pub mod proto; pub mod remote; pub mod tags; -use crate::api::proto::WaitIdleRequest; +use crate::{api::proto::WaitIdleRequest, provider::events::ProgressError}; pub use crate::{store::util::Tag, util::temp_tag::TempTag}; pub(crate) type ApiClient = irpc::Client; @@ -97,6 +97,8 @@ pub enum ExportBaoError { ExportBaoIo { source: io::Error }, #[snafu(display("encode error: {source}"))] ExportBaoInner { source: bao_tree::io::EncodeError }, + #[snafu(display("client error: {source}"))] + ClientError { source: ProgressError }, } impl From for Error { @@ -107,6 +109,7 @@ impl From for Error { ExportBaoError::Request { source, .. } => Self::Io(source.into()), ExportBaoError::ExportBaoIo { source, .. } => Self::Io(source), ExportBaoError::ExportBaoInner { source, .. } => Self::Io(source.into()), + ExportBaoError::ClientError { source, .. } => Self::Io(source.into()), } } } @@ -152,6 +155,12 @@ impl From for ExportBaoError { } } +impl From for ExportBaoError { + fn from(value: ProgressError) -> Self { + ClientSnafu.into_error(value) + } +} + pub type ExportBaoResult = std::result::Result; #[derive(Debug, derive_more::Display, derive_more::From, Serialize, Deserialize)] diff --git a/src/api/blobs.rs b/src/api/blobs.rs index cbd27bbac..897e0371c 100644 --- a/src/api/blobs.rs +++ b/src/api/blobs.rs @@ -57,7 +57,7 @@ use super::{ }; use crate::{ api::proto::{BatchRequest, ImportByteStreamUpdate}, - provider::StreamContext, + provider::events::ClientResult, store::IROH_BLOCK_SIZE, util::temp_tag::TempTag, BlobFormat, Hash, HashAndFormat, @@ -1116,7 +1116,9 @@ impl ExportBaoProgress { .write_chunk(leaf.data) .await .map_err(io::Error::from)?; - progress.notify_payload_write(index, leaf.offset, len).await; + progress + .notify_payload_write(index, leaf.offset, len) + .await?; } EncodedItem::Done => break, EncodedItem::Error(cause) => return Err(cause.into()), @@ -1162,7 +1164,7 @@ impl ExportBaoProgress { pub(crate) trait WriteProgress { /// Notify the progress writer that a payload write has happened. - async fn notify_payload_write(&mut self, index: u64, offset: u64, len: usize); + async fn notify_payload_write(&mut self, index: u64, offset: u64, len: usize) -> ClientResult; /// Log a write of some other data. fn log_other_write(&mut self, len: usize); @@ -1170,17 +1172,3 @@ pub(crate) trait WriteProgress { /// Notify the progress writer that a transfer has started. async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64); } - -impl WriteProgress for StreamContext { - async fn notify_payload_write(&mut self, index: u64, offset: u64, len: usize) { - StreamContext::notify_payload_write(self, index, offset, len); - } - - fn log_other_write(&mut self, len: usize) { - StreamContext::log_other_write(self, len); - } - - async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64) { - StreamContext::send_transfer_started(self, index, hash, size).await - } -} diff --git a/src/api/downloader.rs b/src/api/downloader.rs index ffdfd2782..50db0fc2f 100644 --- a/src/api/downloader.rs +++ b/src/api/downloader.rs @@ -3,7 +3,6 @@ use std::{ collections::{HashMap, HashSet}, fmt::Debug, future::{Future, IntoFuture}, - io, sync::Arc, }; @@ -113,7 +112,7 @@ async fn handle_download_impl( SplitStrategy::Split => handle_download_split_impl(store, pool, request, tx).await?, SplitStrategy::None => match request.request { FiniteRequest::Get(get) => { - let sink = IrpcSenderRefSink(tx).with_map_err(io::Error::other); + let sink = IrpcSenderRefSink(tx); execute_get(&pool, Arc::new(get), &request.providers, &store, sink).await?; } FiniteRequest::GetMany(_) => { @@ -143,9 +142,7 @@ async fn handle_download_split_impl( let hash = request.hash; let (tx, rx) = tokio::sync::mpsc::channel::<(usize, DownloadProgessItem)>(16); progress_tx.send(rx).await.ok(); - let sink = TokioMpscSenderSink(tx) - .with_map_err(io::Error::other) - .with_map(move |x| (id, x)); + let sink = TokioMpscSenderSink(tx).with_map(move |x| (id, x)); let res = execute_get(&pool, Arc::new(request), &providers, &store, sink).await; (hash, res) } @@ -375,7 +372,7 @@ async fn split_request<'a>( providers: &Arc, pool: &ConnectionPool, store: &Store, - progress: impl Sink, + progress: impl Sink, ) -> anyhow::Result + Send + 'a>> { Ok(match request { FiniteRequest::Get(req) => { @@ -431,7 +428,7 @@ async fn execute_get( request: Arc, providers: &Arc, store: &Store, - mut progress: impl Sink, + mut progress: impl Sink, ) -> anyhow::Result<()> { let remote = store.remote(); let mut providers = providers.find_providers(request.content()); diff --git a/src/api/remote.rs b/src/api/remote.rs index 5eb64c24b..dcfbc4fb4 100644 --- a/src/api/remote.rs +++ b/src/api/remote.rs @@ -18,6 +18,7 @@ use crate::{ GetManyRequest, ObserveItem, ObserveRequest, PushRequest, Request, RequestType, MAX_MESSAGE_SIZE, }, + provider::events::{ClientResult, ProgressError}, util::sink::{Sink, TokioMpscSenderSink}, }; @@ -478,9 +479,7 @@ impl Remote { let content = content.into(); let (tx, rx) = tokio::sync::mpsc::channel(64); let tx2 = tx.clone(); - let sink = TokioMpscSenderSink(tx) - .with_map(GetProgressItem::Progress) - .with_map_err(io::Error::other); + let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress); let this = self.clone(); let fut = async move { let res = this.fetch_sink(conn, content, sink).await.into(); @@ -503,7 +502,7 @@ impl Remote { &self, mut conn: impl GetConnection, content: impl Into, - progress: impl Sink, + progress: impl Sink, ) -> GetResult { let content = content.into(); let local = self @@ -556,9 +555,7 @@ impl Remote { pub fn execute_push(&self, conn: Connection, request: PushRequest) -> PushProgress { let (tx, rx) = tokio::sync::mpsc::channel(64); let tx2 = tx.clone(); - let sink = TokioMpscSenderSink(tx) - .with_map(PushProgressItem::Progress) - .with_map_err(io::Error::other); + let sink = TokioMpscSenderSink(tx).with_map(PushProgressItem::Progress); let this = self.clone(); let fut = async move { let res = this.execute_push_sink(conn, request, sink).await.into(); @@ -577,7 +574,7 @@ impl Remote { &self, conn: Connection, request: PushRequest, - progress: impl Sink, + progress: impl Sink, ) -> anyhow::Result { let hash = request.hash; debug!(%hash, "pushing"); @@ -632,9 +629,7 @@ impl Remote { pub fn execute_get_with_opts(&self, conn: Connection, request: GetRequest) -> GetProgress { let (tx, rx) = tokio::sync::mpsc::channel(64); let tx2 = tx.clone(); - let sink = TokioMpscSenderSink(tx) - .with_map(GetProgressItem::Progress) - .with_map_err(io::Error::other); + let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress); let this = self.clone(); let fut = async move { let res = this.execute_get_sink(&conn, request, sink).await.into(); @@ -658,7 +653,7 @@ impl Remote { &self, conn: &Connection, request: GetRequest, - mut progress: impl Sink, + mut progress: impl Sink, ) -> GetResult { let store = self.store(); let root = request.hash; @@ -721,9 +716,7 @@ impl Remote { pub fn execute_get_many(&self, conn: Connection, request: GetManyRequest) -> GetProgress { let (tx, rx) = tokio::sync::mpsc::channel(64); let tx2 = tx.clone(); - let sink = TokioMpscSenderSink(tx) - .with_map(GetProgressItem::Progress) - .with_map_err(io::Error::other); + let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress); let this = self.clone(); let fut = async move { let res = this.execute_get_many_sink(conn, request, sink).await.into(); @@ -747,7 +740,7 @@ impl Remote { &self, conn: Connection, request: GetManyRequest, - mut progress: impl Sink, + mut progress: impl Sink, ) -> GetResult { let store = self.store(); let hash_seq = request.hashes.iter().copied().collect::(); @@ -884,7 +877,7 @@ async fn get_blob_ranges_impl( header: AtBlobHeader, hash: Hash, store: &Store, - mut progress: impl Sink, + mut progress: impl Sink, ) -> GetResult { let (mut content, size) = header.next().await?; let Some(size) = NonZeroU64::new(size) else { @@ -1048,11 +1041,20 @@ struct StreamContext { impl WriteProgress for StreamContext where - S: Sink, + S: Sink, { - async fn notify_payload_write(&mut self, _index: u64, _offset: u64, len: usize) { + async fn notify_payload_write( + &mut self, + _index: u64, + _offset: u64, + len: usize, + ) -> ClientResult { self.payload_bytes_sent += len as u64; - self.sender.send(self.payload_bytes_sent).await.ok(); + self.sender + .send(self.payload_bytes_sent) + .await + .map_err(|e| ProgressError::Internal { source: e.into() })?; + Ok(()) } fn log_other_write(&mut self, _len: usize) {} diff --git a/src/net_protocol.rs b/src/net_protocol.rs index 3e7d9582e..47cda5344 100644 --- a/src/net_protocol.rs +++ b/src/net_protocol.rs @@ -36,22 +36,16 @@ //! # } //! ``` -use std::{fmt::Debug, future::Future, ops::Deref, sync::Arc}; +use std::{fmt::Debug, ops::Deref, sync::Arc}; use iroh::{ endpoint::Connection, protocol::{AcceptError, ProtocolHandler}, Endpoint, Watcher, }; -use tokio::sync::mpsc; use tracing::error; -use crate::{ - api::Store, - provider::{Event, EventSender}, - ticket::BlobTicket, - HashAndFormat, -}; +use crate::{api::Store, provider::events::EventSender, ticket::BlobTicket, HashAndFormat}; #[derive(Debug)] pub(crate) struct BlobsInner { @@ -75,12 +69,12 @@ impl Deref for BlobsProtocol { } impl BlobsProtocol { - pub fn new(store: &Store, endpoint: Endpoint, events: Option>) -> Self { + pub fn new(store: &Store, endpoint: Endpoint, events: Option) -> Self { Self { inner: Arc::new(BlobsInner { store: store.clone(), endpoint, - events: EventSender::new(events), + events: events.unwrap_or(EventSender::DEFAULT), }), } } @@ -106,25 +100,16 @@ impl BlobsProtocol { } impl ProtocolHandler for BlobsProtocol { - fn accept( - &self, - conn: Connection, - ) -> impl Future> + Send { + async fn accept(&self, conn: Connection) -> std::result::Result<(), AcceptError> { let store = self.store().clone(); let events = self.inner.events.clone(); - - Box::pin(async move { - crate::provider::handle_connection(conn, store, events).await; - Ok(()) - }) + crate::provider::handle_connection(conn, store, events).await; + Ok(()) } - fn shutdown(&self) -> impl Future + Send { - let store = self.store().clone(); - Box::pin(async move { - if let Err(cause) = store.shutdown().await { - error!("error shutting down store: {:?}", cause); - } - }) + async fn shutdown(&self) { + if let Err(cause) = self.store().shutdown().await { + error!("error shutting down store: {:?}", cause); + } } } diff --git a/src/protocol.rs b/src/protocol.rs index 74e0f986d..ce10865a5 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -392,11 +392,18 @@ pub use range_spec::{ChunkRangesSeq, NonEmptyRequestRangeSpecIter, RangeSpec}; use snafu::{GenerateImplicitData, Snafu}; use tokio::io::AsyncReadExt; -use crate::{api::blobs::Bitfield, provider::CountingReader, BlobFormat, Hash, HashAndFormat}; +use crate::{api::blobs::Bitfield, provider::RecvStreamExt, BlobFormat, Hash, HashAndFormat}; /// Maximum message size is limited to 100MiB for now. pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024; +/// Error code for a permission error +pub const ERR_PERMISSION: VarInt = VarInt::from_u32(1u32); +/// Error code for when a request is aborted due to a rate limit +pub const ERR_LIMIT: VarInt = VarInt::from_u32(2u32); +/// Error code for when a request is aborted due to internal error +pub const ERR_INTERNAL: VarInt = VarInt::from_u32(3u32); + /// The ALPN used with quic for the iroh blobs protocol. pub const ALPN: &[u8] = b"/iroh-bytes/4"; @@ -441,9 +448,7 @@ pub enum RequestType { } impl Request { - pub async fn read_async( - reader: &mut CountingReader<&mut iroh::endpoint::RecvStream>, - ) -> io::Result { + pub async fn read_async(reader: &mut iroh::endpoint::RecvStream) -> io::Result<(Self, usize)> { let request_type = reader.read_u8().await?; let request_type: RequestType = postcard::from_bytes(std::slice::from_ref(&request_type)) .map_err(|_| { @@ -453,22 +458,31 @@ impl Request { ) })?; Ok(match request_type { - RequestType::Get => reader - .read_to_end_as::(MAX_MESSAGE_SIZE) - .await? - .into(), - RequestType::GetMany => reader - .read_to_end_as::(MAX_MESSAGE_SIZE) - .await? - .into(), - RequestType::Observe => reader - .read_to_end_as::(MAX_MESSAGE_SIZE) - .await? - .into(), - RequestType::Push => reader - .read_length_prefixed::(MAX_MESSAGE_SIZE) - .await? - .into(), + RequestType::Get => { + let (r, size) = reader + .read_to_end_as::(MAX_MESSAGE_SIZE) + .await?; + (r.into(), size) + } + RequestType::GetMany => { + let (r, size) = reader + .read_to_end_as::(MAX_MESSAGE_SIZE) + .await?; + (r.into(), size) + } + RequestType::Observe => { + let (r, size) = reader + .read_to_end_as::(MAX_MESSAGE_SIZE) + .await?; + (r.into(), size) + } + RequestType::Push => { + let r = reader + .read_length_prefixed::(MAX_MESSAGE_SIZE) + .await?; + let size = postcard::experimental::serialized_size(&r).unwrap(); + (r.into(), size) + } _ => { return Err(io::Error::new( io::ErrorKind::InvalidData, diff --git a/src/provider.rs b/src/provider.rs index 61af8f6e1..0134169c6 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -6,130 +6,33 @@ use std::{ fmt::Debug, io, - ops::{Deref, DerefMut}, - pin::Pin, - task::Poll, - time::Duration, + time::{Duration, Instant}, }; use anyhow::{Context, Result}; use bao_tree::ChunkRanges; -use iroh::{ - endpoint::{self, RecvStream, SendStream}, - NodeId, -}; -use irpc::channel::oneshot; +use iroh::endpoint::{self, RecvStream, SendStream}; use n0_future::StreamExt; -use serde::de::DeserializeOwned; -use tokio::{io::AsyncRead, select, sync::mpsc}; -use tracing::{debug, debug_span, error, warn, Instrument}; +use quinn::{ClosedStream, ConnectionError, ReadToEndError}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use tokio::select; +use tracing::{debug, debug_span, warn, Instrument}; use crate::{ - api::{self, blobs::Bitfield, Store}, - hashseq::HashSeq, - protocol::{ - ChunkRangesSeq, GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, - Request, + api::{ + blobs::{Bitfield, WriteProgress}, + ExportBaoResult, Store, }, + hashseq::HashSeq, + protocol::{GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request}, + provider::events::{ClientConnected, ClientResult, ConnectionClosed, RequestTracker}, Hash, }; - -/// Provider progress events, to keep track of what the provider is doing. -/// -/// ClientConnected -> -/// (GetRequestReceived -> (TransferStarted -> TransferProgress*n)*n -> (TransferCompleted | TransferAborted))*n -> -/// ConnectionClosed -#[derive(Debug)] -pub enum Event { - /// A new client connected to the provider. - ClientConnected { - connection_id: u64, - node_id: NodeId, - permitted: oneshot::Sender, - }, - /// Connection closed. - ConnectionClosed { connection_id: u64 }, - /// A new get request was received from the provider. - GetRequestReceived { - /// The connection id. Multiple requests can be sent over the same connection. - connection_id: u64, - /// The request id. There is a new id for each request. - request_id: u64, - /// The root hash of the request. - hash: Hash, - /// The exact query ranges of the request. - ranges: ChunkRangesSeq, - }, - /// A new get request was received from the provider. - GetManyRequestReceived { - /// The connection id. Multiple requests can be sent over the same connection. - connection_id: u64, - /// The request id. There is a new id for each request. - request_id: u64, - /// The root hash of the request. - hashes: Vec, - /// The exact query ranges of the request. - ranges: ChunkRangesSeq, - }, - /// A new get request was received from the provider. - PushRequestReceived { - /// The connection id. Multiple requests can be sent over the same connection. - connection_id: u64, - /// The request id. There is a new id for each request. - request_id: u64, - /// The root hash of the request. - hash: Hash, - /// The exact query ranges of the request. - ranges: ChunkRangesSeq, - /// Complete this to permit the request. - permitted: oneshot::Sender, - }, - /// Transfer for the nth blob started. - TransferStarted { - /// The connection id. Multiple requests can be sent over the same connection. - connection_id: u64, - /// The request id. There is a new id for each request. - request_id: u64, - /// The index of the blob in the request. 0 for the first blob or for raw blob requests. - index: u64, - /// The hash of the blob. This is the hash of the request for the first blob, the child hash (index-1) for subsequent blobs. - hash: Hash, - /// The size of the blob. This is the full size of the blob, not the size we are sending. - size: u64, - }, - /// Progress of the transfer. - TransferProgress { - /// The connection id. Multiple requests can be sent over the same connection. - connection_id: u64, - /// The request id. There is a new id for each request. - request_id: u64, - /// The index of the blob in the request. 0 for the first blob or for raw blob requests. - index: u64, - /// The end offset of the chunk that was sent. - end_offset: u64, - }, - /// Entire transfer completed. - TransferCompleted { - /// The connection id. Multiple requests can be sent over the same connection. - connection_id: u64, - /// The request id. There is a new id for each request. - request_id: u64, - /// Statistics about the transfer. - stats: Box, - }, - /// Entire transfer aborted - TransferAborted { - /// The connection id. Multiple requests can be sent over the same connection. - connection_id: u64, - /// The request id. There is a new id for each request. - request_id: u64, - /// Statistics about the part of the transfer that was aborted. - stats: Option>, - }, -} +pub mod events; +use events::EventSender; /// Statistics about a successful or failed transfer. -#[derive(Debug)] +#[derive(Debug, Serialize, Deserialize)] pub struct TransferStats { /// The number of bytes sent that are part of the payload. pub payload_bytes_sent: u64, @@ -139,191 +42,271 @@ pub struct TransferStats { pub other_bytes_sent: u64, /// The number of bytes read from the stream. /// - /// This is the size of the request. - pub bytes_read: u64, + /// In most cases this is just the request, for push requests this is + /// request, size header and hash pairs. + pub other_bytes_read: u64, /// Total duration from reading the request to transfer completed. pub duration: Duration, } -/// Read the request from the getter. -/// -/// Will fail if there is an error while reading, or if no valid request is sent. -/// -/// This will read exactly the number of bytes needed for the request, and -/// leave the rest of the stream for the caller to read. -/// -/// It is up to the caller do decide if there should be more data. -pub async fn read_request(reader: &mut ProgressReader) -> Result { - let mut counting = CountingReader::new(&mut reader.inner); - let res = Request::read_async(&mut counting).await?; - reader.bytes_read += counting.read(); - Ok(res) -} - -#[derive(Debug)] -pub struct StreamContext { - /// The connection ID from the connection - pub connection_id: u64, - /// The request ID from the recv stream - pub request_id: u64, - /// The number of bytes written that are part of the payload - pub payload_bytes_sent: u64, - /// The number of bytes written that are not part of the payload - pub other_bytes_sent: u64, - /// The number of bytes read from the stream - pub bytes_read: u64, - /// The progress sender to send events to - pub progress: EventSender, -} - -/// Wrapper for a [`quinn::SendStream`] with additional per request information. +/// A pair of [`SendStream`] and [`RecvStream`] with additional context data. #[derive(Debug)] -pub struct ProgressWriter { - /// The quinn::SendStream to write to - pub inner: SendStream, - pub(crate) context: StreamContext, +pub struct StreamPair { + t0: Instant, + connection_id: u64, + request_id: u64, + reader: RecvStream, + writer: SendStream, + other_bytes_read: u64, + events: EventSender, } -impl Deref for ProgressWriter { - type Target = StreamContext; +impl StreamPair { + pub async fn accept( + conn: &endpoint::Connection, + events: &EventSender, + ) -> Result { + let (writer, reader) = conn.accept_bi().await?; + Ok(Self { + t0: Instant::now(), + connection_id: conn.stable_id() as u64, + request_id: reader.id().into(), + reader, + writer, + other_bytes_read: 0, + events: events.clone(), + }) + } - fn deref(&self) -> &Self::Target { - &self.context + /// Read the request. + /// + /// Will fail if there is an error while reading, or if no valid request is sent. + /// + /// This will read exactly the number of bytes needed for the request, and + /// leave the rest of the stream for the caller to read. + /// + /// It is up to the caller do decide if there should be more data. + pub async fn read_request(&mut self) -> Result { + let (res, size) = Request::read_async(&mut self.reader).await?; + self.other_bytes_read += size as u64; + Ok(res) } -} -impl DerefMut for ProgressWriter { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.context + /// We are done with reading. Return a ProgressWriter that contains the read stats and connection id + async fn into_writer( + mut self, + tracker: RequestTracker, + ) -> Result { + let res = self.reader.read_to_end(0).await; + if let Err(e) = res { + tracker + .transfer_aborted(|| Box::new(self.stats())) + .await + .ok(); + return Err(e); + }; + Ok(ProgressWriter::new( + self.writer, + WriterContext { + t0: self.t0, + other_bytes_read: self.other_bytes_read, + payload_bytes_written: 0, + other_bytes_written: 0, + tracker, + }, + )) } -} -impl StreamContext { - /// Increase the write count due to a non-payload write. - pub fn log_other_write(&mut self, len: usize) { - self.other_bytes_sent += len as u64; + async fn into_reader( + mut self, + tracker: RequestTracker, + ) -> Result { + let res = self.writer.finish(); + if let Err(e) = res { + tracker + .transfer_aborted(|| Box::new(self.stats())) + .await + .ok(); + return Err(e); + }; + Ok(ProgressReader { + inner: self.reader, + context: ReaderContext { + t0: self.t0, + other_bytes_read: self.other_bytes_read, + tracker, + }, + }) } - pub async fn send_transfer_completed(&mut self) { - self.progress - .send(|| Event::TransferCompleted { - connection_id: self.connection_id, - request_id: self.request_id, - stats: Box::new(TransferStats { - payload_bytes_sent: self.payload_bytes_sent, - other_bytes_sent: self.other_bytes_sent, - bytes_read: self.bytes_read, - duration: Duration::ZERO, - }), - }) + pub async fn get_request( + mut self, + f: impl FnOnce() -> GetRequest, + ) -> anyhow::Result { + let res = self + .events + .request(f, self.connection_id, self.request_id) .await; + match res { + Err(e) => { + self.writer.reset(e.code()).ok(); + Err(e.into()) + } + Ok(tracker) => Ok(self.into_writer(tracker).await?), + } } - pub async fn send_transfer_aborted(&mut self) { - self.progress - .send(|| Event::TransferAborted { - connection_id: self.connection_id, - request_id: self.request_id, - stats: Some(Box::new(TransferStats { - payload_bytes_sent: self.payload_bytes_sent, - other_bytes_sent: self.other_bytes_sent, - bytes_read: self.bytes_read, - duration: Duration::ZERO, - })), - }) + pub async fn get_many_request( + mut self, + f: impl FnOnce() -> GetManyRequest, + ) -> anyhow::Result { + let res = self + .events + .request(f, self.connection_id, self.request_id) .await; + match res { + Err(e) => { + self.writer.reset(e.code()).ok(); + Err(e.into()) + } + Ok(tracker) => Ok(self.into_writer(tracker).await?), + } } - /// Increase the write count due to a payload write, and notify the progress sender. - /// - /// `index` is the index of the blob in the request. - /// `offset` is the offset in the blob where the write started. - /// `len` is the length of the write. - pub fn notify_payload_write(&mut self, index: u64, offset: u64, len: usize) { - self.payload_bytes_sent += len as u64; - self.progress.try_send(|| Event::TransferProgress { - connection_id: self.connection_id, - request_id: self.request_id, - index, - end_offset: offset + len as u64, - }); + pub async fn push_request( + mut self, + f: impl FnOnce() -> PushRequest, + ) -> anyhow::Result { + let res = self + .events + .request(f, self.connection_id, self.request_id) + .await; + match res { + Err(e) => { + self.writer.reset(e.code()).ok(); + Err(e.into()) + } + Ok(tracker) => Ok(self.into_reader(tracker).await?), + } } - /// Send a get request received event. - /// - /// This sends all the required information to make sense of subsequent events such as - /// [`Event::TransferStarted`] and [`Event::TransferProgress`]. - pub async fn send_get_request_received(&self, hash: &Hash, ranges: &ChunkRangesSeq) { - self.progress - .send(|| Event::GetRequestReceived { - connection_id: self.connection_id, - request_id: self.request_id, - hash: *hash, - ranges: ranges.clone(), - }) + pub async fn observe_request( + mut self, + f: impl FnOnce() -> ObserveRequest, + ) -> anyhow::Result { + let res = self + .events + .request(f, self.connection_id, self.request_id) .await; + match res { + Err(e) => { + self.writer.reset(e.code()).ok(); + Err(e.into()) + } + Ok(tracker) => Ok(self.into_writer(tracker).await?), + } } - /// Send a get request received event. - /// - /// This sends all the required information to make sense of subsequent events such as - /// [`Event::TransferStarted`] and [`Event::TransferProgress`]. - pub async fn send_get_many_request_received(&self, hashes: &[Hash], ranges: &ChunkRangesSeq) { - self.progress - .send(|| Event::GetManyRequestReceived { - connection_id: self.connection_id, - request_id: self.request_id, - hashes: hashes.to_vec(), - ranges: ranges.clone(), - }) - .await; + fn stats(&self) -> TransferStats { + TransferStats { + payload_bytes_sent: 0, + other_bytes_sent: 0, + other_bytes_read: self.other_bytes_read, + duration: self.t0.elapsed(), + } } +} - /// Authorize a push request. - /// - /// This will send a request to the event sender, and wait for a response if a - /// progress sender is enabled. If not, it will always fail. - /// - /// We want to make accepting push requests very explicit, since this allows - /// remote nodes to add arbitrary data to our store. - #[must_use = "permit should be checked by the caller"] - pub async fn authorize_push_request(&self, hash: &Hash, ranges: &ChunkRangesSeq) -> bool { - let mut wait_for_permit = None; - // send the request, including the permit channel - self.progress - .send(|| { - let (tx, rx) = oneshot::channel(); - wait_for_permit = Some(rx); - Event::PushRequestReceived { - connection_id: self.connection_id, - request_id: self.request_id, - hash: *hash, - ranges: ranges.clone(), - permitted: tx, - } - }) - .await; - // wait for the permit, if necessary - if let Some(wait_for_permit) = wait_for_permit { - // if somebody does not handle the request, they will drop the channel, - // and this will fail immediately. - wait_for_permit.await.unwrap_or(false) - } else { - false +#[derive(Debug)] +struct ReaderContext { + /// The start time of the transfer + t0: Instant, + /// The number of bytes read from the stream + other_bytes_read: u64, + /// Progress tracking for the request + tracker: RequestTracker, +} + +impl ReaderContext { + fn stats(&self) -> TransferStats { + TransferStats { + payload_bytes_sent: 0, + other_bytes_sent: 0, + other_bytes_read: self.other_bytes_read, + duration: self.t0.elapsed(), } } +} - /// Send a transfer started event. - pub async fn send_transfer_started(&self, index: u64, hash: &Hash, size: u64) { - self.progress - .send(|| Event::TransferStarted { - connection_id: self.connection_id, - request_id: self.request_id, - index, - hash: *hash, - size, - }) - .await; +#[derive(Debug)] +pub(crate) struct WriterContext { + /// The start time of the transfer + t0: Instant, + /// The number of bytes read from the stream + other_bytes_read: u64, + /// The number of payload bytes written to the stream + payload_bytes_written: u64, + /// The number of bytes written that are not part of the payload + other_bytes_written: u64, + /// Way to report progress + tracker: RequestTracker, +} + +impl WriterContext { + fn stats(&self) -> TransferStats { + TransferStats { + payload_bytes_sent: self.payload_bytes_written, + other_bytes_sent: self.other_bytes_written, + other_bytes_read: self.other_bytes_read, + duration: self.t0.elapsed(), + } + } +} + +impl WriteProgress for WriterContext { + async fn notify_payload_write(&mut self, _index: u64, offset: u64, len: usize) -> ClientResult { + let len = len as u64; + let end_offset = offset + len; + self.payload_bytes_written += len; + self.tracker.transfer_progress(len, end_offset).await + } + + fn log_other_write(&mut self, len: usize) { + self.other_bytes_written += len as u64; + } + + async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64) { + self.tracker.transfer_started(index, hash, size).await.ok(); + } +} + +/// Wrapper for a [`quinn::SendStream`] with additional per request information. +#[derive(Debug)] +pub struct ProgressWriter { + /// The quinn::SendStream to write to + pub inner: SendStream, + pub(crate) context: WriterContext, +} + +impl ProgressWriter { + fn new(inner: SendStream, context: WriterContext) -> Self { + Self { inner, context } + } + + async fn transfer_aborted(&self) { + self.context + .tracker + .transfer_aborted(|| Box::new(self.context.stats())) + .await + .ok(); + } + + async fn transfer_completed(&self) { + self.context + .tracker + .transfer_completed(|| Box::new(self.context.stats())) + .await + .ok(); } } @@ -340,106 +323,73 @@ pub async fn handle_connection( warn!("failed to get node id"); return; }; - if !progress - .authorize_client_connection(connection_id, node_id) + if let Err(cause) = progress + .client_connected(|| ClientConnected { + connection_id, + node_id, + }) .await { - debug!("client not authorized to connect"); + connection.close(cause.code(), cause.reason()); + debug!("closing connection: {cause}"); return; } - while let Ok((writer, reader)) = connection.accept_bi().await { - // The stream ID index is used to identify this request. Requests only arrive in - // bi-directional RecvStreams initiated by the client, so this uniquely identifies them. - let request_id = reader.id().index(); - let span = debug_span!("stream", stream_id = %request_id); + while let Ok(context) = StreamPair::accept(&connection, &progress).await { + let span = debug_span!("stream", stream_id = %context.request_id); let store = store.clone(); - let mut writer = ProgressWriter { - inner: writer, - context: StreamContext { - connection_id, - request_id, - payload_bytes_sent: 0, - other_bytes_sent: 0, - bytes_read: 0, - progress: progress.clone(), - }, - }; - tokio::spawn( - async move { - match handle_stream(store, reader, &mut writer).await { - Ok(()) => { - writer.send_transfer_completed().await; - } - Err(err) => { - warn!("error: {err:#?}",); - writer.send_transfer_aborted().await; - } - } - } - .instrument(span), - ); + tokio::spawn(handle_stream(store, context).instrument(span)); } progress - .send(Event::ConnectionClosed { connection_id }) - .await; + .connection_closed(|| ConnectionClosed { connection_id }) + .await + .ok(); } .instrument(span) .await } -async fn handle_stream( - store: Store, - reader: RecvStream, - writer: &mut ProgressWriter, -) -> Result<()> { +async fn handle_stream(store: Store, mut context: StreamPair) -> anyhow::Result<()> { // 1. Decode the request. debug!("reading request"); - let mut reader = ProgressReader { - inner: reader, - context: StreamContext { - connection_id: writer.connection_id, - request_id: writer.request_id, - payload_bytes_sent: 0, - other_bytes_sent: 0, - bytes_read: 0, - progress: writer.progress.clone(), - }, - }; - let request = match read_request(&mut reader).await { - Ok(request) => request, - Err(e) => { - // todo: increase invalid requests metric counter - return Err(e); - } - }; + let request = context.read_request().await?; match request { Request::Get(request) => { - // we expect no more bytes after the request, so if there are more bytes, it is an invalid request. - reader.inner.read_to_end(0).await?; - // move the context so we don't lose the bytes read - writer.context = reader.context; - handle_get(store, request, writer).await + let mut writer = context.get_request(|| request.clone()).await?; + let res = handle_get(store, request, &mut writer).await; + if res.is_ok() { + writer.transfer_completed().await; + } else { + writer.transfer_aborted().await; + } } Request::GetMany(request) => { - // we expect no more bytes after the request, so if there are more bytes, it is an invalid request. - reader.inner.read_to_end(0).await?; - // move the context so we don't lose the bytes read - writer.context = reader.context; - handle_get_many(store, request, writer).await + let mut writer = context.get_many_request(|| request.clone()).await?; + if handle_get_many(store, request, &mut writer).await.is_ok() { + writer.transfer_completed().await; + } else { + writer.transfer_aborted().await; + } } Request::Observe(request) => { - // we expect no more bytes after the request, so if there are more bytes, it is an invalid request. - reader.inner.read_to_end(0).await?; - handle_observe(store, request, writer).await + let mut writer = context.observe_request(|| request.clone()).await?; + if handle_observe(store, request, &mut writer).await.is_ok() { + writer.transfer_completed().await; + } else { + writer.transfer_aborted().await; + } } Request::Push(request) => { - writer.inner.finish()?; - handle_push(store, request, reader).await + let mut reader = context.push_request(|| request.clone()).await?; + if handle_push(store, request, &mut reader).await.is_ok() { + reader.transfer_completed().await; + } else { + reader.transfer_aborted().await; + } } - _ => anyhow::bail!("unsupported request: {request:?}"), - // Request::Push(request) => handle_push(store, request, writer).await, + _ => {} } + Ok(()) } /// Handle a single get request. @@ -449,13 +399,9 @@ pub async fn handle_get( store: Store, request: GetRequest, writer: &mut ProgressWriter, -) -> Result<()> { +) -> anyhow::Result<()> { let hash = request.hash; debug!(%hash, "get received request"); - - writer - .send_get_request_received(&hash, &request.ranges) - .await; let mut hash_seq = None; for (offset, ranges) in request.ranges.iter_non_empty_infinite() { if offset == 0 { @@ -495,9 +441,6 @@ pub async fn handle_get_many( writer: &mut ProgressWriter, ) -> Result<()> { debug!("get_many received request"); - writer - .send_get_many_request_received(&request.hashes, &request.ranges) - .await; let request_ranges = request.ranges.iter_infinite(); for (child, (hash, ranges)) in request.hashes.iter().zip(request_ranges).enumerate() { if !ranges.is_empty() { @@ -513,14 +456,10 @@ pub async fn handle_get_many( pub async fn handle_push( store: Store, request: PushRequest, - mut reader: ProgressReader, + reader: &mut ProgressReader, ) -> Result<()> { let hash = request.hash; debug!(%hash, "push received request"); - if !reader.authorize_push_request(&hash, &request.ranges).await { - debug!("push request not authorized"); - return Ok(()); - }; let mut request_ranges = request.ranges.iter_infinite(); let root_ranges = request_ranges.next().expect("infinite iterator"); if !root_ranges.is_empty() { @@ -554,11 +493,11 @@ pub(crate) async fn send_blob( hash: Hash, ranges: ChunkRanges, writer: &mut ProgressWriter, -) -> api::Result<()> { - Ok(store +) -> ExportBaoResult<()> { + store .export_bao(hash, ranges) .write_quinn_with_progress(&mut writer.inner, &mut writer.context, &hash, index) - .await?) + .await } /// Handle a single push request. @@ -601,162 +540,51 @@ async fn send_observe_item(writer: &mut ProgressWriter, item: &Bitfield) -> Resu use irpc::util::AsyncWriteVarintExt; let item = ObserveItem::from(item); let len = writer.inner.write_length_prefixed(item).await?; - writer.log_other_write(len); + writer.context.log_other_write(len); Ok(()) } -/// Helper to lazyly create an [`Event`], in the case that the event creation -/// is expensive and we want to avoid it if the progress sender is disabled. -pub trait LazyEvent { - fn call(self) -> Event; -} - -impl LazyEvent for T -where - T: FnOnce() -> Event, -{ - fn call(self) -> Event { - self() - } -} - -impl LazyEvent for Event { - fn call(self) -> Event { - self - } -} - -/// A sender for provider events. -#[derive(Debug, Clone)] -pub struct EventSender(EventSenderInner); - -#[derive(Debug, Clone)] -enum EventSenderInner { - Disabled, - Enabled(mpsc::Sender), -} - -impl EventSender { - pub fn new(sender: Option>) -> Self { - match sender { - Some(sender) => Self(EventSenderInner::Enabled(sender)), - None => Self(EventSenderInner::Disabled), - } - } - - /// Send a client connected event, if the progress sender is enabled. - /// - /// This will permit the client to connect if the sender is disabled. - #[must_use = "permit should be checked by the caller"] - pub async fn authorize_client_connection(&self, connection_id: u64, node_id: NodeId) -> bool { - let mut wait_for_permit = None; - self.send(|| { - let (tx, rx) = oneshot::channel(); - wait_for_permit = Some(rx); - Event::ClientConnected { - connection_id, - node_id, - permitted: tx, - } - }) - .await; - if let Some(wait_for_permit) = wait_for_permit { - // if we have events configured, and they drop the channel, we consider that as a no! - // todo: this will be confusing and needs to be properly documented. - wait_for_permit.await.unwrap_or(false) - } else { - true - } - } - - /// Send an ephemeral event, if the progress sender is enabled. - /// - /// The event will only be created if the sender is enabled. - fn try_send(&self, event: impl LazyEvent) { - match &self.0 { - EventSenderInner::Enabled(sender) => { - let value = event.call(); - sender.try_send(value).ok(); - } - EventSenderInner::Disabled => {} - } - } - - /// Send a mandatory event, if the progress sender is enabled. - /// - /// The event only be created if the sender is enabled. - async fn send(&self, event: impl LazyEvent) { - match &self.0 { - EventSenderInner::Enabled(sender) => { - let value = event.call(); - if let Err(err) = sender.send(value).await { - error!("failed to send progress event: {:?}", err); - } - } - EventSenderInner::Disabled => {} - } - } -} - pub struct ProgressReader { inner: RecvStream, - context: StreamContext, + context: ReaderContext, } -impl Deref for ProgressReader { - type Target = StreamContext; - - fn deref(&self) -> &Self::Target { - &self.context +impl ProgressReader { + async fn transfer_aborted(&self) { + self.context + .tracker + .transfer_aborted(|| Box::new(self.context.stats())) + .await + .ok(); } -} -impl DerefMut for ProgressReader { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.context + async fn transfer_completed(&self) { + self.context + .tracker + .transfer_completed(|| Box::new(self.context.stats())) + .await + .ok(); } } -pub struct CountingReader { - pub inner: R, - pub read: u64, +pub(crate) trait RecvStreamExt { + async fn read_to_end_as( + &mut self, + max_size: usize, + ) -> io::Result<(T, usize)>; } -impl CountingReader { - pub fn new(inner: R) -> Self { - Self { inner, read: 0 } - } - - pub fn read(&self) -> u64 { - self.read - } -} - -impl CountingReader<&mut iroh::endpoint::RecvStream> { - pub async fn read_to_end_as(&mut self, max_size: usize) -> io::Result { +impl RecvStreamExt for RecvStream { + async fn read_to_end_as( + &mut self, + max_size: usize, + ) -> io::Result<(T, usize)> { let data = self - .inner .read_to_end(max_size) .await .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; let value = postcard::from_bytes(&data) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - self.read += data.len() as u64; - Ok(value) - } -} - -impl AsyncRead for CountingReader { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> Poll> { - let this = self.get_mut(); - let result = Pin::new(&mut this.inner).poll_read(cx, buf); - if let Poll::Ready(Ok(())) = result { - this.read += buf.filled().len() as u64; - } - result + Ok((value, data.len())) } } diff --git a/src/provider/events.rs b/src/provider/events.rs new file mode 100644 index 000000000..e24e0efbb --- /dev/null +++ b/src/provider/events.rs @@ -0,0 +1,702 @@ +use std::{fmt::Debug, io, ops::Deref}; + +use irpc::{ + channel::{mpsc, none::NoSender, oneshot}, + rpc_requests, Channels, WithChannels, +}; +use serde::{Deserialize, Serialize}; +use snafu::Snafu; + +use crate::{ + protocol::{ + GetManyRequest, GetRequest, ObserveRequest, PushRequest, ERR_INTERNAL, ERR_LIMIT, + ERR_PERMISSION, + }, + provider::{events::irpc_ext::IrpcClientExt, TransferStats}, + Hash, +}; + +/// Mode for connect events. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +#[repr(u8)] +pub enum ConnectMode { + /// We don't get notification of connect events at all. + #[default] + None, + /// We get a notification for connect events. + Notify, + /// We get a request for connect events and can reject incoming connections. + Intercept, +} + +/// Request mode for observe requests. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +#[repr(u8)] +pub enum ObserveMode { + /// We don't get notification of connect events at all. + #[default] + None, + /// We get a notification for connect events. + Notify, + /// We get a request for connect events and can reject incoming connections. + Intercept, +} + +/// Request mode for all data related requests. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +#[repr(u8)] +pub enum RequestMode { + /// We don't get request events at all. + #[default] + None, + /// We get a notification for each request, but no transfer events. + Notify, + /// We get a request for each request, and can reject incoming requests, but no transfer events. + Intercept, + /// We get a notification for each request as well as detailed transfer events. + NotifyLog, + /// We get a request for each request, and can reject incoming requests. + /// We also get detailed transfer events. + InterceptLog, + /// This request type is completely disabled. All requests will be rejected. + /// + /// This means that requests of this kind will always be rejected, whereas + /// None means that we don't get any events, but requests will be processed normally. + Disabled, +} + +/// Throttling mode for requests that support throttling. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +#[repr(u8)] +pub enum ThrottleMode { + /// We don't get these kinds of events at all + #[default] + None, + /// We call throttle to give the event handler a way to throttle requests + Intercept, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum AbortReason { + /// The request was aborted because a limit was exceeded. It is OK to try again later. + RateLimited, + /// The request was aborted because the client does not have permission to perform the operation. + Permission, +} + +/// Errors that can occur when sending progress updates. +#[derive(Debug, Snafu)] +pub enum ProgressError { + Limit, + Permission, + #[snafu(transparent)] + Internal { + source: irpc::Error, + }, +} + +impl From for io::Error { + fn from(value: ProgressError) -> Self { + match value { + ProgressError::Limit => io::ErrorKind::QuotaExceeded.into(), + ProgressError::Permission => io::ErrorKind::PermissionDenied.into(), + ProgressError::Internal { source } => source.into(), + } + } +} + +impl ProgressError { + pub fn code(&self) -> quinn::VarInt { + match self { + ProgressError::Limit => ERR_LIMIT, + ProgressError::Permission => ERR_PERMISSION, + ProgressError::Internal { .. } => ERR_INTERNAL, + } + } + + pub fn reason(&self) -> &'static [u8] { + match self { + ProgressError::Limit => b"limit", + ProgressError::Permission => b"permission", + ProgressError::Internal { .. } => b"internal", + } + } +} + +impl From for ProgressError { + fn from(value: AbortReason) -> Self { + match value { + AbortReason::RateLimited => ProgressError::Limit, + AbortReason::Permission => ProgressError::Permission, + } + } +} + +impl From for ProgressError { + fn from(value: irpc::channel::RecvError) -> Self { + ProgressError::Internal { + source: value.into(), + } + } +} + +impl From for ProgressError { + fn from(value: irpc::channel::SendError) -> Self { + ProgressError::Internal { + source: value.into(), + } + } +} + +pub type EventResult = Result<(), AbortReason>; +pub type ClientResult = Result<(), ProgressError>; + +/// Event mask to configure which events are sent to the event handler. +/// +/// This can also be used to completely disable certain request types. E.g. +/// push requests are disabled by default, as they can write to the local store. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct EventMask { + /// Connection event mask + pub connected: ConnectMode, + /// Get request event mask + pub get: RequestMode, + /// Get many request event mask + pub get_many: RequestMode, + /// Push request event mask + pub push: RequestMode, + /// Observe request event mask + pub observe: ObserveMode, + /// throttling is somewhat costly, so you can disable it completely + pub throttle: ThrottleMode, +} + +impl Default for EventMask { + fn default() -> Self { + Self::DEFAULT + } +} + +impl EventMask { + /// All event notifications are fully disabled. Push requests are disabled by default. + pub const DEFAULT: Self = Self { + connected: ConnectMode::None, + get: RequestMode::None, + get_many: RequestMode::None, + push: RequestMode::Disabled, + throttle: ThrottleMode::None, + observe: ObserveMode::None, + }; + + /// All event notifications for read-only requests are fully enabled. + /// + /// If you want to enable push requests, which can write to the local store, you + /// need to do it manually. Providing constants that have push enabled would + /// risk misuse. + pub const ALL_READONLY: Self = Self { + connected: ConnectMode::Intercept, + get: RequestMode::InterceptLog, + get_many: RequestMode::InterceptLog, + push: RequestMode::Disabled, + throttle: ThrottleMode::Intercept, + observe: ObserveMode::Intercept, + }; +} + +/// Newtype wrapper that wraps an event so that it is a distinct type for the notify variant. +#[derive(Debug, Serialize, Deserialize)] +pub struct Notify(T); + +impl Deref for Notify { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[derive(Debug, Default, Clone)] +pub struct EventSender { + mask: EventMask, + inner: Option>, +} + +#[derive(Debug, Default)] +enum RequestUpdates { + /// Request tracking was not configured, all ops are no-ops + #[default] + None, + /// Active request tracking, all ops actually send + Active(mpsc::Sender), + /// Disabled request tracking, we just hold on to the sender so it drops + /// once the request is completed or aborted. + Disabled(#[allow(dead_code)] mpsc::Sender), +} + +#[derive(Debug)] +pub struct RequestTracker { + updates: RequestUpdates, + throttle: Option<(irpc::Client, u64, u64)>, +} + +impl RequestTracker { + fn new( + updates: RequestUpdates, + throttle: Option<(irpc::Client, u64, u64)>, + ) -> Self { + Self { updates, throttle } + } + + /// A request tracker that doesn't track anything. + pub const NONE: Self = Self { + updates: RequestUpdates::None, + throttle: None, + }; + + /// Transfer for index `index` started, size `size` in bytes. + pub async fn transfer_started(&self, index: u64, hash: &Hash, size: u64) -> irpc::Result<()> { + if let RequestUpdates::Active(tx) = &self.updates { + tx.send( + TransferStarted { + index, + hash: *hash, + size, + } + .into(), + ) + .await?; + } + Ok(()) + } + + /// Transfer progress for the previously reported blob, end_offset is the new end offset in bytes. + pub async fn transfer_progress(&mut self, len: u64, end_offset: u64) -> ClientResult { + if let RequestUpdates::Active(tx) = &mut self.updates { + tx.try_send(TransferProgress { end_offset }.into()).await?; + } + if let Some((throttle, connection_id, request_id)) = &self.throttle { + throttle + .rpc(Throttle { + connection_id: *connection_id, + request_id: *request_id, + size: len, + }) + .await??; + } + Ok(()) + } + + /// Transfer completed for the previously reported blob. + pub async fn transfer_completed(&self, f: impl Fn() -> Box) -> irpc::Result<()> { + if let RequestUpdates::Active(tx) = &self.updates { + tx.send(TransferCompleted { stats: f() }.into()).await?; + } + Ok(()) + } + + /// Transfer aborted for the previously reported blob. + pub async fn transfer_aborted(&self, f: impl Fn() -> Box) -> irpc::Result<()> { + if let RequestUpdates::Active(tx) = &self.updates { + tx.send(TransferAborted { stats: f() }.into()).await?; + } + Ok(()) + } +} + +/// Client for progress notifications. +/// +/// For most event types, the client can be configured to either send notifications or requests that +/// can have a response. +impl EventSender { + /// A client that does not send anything. + pub const DEFAULT: Self = Self { + mask: EventMask::DEFAULT, + inner: None, + }; + + pub fn new(client: tokio::sync::mpsc::Sender, mask: EventMask) -> Self { + Self { + mask, + inner: Some(irpc::Client::from(client)), + } + } + + pub fn channel( + capacity: usize, + mask: EventMask, + ) -> (Self, tokio::sync::mpsc::Receiver) { + let (tx, rx) = tokio::sync::mpsc::channel(capacity); + (Self::new(tx, mask), rx) + } + + /// Log request events at trace level. + pub fn tracing(&self, mask: EventMask) -> Self { + use tracing::trace; + let (tx, mut rx) = tokio::sync::mpsc::channel(32); + n0_future::task::spawn(async move { + fn log_request_events( + mut rx: irpc::channel::mpsc::Receiver, + connection_id: u64, + request_id: u64, + ) { + n0_future::task::spawn(async move { + while let Ok(Some(update)) = rx.recv().await { + trace!(%connection_id, %request_id, "{update:?}"); + } + }); + } + while let Some(msg) = rx.recv().await { + match msg { + ProviderMessage::ClientConnected(msg) => { + trace!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + } + ProviderMessage::ClientConnectedNotify(msg) => { + trace!("{:?}", msg.inner); + } + ProviderMessage::ConnectionClosed(msg) => { + trace!("{:?}", msg.inner); + } + ProviderMessage::GetRequestReceived(msg) => { + trace!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::GetRequestReceivedNotify(msg) => { + trace!("{:?}", msg.inner); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::GetManyRequestReceived(msg) => { + trace!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::GetManyRequestReceivedNotify(msg) => { + trace!("{:?}", msg.inner); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::PushRequestReceived(msg) => { + trace!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::PushRequestReceivedNotify(msg) => { + trace!("{:?}", msg.inner); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::ObserveRequestReceived(msg) => { + trace!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::ObserveRequestReceivedNotify(msg) => { + trace!("{:?}", msg.inner); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::Throttle(msg) => { + trace!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + } + } + } + }); + Self { + mask, + inner: Some(irpc::Client::from(tx)), + } + } + + /// A new client has been connected. + pub async fn client_connected(&self, f: impl Fn() -> ClientConnected) -> ClientResult { + if let Some(client) = &self.inner { + match self.mask.connected { + ConnectMode::None => {} + ConnectMode::Notify => client.notify(Notify(f())).await?, + ConnectMode::Intercept => client.rpc(f()).await??, + } + }; + Ok(()) + } + + /// A connection has been closed. + pub async fn connection_closed(&self, f: impl Fn() -> ConnectionClosed) -> ClientResult { + if let Some(client) = &self.inner { + client.notify(f()).await?; + }; + Ok(()) + } + + /// Abstract request, to DRY the 3 to 4 request types. + /// + /// DRYing stuff with lots of bounds is no fun at all... + pub(crate) async fn request( + &self, + f: impl FnOnce() -> Req, + connection_id: u64, + request_id: u64, + ) -> Result + where + ProviderProto: From>, + ProviderMessage: From, ProviderProto>>, + RequestReceived: Channels< + ProviderProto, + Tx = oneshot::Sender, + Rx = mpsc::Receiver, + >, + ProviderProto: From>>, + ProviderMessage: From>, ProviderProto>>, + Notify>: + Channels>, + { + let client = self.inner.as_ref(); + Ok(self.create_tracker(( + match self.mask.get { + RequestMode::None => RequestUpdates::None, + RequestMode::Notify if client.is_some() => { + let msg = RequestReceived { + request: f(), + connection_id, + request_id, + }; + RequestUpdates::Disabled( + client.unwrap().notify_streaming(Notify(msg), 32).await?, + ) + } + RequestMode::Intercept if client.is_some() => { + let msg = RequestReceived { + request: f(), + connection_id, + request_id, + }; + let (tx, rx) = client.unwrap().client_streaming(msg, 32).await?; + // bail out if the request is not allowed + rx.await??; + RequestUpdates::Disabled(tx) + } + RequestMode::NotifyLog if client.is_some() => { + let msg = RequestReceived { + request: f(), + connection_id, + request_id, + }; + RequestUpdates::Active(client.unwrap().notify_streaming(Notify(msg), 32).await?) + } + RequestMode::InterceptLog if client.is_some() => { + let msg = RequestReceived { + request: f(), + connection_id, + request_id, + }; + let (tx, rx) = client.unwrap().client_streaming(msg, 32).await?; + // bail out if the request is not allowed + rx.await??; + RequestUpdates::Active(tx) + } + RequestMode::Disabled => { + return Err(ProgressError::Permission); + } + _ => RequestUpdates::None, + }, + connection_id, + request_id, + ))) + } + + fn create_tracker( + &self, + (updates, connection_id, request_id): (RequestUpdates, u64, u64), + ) -> RequestTracker { + let throttle = match self.mask.throttle { + ThrottleMode::None => None, + ThrottleMode::Intercept => self + .inner + .clone() + .map(|client| (client, connection_id, request_id)), + }; + RequestTracker::new(updates, throttle) + } +} + +#[rpc_requests(message = ProviderMessage)] +#[derive(Debug, Serialize, Deserialize)] +pub enum ProviderProto { + /// A new client connected to the provider. + #[rpc(tx = oneshot::Sender)] + ClientConnected(ClientConnected), + + /// A new client connected to the provider. Notify variant. + #[rpc(tx = NoSender)] + ClientConnectedNotify(Notify), + + /// A client disconnected from the provider. + #[rpc(tx = NoSender)] + ConnectionClosed(ConnectionClosed), + + /// A new get request was received from the provider. + #[rpc(rx = mpsc::Receiver, tx = oneshot::Sender)] + GetRequestReceived(RequestReceived), + + /// A new get request was received from the provider (notify variant). + #[rpc(rx = mpsc::Receiver, tx = NoSender)] + GetRequestReceivedNotify(Notify>), + + /// A new get many request was received from the provider. + #[rpc(rx = mpsc::Receiver, tx = oneshot::Sender)] + GetManyRequestReceived(RequestReceived), + + /// A new get many request was received from the provider (notify variant). + #[rpc(rx = mpsc::Receiver, tx = NoSender)] + GetManyRequestReceivedNotify(Notify>), + + /// A new push request was received from the provider. + #[rpc(rx = mpsc::Receiver, tx = oneshot::Sender)] + PushRequestReceived(RequestReceived), + + /// A new push request was received from the provider (notify variant). + #[rpc(rx = mpsc::Receiver, tx = NoSender)] + PushRequestReceivedNotify(Notify>), + + /// A new observe request was received from the provider. + #[rpc(rx = mpsc::Receiver, tx = oneshot::Sender)] + ObserveRequestReceived(RequestReceived), + + /// A new observe request was received from the provider (notify variant). + #[rpc(rx = mpsc::Receiver, tx = NoSender)] + ObserveRequestReceivedNotify(Notify>), + + /// Request to throttle sending for a specific data request. + #[rpc(tx = oneshot::Sender)] + Throttle(Throttle), +} + +mod proto { + use iroh::NodeId; + use serde::{Deserialize, Serialize}; + + use crate::{provider::TransferStats, Hash}; + + #[derive(Debug, Serialize, Deserialize)] + pub struct ClientConnected { + pub connection_id: u64, + pub node_id: NodeId, + } + + #[derive(Debug, Serialize, Deserialize)] + pub struct ConnectionClosed { + pub connection_id: u64, + } + + /// A new get request was received from the provider. + #[derive(Debug, Serialize, Deserialize)] + pub struct RequestReceived { + /// The connection id. Multiple requests can be sent over the same connection. + pub connection_id: u64, + /// The request id. There is a new id for each request. + pub request_id: u64, + /// The request + pub request: R, + } + + /// Request to throttle sending for a specific request. + #[derive(Debug, Serialize, Deserialize)] + pub struct Throttle { + /// The connection id. Multiple requests can be sent over the same connection. + pub connection_id: u64, + /// The request id. There is a new id for each request. + pub request_id: u64, + /// Size of the chunk to be throttled. This will usually be 16 KiB. + pub size: u64, + } + + #[derive(Debug, Serialize, Deserialize)] + pub struct TransferProgress { + /// The end offset of the chunk that was sent. + pub end_offset: u64, + } + + #[derive(Debug, Serialize, Deserialize)] + pub struct TransferStarted { + pub index: u64, + pub hash: Hash, + pub size: u64, + } + + #[derive(Debug, Serialize, Deserialize)] + pub struct TransferCompleted { + pub stats: Box, + } + + #[derive(Debug, Serialize, Deserialize)] + pub struct TransferAborted { + pub stats: Box, + } + + /// Stream of updates for a single request + #[derive(Debug, Serialize, Deserialize, derive_more::From)] + pub enum RequestUpdate { + /// Start of transfer for a blob, mandatory event + Started(TransferStarted), + /// Progress for a blob - optional event + Progress(TransferProgress), + /// Successful end of transfer + Completed(TransferCompleted), + /// Aborted end of transfer + Aborted(TransferAborted), + } +} +pub use proto::*; + +mod irpc_ext { + use std::future::Future; + + use irpc::{ + channel::{mpsc, none::NoSender}, + Channels, RpcMessage, Service, WithChannels, + }; + + pub trait IrpcClientExt { + fn notify_streaming( + &self, + msg: Req, + local_update_cap: usize, + ) -> impl Future>> + where + S: From, + S::Message: From>, + Req: Channels>, + Update: RpcMessage; + } + + impl IrpcClientExt for irpc::Client { + fn notify_streaming( + &self, + msg: Req, + local_update_cap: usize, + ) -> impl Future>> + where + S: From, + S::Message: From>, + Req: Channels>, + Update: RpcMessage, + { + let client = self.clone(); + async move { + let request = client.request().await?; + match request { + irpc::Request::Local(local) => { + let (req_tx, req_rx) = mpsc::channel(local_update_cap); + local + .send((msg, NoSender, req_rx)) + .await + .map_err(irpc::Error::from)?; + Ok(req_tx) + } + irpc::Request::Remote(remote) => { + let (s, _) = remote.write(msg).await?; + Ok(s.into()) + } + } + } + } + } +} diff --git a/src/tests.rs b/src/tests.rs index e7dc823e6..0ef0c027c 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -16,7 +16,7 @@ use crate::{ hashseq::HashSeq, net_protocol::BlobsProtocol, protocol::{ChunkRangesSeq, GetManyRequest, ObserveRequest, PushRequest}, - provider::Event, + provider::events::{AbortReason, EventMask, EventSender, ProviderMessage, RequestUpdate}, store::{ fs::{ tests::{create_n0_bao, test_data, INTERESTING_SIZES}, @@ -340,27 +340,31 @@ async fn two_nodes_get_many_mem() -> TestResult<()> { fn event_handler( allowed_nodes: impl IntoIterator, -) -> ( - mpsc::Sender, - watch::Receiver, - AbortOnDropHandle<()>, -) { +) -> (EventSender, watch::Receiver, AbortOnDropHandle<()>) { let (count_tx, count_rx) = tokio::sync::watch::channel(0usize); - let (events_tx, mut events_rx) = mpsc::channel::(16); + let (events_tx, mut events_rx) = EventSender::channel(16, EventMask::ALL_READONLY); let allowed_nodes = allowed_nodes.into_iter().collect::>(); let task = AbortOnDropHandle::new(tokio::task::spawn(async move { while let Some(event) = events_rx.recv().await { match event { - Event::ClientConnected { - node_id, permitted, .. - } => { - permitted.send(allowed_nodes.contains(&node_id)).await.ok(); + ProviderMessage::ClientConnected(msg) => { + let res = if allowed_nodes.contains(&msg.inner.node_id) { + Ok(()) + } else { + Err(AbortReason::Permission) + }; + msg.tx.send(res).await.ok(); } - Event::PushRequestReceived { permitted, .. } => { - permitted.send(true).await.ok(); - } - Event::TransferCompleted { .. } => { - count_tx.send_modify(|count| *count += 1); + ProviderMessage::PushRequestReceived(mut msg) => { + msg.tx.send(Ok(())).await.ok(); + let count_tx = count_tx.clone(); + tokio::task::spawn(async move { + while let Ok(Some(update)) = msg.rx.recv().await { + if let RequestUpdate::Completed(_) = update { + count_tx.send_modify(|x| *x += 1); + } + } + }); } _ => {} } @@ -409,7 +413,7 @@ async fn two_nodes_push_blobs_fs() -> TestResult<()> { let (r1, store1, _) = node_test_setup_fs(testdir.path().join("a")).await?; let (events_tx, count_rx, _task) = event_handler([r1.endpoint().node_id()]); let (r2, store2, _) = - node_test_setup_with_events_fs(testdir.path().join("b"), Some(events_tx)).await?; + node_test_setup_with_events_fs(testdir.path().join("b"), events_tx).await?; two_nodes_push_blobs(r1, &store1, r2, &store2, count_rx).await } @@ -418,7 +422,7 @@ async fn two_nodes_push_blobs_mem() -> TestResult<()> { tracing_subscriber::fmt::try_init().ok(); let (r1, store1) = node_test_setup_mem().await?; let (events_tx, count_rx, _task) = event_handler([r1.endpoint().node_id()]); - let (r2, store2) = node_test_setup_with_events_mem(Some(events_tx)).await?; + let (r2, store2) = node_test_setup_with_events_mem(events_tx).await?; two_nodes_push_blobs(r1, &store1, r2, &store2, count_rx).await } @@ -481,30 +485,30 @@ async fn check_presence(store: &Store, sizes: &[usize]) -> TestResult<()> { } pub async fn node_test_setup_fs(db_path: PathBuf) -> TestResult<(Router, FsStore, PathBuf)> { - node_test_setup_with_events_fs(db_path, None).await + node_test_setup_with_events_fs(db_path, EventSender::DEFAULT).await } pub async fn node_test_setup_with_events_fs( db_path: PathBuf, - events: Option>, + events: EventSender, ) -> TestResult<(Router, FsStore, PathBuf)> { let store = crate::store::fs::FsStore::load(&db_path).await?; let ep = Endpoint::builder().bind().await?; - let blobs = BlobsProtocol::new(&store, ep.clone(), events); + let blobs = BlobsProtocol::new(&store, ep.clone(), Some(events)); let router = Router::builder(ep).accept(crate::ALPN, blobs).spawn(); Ok((router, store, db_path)) } pub async fn node_test_setup_mem() -> TestResult<(Router, MemStore)> { - node_test_setup_with_events_mem(None).await + node_test_setup_with_events_mem(EventSender::DEFAULT).await } pub async fn node_test_setup_with_events_mem( - events: Option>, + events: EventSender, ) -> TestResult<(Router, MemStore)> { let store = MemStore::new(); let ep = Endpoint::builder().bind().await?; - let blobs = BlobsProtocol::new(&store, ep.clone(), events); + let blobs = BlobsProtocol::new(&store, ep.clone(), Some(events)); let router = Router::builder(ep).accept(crate::ALPN, blobs).spawn(); Ok((router, store)) } diff --git a/src/util.rs b/src/util.rs index 59e366d81..bc9c25694 100644 --- a/src/util.rs +++ b/src/util.rs @@ -364,7 +364,7 @@ pub(crate) mod outboard_with_progress { } pub(crate) mod sink { - use std::{future::Future, io}; + use std::future::Future; use irpc::RpcMessage; @@ -434,10 +434,13 @@ pub(crate) mod sink { pub struct TokioMpscSenderSink(pub tokio::sync::mpsc::Sender); impl Sink for TokioMpscSenderSink { - type Error = tokio::sync::mpsc::error::SendError; + type Error = irpc::channel::SendError; async fn send(&mut self, value: T) -> std::result::Result<(), Self::Error> { - self.0.send(value).await + self.0 + .send(value) + .await + .map_err(|_| irpc::channel::SendError::ReceiverClosed) } } @@ -484,10 +487,10 @@ pub(crate) mod sink { pub struct Drain; impl Sink for Drain { - type Error = io::Error; + type Error = irpc::channel::SendError; async fn send(&mut self, _offset: T) -> std::result::Result<(), Self::Error> { - io::Result::Ok(()) + Ok(()) } } }