Skip to content

Commit

Permalink
feat(iroh-net): add MagicEndpoint::conn_type_stream returns a strea…
Browse files Browse the repository at this point in the history
…m that reports connection type changes for a `node_id` (#2161)

## Description
`MagicEndpoint::conn_type_stream` returns a `Stream` that reports
changes for a `magicsock::Endpoint` with a given `node_id` in a
`magicsock::NodeMap`.

It will error if no address information for that `node_id` exists in the
`NodeMap`.

This PR also adjusts the `Endpoint::info()` method to use the same
`ConnectionType` that gets reported to the stream.

## Change checklist

- [x] Self-review.
- [x] Documentation updates if relevant.
- [x] Tests if relevant.
  • Loading branch information
ramfox committed Apr 12, 2024
1 parent b07547b commit 7986394
Show file tree
Hide file tree
Showing 4 changed files with 236 additions and 23 deletions.
110 changes: 108 additions & 2 deletions iroh-net/src/magic_endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::{
discovery::{Discovery, DiscoveryTask},
dns::{default_resolver, DnsResolver},
key::{PublicKey, SecretKey},
magicsock::{self, MagicSock},
magicsock::{self, ConnectionTypeStream, MagicSock},
relay::{RelayMap, RelayMode, RelayUrl},
tls, NodeId,
};
Expand Down Expand Up @@ -402,6 +402,16 @@ impl MagicEndpoint {
self.connect(addr, alpn).await
}

/// Returns a stream that reports changes in the [`crate::magicsock::ConnectionType`]
/// for the given `node_id`.
///
/// # Errors
///
/// Will error if we do not have any address information for the given `node_id`
pub fn conn_type_stream(&self, node_id: &PublicKey) -> Result<ConnectionTypeStream> {
self.msock.conn_type_stream(node_id)
}

/// Connect to a remote endpoint.
///
/// A [`NodeAddr`] is required. It must contain the [`NodeId`] to dial and may also contain a
Expand Down Expand Up @@ -630,7 +640,7 @@ mod tests {
use rand_core::SeedableRng;
use tracing::{error_span, info, info_span, Instrument};

use crate::test_utils::run_relay_server;
use crate::{magicsock::ConnectionType, test_utils::run_relay_server};

use super::*;

Expand Down Expand Up @@ -971,4 +981,100 @@ mod tests {
p1_connect.await.unwrap();
p2_connect.await.unwrap();
}

#[tokio::test]
async fn magic_endpoint_conn_type_stream() {
let _logging_guard = iroh_test::logging::setup();
let (relay_map, relay_url, _relay_guard) = run_relay_server().await.unwrap();
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
let ep1_secret_key = SecretKey::generate_with_rng(&mut rng);
let ep2_secret_key = SecretKey::generate_with_rng(&mut rng);
let ep1 = MagicEndpoint::builder()
.secret_key(ep1_secret_key)
.insecure_skip_relay_cert_verify(true)
.alpns(vec![TEST_ALPN.to_vec()])
.relay_mode(RelayMode::Custom(relay_map.clone()))
.bind(0)
.await
.unwrap();
let ep2 = MagicEndpoint::builder()
.secret_key(ep2_secret_key)
.insecure_skip_relay_cert_verify(true)
.alpns(vec![TEST_ALPN.to_vec()])
.relay_mode(RelayMode::Custom(relay_map))
.bind(0)
.await
.unwrap();

async fn handle_direct_conn(ep: MagicEndpoint, node_id: PublicKey) -> Result<()> {
let node_addr = NodeAddr::new(node_id);
ep.add_node_addr(node_addr)?;
let stream = ep.conn_type_stream(&node_id)?;
async fn get_direct_event(
src: &PublicKey,
dst: &PublicKey,
mut stream: ConnectionTypeStream,
) -> Result<()> {
let src = src.fmt_short();
let dst = dst.fmt_short();
while let Some(conn_type) = stream.next().await {
tracing::info!(me = %src, dst = %dst, conn_type = ?conn_type);
if matches!(conn_type, ConnectionType::Direct(_)) {
return Ok(());
}
}
anyhow::bail!("conn_type stream ended before `ConnectionType::Direct`");
}
tokio::time::timeout(
Duration::from_secs(15),
get_direct_event(&ep.node_id(), &node_id, stream),
)
.await??;
Ok(())
}

let ep1_nodeid = ep1.node_id();
let ep2_nodeid = ep2.node_id();

let ep1_nodeaddr = ep1.my_addr().await.unwrap();
tracing::info!(
"node id 1 {ep1_nodeid}, relay URL {:?}",
ep1_nodeaddr.relay_url()
);
tracing::info!("node id 2 {ep2_nodeid}");

let res_ep1 = tokio::spawn(handle_direct_conn(ep1.clone(), ep2_nodeid));

let ep1_abort_handle = res_ep1.abort_handle();
let _ep1_guard = CallOnDrop::new(move || {
ep1_abort_handle.abort();
});

let res_ep2 = tokio::spawn(handle_direct_conn(ep2.clone(), ep1_nodeid));
let ep2_abort_handle = res_ep2.abort_handle();
let _ep2_guard = CallOnDrop::new(move || {
ep2_abort_handle.abort();
});
async fn accept(ep: MagicEndpoint) -> (PublicKey, String, quinn::Connection) {
let incoming = ep.accept().await.unwrap();
accept_conn(incoming).await.unwrap()
}

// create a node addr with no direct connections
let ep1_nodeaddr = NodeAddr::from_parts(ep1_nodeid, Some(relay_url), vec![]);

let accept_res = tokio::spawn(accept(ep1.clone()));
let accept_abort_handle = accept_res.abort_handle();
let _accept_guard = CallOnDrop::new(move || {
accept_abort_handle.abort();
});

let _conn_2 = ep2.connect(ep1_nodeaddr, TEST_ALPN).await.unwrap();

let (got_id, _, _conn) = accept_res.await.unwrap();
assert_eq!(ep2_nodeid, got_id);

res_ep1.await.unwrap().unwrap();
res_ep2.await.unwrap().unwrap();
}
}
21 changes: 20 additions & 1 deletion iroh-net/src/magicsock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ mod udp_conn;
pub use crate::net::UdpSocket;

pub use self::metrics::Metrics;
pub use self::node_map::{ConnectionType, ControlMsg, DirectAddrInfo, EndpointInfo};
pub use self::node_map::{
ConnectionType, ConnectionTypeStream, ControlMsg, DirectAddrInfo, EndpointInfo,
};
pub use self::timer::Timer;

/// How long we consider a STUN-derived endpoint valid for. UDP NAT mappings typically
Expand Down Expand Up @@ -1349,6 +1351,23 @@ impl MagicSock {
}
}

/// Returns a stream that reports the [`ConnectionType`] we have to the
/// given `node_id`.
///
/// The `NodeMap` continuously monitors the `node_id`'s endpoint for
/// [`ConnectionType`] changes, and sends the latest [`ConnectionType`]
/// on the stream.
///
/// The current [`ConnectionType`] will the the initial entry on the stream.
///
/// # Errors
///
/// Will return an error if there is no address information known about the
/// given `node_id`.
pub fn conn_type_stream(&self, node_id: &PublicKey) -> Result<node_map::ConnectionTypeStream> {
self.inner.node_map.conn_type_stream(node_id)
}

/// Get the cached version of the Ipv4 and Ipv6 addrs of the current connection.
pub fn local_addr(&self) -> Result<(SocketAddr, Option<SocketAddr>)> {
Ok(self.inner.local_addr())
Expand Down
56 changes: 55 additions & 1 deletion iroh-net/src/magicsock/node_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@ use std::{
hash::Hash,
net::{IpAddr, SocketAddr},
path::Path,
pin::Pin,
task::{Context, Poll},
time::Instant,
};

use anyhow::{ensure, Context};
use anyhow::{ensure, Context as _};
use futures::Stream;
use iroh_metrics::inc;
use parking_lot::Mutex;
use stun_rs::TransactionId;
Expand Down Expand Up @@ -209,6 +212,19 @@ impl NodeMap {
self.inner.lock().endpoint_infos(now)
}

/// Returns a stream of [`ConnectionType`].
///
/// Sends the current [`ConnectionType`] whenever any changes to the
/// connection type for `public_key` has occured.
///
/// # Errors
///
/// Will return an error if there is not an entry in the [`NodeMap`] for
/// the `public_key`
pub fn conn_type_stream(&self, public_key: &PublicKey) -> anyhow::Result<ConnectionTypeStream> {
self.inner.lock().conn_type_stream(public_key)
}

/// Get the [`EndpointInfo`]s for each endpoint
pub fn endpoint_info(&self, public_key: &PublicKey) -> Option<EndpointInfo> {
self.inner.lock().endpoint_info(public_key)
Expand Down Expand Up @@ -389,6 +405,25 @@ impl NodeMapInner {
.map(|ep| ep.info(Instant::now()))
}

/// Returns a stream of [`ConnectionType`].
///
/// Sends the current [`ConnectionType`] whenever any changes to the
/// connection type for `public_key` has occured.
///
/// # Errors
///
/// Will return an error if there is not an entry in the [`NodeMap`] for
/// the `public_key`
fn conn_type_stream(&self, public_key: &PublicKey) -> anyhow::Result<ConnectionTypeStream> {
match self.get(EndpointId::NodeKey(public_key)) {
Some(ep) => Ok(ConnectionTypeStream {
initial: Some(ep.conn_type.get()),
inner: ep.conn_type.watch().into_stream(),
}),
None => anyhow::bail!("No endpoint for {public_key:?} found"),
}
}

fn handle_pong(&mut self, sender: PublicKey, src: &DiscoMessageSource, pong: Pong) {
if let Some(ep) = self.get_mut(EndpointId::NodeKey(&sender)).as_mut() {
let insert = ep.handle_pong(&pong, src.into());
Expand Down Expand Up @@ -536,6 +571,25 @@ impl NodeMapInner {
}
}

/// Stream returning `ConnectionTypes`
#[derive(Debug)]
pub struct ConnectionTypeStream {
initial: Option<ConnectionType>,
inner: watchable::WatcherStream<ConnectionType>,
}

impl Stream for ConnectionTypeStream {
type Item = ConnectionType;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = &mut *self;
if let Some(initial_conn_type) = this.initial.take() {
return Poll::Ready(Some(initial_conn_type));
}
Pin::new(&mut this.inner).poll_next(cx)
}
}

/// An (Ip, Port) pair.
///
/// NOTE: storing an [`IpPort`] is safer than storing a [`SocketAddr`] because for IPv6 socket
Expand Down
Loading

0 comments on commit 7986394

Please sign in to comment.