From 6229f8af5182747872742689767728562266ce66 Mon Sep 17 00:00:00 2001 From: "Franz Heinzmann (Frando)" Date: Thu, 23 May 2024 01:01:35 +0200 Subject: [PATCH] refactor: make the conn manager actually work as intended --- iroh-gossip/src/net.rs | 76 +++--- iroh-net/src/conn_manager.rs | 440 +++++++++++++++++++++++++++++++++++ iroh-net/src/lib.rs | 1 + 3 files changed, 478 insertions(+), 39 deletions(-) create mode 100644 iroh-net/src/conn_manager.rs diff --git a/iroh-gossip/src/net.rs b/iroh-gossip/src/net.rs index f350e7b44e..237866b5a9 100644 --- a/iroh-gossip/src/net.rs +++ b/iroh-gossip/src/net.rs @@ -6,7 +6,7 @@ use futures_lite::{stream::Stream, StreamExt}; use futures_util::future::FutureExt; use genawaiter::sync::{Co, Gen}; use iroh_net::{ - dialer::{ConnDirection, ConnManager, NewConnection}, + dialer::{ConnDirection, ConnInfo, ConnManager}, endpoint::Connection, key::PublicKey, AddrInfo, Endpoint, NodeAddr, @@ -382,9 +382,14 @@ impl Actor { } } } - Some(new_conn) = self.conn_manager.next() => { + Some(res) = self.conn_manager.next() => { trace!(?i, "tick: conn_manager"); - self.handle_new_connection(new_conn).await; + match res { + Ok(conn) => self.handle_new_connection(conn).await, + Err(err) => { + self.handle_in_event(InEvent::PeerDisconnected(err.node_id), Instant::now()).await?; + } + } } Some(res) = self.conn_tasks.join_next(), if !self.conn_tasks.is_empty() => { match res { @@ -393,7 +398,7 @@ impl Actor { Ok((node_id, result)) => { self.conn_manager.remove(&node_id); self.conn_send_tx.remove(&node_id); - self.handle_in_event(InEvent::PeerDisconnected(node_id), Instant::now()).await ?; + self.handle_in_event(InEvent::PeerDisconnected(node_id), Instant::now()).await?; match result { Ok(()) => { debug!(peer=%node_id.fmt_short(), "connection closed without error"); @@ -430,11 +435,11 @@ impl Actor { async fn handle_to_actor_msg(&mut self, msg: ToActor, now: Instant) -> anyhow::Result<()> { trace!("handle to_actor {msg:?}"); match msg { - ToActor::AcceptConn(conn) => match self.conn_manager.accept(conn) { - Err(err) => warn!(?err, "failed to accept connection"), - Ok(None) => {} - Ok(Some(conn)) => self.handle_new_connection(conn).await, - }, + ToActor::AcceptConn(conn) => { + if let Err(err) = self.conn_manager.accept(conn) { + warn!(?err, "failed to accept connection"); + } + } ToActor::Join(topic_id, peers, reply) => { self.handle_in_event(InEvent::Command(topic_id, Command::Join(peers)), now) .await?; @@ -498,7 +503,7 @@ impl Actor { self.conn_manager.remove(&peer_id); } } else { - if !self.conn_manager.is_dialing(&peer_id) { + if !self.conn_manager.is_pending(&peer_id) { debug!(peer = ?peer_id, "dial"); self.conn_manager.dial(peer_id); } @@ -545,38 +550,31 @@ impl Actor { Ok(()) } - async fn handle_new_connection(&mut self, new_conn: NewConnection) { - let NewConnection { + async fn handle_new_connection(&mut self, new_conn: ConnInfo) { + let ConnInfo { conn, - node_id: peer_id, + node_id, direction, } = new_conn; - match conn { - Ok(conn) => { - let (send_tx, send_rx) = mpsc::channel(SEND_QUEUE_CAP); - self.conn_send_tx.insert(peer_id, send_tx.clone()); - - // Spawn a task for this connection - let pending_sends = self.pending_sends.remove(&peer_id); - let in_event_tx = self.in_event_tx.clone(); - debug!(peer=%peer_id.fmt_short(), ?direction, "connection established"); - self.conn_tasks.spawn( - connection_loop( - peer_id, - conn, - direction, - send_rx, - in_event_tx, - pending_sends, - ) - .map(move |r| (peer_id, r)) - .instrument(error_span!("gossip_conn", peer = %peer_id.fmt_short())), - ); - } - Err(err) => { - warn!(peer=%peer_id.fmt_short(), "connecting to node failed: {err:?}"); - } - } + let (send_tx, send_rx) = mpsc::channel(SEND_QUEUE_CAP); + self.conn_send_tx.insert(node_id, send_tx.clone()); + + // Spawn a task for this connection + let pending_sends = self.pending_sends.remove(&node_id); + let in_event_tx = self.in_event_tx.clone(); + debug!(peer=%node_id.fmt_short(), ?direction, "connection established"); + self.conn_tasks.spawn( + connection_loop( + node_id, + conn, + direction, + send_rx, + in_event_tx, + pending_sends, + ) + .map(move |r| (node_id, r)) + .instrument(error_span!("gossip_conn", peer = %node_id.fmt_short())), + ); } fn subscribe_all(&mut self) -> broadcast::Receiver<(TopicId, Event)> { diff --git a/iroh-net/src/conn_manager.rs b/iroh-net/src/conn_manager.rs new file mode 100644 index 0000000000..4b9828558b --- /dev/null +++ b/iroh-net/src/conn_manager.rs @@ -0,0 +1,440 @@ +//! A connection manager to ensure a single connection between each pair of peers. + +use std::{ + collections::HashMap, + pin::Pin, + task::{ready, Context, Poll, Waker}, +}; + +use futures_lite::{Future, Stream}; +use futures_util::FutureExt; +use tokio::{ + sync::mpsc, + task::{AbortHandle, JoinSet}, +}; +use tracing::{debug, error}; + +use crate::{ + endpoint::{get_remote_node_id, Connection}, + Endpoint, NodeId, +}; + +const DUPLICATE_REASON: &[u8] = b"abort_duplicate"; +const DUPLICATE_CODE: u32 = 123; + +/// A connection manager. +/// +/// The [`ConnManager`] does not accept connections from the endpoint by itself. Instead, you +/// should run an accept loop yourself, and push connections with a matching ALPN into the manager +/// with [`ConnManager::accept`]. The connection will be dropped if we already have a connection to +/// that node. If we are currently dialing the node, the connection will only be accepted if the +/// peer's node id sorts lower than our node id. Through this, it is ensured that we will not get +/// double connections with a node if both we and them dial each other at the same time. +/// +/// The [`ConnManager`] implements [`Stream`]. It will yield new connections, both from dialing and +/// accepting. +#[derive(Debug)] +pub struct ConnManager { + endpoint: Endpoint, + alpn: &'static [u8], + active: HashMap, + pending: HashMap, + tasks: JoinSet<(NodeId, Result)>, + accept_tx: mpsc::Sender, + accept_rx: mpsc::Receiver, + waker: Option, +} + +impl ConnManager { + /// Create a new connection manager. + pub fn new(endpoint: Endpoint, alpn: &'static [u8]) -> Self { + let (accept_tx, accept_rx) = mpsc::channel(128); + Self { + endpoint, + alpn, + active: Default::default(), + accept_tx, + accept_rx, + tasks: JoinSet::new(), + pending: HashMap::new(), + waker: None, + } + } + + /// Start to dial a node. + /// + /// This is a no-op if the a connection to the node is already active or if we are currently + /// dialing the node already. + /// + /// Returns `true` if this is initiates connecting to the node. + pub fn dial(&mut self, node_id: NodeId) -> bool { + if self.is_pending(&node_id) || self.is_connected(&node_id) { + false + } else { + self.spawn( + node_id, + ConnDirection::Dial, + connect_task(self.endpoint.clone(), node_id, self.alpn), + ); + true + } + } + + fn spawn( + &mut self, + node_id: NodeId, + direction: ConnDirection, + fut: impl Future> + Send + 'static, + ) { + let abort_handle = self.tasks.spawn(fut.map(move |res| (node_id, res))); + let pending_state = PendingState { + direction, + abort_handle, + }; + self.pending.insert(node_id, pending_state); + if let Some(waker) = self.waker.take() { + waker.wake(); + } + } + + /// Get a sender to push new connections towards the [`ConnManager`] + /// + /// This does not check the connection's ALPN, so you should make sure that the ALPN matches + /// the [`ConnManager`]'s execpected ALPN before passing the connection to the sender. + /// + /// If we are currently dialing the node, the connection will be dropped if the peer's node id + /// sorty higher than our node id. Otherwise, the connection will be yielded from the manager + /// stream. + pub fn accept_sender(&self) -> AcceptSender { + let tx = self.accept_tx.clone(); + AcceptSender { tx } + } + + /// Accept a connection. + /// + /// This does not check the connection's ALPN, so you should make sure that the ALPN matches + /// the [`ConnManager`]'s execpected ALPN before passing the connection to the sender. + /// + /// If we are currently dialing the node, the connection will be dropped if the peer's node id + /// sorty higher than our node id. Otherwise, the connection will be returned. + pub fn accept(&mut self, conn: quinn::Connection) -> anyhow::Result<()> { + let node_id = get_remote_node_id(&conn)?; + // We are already connected: drop the connection, keep using the existing conn. + if self.is_connected(&node_id) { + return Ok(()); + } + + let accept = match self.pending.get(&node_id) { + // We are currently dialing the node, but the incoming conn "wins": accept and abort + // our dial. + Some(state) + if state.direction == ConnDirection::Dial && node_id > self.our_node_id() => + { + state.abort_handle.abort(); + true + } + // We are currently processing a connection for this node: do not accept a second conn. + Some(_state) => false, + // The connection is new: accept. + None => true, + }; + + if accept { + self.spawn(node_id, ConnDirection::Accept, accept_task(conn)); + } else { + conn.close(DUPLICATE_CODE.into(), DUPLICATE_REASON); + } + Ok(()) + } + + /// Remove the connection to a node. + /// + /// Also aborts pending dials to the node, if existing. + /// + /// Returns the connection if it existed. + pub fn remove(&mut self, node_id: &NodeId) -> Option { + if let Some(state) = self.pending.remove(node_id) { + state.abort_handle.abort(); + } + self.active.remove(node_id) + } + + /// Returns the connection to a node, if connected. + pub fn get(&self, node_id: &NodeId) -> Option<&ConnInfo> { + self.active.get(node_id) + } + + /// Returns `true` if we are currently establishing a connection to the node. + pub fn is_pending(&self, node_id: &NodeId) -> bool { + self.pending.contains_key(node_id) + } + + /// Returns `true` if we are connected to the node. + pub fn is_connected(&self, node_id: &NodeId) -> bool { + self.active.contains_key(node_id) + } + + fn our_node_id(&self) -> NodeId { + self.endpoint.node_id() + } +} + +impl Stream for ConnManager { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + tracing::debug!("poll_next in"); + // Create new tasks for incoming connections. + while let Poll::Ready(Some(conn)) = Pin::new(&mut self.accept_rx).poll_recv(cx) { + // self.accept(conn) + debug!("accept - polled"); + if let Err(error) = self.accept(conn) { + tracing::warn!(?error, "skipping invalid connection attempt"); + } + } + + // Poll for finished tasks, + loop { + let join_res = ready!(self.tasks.poll_join_next(cx)); + debug!(?join_res, "join res"); + let (node_id, res) = match join_res { + None => { + self.waker = Some(cx.waker().to_owned()); + return Poll::Pending; + } + Some(Err(err)) if err.is_cancelled() => continue, + // we are merely forwarding a panic here, which should never occur. + Some(Err(err)) => panic!("connection manager task paniced with {err:?}"), + Some(Ok(res)) => res, + }; + match res { + Err(InitError::IsDuplicate) => continue, + Err(InitError::Other(reason)) => { + let Some(PendingState { direction, .. }) = self.pending.remove(&node_id) else { + // TODO: unreachable? + tracing::warn!(node_id=%node_id.fmt_short(), "missing pending state, dropping connection"); + continue; + }; + let err = ConnectError { + node_id, + reason, + direction, + }; + break Poll::Ready(Some(Err(err))); + } + Ok(conn) => { + let Some(PendingState { direction, .. }) = self.pending.remove(&node_id) else { + // TODO: unreachable? + tracing::warn!(node_id=%node_id.fmt_short(), "missing pending state, dropping connection"); + continue; + }; + let info = ConnInfo { + conn, + node_id, + direction, + }; + self.active.insert(node_id, info.clone()); + break Poll::Ready(Some(Ok(info))); + } + } + } + } +} + +async fn accept_task(conn: Connection) -> Result { + let mut stream = conn.open_uni().await?; + stream.write_all(&[0]).await?; + stream.finish().await?; + Ok(conn) +} + +async fn connect_task( + ep: Endpoint, + node_id: NodeId, + alpn: &'static [u8], +) -> Result { + let conn = ep.connect_by_node_id(&node_id, alpn).await?; + let mut stream = conn.accept_uni().await?; + stream.read_to_end(1).await?; + Ok(conn) +} + +#[derive(Debug)] +struct PendingState { + direction: ConnDirection, + abort_handle: AbortHandle, +} + +/// A sender to push new connections into a [`ConnManager`]. +/// +/// See [`ConnManager::accept_sender`] for details. +#[derive(Debug, Clone)] +pub struct AcceptSender { + tx: mpsc::Sender, +} + +impl AcceptSender { + /// Send a new connection to the [`ConnManager`]. + pub async fn send(&self, conn: Connection) -> anyhow::Result<()> { + self.tx.send(conn).await?; + Ok(()) + } +} + +/// The error returned from [`ConnManager::poll_next`]. +#[derive(thiserror::Error, Debug)] +#[error("Connection to node {} direction {:?} failed: {:?}", self.node_id, self.direction, self.reason)] +pub struct ConnectError { + /// The node id of the peer to which the connection failed. + pub node_id: NodeId, + /// The direction of the connection. + pub direction: ConnDirection, + /// The actual error that ocurred. + #[source] + pub reason: anyhow::Error, +} + +#[derive(Debug)] +enum InitError { + IsDuplicate, + Other(anyhow::Error), +} + +impl From for InitError { + fn from(value: anyhow::Error) -> Self { + Self::Other(value) + } +} + +impl From for InitError { + fn from(value: quinn::ConnectionError) -> Self { + match &value { + quinn::ConnectionError::ApplicationClosed(err) + if &err.reason[..] == DUPLICATE_REASON + && err.error_code == DUPLICATE_CODE.into() => + { + Self::IsDuplicate + } + _ => Self::Other(value.into()), + } + } +} + +impl From for InitError { + fn from(value: quinn::ReadToEndError) -> Self { + match value { + quinn::ReadToEndError::Read(quinn::ReadError::ConnectionLost(err)) => err.into(), + err @ _ => Self::Other(err.into()), + } + } +} + +impl From for InitError { + fn from(value: quinn::WriteError) -> Self { + match value { + quinn::WriteError::ConnectionLost(err) => err.into(), + err @ _ => Self::Other(err.into()), + } + } +} + +/// Whether we accepted the connection or initiated it. +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub enum ConnDirection { + /// We accepted this connection from the other peer. + Accept, + /// We initiated this connection by connecting to the other peer. + Dial, +} + +/// A new connection as emitted from [`ConnManager`]. +#[derive(Debug, Clone, derive_more::Deref)] +pub struct ConnInfo { + /// The QUIC connection. + #[deref] + pub conn: Connection, + /// The node id of the other peer. + pub node_id: NodeId, + /// Whether we accepted or initiated this connection. + pub direction: ConnDirection, +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use futures_lite::StreamExt; + use tokio::task::JoinSet; + + use crate::test_utils::TestEndpointFactory; + + use super::{AcceptSender, ConnManager}; + + const TEST_ALPN: &[u8] = b"test"; + + async fn accept_loop(ep: crate::Endpoint, accept_sender: AcceptSender) -> anyhow::Result<()> { + while let Some(conn) = ep.accept().await { + let conn = conn.await?; + tracing::debug!(me=%ep.node_id().fmt_short(), "conn incoming"); + accept_sender.send(conn).await?; + } + Ok(()) + } + + #[tokio::test] + async fn test_conn_manager() -> anyhow::Result<()> { + let _guard = iroh_test::logging::setup(); + let mut factory = TestEndpointFactory::run().await?; + + let alpns = vec![TEST_ALPN.to_vec()]; + let ep1 = factory.create_endpoint(alpns.clone()).await?; + let ep2 = factory.create_endpoint(alpns.clone()).await?; + let n1 = ep1.node_id(); + let n2 = ep2.node_id(); + tracing::info!(?n1, ?n2, "endpoints created"); + factory.on_node(&n1, Duration::from_secs(2)).await?; + factory.on_node(&n2, Duration::from_secs(2)).await?; + + let mut conn_manager1 = ConnManager::new(ep1.clone(), TEST_ALPN); + let mut conn_manager2 = ConnManager::new(ep2.clone(), TEST_ALPN); + + let accept1 = conn_manager1.accept_sender(); + let accept2 = conn_manager2.accept_sender(); + let mut tasks = JoinSet::new(); + tasks.spawn(accept_loop(ep1, accept1)); + tasks.spawn(accept_loop(ep2, accept2)); + + for i in 0u8..20 { + tracing::info!(i, "start dial"); + conn_manager1.dial(n2); + conn_manager2.dial(n1); + let (conn1, conn2) = tokio::join!(conn_manager1.next(), conn_manager2.next()); + let conn1 = conn1.unwrap().unwrap(); + let conn2 = conn2.unwrap().unwrap(); + + tracing::info!(?conn1.direction, "conn1"); + tracing::info!(?conn2.direction, "conn2"); + assert!(conn1.direction != conn2.direction); + assert_eq!(conn1.node_id, n2); + assert_eq!(conn2.node_id, n1); + + let mut s1 = conn1.open_uni().await.unwrap(); + s1.write_all(&[i]).await?; + s1.finish().await?; + + let mut s2 = conn2.accept_uni().await.unwrap(); + let x = s2.read_to_end(1).await.unwrap(); + + assert_eq!(&x, &[i]); + assert!(conn_manager1.remove(&n2).is_some()); + assert!(conn_manager2.remove(&n1).is_some()); + } + + tasks.abort_all(); + while let Some(r) = tasks.join_next().await { + assert!(r.unwrap_err().is_cancelled()); + } + + Ok(()) + } +} diff --git a/iroh-net/src/lib.rs b/iroh-net/src/lib.rs index 911060df6f..122b79f3cb 100644 --- a/iroh-net/src/lib.rs +++ b/iroh-net/src/lib.rs @@ -13,6 +13,7 @@ pub mod config; pub mod defaults; pub mod dialer; +pub mod conn_manager; mod disco; pub mod discovery; pub mod dns;