diff --git a/iroh-gossip/src/net.rs b/iroh-gossip/src/net.rs index 9ef6f08152..f350e7b44e 100644 --- a/iroh-gossip/src/net.rs +++ b/iroh-gossip/src/net.rs @@ -384,12 +384,7 @@ impl Actor { } Some(new_conn) = self.conn_manager.next() => { trace!(?i, "tick: conn_manager"); - let node_id = new_conn.node_id; - if let Err(err) = self.handle_new_connection(new_conn).await { - warn!(peer=%node_id.fmt_short(), ?err, "failed to handle new connection"); - self.conn_manager.remove(&node_id); - self.conn_send_tx.remove(&node_id); - } + self.handle_new_connection(new_conn).await; } Some(res) = self.conn_tasks.join_next(), if !self.conn_tasks.is_empty() => { match res { @@ -435,11 +430,11 @@ impl Actor { async fn handle_to_actor_msg(&mut self, msg: ToActor, now: Instant) -> anyhow::Result<()> { trace!("handle to_actor {msg:?}"); match msg { - ToActor::AcceptConn(conn) => { - if let Err(err) = self.conn_manager.push_accept(conn) { - warn!(?err, "failed to accept connection"); - } - } + ToActor::AcceptConn(conn) => match self.conn_manager.accept(conn) { + Err(err) => warn!(?err, "failed to accept connection"), + Ok(None) => {} + Ok(Some(conn)) => self.handle_new_connection(conn).await, + }, ToActor::Join(topic_id, peers, reply) => { self.handle_in_event(InEvent::Command(topic_id, Command::Join(peers)), now) .await?; @@ -550,7 +545,7 @@ impl Actor { Ok(()) } - async fn handle_new_connection(&mut self, new_conn: NewConnection) -> anyhow::Result<()> { + async fn handle_new_connection(&mut self, new_conn: NewConnection) { let NewConnection { conn, node_id: peer_id, @@ -582,8 +577,6 @@ impl Actor { warn!(peer=%peer_id.fmt_short(), "connecting to node failed: {err:?}"); } } - - Ok(()) } fn subscribe_all(&mut self) -> broadcast::Receiver<(TopicId, Event)> { diff --git a/iroh-net/Cargo.toml b/iroh-net/Cargo.toml index 9b52533a49..a20f119c72 100644 --- a/iroh-net/Cargo.toml +++ b/iroh-net/Cargo.toml @@ -55,6 +55,7 @@ rand = "0.8" rand_core = "0.6.4" rcgen = "0.11" reqwest = { version = "0.11.19", default-features = false, features = ["rustls-tls"] } +rand_chacha = { version = "0.3.1", optional = true } ring = "0.17" rustls = { version = "0.21.11", default-features = false, features = ["dangerous_configuration"] } serde = { version = "1", features = ["derive", "rc"] } @@ -124,7 +125,7 @@ duct = "0.13.6" default = ["metrics"] iroh-relay = ["clap", "toml", "rustls-pemfile", "regex", "serde_with", "tracing-subscriber"] metrics = ["iroh-metrics/metrics"] -test-utils = ["axum"] +test-utils = ["axum", "rand_chacha"] [[bin]] name = "iroh-relay" diff --git a/iroh-net/src/dialer.rs b/iroh-net/src/dialer.rs index 5c2fb28402..075dd1a425 100644 --- a/iroh-net/src/dialer.rs +++ b/iroh-net/src/dialer.rs @@ -1,7 +1,7 @@ //! A dialer to dial nodes use std::{ - collections::{HashMap, VecDeque}, + collections::{HashMap}, pin::Pin, task::{ready, Poll}, }; @@ -11,8 +11,7 @@ use crate::{ key::PublicKey, Endpoint, NodeAddr, NodeId, }; -use anyhow::anyhow; -use futures_lite::{future::Boxed as BoxFuture, Stream}; +use futures_lite::{future::Boxed as BoxFuture, Future, Stream}; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; use tracing::error; @@ -36,44 +35,69 @@ use tracing::error; pub struct ConnManager { dialer: Dialer, alpn: &'static [u8], - newly_accepted: VecDeque, - active: HashMap, + active: HashMap, + accept_tx: flume::Sender, + accept_rx: flume::Receiver, } impl ConnManager { /// Create a new connection manager. pub fn new(endpoint: Endpoint, alpn: &'static [u8]) -> Self { let dialer = Dialer::new(endpoint); + let (accept_tx, accept_rx) = flume::bounded(128); Self { dialer, alpn, - newly_accepted: Default::default(), active: Default::default(), + accept_tx, + accept_rx, } } - /// Push a newly accepted connection into the manager. + /// Get a sender to push new connections towards the [`ConnManager`] /// /// This does not check the connection's ALPN, so you should make sure that the ALPN matches - /// the [`ConnManager`]'s execpected ALPN before passing the connection to [`Self::accept]. + /// the [`ConnManager`]'s execpected ALPN before passing the connection to the sender. /// /// If we are currently dialing the node, the connection will be dropped if the peer's node id - /// sorty higher than our node id. + /// sorty higher than our node id. Otherwise, the connection will be yielded from the manager + /// stream. + pub fn accept_sender(&self) -> flume::Sender { + self.accept_tx.clone() + } + + /// Accept a connection. + /// + /// This does not check the connection's ALPN, so you should make sure that the ALPN matches + /// the [`ConnManager`]'s execpected ALPN before passing the connection to the sender. /// - /// Returns an error if getting the peer's node id from the TLS certificate fails. - pub fn accept(&mut self, conn: quinn::Connection) -> anyhow::Result<()> { + /// If we are currently dialing the node, the connection will be dropped if the peer's node id + /// sorty higher than our node id. Otherwise, the connection will be returned. + pub fn accept(&mut self, conn: quinn::Connection) -> anyhow::Result> { let node_id = get_remote_node_id(&conn)?; + tracing::info!(me=%self.our_node_id().fmt_short(), peer=%node_id.fmt_short(), is_connected=self.is_connected(&node_id), is_dialing=self.is_dialing(&node_id), "incoming accept"); + // We are already connected: drop the connection, keep using the existing conn. if self.is_connected(&node_id) { - return Ok(()); + return Ok(None); } - // If we are also dialing this node, only accept if node id is greater than ours. - // this deduplicates connections. - if !self.dialer.is_pending(&node_id) || node_id > self.our_node_id() { + + // If we are currently dialing as well, only accept if our node id sorts higher than + // theirs. + if !self.is_dialing(&node_id) || node_id > self.our_node_id() { self.dialer.abort_dial(&node_id); self.active.insert(node_id, conn.clone()); - self.newly_accepted.push_back(node_id); + tracing::info!(me=%self.our_node_id().fmt_short(), peer=%node_id.fmt_short(), "accept: OK, our dial aborted"); + let c = NewConnection { + conn: Ok(conn.clone()), + node_id, + direction: ConnDirection::Accept, + }; + Ok(Some(c)) + } else { + conn.close(0u32.into(), b"prefer_ours"); + tracing::info!(me=%self.our_node_id().fmt_short(), peer=%node_id.fmt_short(), "accept: drop"); + Ok(None) } - Ok(()) } /// Remove the connection to a node. @@ -124,18 +148,37 @@ impl Stream for ConnManager { mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { - while let Some(node_id) = self.newly_accepted.pop_front() { - if let Some(conn) = self.active.get(&node_id) { - let c = NewConnection { - conn: Ok(conn.clone()), - node_id, - direction: ConnDirection::Accept, - }; - return Poll::Ready(Some(c)); + loop { + let conn = loop { + let recv = self.accept_rx.recv_async(); + tokio::pin!(recv); + match recv.poll(cx) { + Poll::Pending => break None, + Poll::Ready(Ok(conn)) => break Some(conn), + Poll::Ready(Err(_err)) => { + continue; + } + } + }; + match conn { + Some(conn) => match self.as_mut().accept(conn) { + Err(err) => { + tracing::warn!("dropping invalid connection: {err:?}"); + continue; + } + Ok(None) => continue, + Ok(Some(new_conn)) => { + return Poll::Ready(Some(new_conn)); + } + }, + None => break, } } + if let Some((node_id, conn)) = ready!(Pin::new(&mut self.dialer).poll_next(cx)) { - if !self.active.contains_key(&node_id) { + // tracing::info!(me=%self.our_node_id().fmt_short(), peer=%node_id.fmt_short(), "dial complete!"); + tracing::info!(me=%self.our_node_id().fmt_short(), peer=%node_id.fmt_short(), is_connected=self.is_connected(&node_id), is_dialing=self.is_dialing(&node_id), success=conn.is_ok(), "incoming dial"); + if !self.is_connected(&node_id) { if let Ok(conn) = &conn { self.active.insert(node_id, conn.clone()); } @@ -147,7 +190,7 @@ impl Stream for ConnManager { return Poll::Ready(Some(c)); } } - Poll::Ready(None) + Poll::Pending } } @@ -180,7 +223,7 @@ pub struct NewConnection { #[derive(Debug)] pub struct Dialer { endpoint: Endpoint, - pending: JoinSet<(PublicKey, anyhow::Result)>, + pending: JoinSet<(PublicKey, Option>)>, pending_dials: HashMap, } @@ -205,19 +248,26 @@ impl Dialer { let cancel = CancellationToken::new(); self.pending_dials.insert(node_id, cancel.clone()); let endpoint = self.endpoint.clone(); + let me = endpoint.node_id(); self.pending.spawn(async move { - let res = tokio::select! { + tokio::select! { biased; - _ = cancel.cancelled() => Err(anyhow!("Cancelled")), - res = endpoint.connect(NodeAddr::new(node_id), alpn) => res - }; - (node_id, res) + _ = cancel.cancelled() => { + tracing::info!(me=%me.fmt_short(), peer=%node_id.fmt_short(), "dial cancel!"); + (node_id, None) + } + res = endpoint.connect(NodeAddr::new(node_id), alpn) => { + tracing::info!(me=%me.fmt_short(), peer=%node_id.fmt_short(), "dial success!"); + (node_id, Some(res)) + } + } }); } /// Abort a pending dial pub fn abort_dial(&mut self, node_id: &NodeId) { if let Some(cancel) = self.pending_dials.remove(node_id) { + tracing::info!(me=%self.endpoint.node_id().fmt_short(), peer=%node_id.fmt_short(), "abort dial!"); cancel.cancel(); } } @@ -227,32 +277,6 @@ impl Dialer { self.pending_dials.contains_key(node) } - /// Wait for the next dial operation to complete - pub async fn next_conn(&mut self) -> (PublicKey, anyhow::Result) { - match self.pending_dials.is_empty() { - false => { - let (node_id, res) = loop { - match self.pending.join_next().await { - Some(Ok((node_id, res))) => { - self.pending_dials.remove(&node_id); - break (node_id, res); - } - Some(Err(e)) => { - error!("next conn error: {:?}", e); - } - None => { - error!("no more pending conns available"); - std::future::pending().await - } - } - }; - - (node_id, res) - } - true => std::future::pending().await, - } - } - /// Number of pending connections to be opened. pub fn pending_count(&self) -> usize { self.pending_dials.len() @@ -266,19 +290,101 @@ impl Stream for Dialer { mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { - match self.pending.poll_join_next(cx) { - Poll::Ready(Some(Ok((node_id, result)))) => { - self.pending_dials.remove(&node_id); - Poll::Ready(Some((node_id, result))) - } - Poll::Ready(Some(Err(e))) => { - error!("dialer error: {:?}", e); - Poll::Pending + loop { + match self.pending.poll_join_next(cx) { + Poll::Ready(Some(Ok((node_id, result)))) => { + self.pending_dials.remove(&node_id); + match result { + // cancelled! + None => continue, + Some(result) => return Poll::Ready(Some((node_id, result))), + } + } + Poll::Ready(Some(Err(e))) => { + error!("dialer error: {:?}", e); + continue; + } + _ => return Poll::Pending, } - _ => Poll::Pending, } } } /// Future for a pending dial operation pub type DialFuture = BoxFuture<(PublicKey, anyhow::Result)>; + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use futures_lite::StreamExt; + use tokio::task::JoinSet; + + use crate::{dialer::ConnManager, endpoint::Connection, test_utils::TestEndpointFactory}; + + const TEST_ALPN: &[u8] = b"test"; + + async fn accept_loop(ep: crate::Endpoint, tx: flume::Sender) -> anyhow::Result<()> { + while let Some(conn) = ep.accept().await { + let conn = conn.await?; + tx.send_async(conn).await?; + } + Ok(()) + } + + #[tokio::test] + async fn test_conn_manager() -> anyhow::Result<()> { + let _guard = iroh_test::logging::setup(); + let mut factory = TestEndpointFactory::run().await?; + + let alpns = vec![TEST_ALPN.to_vec()]; + let ep1 = factory.create_endpoint(alpns.clone()).await?; + let ep2 = factory.create_endpoint(alpns.clone()).await?; + let n1 = ep1.node_id(); + let n2 = ep2.node_id(); + tracing::info!(?n1, ?n2, "endpoints created"); + factory.on_node(&n1, Duration::from_secs(2)).await?; + factory.on_node(&n2, Duration::from_secs(2)).await?; + + let mut conn_manager1 = ConnManager::new(ep1.clone(), TEST_ALPN); + let mut conn_manager2 = ConnManager::new(ep2.clone(), TEST_ALPN); + + let accept1 = conn_manager1.accept_sender(); + let accept2 = conn_manager2.accept_sender(); + let mut tasks = JoinSet::new(); + tasks.spawn(accept_loop(ep1, accept1)); + tasks.spawn(accept_loop(ep2, accept2)); + + for i in 0u8..20 { + tracing::info!(i, "start dial"); + conn_manager1.dial(n2); + conn_manager2.dial(n1); + let (conn1, conn2) = tokio::join!(conn_manager1.next(), conn_manager2.next()); + let conn1 = conn1.unwrap(); + let conn2 = conn2.unwrap(); + + tracing::info!(?conn1.direction, "conn1"); + tracing::info!(?conn2.direction, "conn2"); + assert!(conn1.direction != conn2.direction); + assert_eq!(conn1.node_id, n2); + assert_eq!(conn2.node_id, n1); + let conn1 = conn1.conn.unwrap(); + let conn2 = conn2.conn.unwrap(); + let mut s1 = conn1.open_uni().await.unwrap(); + s1.write_all(&[i]).await?; + s1.finish().await?; + let mut s2 = conn2.accept_uni().await.unwrap(); + let x = s2.read_to_end(1).await.unwrap(); + assert_eq!(&x, &[i]); + assert!(conn_manager1.remove(&n2).is_some()); + assert!(conn_manager2.remove(&n1).is_some()); + } + + tasks.abort_all(); + while let Some(r) = tasks.join_next().await { + assert!(r.unwrap_err().is_cancelled()); + } + + Ok(()) + } +} diff --git a/iroh-net/src/test_utils.rs b/iroh-net/src/test_utils.rs index 0cbf8bd857..78891ec96a 100644 --- a/iroh-net/src/test_utils.rs +++ b/iroh-net/src/test_utils.rs @@ -1,12 +1,14 @@ //! Internal utilities to support testing. use anyhow::Result; +use rand::SeedableRng; use tokio::sync::oneshot; use tracing::{error_span, info_span, Instrument}; use crate::{ key::SecretKey, - relay::{RelayMap, RelayNode, RelayUrl}, + relay::{RelayMap, RelayMode, RelayNode, RelayUrl}, + Endpoint, }; pub use dns_and_pkarr_servers::DnsPkarrServer; @@ -68,6 +70,72 @@ pub async fn run_relay_server() -> Result<(RelayMap, RelayUrl, CleanupDropGuard) Ok((m, url, CleanupDropGuard(tx))) } +/// A factory for endpoints with preconfigured local discovery and relays. +#[derive(Debug)] +pub struct TestEndpointFactory { + relay_map: RelayMap, + relay_url: RelayUrl, + _relay_drop_guard: CleanupDropGuard, + dns_pkarr_server: DnsPkarrServer, + rng: rand_chacha::ChaCha12Rng, +} + +impl TestEndpointFactory { + /// Starts local relay and discovery servers and returns a [`TestEndpointFactory`] to create + /// readily configured endpoints. + /// + /// The local servers will shut down once the [`TestEndpointFactory`] is dropped. + pub async fn run() -> anyhow::Result { + let dns_pkarr_server = DnsPkarrServer::run().await?; + let (relay_map, relay_url, relay_drop_guard) = run_relay_server().await?; + Ok(Self { + relay_map, + relay_url, + dns_pkarr_server, + _relay_drop_guard: relay_drop_guard, + rng: rand_chacha::ChaCha12Rng::seed_from_u64(77), + }) + } + + /// Create a new endpoint builder which already has discovery and relays configured. + pub fn create_endpoint_builder(&self, secret_key: SecretKey) -> crate::endpoint::Builder { + Endpoint::builder() + .secret_key(secret_key.clone()) + .relay_mode(RelayMode::Custom(self.relay_map.clone())) + .insecure_skip_relay_cert_verify(true) + .dns_resolver(self.dns_pkarr_server.dns_resolver()) + .discovery(self.dns_pkarr_server.discovery(secret_key)) + } + + /// Create a new endpoint with the specified ALPN. + /// + /// The endpoint will have discovery and relays configured, and have a predictable secret key. + pub async fn create_endpoint(&mut self, alpns: Vec>) -> anyhow::Result { + let secret_key = SecretKey::generate_with_rng(&mut self.rng); + self.create_endpoint_builder(secret_key) + .alpns(alpns) + .bind(0) + .await + } + + /// Returns the URL of the local relay server. + pub fn relay_url(&self) -> &RelayUrl { + &self.relay_url + } + + /// Returns a relay map which contains the local relay server. + pub fn relay_map(&self) -> &RelayMap { + &self.relay_map + } + + /// Wait until a Pkarr announce for a node is published to the server. + /// + /// If `timeout` elapses an error is returned. + pub async fn on_node(&self, node_id: &crate::NodeId, timeout: std::time::Duration) -> Result<()> { + self.dns_pkarr_server.on_node(node_id, timeout).await + } +} + pub(crate) mod dns_and_pkarr_servers { use anyhow::Result; use iroh_base::key::{NodeId, SecretKey};