Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 37 additions & 17 deletions src/util/connection_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,22 @@ impl Default for Options {
}
}

impl Options {
/// Set the on_connected callback
pub fn with_on_connected<F, Fut>(mut self, f: F) -> Self
where
F: Fn(Endpoint, Connection) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = io::Result<()>> + Send + 'static,
{
self.on_connected = Some(Arc::new(move |ep, conn| {
let ep = ep.clone();
let conn = conn.clone();
Box::pin(f(ep, conn))
}));
self
}
}

/// A reference to a connection that is owned by a connection pool.
#[derive(Debug)]
pub struct ConnectionRef {
Expand Down Expand Up @@ -524,9 +540,9 @@ mod tests {

use iroh::{
discovery::static_provider::StaticProvider,
endpoint::Connection,
endpoint::{Connection, ConnectionType},
protocol::{AcceptError, ProtocolHandler, Router},
NodeAddr, NodeId, SecretKey, Watcher,
Endpoint, NodeAddr, NodeId, SecretKey, Watcher,
};
use n0_future::{io, stream, BufferedStreamExt, StreamExt};
use n0_snafu::ResultExt;
Expand Down Expand Up @@ -770,37 +786,41 @@ mod tests {
Ok(())
}

/// Uses an on_connected callback that delays for a long time.
///
/// This checks that the pool timeout includes on_connected delay.
/// Uses an on_connected callback to ensure that the connection is direct.
#[tokio::test]
// #[traced_test]
async fn on_connected_timeout() -> TestResult<()> {
async fn on_connected_direct() -> TestResult<()> {
let n = 1;
let (ids, routers, discovery) = echo_servers(n).await?;
let endpoint = iroh::Endpoint::builder()
.discovery(discovery)
.bind()
.await?;
let on_connected: OnConnected = Arc::new(|_, _| {
Box::pin(async {
tokio::time::sleep(Duration::from_secs(20)).await;
Ok(())
})
});
let on_connected = |ep: Endpoint, conn: Connection| async move {
let Ok(id) = conn.remote_node_id() else {
return Err(io::Error::other("unable to get node id"));
};
let Some(watcher) = ep.conn_type(id) else {
return Err(io::Error::other("unable to get conn_type watcher"));
};
let mut stream = watcher.stream();
while let Some(status) = stream.next().await {
if let ConnectionType::Direct { .. } = status {
return Ok(());
}
}
Err(io::Error::other("connection closed before becoming direct"))
};
let pool = ConnectionPool::new(
endpoint,
ECHO_ALPN,
Options {
on_connected: Some(on_connected),
..test_options()
},
test_options().with_on_connected(on_connected),
);
let client = EchoClient { pool };
let msg = b"Hello, pool!".to_vec();
for id in &ids {
let res = client.echo(*id, msg.clone()).await;
assert!(matches!(res, Err(PoolConnectError::Timeout)));
assert!(res.is_ok());
}
shutdown_routers(routers).await;
Ok(())
Expand Down
Loading