diff --git a/Cargo.toml b/Cargo.toml index d4343838..51715e44 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,12 +37,17 @@ futures_ringbuf = "0.3.1" tar = "0.4.33" chrono = "0.4.19" -derive_more = { version = "0.99.0", default-features = false, features = ["display", "deref"] } +derive_more = { version = "0.99.0", default-features = false, features = ["display", "deref", "from"] } thiserror = "1.0.24" futures = "0.3.12" async-std = { version = "1.9.0", features = ["attributes", "unstable"] } async-tungstenite = { version = "0.14.0", features = ["async-std-runtime", "async-tls"] } +async-io = "1.6.0" +socket2 = "0.4.1" +libc = "0.2.101" +stun_codec = "0.1.13" +bytecodec = "0.4.15" # for "bin" feature clap = { version = "2.33.3", optional = true } diff --git a/src/core/key.rs b/src/core/key.rs index 1535e39a..1e5fdbbe 100644 --- a/src/core/key.rs +++ b/src/core/key.rs @@ -53,8 +53,8 @@ impl Key { let derived_key = self.derive_subkey_from_purpose(&transit_purpose); trace!( "Input key: {}, Transit key: {}, Transit purpose: '{}'", - hex::encode(&**self), - hex::encode(&**derived_key), + self.to_hex(), + derived_key.to_hex(), &transit_purpose ); derived_key @@ -65,6 +65,11 @@ impl Key

{ pub fn new(key: Box) -> Self { Self(key, std::marker::PhantomData) } + + pub fn to_hex(&self) -> String { + hex::encode(&**self) + } + /** * Derive a new sub-key from this one */ diff --git a/src/core/test.rs b/src/core/test.rs index 37bcfa7f..27df24a9 100644 --- a/src/core/test.rs +++ b/src/core/test.rs @@ -47,9 +47,7 @@ pub async fn test_file_rust2rust() -> eyre::Result<()> { std::fs::metadata("examples/example-file.bin") .unwrap() .len(), - |sent, total| { - log::info!("Sent {} of {} bytes", sent, total); - }, + |_sent, _total| {}, ) .await?, ) @@ -72,13 +70,7 @@ pub async fn test_file_rust2rust() -> eyre::Result<()> { .await?; let mut buffer = Vec::::new(); - req.accept( - |received, total| { - log::info!("Received {} of {} bytes", received, total); - }, - &mut buffer, - ) - .await?; + req.accept(|_received, _total| {}, &mut buffer).await?; Ok(buffer) })?; diff --git a/src/transfer.rs b/src/transfer.rs index 30972825..de750271 100644 --- a/src/transfer.rs +++ b/src/transfer.rs @@ -5,7 +5,7 @@ //! It is bound to an [`APPID`](APPID). Only applications using that APPID (and thus this protocol) can interoperate with //! the original Python implementation (and other compliant implementations). //! -//! At its core, [`PeerMessage`s](PeerMessage) are exchanged over an established wormhole connection with the other side. +//! At its core, "peer messages" are exchanged over an established wormhole connection with the other side. //! They are used to set up a [transit] portal and to exchange a file offer/accept. Then, the file is transmitted over the transit relay. use futures::{AsyncRead, AsyncWrite}; @@ -26,6 +26,9 @@ use sha2::{digest::FixedOutput, Digest, Sha256}; use std::path::PathBuf; use transit::{TransitConnectError, TransitConnector, TransitError}; +mod messages; +use messages::*; + const APPID_RAW: &str = "lothar.com/wormhole/text-or-file-xfer"; /// The App ID associated with this protocol. @@ -156,105 +159,6 @@ impl TransitAck { } } -/** - * The type of message exchanged over the wormhole for this protocol - */ -#[derive(Deserialize, Serialize, Debug, PartialEq)] -#[serde(rename_all = "kebab-case")] -pub enum PeerMessage { - Offer(OfferType), - Answer(AnswerType), - /** Tell the other side you got an error */ - Error(String), - /** Used to set up a transit channel */ - Transit(Arc), - #[serde(other)] - Unknown, -} - -impl PeerMessage { - pub fn new_offer_message(msg: impl Into) -> Self { - PeerMessage::Offer(OfferType::Message(msg.into())) - } - - pub fn new_offer_file(name: impl Into, size: u64) -> Self { - PeerMessage::Offer(OfferType::File { - filename: name.into(), - filesize: size, - }) - } - - pub fn new_offer_directory( - name: impl Into, - mode: impl Into, - compressed_size: u64, - numbytes: u64, - numfiles: u64, - ) -> Self { - PeerMessage::Offer(OfferType::Directory { - dirname: name.into(), - mode: mode.into(), - zipsize: compressed_size, - numbytes, - numfiles, - }) - } - - pub fn new_message_ack(msg: impl Into) -> Self { - PeerMessage::Answer(AnswerType::MessageAck(msg.into())) - } - - pub fn new_file_ack(msg: impl Into) -> Self { - PeerMessage::Answer(AnswerType::FileAck(msg.into())) - } - - pub fn new_error_message(msg: impl Into) -> Self { - PeerMessage::Error(msg.into()) - } - - pub fn new_transit(abilities: Vec, hints: Vec) -> Self { - PeerMessage::Transit(Arc::new(transit::TransitType { - abilities_v1: abilities, - hints_v1: hints, - })) - } - - #[cfg(test)] - pub fn serialize(&self) -> String { - json!(self).to_string() - } - - pub fn serialize_vec(&self) -> Vec { - serde_json::to_vec(self).unwrap() - } -} - -#[derive(Serialize, Deserialize, Debug, PartialEq)] -#[serde(rename_all = "kebab-case")] -pub enum OfferType { - Message(String), - File { - filename: PathBuf, - filesize: u64, - }, - Directory { - dirname: PathBuf, - mode: String, - zipsize: u64, - numbytes: u64, - numfiles: u64, - }, - #[serde(other)] - Unknown, -} - -#[derive(Serialize, Deserialize, Debug, PartialEq)] -#[serde(rename_all = "snake_case")] -pub enum AnswerType { - MessageAck(String), - FileAck(String), -} - pub async fn send_file_or_folder( wormhole: &mut Wormhole, relay_url: &RelayUrl, @@ -310,9 +214,15 @@ where let connector = transit::init(transit::Ability::all_abilities(), relay_url).await?; // We want to do some transit - debug!("Sending transit message '{:?}", connector.our_side_ttype()); + debug!("Sending transit message '{:?}", connector.our_hints()); wormhole - .send(PeerMessage::Transit(connector.our_side_ttype().clone()).serialize_vec()) + .send( + PeerMessage::new_transit( + connector.our_abilities().to_vec(), + (**connector.our_hints()).clone().into(), + ) + .serialize_vec(), + ) .await?; // Send file offer message. @@ -322,24 +232,23 @@ where .await?; // Wait for their transit response - let other_side_ttype = { - let maybe_transit = serde_json::from_slice(&wormhole.receive().await?)?; - debug!("received transit message: {:?}", maybe_transit); - - match maybe_transit { - PeerMessage::Transit(tmsg) => tmsg, + let (their_abilities, their_hints): (Vec, transit::Hints) = + match serde_json::from_slice(&wormhole.receive().await?)? { + PeerMessage::Transit(transit) => { + debug!("received transit message: {:?}", transit); + (transit.abilities_v1, transit.hints_v1.into()) + }, PeerMessage::Error(err) => { bail!(TransferError::PeerError(err)); }, - _ => { - let error = TransferError::unexpected_message("transit", maybe_transit); + other => { + let error = TransferError::unexpected_message("transit", other); let _ = wormhole .send(PeerMessage::Error(format!("{}", error)).serialize_vec()) .await; bail!(error) }, - } - }; + }; { // Wait for file_ack @@ -363,17 +272,36 @@ where } } - let mut transit = connector + let mut transit = match connector .leader_connect( wormhole.key().derive_transit_key(wormhole.appid()), - Arc::try_unwrap(other_side_ttype).unwrap(), + Arc::new(their_abilities), + Arc::new(their_hints), ) - .await?; + .await + { + Ok(transit) => transit, + Err(error) => { + let error = TransferError::TransitConnect(error); + let _ = wormhole + .send(PeerMessage::Error(format!("{}", error)).serialize_vec()) + .await; + return Err(error); + }, + }; debug!("Beginning file transfer"); // 11. send the file as encrypted records. - let checksum = send_records(&mut transit, file, file_size, progress_handler).await?; + let checksum = match send_records(&mut transit, file, file_size, progress_handler).await { + Err(TransferError::Transit(error)) => { + let _ = wormhole + .send(PeerMessage::Error(format!("{}", error)).serialize_vec()) + .await; + Err(TransferError::Transit(error)) + }, + other => other, + }?; // 13. wait for the transit ack with sha256 sum from the peer. debug!("sent file. Waiting for ack"); @@ -415,9 +343,15 @@ where } // We want to do some transit - debug!("Sending transit message '{:?}", connector.our_side_ttype()); + debug!("Sending transit message '{:?}", connector.our_hints()); wormhole - .send(PeerMessage::Transit(connector.our_side_ttype().clone()).serialize_vec()) + .send( + PeerMessage::new_transit( + connector.our_abilities().to_vec(), + (**connector.our_hints()).clone().into(), + ) + .serialize_vec(), + ) .await?; use tar::Builder; @@ -478,21 +412,23 @@ where .await?; // Wait for their transit response - let other_side_ttype = { - let maybe_transit = serde_json::from_slice(&wormhole.receive().await?)?; - debug!("received transit message: {:?}", maybe_transit); - - match maybe_transit { - PeerMessage::Transit(tmsg) => tmsg, - _ => { - let error = TransferError::unexpected_message("transit", maybe_transit); + let (their_abilities, their_hints): (Vec, transit::Hints) = + match serde_json::from_slice(&wormhole.receive().await?)? { + PeerMessage::Transit(transit) => { + debug!("received transit message: {:?}", transit); + (transit.abilities_v1, transit.hints_v1.into()) + }, + PeerMessage::Error(err) => { + bail!(TransferError::PeerError(err)); + }, + other => { + let error = TransferError::unexpected_message("transit", other); let _ = wormhole .send(PeerMessage::Error(format!("{}", error)).serialize_vec()) .await; bail!(error) }, - } - }; + }; { // Wait for file_ack @@ -516,12 +452,23 @@ where } } - let mut transit = connector + let mut transit = match connector .leader_connect( wormhole.key().derive_transit_key(wormhole.appid()), - Arc::try_unwrap(other_side_ttype).unwrap(), + Arc::new(their_abilities), + Arc::new(their_hints), ) - .await?; + .await + { + Ok(transit) => transit, + Err(error) => { + let error = TransferError::TransitConnect(error); + let _ = wormhole + .send(PeerMessage::Error(format!("{}", error)).serialize_vec()) + .await; + return Err(error); + }, + }; debug!("Beginning file transfer"); @@ -580,9 +527,15 @@ where std::io::Result::Ok(hasher.finalize_fixed()) }); - let checksum = send_records(&mut transit, &mut reader, length, progress_handler) - .await - .unwrap(); + let checksum = match send_records(&mut transit, &mut reader, length, progress_handler).await { + Err(TransferError::Transit(error)) => { + let _ = wormhole + .send(PeerMessage::Error(format!("{}", error)).serialize_vec()) + .await; + Err(TransferError::Transit(error)) + }, + other => other, + }?; /* This should always be ready by now, but just in case */ let sha256sum = file_sender.await.unwrap(); @@ -620,31 +573,35 @@ pub async fn request_file<'a>( let connector = transit::init(transit::Ability::all_abilities(), relay_url).await?; // send the transit message - debug!("Sending transit message '{:?}", connector.our_side_ttype()); + debug!("Sending transit message '{:?}", connector.our_hints()); wormhole .send( - crate::transfer::PeerMessage::Transit(connector.our_side_ttype().clone()) - .serialize_vec(), + PeerMessage::new_transit( + connector.our_abilities().to_vec(), + (**connector.our_hints()).clone().into(), + ) + .serialize_vec(), ) .await?; // receive transit message - let other_side_ttype = match serde_json::from_slice(&wormhole.receive().await?)? { - PeerMessage::Transit(transit) => { - debug!("received transit message: {:?}", transit); - transit - }, - PeerMessage::Error(err) => { - bail!(TransferError::PeerError(err)); - }, - other => { - let error = TransferError::unexpected_message("transit", other); - let _ = wormhole - .send(PeerMessage::Error(format!("{}", error)).serialize_vec()) - .await; - bail!(error) - }, - }; + let (their_abilities, their_hints): (Vec, transit::Hints) = + match serde_json::from_slice(&wormhole.receive().await?)? { + PeerMessage::Transit(transit) => { + debug!("received transit message: {:?}", transit); + (transit.abilities_v1, transit.hints_v1.into()) + }, + PeerMessage::Error(err) => { + bail!(TransferError::PeerError(err)); + }, + other => { + let error = TransferError::unexpected_message("transit", other); + let _ = wormhole + .send(PeerMessage::Error(format!("{}", error)).serialize_vec()) + .await; + bail!(error) + }, + }; // 3. receive file offer message from peer let maybe_offer = serde_json::from_slice(&wormhole.receive().await?)?; @@ -680,7 +637,8 @@ pub async fn request_file<'a>( filename, filesize, connector, - other_side_ttype, + their_abilities: Arc::new(their_abilities), + their_hints: Arc::new(their_hints), }; Ok(req) @@ -698,7 +656,8 @@ pub struct ReceiveRequest<'a> { /// **Security warning:** this is untrusted and unverified input pub filename: PathBuf, pub filesize: u64, - other_side_ttype: Arc, + their_abilities: Arc>, + their_hints: Arc, } impl<'a> ReceiveRequest<'a> { @@ -722,25 +681,47 @@ impl<'a> ReceiveRequest<'a> { .send(PeerMessage::new_file_ack("ok").serialize_vec()) .await?; - let mut transit = self + let mut transit = match self .connector .follower_connect( self.wormhole .key() .derive_transit_key(self.wormhole.appid()), - self.other_side_ttype.clone(), + self.their_abilities.clone(), + self.their_hints.clone(), ) - .await?; + .await + { + Ok(transit) => transit, + Err(error) => { + let error = TransferError::TransitConnect(error); + let _ = self + .wormhole + .send(PeerMessage::Error(format!("{}", error)).serialize_vec()) + .await; + return Err(error); + }, + }; debug!("Beginning file transfer"); // TODO here's the right position for applying the output directory and to check for malicious (relative) file paths - tcp_file_receive( + match tcp_file_receive( &mut transit, self.filesize, progress_handler, content_handler, ) .await + { + Err(TransferError::Transit(error)) => { + let _ = self + .wormhole + .send(PeerMessage::Error(format!("{}", error)).serialize_vec()) + .await; + Err(TransferError::Transit(error)) + }, + other => other, + } } /** @@ -793,7 +774,6 @@ where // send the encrypted record transit.send_record(&plaintext[0..n]).await?; sent_size += n as u64; - debug!("sent {} bytes out of {} bytes", sent_size, file_size); progress_handler(sent_size, file_size); // sha256 of the input diff --git a/src/transfer/messages.rs b/src/transfer/messages.rs new file mode 100644 index 00000000..fa5dffda --- /dev/null +++ b/src/transfer/messages.rs @@ -0,0 +1,215 @@ +//! Over-the-wire messages for the file transfer (including transit) +//! +//! The transit protocol does not specify how to deliver the information to +//! the other side, so it is up to the file transfer to do that. hfoo + +use crate::transit::{self, Ability, DirectHint}; +use serde_derive::{Deserialize, Serialize}; +#[cfg(test)] +use serde_json::json; +use std::{collections::HashSet, path::PathBuf}; + +/** + * The type of message exchanged over the wormhole for this protocol + */ +#[derive(Deserialize, Serialize, Debug, PartialEq)] +#[serde(rename_all = "kebab-case")] +pub enum PeerMessage { + Offer(OfferType), + Answer(AnswerType), + /** Tell the other side you got an error */ + Error(String), + /** Used to set up a transit channel */ + Transit(TransitType), + #[serde(other)] + Unknown, +} + +impl PeerMessage { + pub fn new_offer_message(msg: impl Into) -> Self { + PeerMessage::Offer(OfferType::Message(msg.into())) + } + + pub fn new_offer_file(name: impl Into, size: u64) -> Self { + PeerMessage::Offer(OfferType::File { + filename: name.into(), + filesize: size, + }) + } + + pub fn new_offer_directory( + name: impl Into, + mode: impl Into, + compressed_size: u64, + numbytes: u64, + numfiles: u64, + ) -> Self { + PeerMessage::Offer(OfferType::Directory { + dirname: name.into(), + mode: mode.into(), + zipsize: compressed_size, + numbytes, + numfiles, + }) + } + + pub fn new_message_ack(msg: impl Into) -> Self { + PeerMessage::Answer(AnswerType::MessageAck(msg.into())) + } + + pub fn new_file_ack(msg: impl Into) -> Self { + PeerMessage::Answer(AnswerType::FileAck(msg.into())) + } + + pub fn new_error_message(msg: impl Into) -> Self { + PeerMessage::Error(msg.into()) + } + + pub fn new_transit(abilities: Vec, hints: Vec) -> Self { + PeerMessage::Transit(TransitType { + abilities_v1: abilities, + hints_v1: hints, + }) + } + + #[cfg(test)] + pub fn serialize(&self) -> String { + json!(self).to_string() + } + + pub fn serialize_vec(&self) -> Vec { + serde_json::to_vec(self).unwrap() + } +} + +#[derive(Serialize, Deserialize, Debug, PartialEq)] +#[serde(rename_all = "kebab-case")] +pub enum OfferType { + Message(String), + File { + filename: PathBuf, + filesize: u64, + }, + Directory { + dirname: PathBuf, + mode: String, + zipsize: u64, + numbytes: u64, + numfiles: u64, + }, + #[serde(other)] + Unknown, +} + +#[derive(Serialize, Deserialize, Debug, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum AnswerType { + MessageAck(String), + FileAck(String), +} + +/** + * A set of hints for both sides to find each other + */ +#[derive(Serialize, Deserialize, Debug, PartialEq)] +#[serde(rename_all = "kebab-case")] +pub struct TransitType { + pub abilities_v1: Vec, + pub hints_v1: Vec, +} + +impl From for Vec { + fn from(hints: transit::Hints) -> Self { + hints + .direct_tcp + .into_iter() + .map(Hint::DirectTcpV1) + .chain(std::iter::once(Hint::new_relay(hints.relay))) + .collect() + } +} + +impl Into for Vec { + fn into(self) -> transit::Hints { + let mut direct_tcp = HashSet::new(); + let mut relay = HashSet::new(); + + /* There is only one "relay hint", though it may contain multiple + * items. Yes, this is inconsistent and weird, watch your step. + */ + for hint in self { + match hint { + Hint::DirectTcpV1(hint) => { + direct_tcp.insert(hint); + }, + Hint::DirectUdtV1(_) => unimplemented!(), + Hint::RelayV1(RelayHint { hints }) => relay.extend(hints), + } + } + + transit::Hints { direct_tcp, relay } + } +} + +#[derive(Serialize, Deserialize, Debug, PartialEq)] +#[serde(rename_all = "kebab-case", tag = "type")] +#[non_exhaustive] +pub enum Hint { + DirectTcpV1(DirectHint), + DirectUdtV1(DirectHint), + /* Weirdness alarm: a "relay hint" contains multiple "direct hints". This means + * that there may be multiple direct hints, but if there are multiple relay hints + * it's still only one item because it internally has a list. + */ + RelayV1(RelayHint), +} + +impl Hint { + pub fn new_direct_tcp(_priority: f32, hostname: &str, port: u16) -> Self { + Hint::DirectTcpV1(DirectHint { + hostname: hostname.to_string(), + port, + }) + } + + pub fn new_direct_udt(_priority: f32, hostname: &str, port: u16) -> Self { + Hint::DirectUdtV1(DirectHint { + hostname: hostname.to_string(), + port, + }) + } + + pub fn new_relay(h: HashSet) -> Self { + Hint::RelayV1(RelayHint { + hints: h.into_iter().collect(), + }) + } +} + +#[derive(Serialize, Deserialize, Debug, PartialEq)] +pub struct RelayHint { + pub hints: Vec, +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_transit() { + let abilities = vec![Ability::DirectTcpV1, Ability::RelayV1]; + let hints = vec![ + Hint::new_direct_tcp(0.0, "192.168.1.8", 46295), + Hint::new_relay( + vec![DirectHint { + hostname: "magic-wormhole-transit.debian.net".to_string(), + port: 4001, + }] + .into_iter() + .collect(), + ), + ]; + let t = crate::transfer::PeerMessage::new_transit(abilities, hints); + assert_eq!(t.serialize(), "{\"transit\":{\"abilities-v1\":[{\"type\":\"direct-tcp-v1\"},{\"type\":\"relay-v1\"}],\"hints-v1\":[{\"hostname\":\"192.168.1.8\",\"port\":46295,\"type\":\"direct-tcp-v1\"},{\"hints\":[{\"hostname\":\"magic-wormhole-transit.debian.net\",\"port\":4001}],\"type\":\"relay-v1\"}]}}") + } +} diff --git a/src/transit.rs b/src/transit.rs index 2522d4d1..7a81a31f 100644 --- a/src/transit.rs +++ b/src/transit.rs @@ -4,7 +4,7 @@ //! but it depends on some kind of secure communication channel to talk to the other side. Conveniently, Wormhole provides //! exactly such a thing :) //! -//! Both clients exchange messages containing hints on how to find each other. These may be local IP Addresses for in case they +//! Both clients exchange messages containing hints on how to find each other. These may be local IP addresses for in case they //! are in the same network, or the URL to a relay server. In case a direct connection fails, both will connect to the relay server //! which will transparently glue the connections together. //! @@ -13,7 +13,7 @@ //! **Notice:** while the resulting TCP connection is naturally bi-directional, the handshake is not symmetric. There *must* be one //! "leader" side and one "follower" side (formerly called "sender" and "receiver"). -use crate::{core::WormholeError, Key, KeyPurpose}; +use crate::{Key, KeyPurpose}; use serde_derive::{Deserialize, Serialize}; use async_std::{ @@ -21,14 +21,18 @@ use async_std::{ net::{TcpListener, TcpStream}, }; #[allow(unused_imports)] /* We need them for the docs */ -use futures::{future::TryFutureExt, Sink, Stream, StreamExt}; +use futures::{future::TryFutureExt, Sink, SinkExt, Stream, StreamExt, TryStreamExt}; use log::*; -use std::{net::ToSocketAddrs, str::FromStr, sync::Arc}; +use std::{collections::HashSet, str::FromStr, sync::Arc}; use xsalsa20poly1305 as secretbox; use xsalsa20poly1305::aead::{Aead, NewAead}; /// ULR to a default hosted relay server. Please don't abuse or DOS. pub const DEFAULT_RELAY_SERVER: &str = "tcp:transit.magic-wormhole.io:4001"; +// No need to make public, it's hard-coded anyways (: +// Open an issue if you want an API for this +// Use for non-production testing +const PUBLIC_STUN_SERVER: &str = "stun.piegames.de:3478"; #[derive(Debug)] pub struct TransitKey; @@ -43,14 +47,11 @@ impl KeyPurpose for TransitTxKey {} #[derive(Debug, thiserror::Error)] #[non_exhaustive] pub enum TransitConnectError { - #[error("All (relay) handshakes failed, could not establish a connection with the peer")] + /** Incompatible abilities, or wrong hints */ + #[error("{}", _0)] + Protocol(Box), + #[error("All (relay) handshakes failed or timed out; could not establish a connection with the peer")] Handshake, - #[error("Wormhole connection error")] - Wormhole( - #[from] - #[source] - WormholeError, - ), #[error("IO error")] IO( #[from] @@ -68,11 +69,11 @@ enum TransitHandshakeError { HandshakeFailed, #[error("Relay handshake failed")] RelayHandshakeFailed, - #[error("Wormhole connection error")] - Wormhole( + #[error("Malformed peer address")] + BadAddress( #[from] #[source] - WormholeError, + std::net::AddrParseError, ), #[error("IO error")] IO( @@ -97,31 +98,34 @@ pub enum TransitError { ), } -/** - * A set of hints for both sides to find each other - */ -#[derive(Serialize, Deserialize, Debug, PartialEq)] -#[serde(rename_all = "kebab-case")] -pub struct TransitType { - pub abilities_v1: Vec, - pub hints_v1: Vec, -} - /** * Defines a way to find the other side. * - * Each ability comes with a set of [hints](Hint) to encode how to meet up. + * Each ability comes with a set of [`Hints`] to encode how to meet up. */ -#[derive(Serialize, Deserialize, Debug, PartialEq)] +#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] #[serde(rename_all = "kebab-case", tag = "type")] +#[non_exhaustive] pub enum Ability { /** - * Try to connect directly to the other side. + * Try to connect directly to the other side via TCP. * * This usually requires both participants to be in the same network. [`DirectHint`s](DirectHint) are sent, * which encode all local IP addresses for the other side to find us. */ DirectTcpV1, + /** + * UNSTABLE; NOT IMPLEMENTED! + * Try to connect directly to the other side via UDT. + * + * This supersedes [`Ability::DirectTcpV1`] because it has several advantages: + * + * - Works with stateful firewalls, no need to open up ports + * - Works despite many NAT types if combined with STUN + * - UDT has a few other interesting performance-related properties that make it better + * suited than TCP (it's literally called "UDP-based Data Transfer Protocol") + */ + DirectUdtV1, /** Try to meet the other side at a relay. */ RelayV1, /* TODO Fix once https://github.com/serde-rs/serde/issues/912 is done */ @@ -131,44 +135,71 @@ pub enum Ability { impl Ability { pub fn all_abilities() -> Vec { - vec![Self::DirectTcpV1, Self::RelayV1] + vec![Self::DirectTcpV1, Self::DirectUdtV1, Self::RelayV1] } -} -#[derive(Serialize, Deserialize, Debug, PartialEq)] -#[serde(rename_all = "kebab-case", tag = "type")] -pub enum Hint { - DirectTcpV1(DirectHint), - RelayV1(RelayHint), -} - -impl Hint { - pub fn new_direct(priority: f32, hostname: &str, port: u16) -> Self { - Hint::DirectTcpV1(DirectHint { - priority, - hostname: hostname.to_string(), - port, - }) + /** + * If you absolutely don't want to use any relay servers. + * + * If the other side forces relay usage or doesn't support any of your connection modes + * the attempt will fail. + */ + pub fn force_direct() -> Vec { + vec![Self::DirectTcpV1, Self::DirectUdtV1] } - pub fn new_relay(h: Vec) -> Self { - Hint::RelayV1(RelayHint { hints: h }) + /** + * If you don't want to disclose your IP address to your peer + * + * If the other side forces a the usage of a direct connection the attempt will fail. + * Note that the other side might control the relay server being used, if you really + * don't want your IP to potentially be disclosed use TOR instead (not supported by + * the Rust implementation yet). + */ + pub fn force_relay() -> Vec { + vec![Self::RelayV1] } } -#[derive(Serialize, Deserialize, Debug, PartialEq)] +#[derive(Clone, Debug, Default)] +pub struct Hints { + pub direct_tcp: HashSet, + pub relay: HashSet, +} + +#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq, Hash, derive_more::Display)] +#[display(fmt = "tcp://{}:{}", hostname, port)] pub struct DirectHint { - pub priority: f32, + // DirectHint also contains a `priority` field, but it is underspecified + // and we won't use it + // pub priority: f32, pub hostname: String, pub port: u16, } -#[derive(Serialize, Deserialize, Debug, PartialEq)] -pub struct RelayHint { - pub hints: Vec, +use std::convert::{TryFrom, TryInto}; + +impl TryFrom<&DirectHint> for std::net::IpAddr { + type Error = std::net::AddrParseError; + fn try_from(hint: &DirectHint) -> Result { + hint.hostname.parse() + } +} + +impl TryFrom<&DirectHint> for std::net::SocketAddr { + type Error = std::net::AddrParseError; + /** This does not do the obvious thing and also implicitly maps all V4 addresses into V6 */ + fn try_from(hint: &DirectHint) -> Result { + let addr = hint.try_into()?; + let addr = match addr { + std::net::IpAddr::V4(v4) => std::net::IpAddr::V6(v4.to_ipv6_mapped()), + std::net::IpAddr::V6(_) => addr, + }; + Ok(std::net::SocketAddr::new(addr, hint.port)) + } } -#[derive(Debug, PartialEq)] +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] enum HostType { Direct, Relay, @@ -198,68 +229,314 @@ impl FromStr for RelayUrl { } } +fn set_socket_opts(socket: &socket2::Socket) -> std::io::Result<()> { + socket.set_nonblocking(true)?; + + /* See https://stackoverflow.com/a/14388707/6094756. + * On most BSD and Linux systems, we need both REUSEADDR and REUSEPORT; + * and if they don't support the latter we won't compile. + * On Windows, there is only REUSEADDR but it does what we want. + */ + socket.set_reuse_address(true)?; + #[cfg(all(unix, not(any(target_os = "solaris", target_os = "illumos"))))] + { + socket.set_reuse_port(true)?; + } + #[cfg(not(any( + all(unix, not(any(target_os = "solaris", target_os = "illumos"))), + target_os = "windows" + )))] + { + compile_error!("Your system is not supported yet, please raise an error"); + } + + Ok(()) +} + +/** + * Bind to a port with SO_REUSEADDR, connect to the destination and then hide the blood behind a pretty [`async_std::net::TcpStream`] + * + * We want an `async_std::net::TcpStream`, but with SO_REUSEADDR set. + * The former is just a wrapper around `async_io::Async`, of which we + * copy the `connect` method to add a statement that will set the socket flag. + * See https://github.com/smol-rs/async-net/issues/20. + */ +async fn connect_custom( + local_addr: &socket2::SockAddr, + dest_addr: &socket2::SockAddr, +) -> std::io::Result { + let socket = socket2::Socket::new(socket2::Domain::IPV6, socket2::Type::STREAM, None)?; + /* Set our custum options */ + set_socket_opts(&socket)?; + + socket.bind(local_addr)?; + + /* Initiate connect */ + match socket.connect(dest_addr) { + Ok(_) => {}, + #[cfg(unix)] + Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => {}, + Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => {}, + Err(err) => return Err(err), + } + + let stream = async_io::Async::new(std::net::TcpStream::from(socket))?; + /* The stream becomes writable when connected. */ + stream.writable().await?; + + /* Check if there was an error while connecting. */ + stream + .get_ref() + .take_error() + .and_then(|maybe_err| maybe_err.map_or(Ok(()), Result::Err))?; + /* Convert our mess to `async_std::net::TcpStream */ + Ok(stream.into_inner()?.into()) +} + +#[derive(Debug, thiserror::Error)] +enum StunError { + #[error("No V4 addresses were found for the selected STUN server")] + ServerIsV4Only, + #[error("Connection timed out")] + Timeout, + #[error("IO error")] + IO( + #[from] + #[source] + std::io::Error, + ), + #[error("Malformed STUN packet")] + Codec( + #[from] + #[source] + bytecodec::Error, + ), +} + +/** Perform a STUN query to get the external IP address */ +async fn get_external_ip() -> Result<(std::net::SocketAddr, TcpStream), StunError> { + let mut socket = connect_custom( + &"[::]:0".parse::().unwrap().into(), + &PUBLIC_STUN_SERVER + .to_socket_addrs()? + /* If you find yourself behind a NAT66, open an issue */ + .find(|x| x.is_ipv4()) + /* TODO add a helper method to stdlib for this */ + .map(|addr| match addr { + std::net::SocketAddr::V4(v4) => std::net::SocketAddr::new( + std::net::IpAddr::V6(v4.ip().to_ipv6_mapped()), + v4.port(), + ), + std::net::SocketAddr::V6(_) => unreachable!(), + }) + .ok_or(StunError::ServerIsV4Only)? + .into(), + ) + .await?; + + use bytecodec::{DecodeExt, EncodeExt}; + use std::net::{SocketAddr, ToSocketAddrs}; + use stun_codec::{ + rfc5389::{ + self, + attributes::{MappedAddress, Software, XorMappedAddress}, + Attribute, + }, + Message, MessageClass, MessageDecoder, MessageEncoder, TransactionId, + }; + + fn get_binding_request() -> Result, bytecodec::Error> { + use rand::Rng; + let random_bytes = rand::thread_rng().gen::<[u8; 12]>(); + + let mut message = Message::new( + MessageClass::Request, + rfc5389::methods::BINDING, + TransactionId::new(random_bytes), + ); + + message.add_attribute(Attribute::Software(Software::new( + "magic-wormhole-rust".to_owned(), + )?)); + + // Encodes the message + let mut encoder = MessageEncoder::new(); + let bytes = encoder.encode_into_bytes(message.clone())?; + Ok(bytes) + } + + fn decode_address(buf: &[u8]) -> Result { + let mut decoder = MessageDecoder::::new(); + let decoded = decoder.decode_from_bytes(buf)??; + + println!("Decoded message: {:?}", decoded); + + let external_addr1 = decoded + .get_attribute::() + .map(|x| x.address()); + //let external_addr2 = decoded.get_attribute::().map(|x|x.address()); + let external_addr3 = decoded + .get_attribute::() + .map(|x| x.address()); + let external_addr = external_addr1 + // .or(external_addr2) + .or(external_addr3); + let external_addr = external_addr.unwrap(); + + Ok(external_addr) + } + + /* Connect the plugs */ + + socket.write_all(get_binding_request()?.as_ref()).await?; + + let mut buf = [0u8; 256]; + /* Read header first */ + socket.read_exact(&mut buf[..20]).await?; + let len: u16 = u16::from_be_bytes([buf[2], buf[3]]); + /* Read the rest of the message */ + socket.read_exact(&mut buf[20..][..len as usize]).await?; + let external_addr = decode_address(&buf[..20 + len as usize])?; + + Ok((external_addr, socket)) +} + /** * Initialize a relay handshake * - * Bind a port and generate our [`TransitType`]. This does not do any communication yet. + * Bind a port and generate our [`Hints`]. This does not do any communication yet. */ pub async fn init( abilities: Vec, relay_url: &RelayUrl, ) -> Result { - let listener = TcpListener::bind("[::]:0").await?; - let port = listener.local_addr()?.port(); + let mut our_hints = Hints::default(); + let mut listener = None; - let mut our_hints: Vec = Vec::new(); + /* Detect our IP addresses if the ability is enabled */ if abilities.contains(&Ability::DirectTcpV1) { - our_hints.extend( + /* Do a STUN query to get our public IP. If it works, we must reuse the same socket (port) + * so that we will be NATted to the same port again. If it doesn't, simply bind a new socket + * and use that instead. + */ + let socket: MaybeConnectedSocket = + match async_std::future::timeout(std::time::Duration::from_secs(4), get_external_ip()) + .await + .map_err(|_| StunError::Timeout) + { + Ok(Ok((external_ip, stream))) => { + log::debug!("Our external IP address is {}", external_ip); + our_hints.direct_tcp.insert(DirectHint { + hostname: external_ip.ip().to_string(), + port: external_ip.port(), + }); + stream.into() + }, + // TODO replace with .flatten() once stable + // https://github.com/rust-lang/rust/issues/70142 + Err(err) | Ok(Err(err)) => { + log::debug!("Failed to get external address via STUN, {}", err); + let socket = + socket2::Socket::new(socket2::Domain::IPV6, socket2::Type::STREAM, None) + .unwrap(); + set_socket_opts(&socket)?; + + socket + .bind(&"[::]:0".parse::().unwrap().into()) + .unwrap(); + + socket.into() + }, + }; + + /* Get a second socket, but this time open a listener on that port. + * This sadly doubles the number of hints, but the method above doesn't work + * for systems which don't have any firewalls. Also, this time we can't reuse + * the port. In theory, we could, but it really confused the kernel to the point + * of `accept` calls never returning again. + */ + let socket2 = TcpListener::bind("[::]:0").await?; + + /* Find our ports, iterate all our local addresses, combine them with the ports and that's our hints */ + let port = socket.local_addr()?.as_socket().unwrap().port(); + let port2 = socket2.local_addr()?.port(); + our_hints.direct_tcp.extend( get_if_addrs::get_if_addrs()? .iter() .filter(|iface| !iface.is_loopback()) - .map(|ip| { - Hint::DirectTcpV1(DirectHint { - priority: 0.0, - hostname: ip.ip().to_string(), - port, - }) - }), + .flat_map(|ip| + /* TODO replace with array once into_iter works as it should */ + vec![ + DirectHint { + hostname: ip.ip().to_string(), + port, + }, + DirectHint { + hostname: ip.ip().to_string(), + port: port2, + }, + ].into_iter()), ); + + listener = Some((socket, socket2)); } + if abilities.contains(&Ability::RelayV1) { - our_hints.push(Hint::new_relay(vec![DirectHint { - priority: 0.0, + our_hints.relay.insert(DirectHint { hostname: relay_url.host.clone(), port: relay_url.port, - }])); + }); } Ok(TransitConnector { - listener, - port, - our_side_ttype: Arc::new(TransitType { - abilities_v1: abilities, - hints_v1: our_hints, - }), + sockets: listener, + our_abilities: Arc::new(abilities), + our_hints: Arc::new(our_hints), }) } +#[derive(derive_more::From)] +enum MaybeConnectedSocket { + #[from] + Socket(socket2::Socket), + #[from] + Stream(TcpStream), +} + +impl MaybeConnectedSocket { + fn local_addr(&self) -> std::io::Result { + match &self { + Self::Socket(socket) => socket.local_addr(), + Self::Stream(stream) => Ok(stream.local_addr()?.into()), + } + } +} + /** * A partially set up [`Transit`] connection. * - * For the transit handshake, each side generates a [ttype](`TransitType`) with all the hints to find the other. You need - * to exchange it (as in: send yours, receive theirs) with them. This is outside of the transit protocol to be protocol - * agnostic. + * For the transit handshake, each side generates a [`Hints`] with all the information to find the other. You need + * to exchange it (as in: send yours, receive theirs) with them. This is outside of the transit protocol, because we + * are protocol agnostic. */ pub struct TransitConnector { - listener: TcpListener, - port: u16, - our_side_ttype: Arc, + /* Only `Some` if direct-tcp-v1 ability has been enabled. + * The first socket is the port from which we will start connection attempts. + * For in case the user is behind no firewalls, we must also listen to the second socket. + */ + sockets: Option<(MaybeConnectedSocket, TcpListener)>, + our_abilities: Arc>, + our_hints: Arc, } impl TransitConnector { + pub fn our_abilities(&self) -> &Arc> { + &self.our_abilities + } + /** Send this one to the other side */ - pub fn our_side_ttype(&self) -> &Arc { - &self.our_side_ttype + pub fn our_hints(&self) -> &Arc { + &self.our_hints } /** @@ -268,110 +545,89 @@ impl TransitConnector { pub async fn leader_connect( self, transit_key: Key, - other_side_ttype: TransitType, + their_abilities: Arc>, + their_hints: Arc, ) -> Result { + let Self { + sockets, + our_abilities, + our_hints, + } = self; let transit_key = Arc::new(transit_key); - /* TODO This Deref thing is getting out of hand. Maybe implementing AsRef or some other trait may help? */ - debug!("transit key {}", hex::encode(&***transit_key)); - - let port = self.port; - let listener = self.listener; - // let other_side_ttype = Arc::new(other_side_ttype); - // TODO remove this one day - let ttype = &*Box::leak(Box::new(other_side_ttype)); - - // 8. listen for connections on the port and simultaneously try connecting to the peer port. - // extract peer's ip/hostname from 'ttype' - let (mut direct_hosts, mut relay_hosts) = get_direct_relay_hosts(ttype); - - let mut hosts: Vec<(HostType, &DirectHint)> = Vec::new(); - hosts.append(&mut direct_hosts); - hosts.append(&mut relay_hosts); - // TODO: combine our relay hints with the peer's relay hints. - - let mut handshake_futures = Vec::new(); - for host in hosts { - // TODO use async scopes to borrow instead of cloning one day - let transit_key = transit_key.clone(); - let future = async_std::task::spawn( - //async_std::future::timeout(Duration::from_secs(5), - async move { - debug!("host: {:?}", host); - let mut direct_host_iter = format!("{}:{}", host.1.hostname, host.1.port) - .to_socket_addrs() - .unwrap(); - let direct_host = direct_host_iter.next().unwrap(); - - debug!("peer host: {}", direct_host); - - TcpStream::connect(direct_host) - .err_into::() - .and_then(|socket| leader_handshake_exchange(socket, host.0, &*transit_key)) - .await - }, - ); //); - handshake_futures.push(future); - } - handshake_futures.push(async_std::task::spawn(async move { - debug!("local host {}", port); - /* Mixing and matching two different futures library probably isn't the - * best idea, but here we are. Simply be careful about prelude::* imports - * and don't have both StreamExt/FutureExt/… imported at once - */ - use futures::stream::TryStreamExt; - async_std::stream::StreamExt::skip_while(listener.incoming() - .err_into::() - .and_then(move |socket| { - /* Pinning a future + moving some value from outer scope is a bit painful */ - let transit_key = transit_key.clone(); - Box::pin(async move { - leader_handshake_exchange(socket, HostType::Direct, &*transit_key).await - }) - }), - Result::is_err) - /* We only care about the first that succeeds */ - .next() - .await - /* Next always returns Some because Incoming is an infinite stream. We gotta succeed _sometime_. */ - .unwrap() - })); + let start = std::time::Instant::now(); + let mut connection_stream = Box::pin( + Self::connect( + true, + transit_key, + our_abilities.clone(), + our_hints, + their_abilities, + their_hints, + sockets, + ) + .filter_map(|result| async { + match result { + Ok(val) => Some(val), + Err(err) => { + log::debug!("Some leader handshake failed: {:?}", err); + None + }, + } + }), + ); - /* Try to get a Transit out of the first handshake that succeeds. If all fail, - * we fail. - */ - let transit; - loop { - ensure!( - !handshake_futures.is_empty(), - TransitConnectError::Handshake + let (mut transit, host_type) = async_std::future::timeout( + std::time::Duration::from_secs(60), + connection_stream.next(), + ) + .await + .map_err(|_| { + log::debug!("`leader_connect` timed out"); + TransitConnectError::Handshake + })? + .ok_or(TransitConnectError::Handshake)?; + + if host_type == HostType::Relay && our_abilities.contains(&Ability::DirectTcpV1) { + log::debug!( + "Established transit connection over relay. Trying to find a direct connection …" ); - - match futures::future::select_all(handshake_futures).await { - (Ok(transit2), _index, remaining) => { - transit = transit2; - handshake_futures = remaining; - break; - }, - (Err(e), _index, remaining) => { - debug!("Some handshake failed {:#}", e); - handshake_futures = remaining; - }, - } + /* Measure the time it took us to get a response. Based on this, wait some more for more responses + * in case we like one better. + */ + let elapsed = start.elapsed(); + let to_wait = if elapsed.as_secs() > 5 { + /* If our RTT was *that* long, let's just be happy we even got one connection */ + std::time::Duration::from_secs(1) + } else { + elapsed.mul_f32(0.3) + }; + let _ = async_std::future::timeout(to_wait, async { + while let Some((new_transit, new_host_type)) = connection_stream.next().await { + /* We already got a connection, so we're only interested in direct ones */ + if new_host_type == HostType::Direct { + transit = new_transit; + log::debug!("Found direct connection; using that instead."); + break; + } + } + }) + .await; + log::debug!("Did not manage to establish a better connection in time."); + } else { + log::debug!("Established direct transit connection"); } - let mut transit = transit; - /* Cancel all remaining non-finished handshakes */ - handshake_futures - .into_iter() - .map(async_std::task::JoinHandle::cancel) - .for_each(std::mem::drop); + /* Cancel all remaining non-finished handshakes. We could send "nevermind" to explicitly tell + * the other side (probably, this is mostly for relay server statistics), but eeh, nevermind :) + */ + std::mem::drop(connection_stream); - debug!( - "Sending 'go' message to {}", + transit.socket.write_all(b"go\n").await?; + info!( + "Established transit connection to '{}'", transit.socket.peer_addr().unwrap() ); - transit.socket.write_all(b"go\n").await?; Ok(transit) } @@ -382,107 +638,225 @@ impl TransitConnector { pub async fn follower_connect( self, transit_key: Key, - other_side_ttype: Arc, + their_abilities: Arc>, + their_hints: Arc, ) -> Result { + let Self { + sockets, + our_abilities, + our_hints, + } = self; let transit_key = Arc::new(transit_key); - /* TODO This Deref thing is getting out of hand. Maybe implementing AsRef or some other trait may help? */ - debug!("transit key {}", hex::encode(&***transit_key)); - - let port = self.port; - let listener = self.listener; - // let other_side_ttype = Arc::new(other_side_ttype); - let ttype = &*Box::leak(Box::new(other_side_ttype)); // TODO remove this one day - - // 4. listen for connections on the port and simultaneously try connecting to the - // peer listening port. - let (mut direct_hosts, mut relay_hosts) = get_direct_relay_hosts(ttype); - - let mut hosts: Vec<(HostType, &DirectHint)> = Vec::new(); - hosts.append(&mut direct_hosts); - hosts.append(&mut relay_hosts); - // TODO: combine our relay hints with the peer's relay hints. - - let mut handshake_futures = Vec::new(); - for host in hosts { - let transit_key = transit_key.clone(); - - let future = async_std::task::spawn( - //async_std::future::timeout(Duration::from_secs(5), - async move { - debug!("host: {:?}", host); - let mut direct_host_iter = format!("{}:{}", host.1.hostname, host.1.port) - .to_socket_addrs() - .unwrap(); - let direct_host = direct_host_iter.next().unwrap(); - debug!("peer host: {}", direct_host); + let mut connection_stream = Box::pin( + Self::connect( + false, + transit_key, + our_abilities, + our_hints, + their_abilities, + their_hints, + sockets, + ) + .filter_map(|result| async { + match result { + Ok(val) => Some(val), + Err(err) => { + log::debug!("Some follower handshake failed: {:?}", err); + None + }, + } + }), + ); - TcpStream::connect(direct_host) - .err_into::() - .and_then(|socket| { - follower_handshake_exchange(socket, host.0, &*transit_key) - }) - .await - }, - ); //); - handshake_futures.push(future); - } - handshake_futures.push(async_std::task::spawn(async move { - debug!("local host {}", port); + let transit = match async_std::future::timeout( + std::time::Duration::from_secs(60), + &mut connection_stream.next(), + ) + .await + { + Ok(Some((transit, host_type))) => { + log::debug!( + "Established a {} transit connection.", + if host_type == HostType::Direct { + "direct" + } else { + "relay" + } + ); + Ok(transit) + }, + Ok(None) | Err(_) => { + log::debug!("`follower_connect` timed out"); + Err(TransitConnectError::Handshake) + }, + }; - /* Mixing and matching two different futures library probably isn't the - * best idea, but here we are. Simply be careful about prelude::* imports - * and don't have both StreamExt/FutureExt/… imported at once - */ - use futures::stream::TryStreamExt; - async_std::stream::StreamExt::skip_while(listener.incoming() - .err_into::() - .and_then(move |socket| { - /* Pinning a future + moving some value from outer scope is a bit painful */ - let transit_key = transit_key.clone(); - use futures::future::FutureExt; - async move { - follower_handshake_exchange(socket, HostType::Direct, &*transit_key).await - }.boxed() - }), - Result::is_err) - /* We only care about the first that succeeds */ - .next() - .await - /* Next always returns Some because Incoming is an infinite stream. We gotta succeed _sometime_. */ - .unwrap() - })); + /* Cancel all remaining non-finished handshakes. We could send "nevermind" to explicitly tell + * the other side (probably, this is mostly for relay server statistics), but eeh, nevermind :) + */ + std::mem::drop(connection_stream); + + transit + } + + /** Try to establish a connection with the peer. + * + * This encapsulates code that is common to both the leader and the follower. + * + * ## Panics + * + * If the receiving end of the channel for the results is closed before all futures in the return + * value are cancelled/dropped. + */ + fn connect( + is_leader: bool, + transit_key: Arc>, + our_abilities: Arc>, + our_hints: Arc, + their_abilities: Arc>, + their_hints: Arc, + socket: Option<(MaybeConnectedSocket, TcpListener)>, + ) -> impl Stream> + 'static { + assert!(socket.is_some() == our_abilities.contains(&Ability::DirectTcpV1)); + + // 8. listen for connections on the port and simultaneously try connecting to the peer port. + let tside = Arc::new(hex::encode(rand::random::<[u8; 8]>())); - /* Try to get a Transit out of the first handshake that succeeds. If all fail, - * we fail. + /* Iterator of futures yielding a connection. They'll be then mapped with the handshake, collected into + * a Vec and polled concurrently. */ - let transit; - loop { - ensure!( - !handshake_futures.is_empty(), - TransitConnectError::Handshake - ); + use futures::future::BoxFuture; + type BoxIterator = Box>; + type ConnectorFuture = + BoxFuture<'static, Result<(TcpStream, HostType), TransitHandshakeError>>; + let mut connectors: BoxIterator = Box::new(std::iter::empty()); + + /* Create direct connection sockets, if we support it. If peer doesn't support it, their list of hints will + * be empty and no entries will be pushed. + */ + let socket2 = if let Some((socket, socket2)) = socket { + let local_addr = Arc::new(socket.local_addr().unwrap()); + /* Connect to each hint of the peer */ + connectors = Box::new( + connectors.chain( + their_hints + .direct_tcp + .clone() + .into_iter() + /* Nobody should have that many IP addresses, even with NATing */ + .take(10) + .map(move |hint| { + let local_addr = local_addr.clone(); + async move { + let dest_addr = std::net::SocketAddr::try_from(&hint)?; + log::debug!("Connecting directly to {}", dest_addr); + let socket = connect_custom(&local_addr, &dest_addr.into()).await?; + log::debug!("Connected to {}!", dest_addr); + Ok((socket, HostType::Direct)) + } + }) + .map(|fut| Box::pin(fut) as ConnectorFuture), + ), + ) as BoxIterator; + Some(socket2) + } else { + None + }; - match futures::future::select_all(handshake_futures).await { - (Ok(transit2), _index, remaining) => { - transit = transit2; - handshake_futures = remaining; - break; - }, - (Err(e), _index, remaining) => { - debug!("Some handshake failed {:#}", e); - handshake_futures = remaining; - }, - } + /* Relay hints. Make sure that both sides adverize it, since it is fine to support it without providing own hints. */ + if our_abilities.contains(&Ability::RelayV1) && their_abilities.contains(&Ability::RelayV1) + { + /* Collect intermediate into HashSet for deduplication */ + let relay_hints = our_hints + .relay + .clone() + .into_iter() + .take(2) + .chain(their_hints.relay.clone().into_iter().take(2)) + .collect::>(); + connectors = Box::new( + connectors.chain( + relay_hints + .into_iter() + .map(|host| async move { + log::debug!("Connecting to relay {}", host); + let transit = TcpStream::connect((host.hostname.as_str(), host.port)) + .err_into::() + .await?; + log::debug!("Connected to {}!", host); + + Ok((transit, HostType::Relay)) + }) + .map(|fut| Box::pin(fut) as ConnectorFuture), + ), + ) as BoxIterator; } - /* Cancel all remaining non-finished handshakes */ - handshake_futures - .into_iter() - .map(async_std::task::JoinHandle::cancel) - .for_each(std::mem::drop); - - Ok(transit) + /* Do a handshake on all our found connections */ + let transit_key2 = transit_key.clone(); + let tside2 = tside.clone(); + let mut connectors = Box::new( + connectors + .map(move |fut| { + let transit_key = transit_key2.clone(); + let tside = tside2.clone(); + async move { + let (socket, host_type) = fut.await?; + let transit = + handshake_exchange(is_leader, tside, socket, host_type, transit_key) + .await?; + Ok((transit, host_type)) + } + }) + .map(|fut| { + Box::pin(fut) as BoxFuture> + }), + ) + as BoxIterator>>; + + /* Also listen on some port just in case. */ + if let Some(socket2) = socket2 { + connectors = Box::new( + connectors.chain( + std::iter::once(async move { + let transit_key = transit_key.clone(); + let tside = tside.clone(); + let connect = || async { + let (stream, peer) = socket2.accept().await?; + log::debug!("Got connection from {}!", peer); + let transit = handshake_exchange( + is_leader, + tside.clone(), + stream, + HostType::Direct, + transit_key.clone(), + ) + .await?; + Result::<_, TransitHandshakeError>::Ok((transit, HostType::Direct)) + }; + loop { + match connect().await { + Ok(success) => break Ok(success), + Err(err) => { + log::debug!( + "Some handshake failed on the listening port: {:?}", + err + ); + continue; + }, + } + } + }) + .map(|fut| { + Box::pin(fut) + as BoxFuture> + }), + ), + ) + as BoxIterator>>; + } + connectors.collect::>() } } @@ -494,7 +868,7 @@ impl TransitConnector { */ pub struct Transit { /** Raw transit connection */ - pub socket: TcpStream, + socket: TcpStream, /** Our key, used for sending */ pub skey: Key, /** Their key, used for receiving */ @@ -530,7 +904,12 @@ impl Transit { // 2. read that many bytes into an array (or a vector?) let mut buffer = Vec::with_capacity(length); - socket.take(length as u64).read_to_end(&mut buffer).await?; + let len = socket.take(length as u64).read_to_end(&mut buffer).await?; + use std::io::{Error, ErrorKind}; + ensure!( + len == length, + Error::new(ErrorKind::UnexpectedEof, "failed to read whole message") + ); buffer }; @@ -628,40 +1007,43 @@ impl Transit { } } -fn generate_transit_side() -> String { - let x: [u8; 8] = rand::random(); - hex::encode(x) -} - -fn make_relay_handshake(key: &Key, tside: &str) -> String { - let sub_key = key.derive_subkey_from_purpose::("transit_relay_token"); - format!( - "please relay {} for side {}\n", - hex::encode(&**sub_key), - tside - ) -} - -async fn follower_handshake_exchange( +/** + * Do a transit handshake exchange, to establish a direct connection. + * + * This automatically does the relay handshake first if necessary. On the follower + * side, the future will successfully run to completion if a connection could be + * established. On the leader side, the handshake is not 100% completed: the caller + * must write `Ok\n` into the stream that should be used (and optionally `Nevermind\n` + * into all others). + */ +async fn handshake_exchange( + is_leader: bool, + tside: Arc, mut socket: TcpStream, host_type: HostType, - key: &Key, + key: Arc>, ) -> Result { - // create record keys - /* The order here is correct. The "sender" and "receiver" side are a misnomer and should be called - * "leader" and "follower" instead. As a follower, we use the leader key for receiving and our - * key for sending. - */ - let rkey = key.derive_subkey_from_purpose("transit_record_sender_key"); - let skey = key.derive_subkey_from_purpose("transit_record_receiver_key"); - - // exchange handshake - let tside = generate_transit_side(); + // 9. create record keys + let (rkey, skey) = if is_leader { + let rkey = key.derive_subkey_from_purpose("transit_record_receiver_key"); + let skey = key.derive_subkey_from_purpose("transit_record_sender_key"); + (rkey, skey) + } else { + /* The order here is correct. The "sender" and "receiver" side are a misnomer and should be called + * "leader" and "follower" instead. As a follower, we use the leader key for receiving and our + * key for sending. + */ + let rkey = key.derive_subkey_from_purpose("transit_record_sender_key"); + let skey = key.derive_subkey_from_purpose("transit_record_receiver_key"); + (rkey, skey) + }; if host_type == HostType::Relay { trace!("initiating relay handshake"); + + let sub_key = key.derive_subkey_from_purpose::("transit_relay_token"); socket - .write_all(make_relay_handshake(key, &tside).as_bytes()) + .write_all(format!("please relay {} for side {}\n", sub_key.to_hex(), tside).as_bytes()) .await?; let mut rx = [0u8; 3]; socket.read_exact(&mut rx).await?; @@ -669,96 +1051,62 @@ async fn follower_handshake_exchange( ensure!(ok_msg == rx, TransitHandshakeError::RelayHandshakeFailed); } - { - // for receive mode, send receive_handshake_msg and compare. + if is_leader { + // for transmit mode, send send_handshake_msg and compare. // the received message with send_handshake_msg - socket .write_all( format!( - "transit receiver {} ready\n\n", - hex::encode( - &**key.derive_subkey_from_purpose::("transit_receiver") - ) + "transit sender {} ready\n\n", + key.derive_subkey_from_purpose::("transit_sender") + .to_hex() ) .as_bytes(), ) .await?; - // The received message "transit receiver $hash ready\n\n" has exactly 87 bytes - // Three bytes for the "go\n" ack + // The received message "transit sender $hash ready\n\n" has exactly 89 bytes // TODO do proper line parsing one day, this is atrocious - let mut rx: [u8; 90] = [0; 90]; + let mut rx: [u8; 89] = [0; 89]; socket.read_exact(&mut rx).await?; - let expected_tx_handshake = format!( - "transit sender {} ready\n\ngo\n", - hex::encode(&**key.derive_subkey_from_purpose::("transit_sender")) + let expected_rx_handshake = format!( + "transit receiver {} ready\n\n", + key.derive_subkey_from_purpose::("transit_receiver") + .to_hex() ); ensure!( - &rx[..] == expected_tx_handshake.as_bytes(), - TransitHandshakeError::HandshakeFailed + &rx[..] == expected_rx_handshake.as_bytes(), + TransitHandshakeError::HandshakeFailed, ); - } - - Ok(Transit { - socket, - skey, - rkey, - snonce: Default::default(), - rnonce: Default::default(), - }) -} - -async fn leader_handshake_exchange( - mut socket: TcpStream, - host_type: HostType, - key: &Key, -) -> Result { - // 9. create record keys - let skey = key.derive_subkey_from_purpose("transit_record_sender_key"); - let rkey = key.derive_subkey_from_purpose("transit_record_receiver_key"); - - // 10. exchange handshake over tcp - let tside = generate_transit_side(); - - if host_type == HostType::Relay { - socket - .write_all(make_relay_handshake(key, &tside).as_bytes()) - .await?; - let mut rx = [0u8; 3]; - socket.read_exact(&mut rx).await?; - let ok_msg: [u8; 3] = *b"ok\n"; - ensure!(ok_msg == rx, TransitHandshakeError::RelayHandshakeFailed); - } - - { - // for transmit mode, send send_handshake_msg and compare. + } else { + // for receive mode, send receive_handshake_msg and compare. // the received message with send_handshake_msg socket .write_all( format!( - "transit sender {} ready\n\n", - hex::encode( - &**key.derive_subkey_from_purpose::("transit_sender") - ) + "transit receiver {} ready\n\n", + key.derive_subkey_from_purpose::("transit_receiver") + .to_hex(), ) .as_bytes(), ) .await?; - // The received message "transit sender $hash ready\n\n" has exactly 89 bytes + // The received message "transit receiver $hash ready\n\n" has exactly 87 bytes + // Three bytes for the "go\n" ack // TODO do proper line parsing one day, this is atrocious - let mut rx: [u8; 89] = [0; 89]; + let mut rx: [u8; 90] = [0; 90]; socket.read_exact(&mut rx).await?; - let expected_rx_handshake = format!( - "transit receiver {} ready\n\n", - hex::encode(&**key.derive_subkey_from_purpose::("transit_receiver")) + let expected_tx_handshake = format!( + "transit sender {} ready\n\ngo\n", + key.derive_subkey_from_purpose::("transit_sender") + .to_hex(), ); ensure!( - &rx[..] == expected_rx_handshake.as_bytes(), - TransitHandshakeError::HandshakeFailed, + &rx[..] == expected_tx_handshake.as_bytes(), + TransitHandshakeError::HandshakeFailed ); } @@ -770,60 +1118,3 @@ async fn leader_handshake_exchange( rnonce: Default::default(), }) } - -#[allow(clippy::type_complexity)] -fn get_direct_relay_hosts<'a, 'b: 'a>( - ttype: &'b TransitType, -) -> ( - Vec<(HostType, &'a DirectHint)>, - Vec<(HostType, &'a DirectHint)>, -) { - let direct_hosts: Vec<(HostType, &DirectHint)> = ttype - .hints_v1 - .iter() - .filter_map(|hint| match hint { - Hint::DirectTcpV1(dt) => Some((HostType::Direct, dt)), - _ => None, - }) - .collect(); - let relay_hosts_list: Vec<&Vec> = ttype - .hints_v1 - .iter() - .filter_map(|hint| match hint { - Hint::RelayV1(rt) => Some(&rt.hints), - _ => None, - }) - .collect(); - - let _hosts: Vec<(HostType, &DirectHint)> = Vec::new(); - let maybe_relay_hosts = relay_hosts_list.first(); - let relay_hosts: Vec<(HostType, &DirectHint)> = match maybe_relay_hosts { - Some(relay_host_vec) => relay_host_vec - .iter() - .map(|host| (HostType::Relay, host)) - .collect(), - None => vec![], - }; - - (direct_hosts, relay_hosts) -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_transit() { - let abilities = vec![Ability::DirectTcpV1, Ability::RelayV1]; - let hints = vec![ - Hint::new_direct(0.0, "192.168.1.8", 46295), - Hint::new_relay(vec![DirectHint { - priority: 2.0, - hostname: "magic-wormhole-transit.debian.net".to_string(), - port: 4001, - }]), - ]; - let t = crate::transfer::PeerMessage::new_transit(abilities, hints); - assert_eq!(t.serialize(), "{\"transit\":{\"abilities-v1\":[{\"type\":\"direct-tcp-v1\"},{\"type\":\"relay-v1\"}],\"hints-v1\":[{\"hostname\":\"192.168.1.8\",\"port\":46295,\"priority\":0.0,\"type\":\"direct-tcp-v1\"},{\"hints\":[{\"hostname\":\"magic-wormhole-transit.debian.net\",\"port\":4001,\"priority\":2.0}],\"type\":\"relay-v1\"}]}}") - } -}