diff --git a/iroh-cli/src/commands/doctor.rs b/iroh-cli/src/commands/doctor.rs index 02dca5f53d..b7e007c242 100644 --- a/iroh-cli/src/commands/doctor.rs +++ b/iroh-cli/src/commands/doctor.rs @@ -398,7 +398,7 @@ impl Gui { .. }) => { let relay_url = relay_url - .map(|x| x.to_string()) + .map(|x| x.relay_url.to_string()) .unwrap_or_else(|| "unknown".to_string()); let latency = format_latency(latency); let addrs = addrs diff --git a/iroh-cli/src/commands/node.rs b/iroh-cli/src/commands/node.rs index 695e3e8dcc..1c456dd0c2 100644 --- a/iroh-cli/src/commands/node.rs +++ b/iroh-cli/src/commands/node.rs @@ -96,7 +96,7 @@ async fn fmt_connections( let node_id: Cell = conn_info.node_id.to_string().into(); let relay_url = conn_info .relay_url - .map_or(String::new(), |url| url.to_string()) + .map_or(String::new(), |url_info| url_info.relay_url.to_string()) .into(); let conn_type = conn_info.conn_type.to_string().into(); let latency = match conn_info.latency { @@ -132,7 +132,7 @@ fn fmt_connection(info: ConnectionInfo) -> String { table.add_row([bold_cell("current time"), timestamp.into()]); table.add_row([bold_cell("node id"), node_id.to_string().into()]); let relay_url = relay_url - .map(|r| r.to_string()) + .map(|r| r.relay_url.to_string()) .unwrap_or_else(|| String::from("unknown")); table.add_row([bold_cell("relay url"), relay_url.into()]); table.add_row([bold_cell("connection type"), conn_type.to_string().into()]); diff --git a/iroh-net/src/disco.rs b/iroh-net/src/disco.rs index e218e469da..668533643a 100644 --- a/iroh-net/src/disco.rs +++ b/iroh-net/src/disco.rs @@ -24,6 +24,7 @@ use std::{ }; use anyhow::{anyhow, bail, ensure, Context, Result}; +use serde::{Deserialize, Serialize}; use url::Url; use crate::{key, net::ip::to_canonical, relay::RelayUrl}; @@ -133,7 +134,7 @@ pub struct Pong { } /// Addresses to which we can send. This is either a UDP or a relay address. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum SendAddr { /// UDP, the ip addr. Udp(SocketAddr), diff --git a/iroh-net/src/discovery.rs b/iroh-net/src/discovery.rs index 117ab689b3..65a34d5b59 100644 --- a/iroh-net/src/discovery.rs +++ b/iroh-net/src/discovery.rs @@ -207,16 +207,24 @@ impl DiscoveryTask { Ok(stream) } + /// We need discovery if we have no paths to the node, or if the paths we do have + /// have timed out. fn needs_discovery(ep: &MagicEndpoint, node_id: NodeId) -> bool { match ep.connection_info(node_id) { // No connection info means no path to node -> start discovery. None => true, - Some(info) => match info.last_received() { - // No path to node -> start discovery. - None => true, - // If we haven't received for MAX_AGE, start discovery. - Some(elapsed) => elapsed > MAX_AGE, - }, + Some(info) => { + match (info.last_received(), info.last_alive_relay()) { + // No path to node -> start discovery. + (None, None) => true, + // If we haven't received on direct addresses or the relay for MAX_AGE, + // start discovery. + (Some(elapsed), Some(elapsed_relay)) => { + elapsed > MAX_AGE && elapsed_relay > MAX_AGE + } + (Some(elapsed), _) | (_, Some(elapsed)) => elapsed > MAX_AGE, + } + } } } @@ -237,6 +245,10 @@ impl DiscoveryTask { }; match next { Some(Ok(r)) => { + if r.addr_info.is_empty() { + debug!(provenance = %r.provenance, addr = ?r.addr_info, "discovery: empty address found"); + continue; + } debug!(provenance = %r.provenance, addr = ?r.addr_info, "discovery: new address found"); let addr = NodeAddr { info: r.addr_info, @@ -551,8 +563,6 @@ mod test_dns_pkarr { use anyhow::Result; use iroh_base::key::SecretKey; - use tokio::task::JoinHandle; - use tokio_util::sync::CancellationToken; use url::Url; use crate::{ @@ -560,22 +570,21 @@ mod test_dns_pkarr { dns::node_info::{lookup_by_id, NodeInfo}, relay::{RelayMap, RelayMode}, test_utils::{ + dns_and_pkarr_servers::run_dns_and_pkarr_servers, dns_server::{create_dns_resolver, run_dns_server}, + pkarr_dns_state::State, run_relay_server, }, AddrInfo, MagicEndpoint, NodeAddr, }; - use self::{pkarr_relay::run_pkarr_relay, state::State}; - #[tokio::test] async fn dns_resolve() -> Result<()> { let _logging_guard = iroh_test::logging::setup(); - let cancel = CancellationToken::new(); let origin = "testdns.example".to_string(); let state = State::new(origin.clone()); - let (nameserver, dns_task) = run_dns_server(state.clone(), cancel.clone()).await?; + let (nameserver, _dns_drop_guard) = run_dns_server(state.clone()).await?; let secret_key = SecretKey::generate(); let node_info = NodeInfo::new( @@ -590,8 +599,6 @@ mod test_dns_pkarr { assert_eq!(resolved, node_info.into()); - cancel.cancel(); - dns_task.await??; Ok(()) } @@ -600,11 +607,10 @@ mod test_dns_pkarr { let _logging_guard = iroh_test::logging::setup(); let origin = "testdns.example".to_string(); - let cancel = CancellationToken::new(); let timeout = Duration::from_secs(2); - let (nameserver, pkarr_url, state, task) = - run_dns_and_pkarr_servers(origin.clone(), cancel.clone()).await?; + let (nameserver, pkarr_url, state, _dns_drop_guard, _pkarr_drop_guard) = + run_dns_and_pkarr_servers(origin.clone()).await?; let secret_key = SecretKey::generate(); let node_id = secret_key.public(); @@ -628,9 +634,6 @@ mod test_dns_pkarr { }; assert_eq!(resolved, expected); - - cancel.cancel(); - task.await??; Ok(()) } @@ -641,11 +644,10 @@ mod test_dns_pkarr { let _logging_guard = iroh_test::logging::setup(); let origin = "testdns.example".to_string(); - let cancel = CancellationToken::new(); let timeout = Duration::from_secs(2); - let (nameserver, pkarr_url, state, task) = - run_dns_and_pkarr_servers(&origin, cancel.clone()).await?; + let (nameserver, pkarr_url, state, _dns_drop_guard, _pkarr_drop_guard) = + run_dns_and_pkarr_servers(&origin).await?; let (relay_map, _relay_url, _relay_guard) = run_relay_server().await?; let ep1 = ep_with_discovery(relay_map.clone(), nameserver, &origin, &pkarr_url).await?; @@ -657,8 +659,34 @@ mod test_dns_pkarr { // we connect only by node id! let res = ep2.connect(ep1.node_id().into(), TEST_ALPN).await; assert!(res.is_ok(), "connection established"); - cancel.cancel(); - task.await??; + Ok(()) + } + + #[tokio::test] + async fn pkarr_publish_dns_discover_empty_node_addr() -> Result<()> { + let _logging_guard = iroh_test::logging::setup(); + + let origin = "testdns.example".to_string(); + let timeout = Duration::from_secs(2); + + let (nameserver, pkarr_url, state, _dns_drop_guard, _pkarr_drop_guard) = + run_dns_and_pkarr_servers(&origin).await?; + let (relay_map, _relay_url, _relay_guard) = run_relay_server().await?; + + let ep1 = ep_with_discovery(relay_map.clone(), nameserver, &origin, &pkarr_url).await?; + let ep2 = ep_with_discovery(relay_map, nameserver, &origin, &pkarr_url).await?; + + // wait until our shared state received the update from pkarr publishing + state.on_node(&ep1.node_id(), timeout).await?; + + let node_addr = NodeAddr::new(ep1.node_id()); + + // add empty node address. We *should* launch discovery before attempting to dial. + ep2.add_node_addr(node_addr)?; + + // we connect only by node id! + let res = ep2.connect(ep1.node_id().into(), TEST_ALPN).await; + assert!(res.is_ok(), "connection established"); Ok(()) } @@ -685,203 +713,4 @@ mod test_dns_pkarr { .await?; Ok(ep) } - - async fn run_dns_and_pkarr_servers( - origin: impl ToString, - cancel: CancellationToken, - ) -> Result<(SocketAddr, Url, State, JoinHandle>)> { - let state = State::new(origin.to_string()); - let (nameserver, dns_task) = run_dns_server(state.clone(), cancel.clone()).await?; - let (pkarr_url, pkarr_task) = run_pkarr_relay(state.clone(), cancel.clone()).await?; - let join_handle = tokio::task::spawn(async move { - dns_task.await??; - pkarr_task.await??; - Ok(()) - }); - Ok((nameserver, pkarr_url, state, join_handle)) - } - - mod state { - use anyhow::{bail, Result}; - use parking_lot::{Mutex, MutexGuard}; - use pkarr::SignedPacket; - use std::{ - collections::{hash_map, HashMap}, - future::Future, - ops::Deref, - sync::Arc, - time::Duration, - }; - - use crate::dns::node_info::{node_id_from_hickory_name, NodeInfo}; - use crate::test_utils::dns_server::QueryHandler; - use crate::NodeId; - - #[derive(Debug, Clone)] - pub struct State { - packets: Arc>>, - origin: String, - notify: Arc, - } - - impl State { - pub fn new(origin: String) -> Self { - Self { - packets: Default::default(), - origin, - notify: Arc::new(tokio::sync::Notify::new()), - } - } - - pub fn on_update(&self) -> tokio::sync::futures::Notified<'_> { - self.notify.notified() - } - - pub async fn on_node(&self, node: &NodeId, timeout: Duration) -> Result<()> { - let timeout = tokio::time::sleep(timeout); - tokio::pin!(timeout); - while self.get(node).is_none() { - tokio::select! { - _ = &mut timeout => bail!("timeout"), - _ = self.on_update() => {} - } - } - Ok(()) - } - - pub fn upsert(&self, signed_packet: SignedPacket) -> anyhow::Result { - let node_id = NodeId::from_bytes(&signed_packet.public_key().to_bytes())?; - let mut map = self.packets.lock(); - let updated = match map.entry(node_id) { - hash_map::Entry::Vacant(e) => { - e.insert(signed_packet); - true - } - hash_map::Entry::Occupied(mut e) => { - if signed_packet.more_recent_than(e.get()) { - e.insert(signed_packet); - true - } else { - false - } - } - }; - if updated { - self.notify.notify_waiters(); - } - Ok(updated) - } - - /// Returns a mutex guard, do not hold over await points - pub fn get(&self, node_id: &NodeId) -> Option + '_> { - let map = self.packets.lock(); - if map.contains_key(node_id) { - let guard = MutexGuard::map(map, |state| state.get_mut(node_id).unwrap()); - Some(guard) - } else { - None - } - } - - pub fn resolve_dns( - &self, - query: &hickory_proto::op::Message, - reply: &mut hickory_proto::op::Message, - ttl: u32, - ) -> Result<()> { - for query in query.queries() { - let Some(node_id) = node_id_from_hickory_name(query.name()) else { - continue; - }; - let packet = self.get(&node_id); - let Some(packet) = packet.as_ref() else { - continue; - }; - let node_info = NodeInfo::from_pkarr_signed_packet(packet)?; - for record in node_info.to_hickory_records(&self.origin, ttl)? { - reply.add_answer(record); - } - } - Ok(()) - } - } - - impl QueryHandler for State { - fn resolve( - &self, - query: &hickory_proto::op::Message, - reply: &mut hickory_proto::op::Message, - ) -> impl Future> + Send { - const TTL: u32 = 30; - let res = self.resolve_dns(query, reply, TTL); - futures::future::ready(res) - } - } - } - - mod pkarr_relay { - use std::net::{Ipv4Addr, SocketAddr}; - - use anyhow::Result; - use axum::{ - extract::{Path, State}, - response::IntoResponse, - routing::put, - Router, - }; - use bytes::Bytes; - use tokio::task::JoinHandle; - use tokio_util::sync::CancellationToken; - use tracing::warn; - use url::Url; - - use super::State as AppState; - - pub async fn run_pkarr_relay( - state: AppState, - cancel: CancellationToken, - ) -> Result<(Url, JoinHandle>)> { - let bind_addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 0)); - let app = Router::new() - .route("/pkarr/:key", put(pkarr_put)) - .with_state(state); - let listener = tokio::net::TcpListener::bind(bind_addr).await?; - let bound_addr = listener.local_addr()?; - let url: Url = format!("http://{bound_addr}/pkarr") - .parse() - .expect("valid url"); - let join_handle = tokio::task::spawn(async move { - let serve = axum::serve(listener, app); - let serve = serve.with_graceful_shutdown(cancel.cancelled_owned()); - serve.await?; - Ok(()) - }); - Ok((url, join_handle)) - } - - async fn pkarr_put( - State(state): State, - Path(key): Path, - body: Bytes, - ) -> Result { - let key = pkarr::PublicKey::try_from(key.as_str())?; - let signed_packet = pkarr::SignedPacket::from_relay_response(key, body)?; - let _updated = state.upsert(signed_packet)?; - Ok(http::StatusCode::NO_CONTENT) - } - - #[derive(Debug)] - struct AppError(anyhow::Error); - impl> From for AppError { - fn from(value: T) -> Self { - Self(value.into()) - } - } - impl IntoResponse for AppError { - fn into_response(self) -> axum::response::Response { - warn!(err = ?self, "request failed"); - (http::StatusCode::INTERNAL_SERVER_ERROR, self.0.to_string()).into_response() - } - } - } } diff --git a/iroh-net/src/magic_endpoint.rs b/iroh-net/src/magic_endpoint.rs index 558b0d49fe..f17a9f050d 100644 --- a/iroh-net/src/magic_endpoint.rs +++ b/iroh-net/src/magic_endpoint.rs @@ -442,39 +442,14 @@ impl MagicEndpoint { self.add_node_addr(node_addr.clone())?; } - let NodeAddr { node_id, info } = node_addr; + let NodeAddr { node_id, info } = node_addr.clone(); // Get the mapped IPv6 address from the magic socket. Quinn will connect to this address. - let (addr, discovery) = match self.msock.get_mapping_addr(&node_id) { - Some(addr) => { - // We got a mapped address, which means we either spoke to this endpoint before, or - // the user provided addressing info with the [`NodeAddr`]. - // This does not mean that we can actually connect to any of these addresses. - // Therefore, we will invoke the discovery service if we haven't received from the - // endpoint on any of the existing paths recently. - // If the user provided addresses in this connect call, we will add a delay - // followed by a recheck before starting the discovery, to give the magicsocket a - // chance to test the newly provided addresses. - let delay = (!info.is_empty()).then_some(DISCOVERY_WAIT_PERIOD); - let discovery = DiscoveryTask::maybe_start_after_delay(self, node_id, delay) - .ok() - .flatten(); - (addr, discovery) - } - - None => { - // We have not spoken to this endpoint before, and the user provided no direct - // addresses or relay URLs. Thus, we start a discovery task and wait for the first - // result to arrive, and only then continue, because otherwise we wouldn't have any - // path to the remote endpoint. - let mut discovery = DiscoveryTask::start(self.clone(), node_id)?; - discovery.first_arrived().await?; - let addr = self.msock.get_mapping_addr(&node_id).ok_or_else(|| { - anyhow!("Failed to retrieve the mapped address from the magic socket. Unable to dial node {node_id:?}") - })?; - (addr, Some(discovery)) - } - }; + // Start discovery for this node if it's enabled and we have no valid or verified + // address information for this node. + let (addr, discovery) = self + .get_mapping_addr_and_maybe_start_discovery(node_addr) + .await?; debug!( "connecting to {}: (via {} - {:?})", @@ -522,6 +497,65 @@ impl MagicEndpoint { connect.await.context("failed connecting to provider") } + /// Return the quic mapped address for this `node_id` and possibly start discovery + /// services if discovery is enabled on this magic endpoint. + /// + /// This will launch discovery in all cases except if: + /// 1) we do not have discovery enabled + /// 2) we have discovery enabled, but already have at least one verified, unexpired + /// addresses for this `node_id` + /// + /// # Errors + /// + /// This method may fail if we have no way of dialing the node. This can occur if + /// we were given no dialing information in the [`NodeAddr`] and no discovery + /// services were configured or if discovery failed to fetch any dialing information. + async fn get_mapping_addr_and_maybe_start_discovery( + &self, + node_addr: NodeAddr, + ) -> Result<(SocketAddr, Option)> { + let node_id = node_addr.node_id; + + // Only return a mapped addr if we have some way of dialing this node, in other + // words, we have either a relay URL or at least one direct address. + let addr = if self.msock.has_send_address(node_id) { + self.msock.get_mapping_addr(&node_id) + } else { + None + }; + match addr { + Some(addr) => { + // We have some way of dialing this node, but that doesn't actually mean + // we can actually connect to any of these addresses. + // Therefore, we will invoke the discovery service if we haven't received from the + // endpoint on any of the existing paths recently. + // If the user provided addresses in this connect call, we will add a delay + // followed by a recheck before starting the discovery, to give the magicsocket a + // chance to test the newly provided addresses. + let delay = (!node_addr.info.is_empty()).then_some(DISCOVERY_WAIT_PERIOD); + let discovery = DiscoveryTask::maybe_start_after_delay(self, node_id, delay) + .ok() + .flatten(); + Ok((addr, discovery)) + } + + None => { + // We have no known addresses or relay URLs for this node. + // So, we start a discovery task and wait for the first result to arrive, and + // only then continue, because otherwise we wouldn't have any + // path to the remote endpoint. + let mut discovery = DiscoveryTask::start(self.clone(), node_id)?; + discovery.first_arrived().await?; + if self.msock.has_send_address(node_id) { + let addr = self.msock.get_mapping_addr(&node_id).expect("checked"); + Ok((addr, Some(discovery))) + } else { + bail!("Failed to retrieve the mapped address from the magic socket. Unable to dial node {node_id:?}"); + } + } + } + } + /// Inform the magic socket about addresses of the peer. /// /// This updates the magic socket's *netmap* with these addresses, which are used as candidates @@ -530,9 +564,8 @@ impl MagicEndpoint { /// Note: updating the magic socket's *netmap* will also prune any connections that are *not* /// present in the netmap. /// - /// If no UDP addresses are added, and `relay_url` is `None`, it will error. - /// If no UDP addresses are added, and the given `relay_url` cannot be dialed, it will error. - // TODO: This is infallible, stop returning a result. + /// # Errors + /// Will return an error if we attempt to add our own [`PublicKey`] to the node map. pub fn add_node_addr(&self, node_addr: NodeAddr) -> Result<()> { // Connecting to ourselves is not supported. if node_addr.node_id == self.node_id() { diff --git a/iroh-net/src/magicsock.rs b/iroh-net/src/magicsock.rs index 9931276829..55734012e1 100644 --- a/iroh-net/src/magicsock.rs +++ b/iroh-net/src/magicsock.rs @@ -1324,6 +1324,13 @@ impl MagicSock { self.inner.node_map.node_info(&node_key) } + /// Returns `true` if we have at least one candidate address where we can send packets to. + pub fn has_send_address(&self, node_key: PublicKey) -> bool { + self.connection_info(node_key) + .map(|info| info.has_send_address()) + .unwrap_or(false) + } + /// Returns the local endpoints as a stream. /// /// The [`MagicSock`] continuously monitors the local endpoints, the network addresses diff --git a/iroh-net/src/magicsock/node_map/node_state.rs b/iroh-net/src/magicsock/node_map/node_state.rs index f9bfabb0c3..aaa9f066f8 100644 --- a/iroh-net/src/magicsock/node_map/node_state.rs +++ b/iroh-net/src/magicsock/node_map/node_state.rs @@ -236,7 +236,7 @@ impl NodeState { NodeInfo { id: self.id, node_id: self.node_id, - relay_url: self.relay_url(), + relay_url: self.relay_url.clone().map(|r| r.into()), addrs, conn_type, latency, @@ -1364,6 +1364,27 @@ pub struct DirectAddrInfo { pub last_payload: Option, } +/// Information about a relay URL. +#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)] +pub struct RelayUrlInfo { + /// The relay url + pub relay_url: RelayUrl, + /// How long ago was the relay url last used. + pub last_alive: Option, + /// Latency of the relay url. + pub latency: Option, +} + +impl From<(RelayUrl, PathState)> for RelayUrlInfo { + fn from(value: (RelayUrl, PathState)) -> Self { + RelayUrlInfo { + relay_url: value.0, + last_alive: value.1.last_alive().map(|i| i.elapsed()), + latency: value.1.latency(), + } + } +} + /// Details about an iroh node which is known to this node. #[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)] pub struct NodeInfo { @@ -1371,8 +1392,8 @@ pub struct NodeInfo { pub id: usize, /// The public key of the endpoint. pub node_id: NodeId, - /// relay server, if available. - pub relay_url: Option, + /// relay server information, if available. + pub relay_url: Option, /// List of addresses at which this node might be reachable, plus any latency information we /// have about that address and the last time the address was used. pub addrs: Vec, @@ -1393,6 +1414,17 @@ impl NodeInfo { .filter_map(|addr| addr.last_control.map(|x| x.0).min(addr.last_payload)) .min() } + + /// Get the duration since the last activity we received from this endpoint + /// on the relay url. + pub fn last_alive_relay(&self) -> Option { + self.relay_url.as_ref().and_then(|r| r.last_alive) + } + + /// Returns `true` if this info contains either a relay URL or at least one direct address. + pub fn has_send_address(&self) -> bool { + self.relay_url.is_some() || !self.addrs.is_empty() + } } /// The type of connection we have to the endpoint. @@ -1427,16 +1459,26 @@ mod tests { #[test] fn test_endpoint_infos() { - let new_relay_and_state = - |url: Option| url.map(|url| (url, PathState::default())); - let now = Instant::now(); let elapsed = Duration::from_secs(3); let later = now + elapsed; let send_addr: RelayUrl = "https://my-relay.com".parse().unwrap(); - // endpoint with a `best_addr` that has a latency let pong_src = SendAddr::Udp("0.0.0.0:1".parse().unwrap()); let latency = Duration::from_millis(50); + + let new_relay_and_state = |url: RelayUrl| Some((url, PathState::default())); + + let relay_and_state = |url: RelayUrl| { + let relay_state = PathState::with_pong_reply(PongReply { + latency, + pong_at: now, + from: SendAddr::Relay(send_addr.clone()), + pong_src: pong_src.clone(), + }); + Some((url, relay_state)) + }; + + // endpoint with a `best_addr` that has a latency but no relay let (a_endpoint, a_socket_addr) = { let ip_port = IpPort { ip: Ipv4Addr::UNSPECIFIED.into(), @@ -1458,7 +1500,7 @@ mod tests { quic_mapped_addr: QuicMappedAddr::generate(), node_id: key.public(), last_full_ping: None, - relay_url: new_relay_and_state(Some(send_addr.clone())), + relay_url: None, best_addr: BestAddr::from_parts( ip_port.into(), latency, @@ -1477,19 +1519,13 @@ mod tests { // endpoint w/ no best addr but a relay w/ latency let b_endpoint = { // let socket_addr = "0.0.0.0:9".parse().unwrap(); - let relay_state = PathState::with_pong_reply(PongReply { - latency, - pong_at: now, - from: SendAddr::Relay(send_addr.clone()), - pong_src: pong_src.clone(), - }); let key = SecretKey::generate(); NodeState { id: 1, quic_mapped_addr: QuicMappedAddr::generate(), node_id: key.public(), last_full_ping: None, - relay_url: Some((send_addr.clone(), relay_state)), + relay_url: relay_and_state(send_addr.clone()), best_addr: BestAddr::default(), direct_addr_state: BTreeMap::default(), sent_pings: HashMap::new(), @@ -1509,7 +1545,7 @@ mod tests { quic_mapped_addr: QuicMappedAddr::generate(), node_id: key.public(), last_full_ping: None, - relay_url: new_relay_and_state(Some(send_addr.clone())), + relay_url: new_relay_and_state(send_addr.clone()), best_addr: BestAddr::default(), direct_addr_state: endpoint_state, sent_pings: HashMap::new(), @@ -1519,7 +1555,7 @@ mod tests { } }; - // endpoint w/ expired best addr + // endpoint w/ expired best addr and relay w/ latency let (d_endpoint, d_socket_addr) = { let socket_addr: SocketAddr = "0.0.0.0:7".parse().unwrap(); let expired = now.checked_sub(Duration::from_secs(100)).unwrap(); @@ -1532,12 +1568,6 @@ mod tests { pong_src: pong_src.clone(), }), )]); - let relay_state = PathState::with_pong_reply(PongReply { - latency, - pong_at: now, - from: SendAddr::Relay(send_addr.clone()), - pong_src, - }); let key = SecretKey::generate(); ( NodeState { @@ -1545,7 +1575,7 @@ mod tests { quic_mapped_addr: QuicMappedAddr::generate(), node_id: key.public(), last_full_ping: None, - relay_url: Some((send_addr.clone(), relay_state)), + relay_url: relay_and_state(send_addr.clone()), best_addr: BestAddr::from_parts( socket_addr, Duration::from_millis(80), @@ -1568,7 +1598,7 @@ mod tests { NodeInfo { id: a_endpoint.id, node_id: a_endpoint.node_id, - relay_url: a_endpoint.relay_url(), + relay_url: None, addrs: Vec::from([DirectAddrInfo { addr: a_socket_addr, latency: Some(latency), @@ -1582,7 +1612,11 @@ mod tests { NodeInfo { id: b_endpoint.id, node_id: b_endpoint.node_id, - relay_url: b_endpoint.relay_url(), + relay_url: Some(RelayUrlInfo { + relay_url: b_endpoint.relay_url.as_ref().unwrap().0.clone(), + last_alive: None, + latency: Some(latency), + }), addrs: Vec::new(), conn_type: ConnectionType::Relay(send_addr.clone()), latency: Some(latency), @@ -1591,7 +1625,11 @@ mod tests { NodeInfo { id: c_endpoint.id, node_id: c_endpoint.node_id, - relay_url: c_endpoint.relay_url(), + relay_url: Some(RelayUrlInfo { + relay_url: c_endpoint.relay_url.as_ref().unwrap().0.clone(), + last_alive: None, + latency: None, + }), addrs: Vec::new(), conn_type: ConnectionType::Relay(send_addr.clone()), latency: None, @@ -1600,7 +1638,11 @@ mod tests { NodeInfo { id: d_endpoint.id, node_id: d_endpoint.node_id, - relay_url: d_endpoint.relay_url(), + relay_url: Some(RelayUrlInfo { + relay_url: d_endpoint.relay_url.as_ref().unwrap().0.clone(), + last_alive: None, + latency: Some(latency), + }), addrs: Vec::from([DirectAddrInfo { addr: d_socket_addr, latency: Some(latency), @@ -1640,9 +1682,18 @@ mod tests { }); let mut got = node_map.node_infos(later); got.sort_by_key(|p| p.id); + remove_non_deterministic_fields(&mut got); assert_eq!(expect, got); } + fn remove_non_deterministic_fields(infos: &mut [NodeInfo]) { + for info in infos.iter_mut() { + if info.relay_url.is_some() { + info.relay_url.as_mut().unwrap().last_alive = None; + } + } + } + #[test] fn test_prune_direct_addresses() { // When we handle a call-me-maybe with more than MAX_INACTIVE_DIRECT_ADDRESSES we do diff --git a/iroh-net/src/test_utils.rs b/iroh-net/src/test_utils.rs index 652d81a09d..88ac5b11f9 100644 --- a/iroh-net/src/test_utils.rs +++ b/iroh-net/src/test_utils.rs @@ -4,8 +4,10 @@ use anyhow::Result; use tokio::sync::oneshot; use tracing::{error_span, info_span, Instrument}; -use crate::key::SecretKey; -use crate::relay::{RelayMap, RelayNode, RelayUrl}; +use crate::{ + key::SecretKey, + relay::{RelayMap, RelayNode, RelayUrl}, +}; /// A drop guard to clean up test infrastructure. /// @@ -63,6 +65,34 @@ pub async fn run_relay_server() -> Result<(RelayMap, RelayUrl, CleanupDropGuard) Ok((m, url, CleanupDropGuard(tx))) } +#[cfg(test)] +pub(crate) mod dns_and_pkarr_servers { + use anyhow::Result; + use std::net::SocketAddr; + use url::Url; + + use super::CleanupDropGuard; + + use crate::test_utils::{ + dns_server::run_dns_server, pkarr_dns_state::State, pkarr_relay::run_pkarr_relay, + }; + + pub async fn run_dns_and_pkarr_servers( + origin: impl ToString, + ) -> Result<(SocketAddr, Url, State, CleanupDropGuard, CleanupDropGuard)> { + let state = State::new(origin.to_string()); + let (nameserver, dns_drop_guard) = run_dns_server(state.clone()).await?; + let (pkarr_url, pkarr_drop_guard) = run_pkarr_relay(state.clone()).await?; + Ok(( + nameserver, + pkarr_url, + state, + dns_drop_guard, + pkarr_drop_guard, + )) + } +} + #[cfg(test)] pub(crate) mod dns_server { use std::net::{Ipv4Addr, SocketAddr}; @@ -74,9 +104,10 @@ pub(crate) mod dns_server { serialize::binary::BinDecodable, }; use hickory_resolver::{config::NameServerConfig, TokioAsyncResolver}; - use tokio::{net::UdpSocket, task::JoinHandle}; - use tokio_util::sync::CancellationToken; - use tracing::{debug, warn}; + use tokio::{net::UdpSocket, sync::oneshot}; + use tracing::{debug, error, warn}; + + use super::CleanupDropGuard; /// Trait used by [`run_dns_server`] for answering DNS queries. pub trait QueryHandler: Send + Sync + 'static { @@ -105,18 +136,25 @@ pub(crate) mod dns_server { /// Must pass a [`QueryHandler`] that answers queries. Can be a [`ResolveCallback`] or a struct. pub async fn run_dns_server( resolver: impl QueryHandler, - cancel: CancellationToken, - ) -> Result<(SocketAddr, JoinHandle>)> { + ) -> Result<(SocketAddr, CleanupDropGuard)> { let bind_addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 0)); let socket = UdpSocket::bind(bind_addr).await?; let bound_addr = socket.local_addr()?; - let s = TestDnsServer { - socket, - cancel, - resolver, - }; - let join_handle = tokio::task::spawn(async move { s.run().await }); - Ok((bound_addr, join_handle)) + let s = TestDnsServer { socket, resolver }; + let (tx, mut rx) = oneshot::channel(); + tokio::task::spawn(async move { + tokio::select! { + _ = &mut rx => { + debug!("shutting down dns server"); + } + res = s.run() => { + if let Err(e) = res { + error!("error running dns server {e:?}"); + } + } + } + }); + Ok((bound_addr, CleanupDropGuard(tx))) } /// Create a DNS resolver with a single nameserver. @@ -132,24 +170,18 @@ pub(crate) mod dns_server { struct TestDnsServer { resolver: R, socket: UdpSocket, - cancel: CancellationToken, } impl TestDnsServer { async fn run(self) -> Result<()> { let mut buf = [0; 1450]; loop { - tokio::select! { - _ = self.cancel.cancelled() => break, - res = self.socket.recv_from(&mut buf) => { - let (len, from) = res?; - if let Err(err) = self.handle_datagram(from, &buf[..len]).await { - warn!(?err, %from, "failed to handle incoming datagram"); - } - } - }; + let res = self.socket.recv_from(&mut buf).await; + let (len, from) = res?; + if let Err(err) = self.handle_datagram(from, &buf[..len]).await { + warn!(?err, %from, "failed to handle incoming datagram"); + } } - Ok(()) } async fn handle_datagram(&self, from: SocketAddr, buf: &[u8]) -> Result<()> { @@ -166,3 +198,197 @@ pub(crate) mod dns_server { } } } + +#[cfg(test)] +pub(crate) mod pkarr_relay { + use std::future::IntoFuture; + use std::net::{Ipv4Addr, SocketAddr}; + + use anyhow::Result; + use axum::{ + extract::{Path, State}, + response::IntoResponse, + routing::put, + Router, + }; + use bytes::Bytes; + use tokio::sync::oneshot; + use tracing::{debug, error, warn}; + use url::Url; + + use crate::test_utils::pkarr_dns_state::State as AppState; + + use super::CleanupDropGuard; + + pub async fn run_pkarr_relay(state: AppState) -> Result<(Url, CleanupDropGuard)> { + let bind_addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 0)); + let app = Router::new() + .route("/pkarr/:key", put(pkarr_put)) + .with_state(state); + let listener = tokio::net::TcpListener::bind(bind_addr).await?; + let bound_addr = listener.local_addr()?; + let url: Url = format!("http://{bound_addr}/pkarr") + .parse() + .expect("valid url"); + + let (tx, mut rx) = oneshot::channel(); + tokio::spawn(async move { + let serve = axum::serve(listener, app); + tokio::select! { + _ = &mut rx => { + debug!("shutting down pkarr server"); + } + res = serve.into_future() => { + if let Err(e) = res { + error!("pkarr server error: {e:?}"); + } + } + } + }); + Ok((url, CleanupDropGuard(tx))) + } + + async fn pkarr_put( + State(state): State, + Path(key): Path, + body: Bytes, + ) -> Result { + let key = pkarr::PublicKey::try_from(key.as_str())?; + let signed_packet = pkarr::SignedPacket::from_relay_response(key, body)?; + let _updated = state.upsert(signed_packet)?; + Ok(http::StatusCode::NO_CONTENT) + } + + #[derive(Debug)] + struct AppError(anyhow::Error); + impl> From for AppError { + fn from(value: T) -> Self { + Self(value.into()) + } + } + impl IntoResponse for AppError { + fn into_response(self) -> axum::response::Response { + warn!(err = ?self, "request failed"); + (http::StatusCode::INTERNAL_SERVER_ERROR, self.0.to_string()).into_response() + } + } +} + +#[cfg(test)] +pub(crate) mod pkarr_dns_state { + use anyhow::{bail, Result}; + use parking_lot::{Mutex, MutexGuard}; + use pkarr::SignedPacket; + use std::{ + collections::{hash_map, HashMap}, + future::Future, + ops::Deref, + sync::Arc, + time::Duration, + }; + + use crate::dns::node_info::{node_id_from_hickory_name, NodeInfo}; + use crate::test_utils::dns_server::QueryHandler; + use crate::NodeId; + + #[derive(Debug, Clone)] + pub struct State { + packets: Arc>>, + origin: String, + notify: Arc, + } + + impl State { + pub fn new(origin: String) -> Self { + Self { + packets: Default::default(), + origin, + notify: Arc::new(tokio::sync::Notify::new()), + } + } + + pub fn on_update(&self) -> tokio::sync::futures::Notified<'_> { + self.notify.notified() + } + + pub async fn on_node(&self, node: &NodeId, timeout: Duration) -> Result<()> { + let timeout = tokio::time::sleep(timeout); + tokio::pin!(timeout); + while self.get(node).is_none() { + tokio::select! { + _ = &mut timeout => bail!("timeout"), + _ = self.on_update() => {} + } + } + Ok(()) + } + + pub fn upsert(&self, signed_packet: SignedPacket) -> anyhow::Result { + let node_id = NodeId::from_bytes(&signed_packet.public_key().to_bytes())?; + let mut map = self.packets.lock(); + let updated = match map.entry(node_id) { + hash_map::Entry::Vacant(e) => { + e.insert(signed_packet); + true + } + hash_map::Entry::Occupied(mut e) => { + if signed_packet.more_recent_than(e.get()) { + e.insert(signed_packet); + true + } else { + false + } + } + }; + if updated { + self.notify.notify_waiters(); + } + Ok(updated) + } + + /// Returns a mutex guard, do not hold over await points + pub fn get(&self, node_id: &NodeId) -> Option + '_> { + let map = self.packets.lock(); + if map.contains_key(node_id) { + let guard = MutexGuard::map(map, |state| state.get_mut(node_id).unwrap()); + Some(guard) + } else { + None + } + } + + pub fn resolve_dns( + &self, + query: &hickory_proto::op::Message, + reply: &mut hickory_proto::op::Message, + ttl: u32, + ) -> Result<()> { + for query in query.queries() { + let Some(node_id) = node_id_from_hickory_name(query.name()) else { + continue; + }; + let packet = self.get(&node_id); + let Some(packet) = packet.as_ref() else { + continue; + }; + let node_info = NodeInfo::from_pkarr_signed_packet(packet)?; + for record in node_info.to_hickory_records(&self.origin, ttl)? { + reply.add_answer(record); + } + } + Ok(()) + } + } + + impl QueryHandler for State { + fn resolve( + &self, + query: &hickory_proto::op::Message, + reply: &mut hickory_proto::op::Message, + ) -> impl Future> + Send { + const TTL: u32 = 30; + let res = self.resolve_dns(query, reply, TTL); + futures::future::ready(res) + } + } +}