diff --git a/Cargo.lock b/Cargo.lock index 7e794a9fa..2994a25fd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "addr2line" @@ -237,6 +237,17 @@ version = "4.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b75356056920673b02621b35afd0f7dda9306d03c79a30f5c56c44cf256e3de" +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "atomic-waker" version = "1.1.2" @@ -744,7 +755,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.61.1", + "windows-sys 0.52.0", ] [[package]] @@ -1214,9 +1225,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.82" +version = "0.3.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b011eec8cc36da2aab2d5cff675ec18454fad408585853910a202391cf9f8e65" +checksum = "464a3709c7f55f1f721e5389aa6ea4e3bc6aba669353300af094b29ffbdde1d8" dependencies = [ "once_cell", "wasm-bindgen", @@ -1377,6 +1388,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg", + "libm", ] [[package]] @@ -1712,7 +1724,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys 0.61.1", + "windows-sys 0.52.0", ] [[package]] @@ -2323,9 +2335,9 @@ dependencies = [ [[package]] name = "wasm-bindgen" -version = "0.2.105" +version = "0.2.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da95793dfc411fbbd93f5be7715b0578ec61fe87cb1a42b12eb625caa5c5ea60" +checksum = "0d759f433fa64a2d763d1340820e46e111a7a5ab75f993d1852d70b03dbb80fd" dependencies = [ "cfg-if", "once_cell", @@ -2336,9 +2348,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.55" +version = "0.4.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "551f88106c6d5e7ccc7cd9a16f312dd3b5d36ea8b4954304657d5dfba115d4a0" +checksum = "836d9622d604feee9e5de25ac10e3ea5f2d65b41eac0d9ce72eb5deae707ce7c" dependencies = [ "cfg-if", "js-sys", @@ -2349,9 +2361,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.105" +version = "0.2.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04264334509e04a7bf8690f2384ef5265f05143a4bff3889ab7a3269adab59c2" +checksum = "48cb0d2638f8baedbc542ed444afc0644a29166f1595371af4fecf8ce1e7eeb3" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2359,9 +2371,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.105" +version = "0.2.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "420bc339d9f322e562942d52e115d57e950d12d88983a14c79b86859ee6c7ebc" +checksum = "cefb59d5cd5f92d9dcf80e4683949f15ca4b511f4ac0a6e14d4e1ac60c6ecd40" dependencies = [ "bumpalo", "proc-macro2", @@ -2372,21 +2384,29 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.105" +version = "0.2.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76f218a38c84bcb33c25ec7059b07847d465ce0e0a76b995e134a45adcb6af76" +checksum = "cbc538057e648b67f72a982e708d485b2efa771e1ac05fec311f9f63e5800db4" dependencies = [ "unicode-ident", ] [[package]] name = "wasm-bindgen-test" -version = "0.3.55" +version = "0.3.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfc379bfb624eb59050b509c13e77b4eb53150c350db69628141abce842f2373" +checksum = "25e90e66d265d3a1efc0e72a54809ab90b9c0c515915c67cdf658689d2c22c6c" dependencies = [ + "async-trait", + "cast", "js-sys", + "libm", "minicov", + "nu-ansi-term", + "num-traits", + "oorandom", + "serde", + "serde_json", "wasm-bindgen", "wasm-bindgen-futures", "wasm-bindgen-test-macro", @@ -2394,9 +2414,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-test-macro" -version = "0.3.55" +version = "0.3.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "085b2df989e1e6f9620c1311df6c996e83fe16f57792b272ce1e024ac16a90f1" +checksum = "7150335716dce6028bead2b848e72f47b45e7b9422f64cccdc23bedca89affc1" dependencies = [ "proc-macro2", "quote", @@ -2405,9 +2425,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.82" +version = "0.3.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a1f95c0d03a47f4ae1f7a64643a6bb97465d9b740f0fa8f90ea33915c99a9a1" +checksum = "9b32828d774c412041098d182a8b38b16ea816958e07cf40eec2bc080ae137ac" dependencies = [ "js-sys", "wasm-bindgen", @@ -2454,7 +2474,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.61.1", + "windows-sys 0.52.0", ] [[package]] diff --git a/quinn-proto/src/config/transport.rs b/quinn-proto/src/config/transport.rs index a2313832f..d2de9ac0a 100644 --- a/quinn-proto/src/config/transport.rs +++ b/quinn-proto/src/config/transport.rs @@ -1,4 +1,8 @@ -use std::{fmt, num::NonZeroU32, sync::Arc}; +use std::{ + fmt, + num::{NonZeroU8, NonZeroU32}, + sync::Arc, +}; #[cfg(feature = "qlog")] use std::{io, sync::Mutex, time::Instant}; @@ -15,7 +19,7 @@ use crate::{ /// When multipath is required and has not been explicitly enabled, this value will be used for /// [`TransportConfig::max_concurrent_multipath_paths`]. const DEFAULT_CONCURRENT_MULTIPATH_PATHS_WHEN_ENABLED_: NonZeroU32 = { - match NonZeroU32::new(4) { + match NonZeroU32::new(12) { Some(v) => v, None => panic!("to enable multipath this must be positive, which clearly it is"), } @@ -78,7 +82,7 @@ pub struct TransportConfig { pub(crate) default_path_max_idle_timeout: Option, pub(crate) default_path_keep_alive_interval: Option, - pub(crate) nat_traversal_concurrency_limit: Option, + pub(crate) max_remote_nat_traversal_addresses: Option, pub(crate) qlog_sink: QlogSink, } @@ -443,18 +447,19 @@ impl TransportConfig { .map(Into::into) } - /// Sets the maximum number of concurrent nat traversal attempts to initiate as a client, or to - /// allow as a server. + /// Sets the maximum number of nat traversal addresses this endpoint allows the remote to + /// advertise /// - /// Setting this to any nonzero value will enable the Nat Traversal Extension for QUIC, - /// see + /// Setting this to any nonzero value will enable Iroh's holepunching, losely based in the Nat + /// Traversal Extension for QUIC, see + /// /// /// This implementation expects the multipath extension to be enabled as well. if not yet /// enabled via [`Self::max_concurrent_multipath_paths`], a default value of /// [`DEFAULT_CONCURRENT_MULTIPATH_PATHS_WHEN_ENABLED`] will be used. - pub fn set_max_nat_traversal_concurrent_attempts(&mut self, max_concurrent: u32) -> &mut Self { - self.nat_traversal_concurrency_limit = NonZeroU32::new(max_concurrent); - if max_concurrent != 0 && self.max_concurrent_multipath_paths.is_none() { + pub fn set_max_remote_nat_traversal_addresses(&mut self, max_addresses: u8) -> &mut Self { + self.max_remote_nat_traversal_addresses = NonZeroU8::new(max_addresses); + if max_addresses != 0 && self.max_concurrent_multipath_paths.is_none() { self.max_concurrent_multipath_paths( DEFAULT_CONCURRENT_MULTIPATH_PATHS_WHEN_ENABLED_.get(), ); @@ -462,14 +467,6 @@ impl TransportConfig { self } - /// Gets the maximum number of concurrent attempts for nat traversal - /// - /// If this is `Some`, the value is guaranteed to be non zero. - pub fn get_nat_traversal_concurrency_limit(&self) -> Option { - self.nat_traversal_concurrency_limit - .map(|non_zero| VarInt::from_u32(non_zero.get())) - } - /// qlog capture configuration to use for a particular connection #[cfg(feature = "qlog")] pub fn qlog_stream(&mut self, stream: Option) -> &mut Self { @@ -526,7 +523,7 @@ impl Default for TransportConfig { default_path_keep_alive_interval: None, // nat traversal disabled by default - nat_traversal_concurrency_limit: None, + max_remote_nat_traversal_addresses: None, qlog_sink: QlogSink::default(), } @@ -565,7 +562,7 @@ impl fmt::Debug for TransportConfig { max_concurrent_multipath_paths, default_path_max_idle_timeout, default_path_keep_alive_interval, - nat_traversal_concurrency_limit, + max_remote_nat_traversal_addresses, qlog_sink, } = self; let mut s = fmt.debug_struct("TransportConfig"); @@ -610,8 +607,8 @@ impl fmt::Debug for TransportConfig { default_path_keep_alive_interval, ) .field( - "nat_traversal_concurrency_limit", - nat_traversal_concurrency_limit, + "max_remote_nat_traversal_addresses", + max_remote_nat_traversal_addresses, ); if cfg!(feature = "qlog") { s.field("qlog_stream", &qlog_sink.is_enabled()); diff --git a/quinn-proto/src/connection/mod.rs b/quinn-proto/src/connection/mod.rs index 65ba04c21..3d7e91c70 100644 --- a/quinn-proto/src/connection/mod.rs +++ b/quinn-proto/src/connection/mod.rs @@ -28,6 +28,7 @@ use crate::{ connection::timer::{ConnTimer, PathTimer}, crypto::{self, KeyPair, Keys, PacketKey}, frame::{self, Close, Datagram, FrameStruct, NewToken, ObservedAddr}, + iroh_hp, packet::{ FixedLengthConnectionIdParser, Header, InitialHeader, InitialPacket, LongType, Packet, PacketNumber, PartialDecode, SpaceId, @@ -302,6 +303,8 @@ pub struct Connection { // TODO(flub): Make this a more efficient data structure. Like ranges of abandoned // paths. Or a set together with a minimum. Or something. abandoned_paths: FxHashSet, + + iroh_hp: Option, } impl Connection { @@ -441,6 +444,9 @@ impl Connection { remote_max_path_id: PathId::ZERO, max_path_id_with_cids: PathId::ZERO, abandoned_paths: Default::default(), + + // iroh's nat traversal + iroh_hp: None, }; if path_validated { this.on_path_validated(PathId::ZERO); @@ -608,12 +614,13 @@ impl Connection { if self.abandoned_paths.contains(&path_id) || Some(path_id) > self.max_path_id() { return Err(ClosePathError::ClosedPath); } - if self - .paths - .keys() - .filter(|&id| !self.abandoned_paths.contains(id)) - .count() - < 2 + if self.paths.contains_key(&path_id) + && self + .paths + .keys() + .filter(|&id| !self.abandoned_paths.contains(id)) + .count() + < 2 { return Err(ClosePathError::LastOpenPath); } @@ -805,8 +812,7 @@ impl Connection { // for the path to be opened we need to send a packet on the path. Sending a challenge // guarantees this - data.challenge = Some(self.rng.random()); - data.challenge_pending = true; + data.send_new_challenge = true; let path = vacant_entry.insert(PathState { data, prev: None }); @@ -836,6 +842,20 @@ impl Connection { max_datagrams: usize, buf: &mut Vec, ) -> Option { + if let Some(address) = self.spaces[SpaceId::Data].pending.hole_punch_to.pop() { + trace!(dst = ?address, "RAND_DATA packet"); + buf.reserve_exact(8); // send 8 bytes of random data + let tmp: [u8; 8] = self.rng.random(); + buf.put_slice(&tmp); + return Some(Transmit { + destination: address.into(), + ecn: None, + size: 8, + segment_size: None, + src_ip: None, + }); + } + assert!(max_datagrams != 0); let max_datagrams = match self.config.enable_segmentation_offload { false => 1, @@ -1572,13 +1592,14 @@ impl Connection { path_id: PathId, ) -> Option { let (prev_cid, prev_path) = self.paths.get_mut(&path_id)?.prev.as_mut()?; - if !prev_path.challenge_pending { + // TODO (matheus23): We could use !prev_path.is_validating() here instead to + // (possibly) also re-send challenges when they get lost. + if !prev_path.send_new_challenge { return None; - } - prev_path.challenge_pending = false; - let token = prev_path - .challenge - .expect("previous path challenge pending without token"); + }; + prev_path.send_new_challenge = false; + let token = self.rng.random(); + prev_path.challenges_sent.insert(token, now); let destination = prev_path.remote; debug_assert_eq!( self.highest_space, @@ -1669,13 +1690,15 @@ impl Connection { first_decode, remaining, }) => { + let span = trace_span!("pkt", %path_id); + let _guard = span.enter(); // If this packet could initiate a migration and we're a client or a server that // forbids migration, drop the datagram. This could be relaxed to heuristically // permit NAT-rebinding-like migration. if let Some(known_remote) = self.path(path_id).map(|path| path.remote) { if remote != known_remote && !self.side.remote_may_migrate(&self.state) { trace!( - ?path_id, + %path_id, ?remote, path_remote = ?self.path(path_id).map(|p| p.remote), "discarding packet from unrecognized peer" @@ -1800,90 +1823,107 @@ impl Connection { } } }, - Timer::PerPath(path_id, timer) => match timer { - PathTimer::PathIdle => { - // TODO(flub): TransportErrorCode::NO_ERROR but where's the API to get - // that into a VarInt? - self.close_path(now, path_id, TransportErrorCode::NO_ERROR.into()) - .ok(); - } + // TODO: add path_id as span somehow + Timer::PerPath(path_id, timer) => { + let span = trace_span!("per-path timer fired", %path_id, ?timer); + let _guard = span.enter(); + match timer { + PathTimer::PathIdle => { + self.close_path(now, path_id, TransportErrorCode::NO_ERROR.into()) + .ok(); + } - PathTimer::PathKeepAlive => { - trace!(?path_id, "sending keep-alive on path"); - self.ping_path(path_id).ok(); - } - PathTimer::LossDetection => { - self.on_loss_detection_timeout(now, path_id); - self.config.qlog_sink.emit_recovery_metrics( - self.path_data(path_id).pto_count, - &mut self.paths.get_mut(&path_id).unwrap().data, - now, - self.orig_rem_cid, - ); - } - PathTimer::PathValidation => { - let Some(path) = self.paths.get_mut(&path_id) else { - continue; - }; - debug!("path validation failed"); - if let Some((_, prev)) = path.prev.take() { - path.data = prev; + PathTimer::PathKeepAlive => { + trace!("sending keep-alive on path"); + self.ping_path(path_id).ok(); } - path.data.challenge = None; - path.data.challenge_pending = false; - } - PathTimer::PathOpen => { - let Some(path) = self.path_mut(path_id) else { - continue; - }; - path.challenge = None; - path.challenge_pending = false; - debug!("new path validation failed"); - if let Err(err) = self.close_path( - now, - path_id, - TransportErrorCode::UNSTABLE_INTERFACE.into(), - ) { - warn!(?err, "failed closing path"); + PathTimer::LossDetection => { + self.on_loss_detection_timeout(now, path_id); + self.config.qlog_sink.emit_recovery_metrics( + self.path_data(path_id).pto_count, + &mut self.paths.get_mut(&path_id).unwrap().data, + now, + self.orig_rem_cid, + ); } + PathTimer::PathValidation => { + let Some(path) = self.paths.get_mut(&path_id) else { + continue; + }; + self.timers + .stop(Timer::PerPath(path_id, PathTimer::PathChallengeLost)); + debug!("path validation failed"); + if let Some((_, prev)) = path.prev.take() { + path.data = prev; + } + path.data.challenges_sent.clear(); + path.data.send_new_challenge = false; + } + PathTimer::PathChallengeLost => { + let Some(path) = self.paths.get_mut(&path_id) else { + continue; + }; + trace!("path challenge deemed lost"); + path.data.send_new_challenge = true; + } + PathTimer::PathOpen => { + let Some(path) = self.path_mut(path_id) else { + continue; + }; + path.challenges_sent.clear(); + path.send_new_challenge = false; + debug!("new path validation failed"); + if let Err(err) = self.close_path( + now, + path_id, + TransportErrorCode::UNSTABLE_INTERFACE.into(), + ) { + warn!(?err, "failed closing path"); + } - self.events.push_back(Event::Path(PathEvent::LocallyClosed { - id: path_id, - error: PathError::ValidationFailed, - })); - } - PathTimer::Pacing => trace!(?path_id, "pacing timer expired"), - PathTimer::MaxAckDelay => { - trace!("max ack delay reached"); - // This timer is only armed in the Data space - self.spaces[SpaceId::Data] - .for_path(path_id) - .pending_acks - .on_max_ack_delay_timeout() - } - PathTimer::PathAbandoned => { - // The path was abandoned and 3*PTO has expired since. Clean up all - // remaining state and install stateless reset token. - if let Some(loc_cid_state) = self.local_cid_state.remove(&path_id) { - let (min_seq, max_seq) = loc_cid_state.active_seq(); - for seq in min_seq..=max_seq { - self.endpoint_events.push_back( - EndpointEventInner::RetireConnectionId( - now, path_id, seq, false, - ), - ); + self.events.push_back(Event::Path(PathEvent::LocallyClosed { + id: path_id, + error: PathError::ValidationFailed, + })); + } + PathTimer::Pacing => trace!("pacing timer expired"), + PathTimer::MaxAckDelay => { + trace!("max ack delay reached"); + // This timer is only armed in the Data space + self.spaces[SpaceId::Data] + .for_path(path_id) + .pending_acks + .on_max_ack_delay_timeout() + } + PathTimer::PathAbandoned => { + // The path was abandoned and 3*PTO has expired since. Clean up all + // remaining state and install stateless reset token. + self.timers.stop_per_path(path_id); + if let Some(loc_cid_state) = self.local_cid_state.remove(&path_id) { + let (min_seq, max_seq) = loc_cid_state.active_seq(); + for seq in min_seq..=max_seq { + self.endpoint_events.push_back( + EndpointEventInner::RetireConnectionId( + now, path_id, seq, false, + ), + ); + } } + self.drop_path_state(path_id, now); + } + PathTimer::PathNotAbandoned => { + // The peer failed to respond with a PATH_ABANDON when we sent such a + // frame. + warn!("missing PATH_ABANDON from peer"); + // TODO(flub): What should the error code be? + self.close( + now, + TransportErrorCode::NO_ERROR.into(), + "peer ignored PATH_ABANDON frame".into(), + ); } - self.drop_path_state(path_id, now); - } - PathTimer::PathNotAbandoned => { - // The peer failed to respond with a PATH_ABANDON when we sent such a - // frame. - warn!(?path_id, "missing PATH_ABANDON from peer"); - // TODO(flub): What should the error code be? - self.close(now, 0u8.into(), "peer ignored PATH_ABANDON frame".into()); } - }, + } } } } @@ -2382,7 +2422,7 @@ impl Connection { .remove_in_flight(&info); let app_limited = self.app_limited; let path = self.path_data_mut(path_id); - if info.ack_eliciting && path.challenge.is_none() { + if info.ack_eliciting && !path.challenges_sent.is_empty() { // Only pass ACKs to the congestion controller if we are not validating the current // path, so as to ignore any ACKs from older paths still coming in. let rtt = path.rtt; @@ -2445,7 +2485,7 @@ impl Connection { let (_, space) = match self.pto_time_and_space(now, path_id) { Some(x) => x, None => { - error!("PTO expired while unset"); + error!(?path_id, "PTO expired while unset"); return; } }; @@ -3410,8 +3450,8 @@ impl Connection { } Ok((packet, number)) => { let span = match number { - Some(pn) => trace_span!("recv", space = ?packet.header.space(), pn, %path_id), - None => trace_span!("recv", space = ?packet.header.space(), %path_id), + Some(pn) => trace_span!("recv", space = ?packet.header.space(), pn), + None => trace_span!("recv", space = ?packet.header.space()), }; let _guard = span.enter(); @@ -4005,16 +4045,26 @@ impl Connection { .paths .get_mut(&path_id) .expect("payload is processed only after the path becomes known"); - if path.data.challenge == Some(token) && remote == path.data.remote { + + if remote != path.data.remote { + debug!(token, "ignoring invalid PATH_RESPONSE"); + } else if let Some(&challenge_sent) = path.data.challenges_sent.get(&token) { self.timers .stop(Timer::PerPath(path_id, PathTimer::PathValidation)); + self.timers + .stop(Timer::PerPath(path_id, PathTimer::PathChallengeLost)); if !path.data.validated { trace!("new path validated"); } self.timers .stop(Timer::PerPath(path_id, PathTimer::PathOpen)); - path.data.challenge = None; + path.data.challenges_sent.clear(); + path.data.send_new_challenge = false; path.data.validated = true; + path.data.rtt.update( + Duration::ZERO, + now.saturating_duration_since(challenge_sent), + ); self.events .push_back(Event::Path(PathEvent::Opened { id: path_id })); // mark the path as open from the application perspective now that Opened @@ -4028,8 +4078,8 @@ impl Connection { } } if let Some((_, ref mut prev)) = path.prev { - prev.challenge = None; - prev.challenge_pending = false; + prev.challenges_sent.clear(); + prev.send_new_challenge = false; } } else { debug!(token, "ignoring invalid PATH_RESPONSE"); @@ -4110,20 +4160,20 @@ impl Connection { } } Frame::NewConnectionId(frame) => { - let path_id = match (frame.path_id, self.max_path_id()) { - (Some(path_id), Some(current_max)) if path_id <= current_max => path_id, - (Some(_large_path_id), Some(_current_max)) => { + let path_id = if let Some(path_id) = frame.path_id { + if !self.is_multipath_negotiated() { return Err(TransportError::PROTOCOL_VIOLATION( - "PATH_NEW_CONNECTION_ID contains path_id exceeding current max", + "received PATH_NEW_CONNECTION_ID frame when multipath was not negotiated", )); } - (Some(_path_id), None) => { + if path_id > self.local_max_path_id { return Err(TransportError::PROTOCOL_VIOLATION( - "received PATH_NEW_CONNECTION_ID frame when multipath was not negotiated", + "PATH_NEW_CONNECTION_ID contains path_id exceeding current max", )); } - - (None, _) => PathId::ZERO, + path_id + } else { + PathId::ZERO }; if self.abandoned_paths.contains(&path_id) { @@ -4313,7 +4363,7 @@ impl Connection { // TODO(flub): which error code? self.close( now, - 0u8.into(), + TransportErrorCode::NO_ERROR.into(), Bytes::from_static(b"last path abandoned by peer"), ); } @@ -4369,17 +4419,16 @@ impl Connection { } Frame::MaxPathId(frame::MaxPathId(path_id)) => { span.record("path", tracing::field::debug(&path_id)); - if let Some(current_max) = self.max_path_id() { - // frames that do not increase the path id are ignored - self.remote_max_path_id = self.remote_max_path_id.max(path_id); - if self.max_path_id() != Some(current_max) { - self.issue_first_path_cids(now); - } - } else { + if !self.is_multipath_negotiated() { return Err(TransportError::PROTOCOL_VIOLATION( "received MAX_PATH_ID frame when multipath was not negotiated", )); } + // frames that do not increase the path id are ignored + if path_id > self.remote_max_path_id { + self.remote_max_path_id = path_id; + self.issue_first_path_cids(now); + } } Frame::PathsBlocked(frame::PathsBlocked(max_path_id)) => { // Receipt of a value of Maximum Path Identifier or Path Identifier that is higher than the local maximum value MUST @@ -4431,14 +4480,103 @@ impl Connection { )); } } - Frame::AddAddress(_addr) => { - // TODO(@divma): handle + Frame::AddAddress(addr) => { + let Some(hp_state) = self.iroh_hp.as_mut() else { + return Err(TransportError::PROTOCOL_VIOLATION( + "received ADD_ADDRESS frame when iroh's nat traversal was not negotiated", + )); + }; + + let Ok(mut client_state) = hp_state.client_side() else { + return Err(TransportError::PROTOCOL_VIOLATION( + "client sent ADD_ADDRESS frame", + )); + }; + + if !client_state.check_remote_address(&addr) { + // if the address is not valid we flag it, but update anyway + warn!(?addr, "server sent ilegal ADD_ADDRESS frame"); + } + + match client_state.add_remote_address(addr) { + Ok(maybe_added) => { + if let Some(added) = maybe_added { + self.events.push_back(Event::NatTraversal( + iroh_hp::Event::AddressAdded(added), + )); + } + } + Err(e) => { + warn!(?e, "failed to add remote address") + } + } } - Frame::PunchMeNow(_frame) => { - // TODO(@divma): handle + Frame::RemoveAddress(addr) => { + let Some(hp_state) = self.iroh_hp.as_mut() else { + return Err(TransportError::PROTOCOL_VIOLATION( + "received REMOVE_ADDRESS frame when iroh's nat traversal was not negotiated", + )); + }; + + let Ok(mut client_state) = hp_state.client_side() else { + return Err(TransportError::PROTOCOL_VIOLATION( + "client sent REMOVE_ADDRESS frame", + )); + }; + + if let Some(removed_addr) = client_state.remove_remote_address(addr) { + self.events + .push_back(Event::NatTraversal(iroh_hp::Event::AddressRemoved( + removed_addr, + ))); + } } - Frame::RemoveAddress(_frame) => { - // TODO(@divma): handle + Frame::ReachOut(reach_out) => { + let Some(hp_state) = self.iroh_hp.as_mut() else { + return Err(TransportError::PROTOCOL_VIOLATION( + "received REACH_OUT frame when iroh's nat traversal was not negotiated", + )); + }; + + match hp_state.handle_reach_out(reach_out) { + Ok(None) => { + // no action required here + } + Ok(Some(info)) => { + let iroh_hp::RandDataNeeded { + ip, + port, + round, + is_new_round, + } = info; + if is_new_round { + // TODO(@divma): this depends on round starting on 1 right now, + // because the round should be greater to the default one, which is + // zero + self.spaces[SpaceId::Data].pending.hole_punch_round = round; + self.spaces[SpaceId::Data].pending.hole_punch_to.clear(); + } + + self.spaces[SpaceId::Data] + .pending + .hole_punch_to + .push((ip, port)); + } + Err(iroh_hp::Error::WrongConnectionSide) => { + return Err(TransportError::PROTOCOL_VIOLATION( + "server sent REACH_OUT frames for nat traversal", + )); + } + Err(iroh_hp::Error::TooManyAddresses) => { + return Err(TransportError::PROTOCOL_VIOLATION( + "client exceeded allowed REACH_OUT frames for this round", + )); + } + Err(error) => { + warn!(%error,"error handling REACH_OUT frame"); + // TODO(@divma): check if this is reachable + } + } } } } @@ -4536,14 +4674,12 @@ impl Connection { })); } } - new_path.challenge = Some(self.rng.random()); - new_path.challenge_pending = true; + new_path.send_new_challenge = true; let mut prev = mem::replace(path, new_path); // Don't clobber the original path if the previous one hasn't been validated yet - if prev.challenge.is_none() { - prev.challenge = Some(self.rng.random()); - prev.challenge_pending = true; + if !prev.is_validating_path() { + prev.send_new_challenge = true; // We haven't updated the remote CID yet, this captures the remote CID we were using on // the previous path. @@ -4681,6 +4817,31 @@ impl Connection { self.stats.frame_tx.handshake_done.saturating_add(1); } + // REACH_OUT + // TODO(@divma): path explusive considerations + if let Some((round, addresses)) = space.pending.reach_out.as_mut() { + while let Some(local_addr) = addresses.pop() { + let reach_out = frame::ReachOut::new(*round, local_addr); + if buf.remaining_mut() > reach_out.size() { + trace!(%round, ?local_addr, "REACH_OUT"); + reach_out.write(buf); + let sent_reachouts = sent + .retransmits + .get_or_create() + .reach_out + .get_or_insert_with(|| (*round, Default::default())); + sent_reachouts.1.push(local_addr); + self.stats.frame_tx.reach_out = self.stats.frame_tx.reach_out.saturating_add(1); + } else { + addresses.push(local_addr); + break; + } + } + if addresses.is_empty() { + space.pending.reach_out = None; + } + } + // OBSERVED_ADDR if !path_exclusive_only && space_id == SpaceId::Data @@ -4779,44 +4940,48 @@ impl Connection { } // PATH_CHALLENGE - if buf.remaining_mut() > 9 && space_id == SpaceId::Data { - // Transmit challenges with every outgoing packet on an unvalidated path - if let Some(token) = path.challenge { - sent.non_retransmits = true; - sent.requires_padding = true; - trace!("PATH_CHALLENGE {:08x}", token); - buf.write(frame::FrameType::PATH_CHALLENGE); - buf.write(token); + if buf.remaining_mut() > 9 && space_id == SpaceId::Data && path.send_new_challenge { + path.send_new_challenge = false; - if is_multipath_negotiated && !path.validated && path.challenge_pending { - // queue informing the path status along with the challenge - space.pending.path_status.insert(path_id); - } + // Generate a new challenge every time we send a new PATH_CHALLENGE + let token = self.rng.random(); + path.challenges_sent.insert(token, now); + sent.non_retransmits = true; + sent.requires_padding = true; + trace!("PATH_CHALLENGE {:08x}", token); + buf.write(frame::FrameType::PATH_CHALLENGE); + buf.write(token); + self.stats.frame_tx.path_challenge += 1; + let pto = self.ack_frequency.max_ack_delay_for_pto() + path.rtt.pto_base(); + self.timers.set( + Timer::PerPath(path_id, PathTimer::PathChallengeLost), + now + pto, + ); - // But only send a packet solely for that purpose at most once - path.challenge_pending = false; + if is_multipath_negotiated && !path.validated && path.send_new_challenge { + // queue informing the path status along with the challenge + space.pending.path_status.insert(path_id); + } - // Always include an OBSERVED_ADDR frame with a PATH_CHALLENGE, regardless - // of whether one has already been sent on this path. - if space_id == SpaceId::Data - && self - .config - .address_discovery_role - .should_report(&self.peer_params.address_discovery_role) - { - let frame = - frame::ObservedAddr::new(path.remote, self.next_observed_addr_seq_no); - if buf.remaining_mut() > frame.size() { - frame.write(buf); + // Always include an OBSERVED_ADDR frame with a PATH_CHALLENGE, regardless + // of whether one has already been sent on this path. + if space_id == SpaceId::Data + && self + .config + .address_discovery_role + .should_report(&self.peer_params.address_discovery_role) + { + let frame = frame::ObservedAddr::new(path.remote, self.next_observed_addr_seq_no); + if buf.remaining_mut() > frame.size() { + frame.write(buf); - self.next_observed_addr_seq_no = - self.next_observed_addr_seq_no.saturating_add(1u8); - path.observed_addr_sent = true; + self.next_observed_addr_seq_no = + self.next_observed_addr_seq_no.saturating_add(1u8); + path.observed_addr_sent = true; - self.stats.frame_tx.observed_addr += 1; - sent.retransmits.get_or_create().observed_addr = true; - space.pending.observed_addr = false; - } + self.stats.frame_tx.observed_addr += 1; + sent.retransmits.get_or_create().observed_addr = true; + space.pending.observed_addr = false; } } } @@ -5166,6 +5331,43 @@ impl Connection { self.stats.frame_tx.stream += sent.stream_frames.len() as u64; } + // ADD_ADDRESS + // TODO(@divma): check if we need to do path exclusive filters + while space_id == SpaceId::Data && frame::AddAddress::SIZE_BOUND <= buf.remaining_mut() { + if let Some(added_address) = space.pending.add_address.pop_last() { + trace!( + seq = %added_address.seq_no, + ip = ?added_address.ip, + port = added_address.port, + "ADD_ADDRESS", + ); + added_address.write(buf); + sent.retransmits + .get_or_create() + .add_address + .insert(added_address); + self.stats.frame_tx.add_address = self.stats.frame_tx.add_address.saturating_add(1); + } else { + break; + } + } + + // REMOVE_ADDRESS + while space_id == SpaceId::Data && frame::RemoveAddress::SIZE_BOUND <= buf.remaining_mut() { + if let Some(removed_address) = space.pending.remove_address.pop_last() { + trace!(seq = %removed_address.seq_no, "REMOVE_ADDRESS"); + removed_address.write(buf); + sent.retransmits + .get_or_create() + .remove_address + .insert(removed_address); + self.stats.frame_tx.remove_address = + self.stats.frame_tx.remove_address.saturating_add(1); + } else { + break; + } + } + sent } @@ -5280,6 +5482,7 @@ impl Connection { } self.ack_frequency.peer_max_ack_delay = get_max_ack_delay(¶ms); + let mut multipath_enabled = None; if let (Some(local_max_path_id), Some(remote_max_path_id)) = ( self.config.get_initial_max_path_id(), params.initial_max_path_id, @@ -5287,7 +5490,55 @@ impl Connection { // multipath is enabled, register the local and remote maximums self.local_max_path_id = local_max_path_id; self.remote_max_path_id = remote_max_path_id; - debug!(initial_max_path_id=%local_max_path_id.min(remote_max_path_id), "multipath negotiated"); + let initial_max_path_id = local_max_path_id.min(remote_max_path_id); + debug!(%initial_max_path_id, "multipath negotiated"); + multipath_enabled = Some(initial_max_path_id); + } + + if let Some((max_locally_allowed_remote_addresses, max_remotely_allowed_remote_addresses)) = + self.config + .max_remote_nat_traversal_addresses + .zip(params.max_remote_nat_traversal_addresses) + { + if let Some(max_initial_paths) = + multipath_enabled.map(|path_id| path_id.saturating_add(1u8)) + { + let max_local_addresses = max_remotely_allowed_remote_addresses.get(); + let max_remote_addresses = max_locally_allowed_remote_addresses.get(); + self.iroh_hp = Some(iroh_hp::State::new( + max_remote_addresses, + max_local_addresses, + self.side(), + )); + debug!( + %max_remote_addresses, %max_local_addresses, + "iroh hole punching negotiated" + ); + + match self.side() { + Side::Client => { + if max_initial_paths.as_u32() < max_remote_addresses as u32 + 1 { + // in this case the client might try to open `max_remote_addresses` new + // paths, but the current multipath configuration will not allow it + warn!(%max_initial_paths, %max_remote_addresses, "local client configuration might cause nat traversal issues") + } else if max_local_addresses as u64 + > params.active_connection_id_limit.into_inner() + { + // the server allows us to send at most `params.active_connection_id_limit` + // but they might need at least `max_local_addresses` to effectively send + // `PATH_CHALLENGE` frames to each advertised local address + warn!(%max_local_addresses, remote_cid_limit=%params.active_connection_id_limit.into_inner(), "remote server configuration might cause nat traversal issues") + } + } + Side::Server => { + if (max_initial_paths.as_u32() as u64) < crate::LOC_CID_COUNT { + warn!(%max_initial_paths, local_cid_limit=%crate::LOC_CID_COUNT, "local server configuration might cause nat traversal issues") + } + } + } + } else { + debug!("iroh nat traversal enabled for both endpoints, but multipath is missing") + } } self.peer_params = params; @@ -5511,6 +5762,14 @@ impl Connection { self.path_data(PathId::ZERO).current_mtu() } + /// Triggers path validation on all paths + #[cfg(test)] + pub(crate) fn trigger_path_validation(&mut self) { + for path in self.paths.values_mut() { + path.data.send_new_challenge = true; + } + } + /// Whether we have 1-RTT data to send /// /// This checks for frames that can only be sent in the data space (1-RTT): @@ -5523,11 +5782,11 @@ impl Connection { /// may need to be sent. fn can_send_1rtt(&self, path_id: PathId, max_size: usize) -> SendableFrames { let path_exclusive = self.paths.get(&path_id).is_some_and(|path| { - path.data.challenge_pending + path.data.send_new_challenge || path .prev .as_ref() - .is_some_and(|(_, path)| path.challenge_pending) + .is_some_and(|(_, path)| path.send_new_challenge) || !path.data.path_responses.is_empty() }); let other = self.streams.can_send_stream_data() @@ -5648,10 +5907,14 @@ impl Connection { ); } - /// Returns the maximum [`PathId`] to be used in this connection. + /// Returns the maximum [`PathId`] to be used for sending in this connection. /// /// This is calculated as minimum between the local and remote's maximums when multipath is /// enabled, or `None` when disabled. + /// + /// For data that's received, we should use [`Self::local_max_path_id`] instead. + /// The reasoning is that the remote might already have updated to its own newer + /// [`Self::max_path_id`] after sending out a `MAX_PATH_ID` frame, but it got re-ordered. fn max_path_id(&self) -> Option { if self.is_multipath_negotiated() { Some(self.remote_max_path_id.min(self.local_max_path_id)) @@ -5659,6 +5922,142 @@ impl Connection { None } } + + /// Add addresses the local endpoint considers are reachable for nat traversal + /// + /// If adding any address fails, an error is returned. Previous addresses might have been + /// added. + // TODO(@divma): this combined api has the issue that an error does not mean nothing was done + pub fn add_nat_traversal_address(&mut self, address: SocketAddr) -> Result<(), iroh_hp::Error> { + let hp_state = self + .iroh_hp + .as_mut() + .ok_or(iroh_hp::Error::ExtensionNotNegotiated)?; + + if let Some(added) = hp_state.add_local_address(address)? { + self.spaces[SpaceId::Data].pending.add_address.insert(added); + }; + Ok(()) + } + + /// Removes an address the endpoing no longer considers reachable for nat traversal + /// + /// Addresses not present in the set will be silently ignored. + pub fn remove_nat_traversal_address( + &mut self, + address: SocketAddr, + ) -> Result<(), iroh_hp::Error> { + let is_server = self.side().is_server(); + let hp_state = self + .iroh_hp + .as_mut() + .ok_or(iroh_hp::Error::ExtensionNotNegotiated)?; + if let Some(removed) = hp_state.remove_local_address(address) { + if is_server { + self.spaces[SpaceId::Data] + .pending + .remove_address + .insert(removed); + } + } + Ok(()) + } + + /// Get the current local nat traversal addresses + pub fn get_local_nat_traversal_addresses(&self) -> Result, iroh_hp::Error> { + let hp_state = self + .iroh_hp + .as_ref() + .ok_or(iroh_hp::Error::ExtensionNotNegotiated)?; + Ok(hp_state.get_local_nat_traversal_addresses()) + } + + /// Get the currently advertised nat traversal addresses by the server + pub fn get_remote_nat_traversal_addresses(&self) -> Result, iroh_hp::Error> { + let hp_state = self + .iroh_hp + .as_ref() + .ok_or(iroh_hp::Error::ExtensionNotNegotiated)?; + hp_state.get_remote_nat_traversal_addresses() + } + + /// Initiates a new nat traversal round + /// + /// A nat traversal round involves advertising the client's local addresses in `REACH_OUT` + /// frames, and initiating probing of the known remote addresses. When a new round is + /// initiated, the previous one is cancelled, and paths that have not been opened are closed. + /// + /// Returns the server addresses that are now being probed. + pub fn initiate_nat_traversal_round( + &mut self, + now: Instant, + ) -> Result, iroh_hp::Error> { + let hp_state = self + .iroh_hp + .as_mut() + .ok_or(iroh_hp::Error::ExtensionNotNegotiated)?; + let iroh_hp::NatTraversalRound { + new_round, + reach_out_at, + addresses_to_probe, + prev_round_path_ids, + } = hp_state.initiate_nat_traversal_round()?; + + self.spaces[SpaceId::Data].pending.reach_out = Some((new_round, reach_out_at)); + + for path_id in prev_round_path_ids { + // TODO(@divma): this sounds reasonable but we need if this actually works for the + // purposes of the protocol + let validated = self + .path(path_id) + .map(|path| path.validated) + .unwrap_or(false); + + if !validated { + let _ = + self.close_path(now, path_id, TransportErrorCode::APPLICATION_ABANDON.into()); + } + } + + let mut err = None; + + let mut path_ids = Vec::with_capacity(addresses_to_probe.len()); + let mut probed_addresses = Vec::with_capacity(addresses_to_probe.len()); + let ipv6 = self.paths.values().any(|p| p.data.remote.is_ipv6()); + + for (ip, port) in addresses_to_probe { + // If this endpoint is an IPv6 endpoint we use IPv6 addresses for all remotes. + let remote = match ip { + IpAddr::V4(addr) if ipv6 => SocketAddr::new(addr.to_ipv6_mapped().into(), port), + IpAddr::V4(addr) => SocketAddr::new(addr.into(), port), + IpAddr::V6(_) if ipv6 => SocketAddr::new(ip, port), + IpAddr::V6(_) => { + trace!("not using IPv6 nat candidate for IPv4 socket"); + continue; + } + }; + match self.open_path_ensure(remote, PathStatus::Backup, now) { + Ok((path_id, path_was_known)) if !path_was_known => { + path_ids.push(path_id); + probed_addresses.push(remote); + } + Ok((path_id, _)) => { + trace!(%path_id, %remote,"nat traversal: path existed for remote") + } + Err(e) => { + debug!(%remote, %e,"nat traversal: failed to probe remote"); + err.get_or_insert(e); + } + } + } + + let hp_state = self.iroh_hp.as_mut().expect("previously validated"); + hp_state + .set_round_path_ids(path_ids) + .expect("connection side validated"); + + Ok(probed_addresses) + } } impl fmt::Debug for Connection { @@ -5973,6 +6372,8 @@ pub enum Event { DatagramsUnblocked, /// (Multi)Path events Path(PathEvent), + /// Iroh's nat traversal events + NatTraversal(iroh_hp::Event), } impl From for Event { diff --git a/quinn-proto/src/connection/paths.rs b/quinn-proto/src/connection/paths.rs index dc103f707..207919887 100644 --- a/quinn-proto/src/connection/paths.rs +++ b/quinn-proto/src/connection/paths.rs @@ -1,5 +1,6 @@ use std::{cmp, net::SocketAddr}; +use identity_hash::IntMap; use thiserror::Error; use tracing::{debug, trace}; @@ -49,7 +50,7 @@ impl PathId { pub const ZERO: Self = Self(0); /// The number of bytes this [`PathId`] uses when encoded as a [`VarInt`] - pub(crate) fn size(&self) -> usize { + pub(crate) const fn size(&self) -> usize { VarInt(self.0 as u64).size() } @@ -128,8 +129,10 @@ pub(super) struct PathData { pub(super) congestion: Box, /// Pacing state pub(super) pacing: Pacer, - pub(super) challenge: Option, - pub(super) challenge_pending: bool, + /// Actually sent challenges (on the wire) + pub(super) challenges_sent: IntMap, + /// Whether to *immediately* trigger another PATH_CHALLENGE (via Connection::can_send) + pub(super) send_new_challenge: bool, /// Pending responses to PATH_CHALLENGE frames pub(super) path_responses: PathResponses, /// Whether we're certain the peer can both send and receive on this address @@ -224,8 +227,8 @@ impl PathData { now, ), congestion, - challenge: None, - challenge_pending: false, + challenges_sent: Default::default(), + send_new_challenge: false, path_responses: PathResponses::default(), validated: false, total_sent: 0, @@ -278,8 +281,8 @@ impl PathData { pacing: Pacer::new(smoothed_rtt, congestion.window(), prev.current_mtu(), now), sending_ecn: true, congestion, - challenge: None, - challenge_pending: false, + challenges_sent: Default::default(), + send_new_challenge: false, path_responses: PathResponses::default(), validated: false, total_sent: 0, @@ -301,6 +304,11 @@ impl PathData { } } + /// Whether we're in the process of validating this path with PATH_CHALLENGEs + pub(super) fn is_validating_path(&self) -> bool { + !self.challenges_sent.is_empty() || self.send_new_challenge + } + /// Resets RTT, congestion control and MTU states. /// /// This is useful when it is known the underlying path has changed. @@ -348,11 +356,21 @@ impl PathData { /// Increment the total size of sent UDP datagrams pub(super) fn inc_total_sent(&mut self, inc: u64) { self.total_sent = self.total_sent.saturating_add(inc); + trace!( + remote = %self.remote, + anti_amplification_budget = %(self.total_recvd * 3).saturating_sub(self.total_sent), + "anti amplification budget decreased" + ); } /// Increment the total size of received UDP datagrams pub(super) fn inc_total_recvd(&mut self, inc: u64) { self.total_recvd = self.total_recvd.saturating_add(inc); + trace!( + remote = %self.remote, + anti_amplification_budget = %(self.total_recvd * 3).saturating_sub(self.total_sent), + "anti amplification budget increased" + ); } #[cfg(feature = "qlog")] diff --git a/quinn-proto/src/connection/spaces.rs b/quinn-proto/src/connection/spaces.rs index d099bdc82..b86dc8d9f 100644 --- a/quinn-proto/src/connection/spaces.rs +++ b/quinn-proto/src/connection/spaces.rs @@ -2,6 +2,7 @@ use std::{ cmp, collections::{BTreeMap, BTreeSet, VecDeque}, mem, + net::IpAddr, ops::{Bound, Index, IndexMut}, }; @@ -12,7 +13,11 @@ use tracing::{error, trace}; use super::{PathId, assembler::Assembler}; use crate::{ Dir, Duration, Instant, SocketAddr, StreamId, TransportError, TransportErrorCode, VarInt, - connection::StreamsState, crypto::Keys, frame, packet::SpaceId, range_set::ArrayRangeSet, + connection::StreamsState, + crypto::Keys, + frame::{self, AddAddress, RemoveAddress}, + packet::SpaceId, + range_set::ArrayRangeSet, shared::IssuedCid, }; @@ -551,6 +556,20 @@ pub struct Retransmits { pub(super) path_status: BTreeSet, /// If a PATH_CIDS_BLOCKED frame needs to be sent for a path pub(super) path_cids_blocked: Vec, + + // Nat traversal data + /// Addresses to report in `ADD_ADDRESS` frames + pub(super) add_address: BTreeSet, + /// Address IDs to remove in `REMOVE_ADDRESS` frames + pub(super) remove_address: BTreeSet, + /// Round and local addresses to advertise in `REACH_OUT` frames + pub(super) reach_out: Option<(VarInt, Vec<(IpAddr, u16)>)>, + /// Round of the nat traversal rand data that are pending + /// + /// This is only used for bitwise operations on the pending data. + pub(super) hole_punch_round: VarInt, + /// Remote addresses to which random data needs to be sent + pub(super) hole_punch_to: Vec<(IpAddr, u16)>, } impl Retransmits { @@ -574,6 +593,10 @@ impl Retransmits { && self.path_status.is_empty() && !self.max_path_id && !self.paths_blocked + && self.add_address.is_empty() + && self.remove_address.is_empty() + && self.reach_out.is_none() + && self.hole_punch_to.is_empty() } } @@ -600,6 +623,26 @@ impl ::std::ops::BitOrAssign for Retransmits { self.path_abandon.append(&mut rhs.path_abandon); self.max_path_id |= rhs.max_path_id; self.paths_blocked |= rhs.paths_blocked; + self.add_address.extend(rhs.add_address.iter().copied()); + self.remove_address + .extend(rhs.remove_address.iter().copied()); + // if there are two rounds, prefer the most recent reach out set + let lhs_round = self.reach_out.as_ref().map(|(round, _)| *round); + let rhs_round = rhs.reach_out.as_ref().map(|(round, _)| *round); + match (lhs_round, rhs_round) { + (None, Some(_)) => self.reach_out = rhs.reach_out.clone(), + (Some(lhs_round), Some(rhs_round)) if rhs_round > lhs_round => { + self.reach_out = rhs.reach_out.clone() + } + _ => {} + } + + if self.hole_punch_round < rhs.hole_punch_round { + self.hole_punch_round = rhs.hole_punch_round; + self.hole_punch_to = rhs.hole_punch_to.clone(); + } else if self.hole_punch_round == rhs.hole_punch_round { + self.hole_punch_to.extend_from_slice(&rhs.hole_punch_to); + } } } diff --git a/quinn-proto/src/connection/stats.rs b/quinn-proto/src/connection/stats.rs index 7394a28c7..43792f820 100644 --- a/quinn-proto/src/connection/stats.rs +++ b/quinn-proto/src/connection/stats.rs @@ -67,7 +67,7 @@ pub struct FrameStats { pub paths_blocked: u64, pub path_cids_blocked: u64, pub add_address: u64, - pub punch_me_now: u64, + pub reach_out: u64, pub remove_address: u64, } @@ -127,7 +127,7 @@ impl FrameStats { self.path_cids_blocked = self.path_cids_blocked.saturating_add(1) } Frame::AddAddress(_) => self.add_address = self.add_address.saturating_add(1), - Frame::PunchMeNow(_) => self.punch_me_now = self.punch_me_now.saturating_add(1), + Frame::ReachOut(_) => self.reach_out = self.reach_out.saturating_add(1), Frame::RemoveAddress(_) => self.remove_address = self.remove_address.saturating_add(1), } } diff --git a/quinn-proto/src/connection/timer.rs b/quinn-proto/src/connection/timer.rs index 8e1ff7f36..f54f79dba 100644 --- a/quinn-proto/src/connection/timer.rs +++ b/quinn-proto/src/connection/timer.rs @@ -44,25 +44,28 @@ pub(crate) enum PathTimer { PathIdle = 1, /// When to give up on validating a new path from RFC9000 migration PathValidation = 2, + /// When to resend a path challenge deemed lost + PathChallengeLost = 3, /// When to give up on validating a new (multi)path - PathOpen = 3, + PathOpen = 4, /// When to send a `PING` frame to keep the path alive - PathKeepAlive = 4, + PathKeepAlive = 5, /// When pacing will allow us to send a packet - Pacing = 5, + Pacing = 6, /// When to send an immediate ACK if there are unacked ack-eliciting packets of the peer - MaxAckDelay = 6, + MaxAckDelay = 7, /// When to clean up state for an abandoned path - PathAbandoned = 7, + PathAbandoned = 8, /// When the peer fails to confirm abandoning the path - PathNotAbandoned = 8, + PathNotAbandoned = 9, } impl PathTimer { - const VALUES: [Self; 9] = [ + const VALUES: [Self; 10] = [ Self::LossDetection, Self::PathIdle, Self::PathValidation, + Self::PathChallengeLost, Self::PathOpen, Self::PathKeepAlive, Self::Pacing, @@ -278,6 +281,15 @@ impl TimerTable { } } + /// Stops all per-path timers + pub(super) fn stop_per_path(&mut self, path_id: PathId) { + for timer in PathTimer::VALUES { + if let Some(e) = self.path_timers.get_mut(&path_id) { + e.stop(timer); + } + } + } + /// Get the next queued timeout pub(super) fn peek(&mut self) -> Option { // TODO: this is currently linear in the number of paths diff --git a/quinn-proto/src/frame.rs b/quinn-proto/src/frame.rs index 54bfc90a1..54f04cd48 100644 --- a/quinn-proto/src/frame.rs +++ b/quinn-proto/src/frame.rs @@ -150,12 +150,12 @@ frame_types! { MAX_PATH_ID = 0x15228c0c, PATHS_BLOCKED = 0x15228c0d, PATH_CIDS_BLOCKED = 0x15228c0e, - // NAT TRAVERSAL - ADD_IPV4_ADDRESS = 0x3d7e90, - ADD_IPV6_ADDRESS = 0x3d7e91, - PUNCH_IPV4_ADDR = 0x3d7e92, - PUNCH_IPV6_ADDR = 0x3d7e93, - REMOVE_ADDRESS = 0x3d7e94, + // IROH'S NAT TRAVERSAL + ADD_IPV4_ADDRESS = 0x3d7f90, + ADD_IPV6_ADDRESS = 0x3d7f91, + REACH_OUT_AT_IPV4 = 0x3d7f92, + REACH_OUT_AT_IPV6 = 0x3d7f93, + REMOVE_ADDRESS = 0x3d7f94, } const STREAM_TYS: RangeInclusive = RangeInclusive::new(0x08, 0x0f); @@ -195,7 +195,7 @@ pub(crate) enum Frame { PathsBlocked(PathsBlocked), PathCidsBlocked(PathCidsBlocked), AddAddress(AddAddress), - PunchMeNow(PunchMeNow), + ReachOut(ReachOut), RemoveAddress(RemoveAddress), } @@ -247,7 +247,7 @@ impl Frame { PathsBlocked(_) => FrameType::PATHS_BLOCKED, PathCidsBlocked(_) => FrameType::PATH_CIDS_BLOCKED, AddAddress(ref frame) => frame.get_type(), - PunchMeNow(ref frame) => frame.get_type(), + ReachOut(ref frame) => frame.get_type(), RemoveAddress(_) => self::RemoveAddress::TYPE, } } @@ -985,10 +985,10 @@ impl Iter { let add_address = AddAddress::read(&mut self.bytes, is_ipv6)?; Frame::AddAddress(add_address) } - FrameType::PUNCH_IPV4_ADDR | FrameType::PUNCH_IPV6_ADDR => { - let is_ipv6 = ty == FrameType::PUNCH_IPV6_ADDR; - let punch_me = PunchMeNow::read(&mut self.bytes, is_ipv6)?; - Frame::PunchMeNow(punch_me) + FrameType::REACH_OUT_AT_IPV4 | FrameType::REACH_OUT_AT_IPV6 => { + let is_ipv6 = ty == FrameType::REACH_OUT_AT_IPV6; + let reach_out = ReachOut::read(&mut self.bytes, is_ipv6)?; + Frame::ReachOut(reach_out) } FrameType::REMOVE_ADDRESS => { Frame::RemoveAddress(RemoveAddress::read(&mut self.bytes)?) @@ -1478,7 +1478,7 @@ impl PathBackup { /// Conjuction of the information contained in the add address frames /// ([`FrameType::ADD_IPV4_ADDRESS`], [`FrameType::ADD_IPV6_ADDRESS`]). -#[derive(Debug, PartialEq, Eq, Clone)] +#[derive(Debug, PartialEq, Eq, Copy, Clone, PartialOrd, Ord)] // TODO(@divma): remove #[allow(dead_code)] pub(crate) struct AddAddress { @@ -1502,12 +1502,8 @@ impl AddAddress { } .size(); - pub(crate) const fn new(remote: std::net::SocketAddr, seq_no: VarInt) -> Self { - Self { - ip: remote.ip(), - port: remote.port(), - seq_no, - } + pub(crate) const fn new((ip, port): (IpAddr, u16), seq_no: VarInt) -> Self { + Self { ip, port, seq_no } } /// Get the [`FrameType`] for this frame. @@ -1560,21 +1556,22 @@ impl AddAddress { /// Give the [`SocketAddr`] encoded in the frame pub(crate) fn socket_addr(&self) -> SocketAddr { - (self.ip, self.port).into() + self.ip_port().into() + } + + pub(crate) fn ip_port(&self) -> (IpAddr, u16) { + (self.ip, self.port) } } -/// Conjuction of the information contained in the punch me now frames -/// ([`FrameType::PUNCH_IPV4_ADDR`], [`FrameType::PUNCH_IPV6_ADDR`]) +/// Conjuction of the information contained in the reach out frames +/// ([`FrameType::REACH_OUT_AT_IPV4`], [`FrameType::REACH_OUT_AT_IPV6`]) #[derive(Debug, PartialEq, Eq, Clone)] -// TODO(@divma): remove. Beg the draft people for a better name +// TODO(@divma): remove #[allow(dead_code)] -pub(crate) struct PunchMeNow { +pub(crate) struct ReachOut { /// The sequence number of the NAT Traversal attempts - // TODO(@divma): type assumed, spec is un-spec-ific pub(crate) round: VarInt, - /// The sequence number of the address that was paired with this address - pub(crate) paired_with: VarInt, /// Address to use pub(crate) ip: IpAddr, /// Port to use with this address @@ -1583,35 +1580,25 @@ pub(crate) struct PunchMeNow { // TODO(@divma): remove #[allow(dead_code)] -impl PunchMeNow { +impl ReachOut { /// Smallest number of bytes this type of frame is guaranteed to fit within pub(crate) const SIZE_BOUND: usize = Self { round: VarInt::MAX, - paired_with: VarInt::MAX, ip: IpAddr::V6(std::net::Ipv6Addr::LOCALHOST), port: u16::MAX, } .size(); - pub(crate) const fn new( - round: VarInt, - paired_with: VarInt, - local_addr: std::net::SocketAddr, - ) -> Self { - Self { - round, - paired_with, - ip: local_addr.ip(), - port: local_addr.port(), - } + pub(crate) const fn new(round: VarInt, (ip, port): (IpAddr, u16)) -> Self { + Self { round, ip, port } } /// Get the [`FrameType`] for this frame pub(crate) const fn get_type(&self) -> FrameType { if self.ip.is_ipv6() { - FrameType::PUNCH_IPV6_ADDR + FrameType::REACH_OUT_AT_IPV6 } else { - FrameType::PUNCH_IPV4_ADDR + FrameType::REACH_OUT_AT_IPV4 } } @@ -1619,17 +1606,15 @@ impl PunchMeNow { pub(crate) const fn size(&self) -> usize { let type_size = VarInt(self.get_type().0).size(); let round_bytes = self.round.size(); - let paired_with_bytes = self.paired_with.size(); let ip_bytes = if self.ip.is_ipv6() { 16 } else { 4 }; let port_bytes = 2; - type_size + round_bytes + paired_with_bytes + ip_bytes + port_bytes + type_size + round_bytes + ip_bytes + port_bytes } /// Unconditionally write this frame to `buf` pub(crate) fn write(&self, buf: &mut W) { buf.write(self.get_type()); buf.write(self.round); - buf.write(self.paired_with); match self.ip { IpAddr::V4(ipv4_addr) => { buf.write(ipv4_addr); @@ -1644,22 +1629,16 @@ impl PunchMeNow { /// Read the frame contents from the buffer /// /// Should only be called when the frame type has been identified as - /// [`FrameType::PUNCH_IPV4_ADDR`] or [`FrameType::PUNCH_IPV6_ADDR`]. + /// [`FrameType::REACH_OUT_AT_IPV4`] or [`FrameType::REACH_OUT_AT_IPV6`]. pub(crate) fn read(bytes: &mut R, is_ipv6: bool) -> coding::Result { let round = bytes.get()?; - let paired_with = bytes.get()?; let ip = if is_ipv6 { IpAddr::V6(bytes.get()?) } else { IpAddr::V4(bytes.get()?) }; let port = bytes.get()?; - Ok(Self { - round, - paired_with, - ip, - port, - }) + Ok(Self { round, ip, port }) } /// Give the [`SocketAddr`] encoded in the frame @@ -1669,7 +1648,7 @@ impl PunchMeNow { } /// Frame signaling an address is no longer being advertised -#[derive(Debug, PartialEq, Eq, Clone)] +#[derive(Debug, PartialEq, Eq, Copy, Clone, PartialOrd, Ord)] // TODO(@divma): remove #[allow(dead_code)] pub(crate) struct RemoveAddress { @@ -1976,18 +1955,17 @@ mod test { /// Test that encoding and decoding [`AddAddress`] produces the same result #[test] - fn test_punch_me_now_roundrip() { - let punch_me = PunchMeNow { + fn test_reach_out_roundrip() { + let reach_out = ReachOut { round: VarInt(42), - paired_with: VarInt(24), ip: std::net::Ipv6Addr::LOCALHOST.into(), port: 4242, }; - let mut buf = Vec::with_capacity(punch_me.size()); - punch_me.write(&mut buf); + let mut buf = Vec::with_capacity(reach_out.size()); + reach_out.write(&mut buf); assert_eq!( - punch_me.size(), + reach_out.size(), buf.len(), "expected written bytes and actual size differ" ); @@ -1995,7 +1973,7 @@ mod test { let mut decoded = frames(buf); assert_eq!(decoded.len(), 1); match decoded.pop().expect("non empty") { - Frame::PunchMeNow(decoded) => assert_eq!(decoded, punch_me), + Frame::ReachOut(decoded) => assert_eq!(decoded, reach_out), x => panic!("incorrect frame {x:?}"), } } diff --git a/quinn-proto/src/iroh_hp.rs b/quinn-proto/src/iroh_hp.rs new file mode 100644 index 000000000..d9b2f37cf --- /dev/null +++ b/quinn-proto/src/iroh_hp.rs @@ -0,0 +1,319 @@ +//! iroh NAT Traversal + +use std::{ + collections::hash_map::Entry, + net::{IpAddr, SocketAddr}, +}; + +use rustc_hash::{FxHashMap, FxHashSet}; + +use crate::{ + PathId, Side, VarInt, + frame::{AddAddress, ReachOut, RemoveAddress}, +}; + +/// Maximum number of addresses to handle, applied both to local and remote addresses, regardless +/// of configuration parameters +const MAX_ADDRESSES: u8 = 20; + +/// Errors that the nat traversal state might encounter. +#[derive(Debug, thiserror::Error)] +pub enum Error { + /// An endpoint (local or remote) tried to add too many addresses to their advertised set + #[error("Tried to add too many addresses to their advertised set")] + TooManyAddresses, + /// The operation is not allowed for this endpoint's connection side + #[error("Not allowed for this endpoint's connection side")] + WrongConnectionSide, + /// The extension was not negotiated + #[error("Iroh's nat traversal was not negotiated")] + ExtensionNotNegotiated, + /// Not enough addresses to complete the operation + #[error("Not enough addresses")] + NotEnoughAddresses, + /// Nat traversal attempt failed due to a multipath error + #[error("Failed to establish paths {0}")] + Multipath(super::PathError), +} + +pub(crate) struct NatTraversalRound { + /// Sequence number to use for the new reach out frames + pub(crate) new_round: VarInt, + /// Addresses to use to send reach out frames + pub(crate) reach_out_at: Vec<(IpAddr, u16)>, + /// Remotes to probe by attempting to open new paths + pub(crate) addresses_to_probe: Vec<(IpAddr, u16)>, + /// [`PathId`]s of the cancelled round + pub(crate) prev_round_path_ids: Vec, +} + +pub(crate) struct RandDataNeeded { + /// Destination address of the hole punching random data + pub(crate) ip: IpAddr, + /// Destination port of the hole punching random data + pub(crate) port: u16, + /// Round to which this hole punching random data belongs to + pub(crate) round: VarInt, + /// Whether this starts a new round + pub(crate) is_new_round: bool, +} + +/// Event emitted when the client receives ADD_ADDRESS or REMOVE_ADDRESS frames. +#[derive(Debug, Clone)] +pub enum Event { + /// An ADD_ADDRESS frame was received. + AddressAdded(SocketAddr), + /// A REMOVE_ADDRESS frame was received. + AddressRemoved(SocketAddr), +} + +/// State kept for Iroh's nat traversal +#[derive(Debug)] +pub(crate) struct State { + /// Max number of remote addresses we allow + /// + /// This is set by the local endpoint. + max_remote_addresses: usize, + /// Max number of local addresses allowed + /// + /// This is set by the remote endpoint. + max_local_addresses: usize, + /// Candidate addresses the remote server reports as potentially reachable, to use for nat + /// traversal attempts. Always canonical. + remote_addresses: FxHashMap, + /// Candidate addresses the local client reports as potentially reachable, to use for nat + /// traversal attempts. Always canonical. + local_addresses: FxHashMap<(IpAddr, u16), VarInt>, + /// The next id to use for local addresses sent to the client + next_local_addr_id: VarInt, + /// Local connection side + side: Side, + /// Current nat holepunching round + /// + /// Clients initiate hole punching rounds and are thus responsible for incrementing the count. + /// Servers keep track of the client's most recent round and cancel probing related to previous + /// rounds. + round: VarInt, + /// [`PathId`]s used to probe remotes assigned to this round + round_path_ids: Vec, + /// Addresses to which random data sent by servers to attempt to hole punch to clients + server_sent_rand_data: FxHashSet<(IpAddr, u16)>, +} + +/// Nat traversal api exclusive to clients +pub(crate) struct ClientSide<'a> { + state: &'a mut State, +} + +impl State { + /// Adds a local address to use for nat traversal + /// + /// When this endpoint is the server within the connection, these addresses will be sent to the + /// client in add address frames. For clients, these addresses will be sent in reach out frames + /// when nat traversal attempts are initiated. + /// + /// If a frame should be sent, it is returned. + pub(crate) fn add_local_address( + &mut self, + address: SocketAddr, + ) -> Result, Error> { + let address = (address.ip().to_canonical(), address.port()); + let allow_new = self.local_addresses.len() < self.max_local_addresses; + let is_server = self.side.is_server(); + match self.local_addresses.entry(address) { + Entry::Occupied(_) => Ok(None), + Entry::Vacant(vacant_entry) if allow_new => { + let id = self.next_local_addr_id; + self.next_local_addr_id = self.next_local_addr_id.saturating_add(1u8); + vacant_entry.insert(id); + if is_server { + Ok(Some(AddAddress::new(address, id))) + } else { + Ok(None) + } + } + _ => Err(Error::TooManyAddresses), + } + } + + /// Removes a local address from the advertised set for nat traversal + /// + /// When this endpoint is the server, removed addresses must be reported with remove address + /// frames. Clients will simply stop reporting these addresses in reach out frames. + /// + /// If a frame should be sent, it is returned. + pub(crate) fn remove_local_address(&mut self, address: SocketAddr) -> Option { + let id = self + .local_addresses + .remove(&(address.ip(), address.port()))?; + if self.side.is_server() { + Some(RemoveAddress::new(id)) + } else { + None + } + } + + pub(crate) fn client_side(&mut self) -> Result, Error> { + if self.side.is_client() { + Ok(ClientSide { state: self }) + } else { + Err(Error::WrongConnectionSide) + } + } + + pub(crate) fn new(max_remote_addresses: u8, max_local_addresses: u8, side: Side) -> Self { + Self { + remote_addresses: Default::default(), + local_addresses: Default::default(), + next_local_addr_id: Default::default(), + side, + round: Default::default(), + round_path_ids: Default::default(), + server_sent_rand_data: Default::default(), + max_remote_addresses: max_remote_addresses.min(MAX_ADDRESSES).into(), + max_local_addresses: max_local_addresses.min(MAX_ADDRESSES).into(), + } + } + + pub(crate) fn get_local_nat_traversal_addresses(&self) -> Vec { + self.local_addresses + .keys() + .copied() + .map(Into::into) + .collect() + } + + pub(crate) fn get_remote_nat_traversal_addresses(&self) -> Result, Error> { + if !self.side.is_client() { + return Err(Error::WrongConnectionSide); + } + + Ok(self + .remote_addresses + .values() + .copied() + .map(Into::into) + .collect()) + } + + /// Initiates a new nat traversal round + /// + /// A nat traversal round involves advertising the client's local addresses in `REACH_OUT` + /// frames, and initiating probing of the known remote addresses. When a new round is + /// initiated, the previous one is cancelled, and paths that have not been opened should be + /// closed. + pub(crate) fn initiate_nat_traversal_round(&mut self) -> Result { + if self.side.is_server() { + return Err(Error::WrongConnectionSide); + } + + if self.local_addresses.is_empty() { + return Err(Error::NotEnoughAddresses); + } + + let prev_round_path_ids = std::mem::take(&mut self.round_path_ids); + self.round = self.round.saturating_add(1u8); + + Ok(NatTraversalRound { + new_round: self.round, + reach_out_at: self.local_addresses.keys().copied().collect(), + addresses_to_probe: self.remote_addresses.values().copied().collect(), + prev_round_path_ids, + }) + } + + /// Add a [`PathId`] as part of the current attempts to create paths based on the server's + /// advertised addresses. + pub(crate) fn set_round_path_ids(&mut self, path_ids: Vec) -> Result<(), Error> { + if self.side.is_server() { + return Err(Error::WrongConnectionSide); + } + self.round_path_ids = path_ids; + Ok(()) + } + + /// Handles a received [`ReachOut`] + /// + /// It returns the token that should be sent in response to this frame as a challenge, and + /// whether this starts a new nat traversal round. + /// + /// If this frame was ignored, it returns `None`. + pub(crate) fn handle_reach_out( + &mut self, + reach_out: ReachOut, + ) -> Result, Error> { + let ReachOut { round, ip, port } = reach_out; + if self.side.is_client() { + return Err(Error::WrongConnectionSide); + } + + if round >= self.round { + let is_new_round = round > self.round; + if is_new_round { + self.server_sent_rand_data.clear(); + } + if self.server_sent_rand_data.len() >= self.max_remote_addresses { + return Err(Error::TooManyAddresses); + } + self.server_sent_rand_data.insert((ip, port)); + let info = RandDataNeeded { + ip, + port, + round, + is_new_round, + }; + return Ok(Some(info)); + } + + Ok(None) + } +} + +impl<'a> ClientSide<'a> { + /// Adds an address to the remote set + /// + /// On success returns the address if it was new to the set. It will error when the set has no + /// capacity for the address. + pub(crate) fn add_remote_address( + &mut self, + add_addr: AddAddress, + ) -> Result, Error> { + let AddAddress { seq_no, ip, port } = add_addr; + let address = (ip.to_canonical(), port); + let allow_new = self.state.remote_addresses.len() < self.state.max_remote_addresses; + match self.state.remote_addresses.entry(seq_no) { + Entry::Occupied(mut occupied_entry) => { + let old_value = occupied_entry.insert(address); + // The value might be different. This should not happen, but we assume that the new + // address is more recent than the previous, and thus worth updating + Ok((address != old_value).then_some(address.into())) + } + Entry::Vacant(vacant_entry) if allow_new => { + vacant_entry.insert(address); + Ok(Some(address.into())) + } + _ => Err(Error::TooManyAddresses), + } + } + + /// Removes an address from the remote set + /// + /// Returns whether the address was present. + pub(crate) fn remove_remote_address( + &mut self, + remove_addr: RemoveAddress, + ) -> Option { + self.state + .remote_addresses + .remove(&remove_addr.seq_no) + .map(Into::into) + } + + /// Checks that a received remote address is valid + /// + /// An address is valid as long as it does not change the value of a known address id. + pub(crate) fn check_remote_address(&self, add_addr: &AddAddress) -> bool { + let existing = self.state.remote_addresses.get(&add_addr.seq_no); + existing.is_none() || existing == Some(&add_addr.ip_port()) + } +} diff --git a/quinn-proto/src/lib.rs b/quinn-proto/src/lib.rs index 87d8366bb..9f3c3711e 100644 --- a/quinn-proto/src/lib.rs +++ b/quinn-proto/src/lib.rs @@ -103,6 +103,8 @@ mod address_discovery; mod token_memory_cache; pub use token_memory_cache::TokenMemoryCache; +pub mod iroh_hp; + #[cfg(feature = "arbitrary")] use arbitrary::Arbitrary; @@ -327,7 +329,7 @@ pub struct Transmit { // /// The maximum number of CIDs we bother to issue per path -const LOC_CID_COUNT: u64 = 8; +const LOC_CID_COUNT: u64 = 12; const RESET_TOKEN_SIZE: usize = 16; const MAX_CID_SIZE: usize = 20; const MIN_INITIAL_SIZE: u16 = 1200; diff --git a/quinn-proto/src/tests/mod.rs b/quinn-proto/src/tests/mod.rs index 4b3474260..c9fe23c9a 100644 --- a/quinn-proto/src/tests/mod.rs +++ b/quinn-proto/src/tests/mod.rs @@ -1339,6 +1339,72 @@ fn migration() { ); } +#[test] +fn path_challenge_retransmit() { + let _guard = subscribe(); + let mut pair = Pair::default(); + let (client_ch, server_ch) = pair.connect(); + pair.drive(); + + pair.client_conn_mut(client_ch).ping(); + pair.drive(); + + println!("-------- server wants path validation --------"); + pair.server_conn_mut(server_ch).trigger_path_validation(); + pair.drive_server(); // Send the path challenge + println!("-------- client loses messages --------"); + // Have the client lose the challenge + pair.client.inbound.clear(); + + pair.drive(); + + let client_tx = pair.client_conn_mut(client_ch).stats().frame_tx; + let server_tx = pair.server_conn_mut(server_ch).stats().frame_tx; + + assert_eq!( + server_tx.path_challenge, 2, + "expected server to send two path challenges" + ); + assert_eq!( + client_tx.path_response, 1, + "expected client to send one path response" + ); +} + +#[test] +fn path_response_retransmit() { + let _guard = subscribe(); + let mut pair = Pair::default(); + let (client_ch, server_ch) = pair.connect(); + pair.drive(); + + pair.client_conn_mut(client_ch).ping(); + pair.drive(); + + println!("-------- server wants path validation --------"); + pair.server_conn_mut(server_ch).trigger_path_validation(); + pair.drive_server(); // Send the path challenge + pair.drive_client(); // Send the path response + println!("-------- server loses messages --------"); + // Have the server lose the path response + pair.server.inbound.clear(); + + // The server should decide to re-send the path challenge + pair.drive(); + + let client_tx = pair.client_conn_mut(client_ch).stats().frame_tx; + let server_tx = pair.server_conn_mut(server_ch).stats().frame_tx; + + assert_eq!( + server_tx.path_challenge, 2, + "expected server to send two path challenges" + ); + assert_eq!( + client_tx.path_response, 2, + "expected client to send two path responses" + ); +} + fn test_flow_control(config: TransportConfig, window_size: usize) { let _guard = subscribe(); let mut pair = Pair::new( diff --git a/quinn-proto/src/tests/multipath.rs b/quinn-proto/src/tests/multipath.rs index 5c621ab56..b7bb2b1d2 100644 --- a/quinn-proto/src/tests/multipath.rs +++ b/quinn-proto/src/tests/multipath.rs @@ -363,6 +363,78 @@ fn issue_max_path_id() { assert_eq!(stats.frame_rx.path_new_connection_id, client_path_new_cids); } +/// A copy of [`issue_max_path_id`], but reordering the `MAX_PATH_ID` frame +/// that's sent from the server to the client, so that some `NEW_CONNECTION_ID` +/// frames arrive with higher path IDs than the most recently received +/// `MAX_PATH_ID` frame on the client side. +#[test] +fn issue_max_path_id_reordered() { + let _guard = subscribe(); + + // We enable multipath but initially do not allow any paths to be opened. + let multipath_transport_cfg = Arc::new(TransportConfig { + max_concurrent_multipath_paths: NonZeroU32::new(1), + ..TransportConfig::default() + }); + let server_cfg = Arc::new(ServerConfig { + transport: multipath_transport_cfg.clone(), + ..server_config() + }); + let server = Endpoint::new(Default::default(), Some(server_cfg), true, None); + let client = Endpoint::new(Default::default(), None, true, None); + + let mut pair = Pair::new_from_endpoint(client, server); + + // The client is allowed to create more paths immediately. + let client_multipath_transport_cfg = Arc::new(TransportConfig { + max_concurrent_multipath_paths: NonZeroU32::new(MAX_PATHS), + ..TransportConfig::default() + }); + let client_cfg = ClientConfig { + transport: client_multipath_transport_cfg, + ..client_config() + }; + let (_client_ch, server_ch) = pair.connect_with(client_cfg); + pair.drive(); + info!("connected"); + + // Server should only have sent NEW_CONNECTION_ID frames for now. + let server_new_cids = CidQueue::LEN as u64 - 1; + let mut server_path_new_cids = 0; + let stats = pair.server_conn_mut(server_ch).stats(); + assert_eq!(stats.frame_tx.max_path_id, 0); + assert_eq!(stats.frame_tx.new_connection_id, server_new_cids); + assert_eq!(stats.frame_tx.path_new_connection_id, server_path_new_cids); + + // Client should have sent PATH_NEW_CONNECTION_ID frames for PathId::ZERO. + let client_new_cids = 0; + let mut client_path_new_cids = CidQueue::LEN as u64; + assert_eq!(stats.frame_rx.new_connection_id, client_new_cids); + assert_eq!(stats.frame_rx.path_new_connection_id, client_path_new_cids); + + // Server increases MAX_PATH_ID, but we reorder the frame + pair.server_conn_mut(server_ch) + .set_max_concurrent_paths(Instant::now(), NonZeroU32::new(MAX_PATHS).unwrap()) + .unwrap(); + pair.drive_server(); + // reorder the frames on the incoming side + let p = pair.client.inbound.pop_front().unwrap(); + pair.client.inbound.push_back(p); + pair.drive(); + let stats = pair.server_conn_mut(server_ch).stats(); + + // Server should have sent MAX_PATH_ID and new CIDs + server_path_new_cids += (MAX_PATHS as u64 - 1) * CidQueue::LEN as u64; + assert_eq!(stats.frame_tx.max_path_id, 1); + assert_eq!(stats.frame_tx.new_connection_id, server_new_cids); + assert_eq!(stats.frame_tx.path_new_connection_id, server_path_new_cids); + + // Client should have sent CIDs for new paths + client_path_new_cids += (MAX_PATHS as u64 - 1) * CidQueue::LEN as u64; + assert_eq!(stats.frame_rx.new_connection_id, client_new_cids); + assert_eq!(stats.frame_rx.path_new_connection_id, client_path_new_cids); +} + #[test] fn open_path() { let _guard = subscribe(); diff --git a/quinn-proto/src/transport_parameters.rs b/quinn-proto/src/transport_parameters.rs index fdedf20e2..bae0b9cc0 100644 --- a/quinn-proto/src/transport_parameters.rs +++ b/quinn-proto/src/transport_parameters.rs @@ -9,6 +9,7 @@ use std::{ convert::TryFrom, net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6}, + num::NonZeroU8, }; use bytes::{Buf, BufMut}; @@ -116,11 +117,11 @@ macro_rules! make_struct { /// The role of this peer in address discovery, if any. pub(crate) address_discovery_role: address_discovery::Role, - // Multipath extension + /// Multipath extension pub(crate) initial_max_path_id: Option, /// Nat traversal draft - pub nat_traversal: Option, + pub max_remote_nat_traversal_addresses: Option, } // We deliberately don't implement the `Default` trait, since that would be public, and @@ -146,7 +147,7 @@ macro_rules! make_struct { write_order: None, address_discovery_role: address_discovery::Role::Disabled, initial_max_path_id: None, - nat_traversal: None, + max_remote_nat_traversal_addresses: None, } } } @@ -196,7 +197,7 @@ impl TransportParameters { }), address_discovery_role: config.address_discovery_role, initial_max_path_id: config.get_initial_max_path_id(), - nat_traversal: config.get_nat_traversal_concurrency_limit(), + max_remote_nat_traversal_addresses: config.max_remote_nat_traversal_addresses, ..Self::default() } } @@ -214,7 +215,7 @@ impl TransportParameters { || cached.max_datagram_frame_size > self.max_datagram_frame_size || cached.grease_quic_bit && !self.grease_quic_bit || cached.address_discovery_role != self.address_discovery_role - || cached.nat_traversal != self.nat_traversal + || cached.max_remote_nat_traversal_addresses != self.max_remote_nat_traversal_addresses { return Err(TransportError::PROTOCOL_VIOLATION( "0-RTT accepted with incompatible transport parameters", @@ -414,11 +415,11 @@ impl TransportParameters { w.write(val); } } - TransportParameterId::NatTraversal => { - if let Some(val) = self.nat_traversal { + TransportParameterId::IrohNatTraversal => { + if let Some(val) = self.max_remote_nat_traversal_addresses { w.write_var(id as u64); - w.write_var(val.size() as u64); - w.write(val); + w.write(VarInt(1)); + w.write(val.get()); } } id => { @@ -546,21 +547,18 @@ impl TransportParameters { params.initial_max_path_id = Some(value); } - TransportParameterId::NatTraversal => { - if params.nat_traversal.is_some() { + TransportParameterId::IrohNatTraversal => { + if params.max_remote_nat_traversal_addresses.is_some() { return Err(Error::Malformed); } - - let value: VarInt = r.get()?; - if len != value.size() { + if len != 1 { return Err(Error::Malformed); } - if value.into_inner() == 0 { - return Err(Error::IllegalValue); - } + let value: u8 = r.get()?; + let value = NonZeroU8::new(value).ok_or(Error::IllegalValue)?; - params.nat_traversal = Some(value); + params.max_remote_nat_traversal_addresses = Some(value); } _ => { macro_rules! parse { @@ -731,8 +729,9 @@ pub(crate) enum TransportParameterId { // https://datatracker.ietf.org/doc/html/draft-ietf-quic-multipath InitialMaxPathId = 0x0f739bbc1b666d0c, - // https://www.ietf.org/archive/id/draft-seemann-quic-nat-traversal-02.html - NatTraversal = 0x3d7e9f0bca12fea6, + // inspired by https://www.ietf.org/archive/id/draft-seemann-quic-nat-traversal-02.html, + // simplified to iroh's needs + IrohNatTraversal = 0x3d7f91120401, } impl TransportParameterId { @@ -761,7 +760,7 @@ impl TransportParameterId { Self::MinAckDelayDraft07, Self::ObservedAddr, Self::InitialMaxPathId, - Self::NatTraversal, + Self::IrohNatTraversal, ]; } @@ -803,6 +802,7 @@ impl TryFrom for TransportParameterId { id if Self::MinAckDelayDraft07 == id => Self::MinAckDelayDraft07, id if Self::ObservedAddr == id => Self::ObservedAddr, id if Self::InitialMaxPathId == id => Self::InitialMaxPathId, + id if Self::IrohNatTraversal == id => Self::IrohNatTraversal, _ => return Err(()), }; Ok(param) @@ -843,6 +843,7 @@ mod test { min_ack_delay: Some(2_000u32.into()), address_discovery_role: address_discovery::Role::SendOnly, initial_max_path_id: Some(PathId::MAX), + max_remote_nat_traversal_addresses: Some(5u8.try_into().unwrap()), ..TransportParameters::default() }; params.write(&mut buf); diff --git a/quinn/src/connection.rs b/quinn/src/connection.rs index 464e07760..a89665736 100644 --- a/quinn/src/connection.rs +++ b/quinn/src/connection.rs @@ -28,7 +28,7 @@ use crate::{ }; use proto::{ ConnectionError, ConnectionHandle, ConnectionStats, Dir, EndpointEvent, PathError, PathEvent, - PathId, PathStats, PathStatus, Side, StreamEvent, StreamId, congestion::Controller, + PathId, PathStats, PathStatus, Side, StreamEvent, StreamId, congestion::Controller, iroh_hp, }; /// In-progress connection attempt future @@ -492,6 +492,15 @@ impl Connection { self.0.state.lock("path_events").path_events.subscribe() } + /// A broadcast receiver of [`iroh_hp::Event`]s for updates about server addresses + pub fn nat_traversal_updates(&self) -> tokio::sync::broadcast::Receiver { + self.0 + .state + .lock("nat_traversal_updates") + .nat_traversal_updates + .subscribe() + } + /// Wait for the connection to be closed for any reason /// /// Despite the return type's name, closed connections are often not an error condition at the @@ -841,6 +850,56 @@ impl Connection { let conn = self.0.state.lock("is_multipath_enabled"); conn.inner.is_multipath_negotiated() } + + /// Registers one address at which this endpoint might be reachable + /// + /// When the NAT traversal extension is negotiated, servers send these addresses to clients in + /// `ADD_ADDRESS` frames. This allows clients to obtain server address candidates to initiate + /// NAT traversal attempts. Clients provide their own reachable addresses in `REACH_OUT` frames + /// when [`Self::initiate_nat_traversal_round`] is called. + pub fn add_nat_traversal_address(&self, address: SocketAddr) -> Result<(), iroh_hp::Error> { + let mut conn = self.0.state.lock("add_nat_traversal_addresses"); + conn.inner.add_nat_traversal_address(address) + } + + /// Removes one or more addresses from the set of addresses at which this endpoint is reachable + /// + /// When the NAT traversal extension is negotiated, servers send address removals to + /// clients in `REMOVE_ADDRESS` frames. This allows clients to stop using outdated + /// server address candidates that are no longer valid for NAT traversal. + /// + /// For clients, removed addresses will no longer be advertised in `REACH_OUT` frames. + /// + /// Addresses not present in the set will be silently ignored. + pub fn remove_nat_traversal_address(&self, address: SocketAddr) -> Result<(), iroh_hp::Error> { + let mut conn = self.0.state.lock("remove_nat_traversal_addresses"); + conn.inner.remove_nat_traversal_address(address) + } + + /// Get the current local nat traversal addresses + pub fn get_local_nat_traversal_addresses(&self) -> Result, iroh_hp::Error> { + let conn = self.0.state.lock("get_remote_nat_traversal_addresses"); + conn.inner.get_local_nat_traversal_addresses() + } + + /// Get the currently advertised nat traversal addresses by the server + pub fn get_remote_nat_traversal_addresses(&self) -> Result, iroh_hp::Error> { + let conn = self.0.state.lock("get_remote_nat_traversal_addresses"); + conn.inner.get_remote_nat_traversal_addresses() + } + + /// Initiates a new nat traversal round + /// + /// A nat traversal round involves advertising the client's local addresses in `REACH_OUT` + /// frames, and initiating probing of the known remote addresses. When a new round is + /// initiated, the previous one is cancelled, and paths that have not been opened are closed. + /// + /// Returns the server addresses that are now being probed. + pub fn initiate_nat_traversal_round(&self) -> Result, iroh_hp::Error> { + let mut conn = self.0.state.lock("initiate_nat_traversal_round"); + let now = conn.runtime.now(); + conn.inner.initiate_nat_traversal_round(now) + } } pin_project! { @@ -1137,6 +1196,7 @@ impl ConnectionRef { send_buffer: Vec::new(), buffered_transmit: None, observed_external_addr: watch::Sender::new(None), + nat_traversal_updates: tokio::sync::broadcast::channel(32).0, on_closed: Vec::new(), }), shared: Shared::default(), @@ -1276,6 +1336,7 @@ pub(crate) struct State { /// Our last external address reported by the peer. When multipath is enabled, this will be the /// last report across all paths. pub(crate) observed_external_addr: watch::Sender>, + pub(crate) nat_traversal_updates: tokio::sync::broadcast::Sender, on_closed: Vec>, } @@ -1457,6 +1518,9 @@ impl State { Path(evt @ PathEvent::RemoteStatus { .. }) => { self.path_events.send(evt).ok(); } + NatTraversal(update) => { + self.nat_traversal_updates.send(update).ok(); + } } } }