diff --git a/Cargo.lock b/Cargo.lock index 31aac2e02a..02f388527e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2700,6 +2700,7 @@ dependencies = [ "derive_more", "ed25519-dalek", "futures-lite 2.3.0", + "futures-util", "genawaiter", "indexmap 2.2.6", "iroh-base", @@ -2808,7 +2809,7 @@ dependencies = [ "rand_core", "rcgen 0.11.3", "regex", - "reqwest 0.12.4", + "reqwest 0.11.27", "ring 0.17.8", "rtnetlink", "rustls 0.21.12", diff --git a/iroh-gossip/Cargo.toml b/iroh-gossip/Cargo.toml index de0c0d2b5b..99505a84e5 100644 --- a/iroh-gossip/Cargo.toml +++ b/iroh-gossip/Cargo.toml @@ -32,6 +32,7 @@ iroh-base = { version = "0.16.0", path = "../iroh-base" } # net dependencies (optional) futures-lite = { version = "2.3", optional = true } +futures-util = { version = "0.3.30", optional = true } iroh-net = { path = "../iroh-net", version = "0.16.0", optional = true, default-features = false, features = ["test-utils"] } tokio = { version = "1", optional = true, features = ["io-util", "sync", "rt", "macros", "net", "fs"] } tokio-util = { version = "0.7.8", optional = true, features = ["codec"] } @@ -46,7 +47,7 @@ url = "2.4.0" [features] default = ["net"] -net = ["dep:futures-lite", "dep:iroh-net", "dep:tokio", "dep:tokio-util"] +net = ["dep:futures-lite", "dep:futures-util", "dep:iroh-net", "dep:tokio", "dep:tokio-util"] [[example]] name = "chat" diff --git a/iroh-gossip/src/net.rs b/iroh-gossip/src/net.rs index 4083e3a113..29b43ae32f 100644 --- a/iroh-gossip/src/net.rs +++ b/iroh-gossip/src/net.rs @@ -2,11 +2,12 @@ use anyhow::{anyhow, Context}; use bytes::{Bytes, BytesMut}; -use futures_lite::stream::Stream; +use futures_lite::{stream::Stream, StreamExt}; +use futures_util::future::FutureExt; use genawaiter::sync::{Co, Gen}; use iroh_net::{ - dialer::Dialer, - endpoint::{get_remote_node_id, Connection}, + conn_manager::{ConnDirection, ConnInfo, ConnManager}, + endpoint::Connection, key::PublicKey, AddrInfo, Endpoint, NodeAddr, }; @@ -15,7 +16,7 @@ use rand_core::SeedableRng; use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc, task::Poll, time::Instant}; use tokio::{ sync::{broadcast, mpsc, oneshot}, - task::JoinHandle, + task::{JoinHandle, JoinSet}, }; use tracing::{debug, error_span, trace, warn, Instrument}; @@ -82,7 +83,7 @@ impl Gossip { /// Spawn a gossip actor and get a handle for it pub fn from_endpoint(endpoint: Endpoint, config: proto::Config, my_addr: &AddrInfo) -> Self { let peer_id = endpoint.node_id(); - let dialer = Dialer::new(endpoint.clone()); + let conn_manager = ConnManager::new(endpoint.clone(), GOSSIP_ALPN); let state = proto::State::new( peer_id, encode_peer_data(my_addr).unwrap(), @@ -97,12 +98,12 @@ impl Gossip { let actor = Actor { endpoint, state, - dialer, + conn_manager, + conn_tasks: Default::default(), to_actor_rx, in_event_rx, in_event_tx, on_endpoints_rx, - conns: Default::default(), conn_send_tx: Default::default(), pending_sends: Default::default(), timers: Timers::new(), @@ -231,9 +232,7 @@ impl Gossip { /// /// Make sure to check the ALPN protocol yourself before passing the connection. pub async fn handle_connection(&self, conn: Connection) -> anyhow::Result<()> { - let peer_id = get_remote_node_id(&conn)?; - self.send(ToActor::ConnIncoming(peer_id, ConnOrigin::Accept, conn)) - .await?; + self.send(ToActor::ConnIncoming(conn)).await?; Ok(()) } @@ -283,19 +282,11 @@ impl Future for JoinTopicFut { } } -/// Whether a connection is initiated by us (Dial) or by the remote peer (Accept) -#[derive(Debug)] -enum ConnOrigin { - Accept, - Dial, -} - /// Input messages for the gossip [`Actor`]. #[derive(derive_more::Debug)] enum ToActor { - /// Handle a new QUIC connection, either from accept (external to the actor) or from connect - /// (happens internally in the actor). - ConnIncoming(PublicKey, ConnOrigin, #[debug(skip)] Connection), + /// Handle a new incoming QUIC connection. + ConnIncoming(iroh_net::endpoint::Connection), /// Join a topic with a list of peers. Reply with oneshot once at least one peer joined. Join( TopicId, @@ -329,8 +320,8 @@ struct Actor { /// Protocol state state: proto::State, endpoint: Endpoint, - /// Dial machine to connect to peers - dialer: Dialer, + /// Connection manager to dial and accept connections. + conn_manager: ConnManager, /// Input messages to the actor to_actor_rx: mpsc::Receiver, /// Sender for the state input (cloned into the connection loops) @@ -341,10 +332,10 @@ struct Actor { on_endpoints_rx: mpsc::Receiver>, /// Queued timers timers: Timers, - /// Currently opened quinn connections to peers - conns: HashMap, /// Channels to send outbound messages into the connection loops conn_send_tx: HashMap>, + /// Connection loop tasks + conn_tasks: JoinSet<(PublicKey, anyhow::Result<()>)>, /// Queued messages that were to be sent before a dial completed pending_sends: HashMap>, /// Broadcast senders for active topic subscriptions from the application @@ -353,6 +344,12 @@ struct Actor { subscribers_all: Option>, } +impl Drop for Actor { + fn drop(&mut self) { + self.conn_tasks.abort_all(); + } +} + impl Actor { pub async fn run(mut self) -> anyhow::Result<()> { let mut i = 0; @@ -384,15 +381,27 @@ impl Actor { } } } - (peer_id, res) = self.dialer.next_conn() => { - trace!(?i, "tick: dialer"); + Some(res) = self.conn_manager.next() => { + trace!(?i, "tick: conn_manager"); match res { - Ok(conn) => { - debug!(peer = ?peer_id, "dial successful"); - self.handle_to_actor_msg(ToActor::ConnIncoming(peer_id, ConnOrigin::Dial, conn), Instant::now()).await.context("dialer.next -> conn -> handle_to_actor_msg")?; - } + Ok(conn) => self.handle_new_connection(conn).await, Err(err) => { - warn!(peer = ?peer_id, "dial failed: {err}"); + self.handle_in_event(InEvent::PeerDisconnected(err.node_id), Instant::now()).await?; + } + } + } + Some(res) = self.conn_tasks.join_next(), if !self.conn_tasks.is_empty() => { + match res { + Err(err) if !err.is_cancelled() => warn!(?err, "connection loop panicked"), + Err(_err) => {}, + Ok((node_id, result)) => { + self.conn_manager.remove(&node_id); + self.conn_send_tx.remove(&node_id); + self.handle_in_event(InEvent::PeerDisconnected(node_id), Instant::now()).await?; + match result { + Ok(()) => debug!(peer=%node_id.fmt_short(), "connection closed without error"), + Err(err) => debug!(peer=%node_id.fmt_short(), "connection closed with error {err:?}"), + } } } } @@ -421,38 +430,9 @@ impl Actor { async fn handle_to_actor_msg(&mut self, msg: ToActor, now: Instant) -> anyhow::Result<()> { trace!("handle to_actor {msg:?}"); match msg { - ToActor::ConnIncoming(peer_id, origin, conn) => { - self.conns.insert(peer_id, conn.clone()); - self.dialer.abort_dial(&peer_id); - let (send_tx, send_rx) = mpsc::channel(SEND_QUEUE_CAP); - self.conn_send_tx.insert(peer_id, send_tx.clone()); - - // Spawn a task for this connection - let in_event_tx = self.in_event_tx.clone(); - tokio::spawn( - async move { - debug!("connection established"); - match connection_loop(peer_id, conn, origin, send_rx, &in_event_tx).await { - Ok(()) => { - debug!("connection closed without error") - } - Err(err) => { - debug!("connection closed with error {err:?}") - } - } - in_event_tx - .send(InEvent::PeerDisconnected(peer_id)) - .await - .ok(); - } - .instrument(error_span!("gossip_conn", peer = %peer_id.fmt_short())), - ); - - // Forward queued pending sends - if let Some(send_queue) = self.pending_sends.remove(&peer_id) { - for msg in send_queue { - send_tx.send(msg).await?; - } + ToActor::ConnIncoming(conn) => { + if let Err(err) = self.conn_manager.accept(conn) { + warn!(?err, "failed to accept connection"); } } ToActor::Join(topic_id, peers, reply) => { @@ -502,9 +482,6 @@ impl Actor { } else { debug!("handle in_event {event:?}"); }; - if let InEvent::PeerDisconnected(peer) = &event { - self.conn_send_tx.remove(peer); - } let out = self.state.handle(event, now); for event in out { if matches!(event, OutEvent::ScheduleTimer(_, _)) { @@ -518,10 +495,13 @@ impl Actor { if let Err(_err) = send.send(message).await { warn!("conn receiver for {peer_id:?} dropped"); self.conn_send_tx.remove(&peer_id); + self.conn_manager.remove(&peer_id); } } else { - debug!(peer = ?peer_id, "dial"); - self.dialer.queue_dial(peer_id, GOSSIP_ALPN); + if !self.conn_manager.is_pending(&peer_id) { + debug!(peer = ?peer_id, "dial"); + self.conn_manager.dial(peer_id); + } // TODO: Enforce max length self.pending_sends.entry(peer_id).or_default().push(message); } @@ -544,12 +524,11 @@ impl Actor { self.timers.insert(now + delay, timer); } OutEvent::DisconnectPeer(peer) => { - if let Some(conn) = self.conns.remove(&peer) { - conn.close(0u8.into(), b"close from disconnect"); - } self.conn_send_tx.remove(&peer); self.pending_sends.remove(&peer); - self.dialer.abort_dial(&peer); + if let Some(conn) = self.conn_manager.remove(&peer) { + conn.close(0u8.into(), b"close from disconnect"); + } } OutEvent::PeerData(node_id, data) => match decode_peer_data(&data) { Err(err) => warn!("Failed to decode {data:?} from {node_id}: {err}"), @@ -566,6 +545,33 @@ impl Actor { Ok(()) } + async fn handle_new_connection(&mut self, new_conn: ConnInfo) { + let ConnInfo { + conn, + node_id, + direction, + } = new_conn; + let (send_tx, send_rx) = mpsc::channel(SEND_QUEUE_CAP); + self.conn_send_tx.insert(node_id, send_tx.clone()); + + // Spawn a task for this connection + let pending_sends = self.pending_sends.remove(&node_id); + let in_event_tx = self.in_event_tx.clone(); + debug!(peer=%node_id.fmt_short(), ?direction, "connection established"); + self.conn_tasks.spawn( + connection_loop( + node_id, + conn, + direction, + send_rx, + in_event_tx, + pending_sends, + ) + .map(move |r| (node_id, r)) + .instrument(error_span!("gossip_conn", peer = %node_id.fmt_short())), + ); + } + fn subscribe_all(&mut self) -> broadcast::Receiver<(TopicId, Event)> { if let Some(tx) = self.subscribers_all.as_mut() { tx.subscribe() @@ -602,16 +608,26 @@ async fn wait_for_neighbor_up(mut sub: broadcast::Receiver) -> anyhow::Re async fn connection_loop( from: PublicKey, conn: Connection, - origin: ConnOrigin, + direction: ConnDirection, mut send_rx: mpsc::Receiver, - in_event_tx: &mpsc::Sender, + in_event_tx: mpsc::Sender, + mut pending_sends: Option>, ) -> anyhow::Result<()> { - let (mut send, mut recv) = match origin { - ConnOrigin::Accept => conn.accept_bi().await?, - ConnOrigin::Dial => conn.open_bi().await?, + let (mut send, mut recv) = match direction { + ConnDirection::Accept => conn.accept_bi().await?, + ConnDirection::Dial => conn.open_bi().await?, }; let mut send_buf = BytesMut::new(); let mut recv_buf = BytesMut::new(); + + // Forward queued pending sends + if let Some(mut send_queue) = pending_sends.take() { + for msg in send_queue.drain(..) { + write_message(&mut send, &mut send_buf, &msg).await?; + } + } + + // loop over sending and receiving messages loop { tokio::select! { biased; diff --git a/iroh-net/Cargo.toml b/iroh-net/Cargo.toml index 9d0f7cdc36..aedfcef97b 100644 --- a/iroh-net/Cargo.toml +++ b/iroh-net/Cargo.toml @@ -53,9 +53,10 @@ quinn = { package = "iroh-quinn", version = "0.10.4" } quinn-proto = { package = "iroh-quinn-proto", version = "0.10.7" } quinn-udp = { package = "iroh-quinn-udp", version = "0.4" } rand = "0.8" +rand_chacha = { version = "0.3.1", optional = true } rand_core = "0.6.4" rcgen = "0.11" -reqwest = { version = "0.12.4", default-features = false, features = ["rustls-tls"] } +reqwest = { version = "0.11.19", default-features = false, features = ["rustls-tls"] } ring = "0.17" rustls = { version = "0.21.11", default-features = false, features = ["dangerous_configuration"] } serde = { version = "1", features = ["derive", "rc"] } @@ -125,7 +126,7 @@ duct = "0.13.6" default = ["metrics"] iroh-relay = ["clap", "toml", "rustls-pemfile", "regex", "serde_with", "tracing-subscriber"] metrics = ["iroh-metrics/metrics"] -test-utils = ["axum"] +test-utils = ["axum", "rand_chacha"] [[bin]] name = "iroh-relay" diff --git a/iroh-net/src/conn_manager.rs b/iroh-net/src/conn_manager.rs new file mode 100644 index 0000000000..4b9828558b --- /dev/null +++ b/iroh-net/src/conn_manager.rs @@ -0,0 +1,440 @@ +//! A connection manager to ensure a single connection between each pair of peers. + +use std::{ + collections::HashMap, + pin::Pin, + task::{ready, Context, Poll, Waker}, +}; + +use futures_lite::{Future, Stream}; +use futures_util::FutureExt; +use tokio::{ + sync::mpsc, + task::{AbortHandle, JoinSet}, +}; +use tracing::{debug, error}; + +use crate::{ + endpoint::{get_remote_node_id, Connection}, + Endpoint, NodeId, +}; + +const DUPLICATE_REASON: &[u8] = b"abort_duplicate"; +const DUPLICATE_CODE: u32 = 123; + +/// A connection manager. +/// +/// The [`ConnManager`] does not accept connections from the endpoint by itself. Instead, you +/// should run an accept loop yourself, and push connections with a matching ALPN into the manager +/// with [`ConnManager::accept`]. The connection will be dropped if we already have a connection to +/// that node. If we are currently dialing the node, the connection will only be accepted if the +/// peer's node id sorts lower than our node id. Through this, it is ensured that we will not get +/// double connections with a node if both we and them dial each other at the same time. +/// +/// The [`ConnManager`] implements [`Stream`]. It will yield new connections, both from dialing and +/// accepting. +#[derive(Debug)] +pub struct ConnManager { + endpoint: Endpoint, + alpn: &'static [u8], + active: HashMap, + pending: HashMap, + tasks: JoinSet<(NodeId, Result)>, + accept_tx: mpsc::Sender, + accept_rx: mpsc::Receiver, + waker: Option, +} + +impl ConnManager { + /// Create a new connection manager. + pub fn new(endpoint: Endpoint, alpn: &'static [u8]) -> Self { + let (accept_tx, accept_rx) = mpsc::channel(128); + Self { + endpoint, + alpn, + active: Default::default(), + accept_tx, + accept_rx, + tasks: JoinSet::new(), + pending: HashMap::new(), + waker: None, + } + } + + /// Start to dial a node. + /// + /// This is a no-op if the a connection to the node is already active or if we are currently + /// dialing the node already. + /// + /// Returns `true` if this is initiates connecting to the node. + pub fn dial(&mut self, node_id: NodeId) -> bool { + if self.is_pending(&node_id) || self.is_connected(&node_id) { + false + } else { + self.spawn( + node_id, + ConnDirection::Dial, + connect_task(self.endpoint.clone(), node_id, self.alpn), + ); + true + } + } + + fn spawn( + &mut self, + node_id: NodeId, + direction: ConnDirection, + fut: impl Future> + Send + 'static, + ) { + let abort_handle = self.tasks.spawn(fut.map(move |res| (node_id, res))); + let pending_state = PendingState { + direction, + abort_handle, + }; + self.pending.insert(node_id, pending_state); + if let Some(waker) = self.waker.take() { + waker.wake(); + } + } + + /// Get a sender to push new connections towards the [`ConnManager`] + /// + /// This does not check the connection's ALPN, so you should make sure that the ALPN matches + /// the [`ConnManager`]'s execpected ALPN before passing the connection to the sender. + /// + /// If we are currently dialing the node, the connection will be dropped if the peer's node id + /// sorty higher than our node id. Otherwise, the connection will be yielded from the manager + /// stream. + pub fn accept_sender(&self) -> AcceptSender { + let tx = self.accept_tx.clone(); + AcceptSender { tx } + } + + /// Accept a connection. + /// + /// This does not check the connection's ALPN, so you should make sure that the ALPN matches + /// the [`ConnManager`]'s execpected ALPN before passing the connection to the sender. + /// + /// If we are currently dialing the node, the connection will be dropped if the peer's node id + /// sorty higher than our node id. Otherwise, the connection will be returned. + pub fn accept(&mut self, conn: quinn::Connection) -> anyhow::Result<()> { + let node_id = get_remote_node_id(&conn)?; + // We are already connected: drop the connection, keep using the existing conn. + if self.is_connected(&node_id) { + return Ok(()); + } + + let accept = match self.pending.get(&node_id) { + // We are currently dialing the node, but the incoming conn "wins": accept and abort + // our dial. + Some(state) + if state.direction == ConnDirection::Dial && node_id > self.our_node_id() => + { + state.abort_handle.abort(); + true + } + // We are currently processing a connection for this node: do not accept a second conn. + Some(_state) => false, + // The connection is new: accept. + None => true, + }; + + if accept { + self.spawn(node_id, ConnDirection::Accept, accept_task(conn)); + } else { + conn.close(DUPLICATE_CODE.into(), DUPLICATE_REASON); + } + Ok(()) + } + + /// Remove the connection to a node. + /// + /// Also aborts pending dials to the node, if existing. + /// + /// Returns the connection if it existed. + pub fn remove(&mut self, node_id: &NodeId) -> Option { + if let Some(state) = self.pending.remove(node_id) { + state.abort_handle.abort(); + } + self.active.remove(node_id) + } + + /// Returns the connection to a node, if connected. + pub fn get(&self, node_id: &NodeId) -> Option<&ConnInfo> { + self.active.get(node_id) + } + + /// Returns `true` if we are currently establishing a connection to the node. + pub fn is_pending(&self, node_id: &NodeId) -> bool { + self.pending.contains_key(node_id) + } + + /// Returns `true` if we are connected to the node. + pub fn is_connected(&self, node_id: &NodeId) -> bool { + self.active.contains_key(node_id) + } + + fn our_node_id(&self) -> NodeId { + self.endpoint.node_id() + } +} + +impl Stream for ConnManager { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + tracing::debug!("poll_next in"); + // Create new tasks for incoming connections. + while let Poll::Ready(Some(conn)) = Pin::new(&mut self.accept_rx).poll_recv(cx) { + // self.accept(conn) + debug!("accept - polled"); + if let Err(error) = self.accept(conn) { + tracing::warn!(?error, "skipping invalid connection attempt"); + } + } + + // Poll for finished tasks, + loop { + let join_res = ready!(self.tasks.poll_join_next(cx)); + debug!(?join_res, "join res"); + let (node_id, res) = match join_res { + None => { + self.waker = Some(cx.waker().to_owned()); + return Poll::Pending; + } + Some(Err(err)) if err.is_cancelled() => continue, + // we are merely forwarding a panic here, which should never occur. + Some(Err(err)) => panic!("connection manager task paniced with {err:?}"), + Some(Ok(res)) => res, + }; + match res { + Err(InitError::IsDuplicate) => continue, + Err(InitError::Other(reason)) => { + let Some(PendingState { direction, .. }) = self.pending.remove(&node_id) else { + // TODO: unreachable? + tracing::warn!(node_id=%node_id.fmt_short(), "missing pending state, dropping connection"); + continue; + }; + let err = ConnectError { + node_id, + reason, + direction, + }; + break Poll::Ready(Some(Err(err))); + } + Ok(conn) => { + let Some(PendingState { direction, .. }) = self.pending.remove(&node_id) else { + // TODO: unreachable? + tracing::warn!(node_id=%node_id.fmt_short(), "missing pending state, dropping connection"); + continue; + }; + let info = ConnInfo { + conn, + node_id, + direction, + }; + self.active.insert(node_id, info.clone()); + break Poll::Ready(Some(Ok(info))); + } + } + } + } +} + +async fn accept_task(conn: Connection) -> Result { + let mut stream = conn.open_uni().await?; + stream.write_all(&[0]).await?; + stream.finish().await?; + Ok(conn) +} + +async fn connect_task( + ep: Endpoint, + node_id: NodeId, + alpn: &'static [u8], +) -> Result { + let conn = ep.connect_by_node_id(&node_id, alpn).await?; + let mut stream = conn.accept_uni().await?; + stream.read_to_end(1).await?; + Ok(conn) +} + +#[derive(Debug)] +struct PendingState { + direction: ConnDirection, + abort_handle: AbortHandle, +} + +/// A sender to push new connections into a [`ConnManager`]. +/// +/// See [`ConnManager::accept_sender`] for details. +#[derive(Debug, Clone)] +pub struct AcceptSender { + tx: mpsc::Sender, +} + +impl AcceptSender { + /// Send a new connection to the [`ConnManager`]. + pub async fn send(&self, conn: Connection) -> anyhow::Result<()> { + self.tx.send(conn).await?; + Ok(()) + } +} + +/// The error returned from [`ConnManager::poll_next`]. +#[derive(thiserror::Error, Debug)] +#[error("Connection to node {} direction {:?} failed: {:?}", self.node_id, self.direction, self.reason)] +pub struct ConnectError { + /// The node id of the peer to which the connection failed. + pub node_id: NodeId, + /// The direction of the connection. + pub direction: ConnDirection, + /// The actual error that ocurred. + #[source] + pub reason: anyhow::Error, +} + +#[derive(Debug)] +enum InitError { + IsDuplicate, + Other(anyhow::Error), +} + +impl From for InitError { + fn from(value: anyhow::Error) -> Self { + Self::Other(value) + } +} + +impl From for InitError { + fn from(value: quinn::ConnectionError) -> Self { + match &value { + quinn::ConnectionError::ApplicationClosed(err) + if &err.reason[..] == DUPLICATE_REASON + && err.error_code == DUPLICATE_CODE.into() => + { + Self::IsDuplicate + } + _ => Self::Other(value.into()), + } + } +} + +impl From for InitError { + fn from(value: quinn::ReadToEndError) -> Self { + match value { + quinn::ReadToEndError::Read(quinn::ReadError::ConnectionLost(err)) => err.into(), + err @ _ => Self::Other(err.into()), + } + } +} + +impl From for InitError { + fn from(value: quinn::WriteError) -> Self { + match value { + quinn::WriteError::ConnectionLost(err) => err.into(), + err @ _ => Self::Other(err.into()), + } + } +} + +/// Whether we accepted the connection or initiated it. +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub enum ConnDirection { + /// We accepted this connection from the other peer. + Accept, + /// We initiated this connection by connecting to the other peer. + Dial, +} + +/// A new connection as emitted from [`ConnManager`]. +#[derive(Debug, Clone, derive_more::Deref)] +pub struct ConnInfo { + /// The QUIC connection. + #[deref] + pub conn: Connection, + /// The node id of the other peer. + pub node_id: NodeId, + /// Whether we accepted or initiated this connection. + pub direction: ConnDirection, +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use futures_lite::StreamExt; + use tokio::task::JoinSet; + + use crate::test_utils::TestEndpointFactory; + + use super::{AcceptSender, ConnManager}; + + const TEST_ALPN: &[u8] = b"test"; + + async fn accept_loop(ep: crate::Endpoint, accept_sender: AcceptSender) -> anyhow::Result<()> { + while let Some(conn) = ep.accept().await { + let conn = conn.await?; + tracing::debug!(me=%ep.node_id().fmt_short(), "conn incoming"); + accept_sender.send(conn).await?; + } + Ok(()) + } + + #[tokio::test] + async fn test_conn_manager() -> anyhow::Result<()> { + let _guard = iroh_test::logging::setup(); + let mut factory = TestEndpointFactory::run().await?; + + let alpns = vec![TEST_ALPN.to_vec()]; + let ep1 = factory.create_endpoint(alpns.clone()).await?; + let ep2 = factory.create_endpoint(alpns.clone()).await?; + let n1 = ep1.node_id(); + let n2 = ep2.node_id(); + tracing::info!(?n1, ?n2, "endpoints created"); + factory.on_node(&n1, Duration::from_secs(2)).await?; + factory.on_node(&n2, Duration::from_secs(2)).await?; + + let mut conn_manager1 = ConnManager::new(ep1.clone(), TEST_ALPN); + let mut conn_manager2 = ConnManager::new(ep2.clone(), TEST_ALPN); + + let accept1 = conn_manager1.accept_sender(); + let accept2 = conn_manager2.accept_sender(); + let mut tasks = JoinSet::new(); + tasks.spawn(accept_loop(ep1, accept1)); + tasks.spawn(accept_loop(ep2, accept2)); + + for i in 0u8..20 { + tracing::info!(i, "start dial"); + conn_manager1.dial(n2); + conn_manager2.dial(n1); + let (conn1, conn2) = tokio::join!(conn_manager1.next(), conn_manager2.next()); + let conn1 = conn1.unwrap().unwrap(); + let conn2 = conn2.unwrap().unwrap(); + + tracing::info!(?conn1.direction, "conn1"); + tracing::info!(?conn2.direction, "conn2"); + assert!(conn1.direction != conn2.direction); + assert_eq!(conn1.node_id, n2); + assert_eq!(conn2.node_id, n1); + + let mut s1 = conn1.open_uni().await.unwrap(); + s1.write_all(&[i]).await?; + s1.finish().await?; + + let mut s2 = conn2.accept_uni().await.unwrap(); + let x = s2.read_to_end(1).await.unwrap(); + + assert_eq!(&x, &[i]); + assert!(conn_manager1.remove(&n2).is_some()); + assert!(conn_manager2.remove(&n1).is_some()); + } + + tasks.abort_all(); + while let Some(r) = tasks.join_next().await { + assert!(r.unwrap_err().is_cancelled()); + } + + Ok(()) + } +} diff --git a/iroh-net/src/lib.rs b/iroh-net/src/lib.rs index 911060df6f..b613598a4a 100644 --- a/iroh-net/src/lib.rs +++ b/iroh-net/src/lib.rs @@ -11,6 +11,7 @@ #![deny(missing_docs, rustdoc::broken_intra_doc_links)] pub mod config; +pub mod conn_manager; pub mod defaults; pub mod dialer; mod disco; diff --git a/iroh-net/src/test_utils.rs b/iroh-net/src/test_utils.rs index 0cbf8bd857..bec4f2cf60 100644 --- a/iroh-net/src/test_utils.rs +++ b/iroh-net/src/test_utils.rs @@ -1,12 +1,14 @@ //! Internal utilities to support testing. use anyhow::Result; +use rand::SeedableRng; use tokio::sync::oneshot; use tracing::{error_span, info_span, Instrument}; use crate::{ key::SecretKey, - relay::{RelayMap, RelayNode, RelayUrl}, + relay::{RelayMap, RelayMode, RelayNode, RelayUrl}, + Endpoint, }; pub use dns_and_pkarr_servers::DnsPkarrServer; @@ -68,6 +70,76 @@ pub async fn run_relay_server() -> Result<(RelayMap, RelayUrl, CleanupDropGuard) Ok((m, url, CleanupDropGuard(tx))) } +/// A factory for endpoints with preconfigured local discovery and relays. +#[derive(Debug)] +pub struct TestEndpointFactory { + relay_map: RelayMap, + relay_url: RelayUrl, + _relay_drop_guard: CleanupDropGuard, + dns_pkarr_server: DnsPkarrServer, + rng: rand_chacha::ChaCha12Rng, +} + +impl TestEndpointFactory { + /// Starts local relay and discovery servers and returns a [`TestEndpointFactory`] to create + /// readily configured endpoints. + /// + /// The local servers will shut down once the [`TestEndpointFactory`] is dropped. + pub async fn run() -> anyhow::Result { + let dns_pkarr_server = DnsPkarrServer::run().await?; + let (relay_map, relay_url, relay_drop_guard) = run_relay_server().await?; + Ok(Self { + relay_map, + relay_url, + dns_pkarr_server, + _relay_drop_guard: relay_drop_guard, + rng: rand_chacha::ChaCha12Rng::seed_from_u64(77), + }) + } + + /// Create a new endpoint builder which already has discovery and relays configured. + pub fn create_endpoint_builder(&self, secret_key: SecretKey) -> crate::endpoint::Builder { + Endpoint::builder() + .secret_key(secret_key.clone()) + .relay_mode(RelayMode::Custom(self.relay_map.clone())) + .insecure_skip_relay_cert_verify(true) + .dns_resolver(self.dns_pkarr_server.dns_resolver()) + .discovery(self.dns_pkarr_server.discovery(secret_key)) + } + + /// Create a new endpoint with the specified ALPNs. + /// + /// The endpoint will have discovery and relays configured, and have a predictable secret key. + pub async fn create_endpoint(&mut self, alpns: Vec>) -> anyhow::Result { + let secret_key = SecretKey::generate_with_rng(&mut self.rng); + self.create_endpoint_builder(secret_key) + .alpns(alpns) + .bind(0) + .await + } + + /// Returns the URL of the local relay server. + pub fn relay_url(&self) -> &RelayUrl { + &self.relay_url + } + + /// Returns a relay map which contains the local relay server. + pub fn relay_map(&self) -> &RelayMap { + &self.relay_map + } + + /// Wait until a Pkarr announce for a node is published to the server. + /// + /// If `timeout` elapses an error is returned. + pub async fn on_node( + &self, + node_id: &crate::NodeId, + timeout: std::time::Duration, + ) -> Result<()> { + self.dns_pkarr_server.on_node(node_id, timeout).await + } +} + pub(crate) mod dns_and_pkarr_servers { use anyhow::Result; use iroh_base::key::{NodeId, SecretKey};