diff --git a/src/bin/main.rs b/src/bin/main.rs index 95af572c..a5bee420 100644 --- a/src/bin/main.rs +++ b/src/bin/main.rs @@ -367,7 +367,7 @@ async fn main() -> eyre::Result<()> { loop { let (wormhole, _code, relay_server) = parse_and_connect(&mut term, matches, true, forwarding::APP_CONFIG).await?; - let relay_server = vec![transit::RelayHint::from_url(relay_server)]; + let relay_server = vec![transit::RelayHint::from_urls(None, [relay_server])]; async_std::task::spawn(forwarding::serve(wormhole, relay_server, targets.clone())); } } else if let Some(matches) = matches.subcommand_matches("connect") { @@ -380,7 +380,7 @@ async fn main() -> eyre::Result<()> { let bind_address: std::net::IpAddr = matches.value_of("bind").unwrap().parse()?; let (wormhole, _code, relay_server) = parse_and_connect(&mut term, matches, false, forwarding::APP_CONFIG).await?; - let relay_server = vec![transit::RelayHint::from_url(relay_server)]; + let relay_server = vec![transit::RelayHint::from_urls(None, [relay_server])]; forwarding::connect(wormhole, relay_server, Some(bind_address), &custom_ports).await?; } else { diff --git a/src/transfer.rs b/src/transfer.rs index ab221ce7..9d91ea4f 100644 --- a/src/transfer.rs +++ b/src/transfer.rs @@ -254,7 +254,7 @@ where H: FnMut(u64, u64) + 'static, { let _peer_version: AppVersion = serde_json::from_value(wormhole.peer_version.clone())?; - let relay_hints = vec![transit::RelayHint::from_url(relay_url)]; + let relay_hints = vec![transit::RelayHint::from_urls(None, [relay_url])]; // if peer_version.supports_v2() && false { // v2::send_file(wormhole, relay_url, file, file_name, file_size, progress_handler, peer_version).await // } else { @@ -288,7 +288,7 @@ where M: Into, H: FnMut(u64, u64) + 'static, { - let relay_hints = vec![transit::RelayHint::from_url(relay_url)]; + let relay_hints = vec![transit::RelayHint::from_urls(None, [relay_url])]; v1::send_folder( wormhole, relay_hints, @@ -309,7 +309,7 @@ pub async fn request_file( mut wormhole: Wormhole, relay_url: url::Url, ) -> Result { - let relay_hints = vec![transit::RelayHint::from_url(relay_url)]; + let relay_hints = vec![transit::RelayHint::from_urls(None, [relay_url])]; let connector = transit::init(transit::Abilities::ALL_ABILITIES, None, relay_hints).await?; // send the transit message diff --git a/src/transfer/messages.rs b/src/transfer/messages.rs index 2c6bd241..35dc61de 100644 --- a/src/transfer/messages.rs +++ b/src/transfer/messages.rs @@ -179,15 +179,36 @@ mod test { let abilities = Abilities::ALL_ABILITIES; let hints = transit::Hints::new( [DirectHint::new("192.168.1.8", 46295)], - [RelayHint::from_url( - "tcp://magic-wormhole-transit.debian.net:4001" - .parse() - .unwrap(), + [RelayHint::new( + None, + [DirectHint::new("magic-wormhole-transit.debian.net", 4001)], + [], )], ); - let t = - serde_json::json!(crate::transfer::PeerMessage::transit(abilities, hints)).to_string(); - assert_eq!(t, "{\"transit\":{\"abilities-v1\":[{\"type\":\"direct-tcp-v1\"},{\"type\":\"relay-v1\",\"url-hints\":true}],\"hints-v1\":[{\"hostname\":\"192.168.1.8\",\"port\":46295,\"type\":\"direct-tcp-v1\"},{\"hints\":[{\"hostname\":\"magic-wormhole-transit.debian.net\",\"port\":4001}],\"type\":\"relay-v1\",\"urls\":[\"tcp://magic-wormhole-transit.debian.net:4001\"]}]}}") + assert_eq!( + serde_json::json!(crate::transfer::PeerMessage::transit(abilities, hints)), + serde_json::json!({ + "transit": { + "abilities-v1": [{"type":"direct-tcp-v1"},{"type":"relay-v1"},{"type":"relay-v2"}], + "hints-v1": [ + {"hostname":"192.168.1.8","port":46295,"type":"direct-tcp-v1"}, + { + "type": "relay-v1", + "hints": [ + {"hostname": "magic-wormhole-transit.debian.net", "port": 4001 } + ] + }, + { + "type": "relay-v2", + "hints": [ + {"type": "tcp", "hostname": "magic-wormhole-transit.debian.net", "port": 4001} + ], + "name": null + } + ], + } + }) + ); } #[test] diff --git a/src/transit.rs b/src/transit.rs index b87595aa..f6962ca4 100644 --- a/src/transit.rs +++ b/src/transit.rs @@ -8,7 +8,7 @@ //! are in the same network, or the URL to a relay server. In case a direct connection fails, both will connect to the relay server //! which will transparently glue the connections together. //! -//! Each side might implement (or use/enable) some [abilities](Ability). +//! Each side might implement (or use/enable) some [abilities](Abilities). //! //! **Notice:** while the resulting TCP connection is naturally bi-directional, the handshake is not symmetric. There *must* be one //! "leader" side and one "follower" side (formerly called "sender" and "receiver"). @@ -98,33 +98,26 @@ pub enum TransitError { ), } -#[derive(Copy, Clone, Debug, Serialize, Deserialize)] -#[serde(rename_all = "kebab-case")] -pub struct RelayAbility { - pub url_hints: bool, -} - -impl Default for RelayAbility { - fn default() -> Self { - Self { url_hints: true } - } -} - /** * Defines a way to find the other side. * * Each ability comes with a set of [`Hints`] to encode how to meet up. */ -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, Default)] pub struct Abilities { + /** Direct connection to the peer */ pub direct_tcp_v1: bool, - pub relay_v1: Option, + /** Connection over a TCP relay */ + pub relay_v1: bool, + /** Connection over a TCP or WebSocket relay */ + pub relay_v2: bool, } impl Abilities { pub const ALL_ABILITIES: Self = Self { direct_tcp_v1: true, - relay_v1: Some(RelayAbility { url_hints: true }), + relay_v1: true, + relay_v2: true, }; /** @@ -135,7 +128,8 @@ impl Abilities { */ pub const FORCE_DIRECT: Self = Self { direct_tcp_v1: true, - relay_v1: None, + relay_v1: false, + relay_v2: false, }; /** @@ -148,7 +142,8 @@ impl Abilities { */ pub const FORCE_RELAY: Self = Self { direct_tcp_v1: false, - relay_v1: Some(RelayAbility { url_hints: true }), + relay_v1: true, + relay_v2: true, }; pub fn can_direct(&self) -> bool { @@ -156,32 +151,18 @@ impl Abilities { } pub fn can_relay(&self) -> bool { - self.relay_v1.is_some() + self.relay_v1 || self.relay_v2 } /** Keep only abilities that both sides support */ pub fn intersect(mut self, other: &Self) -> Self { self.direct_tcp_v1 &= other.direct_tcp_v1; - self.relay_v1 = match (self.relay_v1, other.relay_v1) { - (Some(RelayAbility { url_hints: true }), Some(RelayAbility { url_hints: true })) => { - Some(RelayAbility { url_hints: true }) - }, - (Some(_), Some(_)) => Some(RelayAbility { url_hints: false }), - _ => None, - }; + self.relay_v1 &= other.relay_v1; + self.relay_v2 &= other.relay_v2; self } } -impl Default for Abilities { - fn default() -> Self { - Self { - direct_tcp_v1: false, - relay_v1: None, - } - } -} - impl serde::Serialize for Abilities { fn serialize(&self, ser: S) -> Result where @@ -193,10 +174,14 @@ impl serde::Serialize for Abilities { "type": "direct-tcp-v1", })); } - if let Some(relay_v1) = self.relay_v1 { + if self.relay_v1 { hints.push(serde_json::json!({ "type": "relay-v1", - "url-hints": relay_v1.url_hints, + })); + } + if self.relay_v2 { + hints.push(serde_json::json!({ + "type": "relay-v2", })); } serde_json::Value::Array(hints).serialize(ser) @@ -212,10 +197,8 @@ impl<'de> serde::Deserialize<'de> for Abilities { #[serde(rename_all = "kebab-case", tag = "type")] enum Ability { DirectTcpV1, - RelayV1 { - #[serde(default)] - url_hints: bool, - }, + RelayV1, + RelayV2, #[serde(other)] Other, } @@ -227,8 +210,11 @@ impl<'de> serde::Deserialize<'de> for Abilities { Ability::DirectTcpV1 => { abilities.direct_tcp_v1 = true; }, - Ability::RelayV1 { url_hints } => { - abilities.relay_v1 = Some(RelayAbility { url_hints }); + Ability::RelayV1 => { + abilities.relay_v1 = true; + }, + Ability::RelayV2 => { + abilities.relay_v2 = true; }, _ => (), } @@ -245,16 +231,39 @@ enum HintSerde { DirectTcpV1(DirectHint), RelayV1 { hints: HashSet, - /** Newer encoding. When present, the `hints` field is redundant. - */ - urls: Option>, }, + RelayV2(RelayHint), #[serde(other)] Unknown, } -impl From> for Hints { - fn from(hints: Vec) -> Hints { +/** Information about how to find a peer */ +#[derive(Clone, Debug, Default)] +pub struct Hints { + /** Hints for direct connection */ + pub direct_tcp: HashSet, + /** List of relay servers */ + pub relay: Vec, +} + +impl Hints { + pub fn new( + direct_tcp: impl IntoIterator, + relay: impl IntoIterator, + ) -> Self { + Self { + direct_tcp: direct_tcp.into_iter().collect(), + relay: relay.into_iter().collect(), + } + } +} + +impl<'de> serde::Deserialize<'de> for Hints { + fn deserialize(de: D) -> Result + where + D: serde::Deserializer<'de>, + { + let hints: Vec = serde::Deserialize::deserialize(de)?; let mut direct_tcp = HashSet::new(); let mut relay = Vec::::new(); let mut relay_v2 = Vec::::new(); @@ -264,18 +273,14 @@ impl From> for Hints { HintSerde::DirectTcpV1(hint) => { direct_tcp.insert(hint); }, - HintSerde::RelayV1 { hints, urls: None } => { + HintSerde::RelayV1 { hints } => { relay.push(RelayHint { tcp: hints, ..RelayHint::default() }); }, - HintSerde::RelayV1 { - hints: _, - urls: Some(urls), - } => { - let hint = RelayHint::new(urls); - hint.merge_into(&mut relay_v2); + HintSerde::RelayV2(hint) => { + relay_v2.push(hint); }, /* Ignore unknown hints */ _ => {}, @@ -288,48 +293,7 @@ impl From> for Hints { } relay.extend(relay_v2.into_iter().map(Into::into)); - Hints { direct_tcp, relay } - } -} - -#[derive(Clone, Debug, Default)] -pub struct Hints { - pub direct_tcp: HashSet, - pub relay: Vec, -} - -impl Hints { - pub fn new( - direct_tcp: impl IntoIterator, - relay: impl IntoIterator, - ) -> Self { - Self { - direct_tcp: direct_tcp.into_iter().collect(), - relay: relay.into_iter().collect(), - } - } - - fn iter_serde(&self) -> impl IntoIterator + '_ { - self.direct_tcp - .iter() - .cloned() - .map(HintSerde::DirectTcpV1) - .chain(self.relay.iter().flat_map(|hint| { - [HintSerde::RelayV1 { - hints: hint.tcp.clone(), - urls: Some(hint.iter_urls().into_iter().collect()), - }] - })) - } -} - -impl<'de> serde::Deserialize<'de> for Hints { - fn deserialize(de: D) -> Result - where - D: serde::Deserializer<'de>, - { - let hints: Vec = serde::Deserialize::deserialize(de)?; - Ok(hints.into()) + Ok(Hints { direct_tcp, relay }) } } @@ -338,10 +302,20 @@ impl serde::Serialize for Hints { where S: serde::Serializer, { - ser.collect_seq(self.iter_serde()) + let direct = self.direct_tcp.iter().cloned().map(HintSerde::DirectTcpV1); + let relay = self.relay.iter().flat_map(|hint| { + [ + HintSerde::RelayV1 { + hints: hint.tcp.clone(), + }, + HintSerde::RelayV2(hint.clone()), + ] + }); + ser.collect_seq(direct.chain(relay)) } } +/** hostname and port for direct connection */ #[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq, Hash, derive_more::Display)] #[display(fmt = "tcp://{}:{}", hostname, port)] pub struct DirectHint { @@ -361,7 +335,31 @@ impl DirectHint { } } -/** Hint describing a relay server +/* Wire representation of a single relay hint (Helper struct for serialization) */ +#[derive(Serialize, Deserialize, Debug, PartialEq)] +#[serde(rename_all = "kebab-case", tag = "type")] +#[non_exhaustive] +struct RelayHintSerde { + name: Option, + #[serde(rename = "hints")] + endpoints: Vec, +} + +/* Wire representation of a single relay endpoint (Helper struct for serialization) */ +#[derive(Serialize, Deserialize, Debug, PartialEq)] +#[serde(rename_all = "kebab-case", tag = "type")] +#[non_exhaustive] +enum RelayHintSerdeInner { + Tcp(DirectHint), + Websocket { + url: url::Url, + }, + #[serde(other)] + Unknown, +} + +/** + * Hint describing a relay server * * A server may be reachable at multiple locations. Any two must be relayable * over that server, therefore a client may pick only one of these per hint. @@ -372,37 +370,50 @@ impl DirectHint { /* RelayHint::default() gives the empty server (cannot be reached), and is only there for struct update syntax */ #[derive(Clone, Debug, Eq, PartialEq, Default)] pub struct RelayHint { + /** Human readable name */ + pub name: Option, + /** TCP endpoints of that relay */ pub tcp: HashSet, + /** WebSockets endpoints of that relay */ pub ws: HashSet, - pub other: HashSet, } impl RelayHint { - pub fn from_url(url: url::Url) -> Self { - Self::new(std::iter::once(url)) + pub fn new( + name: Option, + tcp: impl IntoIterator, + ws: impl IntoIterator, + ) -> Self { + Self { + name, + tcp: tcp.into_iter().collect(), + ws: ws.into_iter().collect(), + } } - pub fn new(urls: impl IntoIterator) -> Self { - let mut tcp = HashSet::new(); - let mut ws = HashSet::new(); - let mut other = HashSet::new(); - for hint in urls { - match hint.scheme() { + pub fn from_urls(name: Option, urls: impl IntoIterator) -> Self { + let mut this = Self { + name, + ..Self::default() + }; + for url in urls.into_iter() { + match url.scheme() { "tcp" => { - tcp.insert(DirectHint { - hostname: hint.host_str().expect("TODO Error handling").into(), - port: hint.port().expect("TODO Error handling"), + this.tcp.insert(DirectHint { + hostname: url.host_str().expect("TODO error handling").into(), + port: url.port().expect("TODO error handling"), }); }, "ws" | "wss" => { - ws.insert(hint); + this.ws.insert(url); }, _ => { - other.insert(hint); + // Do we fail or do we ignore? + todo!("TODO error handling"); }, } } - RelayHint { tcp, ws, other } + this } pub fn can_merge(&self, other: &Self) -> bool { @@ -417,7 +428,6 @@ impl RelayHint { pub fn merge_mut(&mut self, other: Self) { self.tcp.extend(other.tcp); self.ws.extend(other.ws); - self.other.extend(other.other); } pub fn merge_into(self, collection: &mut Vec) { @@ -429,36 +439,56 @@ impl RelayHint { } collection.push(self); } - - pub fn iter_urls(&self) -> impl IntoIterator + '_ { - self.tcp - .iter() - .map(|hint| { - format!("tcp://{}:{}", hint.hostname, hint.port) - .parse() - .unwrap() - }) - .chain(self.other.iter().cloned()) - .chain(self.ws.iter().cloned()) - } } -impl From for HashSet { - fn from(hint: RelayHint) -> HashSet { - let mut urls = hint.other; - urls.extend(hint.ws); - urls.extend(hint.tcp.into_iter().map(|hint| { - format!("tcp://{}:{}", hint.hostname, hint.port) - .parse() - .unwrap() - })); - urls +impl serde::Serialize for RelayHint { + fn serialize(&self, ser: S) -> Result + where + S: serde::Serializer, + { + let mut hints = Vec::new(); + hints.extend(self.tcp.iter().cloned().map(RelayHintSerdeInner::Tcp)); + hints.extend( + self.ws + .iter() + .cloned() + .map(|h| RelayHintSerdeInner::Websocket { url: h }), + ); + + serde_json::json!({ + "name": self.name, + "hints": hints, + }) + .serialize(ser) } } -impl From> for RelayHint { - fn from(urls: HashSet) -> RelayHint { - Self::new(urls) +impl<'de> serde::Deserialize<'de> for RelayHint { + fn deserialize(de: D) -> Result + where + D: serde::Deserializer<'de>, + { + let raw = RelayHintSerde::deserialize(de)?; + let mut hint = RelayHint { + name: raw.name, + tcp: HashSet::new(), + ws: HashSet::new(), + }; + + for e in raw.endpoints { + match e { + RelayHintSerdeInner::Tcp(tcp) => { + hint.tcp.insert(tcp); + }, + RelayHintSerdeInner::Websocket { url } => { + hint.ws.insert(url); + }, + /* Ignore unknown hints */ + _ => {}, + } + } + + Ok(hint) } } @@ -1408,3 +1438,70 @@ async fn handshake_exchange( rnonce: Default::default(), }) } + +#[cfg(test)] +mod test { + use super::*; + use serde_json::json; + + #[test] + pub fn test_abilities_encoding() { + assert_eq!( + serde_json::to_value(Abilities::ALL_ABILITIES).unwrap(), + json!([{"type": "direct-tcp-v1"}, {"type": "relay-v1"}, {"type": "relay-v2"}]) + ); + assert_eq!( + serde_json::to_value(Abilities::FORCE_DIRECT).unwrap(), + json!([{"type": "direct-tcp-v1"}]) + ); + } + + #[test] + pub fn test_hints_encoding() { + assert_eq!( + serde_json::to_value(Hints::new( + [DirectHint { + hostname: "localhost".into(), + port: 1234 + }], + [RelayHint::new( + Some("default".into()), + [DirectHint::new("transit.magic-wormhole.io", 4001)], + ["ws://transit.magic-wormhole.io/relay".parse().unwrap(),], + )] + )) + .unwrap(), + json!([ + { + "type": "direct-tcp-v1", + "hostname": "localhost", + "port": 1234 + }, + { + "type": "relay-v1", + "hints": [ + { + "hostname": "transit.magic-wormhole.io", + "port": 4001, + } + ] + }, + { + "type": "relay-v2", + "name": "default", + "hints": [ + { + "type": "tcp", + "hostname": "transit.magic-wormhole.io", + "port": 4001, + }, + { + "type": "websocket", + "url": "ws://transit.magic-wormhole.io/relay", + }, + ] + } + ]) + ) + } +}