From e8e1c749f20a033d623604ca2ec54b959d9ad174 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Wed, 1 Oct 2025 14:23:30 +0300 Subject: [PATCH] Add a little helper on connection_pool::Options to deal with the Arc/Box/Pin madness Also add a test to use on_connected to wait for direct connections. --- src/util/connection_pool.rs | 54 +++++++++++++++++++++++++------------ 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/src/util/connection_pool.rs b/src/util/connection_pool.rs index aa9c15292..68b1476ff 100644 --- a/src/util/connection_pool.rs +++ b/src/util/connection_pool.rs @@ -65,6 +65,22 @@ impl Default for Options { } } +impl Options { + /// Set the on_connected callback + pub fn with_on_connected(mut self, f: F) -> Self + where + F: Fn(Endpoint, Connection) -> Fut + Send + Sync + 'static, + Fut: std::future::Future> + Send + 'static, + { + self.on_connected = Some(Arc::new(move |ep, conn| { + let ep = ep.clone(); + let conn = conn.clone(); + Box::pin(f(ep, conn)) + })); + self + } +} + /// A reference to a connection that is owned by a connection pool. #[derive(Debug)] pub struct ConnectionRef { @@ -524,9 +540,9 @@ mod tests { use iroh::{ discovery::static_provider::StaticProvider, - endpoint::Connection, + endpoint::{Connection, ConnectionType}, protocol::{AcceptError, ProtocolHandler, Router}, - NodeAddr, NodeId, SecretKey, Watcher, + Endpoint, NodeAddr, NodeId, SecretKey, Watcher, }; use n0_future::{io, stream, BufferedStreamExt, StreamExt}; use n0_snafu::ResultExt; @@ -770,37 +786,41 @@ mod tests { Ok(()) } - /// Uses an on_connected callback that delays for a long time. - /// - /// This checks that the pool timeout includes on_connected delay. + /// Uses an on_connected callback to ensure that the connection is direct. #[tokio::test] // #[traced_test] - async fn on_connected_timeout() -> TestResult<()> { + async fn on_connected_direct() -> TestResult<()> { let n = 1; let (ids, routers, discovery) = echo_servers(n).await?; let endpoint = iroh::Endpoint::builder() .discovery(discovery) .bind() .await?; - let on_connected: OnConnected = Arc::new(|_, _| { - Box::pin(async { - tokio::time::sleep(Duration::from_secs(20)).await; - Ok(()) - }) - }); + let on_connected = |ep: Endpoint, conn: Connection| async move { + let Ok(id) = conn.remote_node_id() else { + return Err(io::Error::other("unable to get node id")); + }; + let Some(watcher) = ep.conn_type(id) else { + return Err(io::Error::other("unable to get conn_type watcher")); + }; + let mut stream = watcher.stream(); + while let Some(status) = stream.next().await { + if let ConnectionType::Direct { .. } = status { + return Ok(()); + } + } + Err(io::Error::other("connection closed before becoming direct")) + }; let pool = ConnectionPool::new( endpoint, ECHO_ALPN, - Options { - on_connected: Some(on_connected), - ..test_options() - }, + test_options().with_on_connected(on_connected), ); let client = EchoClient { pool }; let msg = b"Hello, pool!".to_vec(); for id in &ids { let res = client.echo(*id, msg.clone()).await; - assert!(matches!(res, Err(PoolConnectError::Timeout))); + assert!(res.is_ok()); } shutdown_routers(routers).await; Ok(())