From e5ac6dd432ec4c16c3d2801e8ef0c1f125380794 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Thu, 28 Aug 2025 14:54:07 +0200 Subject: [PATCH 01/23] WIP make provider events a proper irpc protocol and allow configuring notifications/requests for each event type. --- Cargo.lock | 4 - Cargo.toml | 2 +- src/provider.rs | 5 +- src/provider/event_proto.rs | 290 ++++++++++++++++++++++++++++++++++++ 4 files changed, 294 insertions(+), 7 deletions(-) create mode 100644 src/provider/event_proto.rs diff --git a/Cargo.lock b/Cargo.lock index 4068354f7..1a4de777e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1943,8 +1943,6 @@ dependencies = [ [[package]] name = "irpc" version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9f8f1d0987ea9da3d74698f921d0a817a214c83b2635a33ed4bc3efa4de1acd" dependencies = [ "anyhow", "futures-buffered", @@ -1966,8 +1964,6 @@ dependencies = [ [[package]] name = "irpc-derive" version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e0b26b834d401a046dd9d47bc236517c746eddbb5d25ff3e1a6075bfa4eebdb" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index bcd5f42d0..3a642632c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,7 +42,7 @@ self_cell = "1.1.0" genawaiter = { version = "0.99.1", features = ["futures03"] } iroh-base = "0.91.1" reflink-copy = "0.1.24" -irpc = { version = "0.7.0", features = ["rpc", "quinn_endpoint_setup", "spans", "stream", "derive"], default-features = false } +irpc = { version = "0.7.0", features = ["rpc", "quinn_endpoint_setup", "spans", "stream", "derive"], default-features = false, path = "../irpc" } iroh-metrics = { version = "0.35" } [dev-dependencies] diff --git a/src/provider.rs b/src/provider.rs index 61af8f6e1..141d674c6 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -20,7 +20,7 @@ use iroh::{ }; use irpc::channel::oneshot; use n0_future::StreamExt; -use serde::de::DeserializeOwned; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; use tokio::{io::AsyncRead, select, sync::mpsc}; use tracing::{debug, debug_span, error, warn, Instrument}; @@ -33,6 +33,7 @@ use crate::{ }, Hash, }; +mod event_proto; /// Provider progress events, to keep track of what the provider is doing. /// @@ -129,7 +130,7 @@ pub enum Event { } /// Statistics about a successful or failed transfer. -#[derive(Debug)] +#[derive(Debug, Serialize, Deserialize)] pub struct TransferStats { /// The number of bytes sent that are part of the payload. pub payload_bytes_sent: u64, diff --git a/src/provider/event_proto.rs b/src/provider/event_proto.rs new file mode 100644 index 000000000..cc6e1aeab --- /dev/null +++ b/src/provider/event_proto.rs @@ -0,0 +1,290 @@ +use std::fmt::Debug; + +use iroh::NodeId; +use irpc::{ + channel::{none::NoSender, oneshot}, + rpc_requests, +}; +use serde::{Deserialize, Serialize}; +use snafu::Snafu; + +use crate::{protocol::ChunkRangesSeq, provider::TransferStats, Hash}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +#[repr(u8)] +pub enum EventMode { + /// We don't get these kinds of events at all + #[default] + None, + /// We get a notification for these kinds of events + Notify, + /// We can respond to these kinds of events, either by aborting or by + /// e.g. introducing a delay for throttling. + Request, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +#[repr(u8)] +pub enum EventMode2 { + /// We don't get these kinds of events at all + #[default] + None, + /// We get a notification for these kinds of events + Notify, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum AbortReason { + RateLimited, + Permission, +} + +#[derive(Debug, Snafu)] +pub enum ClientError { + RateLimited, + Permission, + #[snafu(transparent)] + Irpc { + source: irpc::Error, + }, +} + +impl From for ClientError { + fn from(value: AbortReason) -> Self { + match value { + AbortReason::RateLimited => ClientError::RateLimited, + AbortReason::Permission => ClientError::Permission, + } + } +} + +pub type EventResult = Result<(), AbortReason>; +pub type ClientResult = Result<(), ClientError>; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct EventMask { + connected: EventMode, + get: EventMode, + get_many: EventMode, + push: EventMode, + transfer: EventMode, + transfer_complete: EventMode2, + transfer_aborted: EventMode2, +} + +/// Newtype wrapper that wraps an event so that it is a distinct type for the notify variant. +#[derive(Debug, Serialize, Deserialize)] +pub struct Notify(T); + +#[derive(Debug, Default)] +pub struct Client { + mask: EventMask, + inner: Option>, +} + +/// A new get request was received from the provider. +#[derive(Debug, Serialize, Deserialize)] +pub struct GetRequestReceived { + /// The connection id. Multiple requests can be sent over the same connection. + pub connection_id: u64, + /// The request id. There is a new id for each request. + pub request_id: u64, + /// The root hash of the request. + pub hash: Hash, + /// The exact query ranges of the request. + pub ranges: ChunkRangesSeq, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct GetManyRequestReceived { + /// The connection id. Multiple requests can be sent over the same connection. + pub connection_id: u64, + /// The request id. There is a new id for each request. + pub request_id: u64, + /// The root hash of the request. + pub hashes: Vec, + /// The exact query ranges of the request. + pub ranges: ChunkRangesSeq, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct PushRequestReceived { + /// The connection id. Multiple requests can be sent over the same connection. + pub connection_id: u64, + /// The request id. There is a new id for each request. + pub request_id: u64, + /// The root hash of the request. + pub hash: Hash, + /// The exact query ranges of the request. + pub ranges: ChunkRangesSeq, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct TransferProgress { + /// The connection id. Multiple requests can be sent over the same connection. + pub connection_id: u64, + /// The request id. There is a new id for each request. + pub request_id: u64, + /// The index of the blob in the request. 0 for the first blob or for raw blob requests. + pub index: u64, + /// The end offset of the chunk that was sent. + pub end_offset: u64, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct TransferStarted { + /// The connection id. Multiple requests can be sent over the same connection. + pub connection_id: u64, + /// The request id. There is a new id for each request. + pub request_id: u64, + /// The index of the blob in the request. 0 for the first blob or for raw blob requests. + pub index: u64, + /// The hash of the blob. This is the hash of the request for the first blob, the child hash (index-1) for subsequent blobs. + pub hash: Hash, + /// The size of the blob. This is the full size of the blob, not the size we are sending. + pub size: u64, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct TransferCompleted { + /// The connection id. Multiple requests can be sent over the same connection. + pub connection_id: u64, + /// The request id. There is a new id for each request. + pub request_id: u64, + /// Statistics about the transfer. + pub stats: Box, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct TransferAborted { + /// The connection id. Multiple requests can be sent over the same connection. + pub connection_id: u64, + /// The request id. There is a new id for each request. + pub request_id: u64, + /// Statistics about the part of the transfer that was aborted. + pub stats: Option>, +} + +/// Client for progress notifications. +/// +/// For most event types, the client can be configured to either send notifications or requests that +/// can have a response. +impl Client { + /// A client that does not send anything. + pub const NONE: Self = Self { + mask: EventMask { + connected: EventMode::None, + get: EventMode::None, + get_many: EventMode::None, + push: EventMode::None, + transfer: EventMode::None, + transfer_complete: EventMode2::None, + transfer_aborted: EventMode2::None, + }, + inner: None, + }; + + pub async fn client_connected(&self, f: impl Fn() -> ClientConnected) -> ClientResult { + Ok(if let Some(client) = &self.inner { + match self.mask.connected { + EventMode::None => {} + EventMode::Notify => client.notify(Notify(f())).await?, + EventMode::Request => client.rpc(f()).await??, + } + }) + } + + pub async fn get_request(&self, f: impl Fn() -> GetRequestReceived) -> ClientResult { + Ok(if let Some(client) = &self.inner { + match self.mask.get { + EventMode::None => {} + EventMode::Notify => client.notify(Notify(f())).await?, + EventMode::Request => client.rpc(f()).await??, + } + }) + } + + pub async fn push_request(&self, f: impl Fn() -> PushRequestReceived) -> ClientResult { + Ok(if let Some(client) = &self.inner { + match self.mask.push { + EventMode::None => {} + EventMode::Notify => client.notify(Notify(f())).await?, + EventMode::Request => client.rpc(f()).await??, + } + }) + } + + pub async fn send_get_many_request( + &self, + f: impl Fn() -> GetManyRequestReceived, + ) -> ClientResult { + Ok(if let Some(client) = &self.inner { + match self.mask.get_many { + EventMode::None => {} + EventMode::Notify => client.notify(Notify(f())).await?, + EventMode::Request => client.rpc(f()).await??, + } + }) + } + + pub async fn transfer_progress(&self, f: impl Fn() -> TransferProgress) -> ClientResult { + Ok(if let Some(client) = &self.inner { + match self.mask.transfer { + EventMode::None => {} + EventMode::Notify => client.notify(Notify(f())).await?, + EventMode::Request => client.rpc(f()).await??, + } + }) + } +} + +#[rpc_requests(message = ProviderMessage)] +#[derive(Debug, Serialize, Deserialize)] +pub enum ProviderProto { + /// A new client connected to the provider. + #[rpc(tx = oneshot::Sender)] + #[wrap(ClientConnected)] + ClientConnected { connection_id: u64, node_id: NodeId }, + /// A new client connected to the provider. Notify variant. + #[rpc(tx = NoSender)] + ClientConnectedNotify(Notify), + /// A client disconnected from the provider. + #[rpc(tx = NoSender)] + #[wrap(ConnectionClosed)] + ConnectionClosed { connection_id: u64 }, + + #[rpc(tx = oneshot::Sender)] + /// A new get request was received from the provider. + GetRequestReceived(GetRequestReceived), + + #[rpc(tx = NoSender)] + /// A new get request was received from the provider. + GetRequestReceivedNotify(Notify), + /// A new get request was received from the provider. + #[rpc(tx = oneshot::Sender)] + GetManyRequestReceived(GetManyRequestReceived), + /// A new get request was received from the provider. + #[rpc(tx = NoSender)] + GetManyRequestReceivedNotify(Notify), + /// A new get request was received from the provider. + #[rpc(tx = oneshot::Sender)] + PushRequestReceived(PushRequestReceived), + /// A new get request was received from the provider. + #[rpc(tx = NoSender)] + PushRequestReceivedNotify(Notify), + /// Transfer for the nth blob started. + #[rpc(tx = NoSender)] + TransferStarted(TransferStarted), + /// Progress of the transfer. + #[rpc(tx = oneshot::Sender)] + TransferProgress(TransferProgress), + /// Progress of the transfer. + #[rpc(tx = NoSender)] + TransferProgressNotify(Notify), + /// Entire transfer completed. + #[rpc(tx = NoSender)] + TransferCompleted(TransferCompleted), + /// Entire transfer aborted. + #[rpc(tx = NoSender)] + TransferAborted(TransferAborted), +} From d17c6f6a4a246dac165b951800597340cb2bacc7 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Thu, 28 Aug 2025 15:03:26 +0200 Subject: [PATCH 02/23] Add transfer_completed and transfer_aborted fn. --- src/provider/event_proto.rs | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/provider/event_proto.rs b/src/provider/event_proto.rs index cc6e1aeab..ac769a476 100644 --- a/src/provider/event_proto.rs +++ b/src/provider/event_proto.rs @@ -236,6 +236,24 @@ impl Client { } }) } + + pub async fn transfer_completed(&self, f: impl Fn() -> TransferCompleted) -> ClientResult { + Ok(if let Some(client) = &self.inner { + match self.mask.transfer_complete { + EventMode2::Notify => client.notify(f()).await?, + EventMode2::None => {} + } + }) + } + + pub async fn transfer_aborted(&self, f: impl Fn() -> TransferAborted) -> ClientResult { + Ok(if let Some(client) = &self.inner { + match self.mask.transfer_aborted { + EventMode2::Notify => client.notify(f()).await?, + EventMode2::None => {} + } + }) + } } #[rpc_requests(message = ProviderMessage)] From b23995e71b479ccba05370a9df8dbb6099259a0d Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Fri, 29 Aug 2025 13:23:35 +0200 Subject: [PATCH 03/23] Nicer proto --- src/provider/event_proto.rs | 620 ++++++++++++++++++++++++++---------- 1 file changed, 444 insertions(+), 176 deletions(-) diff --git a/src/provider/event_proto.rs b/src/provider/event_proto.rs index ac769a476..8713d8aac 100644 --- a/src/provider/event_proto.rs +++ b/src/provider/event_proto.rs @@ -1,36 +1,51 @@ use std::fmt::Debug; -use iroh::NodeId; use irpc::{ - channel::{none::NoSender, oneshot}, - rpc_requests, + channel::{mpsc, none::NoSender, oneshot}, + rpc_requests, Channels, WithChannels, }; use serde::{Deserialize, Serialize}; use snafu::Snafu; -use crate::{protocol::ChunkRangesSeq, provider::TransferStats, Hash}; +use crate::provider::{event_proto::irpc_ext::IrpcClientExt, TransferStats}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] #[repr(u8)] -pub enum EventMode { - /// We don't get these kinds of events at all +pub enum ConnectMode { + /// We don't get notification of connect events at all. #[default] None, - /// We get a notification for these kinds of events + /// We get a notification for connect events. Notify, - /// We can respond to these kinds of events, either by aborting or by - /// e.g. introducing a delay for throttling. + /// We get a request for connect events and can reject incoming connections. Request, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] #[repr(u8)] -pub enum EventMode2 { - /// We don't get these kinds of events at all +pub enum RequestMode { + /// We don't get request events at all. #[default] None, - /// We get a notification for these kinds of events + /// We get a notification for each request. Notify, + /// We get a request for each request, and can reject incoming requests. + Request, + /// We get a notification for each request as well as detailed transfer events. + NotifyLog, + /// We get a request for each request, and can reject incoming requests. + /// We also get detailed transfer events. + RequestLog, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +#[repr(u8)] +pub enum ThrottleMode { + /// We don't get these kinds of events at all + #[default] + None, + /// We call throttle to give the event handler a way to throttle requests + Throttle, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] @@ -58,111 +73,154 @@ impl From for ClientError { } } +impl From for ClientError { + fn from(value: irpc::channel::RecvError) -> Self { + ClientError::Irpc { + source: value.into(), + } + } +} + +impl From for ClientError { + fn from(value: irpc::channel::SendError) -> Self { + ClientError::Irpc { + source: value.into(), + } + } +} + pub type EventResult = Result<(), AbortReason>; pub type ClientResult = Result<(), ClientError>; #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub struct EventMask { - connected: EventMode, - get: EventMode, - get_many: EventMode, - push: EventMode, - transfer: EventMode, - transfer_complete: EventMode2, - transfer_aborted: EventMode2, + connected: ConnectMode, + get: RequestMode, + get_many: RequestMode, + push: RequestMode, + /// throttling is somewhat costly, so you can disable it completely + throttle: ThrottleMode, +} + +impl EventMask { + /// Everything is disabled. You won't get any events, but there is also no runtime cost. + pub const NONE: Self = Self { + connected: ConnectMode::None, + get: RequestMode::None, + get_many: RequestMode::None, + push: RequestMode::None, + throttle: ThrottleMode::None, + }; + + /// You get asked for every single thing that is going on and can intervene/throttle. + pub const ALL: Self = Self { + connected: ConnectMode::Request, + get: RequestMode::RequestLog, + get_many: RequestMode::RequestLog, + push: RequestMode::RequestLog, + throttle: ThrottleMode::Throttle, + }; + + /// You get notified for every single thing that is going on, but can't intervene. + pub const NOTIFY_ALL: Self = Self { + connected: ConnectMode::Notify, + get: RequestMode::NotifyLog, + get_many: RequestMode::NotifyLog, + push: RequestMode::NotifyLog, + throttle: ThrottleMode::None, + }; } /// Newtype wrapper that wraps an event so that it is a distinct type for the notify variant. #[derive(Debug, Serialize, Deserialize)] pub struct Notify(T); -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] pub struct Client { mask: EventMask, inner: Option>, } -/// A new get request was received from the provider. -#[derive(Debug, Serialize, Deserialize)] -pub struct GetRequestReceived { - /// The connection id. Multiple requests can be sent over the same connection. - pub connection_id: u64, - /// The request id. There is a new id for each request. - pub request_id: u64, - /// The root hash of the request. - pub hash: Hash, - /// The exact query ranges of the request. - pub ranges: ChunkRangesSeq, +#[derive(Debug, Default)] +enum RequestUpdates { + /// Request tracking was not configured, all ops are no-ops + #[default] + None, + /// Active request tracking, all ops actually send + Active(mpsc::Sender), + /// Disabled request tracking, we just hold on to the sender so it drops + /// once the request is completed or aborted. + Disabled(mpsc::Sender), } -#[derive(Debug, Serialize, Deserialize)] -pub struct GetManyRequestReceived { - /// The connection id. Multiple requests can be sent over the same connection. - pub connection_id: u64, - /// The request id. There is a new id for each request. - pub request_id: u64, - /// The root hash of the request. - pub hashes: Vec, - /// The exact query ranges of the request. - pub ranges: ChunkRangesSeq, +pub struct RequestTracker { + updates: RequestUpdates, + throttle: Option<(irpc::Client, u64, u64)>, } -#[derive(Debug, Serialize, Deserialize)] -pub struct PushRequestReceived { - /// The connection id. Multiple requests can be sent over the same connection. - pub connection_id: u64, - /// The request id. There is a new id for each request. - pub request_id: u64, - /// The root hash of the request. - pub hash: Hash, - /// The exact query ranges of the request. - pub ranges: ChunkRangesSeq, -} +impl RequestTracker { + fn new( + updates: RequestUpdates, + throttle: Option<(irpc::Client, u64, u64)>, + ) -> Self { + Self { updates, throttle } + } -#[derive(Debug, Serialize, Deserialize)] -pub struct TransferProgress { - /// The connection id. Multiple requests can be sent over the same connection. - pub connection_id: u64, - /// The request id. There is a new id for each request. - pub request_id: u64, - /// The index of the blob in the request. 0 for the first blob or for raw blob requests. - pub index: u64, - /// The end offset of the chunk that was sent. - pub end_offset: u64, -} + /// A request tracker that doesn't track anything. + const NONE: Self = Self { + updates: RequestUpdates::None, + throttle: None, + }; -#[derive(Debug, Serialize, Deserialize)] -pub struct TransferStarted { - /// The connection id. Multiple requests can be sent over the same connection. - pub connection_id: u64, - /// The request id. There is a new id for each request. - pub request_id: u64, - /// The index of the blob in the request. 0 for the first blob or for raw blob requests. - pub index: u64, - /// The hash of the blob. This is the hash of the request for the first blob, the child hash (index-1) for subsequent blobs. - pub hash: Hash, - /// The size of the blob. This is the full size of the blob, not the size we are sending. - pub size: u64, -} + /// Transfer for index `index` started, size `size` + pub async fn transfer_started(&self, index: u64, size: u64) -> irpc::Result<()> { + if let RequestUpdates::Active(tx) = &self.updates { + tx.send(RequestUpdate::Started(TransferStarted { index, size })) + .await?; + } + Ok(()) + } -#[derive(Debug, Serialize, Deserialize)] -pub struct TransferCompleted { - /// The connection id. Multiple requests can be sent over the same connection. - pub connection_id: u64, - /// The request id. There is a new id for each request. - pub request_id: u64, - /// Statistics about the transfer. - pub stats: Box, -} + /// Transfer progress for the previously reported blob, end_offset is the new end offset in bytes. + pub async fn transfer_progress(&mut self, end_offset: u64) -> ClientResult { + if let RequestUpdates::Active(tx) = &mut self.updates { + tx.try_send(RequestUpdate::Progress(TransferProgress { end_offset })) + .await?; + } + if let Some((throttle, connection_id, request_id)) = &self.throttle { + throttle + .rpc(Throttle { + connection_id: *connection_id, + request_id: *request_id, + }) + .await??; + } + Ok(()) + } -#[derive(Debug, Serialize, Deserialize)] -pub struct TransferAborted { - /// The connection id. Multiple requests can be sent over the same connection. - pub connection_id: u64, - /// The request id. There is a new id for each request. - pub request_id: u64, - /// Statistics about the part of the transfer that was aborted. - pub stats: Option>, + /// Transfer completed for the previously reported blob. + pub async fn transfer_completed( + &mut self, + f: impl Fn() -> Box, + ) -> irpc::Result<()> { + if let RequestUpdates::Active(tx) = &self.updates { + tx.send(RequestUpdate::Completed(TransferCompleted { stats: f() })) + .await?; + } + Ok(()) + } + + /// Transfer aborted for the previously reported blob. + pub async fn transfer_aborted( + &mut self, + f: impl Fn() -> Option>, + ) -> irpc::Result<()> { + if let RequestUpdates::Active(tx) = &self.updates { + tx.send(RequestUpdate::Aborted(TransferAborted { stats: f() })) + .await?; + } + Ok(()) + } } /// Client for progress notifications. @@ -172,87 +230,132 @@ pub struct TransferAborted { impl Client { /// A client that does not send anything. pub const NONE: Self = Self { - mask: EventMask { - connected: EventMode::None, - get: EventMode::None, - get_many: EventMode::None, - push: EventMode::None, - transfer: EventMode::None, - transfer_complete: EventMode2::None, - transfer_aborted: EventMode2::None, - }, + mask: EventMask::NONE, inner: None, }; + /// A new client has been connected. pub async fn client_connected(&self, f: impl Fn() -> ClientConnected) -> ClientResult { Ok(if let Some(client) = &self.inner { match self.mask.connected { - EventMode::None => {} - EventMode::Notify => client.notify(Notify(f())).await?, - EventMode::Request => client.rpc(f()).await??, + ConnectMode::None => {} + ConnectMode::Notify => client.notify(Notify(f())).await?, + ConnectMode::Request => client.rpc(f()).await??, } }) } - pub async fn get_request(&self, f: impl Fn() -> GetRequestReceived) -> ClientResult { - Ok(if let Some(client) = &self.inner { - match self.mask.get { - EventMode::None => {} - EventMode::Notify => client.notify(Notify(f())).await?, - EventMode::Request => client.rpc(f()).await??, - } - }) - } - - pub async fn push_request(&self, f: impl Fn() -> PushRequestReceived) -> ClientResult { - Ok(if let Some(client) = &self.inner { - match self.mask.push { - EventMode::None => {} - EventMode::Notify => client.notify(Notify(f())).await?, - EventMode::Request => client.rpc(f()).await??, - } - }) + /// Start a get request. You will get back either an error if the request should not proceed, or a + /// [`RequestTracker`] that you can use to log progress for this particular request. + /// + /// Depending on the event sender config, the returned tracker might be a no-op. + pub async fn get_request( + &self, + f: impl FnOnce() -> GetRequestReceived, + ) -> Result { + self.request(f).await } - pub async fn send_get_many_request( + // Start a get_many request. You will get back either an error if the request should not proceed, or a + /// [`RequestTracker`] that you can use to log progress for this particular request. + /// + /// Depending on the event sender config, the returned tracker might be a no-op. + pub async fn get_many_request( &self, - f: impl Fn() -> GetManyRequestReceived, - ) -> ClientResult { - Ok(if let Some(client) = &self.inner { - match self.mask.get_many { - EventMode::None => {} - EventMode::Notify => client.notify(Notify(f())).await?, - EventMode::Request => client.rpc(f()).await??, - } - }) + f: impl FnOnce() -> GetManyRequestReceived, + ) -> Result { + self.request(f).await } - pub async fn transfer_progress(&self, f: impl Fn() -> TransferProgress) -> ClientResult { - Ok(if let Some(client) = &self.inner { - match self.mask.transfer { - EventMode::None => {} - EventMode::Notify => client.notify(Notify(f())).await?, - EventMode::Request => client.rpc(f()).await??, - } - }) + // Start a push request. You will get back either an error if the request should not proceed, or a + /// [`RequestTracker`] that you can use to log progress for this particular request. + /// + /// Depending on the event sender config, the returned tracker might be a no-op. + pub async fn push_request( + &self, + f: impl FnOnce() -> PushRequestReceived, + ) -> Result { + self.request(f).await } - pub async fn transfer_completed(&self, f: impl Fn() -> TransferCompleted) -> ClientResult { - Ok(if let Some(client) = &self.inner { - match self.mask.transfer_complete { - EventMode2::Notify => client.notify(f()).await?, - EventMode2::None => {} + /// Abstract request, to DRY the 3 to 4 request types. + /// + /// DRYing stuff with lots of bounds is no fun at all... + async fn request(&self, f: impl FnOnce() -> Req) -> Result + where + Req: Request, + ProviderProto: From, + ProviderMessage: From>, + Req: Channels< + ProviderProto, + Tx = oneshot::Sender, + Rx = mpsc::Receiver, + >, + ProviderProto: From>, + ProviderMessage: From, ProviderProto>>, + Notify: Channels>, + { + Ok(self.into_tracker(if let Some(client) = &self.inner { + match self.mask.get { + RequestMode::None => { + if self.mask.throttle == ThrottleMode::Throttle { + // if throttling is enabled, we need to call f to get connection_id and request_id + let msg = f(); + (RequestUpdates::None, msg.id()) + } else { + (RequestUpdates::None, (0, 0)) + } + } + RequestMode::Notify => { + let msg = f(); + let id = msg.id(); + ( + RequestUpdates::Disabled(client.notify_streaming(Notify(msg), 32).await?), + id, + ) + } + RequestMode::Request => { + let msg = f(); + let id = msg.id(); + let (tx, rx) = client.client_streaming(msg, 32).await?; + // bail out if the request is not allowed + rx.await??; + (RequestUpdates::Disabled(tx), id) + } + RequestMode::NotifyLog => { + let msg = f(); + let id = msg.id(); + ( + RequestUpdates::Active(client.notify_streaming(Notify(msg), 32).await?), + id, + ) + } + RequestMode::RequestLog => { + let msg = f(); + let id = msg.id(); + let (tx, rx) = client.client_streaming(msg, 32).await?; + // bail out if the request is not allowed + rx.await??; + (RequestUpdates::Active(tx), id) + } } - }) + } else { + (RequestUpdates::None, (0, 0)) + })) } - pub async fn transfer_aborted(&self, f: impl Fn() -> TransferAborted) -> ClientResult { - Ok(if let Some(client) = &self.inner { - match self.mask.transfer_aborted { - EventMode2::Notify => client.notify(f()).await?, - EventMode2::None => {} - } - }) + fn into_tracker( + &self, + (updates, (connection_id, request_id)): (RequestUpdates, (u64, u64)), + ) -> RequestTracker { + let throttle = match self.mask.throttle { + ThrottleMode::None => None, + ThrottleMode::Throttle => self + .inner + .clone() + .map(|client| (client, connection_id, request_id)), + }; + RequestTracker::new(updates, throttle) } } @@ -261,48 +364,213 @@ impl Client { pub enum ProviderProto { /// A new client connected to the provider. #[rpc(tx = oneshot::Sender)] - #[wrap(ClientConnected)] - ClientConnected { connection_id: u64, node_id: NodeId }, + ClientConnected(ClientConnected), /// A new client connected to the provider. Notify variant. #[rpc(tx = NoSender)] ClientConnectedNotify(Notify), + /// A client disconnected from the provider. #[rpc(tx = NoSender)] #[wrap(ConnectionClosed)] ConnectionClosed { connection_id: u64 }, - #[rpc(tx = oneshot::Sender)] + #[rpc(rx = mpsc::Receiver, tx = oneshot::Sender)] /// A new get request was received from the provider. GetRequestReceived(GetRequestReceived), - #[rpc(tx = NoSender)] + #[rpc(rx = mpsc::Receiver, tx = NoSender)] /// A new get request was received from the provider. GetRequestReceivedNotify(Notify), + /// A new get request was received from the provider. - #[rpc(tx = oneshot::Sender)] + #[rpc(rx = mpsc::Receiver, tx = oneshot::Sender)] GetManyRequestReceived(GetManyRequestReceived), + /// A new get request was received from the provider. - #[rpc(tx = NoSender)] + #[rpc(rx = mpsc::Receiver, tx = NoSender)] GetManyRequestReceivedNotify(Notify), + /// A new get request was received from the provider. - #[rpc(tx = oneshot::Sender)] + #[rpc(rx = mpsc::Receiver, tx = oneshot::Sender)] PushRequestReceived(PushRequestReceived), + /// A new get request was received from the provider. - #[rpc(tx = NoSender)] + #[rpc(rx = mpsc::Receiver, tx = NoSender)] PushRequestReceivedNotify(Notify), - /// Transfer for the nth blob started. - #[rpc(tx = NoSender)] - TransferStarted(TransferStarted), - /// Progress of the transfer. + #[rpc(tx = oneshot::Sender)] - TransferProgress(TransferProgress), - /// Progress of the transfer. - #[rpc(tx = NoSender)] - TransferProgressNotify(Notify), - /// Entire transfer completed. - #[rpc(tx = NoSender)] - TransferCompleted(TransferCompleted), - /// Entire transfer aborted. - #[rpc(tx = NoSender)] - TransferAborted(TransferAborted), + Throttle(Throttle), +} + +trait Request { + fn id(&self) -> (u64, u64); +} + +mod proto { + use iroh::NodeId; + use serde::{Deserialize, Serialize}; + + use super::Request; + use crate::{protocol::ChunkRangesSeq, provider::TransferStats, Hash}; + + #[derive(Debug, Serialize, Deserialize)] + pub struct ClientConnected { + pub connection_id: u64, + pub node_id: NodeId, + } + + /// A new get request was received from the provider. + #[derive(Debug, Serialize, Deserialize)] + pub struct GetRequestReceived { + /// The connection id. Multiple requests can be sent over the same connection. + pub connection_id: u64, + /// The request id. There is a new id for each request. + pub request_id: u64, + /// The root hash of the request. + pub hash: Hash, + /// The exact query ranges of the request. + pub ranges: ChunkRangesSeq, + } + + impl Request for GetRequestReceived { + fn id(&self) -> (u64, u64) { + (self.connection_id, self.request_id) + } + } + + #[derive(Debug, Serialize, Deserialize)] + pub struct GetManyRequestReceived { + /// The connection id. Multiple requests can be sent over the same connection. + pub connection_id: u64, + /// The request id. There is a new id for each request. + pub request_id: u64, + /// The root hash of the request. + pub hashes: Vec, + /// The exact query ranges of the request. + pub ranges: ChunkRangesSeq, + } + + impl Request for GetManyRequestReceived { + fn id(&self) -> (u64, u64) { + (self.connection_id, self.request_id) + } + } + + #[derive(Debug, Serialize, Deserialize)] + pub struct PushRequestReceived { + /// The connection id. Multiple requests can be sent over the same connection. + pub connection_id: u64, + /// The request id. There is a new id for each request. + pub request_id: u64, + /// The root hash of the request. + pub hash: Hash, + /// The exact query ranges of the request. + pub ranges: ChunkRangesSeq, + } + + impl Request for PushRequestReceived { + fn id(&self) -> (u64, u64) { + (self.connection_id, self.request_id) + } + } + + /// Request to throttle sending for a specific request. + #[derive(Debug, Serialize, Deserialize)] + pub struct Throttle { + /// The connection id. Multiple requests can be sent over the same connection. + pub connection_id: u64, + /// The request id. There is a new id for each request. + pub request_id: u64, + } + + #[derive(Debug, Serialize, Deserialize)] + pub struct TransferProgress { + /// The end offset of the chunk that was sent. + pub end_offset: u64, + } + + #[derive(Debug, Serialize, Deserialize)] + pub struct TransferStarted { + pub index: u64, + pub size: u64, + } + + #[derive(Debug, Serialize, Deserialize)] + pub struct TransferCompleted { + pub stats: Box, + } + + #[derive(Debug, Serialize, Deserialize)] + pub struct TransferAborted { + pub stats: Option>, + } + + /// Stream of updates for a single request + #[derive(Debug, Serialize, Deserialize)] + pub enum RequestUpdate { + /// Start of transfer for a blob, mandatory event + Started(TransferStarted), + /// Progress for a blob - optional event + Progress(TransferProgress), + /// Successful end of transfer + Completed(TransferCompleted), + /// Aborted end of transfer + Aborted(TransferAborted), + } +} +use proto::*; + +mod irpc_ext { + use std::future::Future; + + use irpc::{ + channel::{mpsc, none::NoSender, oneshot}, + Channels, RpcMessage, Service, WithChannels, + }; + + pub trait IrpcClientExt { + fn notify_streaming( + &self, + msg: Req, + local_update_cap: usize, + ) -> impl Future>> + where + S: From, + S::Message: From>, + Req: Channels>, + Update: RpcMessage; + } + + impl IrpcClientExt for irpc::Client { + fn notify_streaming( + &self, + msg: Req, + local_update_cap: usize, + ) -> impl Future>> + where + S: From, + S::Message: From>, + Req: Channels>, + Update: RpcMessage, + { + let client = self.clone(); + async move { + let request = client.request().await?; + match request { + irpc::Request::Local(local) => { + let (req_tx, req_rx) = mpsc::channel(local_update_cap); + local + .send((msg, NoSender, req_rx)) + .await + .map_err(irpc::Error::from)?; + Ok(req_tx) + } + irpc::Request::Remote(remote) => { + let (s, r) = remote.write(msg).await?; + Ok(s.into()) + } + } + } + } + } } From a78c212e8f307210bed17d7ab85010aeff9dfa5d Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Fri, 29 Aug 2025 15:05:26 +0200 Subject: [PATCH 04/23] Update tests --- examples/custom-protocol.rs | 4 +- examples/mdns-discovery.rs | 4 +- examples/random_store.rs | 104 +++--- examples/transfer.rs | 4 +- src/api/blobs.rs | 17 +- src/net_protocol.rs | 8 +- src/provider.rs | 354 ++++++++------------- src/provider/{event_proto.rs => events.rs} | 47 ++- src/tests.rs | 52 +-- 9 files changed, 261 insertions(+), 333 deletions(-) rename src/provider/{event_proto.rs => events.rs} (94%) diff --git a/examples/custom-protocol.rs b/examples/custom-protocol.rs index c021b7f0a..6542acd18 100644 --- a/examples/custom-protocol.rs +++ b/examples/custom-protocol.rs @@ -48,7 +48,7 @@ use iroh::{ protocol::{AcceptError, ProtocolHandler, Router}, Endpoint, NodeId, }; -use iroh_blobs::{api::Store, store::mem::MemStore, BlobsProtocol, Hash}; +use iroh_blobs::{api::Store, provider::EventSender2, store::mem::MemStore, BlobsProtocol, Hash}; mod common; use common::{get_or_generate_secret_key, setup_logging}; @@ -100,7 +100,7 @@ async fn listen(text: Vec) -> Result<()> { proto.insert_and_index(text).await?; } // Build the iroh-blobs protocol handler, which is used to download blobs. - let blobs = BlobsProtocol::new(&store, endpoint.clone(), None); + let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender2::NONE); // create a router that handles both our custom protocol and the iroh-blobs protocol. let node = Router::builder(endpoint) diff --git a/examples/mdns-discovery.rs b/examples/mdns-discovery.rs index b42f88f47..ef5d0619c 100644 --- a/examples/mdns-discovery.rs +++ b/examples/mdns-discovery.rs @@ -18,7 +18,7 @@ use clap::{Parser, Subcommand}; use iroh::{ discovery::mdns::MdnsDiscovery, protocol::Router, Endpoint, PublicKey, RelayMode, SecretKey, }; -use iroh_blobs::{store::mem::MemStore, BlobsProtocol, Hash}; +use iroh_blobs::{provider::EventSender2, store::mem::MemStore, BlobsProtocol, Hash}; mod common; use common::{get_or_generate_secret_key, setup_logging}; @@ -68,7 +68,7 @@ async fn accept(path: &Path) -> Result<()> { .await?; let builder = Router::builder(endpoint.clone()); let store = MemStore::new(); - let blobs = BlobsProtocol::new(&store, endpoint.clone(), None); + let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender2::NONE); let builder = builder.accept(iroh_blobs::ALPN, blobs.clone()); let node = builder.spawn(); diff --git a/examples/random_store.rs b/examples/random_store.rs index ffdd9b826..6f933d511 100644 --- a/examples/random_store.rs +++ b/examples/random_store.rs @@ -6,7 +6,7 @@ use iroh::{SecretKey, Watcher}; use iroh_base::ticket::NodeTicket; use iroh_blobs::{ api::downloader::Shuffled, - provider::Event, + provider::{AbortReason, Event, EventMask, EventSender2, ProviderMessage}, store::fs::FsStore, test::{add_hash_sequences, create_random_blobs}, HashAndFormat, @@ -104,78 +104,66 @@ pub fn dump_provider_events( allow_push: bool, ) -> ( tokio::task::JoinHandle<()>, - mpsc::Sender, + EventSender2, ) { let (tx, mut rx) = mpsc::channel(100); let dump_task = tokio::spawn(async move { while let Some(event) = rx.recv().await { match event { - Event::ClientConnected { - node_id, - connection_id, - permitted, - } => { - permitted.send(true).await.ok(); - println!("Client connected: {node_id} {connection_id}"); + ProviderMessage::ClientConnected(msg) => { + println!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); } - Event::GetRequestReceived { - connection_id, - request_id, - hash, - ranges, - } => { - println!( - "Get request received: {connection_id} {request_id} {hash} {ranges:?}" - ); + ProviderMessage::ClientConnectedNotify(msg) => { + println!("{:?}", msg.inner); } - Event::TransferCompleted { - connection_id, - request_id, - stats, - } => { - println!("Transfer completed: {connection_id} {request_id} {stats:?}"); + ProviderMessage::ConnectionClosed(msg) => { + println!("{:?}", msg.inner); } - Event::TransferAborted { - connection_id, - request_id, - stats, - } => { - println!("Transfer aborted: {connection_id} {request_id} {stats:?}"); + ProviderMessage::GetRequestReceived(mut msg) => { + println!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + tokio::spawn(async move { + while let Ok(update) = msg.rx.recv().await { + info!("{update:?}"); + } + }); } - Event::TransferProgress { - connection_id, - request_id, - index, - end_offset, - } => { - info!("Transfer progress: {connection_id} {request_id} {index} {end_offset}"); + ProviderMessage::GetRequestReceivedNotify(msg) => { + println!("{:?}", msg.inner); } - Event::PushRequestReceived { - connection_id, - request_id, - hash, - ranges, - permitted, - } => { - if allow_push { - permitted.send(true).await.ok(); - println!( - "Push request received: {connection_id} {request_id} {hash} {ranges:?}" - ); + ProviderMessage::GetManyRequestReceived(mut msg) => { + println!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + tokio::spawn(async move { + while let Ok(update) = msg.rx.recv().await { + info!("{update:?}"); + } + }); + } + ProviderMessage::GetManyRequestReceivedNotify(msg) => { + println!("{:?}", msg.inner); + } + ProviderMessage::PushRequestReceived(msg) => { + println!("{:?}", msg.inner); + let res = if allow_push { + Ok(()) } else { - permitted.send(false).await.ok(); - println!( - "Push request denied: {connection_id} {request_id} {hash} {ranges:?}" - ); - } + Err(AbortReason::Permission) + }; + msg.tx.send(res).await.ok(); + } + ProviderMessage::PushRequestReceivedNotify(msg) => { + println!("{:?}", msg.inner); } - _ => { - info!("Received event: {:?}", event); + ProviderMessage::Throttle(msg) => { + println!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); } } } }); - (dump_task, tx) + (dump_task, EventSender2::new(tx, EventMask::ALL)) } #[tokio::main] @@ -237,7 +225,7 @@ async fn provide(args: ProvideArgs) -> anyhow::Result<()> { .bind() .await?; let (dump_task, events_tx) = dump_provider_events(args.allow_push); - let blobs = iroh_blobs::BlobsProtocol::new(&store, endpoint.clone(), Some(events_tx)); + let blobs = iroh_blobs::BlobsProtocol::new(&store, endpoint.clone(), events_tx); let router = iroh::protocol::Router::builder(endpoint.clone()) .accept(iroh_blobs::ALPN, blobs) .spawn(); diff --git a/examples/transfer.rs b/examples/transfer.rs index 48fba6ba3..baa1e343c 100644 --- a/examples/transfer.rs +++ b/examples/transfer.rs @@ -1,7 +1,7 @@ use std::path::PathBuf; use iroh::{protocol::Router, Endpoint}; -use iroh_blobs::{store::mem::MemStore, ticket::BlobTicket, BlobsProtocol}; +use iroh_blobs::{provider::EventSender2, store::mem::MemStore, ticket::BlobTicket, BlobsProtocol}; #[tokio::main] async fn main() -> anyhow::Result<()> { @@ -12,7 +12,7 @@ async fn main() -> anyhow::Result<()> { // We initialize an in-memory backing store for iroh-blobs let store = MemStore::new(); // Then we initialize a struct that can accept blobs requests over iroh connections - let blobs = BlobsProtocol::new(&store, endpoint.clone(), None); + let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender2::NONE); // Grab all passed in arguments, the first one is the binary itself, so we skip it. let args: Vec = std::env::args().skip(1).collect(); diff --git a/src/api/blobs.rs b/src/api/blobs.rs index d0b948598..76f338359 100644 --- a/src/api/blobs.rs +++ b/src/api/blobs.rs @@ -57,7 +57,7 @@ use super::{ }; use crate::{ api::proto::{BatchRequest, ImportByteStreamUpdate}, - provider::StreamContext, + provider::{ReaderContext, WriterContext}, store::IROH_BLOCK_SIZE, util::temp_tag::TempTag, BlobFormat, Hash, HashAndFormat, @@ -1168,16 +1168,21 @@ pub(crate) trait WriteProgress { async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64); } -impl WriteProgress for StreamContext { - async fn notify_payload_write(&mut self, index: u64, offset: u64, len: usize) { - StreamContext::notify_payload_write(self, index, offset, len); +impl WriteProgress for WriterContext { + async fn notify_payload_write(&mut self, _index: u64, offset: u64, len: usize) { + let end_offset = offset + len as u64; + self.payload_bytes_written += len as u64; + self.tracker.transfer_progress(end_offset).await.ok(); } fn log_other_write(&mut self, len: usize) { - StreamContext::log_other_write(self, len); + self.other_bytes_written += len as u64; } async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64) { - StreamContext::send_transfer_started(self, index, hash, size).await + self.tracker + .transfer_started(index, hash, size) + .await + .ok(); } } diff --git a/src/net_protocol.rs b/src/net_protocol.rs index 3e7d9582e..ca64b1a7b 100644 --- a/src/net_protocol.rs +++ b/src/net_protocol.rs @@ -48,7 +48,7 @@ use tracing::error; use crate::{ api::Store, - provider::{Event, EventSender}, + provider::{Event, EventSender2}, ticket::BlobTicket, HashAndFormat, }; @@ -57,7 +57,7 @@ use crate::{ pub(crate) struct BlobsInner { pub(crate) store: Store, pub(crate) endpoint: Endpoint, - pub(crate) events: EventSender, + pub(crate) events: EventSender2, } /// A protocol handler for the blobs protocol. @@ -75,12 +75,12 @@ impl Deref for BlobsProtocol { } impl BlobsProtocol { - pub fn new(store: &Store, endpoint: Endpoint, events: Option>) -> Self { + pub fn new(store: &Store, endpoint: Endpoint, events: EventSender2) -> Self { Self { inner: Arc::new(BlobsInner { store: store.clone(), endpoint, - events: EventSender::new(events), + events, }), } } diff --git a/src/provider.rs b/src/provider.rs index 141d674c6..b10367911 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -9,7 +9,7 @@ use std::{ ops::{Deref, DerefMut}, pin::Pin, task::Poll, - time::Duration, + time::{Duration, Instant}, }; use anyhow::{Context, Result}; @@ -18,22 +18,23 @@ use iroh::{ endpoint::{self, RecvStream, SendStream}, NodeId, }; -use irpc::channel::oneshot; +use irpc::{channel::oneshot}; use n0_future::StreamExt; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use tokio::{io::AsyncRead, select, sync::mpsc}; -use tracing::{debug, debug_span, error, warn, Instrument}; +use tracing::{debug, debug_span, error, trace, warn, Instrument}; use crate::{ - api::{self, blobs::Bitfield, Store}, - hashseq::HashSeq, - protocol::{ + api::{self, blobs::{Bitfield, WriteProgress}, Store}, hashseq::HashSeq, protocol::{ ChunkRangesSeq, GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request, - }, - Hash, + }, provider::events::{ClientConnected, ConnectionClosed, GetRequestReceived, RequestTracker}, Hash }; -mod event_proto; +pub(crate) mod events; +pub use events::EventSender as EventSender2; +pub use events::ProviderMessage; +pub use events::EventMask; +pub use events::AbortReason; /// Provider progress events, to keep track of what the provider is doing. /// @@ -162,19 +163,33 @@ pub async fn read_request(reader: &mut ProgressReader) -> Result { } #[derive(Debug)] -pub struct StreamContext { +pub struct ReaderContext { + /// The start time of the transfer + pub t0: Instant, /// The connection ID from the connection pub connection_id: u64, /// The request ID from the recv stream pub request_id: u64, - /// The number of bytes written that are part of the payload - pub payload_bytes_sent: u64, - /// The number of bytes written that are not part of the payload - pub other_bytes_sent: u64, /// The number of bytes read from the stream pub bytes_read: u64, - /// The progress sender to send events to - pub progress: EventSender, +} + +#[derive(Debug)] +pub struct WriterContext { + /// The start time of the transfer + pub t0: Instant, + /// The connection ID from the connection + pub connection_id: u64, + /// The request ID from the recv stream + pub request_id: u64, + /// The number of bytes read from the stream + pub bytes_read: u64, + /// The number of payload bytes written to the stream + pub payload_bytes_written: u64, + /// The number of bytes written that are not part of the payload + pub other_bytes_written: u64, + /// Way to report progress + pub tracker: RequestTracker, } /// Wrapper for a [`quinn::SendStream`] with additional per request information. @@ -182,11 +197,43 @@ pub struct StreamContext { pub struct ProgressWriter { /// The quinn::SendStream to write to pub inner: SendStream, - pub(crate) context: StreamContext, + pub(crate) context: WriterContext, +} + +impl ProgressWriter { + fn new(inner: SendStream, context: ReaderContext, tracker: RequestTracker) -> Self { + Self { inner, context: WriterContext { + connection_id: context.connection_id, + request_id: context.request_id, + bytes_read: context.bytes_read, + t0: context.t0, + payload_bytes_written: 0, + other_bytes_written: 0, + tracker, + } } + } + + async fn transfer_aborted(&self) { + self.tracker.transfer_aborted(|| Some(Box::new(TransferStats { + payload_bytes_sent: self.payload_bytes_written, + other_bytes_sent: self.other_bytes_written, + bytes_read: self.bytes_read, + duration: self.context.t0.elapsed(), + }))).await.ok(); + } + + async fn transfer_completed(&self) { + self.tracker.transfer_completed(|| Box::new(TransferStats { + payload_bytes_sent: self.payload_bytes_written, + other_bytes_sent: self.other_bytes_written, + bytes_read: self.bytes_read, + duration: self.context.t0.elapsed(), + })).await.ok(); + } } impl Deref for ProgressWriter { - type Target = StreamContext; + type Target = WriterContext; fn deref(&self) -> &Self::Target { &self.context @@ -199,140 +246,11 @@ impl DerefMut for ProgressWriter { } } -impl StreamContext { - /// Increase the write count due to a non-payload write. - pub fn log_other_write(&mut self, len: usize) { - self.other_bytes_sent += len as u64; - } - - pub async fn send_transfer_completed(&mut self) { - self.progress - .send(|| Event::TransferCompleted { - connection_id: self.connection_id, - request_id: self.request_id, - stats: Box::new(TransferStats { - payload_bytes_sent: self.payload_bytes_sent, - other_bytes_sent: self.other_bytes_sent, - bytes_read: self.bytes_read, - duration: Duration::ZERO, - }), - }) - .await; - } - - pub async fn send_transfer_aborted(&mut self) { - self.progress - .send(|| Event::TransferAborted { - connection_id: self.connection_id, - request_id: self.request_id, - stats: Some(Box::new(TransferStats { - payload_bytes_sent: self.payload_bytes_sent, - other_bytes_sent: self.other_bytes_sent, - bytes_read: self.bytes_read, - duration: Duration::ZERO, - })), - }) - .await; - } - - /// Increase the write count due to a payload write, and notify the progress sender. - /// - /// `index` is the index of the blob in the request. - /// `offset` is the offset in the blob where the write started. - /// `len` is the length of the write. - pub fn notify_payload_write(&mut self, index: u64, offset: u64, len: usize) { - self.payload_bytes_sent += len as u64; - self.progress.try_send(|| Event::TransferProgress { - connection_id: self.connection_id, - request_id: self.request_id, - index, - end_offset: offset + len as u64, - }); - } - - /// Send a get request received event. - /// - /// This sends all the required information to make sense of subsequent events such as - /// [`Event::TransferStarted`] and [`Event::TransferProgress`]. - pub async fn send_get_request_received(&self, hash: &Hash, ranges: &ChunkRangesSeq) { - self.progress - .send(|| Event::GetRequestReceived { - connection_id: self.connection_id, - request_id: self.request_id, - hash: *hash, - ranges: ranges.clone(), - }) - .await; - } - - /// Send a get request received event. - /// - /// This sends all the required information to make sense of subsequent events such as - /// [`Event::TransferStarted`] and [`Event::TransferProgress`]. - pub async fn send_get_many_request_received(&self, hashes: &[Hash], ranges: &ChunkRangesSeq) { - self.progress - .send(|| Event::GetManyRequestReceived { - connection_id: self.connection_id, - request_id: self.request_id, - hashes: hashes.to_vec(), - ranges: ranges.clone(), - }) - .await; - } - - /// Authorize a push request. - /// - /// This will send a request to the event sender, and wait for a response if a - /// progress sender is enabled. If not, it will always fail. - /// - /// We want to make accepting push requests very explicit, since this allows - /// remote nodes to add arbitrary data to our store. - #[must_use = "permit should be checked by the caller"] - pub async fn authorize_push_request(&self, hash: &Hash, ranges: &ChunkRangesSeq) -> bool { - let mut wait_for_permit = None; - // send the request, including the permit channel - self.progress - .send(|| { - let (tx, rx) = oneshot::channel(); - wait_for_permit = Some(rx); - Event::PushRequestReceived { - connection_id: self.connection_id, - request_id: self.request_id, - hash: *hash, - ranges: ranges.clone(), - permitted: tx, - } - }) - .await; - // wait for the permit, if necessary - if let Some(wait_for_permit) = wait_for_permit { - // if somebody does not handle the request, they will drop the channel, - // and this will fail immediately. - wait_for_permit.await.unwrap_or(false) - } else { - false - } - } - - /// Send a transfer started event. - pub async fn send_transfer_started(&self, index: u64, hash: &Hash, size: u64) { - self.progress - .send(|| Event::TransferStarted { - connection_id: self.connection_id, - request_id: self.request_id, - index, - hash: *hash, - size, - }) - .await; - } -} - /// Handle a single connection. pub async fn handle_connection( connection: endpoint::Connection, store: Store, - progress: EventSender, + progress: EventSender2, ) { let connection_id = connection.stable_id() as u64; let span = debug_span!("connection", connection_id); @@ -341,11 +259,14 @@ pub async fn handle_connection( warn!("failed to get node id"); return; }; - if !progress - .authorize_client_connection(connection_id, node_id) + if let Err(cause) =progress + .client_connected(|| ClientConnected { + connection_id, + node_id, + }) .await { - debug!("client not authorized to connect"); + debug!("client not authorized to connect: {cause}"); return; } while let Ok((writer, reader)) = connection.accept_bi().await { @@ -354,35 +275,24 @@ pub async fn handle_connection( let request_id = reader.id().index(); let span = debug_span!("stream", stream_id = %request_id); let store = store.clone(); - let mut writer = ProgressWriter { - inner: writer, - context: StreamContext { - connection_id, - request_id, - payload_bytes_sent: 0, - other_bytes_sent: 0, - bytes_read: 0, - progress: progress.clone(), - }, + let context = ReaderContext { + t0: Instant::now(), + connection_id: connection_id, + request_id: request_id, + bytes_read: 0, + }; + let reader = ProgressReader { + inner: reader, + context, }; tokio::spawn( - async move { - match handle_stream(store, reader, &mut writer).await { - Ok(()) => { - writer.send_transfer_completed().await; - } - Err(err) => { - warn!("error: {err:#?}",); - writer.send_transfer_aborted().await; - } - } - } + handle_stream(store, reader, writer, progress.clone()) .instrument(span), ); } progress - .send(Event::ConnectionClosed { connection_id }) - .await; + .connection_closed(|| ConnectionClosed { connection_id }) + .await.ok(); } .instrument(span) .await @@ -390,56 +300,69 @@ pub async fn handle_connection( async fn handle_stream( store: Store, - reader: RecvStream, - writer: &mut ProgressWriter, -) -> Result<()> { + mut reader: ProgressReader, + writer: SendStream, + progress: EventSender2, +) { // 1. Decode the request. debug!("reading request"); - let mut reader = ProgressReader { - inner: reader, - context: StreamContext { - connection_id: writer.connection_id, - request_id: writer.request_id, - payload_bytes_sent: 0, - other_bytes_sent: 0, - bytes_read: 0, - progress: writer.progress.clone(), - }, - }; let request = match read_request(&mut reader).await { Ok(request) => request, Err(e) => { - // todo: increase invalid requests metric counter - return Err(e); + // todo: event for read request failed + return; } }; match request { Request::Get(request) => { + let tracker = match progress.get_request(|| GetRequestReceived { + connection_id: reader.context.connection_id, + request_id: reader.context.request_id, + request: request.clone(), + }).await { + Ok(tracker) => tracker, + Err(e) => { + trace!("Request denied: {}", e); + return; + } + }; // we expect no more bytes after the request, so if there are more bytes, it is an invalid request. - reader.inner.read_to_end(0).await?; - // move the context so we don't lose the bytes read - writer.context = reader.context; - handle_get(store, request, writer).await + let res = reader.inner.read_to_end(0).await; + let mut writer = ProgressWriter::new(writer, reader.context, tracker); + if res.is_err() { + writer.transfer_aborted().await; + return; + } + match handle_get(store, request, &mut writer).await { + Ok(()) => { + writer.transfer_completed().await; + } + Err(_) => { + writer.transfer_aborted().await; + } + } } Request::GetMany(request) => { - // we expect no more bytes after the request, so if there are more bytes, it is an invalid request. - reader.inner.read_to_end(0).await?; - // move the context so we don't lose the bytes read - writer.context = reader.context; - handle_get_many(store, request, writer).await + todo!(); + // // we expect no more bytes after the request, so if there are more bytes, it is an invalid request. + // reader.inner.read_to_end(0).await?; + // // move the context so we don't lose the bytes read + // writer.context = reader.context; + // handle_get_many(store, request, writer).await } Request::Observe(request) => { - // we expect no more bytes after the request, so if there are more bytes, it is an invalid request. - reader.inner.read_to_end(0).await?; - handle_observe(store, request, writer).await + todo!(); + // // we expect no more bytes after the request, so if there are more bytes, it is an invalid request. + // reader.inner.read_to_end(0).await?; + // handle_observe(store, request, writer).await } Request::Push(request) => { - writer.inner.finish()?; - handle_push(store, request, reader).await + todo!(); + // writer.inner.finish()?; + // handle_push(store, request, reader).await } - _ => anyhow::bail!("unsupported request: {request:?}"), - // Request::Push(request) => handle_push(store, request, writer).await, + _ => {}, } } @@ -450,13 +373,9 @@ pub async fn handle_get( store: Store, request: GetRequest, writer: &mut ProgressWriter, -) -> Result<()> { +) -> anyhow::Result<()> { let hash = request.hash; debug!(%hash, "get received request"); - - writer - .send_get_request_received(&hash, &request.ranges) - .await; let mut hash_seq = None; for (offset, ranges) in request.ranges.iter_non_empty_infinite() { if offset == 0 { @@ -496,9 +415,6 @@ pub async fn handle_get_many( writer: &mut ProgressWriter, ) -> Result<()> { debug!("get_many received request"); - writer - .send_get_many_request_received(&request.hashes, &request.ranges) - .await; let request_ranges = request.ranges.iter_infinite(); for (child, (hash, ranges)) in request.hashes.iter().zip(request_ranges).enumerate() { if !ranges.is_empty() { @@ -518,10 +434,6 @@ pub async fn handle_push( ) -> Result<()> { let hash = request.hash; debug!(%hash, "push received request"); - if !reader.authorize_push_request(&hash, &request.ranges).await { - debug!("push request not authorized"); - return Ok(()); - }; let mut request_ranges = request.ranges.iter_infinite(); let root_ranges = request_ranges.next().expect("infinite iterator"); if !root_ranges.is_empty() { @@ -602,7 +514,7 @@ async fn send_observe_item(writer: &mut ProgressWriter, item: &Bitfield) -> Resu use irpc::util::AsyncWriteVarintExt; let item = ObserveItem::from(item); let len = writer.inner.write_length_prefixed(item).await?; - writer.log_other_write(len); + writer.context.log_other_write(len); Ok(()) } @@ -701,11 +613,11 @@ impl EventSender { pub struct ProgressReader { inner: RecvStream, - context: StreamContext, + context: ReaderContext, } impl Deref for ProgressReader { - type Target = StreamContext; + type Target = ReaderContext; fn deref(&self) -> &Self::Target { &self.context diff --git a/src/provider/event_proto.rs b/src/provider/events.rs similarity index 94% rename from src/provider/event_proto.rs rename to src/provider/events.rs index 8713d8aac..55383e77c 100644 --- a/src/provider/event_proto.rs +++ b/src/provider/events.rs @@ -7,7 +7,7 @@ use irpc::{ use serde::{Deserialize, Serialize}; use snafu::Snafu; -use crate::provider::{event_proto::irpc_ext::IrpcClientExt, TransferStats}; +use crate::{provider::{events::irpc_ext::IrpcClientExt, TransferStats}, Hash}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] #[repr(u8)] @@ -136,7 +136,7 @@ impl EventMask { pub struct Notify(T); #[derive(Debug, Default, Clone)] -pub struct Client { +pub struct EventSender { mask: EventMask, inner: Option>, } @@ -150,9 +150,10 @@ enum RequestUpdates { Active(mpsc::Sender), /// Disabled request tracking, we just hold on to the sender so it drops /// once the request is completed or aborted. - Disabled(mpsc::Sender), + Disabled(#[allow(dead_code)] mpsc::Sender), } +#[derive(Debug)] pub struct RequestTracker { updates: RequestUpdates, throttle: Option<(irpc::Client, u64, u64)>, @@ -173,9 +174,9 @@ impl RequestTracker { }; /// Transfer for index `index` started, size `size` - pub async fn transfer_started(&self, index: u64, size: u64) -> irpc::Result<()> { + pub async fn transfer_started(&self, index: u64, hash: &Hash, size: u64) -> irpc::Result<()> { if let RequestUpdates::Active(tx) = &self.updates { - tx.send(RequestUpdate::Started(TransferStarted { index, size })) + tx.send(RequestUpdate::Started(TransferStarted { index, hash: *hash, size })) .await?; } Ok(()) @@ -200,7 +201,7 @@ impl RequestTracker { /// Transfer completed for the previously reported blob. pub async fn transfer_completed( - &mut self, + &self, f: impl Fn() -> Box, ) -> irpc::Result<()> { if let RequestUpdates::Active(tx) = &self.updates { @@ -212,7 +213,7 @@ impl RequestTracker { /// Transfer aborted for the previously reported blob. pub async fn transfer_aborted( - &mut self, + &self, f: impl Fn() -> Option>, ) -> irpc::Result<()> { if let RequestUpdates::Active(tx) = &self.updates { @@ -227,13 +228,17 @@ impl RequestTracker { /// /// For most event types, the client can be configured to either send notifications or requests that /// can have a response. -impl Client { +impl EventSender { /// A client that does not send anything. pub const NONE: Self = Self { mask: EventMask::NONE, inner: None, }; + pub fn new(client: tokio::sync::mpsc::Sender, mask: EventMask) -> Self { + Self { mask, inner: Some(irpc::Client::from(client)) } + } + /// A new client has been connected. pub async fn client_connected(&self, f: impl Fn() -> ClientConnected) -> ClientResult { Ok(if let Some(client) = &self.inner { @@ -245,6 +250,13 @@ impl Client { }) } + /// A new client has been connected. + pub async fn connection_closed(&self, f: impl Fn() -> ConnectionClosed) -> ClientResult { + Ok(if let Some(client) = &self.inner { + client.notify(f()).await?; + }) + } + /// Start a get request. You will get back either an error if the request should not proceed, or a /// [`RequestTracker`] that you can use to log progress for this particular request. /// @@ -371,8 +383,7 @@ pub enum ProviderProto { /// A client disconnected from the provider. #[rpc(tx = NoSender)] - #[wrap(ConnectionClosed)] - ConnectionClosed { connection_id: u64 }, + ConnectionClosed(ConnectionClosed), #[rpc(rx = mpsc::Receiver, tx = oneshot::Sender)] /// A new get request was received from the provider. @@ -411,7 +422,7 @@ mod proto { use serde::{Deserialize, Serialize}; use super::Request; - use crate::{protocol::ChunkRangesSeq, provider::TransferStats, Hash}; + use crate::{protocol::{ChunkRangesSeq, GetRequest}, provider::TransferStats, Hash}; #[derive(Debug, Serialize, Deserialize)] pub struct ClientConnected { @@ -419,6 +430,11 @@ mod proto { pub node_id: NodeId, } + #[derive(Debug, Serialize, Deserialize)] + pub struct ConnectionClosed { + pub connection_id: u64, + } + /// A new get request was received from the provider. #[derive(Debug, Serialize, Deserialize)] pub struct GetRequestReceived { @@ -426,10 +442,8 @@ mod proto { pub connection_id: u64, /// The request id. There is a new id for each request. pub request_id: u64, - /// The root hash of the request. - pub hash: Hash, - /// The exact query ranges of the request. - pub ranges: ChunkRangesSeq, + /// The request + pub request: GetRequest, } impl Request for GetRequestReceived { @@ -492,6 +506,7 @@ mod proto { #[derive(Debug, Serialize, Deserialize)] pub struct TransferStarted { pub index: u64, + pub hash: Hash, pub size: u64, } @@ -518,7 +533,7 @@ mod proto { Aborted(TransferAborted), } } -use proto::*; +pub use proto::*; mod irpc_ext { use std::future::Future; diff --git a/src/tests.rs b/src/tests.rs index e7dc823e6..9b825bd08 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -16,7 +16,7 @@ use crate::{ hashseq::HashSeq, net_protocol::BlobsProtocol, protocol::{ChunkRangesSeq, GetManyRequest, ObserveRequest, PushRequest}, - provider::Event, + provider::{events::{AbortReason, RequestUpdate}, Event, EventMask, EventSender2, ProviderMessage}, store::{ fs::{ tests::{create_n0_bao, test_data, INTERESTING_SIZES}, @@ -341,32 +341,40 @@ async fn two_nodes_get_many_mem() -> TestResult<()> { fn event_handler( allowed_nodes: impl IntoIterator, ) -> ( - mpsc::Sender, + EventSender2, watch::Receiver, AbortOnDropHandle<()>, ) { let (count_tx, count_rx) = tokio::sync::watch::channel(0usize); - let (events_tx, mut events_rx) = mpsc::channel::(16); + let (events_tx, mut events_rx) = mpsc::channel::(16); let allowed_nodes = allowed_nodes.into_iter().collect::>(); let task = AbortOnDropHandle::new(tokio::task::spawn(async move { while let Some(event) = events_rx.recv().await { match event { - Event::ClientConnected { - node_id, permitted, .. - } => { - permitted.send(allowed_nodes.contains(&node_id)).await.ok(); + ProviderMessage::ClientConnected(msg) => { + let res = if allowed_nodes.contains(&msg.inner.node_id) { + Ok(()) + } else { + Err(AbortReason::Permission) + }; + msg.tx.send(res).await.ok(); } - Event::PushRequestReceived { permitted, .. } => { - permitted.send(true).await.ok(); - } - Event::TransferCompleted { .. } => { - count_tx.send_modify(|count| *count += 1); + ProviderMessage::PushRequestReceived(mut msg) => { + msg.tx.send(Ok(())).await.ok(); + let count_tx = count_tx.clone(); + tokio::task::spawn(async move { + while let Ok(Some(update)) = msg.rx.recv().await { + if let RequestUpdate::Completed(_) = update { + count_tx.send_modify(|x| *x += 1); + } + } + }); } _ => {} } } })); - (events_tx, count_rx, task) + (EventSender2::new(events_tx, EventMask::ALL), count_rx, task) } async fn two_nodes_push_blobs( @@ -409,7 +417,7 @@ async fn two_nodes_push_blobs_fs() -> TestResult<()> { let (r1, store1, _) = node_test_setup_fs(testdir.path().join("a")).await?; let (events_tx, count_rx, _task) = event_handler([r1.endpoint().node_id()]); let (r2, store2, _) = - node_test_setup_with_events_fs(testdir.path().join("b"), Some(events_tx)).await?; + node_test_setup_with_events_fs(testdir.path().join("b"), events_tx).await?; two_nodes_push_blobs(r1, &store1, r2, &store2, count_rx).await } @@ -418,7 +426,7 @@ async fn two_nodes_push_blobs_mem() -> TestResult<()> { tracing_subscriber::fmt::try_init().ok(); let (r1, store1) = node_test_setup_mem().await?; let (events_tx, count_rx, _task) = event_handler([r1.endpoint().node_id()]); - let (r2, store2) = node_test_setup_with_events_mem(Some(events_tx)).await?; + let (r2, store2) = node_test_setup_with_events_mem(events_tx).await?; two_nodes_push_blobs(r1, &store1, r2, &store2, count_rx).await } @@ -481,12 +489,12 @@ async fn check_presence(store: &Store, sizes: &[usize]) -> TestResult<()> { } pub async fn node_test_setup_fs(db_path: PathBuf) -> TestResult<(Router, FsStore, PathBuf)> { - node_test_setup_with_events_fs(db_path, None).await + node_test_setup_with_events_fs(db_path, EventSender2::NONE).await } pub async fn node_test_setup_with_events_fs( db_path: PathBuf, - events: Option>, + events: EventSender2, ) -> TestResult<(Router, FsStore, PathBuf)> { let store = crate::store::fs::FsStore::load(&db_path).await?; let ep = Endpoint::builder().bind().await?; @@ -496,11 +504,11 @@ pub async fn node_test_setup_with_events_fs( } pub async fn node_test_setup_mem() -> TestResult<(Router, MemStore)> { - node_test_setup_with_events_mem(None).await + node_test_setup_with_events_mem(EventSender2::NONE).await } pub async fn node_test_setup_with_events_mem( - events: Option>, + events: EventSender2, ) -> TestResult<(Router, MemStore)> { let store = MemStore::new(); let ep = Endpoint::builder().bind().await?; @@ -601,7 +609,7 @@ async fn node_serve_hash_seq() -> TestResult<()> { let root_tt = store.add_bytes(hash_seq).await?; let root = root_tt.hash; let endpoint = Endpoint::builder().discovery_n0().bind().await?; - let blobs = crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), None); + let blobs = crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), EventSender2::NONE); let r1 = Router::builder(endpoint) .accept(crate::protocol::ALPN, blobs) .spawn(); @@ -632,7 +640,7 @@ async fn node_serve_blobs() -> TestResult<()> { tts.push(store.add_bytes(test_data(size)).await?); } let endpoint = Endpoint::builder().discovery_n0().bind().await?; - let blobs = crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), None); + let blobs = crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), EventSender2::NONE); let r1 = Router::builder(endpoint) .accept(crate::protocol::ALPN, blobs) .spawn(); @@ -674,7 +682,7 @@ async fn node_smoke(store: &Store) -> TestResult<()> { let tt = store.add_bytes(b"hello world".to_vec()).temp_tag().await?; let hash = *tt.hash(); let endpoint = Endpoint::builder().discovery_n0().bind().await?; - let blobs = crate::net_protocol::BlobsProtocol::new(store, endpoint.clone(), None); + let blobs = crate::net_protocol::BlobsProtocol::new(store, endpoint.clone(), EventSender2::NONE); let r1 = Router::builder(endpoint) .accept(crate::protocol::ALPN, blobs) .spawn(); From df1e1ef992bbc1b74167a4cf9b1671ca71ba08d5 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Mon, 1 Sep 2025 11:10:21 +0200 Subject: [PATCH 05/23] tests pass --- examples/random_store.rs | 9 +- src/api/blobs.rs | 7 +- src/net_protocol.rs | 8 +- src/provider.rs | 313 +++++++++++++++++++++++++++------------ src/provider/events.rs | 47 +++--- src/tests.rs | 20 +-- 6 files changed, 261 insertions(+), 143 deletions(-) diff --git a/examples/random_store.rs b/examples/random_store.rs index 6f933d511..5bc136d41 100644 --- a/examples/random_store.rs +++ b/examples/random_store.rs @@ -6,7 +6,7 @@ use iroh::{SecretKey, Watcher}; use iroh_base::ticket::NodeTicket; use iroh_blobs::{ api::downloader::Shuffled, - provider::{AbortReason, Event, EventMask, EventSender2, ProviderMessage}, + provider::{AbortReason, EventMask, EventSender2, ProviderMessage}, store::fs::FsStore, test::{add_hash_sequences, create_random_blobs}, HashAndFormat, @@ -100,12 +100,7 @@ pub fn get_or_generate_secret_key() -> Result { } } -pub fn dump_provider_events( - allow_push: bool, -) -> ( - tokio::task::JoinHandle<()>, - EventSender2, -) { +pub fn dump_provider_events(allow_push: bool) -> (tokio::task::JoinHandle<()>, EventSender2) { let (tx, mut rx) = mpsc::channel(100); let dump_task = tokio::spawn(async move { while let Some(event) = rx.recv().await { diff --git a/src/api/blobs.rs b/src/api/blobs.rs index 76f338359..d00a0a940 100644 --- a/src/api/blobs.rs +++ b/src/api/blobs.rs @@ -57,7 +57,7 @@ use super::{ }; use crate::{ api::proto::{BatchRequest, ImportByteStreamUpdate}, - provider::{ReaderContext, WriterContext}, + provider::WriterContext, store::IROH_BLOCK_SIZE, util::temp_tag::TempTag, BlobFormat, Hash, HashAndFormat, @@ -1180,9 +1180,6 @@ impl WriteProgress for WriterContext { } async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64) { - self.tracker - .transfer_started(index, hash, size) - .await - .ok(); + self.tracker.transfer_started(index, hash, size).await.ok(); } } diff --git a/src/net_protocol.rs b/src/net_protocol.rs index ca64b1a7b..a1d6a1f5d 100644 --- a/src/net_protocol.rs +++ b/src/net_protocol.rs @@ -43,15 +43,9 @@ use iroh::{ protocol::{AcceptError, ProtocolHandler}, Endpoint, Watcher, }; -use tokio::sync::mpsc; use tracing::error; -use crate::{ - api::Store, - provider::{Event, EventSender2}, - ticket::BlobTicket, - HashAndFormat, -}; +use crate::{api::Store, provider::EventSender2, ticket::BlobTicket, HashAndFormat}; #[derive(Debug)] pub(crate) struct BlobsInner { diff --git a/src/provider.rs b/src/provider.rs index b10367911..2f7bd078f 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -18,23 +18,32 @@ use iroh::{ endpoint::{self, RecvStream, SendStream}, NodeId, }; -use irpc::{channel::oneshot}; +use irpc::channel::oneshot; use n0_future::StreamExt; +use quinn::{ClosedStream, ReadToEndError}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use tokio::{io::AsyncRead, select, sync::mpsc}; use tracing::{debug, debug_span, error, trace, warn, Instrument}; use crate::{ - api::{self, blobs::{Bitfield, WriteProgress}, Store}, hashseq::HashSeq, protocol::{ + api::{ + self, + blobs::{Bitfield, WriteProgress}, + Store, + }, + hashseq::HashSeq, + protocol::{ ChunkRangesSeq, GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request, - }, provider::events::{ClientConnected, ConnectionClosed, GetRequestReceived, RequestTracker}, Hash + }, + provider::events::{ + ClientConnected, ConnectionClosed, GetManyRequestReceived, GetRequestReceived, + PushRequestReceived, RequestTracker, + }, + Hash, }; pub(crate) mod events; -pub use events::EventSender as EventSender2; -pub use events::ProviderMessage; -pub use events::EventMask; -pub use events::AbortReason; +pub use events::{AbortReason, EventMask, EventSender as EventSender2, ProviderMessage}; /// Provider progress events, to keep track of what the provider is doing. /// @@ -155,13 +164,71 @@ pub struct TransferStats { /// leave the rest of the stream for the caller to read. /// /// It is up to the caller do decide if there should be more data. -pub async fn read_request(reader: &mut ProgressReader) -> Result { - let mut counting = CountingReader::new(&mut reader.inner); +pub async fn read_request(context: &mut StreamData) -> Result { + let mut counting = CountingReader::new(&mut context.reader); let res = Request::read_async(&mut counting).await?; - reader.bytes_read += counting.read(); + context.bytes_read += counting.read(); Ok(res) } +#[derive(Debug)] +pub struct StreamData { + pub t0: Instant, + pub connection_id: u64, + pub request_id: u64, + pub reader: RecvStream, + pub writer: SendStream, + pub events: EventSender2, + pub bytes_read: u64, +} + +impl StreamData { + /// We are done with reading. Return a ProgressWriter that contains the read stats and connection id + async fn into_writer( + mut self, + tracker: RequestTracker, + ) -> Result { + let res = self.reader.read_to_end(0).await; + if let Err(e) = res { + tracker.transfer_aborted(|| None).await.ok(); + return Err(e); + }; + Ok(ProgressWriter::new( + self.writer, + WriterContext { + t0: self.t0, + connection_id: self.connection_id, + request_id: self.request_id, + bytes_read: self.bytes_read, + payload_bytes_written: 0, + other_bytes_written: 0, + tracker, + }, + )) + } + + async fn into_reader( + mut self, + tracker: RequestTracker, + ) -> Result { + let res = self.writer.finish(); + if let Err(e) = res { + tracker.transfer_aborted(|| None).await.ok(); + return Err(e); + }; + Ok(ProgressReader::new( + self.reader, + ReaderContext { + t0: self.t0, + connection_id: self.connection_id, + request_id: self.request_id, + bytes_read: self.bytes_read, + tracker, + }, + )) + } +} + #[derive(Debug)] pub struct ReaderContext { /// The start time of the transfer @@ -172,6 +239,20 @@ pub struct ReaderContext { pub request_id: u64, /// The number of bytes read from the stream pub bytes_read: u64, + /// Progress tracking for the request + pub tracker: RequestTracker, +} + +impl ReaderContext { + pub fn new(context: StreamData, tracker: RequestTracker) -> Self { + Self { + t0: context.t0, + connection_id: context.connection_id, + request_id: context.request_id, + bytes_read: context.bytes_read, + tracker, + } + } } #[derive(Debug)] @@ -192,6 +273,20 @@ pub struct WriterContext { pub tracker: RequestTracker, } +impl WriterContext { + pub fn new(context: &StreamData, tracker: RequestTracker) -> Self { + Self { + t0: context.t0, + connection_id: context.connection_id, + request_id: context.request_id, + bytes_read: context.bytes_read, + payload_bytes_written: 0, + other_bytes_written: 0, + tracker, + } + } +} + /// Wrapper for a [`quinn::SendStream`] with additional per request information. #[derive(Debug)] pub struct ProgressWriter { @@ -201,34 +296,36 @@ pub struct ProgressWriter { } impl ProgressWriter { - fn new(inner: SendStream, context: ReaderContext, tracker: RequestTracker) -> Self { - Self { inner, context: WriterContext { - connection_id: context.connection_id, - request_id: context.request_id, - bytes_read: context.bytes_read, - t0: context.t0, - payload_bytes_written: 0, - other_bytes_written: 0, - tracker, - } } + fn new(inner: SendStream, context: WriterContext) -> Self { + Self { inner, context } } async fn transfer_aborted(&self) { - self.tracker.transfer_aborted(|| Some(Box::new(TransferStats { - payload_bytes_sent: self.payload_bytes_written, - other_bytes_sent: self.other_bytes_written, - bytes_read: self.bytes_read, - duration: self.context.t0.elapsed(), - }))).await.ok(); + self.tracker + .transfer_aborted(|| { + Some(Box::new(TransferStats { + payload_bytes_sent: self.payload_bytes_written, + other_bytes_sent: self.other_bytes_written, + bytes_read: self.bytes_read, + duration: self.context.t0.elapsed(), + })) + }) + .await + .ok(); } async fn transfer_completed(&self) { - self.tracker.transfer_completed(|| Box::new(TransferStats { - payload_bytes_sent: self.payload_bytes_written, - other_bytes_sent: self.other_bytes_written, - bytes_read: self.bytes_read, - duration: self.context.t0.elapsed(), - })).await.ok(); + self.tracker + .transfer_completed(|| { + Box::new(TransferStats { + payload_bytes_sent: self.payload_bytes_written, + other_bytes_sent: self.other_bytes_written, + bytes_read: self.bytes_read, + duration: self.context.t0.elapsed(), + }) + }) + .await + .ok(); } } @@ -259,7 +356,7 @@ pub async fn handle_connection( warn!("failed to get node id"); return; }; - if let Err(cause) =progress + if let Err(cause) = progress .client_connected(|| ClientConnected { connection_id, node_id, @@ -275,95 +372,87 @@ pub async fn handle_connection( let request_id = reader.id().index(); let span = debug_span!("stream", stream_id = %request_id); let store = store.clone(); - let context = ReaderContext { + let context = StreamData { t0: Instant::now(), connection_id: connection_id, request_id: request_id, + reader, + writer, + events: progress.clone(), bytes_read: 0, }; - let reader = ProgressReader { - inner: reader, - context, - }; - tokio::spawn( - handle_stream(store, reader, writer, progress.clone()) - .instrument(span), - ); + tokio::spawn(handle_stream(store, context).instrument(span)); } progress .connection_closed(|| ConnectionClosed { connection_id }) - .await.ok(); + .await + .ok(); } .instrument(span) .await } -async fn handle_stream( - store: Store, - mut reader: ProgressReader, - writer: SendStream, - progress: EventSender2, -) { +async fn handle_stream(store: Store, mut context: StreamData) -> anyhow::Result<()> { // 1. Decode the request. debug!("reading request"); - let request = match read_request(&mut reader).await { - Ok(request) => request, - Err(e) => { - // todo: event for read request failed - return; - } - }; + let request = read_request(&mut context).await?; match request { Request::Get(request) => { - let tracker = match progress.get_request(|| GetRequestReceived { - connection_id: reader.context.connection_id, - request_id: reader.context.request_id, - request: request.clone(), - }).await { - Ok(tracker) => tracker, - Err(e) => { - trace!("Request denied: {}", e); - return; - } - }; - // we expect no more bytes after the request, so if there are more bytes, it is an invalid request. - let res = reader.inner.read_to_end(0).await; - let mut writer = ProgressWriter::new(writer, reader.context, tracker); - if res.is_err() { + let tracker = context + .events + .get_request(|| GetRequestReceived { + connection_id: context.connection_id, + request_id: context.request_id, + request: request.clone(), + }) + .await?; + let mut writer = context.into_writer(tracker).await?; + if handle_get(store, request, &mut writer).await.is_ok() { + writer.transfer_completed().await; + } else { writer.transfer_aborted().await; - return; - } - match handle_get(store, request, &mut writer).await { - Ok(()) => { - writer.transfer_completed().await; - } - Err(_) => { - writer.transfer_aborted().await; - } } } Request::GetMany(request) => { - todo!(); - // // we expect no more bytes after the request, so if there are more bytes, it is an invalid request. - // reader.inner.read_to_end(0).await?; - // // move the context so we don't lose the bytes read - // writer.context = reader.context; - // handle_get_many(store, request, writer).await + let tracker = context + .events + .get_many_request(|| GetManyRequestReceived { + connection_id: context.connection_id, + request_id: context.request_id, + request: request.clone(), + }) + .await?; + let mut writer = context.into_writer(tracker).await?; + if handle_get_many(store, request, &mut writer).await.is_ok() { + writer.transfer_completed().await; + } else { + writer.transfer_aborted().await; + } } Request::Observe(request) => { - todo!(); - // // we expect no more bytes after the request, so if there are more bytes, it is an invalid request. - // reader.inner.read_to_end(0).await?; - // handle_observe(store, request, writer).await + let mut writer = context.into_writer(RequestTracker::NONE).await?; + handle_observe(store, request, &mut writer).await.ok(); } Request::Push(request) => { - todo!(); - // writer.inner.finish()?; - // handle_push(store, request, reader).await + let tracker = context + .events + .push_request(|| PushRequestReceived { + connection_id: context.connection_id, + request_id: context.request_id, + request: request.clone(), + }) + .await?; + let mut reader = context.into_reader(tracker).await?; + if handle_push(store, request, &mut reader).await.is_ok() { + reader.transfer_completed().await; + } else { + reader.transfer_aborted().await; + } } - _ => {}, + _ => {} } + Ok(()) } /// Handle a single get request. @@ -430,7 +519,7 @@ pub async fn handle_get_many( pub async fn handle_push( store: Store, request: PushRequest, - mut reader: ProgressReader, + reader: &mut ProgressReader, ) -> Result<()> { let hash = request.hash; debug!(%hash, "push received request"); @@ -616,6 +705,40 @@ pub struct ProgressReader { context: ReaderContext, } +impl ProgressReader { + pub fn new(inner: RecvStream, context: ReaderContext) -> Self { + Self { inner, context } + } + + async fn transfer_aborted(&self) { + self.tracker + .transfer_aborted(|| { + Some(Box::new(TransferStats { + payload_bytes_sent: 0, + other_bytes_sent: 0, + bytes_read: self.bytes_read, + duration: self.context.t0.elapsed(), + })) + }) + .await + .ok(); + } + + async fn transfer_completed(&self) { + self.tracker + .transfer_completed(|| { + Box::new(TransferStats { + payload_bytes_sent: 0, + other_bytes_sent: 0, + bytes_read: self.bytes_read, + duration: self.context.t0.elapsed(), + }) + }) + .await + .ok(); + } +} + impl Deref for ProgressReader { type Target = ReaderContext; diff --git a/src/provider/events.rs b/src/provider/events.rs index 55383e77c..2ae4ba5be 100644 --- a/src/provider/events.rs +++ b/src/provider/events.rs @@ -7,7 +7,10 @@ use irpc::{ use serde::{Deserialize, Serialize}; use snafu::Snafu; -use crate::{provider::{events::irpc_ext::IrpcClientExt, TransferStats}, Hash}; +use crate::{ + provider::{events::irpc_ext::IrpcClientExt, TransferStats}, + Hash, +}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] #[repr(u8)] @@ -168,7 +171,7 @@ impl RequestTracker { } /// A request tracker that doesn't track anything. - const NONE: Self = Self { + pub const NONE: Self = Self { updates: RequestUpdates::None, throttle: None, }; @@ -176,8 +179,12 @@ impl RequestTracker { /// Transfer for index `index` started, size `size` pub async fn transfer_started(&self, index: u64, hash: &Hash, size: u64) -> irpc::Result<()> { if let RequestUpdates::Active(tx) = &self.updates { - tx.send(RequestUpdate::Started(TransferStarted { index, hash: *hash, size })) - .await?; + tx.send(RequestUpdate::Started(TransferStarted { + index, + hash: *hash, + size, + })) + .await?; } Ok(()) } @@ -200,10 +207,7 @@ impl RequestTracker { } /// Transfer completed for the previously reported blob. - pub async fn transfer_completed( - &self, - f: impl Fn() -> Box, - ) -> irpc::Result<()> { + pub async fn transfer_completed(&self, f: impl Fn() -> Box) -> irpc::Result<()> { if let RequestUpdates::Active(tx) = &self.updates { tx.send(RequestUpdate::Completed(TransferCompleted { stats: f() })) .await?; @@ -236,7 +240,10 @@ impl EventSender { }; pub fn new(client: tokio::sync::mpsc::Sender, mask: EventMask) -> Self { - Self { mask, inner: Some(irpc::Client::from(client)) } + Self { + mask, + inner: Some(irpc::Client::from(client)), + } } /// A new client has been connected. @@ -422,7 +429,11 @@ mod proto { use serde::{Deserialize, Serialize}; use super::Request; - use crate::{protocol::{ChunkRangesSeq, GetRequest}, provider::TransferStats, Hash}; + use crate::{ + protocol::{GetManyRequest, GetRequest, PushRequest}, + provider::TransferStats, + Hash, + }; #[derive(Debug, Serialize, Deserialize)] pub struct ClientConnected { @@ -458,10 +469,8 @@ mod proto { pub connection_id: u64, /// The request id. There is a new id for each request. pub request_id: u64, - /// The root hash of the request. - pub hashes: Vec, - /// The exact query ranges of the request. - pub ranges: ChunkRangesSeq, + /// The request + pub request: GetManyRequest, } impl Request for GetManyRequestReceived { @@ -476,10 +485,8 @@ mod proto { pub connection_id: u64, /// The request id. There is a new id for each request. pub request_id: u64, - /// The root hash of the request. - pub hash: Hash, - /// The exact query ranges of the request. - pub ranges: ChunkRangesSeq, + /// The request + pub request: PushRequest, } impl Request for PushRequestReceived { @@ -539,7 +546,7 @@ mod irpc_ext { use std::future::Future; use irpc::{ - channel::{mpsc, none::NoSender, oneshot}, + channel::{mpsc, none::NoSender}, Channels, RpcMessage, Service, WithChannels, }; @@ -581,7 +588,7 @@ mod irpc_ext { Ok(req_tx) } irpc::Request::Remote(remote) => { - let (s, r) = remote.write(msg).await?; + let (s, _) = remote.write(msg).await?; Ok(s.into()) } } diff --git a/src/tests.rs b/src/tests.rs index 9b825bd08..e99d0fe02 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -16,7 +16,10 @@ use crate::{ hashseq::HashSeq, net_protocol::BlobsProtocol, protocol::{ChunkRangesSeq, GetManyRequest, ObserveRequest, PushRequest}, - provider::{events::{AbortReason, RequestUpdate}, Event, EventMask, EventSender2, ProviderMessage}, + provider::{ + events::{AbortReason, RequestUpdate}, + EventMask, EventSender2, ProviderMessage, + }, store::{ fs::{ tests::{create_n0_bao, test_data, INTERESTING_SIZES}, @@ -340,11 +343,7 @@ async fn two_nodes_get_many_mem() -> TestResult<()> { fn event_handler( allowed_nodes: impl IntoIterator, -) -> ( - EventSender2, - watch::Receiver, - AbortOnDropHandle<()>, -) { +) -> (EventSender2, watch::Receiver, AbortOnDropHandle<()>) { let (count_tx, count_rx) = tokio::sync::watch::channel(0usize); let (events_tx, mut events_rx) = mpsc::channel::(16); let allowed_nodes = allowed_nodes.into_iter().collect::>(); @@ -609,7 +608,8 @@ async fn node_serve_hash_seq() -> TestResult<()> { let root_tt = store.add_bytes(hash_seq).await?; let root = root_tt.hash; let endpoint = Endpoint::builder().discovery_n0().bind().await?; - let blobs = crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), EventSender2::NONE); + let blobs = + crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), EventSender2::NONE); let r1 = Router::builder(endpoint) .accept(crate::protocol::ALPN, blobs) .spawn(); @@ -640,7 +640,8 @@ async fn node_serve_blobs() -> TestResult<()> { tts.push(store.add_bytes(test_data(size)).await?); } let endpoint = Endpoint::builder().discovery_n0().bind().await?; - let blobs = crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), EventSender2::NONE); + let blobs = + crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), EventSender2::NONE); let r1 = Router::builder(endpoint) .accept(crate::protocol::ALPN, blobs) .spawn(); @@ -682,7 +683,8 @@ async fn node_smoke(store: &Store) -> TestResult<()> { let tt = store.add_bytes(b"hello world".to_vec()).temp_tag().await?; let hash = *tt.hash(); let endpoint = Endpoint::builder().discovery_n0().bind().await?; - let blobs = crate::net_protocol::BlobsProtocol::new(store, endpoint.clone(), EventSender2::NONE); + let blobs = + crate::net_protocol::BlobsProtocol::new(store, endpoint.clone(), EventSender2::NONE); let r1 = Router::builder(endpoint) .accept(crate::protocol::ALPN, blobs) .spawn(); From a9ac8e57cde7975e1844298703f9a17a7dd897d5 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Mon, 1 Sep 2025 15:31:41 +0200 Subject: [PATCH 06/23] Everything works --- README.md | 4 +- examples/custom-protocol.rs | 6 +- examples/mdns-discovery.rs | 4 +- examples/random_store.rs | 48 +- examples/transfer.rs | 6 +- .../store/fs/util/entity_manager.txt | 7 + src/net_protocol.rs | 10 +- src/provider.rs | 425 ++++++------------ src/provider/events.rs | 249 ++++------ src/tests.rs | 24 +- 10 files changed, 295 insertions(+), 488 deletions(-) create mode 100644 proptest-regressions/store/fs/util/entity_manager.txt diff --git a/README.md b/README.md index 2f374e8fb..1a136e44d 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ Here is a basic example of how to set up `iroh-blobs` with `iroh`: ```rust,no_run use iroh::{protocol::Router, Endpoint}; -use iroh_blobs::{store::mem::MemStore, BlobsProtocol}; +use iroh_blobs::{store::mem::MemStore, BlobsProtocol, provider::events::EventSender}; #[tokio::main] async fn main() -> anyhow::Result<()> { @@ -44,7 +44,7 @@ async fn main() -> anyhow::Result<()> { // create a protocol handler using an in-memory blob store. let store = MemStore::new(); - let blobs = BlobsProtocol::new(&store, endpoint.clone(), None); + let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender::NONE); // build the router let router = Router::builder(endpoint) diff --git a/examples/custom-protocol.rs b/examples/custom-protocol.rs index 6542acd18..d4d29e27f 100644 --- a/examples/custom-protocol.rs +++ b/examples/custom-protocol.rs @@ -48,7 +48,9 @@ use iroh::{ protocol::{AcceptError, ProtocolHandler, Router}, Endpoint, NodeId, }; -use iroh_blobs::{api::Store, provider::EventSender2, store::mem::MemStore, BlobsProtocol, Hash}; +use iroh_blobs::{ + api::Store, provider::events::EventSender, store::mem::MemStore, BlobsProtocol, Hash, +}; mod common; use common::{get_or_generate_secret_key, setup_logging}; @@ -100,7 +102,7 @@ async fn listen(text: Vec) -> Result<()> { proto.insert_and_index(text).await?; } // Build the iroh-blobs protocol handler, which is used to download blobs. - let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender2::NONE); + let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender::NONE); // create a router that handles both our custom protocol and the iroh-blobs protocol. let node = Router::builder(endpoint) diff --git a/examples/mdns-discovery.rs b/examples/mdns-discovery.rs index ef5d0619c..ab11bc864 100644 --- a/examples/mdns-discovery.rs +++ b/examples/mdns-discovery.rs @@ -18,7 +18,7 @@ use clap::{Parser, Subcommand}; use iroh::{ discovery::mdns::MdnsDiscovery, protocol::Router, Endpoint, PublicKey, RelayMode, SecretKey, }; -use iroh_blobs::{provider::EventSender2, store::mem::MemStore, BlobsProtocol, Hash}; +use iroh_blobs::{provider::events::EventSender, store::mem::MemStore, BlobsProtocol, Hash}; mod common; use common::{get_or_generate_secret_key, setup_logging}; @@ -68,7 +68,7 @@ async fn accept(path: &Path) -> Result<()> { .await?; let builder = Router::builder(endpoint.clone()); let store = MemStore::new(); - let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender2::NONE); + let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender::NONE); let builder = builder.accept(iroh_blobs::ALPN, blobs.clone()); let node = builder.spawn(); diff --git a/examples/random_store.rs b/examples/random_store.rs index 5bc136d41..f36017e8d 100644 --- a/examples/random_store.rs +++ b/examples/random_store.rs @@ -6,11 +6,12 @@ use iroh::{SecretKey, Watcher}; use iroh_base::ticket::NodeTicket; use iroh_blobs::{ api::downloader::Shuffled, - provider::{AbortReason, EventMask, EventSender2, ProviderMessage}, + provider::events::{AbortReason, EventMask, EventSender, ProviderMessage}, store::fs::FsStore, test::{add_hash_sequences, create_random_blobs}, HashAndFormat, }; +use irpc::RpcMessage; use n0_future::StreamExt; use rand::{rngs::StdRng, Rng, SeedableRng}; use tokio::{signal::ctrl_c, sync::mpsc}; @@ -100,8 +101,15 @@ pub fn get_or_generate_secret_key() -> Result { } } -pub fn dump_provider_events(allow_push: bool) -> (tokio::task::JoinHandle<()>, EventSender2) { +pub fn dump_provider_events(allow_push: bool) -> (tokio::task::JoinHandle<()>, EventSender) { let (tx, mut rx) = mpsc::channel(100); + fn dump_updates(mut rx: irpc::channel::mpsc::Receiver) { + tokio::spawn(async move { + while let Ok(Some(update)) = rx.recv().await { + println!("{update:?}"); + } + }); + } let dump_task = tokio::spawn(async move { while let Some(event) = rx.recv().await { match event { @@ -115,29 +123,23 @@ pub fn dump_provider_events(allow_push: bool) -> (tokio::task::JoinHandle<()>, E ProviderMessage::ConnectionClosed(msg) => { println!("{:?}", msg.inner); } - ProviderMessage::GetRequestReceived(mut msg) => { + ProviderMessage::GetRequestReceived(msg) => { println!("{:?}", msg.inner); msg.tx.send(Ok(())).await.ok(); - tokio::spawn(async move { - while let Ok(update) = msg.rx.recv().await { - info!("{update:?}"); - } - }); + dump_updates(msg.rx); } ProviderMessage::GetRequestReceivedNotify(msg) => { println!("{:?}", msg.inner); + dump_updates(msg.rx); } - ProviderMessage::GetManyRequestReceived(mut msg) => { + ProviderMessage::GetManyRequestReceived(msg) => { println!("{:?}", msg.inner); msg.tx.send(Ok(())).await.ok(); - tokio::spawn(async move { - while let Ok(update) = msg.rx.recv().await { - info!("{update:?}"); - } - }); + dump_updates(msg.rx); } ProviderMessage::GetManyRequestReceivedNotify(msg) => { println!("{:?}", msg.inner); + dump_updates(msg.rx); } ProviderMessage::PushRequestReceived(msg) => { println!("{:?}", msg.inner); @@ -147,9 +149,25 @@ pub fn dump_provider_events(allow_push: bool) -> (tokio::task::JoinHandle<()>, E Err(AbortReason::Permission) }; msg.tx.send(res).await.ok(); + dump_updates(msg.rx); } ProviderMessage::PushRequestReceivedNotify(msg) => { println!("{:?}", msg.inner); + dump_updates(msg.rx); + } + ProviderMessage::ObserveRequestReceived(msg) => { + println!("{:?}", msg.inner); + let res = if allow_push { + Ok(()) + } else { + Err(AbortReason::Permission) + }; + msg.tx.send(res).await.ok(); + dump_updates(msg.rx); + } + ProviderMessage::ObserveRequestReceivedNotify(msg) => { + println!("{:?}", msg.inner); + dump_updates(msg.rx); } ProviderMessage::Throttle(msg) => { println!("{:?}", msg.inner); @@ -158,7 +176,7 @@ pub fn dump_provider_events(allow_push: bool) -> (tokio::task::JoinHandle<()>, E } } }); - (dump_task, EventSender2::new(tx, EventMask::ALL)) + (dump_task, EventSender::new(tx, EventMask::ALL)) } #[tokio::main] diff --git a/examples/transfer.rs b/examples/transfer.rs index baa1e343c..8347774ca 100644 --- a/examples/transfer.rs +++ b/examples/transfer.rs @@ -1,7 +1,9 @@ use std::path::PathBuf; use iroh::{protocol::Router, Endpoint}; -use iroh_blobs::{provider::EventSender2, store::mem::MemStore, ticket::BlobTicket, BlobsProtocol}; +use iroh_blobs::{ + provider::events::EventSender, store::mem::MemStore, ticket::BlobTicket, BlobsProtocol, +}; #[tokio::main] async fn main() -> anyhow::Result<()> { @@ -12,7 +14,7 @@ async fn main() -> anyhow::Result<()> { // We initialize an in-memory backing store for iroh-blobs let store = MemStore::new(); // Then we initialize a struct that can accept blobs requests over iroh connections - let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender2::NONE); + let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender::NONE); // Grab all passed in arguments, the first one is the binary itself, so we skip it. let args: Vec = std::env::args().skip(1).collect(); diff --git a/proptest-regressions/store/fs/util/entity_manager.txt b/proptest-regressions/store/fs/util/entity_manager.txt new file mode 100644 index 000000000..94b6aa63c --- /dev/null +++ b/proptest-regressions/store/fs/util/entity_manager.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 0f2ebc49ab2f84e112f08407bb94654fbcb1f19050a4a8a6196383557696438a # shrinks to input = _TestCountersManagerProptestFsArgs { entries: [(15313427648878534792, 264348813928009031854006459208395772047), (1642534478798447378, 15989109311941500072752977306696275871), (8755041673862065815, 172763711808688570294350362332402629716), (4993597758667891804, 114145440157220458287429360639759690928), (15031383154962489250, 63217081714858286463391060323168548783), (17668469631267503333, 11878544422669770587175118199598836678), (10507570291819955314, 126584081645379643144412921692654648228), (3979008599365278329, 283717221942996985486273080647433218905), (8316838360288996639, 334043288511621783152802090833905919408), (15673798930962474157, 77551315511802713260542200115027244708), (12058791254144360414, 56638044274259821850511200885092637649), (8191628769638031337, 314181956273420400069887649110740549194), (6290369460137232066, 255779791286732775990301011955519176773), (11919824746661852269, 319400891587146831511371932480749645441), (12491631698789073154, 271279849791970841069522263758329847554), (53891048909263304, 12061234604041487609497959407391945555), (9486366498650667097, 311383186592430597410801882015456718030), (15696332331789302593, 306911490707714340526403119780178604150), (8699088947997536151, 312272624973367009520183311568498652066), (1144772544750976199, 200591877747619565555594857038887015), (5907208586200645081, 299942008952473970881666769409865744975), (3384528743842518913, 26230956866762934113564101494944411446), (13877357832690956494, 229457597607752760006918374695475345151), (2965687966026226090, 306489188264741716662410004273408761623), (13624286905717143613, 232801392956394366686194314010536008033), (3622356130274722018, 162030840677521022192355139208505458492), (17807768575470996347, 264107246314713159406963697924105744409), (5103434150074147746, 331686166459964582006209321975587627262), (5962771466034321974, 300961804728115777587520888809168362574), (2930645694242691907, 127752709774252686733969795258447263979), (16197574560597474644, 245410120683069493317132088266217906749), (12478835478062365617, 103838791113879912161511798836229961653), (5503595333662805357, 92368472243854403026472376408708548349), (18122734335129614364, 288955542597300001147753560885976966029), (12688080215989274550, 85237436689682348751672119832134138932), (4148468277722853958, 297778117327421209654837771300216669574), (8749445804640085302, 79595866493078234154562014325793780126), (12442730869682574563, 196176786402808588883611974143577417817), (6110644747049355904, 26592587989877021920275416199052685135), (5851164380497779369, 158876888501825038083692899057819261957), (9497384378514985275, 15279835675313542048650599472403150097), (10661092311826161857, 250089949043892591422587928179995867509), (10046856000675345423, 231369150063141386398059701278066296663)] } diff --git a/src/net_protocol.rs b/src/net_protocol.rs index a1d6a1f5d..aa45aa473 100644 --- a/src/net_protocol.rs +++ b/src/net_protocol.rs @@ -7,7 +7,7 @@ //! ```rust //! # async fn example() -> anyhow::Result<()> { //! use iroh::{protocol::Router, Endpoint}; -//! use iroh_blobs::{store, BlobsProtocol}; +//! use iroh_blobs::{provider::events::EventSender, store, BlobsProtocol}; //! //! // create a store //! let store = store::fs::FsStore::load("blobs").await?; @@ -19,7 +19,7 @@ //! let endpoint = Endpoint::builder().discovery_n0().bind().await?; //! //! // create a blobs protocol handler -//! let blobs = BlobsProtocol::new(&store, endpoint.clone(), None); +//! let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender::NONE); //! //! // create a router and add the blobs protocol handler //! let router = Router::builder(endpoint) @@ -45,13 +45,13 @@ use iroh::{ }; use tracing::error; -use crate::{api::Store, provider::EventSender2, ticket::BlobTicket, HashAndFormat}; +use crate::{api::Store, provider::events::EventSender, ticket::BlobTicket, HashAndFormat}; #[derive(Debug)] pub(crate) struct BlobsInner { pub(crate) store: Store, pub(crate) endpoint: Endpoint, - pub(crate) events: EventSender2, + pub(crate) events: EventSender, } /// A protocol handler for the blobs protocol. @@ -69,7 +69,7 @@ impl Deref for BlobsProtocol { } impl BlobsProtocol { - pub fn new(store: &Store, endpoint: Endpoint, events: EventSender2) -> Self { + pub fn new(store: &Store, endpoint: Endpoint, events: EventSender) -> Self { Self { inner: Arc::new(BlobsInner { store: store.clone(), diff --git a/src/provider.rs b/src/provider.rs index 2f7bd078f..a98c1b3a7 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -14,16 +14,12 @@ use std::{ use anyhow::{Context, Result}; use bao_tree::ChunkRanges; -use iroh::{ - endpoint::{self, RecvStream, SendStream}, - NodeId, -}; -use irpc::channel::oneshot; +use iroh::endpoint::{self, RecvStream, SendStream}; use n0_future::StreamExt; -use quinn::{ClosedStream, ReadToEndError}; +use quinn::{ClosedStream, ConnectionError, ReadToEndError}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use tokio::{io::AsyncRead, select, sync::mpsc}; -use tracing::{debug, debug_span, error, trace, warn, Instrument}; +use tokio::{io::AsyncRead, select}; +use tracing::{debug, debug_span, warn, Instrument}; use crate::{ api::{ @@ -32,112 +28,12 @@ use crate::{ Store, }, hashseq::HashSeq, - protocol::{ - ChunkRangesSeq, GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, - Request, - }, - provider::events::{ - ClientConnected, ConnectionClosed, GetManyRequestReceived, GetRequestReceived, - PushRequestReceived, RequestTracker, - }, + protocol::{GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request}, + provider::events::{ClientConnected, ClientError, ConnectionClosed, RequestTracker}, Hash, }; -pub(crate) mod events; -pub use events::{AbortReason, EventMask, EventSender as EventSender2, ProviderMessage}; - -/// Provider progress events, to keep track of what the provider is doing. -/// -/// ClientConnected -> -/// (GetRequestReceived -> (TransferStarted -> TransferProgress*n)*n -> (TransferCompleted | TransferAborted))*n -> -/// ConnectionClosed -#[derive(Debug)] -pub enum Event { - /// A new client connected to the provider. - ClientConnected { - connection_id: u64, - node_id: NodeId, - permitted: oneshot::Sender, - }, - /// Connection closed. - ConnectionClosed { connection_id: u64 }, - /// A new get request was received from the provider. - GetRequestReceived { - /// The connection id. Multiple requests can be sent over the same connection. - connection_id: u64, - /// The request id. There is a new id for each request. - request_id: u64, - /// The root hash of the request. - hash: Hash, - /// The exact query ranges of the request. - ranges: ChunkRangesSeq, - }, - /// A new get request was received from the provider. - GetManyRequestReceived { - /// The connection id. Multiple requests can be sent over the same connection. - connection_id: u64, - /// The request id. There is a new id for each request. - request_id: u64, - /// The root hash of the request. - hashes: Vec, - /// The exact query ranges of the request. - ranges: ChunkRangesSeq, - }, - /// A new get request was received from the provider. - PushRequestReceived { - /// The connection id. Multiple requests can be sent over the same connection. - connection_id: u64, - /// The request id. There is a new id for each request. - request_id: u64, - /// The root hash of the request. - hash: Hash, - /// The exact query ranges of the request. - ranges: ChunkRangesSeq, - /// Complete this to permit the request. - permitted: oneshot::Sender, - }, - /// Transfer for the nth blob started. - TransferStarted { - /// The connection id. Multiple requests can be sent over the same connection. - connection_id: u64, - /// The request id. There is a new id for each request. - request_id: u64, - /// The index of the blob in the request. 0 for the first blob or for raw blob requests. - index: u64, - /// The hash of the blob. This is the hash of the request for the first blob, the child hash (index-1) for subsequent blobs. - hash: Hash, - /// The size of the blob. This is the full size of the blob, not the size we are sending. - size: u64, - }, - /// Progress of the transfer. - TransferProgress { - /// The connection id. Multiple requests can be sent over the same connection. - connection_id: u64, - /// The request id. There is a new id for each request. - request_id: u64, - /// The index of the blob in the request. 0 for the first blob or for raw blob requests. - index: u64, - /// The end offset of the chunk that was sent. - end_offset: u64, - }, - /// Entire transfer completed. - TransferCompleted { - /// The connection id. Multiple requests can be sent over the same connection. - connection_id: u64, - /// The request id. There is a new id for each request. - request_id: u64, - /// Statistics about the transfer. - stats: Box, - }, - /// Entire transfer aborted - TransferAborted { - /// The connection id. Multiple requests can be sent over the same connection. - connection_id: u64, - /// The request id. There is a new id for each request. - request_id: u64, - /// Statistics about the part of the transfer that was aborted. - stats: Option>, - }, -} +pub mod events; +use events::EventSender; /// Statistics about a successful or failed transfer. #[derive(Debug, Serialize, Deserialize)] @@ -150,8 +46,9 @@ pub struct TransferStats { pub other_bytes_sent: u64, /// The number of bytes read from the stream. /// - /// This is the size of the request. - pub bytes_read: u64, + /// In most cases this is just the request, for push requests this is + /// request, size header and hash pairs. + pub other_bytes_read: u64, /// Total duration from reading the request to transfer completed. pub duration: Duration, } @@ -167,22 +64,38 @@ pub struct TransferStats { pub async fn read_request(context: &mut StreamData) -> Result { let mut counting = CountingReader::new(&mut context.reader); let res = Request::read_async(&mut counting).await?; - context.bytes_read += counting.read(); + context.other_bytes_read += counting.read(); Ok(res) } #[derive(Debug)] pub struct StreamData { - pub t0: Instant, - pub connection_id: u64, - pub request_id: u64, - pub reader: RecvStream, - pub writer: SendStream, - pub events: EventSender2, - pub bytes_read: u64, + t0: Instant, + connection_id: u64, + request_id: u64, + reader: RecvStream, + writer: SendStream, + other_bytes_read: u64, + events: EventSender, } impl StreamData { + pub async fn accept( + conn: &endpoint::Connection, + events: &EventSender, + ) -> Result { + let (writer, reader) = conn.accept_bi().await?; + Ok(Self { + t0: Instant::now(), + connection_id: conn.stable_id() as u64, + request_id: reader.id().into(), + reader, + writer, + other_bytes_read: 0, + events: events.clone(), + }) + } + /// We are done with reading. Return a ProgressWriter that contains the read stats and connection id async fn into_writer( mut self, @@ -190,7 +103,10 @@ impl StreamData { ) -> Result { let res = self.reader.read_to_end(0).await; if let Err(e) = res { - tracker.transfer_aborted(|| None).await.ok(); + tracker + .transfer_aborted(|| Box::new(self.stats())) + .await + .ok(); return Err(e); }; Ok(ProgressWriter::new( @@ -199,7 +115,7 @@ impl StreamData { t0: self.t0, connection_id: self.connection_id, request_id: self.request_id, - bytes_read: self.bytes_read, + other_bytes_read: self.other_bytes_read, payload_bytes_written: 0, other_bytes_written: 0, tracker, @@ -213,7 +129,10 @@ impl StreamData { ) -> Result { let res = self.writer.finish(); if let Err(e) = res { - tracker.transfer_aborted(|| None).await.ok(); + tracker + .transfer_aborted(|| Box::new(self.stats())) + .await + .ok(); return Err(e); }; Ok(ProgressReader::new( @@ -222,11 +141,56 @@ impl StreamData { t0: self.t0, connection_id: self.connection_id, request_id: self.request_id, - bytes_read: self.bytes_read, + other_bytes_read: self.other_bytes_read, tracker, }, )) } + + async fn get_request( + &self, + f: impl FnOnce() -> GetRequest, + ) -> Result { + self.events + .request(f, self.connection_id, self.request_id) + .await + } + + async fn get_many_request( + &self, + f: impl FnOnce() -> GetManyRequest, + ) -> Result { + self.events + .request(f, self.connection_id, self.request_id) + .await + } + + async fn push_request( + &self, + f: impl FnOnce() -> PushRequest, + ) -> Result { + self.events + .request(f, self.connection_id, self.request_id) + .await + } + + async fn observe_request( + &self, + f: impl FnOnce() -> ObserveRequest, + ) -> Result { + self.events + .request(f, self.connection_id, self.request_id) + .await + } + + fn stats(&self) -> TransferStats { + TransferStats { + payload_bytes_sent: 0, + other_bytes_sent: 0, + other_bytes_read: self.other_bytes_read, + duration: self.t0.elapsed(), + } + } } #[derive(Debug)] @@ -238,7 +202,7 @@ pub struct ReaderContext { /// The request ID from the recv stream pub request_id: u64, /// The number of bytes read from the stream - pub bytes_read: u64, + pub other_bytes_read: u64, /// Progress tracking for the request pub tracker: RequestTracker, } @@ -249,10 +213,19 @@ impl ReaderContext { t0: context.t0, connection_id: context.connection_id, request_id: context.request_id, - bytes_read: context.bytes_read, + other_bytes_read: context.other_bytes_read, tracker, } } + + pub fn stats(&self) -> TransferStats { + TransferStats { + payload_bytes_sent: 0, + other_bytes_sent: 0, + other_bytes_read: self.other_bytes_read, + duration: self.t0.elapsed(), + } + } } #[derive(Debug)] @@ -264,7 +237,7 @@ pub struct WriterContext { /// The request ID from the recv stream pub request_id: u64, /// The number of bytes read from the stream - pub bytes_read: u64, + pub other_bytes_read: u64, /// The number of payload bytes written to the stream pub payload_bytes_written: u64, /// The number of bytes written that are not part of the payload @@ -279,12 +252,21 @@ impl WriterContext { t0: context.t0, connection_id: context.connection_id, request_id: context.request_id, - bytes_read: context.bytes_read, + other_bytes_read: context.other_bytes_read, payload_bytes_written: 0, other_bytes_written: 0, tracker, } } + + pub fn stats(&self) -> TransferStats { + TransferStats { + payload_bytes_sent: self.payload_bytes_written, + other_bytes_sent: self.other_bytes_written, + other_bytes_read: self.other_bytes_read, + duration: self.t0.elapsed(), + } + } } /// Wrapper for a [`quinn::SendStream`] with additional per request information. @@ -302,28 +284,14 @@ impl ProgressWriter { async fn transfer_aborted(&self) { self.tracker - .transfer_aborted(|| { - Some(Box::new(TransferStats { - payload_bytes_sent: self.payload_bytes_written, - other_bytes_sent: self.other_bytes_written, - bytes_read: self.bytes_read, - duration: self.context.t0.elapsed(), - })) - }) + .transfer_aborted(|| Box::new(self.stats())) .await .ok(); } async fn transfer_completed(&self) { self.tracker - .transfer_completed(|| { - Box::new(TransferStats { - payload_bytes_sent: self.payload_bytes_written, - other_bytes_sent: self.other_bytes_written, - bytes_read: self.bytes_read, - duration: self.context.t0.elapsed(), - }) - }) + .transfer_completed(|| Box::new(self.stats())) .await .ok(); } @@ -347,7 +315,7 @@ impl DerefMut for ProgressWriter { pub async fn handle_connection( connection: endpoint::Connection, store: Store, - progress: EventSender2, + progress: EventSender, ) { let connection_id = connection.stable_id() as u64; let span = debug_span!("connection", connection_id); @@ -366,21 +334,9 @@ pub async fn handle_connection( debug!("client not authorized to connect: {cause}"); return; } - while let Ok((writer, reader)) = connection.accept_bi().await { - // The stream ID index is used to identify this request. Requests only arrive in - // bi-directional RecvStreams initiated by the client, so this uniquely identifies them. - let request_id = reader.id().index(); - let span = debug_span!("stream", stream_id = %request_id); + while let Ok(context) = StreamData::accept(&connection, &progress).await { + let span = debug_span!("stream", stream_id = %context.request_id); let store = store.clone(); - let context = StreamData { - t0: Instant::now(), - connection_id: connection_id, - request_id: request_id, - reader, - writer, - events: progress.clone(), - bytes_read: 0, - }; tokio::spawn(handle_stream(store, context).instrument(span)); } progress @@ -399,14 +355,7 @@ async fn handle_stream(store: Store, mut context: StreamData) -> anyhow::Result< match request { Request::Get(request) => { - let tracker = context - .events - .get_request(|| GetRequestReceived { - connection_id: context.connection_id, - request_id: context.request_id, - request: request.clone(), - }) - .await?; + let tracker = context.get_request(|| request.clone()).await?; let mut writer = context.into_writer(tracker).await?; if handle_get(store, request, &mut writer).await.is_ok() { writer.transfer_completed().await; @@ -415,14 +364,7 @@ async fn handle_stream(store: Store, mut context: StreamData) -> anyhow::Result< } } Request::GetMany(request) => { - let tracker = context - .events - .get_many_request(|| GetManyRequestReceived { - connection_id: context.connection_id, - request_id: context.request_id, - request: request.clone(), - }) - .await?; + let tracker = context.get_many_request(|| request.clone()).await?; let mut writer = context.into_writer(tracker).await?; if handle_get_many(store, request, &mut writer).await.is_ok() { writer.transfer_completed().await; @@ -431,18 +373,16 @@ async fn handle_stream(store: Store, mut context: StreamData) -> anyhow::Result< } } Request::Observe(request) => { - let mut writer = context.into_writer(RequestTracker::NONE).await?; - handle_observe(store, request, &mut writer).await.ok(); + let tracker = context.observe_request(|| request.clone()).await?; + let mut writer = context.into_writer(tracker).await?; + if handle_observe(store, request, &mut writer).await.is_ok() { + writer.transfer_completed().await; + } else { + writer.transfer_aborted().await; + } } Request::Push(request) => { - let tracker = context - .events - .push_request(|| PushRequestReceived { - connection_id: context.connection_id, - request_id: context.request_id, - request: request.clone(), - }) - .await?; + let tracker = context.push_request(|| request.clone()).await?; let mut reader = context.into_reader(tracker).await?; if handle_push(store, request, &mut reader).await.is_ok() { reader.transfer_completed().await; @@ -607,99 +547,6 @@ async fn send_observe_item(writer: &mut ProgressWriter, item: &Bitfield) -> Resu Ok(()) } -/// Helper to lazyly create an [`Event`], in the case that the event creation -/// is expensive and we want to avoid it if the progress sender is disabled. -pub trait LazyEvent { - fn call(self) -> Event; -} - -impl LazyEvent for T -where - T: FnOnce() -> Event, -{ - fn call(self) -> Event { - self() - } -} - -impl LazyEvent for Event { - fn call(self) -> Event { - self - } -} - -/// A sender for provider events. -#[derive(Debug, Clone)] -pub struct EventSender(EventSenderInner); - -#[derive(Debug, Clone)] -enum EventSenderInner { - Disabled, - Enabled(mpsc::Sender), -} - -impl EventSender { - pub fn new(sender: Option>) -> Self { - match sender { - Some(sender) => Self(EventSenderInner::Enabled(sender)), - None => Self(EventSenderInner::Disabled), - } - } - - /// Send a client connected event, if the progress sender is enabled. - /// - /// This will permit the client to connect if the sender is disabled. - #[must_use = "permit should be checked by the caller"] - pub async fn authorize_client_connection(&self, connection_id: u64, node_id: NodeId) -> bool { - let mut wait_for_permit = None; - self.send(|| { - let (tx, rx) = oneshot::channel(); - wait_for_permit = Some(rx); - Event::ClientConnected { - connection_id, - node_id, - permitted: tx, - } - }) - .await; - if let Some(wait_for_permit) = wait_for_permit { - // if we have events configured, and they drop the channel, we consider that as a no! - // todo: this will be confusing and needs to be properly documented. - wait_for_permit.await.unwrap_or(false) - } else { - true - } - } - - /// Send an ephemeral event, if the progress sender is enabled. - /// - /// The event will only be created if the sender is enabled. - fn try_send(&self, event: impl LazyEvent) { - match &self.0 { - EventSenderInner::Enabled(sender) => { - let value = event.call(); - sender.try_send(value).ok(); - } - EventSenderInner::Disabled => {} - } - } - - /// Send a mandatory event, if the progress sender is enabled. - /// - /// The event only be created if the sender is enabled. - async fn send(&self, event: impl LazyEvent) { - match &self.0 { - EventSenderInner::Enabled(sender) => { - let value = event.call(); - if let Err(err) = sender.send(value).await { - error!("failed to send progress event: {:?}", err); - } - } - EventSenderInner::Disabled => {} - } - } -} - pub struct ProgressReader { inner: RecvStream, context: ReaderContext, @@ -712,28 +559,14 @@ impl ProgressReader { async fn transfer_aborted(&self) { self.tracker - .transfer_aborted(|| { - Some(Box::new(TransferStats { - payload_bytes_sent: 0, - other_bytes_sent: 0, - bytes_read: self.bytes_read, - duration: self.context.t0.elapsed(), - })) - }) + .transfer_aborted(|| Box::new(self.stats())) .await .ok(); } async fn transfer_completed(&self) { self.tracker - .transfer_completed(|| { - Box::new(TransferStats { - payload_bytes_sent: 0, - other_bytes_sent: 0, - bytes_read: self.bytes_read, - duration: self.context.t0.elapsed(), - }) - }) + .transfer_completed(|| Box::new(self.stats())) .await .ok(); } diff --git a/src/provider/events.rs b/src/provider/events.rs index 2ae4ba5be..c8b94c8b8 100644 --- a/src/provider/events.rs +++ b/src/provider/events.rs @@ -8,6 +8,7 @@ use serde::{Deserialize, Serialize}; use snafu::Snafu; use crate::{ + protocol::{GetManyRequest, GetRequest, ObserveRequest, PushRequest}, provider::{events::irpc_ext::IrpcClientExt, TransferStats}, Hash, }; @@ -24,15 +25,27 @@ pub enum ConnectMode { Request, } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +#[repr(u8)] +pub enum ObserveMode { + /// We don't get notification of connect events at all. + #[default] + None, + /// We get a notification for connect events. + Notify, + /// We get a request for connect events and can reject incoming connections. + Request, +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] #[repr(u8)] pub enum RequestMode { /// We don't get request events at all. #[default] None, - /// We get a notification for each request. + /// We get a notification for each request, but no transfer events. Notify, - /// We get a request for each request, and can reject incoming requests. + /// We get a request for each request, and can reject incoming requests, but no transfer events. Request, /// We get a notification for each request as well as detailed transfer events. NotifyLog, @@ -101,6 +114,7 @@ pub struct EventMask { get: RequestMode, get_many: RequestMode, push: RequestMode, + observe: ObserveMode, /// throttling is somewhat costly, so you can disable it completely throttle: ThrottleMode, } @@ -113,6 +127,7 @@ impl EventMask { get_many: RequestMode::None, push: RequestMode::None, throttle: ThrottleMode::None, + observe: ObserveMode::None, }; /// You get asked for every single thing that is going on and can intervene/throttle. @@ -122,6 +137,7 @@ impl EventMask { get_many: RequestMode::RequestLog, push: RequestMode::RequestLog, throttle: ThrottleMode::Throttle, + observe: ObserveMode::Request, }; /// You get notified for every single thing that is going on, but can't intervene. @@ -131,6 +147,7 @@ impl EventMask { get_many: RequestMode::NotifyLog, push: RequestMode::NotifyLog, throttle: ThrottleMode::None, + observe: ObserveMode::Notify, }; } @@ -216,10 +233,7 @@ impl RequestTracker { } /// Transfer aborted for the previously reported blob. - pub async fn transfer_aborted( - &self, - f: impl Fn() -> Option>, - ) -> irpc::Result<()> { + pub async fn transfer_aborted(&self, f: impl Fn() -> Box) -> irpc::Result<()> { if let RequestUpdates::Active(tx) = &self.updates { tx.send(RequestUpdate::Aborted(TransferAborted { stats: f() })) .await?; @@ -264,108 +278,82 @@ impl EventSender { }) } - /// Start a get request. You will get back either an error if the request should not proceed, or a - /// [`RequestTracker`] that you can use to log progress for this particular request. - /// - /// Depending on the event sender config, the returned tracker might be a no-op. - pub async fn get_request( - &self, - f: impl FnOnce() -> GetRequestReceived, - ) -> Result { - self.request(f).await - } - - // Start a get_many request. You will get back either an error if the request should not proceed, or a - /// [`RequestTracker`] that you can use to log progress for this particular request. - /// - /// Depending on the event sender config, the returned tracker might be a no-op. - pub async fn get_many_request( - &self, - f: impl FnOnce() -> GetManyRequestReceived, - ) -> Result { - self.request(f).await - } - - // Start a push request. You will get back either an error if the request should not proceed, or a - /// [`RequestTracker`] that you can use to log progress for this particular request. - /// - /// Depending on the event sender config, the returned tracker might be a no-op. - pub async fn push_request( - &self, - f: impl FnOnce() -> PushRequestReceived, - ) -> Result { - self.request(f).await - } - /// Abstract request, to DRY the 3 to 4 request types. /// /// DRYing stuff with lots of bounds is no fun at all... - async fn request(&self, f: impl FnOnce() -> Req) -> Result + pub(crate) async fn request( + &self, + f: impl FnOnce() -> Req, + connection_id: u64, + request_id: u64, + ) -> Result where - Req: Request, - ProviderProto: From, - ProviderMessage: From>, - Req: Channels< + ProviderProto: From>, + ProviderMessage: From, ProviderProto>>, + RequestReceived: Channels< ProviderProto, Tx = oneshot::Sender, Rx = mpsc::Receiver, >, - ProviderProto: From>, - ProviderMessage: From, ProviderProto>>, - Notify: Channels>, + ProviderProto: From>>, + ProviderMessage: From>, ProviderProto>>, + Notify>: + Channels>, { - Ok(self.into_tracker(if let Some(client) = &self.inner { - match self.mask.get { - RequestMode::None => { - if self.mask.throttle == ThrottleMode::Throttle { - // if throttling is enabled, we need to call f to get connection_id and request_id - let msg = f(); - (RequestUpdates::None, msg.id()) - } else { - (RequestUpdates::None, (0, 0)) + Ok(self.into_tracker(( + if let Some(client) = &self.inner { + match self.mask.get { + RequestMode::None => RequestUpdates::None, + RequestMode::Notify => { + let msg = RequestReceived { + request: f(), + connection_id, + request_id, + }; + RequestUpdates::Disabled(client.notify_streaming(Notify(msg), 32).await?) + } + RequestMode::Request => { + let msg = RequestReceived { + request: f(), + connection_id, + request_id, + }; + let (tx, rx) = client.client_streaming(msg, 32).await?; + // bail out if the request is not allowed + rx.await??; + RequestUpdates::Disabled(tx) + } + RequestMode::NotifyLog => { + let msg = RequestReceived { + request: f(), + connection_id, + request_id, + }; + RequestUpdates::Active(client.notify_streaming(Notify(msg), 32).await?) + } + RequestMode::RequestLog => { + let msg = RequestReceived { + request: f(), + connection_id, + request_id, + }; + let (tx, rx) = client.client_streaming(msg, 32).await?; + // bail out if the request is not allowed + rx.await??; + RequestUpdates::Active(tx) } } - RequestMode::Notify => { - let msg = f(); - let id = msg.id(); - ( - RequestUpdates::Disabled(client.notify_streaming(Notify(msg), 32).await?), - id, - ) - } - RequestMode::Request => { - let msg = f(); - let id = msg.id(); - let (tx, rx) = client.client_streaming(msg, 32).await?; - // bail out if the request is not allowed - rx.await??; - (RequestUpdates::Disabled(tx), id) - } - RequestMode::NotifyLog => { - let msg = f(); - let id = msg.id(); - ( - RequestUpdates::Active(client.notify_streaming(Notify(msg), 32).await?), - id, - ) - } - RequestMode::RequestLog => { - let msg = f(); - let id = msg.id(); - let (tx, rx) = client.client_streaming(msg, 32).await?; - // bail out if the request is not allowed - rx.await??; - (RequestUpdates::Active(tx), id) - } - } - } else { - (RequestUpdates::None, (0, 0)) - })) + } else { + RequestUpdates::None + }, + connection_id, + request_id, + ))) } fn into_tracker( &self, - (updates, (connection_id, request_id)): (RequestUpdates, (u64, u64)), + (updates, connection_id, request_id): (RequestUpdates, u64, u64), ) -> RequestTracker { let throttle = match self.mask.throttle { ThrottleMode::None => None, @@ -394,46 +382,45 @@ pub enum ProviderProto { #[rpc(rx = mpsc::Receiver, tx = oneshot::Sender)] /// A new get request was received from the provider. - GetRequestReceived(GetRequestReceived), + GetRequestReceived(RequestReceived), #[rpc(rx = mpsc::Receiver, tx = NoSender)] /// A new get request was received from the provider. - GetRequestReceivedNotify(Notify), + GetRequestReceivedNotify(Notify>), /// A new get request was received from the provider. #[rpc(rx = mpsc::Receiver, tx = oneshot::Sender)] - GetManyRequestReceived(GetManyRequestReceived), + GetManyRequestReceived(RequestReceived), /// A new get request was received from the provider. #[rpc(rx = mpsc::Receiver, tx = NoSender)] - GetManyRequestReceivedNotify(Notify), + GetManyRequestReceivedNotify(Notify>), /// A new get request was received from the provider. #[rpc(rx = mpsc::Receiver, tx = oneshot::Sender)] - PushRequestReceived(PushRequestReceived), + PushRequestReceived(RequestReceived), /// A new get request was received from the provider. #[rpc(rx = mpsc::Receiver, tx = NoSender)] - PushRequestReceivedNotify(Notify), + PushRequestReceivedNotify(Notify>), + + /// A new get request was received from the provider. + #[rpc(rx = mpsc::Receiver, tx = oneshot::Sender)] + ObserveRequestReceived(RequestReceived), + + /// A new get request was received from the provider. + #[rpc(rx = mpsc::Receiver, tx = NoSender)] + ObserveRequestReceivedNotify(Notify>), #[rpc(tx = oneshot::Sender)] Throttle(Throttle), } -trait Request { - fn id(&self) -> (u64, u64); -} - mod proto { use iroh::NodeId; use serde::{Deserialize, Serialize}; - use super::Request; - use crate::{ - protocol::{GetManyRequest, GetRequest, PushRequest}, - provider::TransferStats, - Hash, - }; + use crate::{provider::TransferStats, Hash}; #[derive(Debug, Serialize, Deserialize)] pub struct ClientConnected { @@ -448,51 +435,13 @@ mod proto { /// A new get request was received from the provider. #[derive(Debug, Serialize, Deserialize)] - pub struct GetRequestReceived { - /// The connection id. Multiple requests can be sent over the same connection. - pub connection_id: u64, - /// The request id. There is a new id for each request. - pub request_id: u64, - /// The request - pub request: GetRequest, - } - - impl Request for GetRequestReceived { - fn id(&self) -> (u64, u64) { - (self.connection_id, self.request_id) - } - } - - #[derive(Debug, Serialize, Deserialize)] - pub struct GetManyRequestReceived { + pub struct RequestReceived { /// The connection id. Multiple requests can be sent over the same connection. pub connection_id: u64, /// The request id. There is a new id for each request. pub request_id: u64, /// The request - pub request: GetManyRequest, - } - - impl Request for GetManyRequestReceived { - fn id(&self) -> (u64, u64) { - (self.connection_id, self.request_id) - } - } - - #[derive(Debug, Serialize, Deserialize)] - pub struct PushRequestReceived { - /// The connection id. Multiple requests can be sent over the same connection. - pub connection_id: u64, - /// The request id. There is a new id for each request. - pub request_id: u64, - /// The request - pub request: PushRequest, - } - - impl Request for PushRequestReceived { - fn id(&self) -> (u64, u64) { - (self.connection_id, self.request_id) - } + pub request: R, } /// Request to throttle sending for a specific request. @@ -524,7 +473,7 @@ mod proto { #[derive(Debug, Serialize, Deserialize)] pub struct TransferAborted { - pub stats: Option>, + pub stats: Box, } /// Stream of updates for a single request diff --git a/src/tests.rs b/src/tests.rs index e99d0fe02..0dda88fee 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -16,10 +16,7 @@ use crate::{ hashseq::HashSeq, net_protocol::BlobsProtocol, protocol::{ChunkRangesSeq, GetManyRequest, ObserveRequest, PushRequest}, - provider::{ - events::{AbortReason, RequestUpdate}, - EventMask, EventSender2, ProviderMessage, - }, + provider::events::{AbortReason, EventMask, EventSender, ProviderMessage, RequestUpdate}, store::{ fs::{ tests::{create_n0_bao, test_data, INTERESTING_SIZES}, @@ -343,7 +340,7 @@ async fn two_nodes_get_many_mem() -> TestResult<()> { fn event_handler( allowed_nodes: impl IntoIterator, -) -> (EventSender2, watch::Receiver, AbortOnDropHandle<()>) { +) -> (EventSender, watch::Receiver, AbortOnDropHandle<()>) { let (count_tx, count_rx) = tokio::sync::watch::channel(0usize); let (events_tx, mut events_rx) = mpsc::channel::(16); let allowed_nodes = allowed_nodes.into_iter().collect::>(); @@ -373,7 +370,7 @@ fn event_handler( } } })); - (EventSender2::new(events_tx, EventMask::ALL), count_rx, task) + (EventSender::new(events_tx, EventMask::ALL), count_rx, task) } async fn two_nodes_push_blobs( @@ -488,12 +485,12 @@ async fn check_presence(store: &Store, sizes: &[usize]) -> TestResult<()> { } pub async fn node_test_setup_fs(db_path: PathBuf) -> TestResult<(Router, FsStore, PathBuf)> { - node_test_setup_with_events_fs(db_path, EventSender2::NONE).await + node_test_setup_with_events_fs(db_path, EventSender::NONE).await } pub async fn node_test_setup_with_events_fs( db_path: PathBuf, - events: EventSender2, + events: EventSender, ) -> TestResult<(Router, FsStore, PathBuf)> { let store = crate::store::fs::FsStore::load(&db_path).await?; let ep = Endpoint::builder().bind().await?; @@ -503,11 +500,11 @@ pub async fn node_test_setup_with_events_fs( } pub async fn node_test_setup_mem() -> TestResult<(Router, MemStore)> { - node_test_setup_with_events_mem(EventSender2::NONE).await + node_test_setup_with_events_mem(EventSender::NONE).await } pub async fn node_test_setup_with_events_mem( - events: EventSender2, + events: EventSender, ) -> TestResult<(Router, MemStore)> { let store = MemStore::new(); let ep = Endpoint::builder().bind().await?; @@ -609,7 +606,7 @@ async fn node_serve_hash_seq() -> TestResult<()> { let root = root_tt.hash; let endpoint = Endpoint::builder().discovery_n0().bind().await?; let blobs = - crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), EventSender2::NONE); + crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), EventSender::NONE); let r1 = Router::builder(endpoint) .accept(crate::protocol::ALPN, blobs) .spawn(); @@ -641,7 +638,7 @@ async fn node_serve_blobs() -> TestResult<()> { } let endpoint = Endpoint::builder().discovery_n0().bind().await?; let blobs = - crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), EventSender2::NONE); + crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), EventSender::NONE); let r1 = Router::builder(endpoint) .accept(crate::protocol::ALPN, blobs) .spawn(); @@ -683,8 +680,7 @@ async fn node_smoke(store: &Store) -> TestResult<()> { let tt = store.add_bytes(b"hello world".to_vec()).temp_tag().await?; let hash = *tt.hash(); let endpoint = Endpoint::builder().discovery_n0().bind().await?; - let blobs = - crate::net_protocol::BlobsProtocol::new(store, endpoint.clone(), EventSender2::NONE); + let blobs = crate::net_protocol::BlobsProtocol::new(store, endpoint.clone(), EventSender::NONE); let r1 = Router::builder(endpoint) .accept(crate::protocol::ALPN, blobs) .spawn(); From 1e4a58161f320b56e55bc90841f943c5d5676579 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Mon, 1 Sep 2025 15:35:53 +0200 Subject: [PATCH 07/23] minimize diff and required changes --- README.md | 4 ++-- examples/custom-protocol.rs | 6 ++---- examples/mdns-discovery.rs | 4 ++-- examples/random_store.rs | 2 +- examples/transfer.rs | 6 ++---- src/net_protocol.rs | 8 ++++---- src/tests.rs | 12 +++++------- 7 files changed, 18 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index 1a136e44d..2f374e8fb 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ Here is a basic example of how to set up `iroh-blobs` with `iroh`: ```rust,no_run use iroh::{protocol::Router, Endpoint}; -use iroh_blobs::{store::mem::MemStore, BlobsProtocol, provider::events::EventSender}; +use iroh_blobs::{store::mem::MemStore, BlobsProtocol}; #[tokio::main] async fn main() -> anyhow::Result<()> { @@ -44,7 +44,7 @@ async fn main() -> anyhow::Result<()> { // create a protocol handler using an in-memory blob store. let store = MemStore::new(); - let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender::NONE); + let blobs = BlobsProtocol::new(&store, endpoint.clone(), None); // build the router let router = Router::builder(endpoint) diff --git a/examples/custom-protocol.rs b/examples/custom-protocol.rs index d4d29e27f..c021b7f0a 100644 --- a/examples/custom-protocol.rs +++ b/examples/custom-protocol.rs @@ -48,9 +48,7 @@ use iroh::{ protocol::{AcceptError, ProtocolHandler, Router}, Endpoint, NodeId, }; -use iroh_blobs::{ - api::Store, provider::events::EventSender, store::mem::MemStore, BlobsProtocol, Hash, -}; +use iroh_blobs::{api::Store, store::mem::MemStore, BlobsProtocol, Hash}; mod common; use common::{get_or_generate_secret_key, setup_logging}; @@ -102,7 +100,7 @@ async fn listen(text: Vec) -> Result<()> { proto.insert_and_index(text).await?; } // Build the iroh-blobs protocol handler, which is used to download blobs. - let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender::NONE); + let blobs = BlobsProtocol::new(&store, endpoint.clone(), None); // create a router that handles both our custom protocol and the iroh-blobs protocol. let node = Router::builder(endpoint) diff --git a/examples/mdns-discovery.rs b/examples/mdns-discovery.rs index ab11bc864..b42f88f47 100644 --- a/examples/mdns-discovery.rs +++ b/examples/mdns-discovery.rs @@ -18,7 +18,7 @@ use clap::{Parser, Subcommand}; use iroh::{ discovery::mdns::MdnsDiscovery, protocol::Router, Endpoint, PublicKey, RelayMode, SecretKey, }; -use iroh_blobs::{provider::events::EventSender, store::mem::MemStore, BlobsProtocol, Hash}; +use iroh_blobs::{store::mem::MemStore, BlobsProtocol, Hash}; mod common; use common::{get_or_generate_secret_key, setup_logging}; @@ -68,7 +68,7 @@ async fn accept(path: &Path) -> Result<()> { .await?; let builder = Router::builder(endpoint.clone()); let store = MemStore::new(); - let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender::NONE); + let blobs = BlobsProtocol::new(&store, endpoint.clone(), None); let builder = builder.accept(iroh_blobs::ALPN, blobs.clone()); let node = builder.spawn(); diff --git a/examples/random_store.rs b/examples/random_store.rs index f36017e8d..f23e804e1 100644 --- a/examples/random_store.rs +++ b/examples/random_store.rs @@ -238,7 +238,7 @@ async fn provide(args: ProvideArgs) -> anyhow::Result<()> { .bind() .await?; let (dump_task, events_tx) = dump_provider_events(args.allow_push); - let blobs = iroh_blobs::BlobsProtocol::new(&store, endpoint.clone(), events_tx); + let blobs = iroh_blobs::BlobsProtocol::new(&store, endpoint.clone(), Some(events_tx)); let router = iroh::protocol::Router::builder(endpoint.clone()) .accept(iroh_blobs::ALPN, blobs) .spawn(); diff --git a/examples/transfer.rs b/examples/transfer.rs index 8347774ca..48fba6ba3 100644 --- a/examples/transfer.rs +++ b/examples/transfer.rs @@ -1,9 +1,7 @@ use std::path::PathBuf; use iroh::{protocol::Router, Endpoint}; -use iroh_blobs::{ - provider::events::EventSender, store::mem::MemStore, ticket::BlobTicket, BlobsProtocol, -}; +use iroh_blobs::{store::mem::MemStore, ticket::BlobTicket, BlobsProtocol}; #[tokio::main] async fn main() -> anyhow::Result<()> { @@ -14,7 +12,7 @@ async fn main() -> anyhow::Result<()> { // We initialize an in-memory backing store for iroh-blobs let store = MemStore::new(); // Then we initialize a struct that can accept blobs requests over iroh connections - let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender::NONE); + let blobs = BlobsProtocol::new(&store, endpoint.clone(), None); // Grab all passed in arguments, the first one is the binary itself, so we skip it. let args: Vec = std::env::args().skip(1).collect(); diff --git a/src/net_protocol.rs b/src/net_protocol.rs index aa45aa473..1927cd23d 100644 --- a/src/net_protocol.rs +++ b/src/net_protocol.rs @@ -7,7 +7,7 @@ //! ```rust //! # async fn example() -> anyhow::Result<()> { //! use iroh::{protocol::Router, Endpoint}; -//! use iroh_blobs::{provider::events::EventSender, store, BlobsProtocol}; +//! use iroh_blobs::{store, BlobsProtocol}; //! //! // create a store //! let store = store::fs::FsStore::load("blobs").await?; @@ -19,7 +19,7 @@ //! let endpoint = Endpoint::builder().discovery_n0().bind().await?; //! //! // create a blobs protocol handler -//! let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender::NONE); +//! let blobs = BlobsProtocol::new(&store, endpoint.clone(), None); //! //! // create a router and add the blobs protocol handler //! let router = Router::builder(endpoint) @@ -69,12 +69,12 @@ impl Deref for BlobsProtocol { } impl BlobsProtocol { - pub fn new(store: &Store, endpoint: Endpoint, events: EventSender) -> Self { + pub fn new(store: &Store, endpoint: Endpoint, events: Option) -> Self { Self { inner: Arc::new(BlobsInner { store: store.clone(), endpoint, - events, + events: events.unwrap_or(EventSender::NONE), }), } } diff --git a/src/tests.rs b/src/tests.rs index 0dda88fee..911a5cfd2 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -494,7 +494,7 @@ pub async fn node_test_setup_with_events_fs( ) -> TestResult<(Router, FsStore, PathBuf)> { let store = crate::store::fs::FsStore::load(&db_path).await?; let ep = Endpoint::builder().bind().await?; - let blobs = BlobsProtocol::new(&store, ep.clone(), events); + let blobs = BlobsProtocol::new(&store, ep.clone(), Some(events)); let router = Router::builder(ep).accept(crate::ALPN, blobs).spawn(); Ok((router, store, db_path)) } @@ -508,7 +508,7 @@ pub async fn node_test_setup_with_events_mem( ) -> TestResult<(Router, MemStore)> { let store = MemStore::new(); let ep = Endpoint::builder().bind().await?; - let blobs = BlobsProtocol::new(&store, ep.clone(), events); + let blobs = BlobsProtocol::new(&store, ep.clone(), Some(events)); let router = Router::builder(ep).accept(crate::ALPN, blobs).spawn(); Ok((router, store)) } @@ -605,8 +605,7 @@ async fn node_serve_hash_seq() -> TestResult<()> { let root_tt = store.add_bytes(hash_seq).await?; let root = root_tt.hash; let endpoint = Endpoint::builder().discovery_n0().bind().await?; - let blobs = - crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), EventSender::NONE); + let blobs = crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), None); let r1 = Router::builder(endpoint) .accept(crate::protocol::ALPN, blobs) .spawn(); @@ -637,8 +636,7 @@ async fn node_serve_blobs() -> TestResult<()> { tts.push(store.add_bytes(test_data(size)).await?); } let endpoint = Endpoint::builder().discovery_n0().bind().await?; - let blobs = - crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), EventSender::NONE); + let blobs = crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), None); let r1 = Router::builder(endpoint) .accept(crate::protocol::ALPN, blobs) .spawn(); @@ -680,7 +678,7 @@ async fn node_smoke(store: &Store) -> TestResult<()> { let tt = store.add_bytes(b"hello world".to_vec()).temp_tag().await?; let hash = *tt.hash(); let endpoint = Endpoint::builder().discovery_n0().bind().await?; - let blobs = crate::net_protocol::BlobsProtocol::new(store, endpoint.clone(), EventSender::NONE); + let blobs = crate::net_protocol::BlobsProtocol::new(store, endpoint.clone(), None); let r1 = Router::builder(endpoint) .accept(crate::protocol::ALPN, blobs) .spawn(); From 64499308baebe3c51f0e38a6a2ab8e986f06f56b Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Mon, 1 Sep 2025 17:12:59 +0200 Subject: [PATCH 08/23] clippy --- src/api/blobs.rs | 17 --- src/net_protocol.rs | 2 +- src/protocol.rs | 47 +++++---- src/provider.rs | 235 ++++++++++++++--------------------------- src/provider/events.rs | 151 +++++++++++++------------- src/tests.rs | 4 +- 6 files changed, 192 insertions(+), 264 deletions(-) diff --git a/src/api/blobs.rs b/src/api/blobs.rs index d00a0a940..8b618de1f 100644 --- a/src/api/blobs.rs +++ b/src/api/blobs.rs @@ -57,7 +57,6 @@ use super::{ }; use crate::{ api::proto::{BatchRequest, ImportByteStreamUpdate}, - provider::WriterContext, store::IROH_BLOCK_SIZE, util::temp_tag::TempTag, BlobFormat, Hash, HashAndFormat, @@ -1167,19 +1166,3 @@ pub(crate) trait WriteProgress { /// Notify the progress writer that a transfer has started. async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64); } - -impl WriteProgress for WriterContext { - async fn notify_payload_write(&mut self, _index: u64, offset: u64, len: usize) { - let end_offset = offset + len as u64; - self.payload_bytes_written += len as u64; - self.tracker.transfer_progress(end_offset).await.ok(); - } - - fn log_other_write(&mut self, len: usize) { - self.other_bytes_written += len as u64; - } - - async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64) { - self.tracker.transfer_started(index, hash, size).await.ok(); - } -} diff --git a/src/net_protocol.rs b/src/net_protocol.rs index 1927cd23d..269ef0e14 100644 --- a/src/net_protocol.rs +++ b/src/net_protocol.rs @@ -74,7 +74,7 @@ impl BlobsProtocol { inner: Arc::new(BlobsInner { store: store.clone(), endpoint, - events: events.unwrap_or(EventSender::NONE), + events: events.unwrap_or(EventSender::DEFAULT), }), } } diff --git a/src/protocol.rs b/src/protocol.rs index 74e0f986d..05ee00678 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -392,7 +392,7 @@ pub use range_spec::{ChunkRangesSeq, NonEmptyRequestRangeSpecIter, RangeSpec}; use snafu::{GenerateImplicitData, Snafu}; use tokio::io::AsyncReadExt; -use crate::{api::blobs::Bitfield, provider::CountingReader, BlobFormat, Hash, HashAndFormat}; +use crate::{api::blobs::Bitfield, provider::RecvStreamExt, BlobFormat, Hash, HashAndFormat}; /// Maximum message size is limited to 100MiB for now. pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024; @@ -441,9 +441,7 @@ pub enum RequestType { } impl Request { - pub async fn read_async( - reader: &mut CountingReader<&mut iroh::endpoint::RecvStream>, - ) -> io::Result { + pub async fn read_async(reader: &mut iroh::endpoint::RecvStream) -> io::Result<(Self, usize)> { let request_type = reader.read_u8().await?; let request_type: RequestType = postcard::from_bytes(std::slice::from_ref(&request_type)) .map_err(|_| { @@ -453,22 +451,31 @@ impl Request { ) })?; Ok(match request_type { - RequestType::Get => reader - .read_to_end_as::(MAX_MESSAGE_SIZE) - .await? - .into(), - RequestType::GetMany => reader - .read_to_end_as::(MAX_MESSAGE_SIZE) - .await? - .into(), - RequestType::Observe => reader - .read_to_end_as::(MAX_MESSAGE_SIZE) - .await? - .into(), - RequestType::Push => reader - .read_length_prefixed::(MAX_MESSAGE_SIZE) - .await? - .into(), + RequestType::Get => { + let (r, size) = reader + .read_to_end_as::(MAX_MESSAGE_SIZE) + .await?; + (r.into(), size) + } + RequestType::GetMany => { + let (r, size) = reader + .read_to_end_as::(MAX_MESSAGE_SIZE) + .await?; + (r.into(), size) + } + RequestType::Observe => { + let (r, size) = reader + .read_to_end_as::(MAX_MESSAGE_SIZE) + .await?; + (r.into(), size) + } + RequestType::Push => { + let r = reader + .read_length_prefixed::(MAX_MESSAGE_SIZE) + .await?; + let size = postcard::experimental::serialized_size(&r).unwrap(); + (r.into(), size) + } _ => { return Err(io::Error::new( io::ErrorKind::InvalidData, diff --git a/src/provider.rs b/src/provider.rs index a98c1b3a7..1683daa57 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -6,9 +6,6 @@ use std::{ fmt::Debug, io, - ops::{Deref, DerefMut}, - pin::Pin, - task::Poll, time::{Duration, Instant}, }; @@ -18,7 +15,7 @@ use iroh::endpoint::{self, RecvStream, SendStream}; use n0_future::StreamExt; use quinn::{ClosedStream, ConnectionError, ReadToEndError}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use tokio::{io::AsyncRead, select}; +use tokio::select; use tracing::{debug, debug_span, warn, Instrument}; use crate::{ @@ -53,23 +50,9 @@ pub struct TransferStats { pub duration: Duration, } -/// Read the request from the getter. -/// -/// Will fail if there is an error while reading, or if no valid request is sent. -/// -/// This will read exactly the number of bytes needed for the request, and -/// leave the rest of the stream for the caller to read. -/// -/// It is up to the caller do decide if there should be more data. -pub async fn read_request(context: &mut StreamData) -> Result { - let mut counting = CountingReader::new(&mut context.reader); - let res = Request::read_async(&mut counting).await?; - context.other_bytes_read += counting.read(); - Ok(res) -} - +/// A pair of [`SendStream`] and [`RecvStream`] with additional context data. #[derive(Debug)] -pub struct StreamData { +pub struct StreamPair { t0: Instant, connection_id: u64, request_id: u64, @@ -79,7 +62,7 @@ pub struct StreamData { events: EventSender, } -impl StreamData { +impl StreamPair { pub async fn accept( conn: &endpoint::Connection, events: &EventSender, @@ -96,8 +79,22 @@ impl StreamData { }) } + /// Read the request. + /// + /// Will fail if there is an error while reading, or if no valid request is sent. + /// + /// This will read exactly the number of bytes needed for the request, and + /// leave the rest of the stream for the caller to read. + /// + /// It is up to the caller do decide if there should be more data. + pub async fn read_request(&mut self) -> Result { + let (res, size) = Request::read_async(&mut self.reader).await?; + self.other_bytes_read += size as u64; + Ok(res) + } + /// We are done with reading. Return a ProgressWriter that contains the read stats and connection id - async fn into_writer( + pub async fn into_writer( mut self, tracker: RequestTracker, ) -> Result { @@ -113,8 +110,6 @@ impl StreamData { self.writer, WriterContext { t0: self.t0, - connection_id: self.connection_id, - request_id: self.request_id, other_bytes_read: self.other_bytes_read, payload_bytes_written: 0, other_bytes_written: 0, @@ -123,7 +118,7 @@ impl StreamData { )) } - async fn into_reader( + pub async fn into_reader( mut self, tracker: RequestTracker, ) -> Result { @@ -135,19 +130,17 @@ impl StreamData { .ok(); return Err(e); }; - Ok(ProgressReader::new( - self.reader, - ReaderContext { + Ok(ProgressReader { + inner: self.reader, + context: ReaderContext { t0: self.t0, - connection_id: self.connection_id, - request_id: self.request_id, other_bytes_read: self.other_bytes_read, tracker, }, - )) + }) } - async fn get_request( + pub async fn get_request( &self, f: impl FnOnce() -> GetRequest, ) -> Result { @@ -156,7 +149,7 @@ impl StreamData { .await } - async fn get_many_request( + pub async fn get_many_request( &self, f: impl FnOnce() -> GetManyRequest, ) -> Result { @@ -165,7 +158,7 @@ impl StreamData { .await } - async fn push_request( + pub async fn push_request( &self, f: impl FnOnce() -> PushRequest, ) -> Result { @@ -174,7 +167,7 @@ impl StreamData { .await } - async fn observe_request( + pub async fn observe_request( &self, f: impl FnOnce() -> ObserveRequest, ) -> Result { @@ -194,31 +187,17 @@ impl StreamData { } #[derive(Debug)] -pub struct ReaderContext { +struct ReaderContext { /// The start time of the transfer - pub t0: Instant, - /// The connection ID from the connection - pub connection_id: u64, - /// The request ID from the recv stream - pub request_id: u64, + t0: Instant, /// The number of bytes read from the stream - pub other_bytes_read: u64, + other_bytes_read: u64, /// Progress tracking for the request - pub tracker: RequestTracker, + tracker: RequestTracker, } impl ReaderContext { - pub fn new(context: StreamData, tracker: RequestTracker) -> Self { - Self { - t0: context.t0, - connection_id: context.connection_id, - request_id: context.request_id, - other_bytes_read: context.other_bytes_read, - tracker, - } - } - - pub fn stats(&self) -> TransferStats { + fn stats(&self) -> TransferStats { TransferStats { payload_bytes_sent: 0, other_bytes_sent: 0, @@ -229,37 +208,21 @@ impl ReaderContext { } #[derive(Debug)] -pub struct WriterContext { +pub(crate) struct WriterContext { /// The start time of the transfer - pub t0: Instant, - /// The connection ID from the connection - pub connection_id: u64, - /// The request ID from the recv stream - pub request_id: u64, + t0: Instant, /// The number of bytes read from the stream - pub other_bytes_read: u64, + other_bytes_read: u64, /// The number of payload bytes written to the stream - pub payload_bytes_written: u64, + payload_bytes_written: u64, /// The number of bytes written that are not part of the payload - pub other_bytes_written: u64, + other_bytes_written: u64, /// Way to report progress - pub tracker: RequestTracker, + tracker: RequestTracker, } impl WriterContext { - pub fn new(context: &StreamData, tracker: RequestTracker) -> Self { - Self { - t0: context.t0, - connection_id: context.connection_id, - request_id: context.request_id, - other_bytes_read: context.other_bytes_read, - payload_bytes_written: 0, - other_bytes_written: 0, - tracker, - } - } - - pub fn stats(&self) -> TransferStats { + fn stats(&self) -> TransferStats { TransferStats { payload_bytes_sent: self.payload_bytes_written, other_bytes_sent: self.other_bytes_written, @@ -269,6 +232,22 @@ impl WriterContext { } } +impl WriteProgress for WriterContext { + async fn notify_payload_write(&mut self, _index: u64, offset: u64, len: usize) { + let end_offset = offset + len as u64; + self.payload_bytes_written += len as u64; + self.tracker.transfer_progress(end_offset).await.ok(); + } + + fn log_other_write(&mut self, len: usize) { + self.other_bytes_written += len as u64; + } + + async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64) { + self.tracker.transfer_started(index, hash, size).await.ok(); + } +} + /// Wrapper for a [`quinn::SendStream`] with additional per request information. #[derive(Debug)] pub struct ProgressWriter { @@ -283,34 +262,22 @@ impl ProgressWriter { } async fn transfer_aborted(&self) { - self.tracker - .transfer_aborted(|| Box::new(self.stats())) + self.context + .tracker + .transfer_aborted(|| Box::new(self.context.stats())) .await .ok(); } async fn transfer_completed(&self) { - self.tracker - .transfer_completed(|| Box::new(self.stats())) + self.context + .tracker + .transfer_completed(|| Box::new(self.context.stats())) .await .ok(); } } -impl Deref for ProgressWriter { - type Target = WriterContext; - - fn deref(&self) -> &Self::Target { - &self.context - } -} - -impl DerefMut for ProgressWriter { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.context - } -} - /// Handle a single connection. pub async fn handle_connection( connection: endpoint::Connection, @@ -334,7 +301,7 @@ pub async fn handle_connection( debug!("client not authorized to connect: {cause}"); return; } - while let Ok(context) = StreamData::accept(&connection, &progress).await { + while let Ok(context) = StreamPair::accept(&connection, &progress).await { let span = debug_span!("stream", stream_id = %context.request_id); let store = store.clone(); tokio::spawn(handle_stream(store, context).instrument(span)); @@ -348,10 +315,10 @@ pub async fn handle_connection( .await } -async fn handle_stream(store: Store, mut context: StreamData) -> anyhow::Result<()> { +async fn handle_stream(store: Store, mut context: StreamPair) -> anyhow::Result<()> { // 1. Decode the request. debug!("reading request"); - let request = read_request(&mut context).await?; + let request = context.read_request().await?; match request { Request::Get(request) => { @@ -553,79 +520,41 @@ pub struct ProgressReader { } impl ProgressReader { - pub fn new(inner: RecvStream, context: ReaderContext) -> Self { - Self { inner, context } - } - async fn transfer_aborted(&self) { - self.tracker - .transfer_aborted(|| Box::new(self.stats())) + self.context + .tracker + .transfer_aborted(|| Box::new(self.context.stats())) .await .ok(); } async fn transfer_completed(&self) { - self.tracker - .transfer_completed(|| Box::new(self.stats())) + self.context + .tracker + .transfer_completed(|| Box::new(self.context.stats())) .await .ok(); } } -impl Deref for ProgressReader { - type Target = ReaderContext; - - fn deref(&self) -> &Self::Target { - &self.context - } +pub(crate) trait RecvStreamExt { + async fn read_to_end_as( + &mut self, + max_size: usize, + ) -> io::Result<(T, usize)>; } -impl DerefMut for ProgressReader { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.context - } -} - -pub struct CountingReader { - pub inner: R, - pub read: u64, -} - -impl CountingReader { - pub fn new(inner: R) -> Self { - Self { inner, read: 0 } - } - - pub fn read(&self) -> u64 { - self.read - } -} - -impl CountingReader<&mut iroh::endpoint::RecvStream> { - pub async fn read_to_end_as(&mut self, max_size: usize) -> io::Result { +impl RecvStreamExt for RecvStream { + async fn read_to_end_as( + &mut self, + max_size: usize, + ) -> io::Result<(T, usize)> { let data = self - .inner .read_to_end(max_size) .await .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; let value = postcard::from_bytes(&data) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - self.read += data.len() as u64; - Ok(value) - } -} - -impl AsyncRead for CountingReader { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> Poll> { - let this = self.get_mut(); - let result = Pin::new(&mut this.inner).poll_read(cx, buf); - if let Poll::Ready(Ok(())) = result { - this.read += buf.filled().len() as u64; - } - result + Ok((value, data.len())) } } diff --git a/src/provider/events.rs b/src/provider/events.rs index c8b94c8b8..1ecf13cb7 100644 --- a/src/provider/events.rs +++ b/src/provider/events.rs @@ -52,6 +52,8 @@ pub enum RequestMode { /// We get a request for each request, and can reject incoming requests. /// We also get detailed transfer events. RequestLog, + /// This request type is completely disabled. All requests will be rejected. + Disabled, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] @@ -108,24 +110,35 @@ impl From for ClientError { pub type EventResult = Result<(), AbortReason>; pub type ClientResult = Result<(), ClientError>; -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct EventMask { - connected: ConnectMode, - get: RequestMode, - get_many: RequestMode, - push: RequestMode, - observe: ObserveMode, + /// Connection event mask + pub connected: ConnectMode, + /// Get request event mask + pub get: RequestMode, + /// Get many request event mask + pub get_many: RequestMode, + /// Push request event mask + pub push: RequestMode, + /// Observe request event mask + pub observe: ObserveMode, /// throttling is somewhat costly, so you can disable it completely - throttle: ThrottleMode, + pub throttle: ThrottleMode, +} + +impl Default for EventMask { + fn default() -> Self { + Self::DEFAULT + } } impl EventMask { - /// Everything is disabled. You won't get any events, but there is also no runtime cost. - pub const NONE: Self = Self { + /// All event notifications are fully disabled. Push requests are disabled by default. + pub const DEFAULT: Self = Self { connected: ConnectMode::None, get: RequestMode::None, get_many: RequestMode::None, - push: RequestMode::None, + push: RequestMode::Disabled, throttle: ThrottleMode::None, observe: ObserveMode::None, }; @@ -139,16 +152,6 @@ impl EventMask { throttle: ThrottleMode::Throttle, observe: ObserveMode::Request, }; - - /// You get notified for every single thing that is going on, but can't intervene. - pub const NOTIFY_ALL: Self = Self { - connected: ConnectMode::Notify, - get: RequestMode::NotifyLog, - get_many: RequestMode::NotifyLog, - push: RequestMode::NotifyLog, - throttle: ThrottleMode::None, - observe: ObserveMode::Notify, - }; } /// Newtype wrapper that wraps an event so that it is a distinct type for the notify variant. @@ -248,8 +251,8 @@ impl RequestTracker { /// can have a response. impl EventSender { /// A client that does not send anything. - pub const NONE: Self = Self { - mask: EventMask::NONE, + pub const DEFAULT: Self = Self { + mask: EventMask::DEFAULT, inner: None, }; @@ -262,20 +265,22 @@ impl EventSender { /// A new client has been connected. pub async fn client_connected(&self, f: impl Fn() -> ClientConnected) -> ClientResult { - Ok(if let Some(client) = &self.inner { + if let Some(client) = &self.inner { match self.mask.connected { ConnectMode::None => {} ConnectMode::Notify => client.notify(Notify(f())).await?, ConnectMode::Request => client.rpc(f()).await??, } - }) + }; + Ok(()) } /// A new client has been connected. pub async fn connection_closed(&self, f: impl Fn() -> ConnectionClosed) -> ClientResult { - Ok(if let Some(client) = &self.inner { + if let Some(client) = &self.inner { client.notify(f()).await?; - }) + }; + Ok(()) } /// Abstract request, to DRY the 3 to 4 request types. @@ -300,58 +305,61 @@ impl EventSender { Notify>: Channels>, { - Ok(self.into_tracker(( - if let Some(client) = &self.inner { - match self.mask.get { - RequestMode::None => RequestUpdates::None, - RequestMode::Notify => { - let msg = RequestReceived { - request: f(), - connection_id, - request_id, - }; - RequestUpdates::Disabled(client.notify_streaming(Notify(msg), 32).await?) - } - RequestMode::Request => { - let msg = RequestReceived { - request: f(), - connection_id, - request_id, - }; - let (tx, rx) = client.client_streaming(msg, 32).await?; - // bail out if the request is not allowed - rx.await??; - RequestUpdates::Disabled(tx) - } - RequestMode::NotifyLog => { - let msg = RequestReceived { - request: f(), - connection_id, - request_id, - }; - RequestUpdates::Active(client.notify_streaming(Notify(msg), 32).await?) - } - RequestMode::RequestLog => { - let msg = RequestReceived { - request: f(), - connection_id, - request_id, - }; - let (tx, rx) = client.client_streaming(msg, 32).await?; - // bail out if the request is not allowed - rx.await??; - RequestUpdates::Active(tx) - } + let client = self.inner.as_ref(); + Ok(self.create_tracker(( + match self.mask.get { + RequestMode::None => RequestUpdates::None, + RequestMode::Notify if client.is_some() => { + let msg = RequestReceived { + request: f(), + connection_id, + request_id, + }; + RequestUpdates::Disabled( + client.unwrap().notify_streaming(Notify(msg), 32).await?, + ) } - } else { - RequestUpdates::None + RequestMode::Request if client.is_some() => { + let msg = RequestReceived { + request: f(), + connection_id, + request_id, + }; + let (tx, rx) = client.unwrap().client_streaming(msg, 32).await?; + // bail out if the request is not allowed + rx.await??; + RequestUpdates::Disabled(tx) + } + RequestMode::NotifyLog if client.is_some() => { + let msg = RequestReceived { + request: f(), + connection_id, + request_id, + }; + RequestUpdates::Active(client.unwrap().notify_streaming(Notify(msg), 32).await?) + } + RequestMode::RequestLog if client.is_some() => { + let msg = RequestReceived { + request: f(), + connection_id, + request_id, + }; + let (tx, rx) = client.unwrap().client_streaming(msg, 32).await?; + // bail out if the request is not allowed + rx.await??; + RequestUpdates::Active(tx) + } + RequestMode::Disabled => { + return Err(ClientError::Permission); + } + _ => RequestUpdates::None, }, connection_id, request_id, ))) } - fn into_tracker( + fn create_tracker( &self, (updates, connection_id, request_id): (RequestUpdates, u64, u64), ) -> RequestTracker { @@ -372,6 +380,7 @@ pub enum ProviderProto { /// A new client connected to the provider. #[rpc(tx = oneshot::Sender)] ClientConnected(ClientConnected), + /// A new client connected to the provider. Notify variant. #[rpc(tx = NoSender)] ClientConnectedNotify(Notify), diff --git a/src/tests.rs b/src/tests.rs index 911a5cfd2..40d9519c9 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -485,7 +485,7 @@ async fn check_presence(store: &Store, sizes: &[usize]) -> TestResult<()> { } pub async fn node_test_setup_fs(db_path: PathBuf) -> TestResult<(Router, FsStore, PathBuf)> { - node_test_setup_with_events_fs(db_path, EventSender::NONE).await + node_test_setup_with_events_fs(db_path, EventSender::DEFAULT).await } pub async fn node_test_setup_with_events_fs( @@ -500,7 +500,7 @@ pub async fn node_test_setup_with_events_fs( } pub async fn node_test_setup_mem() -> TestResult<(Router, MemStore)> { - node_test_setup_with_events_mem(EventSender::NONE).await + node_test_setup_with_events_mem(EventSender::DEFAULT).await } pub async fn node_test_setup_with_events_mem( From b26aefb4e2926fd980084cc95df56d8c3ff6ed45 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Mon, 1 Sep 2025 18:00:55 +0200 Subject: [PATCH 09/23] Footgun protection --- Cargo.lock | 4 ++ Cargo.toml | 2 +- examples/random_store.rs | 2 +- src/provider/events.rs | 94 ++++++++++++++++++++++++++++++++++++++-- src/tests.rs | 11 ++++- 5 files changed, 106 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1a4de777e..4068354f7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1943,6 +1943,8 @@ dependencies = [ [[package]] name = "irpc" version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9f8f1d0987ea9da3d74698f921d0a817a214c83b2635a33ed4bc3efa4de1acd" dependencies = [ "anyhow", "futures-buffered", @@ -1964,6 +1966,8 @@ dependencies = [ [[package]] name = "irpc-derive" version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e0b26b834d401a046dd9d47bc236517c746eddbb5d25ff3e1a6075bfa4eebdb" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 3a642632c..bcd5f42d0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,7 +42,7 @@ self_cell = "1.1.0" genawaiter = { version = "0.99.1", features = ["futures03"] } iroh-base = "0.91.1" reflink-copy = "0.1.24" -irpc = { version = "0.7.0", features = ["rpc", "quinn_endpoint_setup", "spans", "stream", "derive"], default-features = false, path = "../irpc" } +irpc = { version = "0.7.0", features = ["rpc", "quinn_endpoint_setup", "spans", "stream", "derive"], default-features = false } iroh-metrics = { version = "0.35" } [dev-dependencies] diff --git a/examples/random_store.rs b/examples/random_store.rs index f23e804e1..c4c30348b 100644 --- a/examples/random_store.rs +++ b/examples/random_store.rs @@ -176,7 +176,7 @@ pub fn dump_provider_events(allow_push: bool) -> (tokio::task::JoinHandle<()>, E } } }); - (dump_task, EventSender::new(tx, EventMask::ALL)) + (dump_task, EventSender::new(tx, EventMask::ALL_READONLY)) } #[tokio::main] diff --git a/src/provider/events.rs b/src/provider/events.rs index 1ecf13cb7..b7fc58daa 100644 --- a/src/provider/events.rs +++ b/src/provider/events.rs @@ -1,4 +1,4 @@ -use std::fmt::Debug; +use std::{fmt::Debug, ops::Deref}; use irpc::{ channel::{mpsc, none::NoSender, oneshot}, @@ -143,12 +143,16 @@ impl EventMask { observe: ObserveMode::None, }; - /// You get asked for every single thing that is going on and can intervene/throttle. - pub const ALL: Self = Self { + /// All event notifications for read-only requests are fully enabled. + /// + /// If you want to enable push requests, which can write to the local store, you + /// need to do it manually. Providing constants that have push enabled would + /// risk misuse. + pub const ALL_READONLY: Self = Self { connected: ConnectMode::Request, get: RequestMode::RequestLog, get_many: RequestMode::RequestLog, - push: RequestMode::RequestLog, + push: RequestMode::Disabled, throttle: ThrottleMode::Throttle, observe: ObserveMode::Request, }; @@ -158,6 +162,14 @@ impl EventMask { #[derive(Debug, Serialize, Deserialize)] pub struct Notify(T); +impl Deref for Notify { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + #[derive(Debug, Default, Clone)] pub struct EventSender { mask: EventMask, @@ -263,6 +275,80 @@ impl EventSender { } } + /// Log request events at trace level. + pub fn tracing(&self, mask: EventMask) -> Self { + use tracing::trace; + let (tx, mut rx) = tokio::sync::mpsc::channel(32); + n0_future::task::spawn(async move { + fn log_request_events( + mut rx: irpc::channel::mpsc::Receiver, + connection_id: u64, + request_id: u64, + ) { + n0_future::task::spawn(async move { + while let Ok(Some(update)) = rx.recv().await { + trace!(%connection_id, %request_id, "{update:?}"); + } + }); + } + while let Some(msg) = rx.recv().await { + match msg { + ProviderMessage::ClientConnected(_) => todo!(), + ProviderMessage::ClientConnectedNotify(msg) => { + trace!("{:?}", msg.inner); + } + ProviderMessage::ConnectionClosed(msg) => { + trace!("{:?}", msg.inner); + } + ProviderMessage::GetRequestReceived(msg) => { + trace!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::GetRequestReceivedNotify(msg) => { + trace!("{:?}", msg.inner); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::GetManyRequestReceived(msg) => { + trace!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::GetManyRequestReceivedNotify(msg) => { + trace!("{:?}", msg.inner); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::PushRequestReceived(msg) => { + trace!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::PushRequestReceivedNotify(msg) => { + trace!("{:?}", msg.inner); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::ObserveRequestReceived(msg) => { + trace!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::ObserveRequestReceivedNotify(msg) => { + trace!("{:?}", msg.inner); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::Throttle(msg) => { + trace!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + } + } + } + }); + Self { + mask, + inner: Some(irpc::Client::from(tx)), + } + } + /// A new client has been connected. pub async fn client_connected(&self, f: impl Fn() -> ClientConnected) -> ClientResult { if let Some(client) = &self.inner { diff --git a/src/tests.rs b/src/tests.rs index 40d9519c9..dc38eb436 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -370,7 +370,16 @@ fn event_handler( } } })); - (EventSender::new(events_tx, EventMask::ALL), count_rx, task) + ( + EventSender::new( + events_tx, + EventMask { + ..EventMask::ALL_READONLY + }, + ), + count_rx, + task, + ) } async fn two_nodes_push_blobs( From 6d86e4f2403d73c46e1cf66909d5c38297cb3e36 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Tue, 2 Sep 2025 12:51:02 +0200 Subject: [PATCH 10/23] Add limit example This shows how to limit serving content in various ways - by node id - by content hash - throttling - limiting max number of connections --- examples/limit.rs | 341 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 341 insertions(+) create mode 100644 examples/limit.rs diff --git a/examples/limit.rs b/examples/limit.rs new file mode 100644 index 000000000..f23fad96c --- /dev/null +++ b/examples/limit.rs @@ -0,0 +1,341 @@ +/// Example how to limit blob requests by hash and node id, and to add +/// restrictions on limited content. +mod common; +use std::{ + collections::{HashMap, HashSet}, + path::PathBuf, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, +}; + +use clap::Parser; +use common::setup_logging; +use iroh::{NodeId, SecretKey, Watcher}; +use iroh_blobs::{ + provider::events::{ + AbortReason, ConnectMode, EventMask, EventSender, ProviderMessage, RequestMode, + ThrottleMode, + }, + store::mem::MemStore, + ticket::BlobTicket, + BlobsProtocol, Hash, +}; +use rand::thread_rng; + +use crate::common::get_or_generate_secret_key; + +#[derive(Debug, Parser)] +#[command(version, about)] +pub enum Args { + ByNodeId { + /// Path for files to add + paths: Vec, + #[clap(long("allow"))] + /// Nodes that are allowed to download content. + allowed_nodes: Vec, + #[clap(long, default_value_t = 1)] + secrets: usize, + }, + ByHash { + /// Path for files to add + paths: Vec, + }, + Throttle { + /// Path for files to add + paths: Vec, + #[clap(long, default_value = "100")] + delay_ms: u64, + }, + MaxConnections { + /// Path for files to add + paths: Vec, + #[clap(long, default_value = "1")] + max_connections: usize, + }, + Get { + /// Ticket for the blob to download + ticket: BlobTicket, + }, +} + +fn limit_by_node_id(allowed_nodes: HashSet) -> EventSender { + let (tx, mut rx) = tokio::sync::mpsc::channel(32); + n0_future::task::spawn(async move { + while let Some(msg) = rx.recv().await { + match msg { + ProviderMessage::ClientConnected(msg) => { + let node_id = msg.node_id; + let res = if allowed_nodes.contains(&node_id) { + println!("Client connected: {node_id}"); + Ok(()) + } else { + println!("Client rejected: {node_id}"); + Err(AbortReason::Permission) + }; + msg.tx.send(res).await.ok(); + } + _ => {} + } + } + }); + EventSender::new( + tx, + EventMask { + connected: ConnectMode::Request, + ..EventMask::DEFAULT + }, + ) +} + +fn limit_by_hash(allowed_hashes: HashSet) -> EventSender { + let (tx, mut rx) = tokio::sync::mpsc::channel(32); + n0_future::task::spawn(async move { + while let Some(msg) = rx.recv().await { + match msg { + ProviderMessage::GetRequestReceived(msg) => { + let res = if !msg.request.ranges.is_blob() { + println!("HashSeq request not allowed"); + Err(AbortReason::Permission) + } else if !allowed_hashes.contains(&msg.request.hash) { + println!("Request for hash {} not allowed", msg.request.hash); + Err(AbortReason::Permission) + } else { + println!("Request for hash {} allowed", msg.request.hash); + Ok(()) + }; + msg.tx.send(res).await.ok(); + } + _ => {} + } + } + }); + EventSender::new( + tx, + EventMask { + get: RequestMode::Request, + ..EventMask::DEFAULT + }, + ) +} + +fn throttle(delay_ms: u64) -> EventSender { + let (tx, mut rx) = tokio::sync::mpsc::channel(32); + n0_future::task::spawn(async move { + while let Some(msg) = rx.recv().await { + match msg { + ProviderMessage::Throttle(msg) => { + n0_future::task::spawn(async move { + println!( + "Throttling {} {}, {}ms", + msg.connection_id, msg.request_id, delay_ms + ); + tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; + msg.tx.send(Ok(())).await.ok(); + }); + } + _ => {} + } + } + }); + EventSender::new( + tx, + EventMask { + throttle: ThrottleMode::Throttle, + ..EventMask::DEFAULT + }, + ) +} + +fn limit_max_connections(max_connections: usize) -> EventSender { + let (tx, mut rx) = tokio::sync::mpsc::channel(32); + n0_future::task::spawn(async move { + let requests = Arc::new(AtomicUsize::new(0)); + while let Some(msg) = rx.recv().await { + match msg { + ProviderMessage::GetRequestReceived(mut msg) => { + let connection_id = msg.connection_id; + let request_id = msg.request_id; + let res = requests.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |n| { + if n >= max_connections { + None + } else { + Some(n + 1) + } + }); + match res { + Ok(n) => { + println!("Accepting request {n}, id ({connection_id},{request_id})"); + msg.tx.send(Ok(())).await.ok(); + } + Err(_) => { + println!( + "Connection limit of {} exceeded, rejecting request", + max_connections + ); + msg.tx.send(Err(AbortReason::RateLimited)).await.ok(); + continue; + } + } + let requests = requests.clone(); + n0_future::task::spawn(async move { + // just drain the per request events + while let Ok(Some(_)) = msg.rx.recv().await {} + println!("Stopping request, id ({connection_id},{request_id})"); + requests.fetch_sub(1, Ordering::SeqCst); + }); + } + _ => {} + } + } + }); + EventSender::new( + tx, + EventMask { + get: RequestMode::RequestLog, + ..EventMask::DEFAULT + }, + ) +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + setup_logging(); + let args = Args::parse(); + match args { + Args::Get { ticket } => { + let secret = get_or_generate_secret_key()?; + let endpoint = iroh::Endpoint::builder() + .secret_key(secret) + .discovery_n0() + .bind() + .await?; + let connection = endpoint + .connect(ticket.node_addr().clone(), iroh_blobs::ALPN) + .await?; + let (data, stats) = iroh_blobs::get::request::get_blob(connection, ticket.hash()) + .bytes_and_stats() + .await?; + println!("Downloaded {} bytes", data.len()); + println!("Stats: {:?}", stats); + } + Args::ByNodeId { + paths, + allowed_nodes, + secrets, + } => { + let mut allowed_nodes = allowed_nodes.into_iter().collect::>(); + if secrets > 0 { + println!("Generating {secrets} new secret keys for allowed nodes:"); + let mut rand = thread_rng(); + for _ in 0..secrets { + let secret = SecretKey::generate(&mut rand); + let public = secret.public(); + allowed_nodes.insert(public); + println!("IROH_SECRET={}", hex::encode(secret.to_bytes())); + } + } + let endpoint = iroh::Endpoint::builder().discovery_n0().bind().await?; + let store = MemStore::new(); + let mut hashes = HashMap::new(); + for path in paths { + let tag = store.add_path(&path).await?; + hashes.insert(path, tag.hash); + } + let _ = endpoint.home_relay().initialized().await; + let addr = endpoint.node_addr().initialized().await; + let events = limit_by_node_id(allowed_nodes.clone()); + let blobs = BlobsProtocol::new(&store, endpoint.clone(), Some(events)); + let router = iroh::protocol::Router::builder(endpoint) + .accept(iroh_blobs::ALPN, blobs) + .spawn(); + println!("Node id: {}\n", router.endpoint().node_id()); + for id in &allowed_nodes { + println!("Allowed node: {id}"); + } + println!(); + for (path, hash) in &hashes { + let ticket = BlobTicket::new(addr.clone(), *hash, iroh_blobs::BlobFormat::Raw); + println!("{}: {ticket}", path.display()); + } + tokio::signal::ctrl_c().await?; + router.shutdown().await?; + } + Args::ByHash { paths } => { + let endpoint = iroh::Endpoint::builder().discovery_n0().bind().await?; + let store = MemStore::new(); + let mut hashes = HashMap::new(); + let mut allowed_hashes = HashSet::new(); + for (i, path) in paths.into_iter().enumerate() { + let tag = store.add_path(&path).await?; + hashes.insert(path, tag.hash); + if i == 0 { + allowed_hashes.insert(tag.hash); + } + } + let _ = endpoint.home_relay().initialized().await; + let addr = endpoint.node_addr().initialized().await; + let events = limit_by_hash(allowed_hashes.clone()); + let blobs = BlobsProtocol::new(&store, endpoint.clone(), Some(events)); + let router = iroh::protocol::Router::builder(endpoint) + .accept(iroh_blobs::ALPN, blobs) + .spawn(); + for (i, (path, hash)) in hashes.iter().enumerate() { + let ticket = BlobTicket::new(addr.clone(), *hash, iroh_blobs::BlobFormat::Raw); + let permitted = if i == 0 { "" } else { "limited" }; + println!("{}: {ticket} ({permitted})", path.display()); + } + tokio::signal::ctrl_c().await?; + router.shutdown().await?; + } + Args::Throttle { paths, delay_ms } => { + let endpoint = iroh::Endpoint::builder().discovery_n0().bind().await?; + let store = MemStore::new(); + let mut hashes = HashMap::new(); + for path in paths { + let tag = store.add_path(&path).await?; + hashes.insert(path, tag.hash); + } + let _ = endpoint.home_relay().initialized().await; + let addr = endpoint.node_addr().initialized().await; + let events = throttle(delay_ms); + let blobs = BlobsProtocol::new(&store, endpoint.clone(), Some(events)); + let router = iroh::protocol::Router::builder(endpoint) + .accept(iroh_blobs::ALPN, blobs) + .spawn(); + for (path, hash) in hashes { + let ticket = BlobTicket::new(addr.clone(), hash, iroh_blobs::BlobFormat::Raw); + println!("{}: {ticket}", path.display()); + } + tokio::signal::ctrl_c().await?; + router.shutdown().await?; + } + Args::MaxConnections { + paths, + max_connections, + } => { + let endpoint = iroh::Endpoint::builder().discovery_n0().bind().await?; + let store = MemStore::new(); + let mut hashes = HashMap::new(); + for path in paths { + let tag = store.add_path(&path).await?; + hashes.insert(path, tag.hash); + } + let _ = endpoint.home_relay().initialized().await; + let addr = endpoint.node_addr().initialized().await; + let events = limit_max_connections(max_connections); + let blobs = BlobsProtocol::new(&store, endpoint.clone(), Some(events)); + let router = iroh::protocol::Router::builder(endpoint) + .accept(iroh_blobs::ALPN, blobs) + .spawn(); + for (path, hash) in hashes { + let ticket = BlobTicket::new(addr.clone(), hash, iroh_blobs::BlobFormat::Raw); + println!("{}: {ticket}", path.display()); + } + tokio::signal::ctrl_c().await?; + router.shutdown().await?; + } + } + Ok(()) +} From 4b87b6dcbc03540a3f36f93c0d8930386729b37b Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Tue, 2 Sep 2025 14:18:16 +0200 Subject: [PATCH 11/23] Add len to notify_payload_write --- src/provider.rs | 7 ++++--- src/provider/events.rs | 5 ++++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/provider.rs b/src/provider.rs index 1683daa57..883f97811 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -234,9 +234,10 @@ impl WriterContext { impl WriteProgress for WriterContext { async fn notify_payload_write(&mut self, _index: u64, offset: u64, len: usize) { - let end_offset = offset + len as u64; - self.payload_bytes_written += len as u64; - self.tracker.transfer_progress(end_offset).await.ok(); + let len = len as u64; + let end_offset = offset + len; + self.payload_bytes_written += len; + self.tracker.transfer_progress(len, end_offset).await.ok(); } fn log_other_write(&mut self, len: usize) { diff --git a/src/provider/events.rs b/src/provider/events.rs index b7fc58daa..35b641011 100644 --- a/src/provider/events.rs +++ b/src/provider/events.rs @@ -222,7 +222,7 @@ impl RequestTracker { } /// Transfer progress for the previously reported blob, end_offset is the new end offset in bytes. - pub async fn transfer_progress(&mut self, end_offset: u64) -> ClientResult { + pub async fn transfer_progress(&mut self, len: u64, end_offset: u64) -> ClientResult { if let RequestUpdates::Active(tx) = &mut self.updates { tx.try_send(RequestUpdate::Progress(TransferProgress { end_offset })) .await?; @@ -232,6 +232,7 @@ impl RequestTracker { .rpc(Throttle { connection_id: *connection_id, request_id: *request_id, + size: len as u64, }) .await??; } @@ -546,6 +547,8 @@ mod proto { pub connection_id: u64, /// The request id. There is a new id for each request. pub request_id: u64, + /// Size of the chunk to be throttled. This will usually be 16 KiB. + pub size: u64, } #[derive(Debug, Serialize, Deserialize)] From f992a448a55aa84da6a9d5e1e5f9f203aad266d6 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Tue, 2 Sep 2025 14:18:54 +0200 Subject: [PATCH 12/23] clippy --- examples/limit.rs | 133 +++++++++++++++++++---------------------- src/provider/events.rs | 2 +- 2 files changed, 61 insertions(+), 74 deletions(-) diff --git a/examples/limit.rs b/examples/limit.rs index f23fad96c..09e1be132 100644 --- a/examples/limit.rs +++ b/examples/limit.rs @@ -64,19 +64,16 @@ fn limit_by_node_id(allowed_nodes: HashSet) -> EventSender { let (tx, mut rx) = tokio::sync::mpsc::channel(32); n0_future::task::spawn(async move { while let Some(msg) = rx.recv().await { - match msg { - ProviderMessage::ClientConnected(msg) => { - let node_id = msg.node_id; - let res = if allowed_nodes.contains(&node_id) { - println!("Client connected: {node_id}"); - Ok(()) - } else { - println!("Client rejected: {node_id}"); - Err(AbortReason::Permission) - }; - msg.tx.send(res).await.ok(); - } - _ => {} + if let ProviderMessage::ClientConnected(msg) = msg { + let node_id = msg.node_id; + let res = if allowed_nodes.contains(&node_id) { + println!("Client connected: {node_id}"); + Ok(()) + } else { + println!("Client rejected: {node_id}"); + Err(AbortReason::Permission) + }; + msg.tx.send(res).await.ok(); } } }); @@ -93,21 +90,18 @@ fn limit_by_hash(allowed_hashes: HashSet) -> EventSender { let (tx, mut rx) = tokio::sync::mpsc::channel(32); n0_future::task::spawn(async move { while let Some(msg) = rx.recv().await { - match msg { - ProviderMessage::GetRequestReceived(msg) => { - let res = if !msg.request.ranges.is_blob() { - println!("HashSeq request not allowed"); - Err(AbortReason::Permission) - } else if !allowed_hashes.contains(&msg.request.hash) { - println!("Request for hash {} not allowed", msg.request.hash); - Err(AbortReason::Permission) - } else { - println!("Request for hash {} allowed", msg.request.hash); - Ok(()) - }; - msg.tx.send(res).await.ok(); - } - _ => {} + if let ProviderMessage::GetRequestReceived(msg) = msg { + let res = if !msg.request.ranges.is_blob() { + println!("HashSeq request not allowed"); + Err(AbortReason::Permission) + } else if !allowed_hashes.contains(&msg.request.hash) { + println!("Request for hash {} not allowed", msg.request.hash); + Err(AbortReason::Permission) + } else { + println!("Request for hash {} allowed", msg.request.hash); + Ok(()) + }; + msg.tx.send(res).await.ok(); } } }); @@ -124,18 +118,15 @@ fn throttle(delay_ms: u64) -> EventSender { let (tx, mut rx) = tokio::sync::mpsc::channel(32); n0_future::task::spawn(async move { while let Some(msg) = rx.recv().await { - match msg { - ProviderMessage::Throttle(msg) => { - n0_future::task::spawn(async move { - println!( - "Throttling {} {}, {}ms", - msg.connection_id, msg.request_id, delay_ms - ); - tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; - msg.tx.send(Ok(())).await.ok(); - }); - } - _ => {} + if let ProviderMessage::Throttle(msg) = msg { + n0_future::task::spawn(async move { + println!( + "Throttling {} {}, {}ms", + msg.connection_id, msg.request_id, delay_ms + ); + tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; + msg.tx.send(Ok(())).await.ok(); + }); } } }); @@ -153,40 +144,36 @@ fn limit_max_connections(max_connections: usize) -> EventSender { n0_future::task::spawn(async move { let requests = Arc::new(AtomicUsize::new(0)); while let Some(msg) = rx.recv().await { - match msg { - ProviderMessage::GetRequestReceived(mut msg) => { - let connection_id = msg.connection_id; - let request_id = msg.request_id; - let res = requests.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |n| { - if n >= max_connections { - None - } else { - Some(n + 1) - } - }); - match res { - Ok(n) => { - println!("Accepting request {n}, id ({connection_id},{request_id})"); - msg.tx.send(Ok(())).await.ok(); - } - Err(_) => { - println!( - "Connection limit of {} exceeded, rejecting request", - max_connections - ); - msg.tx.send(Err(AbortReason::RateLimited)).await.ok(); - continue; - } + if let ProviderMessage::GetRequestReceived(mut msg) = msg { + let connection_id = msg.connection_id; + let request_id = msg.request_id; + let res = requests.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |n| { + if n >= max_connections { + None + } else { + Some(n + 1) + } + }); + match res { + Ok(n) => { + println!("Accepting request {n}, id ({connection_id},{request_id})"); + msg.tx.send(Ok(())).await.ok(); + } + Err(_) => { + println!( + "Connection limit of {max_connections} exceeded, rejecting request" + ); + msg.tx.send(Err(AbortReason::RateLimited)).await.ok(); + continue; } - let requests = requests.clone(); - n0_future::task::spawn(async move { - // just drain the per request events - while let Ok(Some(_)) = msg.rx.recv().await {} - println!("Stopping request, id ({connection_id},{request_id})"); - requests.fetch_sub(1, Ordering::SeqCst); - }); } - _ => {} + let requests = requests.clone(); + n0_future::task::spawn(async move { + // just drain the per request events + while let Ok(Some(_)) = msg.rx.recv().await {} + println!("Stopping request, id ({connection_id},{request_id})"); + requests.fetch_sub(1, Ordering::SeqCst); + }); } } }); @@ -218,7 +205,7 @@ async fn main() -> anyhow::Result<()> { .bytes_and_stats() .await?; println!("Downloaded {} bytes", data.len()); - println!("Stats: {:?}", stats); + println!("Stats: {stats:?}"); } Args::ByNodeId { paths, diff --git a/src/provider/events.rs b/src/provider/events.rs index 35b641011..5e5972167 100644 --- a/src/provider/events.rs +++ b/src/provider/events.rs @@ -232,7 +232,7 @@ impl RequestTracker { .rpc(Throttle { connection_id: *connection_id, request_id: *request_id, - size: len as u64, + size: len, }) .await??; } From 4bddf77939f672d841f09fc2c522176c1e94a775 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Tue, 2 Sep 2025 14:27:31 +0200 Subject: [PATCH 13/23] nicer connection counter --- examples/limit.rs | 44 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/examples/limit.rs b/examples/limit.rs index 09e1be132..e7b86a8ca 100644 --- a/examples/limit.rs +++ b/examples/limit.rs @@ -29,6 +29,7 @@ use crate::common::get_or_generate_secret_key; #[derive(Debug, Parser)] #[command(version, about)] pub enum Args { + /// Limit requests by node id ByNodeId { /// Path for files to add paths: Vec, @@ -38,16 +39,19 @@ pub enum Args { #[clap(long, default_value_t = 1)] secrets: usize, }, + /// Limit requests by hash, only first hash is allowed ByHash { /// Path for files to add paths: Vec, }, + /// Throttle requests Throttle { /// Path for files to add paths: Vec, #[clap(long, default_value = "100")] delay_ms: u64, }, + /// Limit maximum number of connections. MaxConnections { /// Path for files to add paths: Vec, @@ -140,20 +144,39 @@ fn throttle(delay_ms: u64) -> EventSender { } fn limit_max_connections(max_connections: usize) -> EventSender { + #[derive(Default, Debug, Clone)] + struct ConnectionCounter(Arc<(AtomicUsize, usize)>); + + impl ConnectionCounter { + fn new(max: usize) -> Self { + Self(Arc::new((Default::default(), max))) + } + + fn inc(&self) -> Result { + let (c, max) = &*self.0; + c.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |n| { + if n >= *max { + None + } else { + Some(n + 1) + } + }) + } + + fn dec(&self) { + let (c, _) = &*self.0; + c.fetch_sub(1, Ordering::SeqCst); + } + } + let (tx, mut rx) = tokio::sync::mpsc::channel(32); n0_future::task::spawn(async move { - let requests = Arc::new(AtomicUsize::new(0)); + let requests = ConnectionCounter::new(max_connections); while let Some(msg) = rx.recv().await { if let ProviderMessage::GetRequestReceived(mut msg) = msg { let connection_id = msg.connection_id; let request_id = msg.request_id; - let res = requests.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |n| { - if n >= max_connections { - None - } else { - Some(n + 1) - } - }); + let res = requests.inc(); match res { Ok(n) => { println!("Accepting request {n}, id ({connection_id},{request_id})"); @@ -170,9 +193,12 @@ fn limit_max_connections(max_connections: usize) -> EventSender { let requests = requests.clone(); n0_future::task::spawn(async move { // just drain the per request events + // + // Note that we have requested updates for the request, now we also need to process them + // otherwise the request will be aborted! while let Ok(Some(_)) = msg.rx.recv().await {} println!("Stopping request, id ({connection_id},{request_id})"); - requests.fetch_sub(1, Ordering::SeqCst); + requests.dec(); }); } } From 33333a9afc5659aa1c341474ee40a9952f0e49da Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Tue, 2 Sep 2025 14:38:36 +0200 Subject: [PATCH 14/23] Add docs for the limit example. --- examples/limit.rs | 35 ++++++++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/examples/limit.rs b/examples/limit.rs index e7b86a8ca..fff910da3 100644 --- a/examples/limit.rs +++ b/examples/limit.rs @@ -1,5 +1,13 @@ /// Example how to limit blob requests by hash and node id, and to add -/// restrictions on limited content. +/// throttling or limiting the maximum number of connections. +/// +/// Limiting is done via a fn that returns an EventSender and internally +/// makes liberal use of spawn to spawn background tasks. +/// +/// This is fine, since the tasks will terminate as soon as the [BlobsProtocol] +/// instance holding the [EventSender] will be dropped. But for production +/// grade code you might nevertheless put the tasks into a [tokio::task::JoinSet] or +/// [n0_future::FuturesUnordered]. mod common; use std::{ collections::{HashMap, HashSet}, @@ -31,33 +39,37 @@ use crate::common::get_or_generate_secret_key; pub enum Args { /// Limit requests by node id ByNodeId { - /// Path for files to add + /// Path for files to add. paths: Vec, #[clap(long("allow"))] /// Nodes that are allowed to download content. allowed_nodes: Vec, + /// Number of secrets to generate for allowed node ids. #[clap(long, default_value_t = 1)] secrets: usize, }, /// Limit requests by hash, only first hash is allowed ByHash { - /// Path for files to add + /// Path for files to add. paths: Vec, }, /// Throttle requests Throttle { - /// Path for files to add + /// Path for files to add. paths: Vec, + /// Delay in milliseconds after sending a chunk group of 16 KiB. #[clap(long, default_value = "100")] delay_ms: u64, }, /// Limit maximum number of connections. MaxConnections { - /// Path for files to add + /// Path for files to add. paths: Vec, + /// Maximum number of concurrent get requests. #[clap(long, default_value = "1")] max_connections: usize, }, + /// Get a blob. Just for completeness sake. Get { /// Ticket for the blob to download ticket: BlobTicket, @@ -84,6 +96,8 @@ fn limit_by_node_id(allowed_nodes: HashSet) -> EventSender { EventSender::new( tx, EventMask { + // We want a request for each incoming connection so we can accept + // or reject them. We don't need any other events. connected: ConnectMode::Request, ..EventMask::DEFAULT }, @@ -112,6 +126,9 @@ fn limit_by_hash(allowed_hashes: HashSet) -> EventSender { EventSender::new( tx, EventMask { + // We want to get a request for each get request that we can answer + // with OK or not OK depending on the hash. We do not want detailed + // events once it has been decided to handle a request. get: RequestMode::Request, ..EventMask::DEFAULT }, @@ -128,6 +145,8 @@ fn throttle(delay_ms: u64) -> EventSender { "Throttling {} {}, {}ms", msg.connection_id, msg.request_id, delay_ms ); + // we could compute the delay from the size of the data to have a fixed rate. + // but the size is almost always 16 KiB (16 chunks). tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; msg.tx.send(Ok(())).await.ok(); }); @@ -137,6 +156,8 @@ fn throttle(delay_ms: u64) -> EventSender { EventSender::new( tx, EventMask { + // We want to get requests for each sent user data blob, so we can add a delay. + // Other than that, we don't need any events. throttle: ThrottleMode::Throttle, ..EventMask::DEFAULT }, @@ -206,6 +227,10 @@ fn limit_max_connections(max_connections: usize) -> EventSender { EventSender::new( tx, EventMask { + // For each get request, we want to get a request so we can decide + // based on the current connection count if we want to accept or reject. + // We also want detailed logging of events for the get request, so we can + // detect when the request is finished one way or another. get: RequestMode::RequestLog, ..EventMask::DEFAULT }, From 9a62a581b1cc903f5d8c5dcf8fd947a3f6cbd980 Mon Sep 17 00:00:00 2001 From: Frando Date: Wed, 3 Sep 2025 13:43:44 +0200 Subject: [PATCH 15/23] refactor: make limits example more DRY --- examples/limit.rs | 184 +++++++++++++++++++---------------------- src/provider/events.rs | 8 ++ 2 files changed, 94 insertions(+), 98 deletions(-) diff --git a/examples/limit.rs b/examples/limit.rs index fff910da3..2be92358a 100644 --- a/examples/limit.rs +++ b/examples/limit.rs @@ -18,9 +18,10 @@ use std::{ }, }; +use anyhow::Result; use clap::Parser; use common::setup_logging; -use iroh::{NodeId, SecretKey, Watcher}; +use iroh::{protocol::Router, NodeAddr, NodeId, SecretKey, Watcher}; use iroh_blobs::{ provider::events::{ AbortReason, ConnectMode, EventMask, EventSender, ProviderMessage, RequestMode, @@ -28,7 +29,7 @@ use iroh_blobs::{ }, store::mem::MemStore, ticket::BlobTicket, - BlobsProtocol, Hash, + BlobFormat, BlobsProtocol, Hash, }; use rand::thread_rng; @@ -77,7 +78,13 @@ pub enum Args { } fn limit_by_node_id(allowed_nodes: HashSet) -> EventSender { - let (tx, mut rx) = tokio::sync::mpsc::channel(32); + let mask = EventMask { + // We want a request for each incoming connection so we can accept + // or reject them. We don't need any other events. + connected: ConnectMode::Request, + ..EventMask::DEFAULT + }; + let (tx, mut rx) = EventSender::channel(32, mask); n0_future::task::spawn(async move { while let Some(msg) = rx.recv().await { if let ProviderMessage::ClientConnected(msg) = msg { @@ -93,19 +100,18 @@ fn limit_by_node_id(allowed_nodes: HashSet) -> EventSender { } } }); - EventSender::new( - tx, - EventMask { - // We want a request for each incoming connection so we can accept - // or reject them. We don't need any other events. - connected: ConnectMode::Request, - ..EventMask::DEFAULT - }, - ) + tx } fn limit_by_hash(allowed_hashes: HashSet) -> EventSender { - let (tx, mut rx) = tokio::sync::mpsc::channel(32); + let mask = EventMask { + // We want to get a request for each get request that we can answer + // with OK or not OK depending on the hash. We do not want detailed + // events once it has been decided to handle a request. + get: RequestMode::Request, + ..EventMask::DEFAULT + }; + let (tx, mut rx) = EventSender::channel(32, mask); n0_future::task::spawn(async move { while let Some(msg) = rx.recv().await { if let ProviderMessage::GetRequestReceived(msg) = msg { @@ -123,20 +129,17 @@ fn limit_by_hash(allowed_hashes: HashSet) -> EventSender { } } }); - EventSender::new( - tx, - EventMask { - // We want to get a request for each get request that we can answer - // with OK or not OK depending on the hash. We do not want detailed - // events once it has been decided to handle a request. - get: RequestMode::Request, - ..EventMask::DEFAULT - }, - ) + tx } fn throttle(delay_ms: u64) -> EventSender { - let (tx, mut rx) = tokio::sync::mpsc::channel(32); + let mask = EventMask { + // We want to get requests for each sent user data blob, so we can add a delay. + // Other than that, we don't need any events. + throttle: ThrottleMode::Throttle, + ..EventMask::DEFAULT + }; + let (tx, mut rx) = EventSender::channel(32, mask); n0_future::task::spawn(async move { while let Some(msg) = rx.recv().await { if let ProviderMessage::Throttle(msg) = msg { @@ -153,15 +156,7 @@ fn throttle(delay_ms: u64) -> EventSender { } } }); - EventSender::new( - tx, - EventMask { - // We want to get requests for each sent user data blob, so we can add a delay. - // Other than that, we don't need any events. - throttle: ThrottleMode::Throttle, - ..EventMask::DEFAULT - }, - ) + tx } fn limit_max_connections(max_connections: usize) -> EventSender { @@ -190,7 +185,15 @@ fn limit_max_connections(max_connections: usize) -> EventSender { } } - let (tx, mut rx) = tokio::sync::mpsc::channel(32); + let mask = EventMask { + // For each get request, we want to get a request so we can decide + // based on the current connection count if we want to accept or reject. + // We also want detailed logging of events for the get request, so we can + // detect when the request is finished one way or another. + get: RequestMode::RequestLog, + ..EventMask::DEFAULT + }; + let (tx, mut rx) = EventSender::channel(32, mask); n0_future::task::spawn(async move { let requests = ConnectionCounter::new(max_connections); while let Some(msg) = rx.recv().await { @@ -224,21 +227,11 @@ fn limit_max_connections(max_connections: usize) -> EventSender { } } }); - EventSender::new( - tx, - EventMask { - // For each get request, we want to get a request so we can decide - // based on the current connection count if we want to accept or reject. - // We also want detailed logging of events for the get request, so we can - // detect when the request is finished one way or another. - get: RequestMode::RequestLog, - ..EventMask::DEFAULT - }, - ) + tx } #[tokio::main] -async fn main() -> anyhow::Result<()> { +async fn main() -> Result<()> { setup_logging(); let args = Args::parse(); match args { @@ -274,35 +267,28 @@ async fn main() -> anyhow::Result<()> { println!("IROH_SECRET={}", hex::encode(secret.to_bytes())); } } - let endpoint = iroh::Endpoint::builder().discovery_n0().bind().await?; + let store = MemStore::new(); - let mut hashes = HashMap::new(); - for path in paths { - let tag = store.add_path(&path).await?; - hashes.insert(path, tag.hash); - } - let _ = endpoint.home_relay().initialized().await; - let addr = endpoint.node_addr().initialized().await; + let hashes = add_paths(&store, paths).await?; let events = limit_by_node_id(allowed_nodes.clone()); - let blobs = BlobsProtocol::new(&store, endpoint.clone(), Some(events)); - let router = iroh::protocol::Router::builder(endpoint) - .accept(iroh_blobs::ALPN, blobs) - .spawn(); + let (router, addr) = setup(MemStore::new(), events).await?; + + for (path, hash) in hashes { + let ticket = BlobTicket::new(addr.clone(), hash, BlobFormat::Raw); + println!("{}: {ticket}", path.display()); + } + println!(); println!("Node id: {}\n", router.endpoint().node_id()); for id in &allowed_nodes { println!("Allowed node: {id}"); } - println!(); - for (path, hash) in &hashes { - let ticket = BlobTicket::new(addr.clone(), *hash, iroh_blobs::BlobFormat::Raw); - println!("{}: {ticket}", path.display()); - } + tokio::signal::ctrl_c().await?; router.shutdown().await?; } Args::ByHash { paths } => { - let endpoint = iroh::Endpoint::builder().discovery_n0().bind().await?; let store = MemStore::new(); + let mut hashes = HashMap::new(); let mut allowed_hashes = HashSet::new(); for (i, path) in paths.into_iter().enumerate() { @@ -312,15 +298,12 @@ async fn main() -> anyhow::Result<()> { allowed_hashes.insert(tag.hash); } } - let _ = endpoint.home_relay().initialized().await; - let addr = endpoint.node_addr().initialized().await; - let events = limit_by_hash(allowed_hashes.clone()); - let blobs = BlobsProtocol::new(&store, endpoint.clone(), Some(events)); - let router = iroh::protocol::Router::builder(endpoint) - .accept(iroh_blobs::ALPN, blobs) - .spawn(); + + let events = limit_by_hash(allowed_hashes); + let (router, addr) = setup(MemStore::new(), events).await?; + for (i, (path, hash)) in hashes.iter().enumerate() { - let ticket = BlobTicket::new(addr.clone(), *hash, iroh_blobs::BlobFormat::Raw); + let ticket = BlobTicket::new(addr.clone(), *hash, BlobFormat::Raw); let permitted = if i == 0 { "" } else { "limited" }; println!("{}: {ticket} ({permitted})", path.display()); } @@ -328,22 +311,12 @@ async fn main() -> anyhow::Result<()> { router.shutdown().await?; } Args::Throttle { paths, delay_ms } => { - let endpoint = iroh::Endpoint::builder().discovery_n0().bind().await?; let store = MemStore::new(); - let mut hashes = HashMap::new(); - for path in paths { - let tag = store.add_path(&path).await?; - hashes.insert(path, tag.hash); - } - let _ = endpoint.home_relay().initialized().await; - let addr = endpoint.node_addr().initialized().await; + let hashes = add_paths(&store, paths).await?; let events = throttle(delay_ms); - let blobs = BlobsProtocol::new(&store, endpoint.clone(), Some(events)); - let router = iroh::protocol::Router::builder(endpoint) - .accept(iroh_blobs::ALPN, blobs) - .spawn(); + let (router, addr) = setup(MemStore::new(), events).await?; for (path, hash) in hashes { - let ticket = BlobTicket::new(addr.clone(), hash, iroh_blobs::BlobFormat::Raw); + let ticket = BlobTicket::new(addr.clone(), hash, BlobFormat::Raw); println!("{}: {ticket}", path.display()); } tokio::signal::ctrl_c().await?; @@ -353,22 +326,12 @@ async fn main() -> anyhow::Result<()> { paths, max_connections, } => { - let endpoint = iroh::Endpoint::builder().discovery_n0().bind().await?; let store = MemStore::new(); - let mut hashes = HashMap::new(); - for path in paths { - let tag = store.add_path(&path).await?; - hashes.insert(path, tag.hash); - } - let _ = endpoint.home_relay().initialized().await; - let addr = endpoint.node_addr().initialized().await; + let hashes = add_paths(&store, paths).await?; let events = limit_max_connections(max_connections); - let blobs = BlobsProtocol::new(&store, endpoint.clone(), Some(events)); - let router = iroh::protocol::Router::builder(endpoint) - .accept(iroh_blobs::ALPN, blobs) - .spawn(); + let (router, addr) = setup(MemStore::new(), events).await?; for (path, hash) in hashes { - let ticket = BlobTicket::new(addr.clone(), hash, iroh_blobs::BlobFormat::Raw); + let ticket = BlobTicket::new(addr.clone(), hash, BlobFormat::Raw); println!("{}: {ticket}", path.display()); } tokio::signal::ctrl_c().await?; @@ -377,3 +340,28 @@ async fn main() -> anyhow::Result<()> { } Ok(()) } + +async fn add_paths(store: &MemStore, paths: Vec) -> Result> { + let mut hashes = HashMap::new(); + for path in paths { + let tag = store.add_path(&path).await?; + hashes.insert(path, tag.hash); + } + Ok(hashes) +} + +async fn setup(store: MemStore, events: EventSender) -> Result<(Router, NodeAddr)> { + let secret = get_or_generate_secret_key()?; + let endpoint = iroh::Endpoint::builder() + .discovery_n0() + .secret_key(secret) + .bind() + .await?; + let _ = endpoint.home_relay().initialized().await; + let addr = endpoint.node_addr().initialized().await; + let blobs = BlobsProtocol::new(&store, endpoint.clone(), Some(events)); + let router = Router::builder(endpoint) + .accept(iroh_blobs::ALPN, blobs) + .spawn(); + Ok((router, addr)) +} diff --git a/src/provider/events.rs b/src/provider/events.rs index 5e5972167..f2bddb23c 100644 --- a/src/provider/events.rs +++ b/src/provider/events.rs @@ -276,6 +276,14 @@ impl EventSender { } } + pub fn channel( + capacity: usize, + mask: EventMask, + ) -> (Self, tokio::sync::mpsc::Receiver) { + let (tx, rx) = tokio::sync::mpsc::channel(capacity); + (Self::new(tx, mask), rx) + } + /// Log request events at trace level. pub fn tracing(&self, mask: EventMask) -> Self { use tracing::trace; From 071db5e0a6c69edcac2a3d42de7103cfabcd866c Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Wed, 3 Sep 2025 15:05:11 +0200 Subject: [PATCH 16/23] Make sure to send a proper reset code when resetting a connection so the other side can know if reconnecting is OK --- examples/limit.rs | 12 +++--- src/protocol.rs | 7 +++ src/provider.rs | 97 +++++++++++++++++++++++++++--------------- src/provider/events.rs | 47 ++++++++++++++------ 4 files changed, 110 insertions(+), 53 deletions(-) diff --git a/examples/limit.rs b/examples/limit.rs index 2be92358a..830574fcc 100644 --- a/examples/limit.rs +++ b/examples/limit.rs @@ -234,14 +234,14 @@ fn limit_max_connections(max_connections: usize) -> EventSender { async fn main() -> Result<()> { setup_logging(); let args = Args::parse(); + let secret = get_or_generate_secret_key()?; + let endpoint = iroh::Endpoint::builder() + .secret_key(secret) + .discovery_n0() + .bind() + .await?; match args { Args::Get { ticket } => { - let secret = get_or_generate_secret_key()?; - let endpoint = iroh::Endpoint::builder() - .secret_key(secret) - .discovery_n0() - .bind() - .await?; let connection = endpoint .connect(ticket.node_addr().clone(), iroh_blobs::ALPN) .await?; diff --git a/src/protocol.rs b/src/protocol.rs index 05ee00678..ce10865a5 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -397,6 +397,13 @@ use crate::{api::blobs::Bitfield, provider::RecvStreamExt, BlobFormat, Hash, Has /// Maximum message size is limited to 100MiB for now. pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024; +/// Error code for a permission error +pub const ERR_PERMISSION: VarInt = VarInt::from_u32(1u32); +/// Error code for when a request is aborted due to a rate limit +pub const ERR_LIMIT: VarInt = VarInt::from_u32(2u32); +/// Error code for when a request is aborted due to internal error +pub const ERR_INTERNAL: VarInt = VarInt::from_u32(3u32); + /// The ALPN used with quic for the iroh blobs protocol. pub const ALPN: &[u8] = b"/iroh-bytes/4"; diff --git a/src/provider.rs b/src/provider.rs index 883f97811..49b57e13a 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -20,13 +20,12 @@ use tracing::{debug, debug_span, warn, Instrument}; use crate::{ api::{ - self, blobs::{Bitfield, WriteProgress}, - Store, + ExportBaoResult, Store, }, hashseq::HashSeq, protocol::{GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request}, - provider::events::{ClientConnected, ClientError, ConnectionClosed, RequestTracker}, + provider::events::{ClientConnected, ConnectionClosed, RequestTracker}, Hash, }; pub mod events; @@ -94,7 +93,7 @@ impl StreamPair { } /// We are done with reading. Return a ProgressWriter that contains the read stats and connection id - pub async fn into_writer( + async fn into_writer( mut self, tracker: RequestTracker, ) -> Result { @@ -118,7 +117,7 @@ impl StreamPair { )) } - pub async fn into_reader( + async fn into_reader( mut self, tracker: RequestTracker, ) -> Result { @@ -141,39 +140,71 @@ impl StreamPair { } pub async fn get_request( - &self, + mut self, f: impl FnOnce() -> GetRequest, - ) -> Result { - self.events + ) -> anyhow::Result { + let res = self + .events .request(f, self.connection_id, self.request_id) - .await + .await; + match res { + Err(e) => { + self.writer.reset(e.code()).ok(); + Err(e.into()) + } + Ok(tracker) => Ok(self.into_writer(tracker).await?), + } } pub async fn get_many_request( - &self, + mut self, f: impl FnOnce() -> GetManyRequest, - ) -> Result { - self.events + ) -> anyhow::Result { + let res = self + .events .request(f, self.connection_id, self.request_id) - .await + .await; + match res { + Err(e) => { + self.writer.reset(e.code()).ok(); + Err(e.into()) + } + Ok(tracker) => Ok(self.into_writer(tracker).await?), + } } pub async fn push_request( - &self, + mut self, f: impl FnOnce() -> PushRequest, - ) -> Result { - self.events + ) -> anyhow::Result { + let res = self + .events .request(f, self.connection_id, self.request_id) - .await + .await; + match res { + Err(e) => { + self.writer.reset(e.code()).ok(); + Err(e.into()) + } + Ok(tracker) => Ok(self.into_reader(tracker).await?), + } } pub async fn observe_request( - &self, + mut self, f: impl FnOnce() -> ObserveRequest, - ) -> Result { - self.events + ) -> anyhow::Result { + let res = self + .events .request(f, self.connection_id, self.request_id) - .await + .await; + match res { + Err(e) => { + self.writer.reset(e.code()).ok(); + Err(e.into()) + } + Ok(tracker) => Ok(self.into_writer(tracker).await?), + } } fn stats(&self) -> TransferStats { @@ -299,7 +330,8 @@ pub async fn handle_connection( }) .await { - debug!("client not authorized to connect: {cause}"); + connection.close(cause.code(), cause.reason()); + debug!("closing connection: {cause}"); return; } while let Ok(context) = StreamPair::accept(&connection, &progress).await { @@ -323,17 +355,16 @@ async fn handle_stream(store: Store, mut context: StreamPair) -> anyhow::Result< match request { Request::Get(request) => { - let tracker = context.get_request(|| request.clone()).await?; - let mut writer = context.into_writer(tracker).await?; - if handle_get(store, request, &mut writer).await.is_ok() { + let mut writer = context.get_request(|| request.clone()).await?; + let res = handle_get(store, request, &mut writer).await; + if res.is_ok() { writer.transfer_completed().await; } else { writer.transfer_aborted().await; } } Request::GetMany(request) => { - let tracker = context.get_many_request(|| request.clone()).await?; - let mut writer = context.into_writer(tracker).await?; + let mut writer = context.get_many_request(|| request.clone()).await?; if handle_get_many(store, request, &mut writer).await.is_ok() { writer.transfer_completed().await; } else { @@ -341,8 +372,7 @@ async fn handle_stream(store: Store, mut context: StreamPair) -> anyhow::Result< } } Request::Observe(request) => { - let tracker = context.observe_request(|| request.clone()).await?; - let mut writer = context.into_writer(tracker).await?; + let mut writer = context.observe_request(|| request.clone()).await?; if handle_observe(store, request, &mut writer).await.is_ok() { writer.transfer_completed().await; } else { @@ -350,8 +380,7 @@ async fn handle_stream(store: Store, mut context: StreamPair) -> anyhow::Result< } } Request::Push(request) => { - let tracker = context.push_request(|| request.clone()).await?; - let mut reader = context.into_reader(tracker).await?; + let mut reader = context.push_request(|| request.clone()).await?; if handle_push(store, request, &mut reader).await.is_ok() { reader.transfer_completed().await; } else { @@ -464,11 +493,11 @@ pub(crate) async fn send_blob( hash: Hash, ranges: ChunkRanges, writer: &mut ProgressWriter, -) -> api::Result<()> { - Ok(store +) -> ExportBaoResult<()> { + store .export_bao(hash, ranges) .write_quinn_with_progress(&mut writer.inner, &mut writer.context, &hash, index) - .await?) + .await } /// Handle a single push request. diff --git a/src/provider/events.rs b/src/provider/events.rs index f2bddb23c..5a922300a 100644 --- a/src/provider/events.rs +++ b/src/provider/events.rs @@ -8,7 +8,10 @@ use serde::{Deserialize, Serialize}; use snafu::Snafu; use crate::{ - protocol::{GetManyRequest, GetRequest, ObserveRequest, PushRequest}, + protocol::{ + GetManyRequest, GetRequest, ObserveRequest, PushRequest, ERR_INTERNAL, ERR_LIMIT, + ERR_PERMISSION, + }, provider::{events::irpc_ext::IrpcClientExt, TransferStats}, Hash, }; @@ -82,6 +85,24 @@ pub enum ClientError { }, } +impl ClientError { + pub fn code(&self) -> quinn::VarInt { + match self { + ClientError::RateLimited => ERR_LIMIT, + ClientError::Permission => ERR_PERMISSION, + ClientError::Irpc { .. } => ERR_INTERNAL, + } + } + + pub fn reason(&self) -> &'static [u8] { + match self { + ClientError::RateLimited => b"limit", + ClientError::Permission => b"permission", + ClientError::Irpc { .. } => b"internal", + } + } +} + impl From for ClientError { fn from(value: AbortReason) -> Self { match value { @@ -211,11 +232,14 @@ impl RequestTracker { /// Transfer for index `index` started, size `size` pub async fn transfer_started(&self, index: u64, hash: &Hash, size: u64) -> irpc::Result<()> { if let RequestUpdates::Active(tx) = &self.updates { - tx.send(RequestUpdate::Started(TransferStarted { - index, - hash: *hash, - size, - })) + tx.send( + TransferStarted { + index, + hash: *hash, + size, + } + .into(), + ) .await?; } Ok(()) @@ -224,8 +248,7 @@ impl RequestTracker { /// Transfer progress for the previously reported blob, end_offset is the new end offset in bytes. pub async fn transfer_progress(&mut self, len: u64, end_offset: u64) -> ClientResult { if let RequestUpdates::Active(tx) = &mut self.updates { - tx.try_send(RequestUpdate::Progress(TransferProgress { end_offset })) - .await?; + tx.try_send(TransferProgress { end_offset }.into()).await?; } if let Some((throttle, connection_id, request_id)) = &self.throttle { throttle @@ -242,8 +265,7 @@ impl RequestTracker { /// Transfer completed for the previously reported blob. pub async fn transfer_completed(&self, f: impl Fn() -> Box) -> irpc::Result<()> { if let RequestUpdates::Active(tx) = &self.updates { - tx.send(RequestUpdate::Completed(TransferCompleted { stats: f() })) - .await?; + tx.send(TransferCompleted { stats: f() }.into()).await?; } Ok(()) } @@ -251,8 +273,7 @@ impl RequestTracker { /// Transfer aborted for the previously reported blob. pub async fn transfer_aborted(&self, f: impl Fn() -> Box) -> irpc::Result<()> { if let RequestUpdates::Active(tx) = &self.updates { - tx.send(RequestUpdate::Aborted(TransferAborted { stats: f() })) - .await?; + tx.send(TransferAborted { stats: f() }.into()).await?; } Ok(()) } @@ -583,7 +604,7 @@ mod proto { } /// Stream of updates for a single request - #[derive(Debug, Serialize, Deserialize)] + #[derive(Debug, Serialize, Deserialize, derive_more::From)] pub enum RequestUpdate { /// Start of transfer for a blob, mandatory event Started(TransferStarted), From 2d72de0ac64e5e4f00dbb5c8cb712e201693b51b Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Wed, 3 Sep 2025 15:15:27 +0200 Subject: [PATCH 17/23] deny --- Cargo.lock | 48 +++++++++++++----------------------------------- 1 file changed, 13 insertions(+), 35 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4068354f7..988d7955a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2094,11 +2094,11 @@ checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" [[package]] name = "matchers" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" dependencies = [ - "regex-automata 0.1.10", + "regex-automata", ] [[package]] @@ -2385,12 +2385,11 @@ dependencies = [ [[package]] name = "nu-ansi-term" -version = "0.46.0" +version = "0.50.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +checksum = "d4a28e057d01f97e61255210fcff094d74ed0466038633e95017f5beb68e4399" dependencies = [ - "overload", - "winapi", + "windows-sys 0.52.0", ] [[package]] @@ -2467,12 +2466,6 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" -[[package]] -name = "overload" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" - [[package]] name = "parking" version = "2.2.1" @@ -2907,7 +2900,7 @@ dependencies = [ "rand 0.9.2", "rand_chacha 0.9.0", "rand_xorshift", - "regex-syntax 0.8.5", + "regex-syntax", "rusty-fork", "tempfile", "unarray", @@ -3151,17 +3144,8 @@ checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.9", - "regex-syntax 0.8.5", -] - -[[package]] -name = "regex-automata" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" -dependencies = [ - "regex-syntax 0.6.29", + "regex-automata", + "regex-syntax", ] [[package]] @@ -3172,7 +3156,7 @@ checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.5", + "regex-syntax", ] [[package]] @@ -3181,12 +3165,6 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" -[[package]] -name = "regex-syntax" -version = "0.6.29" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" - [[package]] name = "regex-syntax" version = "0.8.5" @@ -4245,14 +4223,14 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.19" +version = "0.3.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" +checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" dependencies = [ "matchers", "nu-ansi-term", "once_cell", - "regex", + "regex-automata", "sharded-slab", "smallvec", "thread_local", From 2dac46c3bf6ff0fb3b0d1d5648043826afbdaa64 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Wed, 3 Sep 2025 15:25:46 +0200 Subject: [PATCH 18/23] Use async syntax for implementing ProtocolHandler --- src/net_protocol.rs | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/src/net_protocol.rs b/src/net_protocol.rs index 269ef0e14..47cda5344 100644 --- a/src/net_protocol.rs +++ b/src/net_protocol.rs @@ -36,7 +36,7 @@ //! # } //! ``` -use std::{fmt::Debug, future::Future, ops::Deref, sync::Arc}; +use std::{fmt::Debug, ops::Deref, sync::Arc}; use iroh::{ endpoint::Connection, @@ -100,25 +100,16 @@ impl BlobsProtocol { } impl ProtocolHandler for BlobsProtocol { - fn accept( - &self, - conn: Connection, - ) -> impl Future> + Send { + async fn accept(&self, conn: Connection) -> std::result::Result<(), AcceptError> { let store = self.store().clone(); let events = self.inner.events.clone(); - - Box::pin(async move { - crate::provider::handle_connection(conn, store, events).await; - Ok(()) - }) + crate::provider::handle_connection(conn, store, events).await; + Ok(()) } - fn shutdown(&self) -> impl Future + Send { - let store = self.store().clone(); - Box::pin(async move { - if let Err(cause) = store.shutdown().await { - error!("error shutting down store: {:?}", cause); - } - }) + async fn shutdown(&self) { + if let Err(cause) = self.store().shutdown().await { + error!("error shutting down store: {:?}", cause); + } } } From a67d7875e3a9107ba210d8e89092fc632207dbc2 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Wed, 3 Sep 2025 15:57:39 +0200 Subject: [PATCH 19/23] Use irpc::channel::SendError as default sink error. --- src/api.rs | 11 ++++++++- src/api/blobs.rs | 7 ++++-- src/api/downloader.rs | 9 ++++---- src/api/remote.rs | 42 ++++++++++++++++++---------------- src/provider.rs | 6 ++--- src/provider/events.rs | 52 +++++++++++++++++++++++++----------------- src/util.rs | 13 +++++++---- 7 files changed, 83 insertions(+), 57 deletions(-) diff --git a/src/api.rs b/src/api.rs index a2a34a2db..117c59e25 100644 --- a/src/api.rs +++ b/src/api.rs @@ -30,7 +30,7 @@ pub mod downloader; pub mod proto; pub mod remote; pub mod tags; -use crate::api::proto::WaitIdleRequest; +use crate::{api::proto::WaitIdleRequest, provider::events::ProgressError}; pub use crate::{store::util::Tag, util::temp_tag::TempTag}; pub(crate) type ApiClient = irpc::Client; @@ -97,6 +97,8 @@ pub enum ExportBaoError { ExportBaoIo { source: io::Error }, #[snafu(display("encode error: {source}"))] ExportBaoInner { source: bao_tree::io::EncodeError }, + #[snafu(display("client error: {source}"))] + ClientError { source: ProgressError }, } impl From for Error { @@ -107,6 +109,7 @@ impl From for Error { ExportBaoError::Request { source, .. } => Self::Io(source.into()), ExportBaoError::ExportBaoIo { source, .. } => Self::Io(source), ExportBaoError::ExportBaoInner { source, .. } => Self::Io(source.into()), + ExportBaoError::ClientError { source, .. } => Self::Io(source.into()), } } } @@ -152,6 +155,12 @@ impl From for ExportBaoError { } } +impl From for ExportBaoError { + fn from(value: ProgressError) -> Self { + ClientSnafu.into_error(value) + } +} + pub type ExportBaoResult = std::result::Result; #[derive(Debug, derive_more::Display, derive_more::From, Serialize, Deserialize)] diff --git a/src/api/blobs.rs b/src/api/blobs.rs index 8b618de1f..1822be5b2 100644 --- a/src/api/blobs.rs +++ b/src/api/blobs.rs @@ -57,6 +57,7 @@ use super::{ }; use crate::{ api::proto::{BatchRequest, ImportByteStreamUpdate}, + provider::events::ClientResult, store::IROH_BLOCK_SIZE, util::temp_tag::TempTag, BlobFormat, Hash, HashAndFormat, @@ -1112,7 +1113,9 @@ impl ExportBaoProgress { .write_chunk(leaf.data) .await .map_err(io::Error::from)?; - progress.notify_payload_write(index, leaf.offset, len).await; + progress + .notify_payload_write(index, leaf.offset, len) + .await?; } EncodedItem::Done => break, EncodedItem::Error(cause) => return Err(cause.into()), @@ -1158,7 +1161,7 @@ impl ExportBaoProgress { pub(crate) trait WriteProgress { /// Notify the progress writer that a payload write has happened. - async fn notify_payload_write(&mut self, index: u64, offset: u64, len: usize); + async fn notify_payload_write(&mut self, index: u64, offset: u64, len: usize) -> ClientResult; /// Log a write of some other data. fn log_other_write(&mut self, len: usize); diff --git a/src/api/downloader.rs b/src/api/downloader.rs index a2abbd7ea..1db1e6f07 100644 --- a/src/api/downloader.rs +++ b/src/api/downloader.rs @@ -3,7 +3,6 @@ use std::{ collections::{HashMap, HashSet}, fmt::Debug, future::{Future, IntoFuture}, - io, sync::Arc, }; @@ -113,7 +112,7 @@ async fn handle_download_impl( SplitStrategy::Split => handle_download_split_impl(store, pool, request, tx).await?, SplitStrategy::None => match request.request { FiniteRequest::Get(get) => { - let sink = IrpcSenderRefSink(tx).with_map_err(io::Error::other); + let sink = IrpcSenderRefSink(tx); execute_get(&pool, Arc::new(get), &request.providers, &store, sink).await?; } FiniteRequest::GetMany(_) => { @@ -144,7 +143,7 @@ async fn handle_download_split_impl( let (tx, rx) = tokio::sync::mpsc::channel::<(usize, DownloadProgessItem)>(16); progress_tx.send(rx).await.ok(); let sink = TokioMpscSenderSink(tx) - .with_map_err(io::Error::other) + .with_map_err(|_| irpc::channel::SendError::ReceiverClosed) .with_map(move |x| (id, x)); let res = execute_get(&pool, Arc::new(request), &providers, &store, sink).await; (hash, res) @@ -375,7 +374,7 @@ async fn split_request<'a>( providers: &Arc, pool: &ConnectionPool, store: &Store, - progress: impl Sink, + progress: impl Sink, ) -> anyhow::Result + Send + 'a>> { Ok(match request { FiniteRequest::Get(req) => { @@ -431,7 +430,7 @@ async fn execute_get( request: Arc, providers: &Arc, store: &Store, - mut progress: impl Sink, + mut progress: impl Sink, ) -> anyhow::Result<()> { let remote = store.remote(); let mut providers = providers.find_providers(request.content()); diff --git a/src/api/remote.rs b/src/api/remote.rs index 623200900..3d8a3a817 100644 --- a/src/api/remote.rs +++ b/src/api/remote.rs @@ -18,6 +18,7 @@ use crate::{ GetManyRequest, ObserveItem, ObserveRequest, PushRequest, Request, RequestType, MAX_MESSAGE_SIZE, }, + provider::events::{ClientResult, ProgressError}, util::sink::{Sink, TokioMpscSenderSink}, }; @@ -478,9 +479,7 @@ impl Remote { let content = content.into(); let (tx, rx) = tokio::sync::mpsc::channel(64); let tx2 = tx.clone(); - let sink = TokioMpscSenderSink(tx) - .with_map(GetProgressItem::Progress) - .with_map_err(io::Error::other); + let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress); let this = self.clone(); let fut = async move { let res = this.fetch_sink(conn, content, sink).await.into(); @@ -503,7 +502,7 @@ impl Remote { &self, mut conn: impl GetConnection, content: impl Into, - progress: impl Sink, + progress: impl Sink, ) -> GetResult { let content = content.into(); let local = self @@ -556,9 +555,7 @@ impl Remote { pub fn execute_push(&self, conn: Connection, request: PushRequest) -> PushProgress { let (tx, rx) = tokio::sync::mpsc::channel(64); let tx2 = tx.clone(); - let sink = TokioMpscSenderSink(tx) - .with_map(PushProgressItem::Progress) - .with_map_err(io::Error::other); + let sink = TokioMpscSenderSink(tx).with_map(PushProgressItem::Progress); let this = self.clone(); let fut = async move { let res = this.execute_push_sink(conn, request, sink).await.into(); @@ -577,7 +574,7 @@ impl Remote { &self, conn: Connection, request: PushRequest, - progress: impl Sink, + progress: impl Sink, ) -> anyhow::Result { let hash = request.hash; debug!(%hash, "pushing"); @@ -632,9 +629,7 @@ impl Remote { pub fn execute_get_with_opts(&self, conn: Connection, request: GetRequest) -> GetProgress { let (tx, rx) = tokio::sync::mpsc::channel(64); let tx2 = tx.clone(); - let sink = TokioMpscSenderSink(tx) - .with_map(GetProgressItem::Progress) - .with_map_err(io::Error::other); + let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress); let this = self.clone(); let fut = async move { let res = this.execute_get_sink(&conn, request, sink).await.into(); @@ -658,7 +653,7 @@ impl Remote { &self, conn: &Connection, request: GetRequest, - mut progress: impl Sink, + mut progress: impl Sink, ) -> GetResult { let store = self.store(); let root = request.hash; @@ -721,9 +716,7 @@ impl Remote { pub fn execute_get_many(&self, conn: Connection, request: GetManyRequest) -> GetProgress { let (tx, rx) = tokio::sync::mpsc::channel(64); let tx2 = tx.clone(); - let sink = TokioMpscSenderSink(tx) - .with_map(GetProgressItem::Progress) - .with_map_err(io::Error::other); + let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress); let this = self.clone(); let fut = async move { let res = this.execute_get_many_sink(conn, request, sink).await.into(); @@ -747,7 +740,7 @@ impl Remote { &self, conn: Connection, request: GetManyRequest, - mut progress: impl Sink, + mut progress: impl Sink, ) -> GetResult { let store = self.store(); let hash_seq = request.hashes.iter().copied().collect::(); @@ -884,7 +877,7 @@ async fn get_blob_ranges_impl( header: AtBlobHeader, hash: Hash, store: &Store, - mut progress: impl Sink, + mut progress: impl Sink, ) -> GetResult { let (mut content, size) = header.next().await?; let Some(size) = NonZeroU64::new(size) else { @@ -1048,11 +1041,20 @@ struct StreamContext { impl WriteProgress for StreamContext where - S: Sink, + S: Sink, { - async fn notify_payload_write(&mut self, _index: u64, _offset: u64, len: usize) { + async fn notify_payload_write( + &mut self, + _index: u64, + _offset: u64, + len: usize, + ) -> ClientResult { self.payload_bytes_sent += len as u64; - self.sender.send(self.payload_bytes_sent).await.ok(); + self.sender + .send(self.payload_bytes_sent) + .await + .map_err(|e| ProgressError::Internal { source: e.into() })?; + Ok(()) } fn log_other_write(&mut self, _len: usize) {} diff --git a/src/provider.rs b/src/provider.rs index 49b57e13a..0134169c6 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -25,7 +25,7 @@ use crate::{ }, hashseq::HashSeq, protocol::{GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request}, - provider::events::{ClientConnected, ConnectionClosed, RequestTracker}, + provider::events::{ClientConnected, ClientResult, ConnectionClosed, RequestTracker}, Hash, }; pub mod events; @@ -264,11 +264,11 @@ impl WriterContext { } impl WriteProgress for WriterContext { - async fn notify_payload_write(&mut self, _index: u64, offset: u64, len: usize) { + async fn notify_payload_write(&mut self, _index: u64, offset: u64, len: usize) -> ClientResult { let len = len as u64; let end_offset = offset + len; self.payload_bytes_written += len; - self.tracker.transfer_progress(len, end_offset).await.ok(); + self.tracker.transfer_progress(len, end_offset).await } fn log_other_write(&mut self, len: usize) { diff --git a/src/provider/events.rs b/src/provider/events.rs index 5a922300a..fff800dc9 100644 --- a/src/provider/events.rs +++ b/src/provider/events.rs @@ -1,4 +1,4 @@ -use std::{fmt::Debug, ops::Deref}; +use std::{fmt::Debug, io, ops::Deref}; use irpc::{ channel::{mpsc, none::NoSender, oneshot}, @@ -76,60 +76,70 @@ pub enum AbortReason { } #[derive(Debug, Snafu)] -pub enum ClientError { - RateLimited, +pub enum ProgressError { + Limit, Permission, #[snafu(transparent)] - Irpc { + Internal { source: irpc::Error, }, } -impl ClientError { +impl From for io::Error { + fn from(value: ProgressError) -> Self { + match value { + ProgressError::Limit => io::ErrorKind::QuotaExceeded.into(), + ProgressError::Permission => io::ErrorKind::PermissionDenied.into(), + ProgressError::Internal { source } => source.into(), + } + } +} + +impl ProgressError { pub fn code(&self) -> quinn::VarInt { match self { - ClientError::RateLimited => ERR_LIMIT, - ClientError::Permission => ERR_PERMISSION, - ClientError::Irpc { .. } => ERR_INTERNAL, + ProgressError::Limit => ERR_LIMIT, + ProgressError::Permission => ERR_PERMISSION, + ProgressError::Internal { .. } => ERR_INTERNAL, } } pub fn reason(&self) -> &'static [u8] { match self { - ClientError::RateLimited => b"limit", - ClientError::Permission => b"permission", - ClientError::Irpc { .. } => b"internal", + ProgressError::Limit => b"limit", + ProgressError::Permission => b"permission", + ProgressError::Internal { .. } => b"internal", } } } -impl From for ClientError { +impl From for ProgressError { fn from(value: AbortReason) -> Self { match value { - AbortReason::RateLimited => ClientError::RateLimited, - AbortReason::Permission => ClientError::Permission, + AbortReason::RateLimited => ProgressError::Limit, + AbortReason::Permission => ProgressError::Permission, } } } -impl From for ClientError { +impl From for ProgressError { fn from(value: irpc::channel::RecvError) -> Self { - ClientError::Irpc { + ProgressError::Internal { source: value.into(), } } } -impl From for ClientError { +impl From for ProgressError { fn from(value: irpc::channel::SendError) -> Self { - ClientError::Irpc { + ProgressError::Internal { source: value.into(), } } } pub type EventResult = Result<(), AbortReason>; -pub type ClientResult = Result<(), ClientError>; +pub type ClientResult = Result<(), ProgressError>; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct EventMask { @@ -407,7 +417,7 @@ impl EventSender { f: impl FnOnce() -> Req, connection_id: u64, request_id: u64, - ) -> Result + ) -> Result where ProviderProto: From>, ProviderMessage: From, ProviderProto>>, @@ -466,7 +476,7 @@ impl EventSender { RequestUpdates::Active(tx) } RequestMode::Disabled => { - return Err(ClientError::Permission); + return Err(ProgressError::Permission); } _ => RequestUpdates::None, }, diff --git a/src/util.rs b/src/util.rs index 3fdaacbca..40abf0343 100644 --- a/src/util.rs +++ b/src/util.rs @@ -363,7 +363,7 @@ pub(crate) mod outboard_with_progress { } pub(crate) mod sink { - use std::{future::Future, io}; + use std::future::Future; use irpc::RpcMessage; @@ -433,10 +433,13 @@ pub(crate) mod sink { pub struct TokioMpscSenderSink(pub tokio::sync::mpsc::Sender); impl Sink for TokioMpscSenderSink { - type Error = tokio::sync::mpsc::error::SendError; + type Error = irpc::channel::SendError; async fn send(&mut self, value: T) -> std::result::Result<(), Self::Error> { - self.0.send(value).await + self.0 + .send(value) + .await + .map_err(|_| irpc::channel::SendError::ReceiverClosed) } } @@ -483,10 +486,10 @@ pub(crate) mod sink { pub struct Drain; impl Sink for Drain { - type Error = io::Error; + type Error = irpc::channel::SendError; async fn send(&mut self, _offset: T) -> std::result::Result<(), Self::Error> { - io::Result::Ok(()) + Ok(()) } } } From 546f57e90af8518c1c70e06078199987a5fc76d1 Mon Sep 17 00:00:00 2001 From: Frando Date: Wed, 3 Sep 2025 15:05:28 +0200 Subject: [PATCH 20/23] fixup --- examples/limit.rs | 18 +++++++++++------- examples/random_store.rs | 6 +++--- src/tests.rs | 13 ++----------- 3 files changed, 16 insertions(+), 21 deletions(-) diff --git a/examples/limit.rs b/examples/limit.rs index 830574fcc..e72f9be59 100644 --- a/examples/limit.rs +++ b/examples/limit.rs @@ -271,7 +271,7 @@ async fn main() -> Result<()> { let store = MemStore::new(); let hashes = add_paths(&store, paths).await?; let events = limit_by_node_id(allowed_nodes.clone()); - let (router, addr) = setup(MemStore::new(), events).await?; + let (router, addr) = setup(store, events).await?; for (path, hash) in hashes { let ticket = BlobTicket::new(addr.clone(), hash, BlobFormat::Raw); @@ -299,12 +299,16 @@ async fn main() -> Result<()> { } } - let events = limit_by_hash(allowed_hashes); - let (router, addr) = setup(MemStore::new(), events).await?; + let events = limit_by_hash(allowed_hashes.clone()); + let (router, addr) = setup(store, events).await?; - for (i, (path, hash)) in hashes.iter().enumerate() { + for (path, hash) in hashes.iter() { let ticket = BlobTicket::new(addr.clone(), *hash, BlobFormat::Raw); - let permitted = if i == 0 { "" } else { "limited" }; + let permitted = if allowed_hashes.contains(hash) { + "allowed" + } else { + "forbidden" + }; println!("{}: {ticket} ({permitted})", path.display()); } tokio::signal::ctrl_c().await?; @@ -314,7 +318,7 @@ async fn main() -> Result<()> { let store = MemStore::new(); let hashes = add_paths(&store, paths).await?; let events = throttle(delay_ms); - let (router, addr) = setup(MemStore::new(), events).await?; + let (router, addr) = setup(store, events).await?; for (path, hash) in hashes { let ticket = BlobTicket::new(addr.clone(), hash, BlobFormat::Raw); println!("{}: {ticket}", path.display()); @@ -329,7 +333,7 @@ async fn main() -> Result<()> { let store = MemStore::new(); let hashes = add_paths(&store, paths).await?; let events = limit_max_connections(max_connections); - let (router, addr) = setup(MemStore::new(), events).await?; + let (router, addr) = setup(store, events).await?; for (path, hash) in hashes { let ticket = BlobTicket::new(addr.clone(), hash, BlobFormat::Raw); println!("{}: {ticket}", path.display()); diff --git a/examples/random_store.rs b/examples/random_store.rs index c4c30348b..d3f9a0fc4 100644 --- a/examples/random_store.rs +++ b/examples/random_store.rs @@ -14,7 +14,7 @@ use iroh_blobs::{ use irpc::RpcMessage; use n0_future::StreamExt; use rand::{rngs::StdRng, Rng, SeedableRng}; -use tokio::{signal::ctrl_c, sync::mpsc}; +use tokio::signal::ctrl_c; use tracing::info; #[derive(Parser, Debug)] @@ -102,7 +102,7 @@ pub fn get_or_generate_secret_key() -> Result { } pub fn dump_provider_events(allow_push: bool) -> (tokio::task::JoinHandle<()>, EventSender) { - let (tx, mut rx) = mpsc::channel(100); + let (tx, mut rx) = EventSender::channel(100, EventMask::ALL_READONLY); fn dump_updates(mut rx: irpc::channel::mpsc::Receiver) { tokio::spawn(async move { while let Ok(Some(update)) = rx.recv().await { @@ -176,7 +176,7 @@ pub fn dump_provider_events(allow_push: bool) -> (tokio::task::JoinHandle<()>, E } } }); - (dump_task, EventSender::new(tx, EventMask::ALL_READONLY)) + (dump_task, tx) } #[tokio::main] diff --git a/src/tests.rs b/src/tests.rs index dc38eb436..0ef0c027c 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -342,7 +342,7 @@ fn event_handler( allowed_nodes: impl IntoIterator, ) -> (EventSender, watch::Receiver, AbortOnDropHandle<()>) { let (count_tx, count_rx) = tokio::sync::watch::channel(0usize); - let (events_tx, mut events_rx) = mpsc::channel::(16); + let (events_tx, mut events_rx) = EventSender::channel(16, EventMask::ALL_READONLY); let allowed_nodes = allowed_nodes.into_iter().collect::>(); let task = AbortOnDropHandle::new(tokio::task::spawn(async move { while let Some(event) = events_rx.recv().await { @@ -370,16 +370,7 @@ fn event_handler( } } })); - ( - EventSender::new( - events_tx, - EventMask { - ..EventMask::ALL_READONLY - }, - ), - count_rx, - task, - ) + (events_tx, count_rx, task) } async fn two_nodes_push_blobs( From f399e2bc44089dc398b9ba7d7572fdbb556443fb Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Wed, 3 Sep 2025 16:19:38 +0200 Subject: [PATCH 21/23] Remove map_err that isn't needed anymore --- src/api/downloader.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/api/downloader.rs b/src/api/downloader.rs index 1db1e6f07..3555eca9c 100644 --- a/src/api/downloader.rs +++ b/src/api/downloader.rs @@ -142,9 +142,7 @@ async fn handle_download_split_impl( let hash = request.hash; let (tx, rx) = tokio::sync::mpsc::channel::<(usize, DownloadProgessItem)>(16); progress_tx.send(rx).await.ok(); - let sink = TokioMpscSenderSink(tx) - .with_map_err(|_| irpc::channel::SendError::ReceiverClosed) - .with_map(move |x| (id, x)); + let sink = TokioMpscSenderSink(tx).with_map(move |x| (id, x)); let res = execute_get(&pool, Arc::new(request), &providers, &store, sink).await; (hash, res) } From 255f23b7f5814c79fccfd2312871dc784325fc08 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Thu, 11 Sep 2025 15:47:25 +0300 Subject: [PATCH 22/23] Change connection limit example to actually limit connections, not get requests. --- examples/limit.rs | 42 ++++++++++++++++-------------------------- src/provider/events.rs | 2 +- 2 files changed, 17 insertions(+), 27 deletions(-) diff --git a/examples/limit.rs b/examples/limit.rs index e72f9be59..707766b85 100644 --- a/examples/limit.rs +++ b/examples/limit.rs @@ -190,40 +190,30 @@ fn limit_max_connections(max_connections: usize) -> EventSender { // based on the current connection count if we want to accept or reject. // We also want detailed logging of events for the get request, so we can // detect when the request is finished one way or another. - get: RequestMode::RequestLog, + connected: ConnectMode::Request, ..EventMask::DEFAULT }; let (tx, mut rx) = EventSender::channel(32, mask); n0_future::task::spawn(async move { let requests = ConnectionCounter::new(max_connections); while let Some(msg) = rx.recv().await { - if let ProviderMessage::GetRequestReceived(mut msg) = msg { - let connection_id = msg.connection_id; - let request_id = msg.request_id; - let res = requests.inc(); - match res { - Ok(n) => { - println!("Accepting request {n}, id ({connection_id},{request_id})"); - msg.tx.send(Ok(())).await.ok(); - } - Err(_) => { - println!( - "Connection limit of {max_connections} exceeded, rejecting request" - ); - msg.tx.send(Err(AbortReason::RateLimited)).await.ok(); - continue; - } + match msg { + ProviderMessage::ClientConnected(msg) => { + let connection_id = msg.connection_id; + let node_id = msg.node_id; + let res = if let Ok(n) = requests.inc() { + println!("Accepting connection {n}, node_id {node_id}, connection_id {connection_id}"); + Ok(()) + } else { + Err(AbortReason::RateLimited) + }; + msg.tx.send(res).await.ok(); } - let requests = requests.clone(); - n0_future::task::spawn(async move { - // just drain the per request events - // - // Note that we have requested updates for the request, now we also need to process them - // otherwise the request will be aborted! - while let Ok(Some(_)) = msg.rx.recv().await {} - println!("Stopping request, id ({connection_id},{request_id})"); + ProviderMessage::ConnectionClosed(msg) => { requests.dec(); - }); + println!("Connection closed, connection_id {}", msg.connection_id,); + } + _ => {} } } }); diff --git a/src/provider/events.rs b/src/provider/events.rs index fff800dc9..acdee1df3 100644 --- a/src/provider/events.rs +++ b/src/provider/events.rs @@ -401,7 +401,7 @@ impl EventSender { Ok(()) } - /// A new client has been connected. + /// A connection has been closed. pub async fn connection_closed(&self, f: impl Fn() -> ConnectionClosed) -> ClientResult { if let Some(client) = &self.inner { client.notify(f()).await?; From 67ce53334c622912e7373b1ff88c9bceb3023e71 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Thu, 11 Sep 2025 17:48:15 +0300 Subject: [PATCH 23/23] PR review --- examples/limit.rs | 8 ++--- src/provider/events.rs | 66 +++++++++++++++++++++++++++--------------- 2 files changed, 46 insertions(+), 28 deletions(-) diff --git a/examples/limit.rs b/examples/limit.rs index 707766b85..6aaa2921f 100644 --- a/examples/limit.rs +++ b/examples/limit.rs @@ -81,7 +81,7 @@ fn limit_by_node_id(allowed_nodes: HashSet) -> EventSender { let mask = EventMask { // We want a request for each incoming connection so we can accept // or reject them. We don't need any other events. - connected: ConnectMode::Request, + connected: ConnectMode::Intercept, ..EventMask::DEFAULT }; let (tx, mut rx) = EventSender::channel(32, mask); @@ -108,7 +108,7 @@ fn limit_by_hash(allowed_hashes: HashSet) -> EventSender { // We want to get a request for each get request that we can answer // with OK or not OK depending on the hash. We do not want detailed // events once it has been decided to handle a request. - get: RequestMode::Request, + get: RequestMode::Intercept, ..EventMask::DEFAULT }; let (tx, mut rx) = EventSender::channel(32, mask); @@ -136,7 +136,7 @@ fn throttle(delay_ms: u64) -> EventSender { let mask = EventMask { // We want to get requests for each sent user data blob, so we can add a delay. // Other than that, we don't need any events. - throttle: ThrottleMode::Throttle, + throttle: ThrottleMode::Intercept, ..EventMask::DEFAULT }; let (tx, mut rx) = EventSender::channel(32, mask); @@ -190,7 +190,7 @@ fn limit_max_connections(max_connections: usize) -> EventSender { // based on the current connection count if we want to accept or reject. // We also want detailed logging of events for the get request, so we can // detect when the request is finished one way or another. - connected: ConnectMode::Request, + connected: ConnectMode::Intercept, ..EventMask::DEFAULT }; let (tx, mut rx) = EventSender::channel(32, mask); diff --git a/src/provider/events.rs b/src/provider/events.rs index acdee1df3..e24e0efbb 100644 --- a/src/provider/events.rs +++ b/src/provider/events.rs @@ -16,6 +16,7 @@ use crate::{ Hash, }; +/// Mode for connect events. #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] #[repr(u8)] pub enum ConnectMode { @@ -25,9 +26,10 @@ pub enum ConnectMode { /// We get a notification for connect events. Notify, /// We get a request for connect events and can reject incoming connections. - Request, + Intercept, } +/// Request mode for observe requests. #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] #[repr(u8)] pub enum ObserveMode { @@ -37,9 +39,10 @@ pub enum ObserveMode { /// We get a notification for connect events. Notify, /// We get a request for connect events and can reject incoming connections. - Request, + Intercept, } +/// Request mode for all data related requests. #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] #[repr(u8)] pub enum RequestMode { @@ -49,16 +52,20 @@ pub enum RequestMode { /// We get a notification for each request, but no transfer events. Notify, /// We get a request for each request, and can reject incoming requests, but no transfer events. - Request, + Intercept, /// We get a notification for each request as well as detailed transfer events. NotifyLog, /// We get a request for each request, and can reject incoming requests. /// We also get detailed transfer events. - RequestLog, + InterceptLog, /// This request type is completely disabled. All requests will be rejected. + /// + /// This means that requests of this kind will always be rejected, whereas + /// None means that we don't get any events, but requests will be processed normally. Disabled, } +/// Throttling mode for requests that support throttling. #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] #[repr(u8)] pub enum ThrottleMode { @@ -66,15 +73,18 @@ pub enum ThrottleMode { #[default] None, /// We call throttle to give the event handler a way to throttle requests - Throttle, + Intercept, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] pub enum AbortReason { + /// The request was aborted because a limit was exceeded. It is OK to try again later. RateLimited, + /// The request was aborted because the client does not have permission to perform the operation. Permission, } +/// Errors that can occur when sending progress updates. #[derive(Debug, Snafu)] pub enum ProgressError { Limit, @@ -141,6 +151,10 @@ impl From for ProgressError { pub type EventResult = Result<(), AbortReason>; pub type ClientResult = Result<(), ProgressError>; +/// Event mask to configure which events are sent to the event handler. +/// +/// This can also be used to completely disable certain request types. E.g. +/// push requests are disabled by default, as they can write to the local store. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct EventMask { /// Connection event mask @@ -180,12 +194,12 @@ impl EventMask { /// need to do it manually. Providing constants that have push enabled would /// risk misuse. pub const ALL_READONLY: Self = Self { - connected: ConnectMode::Request, - get: RequestMode::RequestLog, - get_many: RequestMode::RequestLog, + connected: ConnectMode::Intercept, + get: RequestMode::InterceptLog, + get_many: RequestMode::InterceptLog, push: RequestMode::Disabled, - throttle: ThrottleMode::Throttle, - observe: ObserveMode::Request, + throttle: ThrottleMode::Intercept, + observe: ObserveMode::Intercept, }; } @@ -239,7 +253,7 @@ impl RequestTracker { throttle: None, }; - /// Transfer for index `index` started, size `size` + /// Transfer for index `index` started, size `size` in bytes. pub async fn transfer_started(&self, index: u64, hash: &Hash, size: u64) -> irpc::Result<()> { if let RequestUpdates::Active(tx) = &self.updates { tx.send( @@ -333,7 +347,10 @@ impl EventSender { } while let Some(msg) = rx.recv().await { match msg { - ProviderMessage::ClientConnected(_) => todo!(), + ProviderMessage::ClientConnected(msg) => { + trace!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + } ProviderMessage::ClientConnectedNotify(msg) => { trace!("{:?}", msg.inner); } @@ -395,7 +412,7 @@ impl EventSender { match self.mask.connected { ConnectMode::None => {} ConnectMode::Notify => client.notify(Notify(f())).await?, - ConnectMode::Request => client.rpc(f()).await??, + ConnectMode::Intercept => client.rpc(f()).await??, } }; Ok(()) @@ -445,7 +462,7 @@ impl EventSender { client.unwrap().notify_streaming(Notify(msg), 32).await?, ) } - RequestMode::Request if client.is_some() => { + RequestMode::Intercept if client.is_some() => { let msg = RequestReceived { request: f(), connection_id, @@ -464,7 +481,7 @@ impl EventSender { }; RequestUpdates::Active(client.unwrap().notify_streaming(Notify(msg), 32).await?) } - RequestMode::RequestLog if client.is_some() => { + RequestMode::InterceptLog if client.is_some() => { let msg = RequestReceived { request: f(), connection_id, @@ -491,7 +508,7 @@ impl EventSender { ) -> RequestTracker { let throttle = match self.mask.throttle { ThrottleMode::None => None, - ThrottleMode::Throttle => self + ThrottleMode::Intercept => self .inner .clone() .map(|client| (client, connection_id, request_id)), @@ -515,38 +532,39 @@ pub enum ProviderProto { #[rpc(tx = NoSender)] ConnectionClosed(ConnectionClosed), - #[rpc(rx = mpsc::Receiver, tx = oneshot::Sender)] /// A new get request was received from the provider. + #[rpc(rx = mpsc::Receiver, tx = oneshot::Sender)] GetRequestReceived(RequestReceived), + /// A new get request was received from the provider (notify variant). #[rpc(rx = mpsc::Receiver, tx = NoSender)] - /// A new get request was received from the provider. GetRequestReceivedNotify(Notify>), - /// A new get request was received from the provider. + /// A new get many request was received from the provider. #[rpc(rx = mpsc::Receiver, tx = oneshot::Sender)] GetManyRequestReceived(RequestReceived), - /// A new get request was received from the provider. + /// A new get many request was received from the provider (notify variant). #[rpc(rx = mpsc::Receiver, tx = NoSender)] GetManyRequestReceivedNotify(Notify>), - /// A new get request was received from the provider. + /// A new push request was received from the provider. #[rpc(rx = mpsc::Receiver, tx = oneshot::Sender)] PushRequestReceived(RequestReceived), - /// A new get request was received from the provider. + /// A new push request was received from the provider (notify variant). #[rpc(rx = mpsc::Receiver, tx = NoSender)] PushRequestReceivedNotify(Notify>), - /// A new get request was received from the provider. + /// A new observe request was received from the provider. #[rpc(rx = mpsc::Receiver, tx = oneshot::Sender)] ObserveRequestReceived(RequestReceived), - /// A new get request was received from the provider. + /// A new observe request was received from the provider (notify variant). #[rpc(rx = mpsc::Receiver, tx = NoSender)] ObserveRequestReceivedNotify(Notify>), + /// Request to throttle sending for a specific data request. #[rpc(tx = oneshot::Sender)] Throttle(Throttle), }