From 78e621b13bb812424f7fcc58312304a4c1a8de03 Mon Sep 17 00:00:00 2001 From: piegames Date: Sat, 14 Aug 2021 22:17:55 +0200 Subject: [PATCH 1/9] Transit refactor I: small code cleanup --- src/core/key.rs | 9 +++++++-- src/transit.rs | 20 +++++++++----------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/src/core/key.rs b/src/core/key.rs index 1535e39a..1e5fdbbe 100644 --- a/src/core/key.rs +++ b/src/core/key.rs @@ -53,8 +53,8 @@ impl Key { let derived_key = self.derive_subkey_from_purpose(&transit_purpose); trace!( "Input key: {}, Transit key: {}, Transit purpose: '{}'", - hex::encode(&**self), - hex::encode(&**derived_key), + self.to_hex(), + derived_key.to_hex(), &transit_purpose ); derived_key @@ -65,6 +65,11 @@ impl Key

{ pub fn new(key: Box) -> Self { Self(key, std::marker::PhantomData) } + + pub fn to_hex(&self) -> String { + hex::encode(&**self) + } + /** * Derive a new sub-key from this one */ diff --git a/src/transit.rs b/src/transit.rs index 2522d4d1..9411ba14 100644 --- a/src/transit.rs +++ b/src/transit.rs @@ -114,6 +114,7 @@ pub struct TransitType { */ #[derive(Serialize, Deserialize, Debug, PartialEq)] #[serde(rename_all = "kebab-case", tag = "type")] +#[non_exhaustive] pub enum Ability { /** * Try to connect directly to the other side. @@ -137,6 +138,7 @@ impl Ability { #[derive(Serialize, Deserialize, Debug, PartialEq)] #[serde(rename_all = "kebab-case", tag = "type")] +#[non_exhaustive] pub enum Hint { DirectTcpV1(DirectHint), RelayV1(RelayHint), @@ -272,7 +274,7 @@ impl TransitConnector { ) -> Result { let transit_key = Arc::new(transit_key); /* TODO This Deref thing is getting out of hand. Maybe implementing AsRef or some other trait may help? */ - debug!("transit key {}", hex::encode(&***transit_key)); + debug!("transit key {}", transit_key.to_hex()); let port = self.port; let listener = self.listener; @@ -386,7 +388,7 @@ impl TransitConnector { ) -> Result { let transit_key = Arc::new(transit_key); /* TODO This Deref thing is getting out of hand. Maybe implementing AsRef or some other trait may help? */ - debug!("transit key {}", hex::encode(&***transit_key)); + debug!("transit key {}", transit_key.to_hex()); let port = self.port; let listener = self.listener; @@ -637,7 +639,7 @@ fn make_relay_handshake(key: &Key, tside: &str) -> String { let sub_key = key.derive_subkey_from_purpose::("transit_relay_token"); format!( "please relay {} for side {}\n", - hex::encode(&**sub_key), + sub_key.to_hex(), tside ) } @@ -677,9 +679,7 @@ async fn follower_handshake_exchange( .write_all( format!( "transit receiver {} ready\n\n", - hex::encode( - &**key.derive_subkey_from_purpose::("transit_receiver") - ) + key.derive_subkey_from_purpose::("transit_receiver").to_hex(), ) .as_bytes(), ) @@ -693,7 +693,7 @@ async fn follower_handshake_exchange( let expected_tx_handshake = format!( "transit sender {} ready\n\ngo\n", - hex::encode(&**key.derive_subkey_from_purpose::("transit_sender")) + key.derive_subkey_from_purpose::("transit_sender").to_hex(), ); ensure!( &rx[..] == expected_tx_handshake.as_bytes(), @@ -739,9 +739,7 @@ async fn leader_handshake_exchange( .write_all( format!( "transit sender {} ready\n\n", - hex::encode( - &**key.derive_subkey_from_purpose::("transit_sender") - ) + key.derive_subkey_from_purpose::("transit_sender").to_hex() ) .as_bytes(), ) @@ -754,7 +752,7 @@ async fn leader_handshake_exchange( let expected_rx_handshake = format!( "transit receiver {} ready\n\n", - hex::encode(&**key.derive_subkey_from_purpose::("transit_receiver")) + key.derive_subkey_from_purpose::("transit_receiver").to_hex() ); ensure!( &rx[..] == expected_rx_handshake.as_bytes(), From dad9dc142607f38c4aaaf98c3a5272f1d6ae91f9 Mon Sep 17 00:00:00 2001 From: piegames Date: Sun, 22 Aug 2021 14:21:08 +0200 Subject: [PATCH 2/9] Transit refactor II: refactoring This is mainly a preparation for what is to come next - Factored out redundancy (duplicate code) between leader and follower handshake - Made the interface depend less on how `transfer` implements it - Thus, some of the structs were moved into transfer::messages, with added conversion - Some bits of preparation for UDT support --- src/core/test.rs | 12 +- src/transfer.rs | 295 +++++++++----------- src/transfer/messages.rs | 207 ++++++++++++++ src/transit.rs | 589 ++++++++++++++++----------------------- 4 files changed, 591 insertions(+), 512 deletions(-) create mode 100644 src/transfer/messages.rs diff --git a/src/core/test.rs b/src/core/test.rs index 37bcfa7f..27df24a9 100644 --- a/src/core/test.rs +++ b/src/core/test.rs @@ -47,9 +47,7 @@ pub async fn test_file_rust2rust() -> eyre::Result<()> { std::fs::metadata("examples/example-file.bin") .unwrap() .len(), - |sent, total| { - log::info!("Sent {} of {} bytes", sent, total); - }, + |_sent, _total| {}, ) .await?, ) @@ -72,13 +70,7 @@ pub async fn test_file_rust2rust() -> eyre::Result<()> { .await?; let mut buffer = Vec::::new(); - req.accept( - |received, total| { - log::info!("Received {} of {} bytes", received, total); - }, - &mut buffer, - ) - .await?; + req.accept(|_received, _total| {}, &mut buffer).await?; Ok(buffer) })?; diff --git a/src/transfer.rs b/src/transfer.rs index 30972825..1f9ea50c 100644 --- a/src/transfer.rs +++ b/src/transfer.rs @@ -5,7 +5,7 @@ //! It is bound to an [`APPID`](APPID). Only applications using that APPID (and thus this protocol) can interoperate with //! the original Python implementation (and other compliant implementations). //! -//! At its core, [`PeerMessage`s](PeerMessage) are exchanged over an established wormhole connection with the other side. +//! At its core, "peer messages" are exchanged over an established wormhole connection with the other side. //! They are used to set up a [transit] portal and to exchange a file offer/accept. Then, the file is transmitted over the transit relay. use futures::{AsyncRead, AsyncWrite}; @@ -26,6 +26,9 @@ use sha2::{digest::FixedOutput, Digest, Sha256}; use std::path::PathBuf; use transit::{TransitConnectError, TransitConnector, TransitError}; +mod messages; +use messages::*; + const APPID_RAW: &str = "lothar.com/wormhole/text-or-file-xfer"; /// The App ID associated with this protocol. @@ -156,105 +159,6 @@ impl TransitAck { } } -/** - * The type of message exchanged over the wormhole for this protocol - */ -#[derive(Deserialize, Serialize, Debug, PartialEq)] -#[serde(rename_all = "kebab-case")] -pub enum PeerMessage { - Offer(OfferType), - Answer(AnswerType), - /** Tell the other side you got an error */ - Error(String), - /** Used to set up a transit channel */ - Transit(Arc), - #[serde(other)] - Unknown, -} - -impl PeerMessage { - pub fn new_offer_message(msg: impl Into) -> Self { - PeerMessage::Offer(OfferType::Message(msg.into())) - } - - pub fn new_offer_file(name: impl Into, size: u64) -> Self { - PeerMessage::Offer(OfferType::File { - filename: name.into(), - filesize: size, - }) - } - - pub fn new_offer_directory( - name: impl Into, - mode: impl Into, - compressed_size: u64, - numbytes: u64, - numfiles: u64, - ) -> Self { - PeerMessage::Offer(OfferType::Directory { - dirname: name.into(), - mode: mode.into(), - zipsize: compressed_size, - numbytes, - numfiles, - }) - } - - pub fn new_message_ack(msg: impl Into) -> Self { - PeerMessage::Answer(AnswerType::MessageAck(msg.into())) - } - - pub fn new_file_ack(msg: impl Into) -> Self { - PeerMessage::Answer(AnswerType::FileAck(msg.into())) - } - - pub fn new_error_message(msg: impl Into) -> Self { - PeerMessage::Error(msg.into()) - } - - pub fn new_transit(abilities: Vec, hints: Vec) -> Self { - PeerMessage::Transit(Arc::new(transit::TransitType { - abilities_v1: abilities, - hints_v1: hints, - })) - } - - #[cfg(test)] - pub fn serialize(&self) -> String { - json!(self).to_string() - } - - pub fn serialize_vec(&self) -> Vec { - serde_json::to_vec(self).unwrap() - } -} - -#[derive(Serialize, Deserialize, Debug, PartialEq)] -#[serde(rename_all = "kebab-case")] -pub enum OfferType { - Message(String), - File { - filename: PathBuf, - filesize: u64, - }, - Directory { - dirname: PathBuf, - mode: String, - zipsize: u64, - numbytes: u64, - numfiles: u64, - }, - #[serde(other)] - Unknown, -} - -#[derive(Serialize, Deserialize, Debug, PartialEq)] -#[serde(rename_all = "snake_case")] -pub enum AnswerType { - MessageAck(String), - FileAck(String), -} - pub async fn send_file_or_folder( wormhole: &mut Wormhole, relay_url: &RelayUrl, @@ -310,9 +214,15 @@ where let connector = transit::init(transit::Ability::all_abilities(), relay_url).await?; // We want to do some transit - debug!("Sending transit message '{:?}", connector.our_side_ttype()); + debug!("Sending transit message '{:?}", connector.our_hints()); wormhole - .send(PeerMessage::Transit(connector.our_side_ttype().clone()).serialize_vec()) + .send( + PeerMessage::new_transit( + connector.our_abilities().to_vec(), + (**connector.our_hints()).clone().into(), + ) + .serialize_vec(), + ) .await?; // Send file offer message. @@ -322,24 +232,23 @@ where .await?; // Wait for their transit response - let other_side_ttype = { - let maybe_transit = serde_json::from_slice(&wormhole.receive().await?)?; - debug!("received transit message: {:?}", maybe_transit); - - match maybe_transit { - PeerMessage::Transit(tmsg) => tmsg, + let (their_abilities, their_hints): (Vec, transit::Hints) = + match serde_json::from_slice(&wormhole.receive().await?)? { + PeerMessage::Transit(transit) => { + debug!("received transit message: {:?}", transit); + (transit.abilities_v1, transit.hints_v1.into()) + }, PeerMessage::Error(err) => { bail!(TransferError::PeerError(err)); }, - _ => { - let error = TransferError::unexpected_message("transit", maybe_transit); + other => { + let error = TransferError::unexpected_message("transit", other); let _ = wormhole .send(PeerMessage::Error(format!("{}", error)).serialize_vec()) .await; bail!(error) }, - } - }; + }; { // Wait for file_ack @@ -363,17 +272,35 @@ where } } - let mut transit = connector + let mut transit = match connector .leader_connect( wormhole.key().derive_transit_key(wormhole.appid()), - Arc::try_unwrap(other_side_ttype).unwrap(), + Arc::new(their_hints), ) - .await?; + .await + { + Ok(transit) => transit, + Err(error) => { + let error = TransferError::TransitConnect(error); + let _ = wormhole + .send(PeerMessage::Error(format!("{}", error)).serialize_vec()) + .await; + return Err(error); + }, + }; debug!("Beginning file transfer"); // 11. send the file as encrypted records. - let checksum = send_records(&mut transit, file, file_size, progress_handler).await?; + let checksum = match send_records(&mut transit, file, file_size, progress_handler).await { + Err(TransferError::Transit(error)) => { + let _ = wormhole + .send(PeerMessage::Error(format!("{}", error)).serialize_vec()) + .await; + Err(TransferError::Transit(error)) + }, + other => other, + }?; // 13. wait for the transit ack with sha256 sum from the peer. debug!("sent file. Waiting for ack"); @@ -415,9 +342,15 @@ where } // We want to do some transit - debug!("Sending transit message '{:?}", connector.our_side_ttype()); + debug!("Sending transit message '{:?}", connector.our_hints()); wormhole - .send(PeerMessage::Transit(connector.our_side_ttype().clone()).serialize_vec()) + .send( + PeerMessage::new_transit( + connector.our_abilities().to_vec(), + (**connector.our_hints()).clone().into(), + ) + .serialize_vec(), + ) .await?; use tar::Builder; @@ -478,21 +411,23 @@ where .await?; // Wait for their transit response - let other_side_ttype = { - let maybe_transit = serde_json::from_slice(&wormhole.receive().await?)?; - debug!("received transit message: {:?}", maybe_transit); - - match maybe_transit { - PeerMessage::Transit(tmsg) => tmsg, - _ => { - let error = TransferError::unexpected_message("transit", maybe_transit); + let (their_abilities, their_hints): (Vec, transit::Hints) = + match serde_json::from_slice(&wormhole.receive().await?)? { + PeerMessage::Transit(transit) => { + debug!("received transit message: {:?}", transit); + (transit.abilities_v1, transit.hints_v1.into()) + }, + PeerMessage::Error(err) => { + bail!(TransferError::PeerError(err)); + }, + other => { + let error = TransferError::unexpected_message("transit", other); let _ = wormhole .send(PeerMessage::Error(format!("{}", error)).serialize_vec()) .await; bail!(error) }, - } - }; + }; { // Wait for file_ack @@ -516,12 +451,22 @@ where } } - let mut transit = connector + let mut transit = match connector .leader_connect( wormhole.key().derive_transit_key(wormhole.appid()), - Arc::try_unwrap(other_side_ttype).unwrap(), + Arc::new(their_hints), ) - .await?; + .await + { + Ok(transit) => transit, + Err(error) => { + let error = TransferError::TransitConnect(error); + let _ = wormhole + .send(PeerMessage::Error(format!("{}", error)).serialize_vec()) + .await; + return Err(error); + }, + }; debug!("Beginning file transfer"); @@ -580,9 +525,15 @@ where std::io::Result::Ok(hasher.finalize_fixed()) }); - let checksum = send_records(&mut transit, &mut reader, length, progress_handler) - .await - .unwrap(); + let checksum = match send_records(&mut transit, &mut reader, length, progress_handler).await { + Err(TransferError::Transit(error)) => { + let _ = wormhole + .send(PeerMessage::Error(format!("{}", error)).serialize_vec()) + .await; + Err(TransferError::Transit(error)) + }, + other => other, + }?; /* This should always be ready by now, but just in case */ let sha256sum = file_sender.await.unwrap(); @@ -620,31 +571,35 @@ pub async fn request_file<'a>( let connector = transit::init(transit::Ability::all_abilities(), relay_url).await?; // send the transit message - debug!("Sending transit message '{:?}", connector.our_side_ttype()); + debug!("Sending transit message '{:?}", connector.our_hints()); wormhole .send( - crate::transfer::PeerMessage::Transit(connector.our_side_ttype().clone()) - .serialize_vec(), + PeerMessage::new_transit( + connector.our_abilities().to_vec(), + (**connector.our_hints()).clone().into(), + ) + .serialize_vec(), ) .await?; // receive transit message - let other_side_ttype = match serde_json::from_slice(&wormhole.receive().await?)? { - PeerMessage::Transit(transit) => { - debug!("received transit message: {:?}", transit); - transit - }, - PeerMessage::Error(err) => { - bail!(TransferError::PeerError(err)); - }, - other => { - let error = TransferError::unexpected_message("transit", other); - let _ = wormhole - .send(PeerMessage::Error(format!("{}", error)).serialize_vec()) - .await; - bail!(error) - }, - }; + let (their_abilities, their_hints): (Vec, transit::Hints) = + match serde_json::from_slice(&wormhole.receive().await?)? { + PeerMessage::Transit(transit) => { + debug!("received transit message: {:?}", transit); + (transit.abilities_v1, transit.hints_v1.into()) + }, + PeerMessage::Error(err) => { + bail!(TransferError::PeerError(err)); + }, + other => { + let error = TransferError::unexpected_message("transit", other); + let _ = wormhole + .send(PeerMessage::Error(format!("{}", error)).serialize_vec()) + .await; + bail!(error) + }, + }; // 3. receive file offer message from peer let maybe_offer = serde_json::from_slice(&wormhole.receive().await?)?; @@ -680,7 +635,7 @@ pub async fn request_file<'a>( filename, filesize, connector, - other_side_ttype, + their_hints: Arc::new(their_hints), }; Ok(req) @@ -698,7 +653,7 @@ pub struct ReceiveRequest<'a> { /// **Security warning:** this is untrusted and unverified input pub filename: PathBuf, pub filesize: u64, - other_side_ttype: Arc, + their_hints: Arc, } impl<'a> ReceiveRequest<'a> { @@ -722,25 +677,46 @@ impl<'a> ReceiveRequest<'a> { .send(PeerMessage::new_file_ack("ok").serialize_vec()) .await?; - let mut transit = self + let mut transit = match self .connector .follower_connect( self.wormhole .key() .derive_transit_key(self.wormhole.appid()), - self.other_side_ttype.clone(), + self.their_hints.clone(), ) - .await?; + .await + { + Ok(transit) => transit, + Err(error) => { + let error = TransferError::TransitConnect(error); + let _ = self + .wormhole + .send(PeerMessage::Error(format!("{}", error)).serialize_vec()) + .await; + return Err(error); + }, + }; debug!("Beginning file transfer"); // TODO here's the right position for applying the output directory and to check for malicious (relative) file paths - tcp_file_receive( + match tcp_file_receive( &mut transit, self.filesize, progress_handler, content_handler, ) .await + { + Err(TransferError::Transit(error)) => { + let _ = self + .wormhole + .send(PeerMessage::Error(format!("{}", error)).serialize_vec()) + .await; + Err(TransferError::Transit(error)) + }, + other => other, + } } /** @@ -793,7 +769,6 @@ where // send the encrypted record transit.send_record(&plaintext[0..n]).await?; sent_size += n as u64; - debug!("sent {} bytes out of {} bytes", sent_size, file_size); progress_handler(sent_size, file_size); // sha256 of the input diff --git a/src/transfer/messages.rs b/src/transfer/messages.rs new file mode 100644 index 00000000..5078ed31 --- /dev/null +++ b/src/transfer/messages.rs @@ -0,0 +1,207 @@ +//! Over-the-wire messages for the file transfer (including transit) +//! +//! The transit protocol does not specify how to deliver the information to +//! the other side, so it is up to the file transfer to do that. + +use crate::transit::{self, Ability, DirectHint}; +use serde_derive::{Deserialize, Serialize}; +#[cfg(test)] +use serde_json::json; +use std::path::PathBuf; + +/** + * The type of message exchanged over the wormhole for this protocol + */ +#[derive(Deserialize, Serialize, Debug, PartialEq)] +#[serde(rename_all = "kebab-case")] +pub enum PeerMessage { + Offer(OfferType), + Answer(AnswerType), + /** Tell the other side you got an error */ + Error(String), + /** Used to set up a transit channel */ + Transit(TransitType), + #[serde(other)] + Unknown, +} + +impl PeerMessage { + pub fn new_offer_message(msg: impl Into) -> Self { + PeerMessage::Offer(OfferType::Message(msg.into())) + } + + pub fn new_offer_file(name: impl Into, size: u64) -> Self { + PeerMessage::Offer(OfferType::File { + filename: name.into(), + filesize: size, + }) + } + + pub fn new_offer_directory( + name: impl Into, + mode: impl Into, + compressed_size: u64, + numbytes: u64, + numfiles: u64, + ) -> Self { + PeerMessage::Offer(OfferType::Directory { + dirname: name.into(), + mode: mode.into(), + zipsize: compressed_size, + numbytes, + numfiles, + }) + } + + pub fn new_message_ack(msg: impl Into) -> Self { + PeerMessage::Answer(AnswerType::MessageAck(msg.into())) + } + + pub fn new_file_ack(msg: impl Into) -> Self { + PeerMessage::Answer(AnswerType::FileAck(msg.into())) + } + + pub fn new_error_message(msg: impl Into) -> Self { + PeerMessage::Error(msg.into()) + } + + pub fn new_transit(abilities: Vec, hints: Vec) -> Self { + PeerMessage::Transit(TransitType { + abilities_v1: abilities, + hints_v1: hints, + }) + } + + #[cfg(test)] + pub fn serialize(&self) -> String { + json!(self).to_string() + } + + pub fn serialize_vec(&self) -> Vec { + serde_json::to_vec(self).unwrap() + } +} + +#[derive(Serialize, Deserialize, Debug, PartialEq)] +#[serde(rename_all = "kebab-case")] +pub enum OfferType { + Message(String), + File { + filename: PathBuf, + filesize: u64, + }, + Directory { + dirname: PathBuf, + mode: String, + zipsize: u64, + numbytes: u64, + numfiles: u64, + }, + #[serde(other)] + Unknown, +} + +#[derive(Serialize, Deserialize, Debug, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum AnswerType { + MessageAck(String), + FileAck(String), +} + +/** + * A set of hints for both sides to find each other + */ +#[derive(Serialize, Deserialize, Debug, PartialEq)] +#[serde(rename_all = "kebab-case")] +pub struct TransitType { + pub abilities_v1: Vec, + pub hints_v1: Vec, +} + +impl From for Vec { + fn from(hints: transit::Hints) -> Self { + hints + .direct_tcp + .into_iter() + .map(Hint::DirectTcpV1) + .chain(std::iter::once(Hint::new_relay(hints.relay))) + .collect() + } +} + +impl Into for Vec { + fn into(self) -> transit::Hints { + let mut direct_tcp = Vec::new(); + let mut relay = Vec::new(); + + /* There is only one "relay hint", though it may contain multiple + * items. Yes, this is inconsistent and weird, watch your step. + */ + for hint in self { + match hint { + Hint::DirectTcpV1(hint) => direct_tcp.push(hint), + Hint::DirectUdtV1(_) => unimplemented!(), + Hint::RelayV1(RelayHint { hints }) => relay.extend(hints), + } + } + + transit::Hints { direct_tcp, relay } + } +} + +#[derive(Serialize, Deserialize, Debug, PartialEq)] +#[serde(rename_all = "kebab-case", tag = "type")] +#[non_exhaustive] +pub enum Hint { + DirectTcpV1(DirectHint), + DirectUdtV1(DirectHint), + /* Weirdness alarm: a "relay hint" contains multiple "direct hints". This means + * that there may be multiple direct hints, but if there are multiple relay hints + * it's still only one item because it internally has a list. + */ + RelayV1(RelayHint), +} + +impl Hint { + pub fn new_direct_tcp(priority: f32, hostname: &str, port: u16) -> Self { + Hint::DirectTcpV1(DirectHint { + hostname: hostname.to_string(), + port, + }) + } + + pub fn new_direct_udt(priority: f32, hostname: &str, port: u16) -> Self { + Hint::DirectUdtV1(DirectHint { + hostname: hostname.to_string(), + port, + }) + } + + pub fn new_relay(h: Vec) -> Self { + Hint::RelayV1(RelayHint { hints: h }) + } +} + +#[derive(Serialize, Deserialize, Debug, PartialEq)] +pub struct RelayHint { + pub hints: Vec, +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_transit() { + let abilities = vec![Ability::DirectTcpV1, Ability::RelayV1]; + let hints = vec![ + Hint::new_direct_tcp(0.0, "192.168.1.8", 46295), + Hint::new_relay(vec![DirectHint { + hostname: "magic-wormhole-transit.debian.net".to_string(), + port: 4001, + }]), + ]; + let t = crate::transfer::PeerMessage::new_transit(abilities, hints); + assert_eq!(t.serialize(), "{\"transit\":{\"abilities-v1\":[{\"type\":\"direct-tcp-v1\"},{\"type\":\"relay-v1\"}],\"hints-v1\":[{\"hostname\":\"192.168.1.8\",\"port\":46295,\"type\":\"direct-tcp-v1\"},{\"hints\":[{\"hostname\":\"magic-wormhole-transit.debian.net\",\"port\":4001}],\"type\":\"relay-v1\"}]}}") + } +} diff --git a/src/transit.rs b/src/transit.rs index 9411ba14..cbb88fbb 100644 --- a/src/transit.rs +++ b/src/transit.rs @@ -4,7 +4,7 @@ //! but it depends on some kind of secure communication channel to talk to the other side. Conveniently, Wormhole provides //! exactly such a thing :) //! -//! Both clients exchange messages containing hints on how to find each other. These may be local IP Addresses for in case they +//! Both clients exchange messages containing hints on how to find each other. These may be local IP addresses for in case they //! are in the same network, or the URL to a relay server. In case a direct connection fails, both will connect to the relay server //! which will transparently glue the connections together. //! @@ -23,7 +23,7 @@ use async_std::{ #[allow(unused_imports)] /* We need them for the docs */ use futures::{future::TryFutureExt, Sink, Stream, StreamExt}; use log::*; -use std::{net::ToSocketAddrs, str::FromStr, sync::Arc}; +use std::{collections::HashSet, net::ToSocketAddrs, str::FromStr, sync::Arc}; use xsalsa20poly1305 as secretbox; use xsalsa20poly1305::aead::{Aead, NewAead}; @@ -97,32 +97,33 @@ pub enum TransitError { ), } -/** - * A set of hints for both sides to find each other - */ -#[derive(Serialize, Deserialize, Debug, PartialEq)] -#[serde(rename_all = "kebab-case")] -pub struct TransitType { - pub abilities_v1: Vec, - pub hints_v1: Vec, -} - /** * Defines a way to find the other side. * - * Each ability comes with a set of [hints](Hint) to encode how to meet up. + * Each ability comes with a set of [`Hints`] to encode how to meet up. */ -#[derive(Serialize, Deserialize, Debug, PartialEq)] +#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] #[serde(rename_all = "kebab-case", tag = "type")] #[non_exhaustive] pub enum Ability { /** - * Try to connect directly to the other side. + * Try to connect directly to the other side via TCP. * * This usually requires both participants to be in the same network. [`DirectHint`s](DirectHint) are sent, * which encode all local IP addresses for the other side to find us. */ DirectTcpV1, + /** + * Try to connect directly to the other side via UDT. + * + * This supersedes [`Ability::DirectTcpV1`] because it has several advantages: + * + * - Works with stateful firewalls, no need to open up ports + * - Works despite many NAT types if combined with STUN + * - UDT has a few other interesting performance-related properties that make it better + * suited than TCP (it's literally called "UDP-based Data Transfer Protocol") + */ + DirectUdtV1, /** Try to meet the other side at a relay. */ RelayV1, /* TODO Fix once https://github.com/serde-rs/serde/issues/912 is done */ @@ -132,45 +133,48 @@ pub enum Ability { impl Ability { pub fn all_abilities() -> Vec { - vec![Self::DirectTcpV1, Self::RelayV1] + vec![Self::DirectTcpV1, Self::DirectUdtV1, Self::RelayV1] } -} - -#[derive(Serialize, Deserialize, Debug, PartialEq)] -#[serde(rename_all = "kebab-case", tag = "type")] -#[non_exhaustive] -pub enum Hint { - DirectTcpV1(DirectHint), - RelayV1(RelayHint), -} -impl Hint { - pub fn new_direct(priority: f32, hostname: &str, port: u16) -> Self { - Hint::DirectTcpV1(DirectHint { - priority, - hostname: hostname.to_string(), - port, - }) + /** + * If you absolutely don't want to use any relay servers. + * + * If the other side forces relay usage or doesn't support any of your connection modes + * the attempt will fail. + */ + pub fn force_direct() -> Vec { + vec![Self::DirectTcpV1, Self::DirectUdtV1] } - pub fn new_relay(h: Vec) -> Self { - Hint::RelayV1(RelayHint { hints: h }) + /** + * If you don't want to disclose your IP address to your peer + * + * If the other side forces a the usage of a direct connection the attempt will fail. + * Note that the other side might control the relay server being used, if you really + * don't want your IP to potentially be disclosed use TOR instead (not supported by + * the Rust implementation yet). + */ + pub fn force_relay() -> Vec { + vec![Self::RelayV1] } } -#[derive(Serialize, Deserialize, Debug, PartialEq)] +#[derive(Clone, Debug, Default)] +pub struct Hints { + pub direct_tcp: Vec, + pub relay: Vec, +} + +#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq, Hash)] pub struct DirectHint { - pub priority: f32, + // DirectHint also contains a `priority` field, but it is underspecified + // and we won't use it + // pub priority: f32, pub hostname: String, pub port: u16, } -#[derive(Serialize, Deserialize, Debug, PartialEq)] -pub struct RelayHint { - pub hints: Vec, -} - -#[derive(Debug, PartialEq)] +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] enum HostType { Direct, Relay, @@ -203,65 +207,67 @@ impl FromStr for RelayUrl { /** * Initialize a relay handshake * - * Bind a port and generate our [`TransitType`]. This does not do any communication yet. + * Bind a port and generate our [`Hints`]. This does not do any communication yet. */ pub async fn init( abilities: Vec, relay_url: &RelayUrl, ) -> Result { - let listener = TcpListener::bind("[::]:0").await?; - let port = listener.local_addr()?.port(); + let mut our_hints = Hints::default(); + let mut listener = None; - let mut our_hints: Vec = Vec::new(); + /* Detect our local IP addresses if the ability is enabled */ if abilities.contains(&Ability::DirectTcpV1) { - our_hints.extend( + listener = Some(TcpListener::bind("[::]:0").await?); + let port = listener.as_ref().unwrap().local_addr()?.port(); + + our_hints.direct_tcp.extend( get_if_addrs::get_if_addrs()? .iter() .filter(|iface| !iface.is_loopback()) - .map(|ip| { - Hint::DirectTcpV1(DirectHint { - priority: 0.0, - hostname: ip.ip().to_string(), - port, - }) + .map(|ip| DirectHint { + hostname: ip.ip().to_string(), + port, }), ); } + if abilities.contains(&Ability::RelayV1) { - our_hints.push(Hint::new_relay(vec![DirectHint { - priority: 0.0, + our_hints.relay.push(DirectHint { hostname: relay_url.host.clone(), port: relay_url.port, - }])); + }); } Ok(TransitConnector { listener, - port, - our_side_ttype: Arc::new(TransitType { - abilities_v1: abilities, - hints_v1: our_hints, - }), + our_abilities: abilities, + our_hints: Arc::new(our_hints), }) } /** * A partially set up [`Transit`] connection. * - * For the transit handshake, each side generates a [ttype](`TransitType`) with all the hints to find the other. You need - * to exchange it (as in: send yours, receive theirs) with them. This is outside of the transit protocol to be protocol - * agnostic. + * For the transit handshake, each side generates a [`Hints`] with all the information to find the other. You need + * to exchange it (as in: send yours, receive theirs) with them. This is outside of the transit protocol, because we + * are protocol agnostic. */ pub struct TransitConnector { - listener: TcpListener, - port: u16, - our_side_ttype: Arc, + /* Only `Some` if direct-tcp-v1 ability has been enabled. */ + listener: Option, + our_abilities: Vec, + our_hints: Arc, } impl TransitConnector { + pub fn our_abilities(&self) -> &[Ability] { + &self.our_abilities + } + /** Send this one to the other side */ - pub fn our_side_ttype(&self) -> &Arc { - &self.our_side_ttype + pub fn our_hints(&self) -> &Arc { + &self.our_hints } /** @@ -270,74 +276,17 @@ impl TransitConnector { pub async fn leader_connect( self, transit_key: Key, - other_side_ttype: TransitType, + their_hints: Arc, ) -> Result { + let Self { + listener, + our_abilities, + our_hints, + } = self; let transit_key = Arc::new(transit_key); - /* TODO This Deref thing is getting out of hand. Maybe implementing AsRef or some other trait may help? */ - debug!("transit key {}", transit_key.to_hex()); - - let port = self.port; - let listener = self.listener; - // let other_side_ttype = Arc::new(other_side_ttype); - // TODO remove this one day - let ttype = &*Box::leak(Box::new(other_side_ttype)); - - // 8. listen for connections on the port and simultaneously try connecting to the peer port. - // extract peer's ip/hostname from 'ttype' - let (mut direct_hosts, mut relay_hosts) = get_direct_relay_hosts(ttype); - - let mut hosts: Vec<(HostType, &DirectHint)> = Vec::new(); - hosts.append(&mut direct_hosts); - hosts.append(&mut relay_hosts); - // TODO: combine our relay hints with the peer's relay hints. - - let mut handshake_futures = Vec::new(); - for host in hosts { - // TODO use async scopes to borrow instead of cloning one day - let transit_key = transit_key.clone(); - let future = async_std::task::spawn( - //async_std::future::timeout(Duration::from_secs(5), - async move { - debug!("host: {:?}", host); - let mut direct_host_iter = format!("{}:{}", host.1.hostname, host.1.port) - .to_socket_addrs() - .unwrap(); - let direct_host = direct_host_iter.next().unwrap(); - debug!("peer host: {}", direct_host); - - TcpStream::connect(direct_host) - .err_into::() - .and_then(|socket| leader_handshake_exchange(socket, host.0, &*transit_key)) - .await - }, - ); //); - handshake_futures.push(future); - } - handshake_futures.push(async_std::task::spawn(async move { - debug!("local host {}", port); - - /* Mixing and matching two different futures library probably isn't the - * best idea, but here we are. Simply be careful about prelude::* imports - * and don't have both StreamExt/FutureExt/… imported at once - */ - use futures::stream::TryStreamExt; - async_std::stream::StreamExt::skip_while(listener.incoming() - .err_into::() - .and_then(move |socket| { - /* Pinning a future + moving some value from outer scope is a bit painful */ - let transit_key = transit_key.clone(); - Box::pin(async move { - leader_handshake_exchange(socket, HostType::Direct, &*transit_key).await - }) - }), - Result::is_err) - /* We only care about the first that succeeds */ - .next() - .await - /* Next always returns Some because Incoming is an infinite stream. We gotta succeed _sometime_. */ - .unwrap() - })); + let mut handshake_futures = + Self::connect(true, transit_key, our_hints, their_hints, listener).await?; /* Try to get a Transit out of the first handshake that succeeds. If all fail, * we fail. @@ -363,17 +312,19 @@ impl TransitConnector { } let mut transit = transit; - /* Cancel all remaining non-finished handshakes */ + /* Cancel all remaining non-finished handshakes. We could send "nevermind" to explicitly tell + * the other side (probably, this is mostly for relay server statistics), but eeh, nevermind :) + */ handshake_futures .into_iter() .map(async_std::task::JoinHandle::cancel) .for_each(std::mem::drop); - debug!( - "Sending 'go' message to {}", + transit.socket.write_all(b"go\n").await?; + info!( + "Established transit connection to '{}'", transit.socket.peer_addr().unwrap() ); - transit.socket.write_all(b"go\n").await?; Ok(transit) } @@ -384,76 +335,17 @@ impl TransitConnector { pub async fn follower_connect( self, transit_key: Key, - other_side_ttype: Arc, + their_hints: Arc, ) -> Result { + let Self { + listener, + our_abilities, + our_hints, + } = self; let transit_key = Arc::new(transit_key); - /* TODO This Deref thing is getting out of hand. Maybe implementing AsRef or some other trait may help? */ - debug!("transit key {}", transit_key.to_hex()); - - let port = self.port; - let listener = self.listener; - // let other_side_ttype = Arc::new(other_side_ttype); - let ttype = &*Box::leak(Box::new(other_side_ttype)); // TODO remove this one day - - // 4. listen for connections on the port and simultaneously try connecting to the - // peer listening port. - let (mut direct_hosts, mut relay_hosts) = get_direct_relay_hosts(ttype); - - let mut hosts: Vec<(HostType, &DirectHint)> = Vec::new(); - hosts.append(&mut direct_hosts); - hosts.append(&mut relay_hosts); - // TODO: combine our relay hints with the peer's relay hints. - - let mut handshake_futures = Vec::new(); - for host in hosts { - let transit_key = transit_key.clone(); - let future = async_std::task::spawn( - //async_std::future::timeout(Duration::from_secs(5), - async move { - debug!("host: {:?}", host); - let mut direct_host_iter = format!("{}:{}", host.1.hostname, host.1.port) - .to_socket_addrs() - .unwrap(); - let direct_host = direct_host_iter.next().unwrap(); - - debug!("peer host: {}", direct_host); - - TcpStream::connect(direct_host) - .err_into::() - .and_then(|socket| { - follower_handshake_exchange(socket, host.0, &*transit_key) - }) - .await - }, - ); //); - handshake_futures.push(future); - } - handshake_futures.push(async_std::task::spawn(async move { - debug!("local host {}", port); - - /* Mixing and matching two different futures library probably isn't the - * best idea, but here we are. Simply be careful about prelude::* imports - * and don't have both StreamExt/FutureExt/… imported at once - */ - use futures::stream::TryStreamExt; - async_std::stream::StreamExt::skip_while(listener.incoming() - .err_into::() - .and_then(move |socket| { - /* Pinning a future + moving some value from outer scope is a bit painful */ - let transit_key = transit_key.clone(); - use futures::future::FutureExt; - async move { - follower_handshake_exchange(socket, HostType::Direct, &*transit_key).await - }.boxed() - }), - Result::is_err) - /* We only care about the first that succeeds */ - .next() - .await - /* Next always returns Some because Incoming is an infinite stream. We gotta succeed _sometime_. */ - .unwrap() - })); + let mut handshake_futures = + Self::connect(false, transit_key, our_hints, their_hints, listener).await?; /* Try to get a Transit out of the first handshake that succeeds. If all fail, * we fail. @@ -486,6 +378,100 @@ impl TransitConnector { Ok(transit) } + + /** Try to establish a connection with the peer. + * + * This encapsulates code that is common to both the leader and the follower. + */ + async fn connect( + is_leader: bool, + transit_key: Arc>, + our_hints: Arc, + their_hints: Arc, + listener: Option, + ) -> Result< + Vec>>, + TransitConnectError, + > { + // 8. listen for connections on the port and simultaneously try connecting to the peer port. + let tside = Arc::new(hex::encode(rand::random::<[u8; 8]>())); + let mut hosts: HashSet<(HostType, DirectHint)> = HashSet::new(); + hosts.extend( + their_hints + .direct_tcp + .iter() + .map(|hint| (HostType::Direct, hint.clone())), + ); + hosts.extend( + their_hints + .relay + .iter() + .map(|hint| (HostType::Relay, hint.clone())), + ); + hosts.extend( + our_hints + .relay + .iter() + .map(|hint| (HostType::Relay, hint.clone())), + ); + + let mut handshake_futures = Vec::new(); + for host in hosts { + let transit_key = transit_key.clone(); + let tside = tside.clone(); + let future = async_std::task::spawn( + //async_std::future::timeout(Duration::from_secs(5), + async move { + debug!("host: {:?}", host); + let mut direct_host_iter = format!("{}:{}", host.1.hostname, host.1.port) + .to_socket_addrs() + .unwrap(); + let direct_host = direct_host_iter.next().unwrap(); + + debug!("peer host: {}", direct_host); + + TcpStream::connect(direct_host) + .err_into::() + .and_then(|socket| { + handshake_exchange(is_leader, &*tside, socket, host.0, &*transit_key) + }) + .await + }, + ); //); + handshake_futures.push(future); + } + + /* If we allowed the other side to make direct connections to us, we have a listening + * socket that will attempt to complete handshakes + */ + if let Some(listener) = listener { + handshake_futures.push(async_std::task::spawn(async move { + /* Mixing and matching two different futures libraries probably isn't the + * best idea, but here we are. Simply be careful about prelude::* imports + * and don't have both StreamExt/FutureExt/… imported at once + */ + use futures::stream::TryStreamExt; + async_std::stream::StreamExt::skip_while(listener.incoming() + .err_into::() + .and_then(move |socket| { + /* Pinning a future + moving some value from outer scope is a bit painful */ + let transit_key = transit_key.clone(); + let tside = tside.clone(); + Box::pin(async move { + handshake_exchange(is_leader, &*tside, socket, HostType::Direct, &*transit_key).await + }) + }), + Result::is_err) + /* We only care about the first that succeeds */ + .next() + .await + /* Next always returns Some because Incoming is an infinite stream. We gotta succeed _sometime_. */ + .unwrap() + })); + } + + Ok(handshake_futures) + } } /** @@ -496,7 +482,7 @@ impl TransitConnector { */ pub struct Transit { /** Raw transit connection */ - pub socket: TcpStream, + socket: TcpStream, /** Our key, used for sending */ pub skey: Key, /** Their key, used for receiving */ @@ -630,40 +616,46 @@ impl Transit { } } -fn generate_transit_side() -> String { - let x: [u8; 8] = rand::random(); - hex::encode(x) -} - fn make_relay_handshake(key: &Key, tside: &str) -> String { let sub_key = key.derive_subkey_from_purpose::("transit_relay_token"); - format!( - "please relay {} for side {}\n", - sub_key.to_hex(), - tside - ) + format!("please relay {} for side {}\n", sub_key.to_hex(), tside) } -async fn follower_handshake_exchange( +/** + * Do a transit handshake exchange, to establish a direct connection. + * + * This automatically does the relay handshake first if necessary. On the follower + * side, the future will successfully run to completion if a connection could be + * established. On the leader side, the handshake is not 100% completed: the caller + * must write `Ok\n` into the stream that should be used (and optionally `Nevermind\n` + * into all others). + */ +async fn handshake_exchange( + is_leader: bool, + tside: &str, mut socket: TcpStream, host_type: HostType, key: &Key, ) -> Result { - // create record keys - /* The order here is correct. The "sender" and "receiver" side are a misnomer and should be called - * "leader" and "follower" instead. As a follower, we use the leader key for receiving and our - * key for sending. - */ - let rkey = key.derive_subkey_from_purpose("transit_record_sender_key"); - let skey = key.derive_subkey_from_purpose("transit_record_receiver_key"); - - // exchange handshake - let tside = generate_transit_side(); + // 9. create record keys + let (rkey, skey) = if is_leader { + let rkey = key.derive_subkey_from_purpose("transit_record_receiver_key"); + let skey = key.derive_subkey_from_purpose("transit_record_sender_key"); + (rkey, skey) + } else { + /* The order here is correct. The "sender" and "receiver" side are a misnomer and should be called + * "leader" and "follower" instead. As a follower, we use the leader key for receiving and our + * key for sending. + */ + let rkey = key.derive_subkey_from_purpose("transit_record_sender_key"); + let skey = key.derive_subkey_from_purpose("transit_record_receiver_key"); + (rkey, skey) + }; if host_type == HostType::Relay { trace!("initiating relay handshake"); socket - .write_all(make_relay_handshake(key, &tside).as_bytes()) + .write_all(make_relay_handshake(key, tside).as_bytes()) .await?; let mut rx = [0u8; 3]; socket.read_exact(&mut rx).await?; @@ -671,92 +663,62 @@ async fn follower_handshake_exchange( ensure!(ok_msg == rx, TransitHandshakeError::RelayHandshakeFailed); } - { - // for receive mode, send receive_handshake_msg and compare. + if is_leader { + // for transmit mode, send send_handshake_msg and compare. // the received message with send_handshake_msg - socket .write_all( format!( - "transit receiver {} ready\n\n", - key.derive_subkey_from_purpose::("transit_receiver").to_hex(), + "transit sender {} ready\n\n", + key.derive_subkey_from_purpose::("transit_sender") + .to_hex() ) .as_bytes(), ) .await?; - // The received message "transit receiver $hash ready\n\n" has exactly 87 bytes - // Three bytes for the "go\n" ack + // The received message "transit sender $hash ready\n\n" has exactly 89 bytes // TODO do proper line parsing one day, this is atrocious - let mut rx: [u8; 90] = [0; 90]; + let mut rx: [u8; 89] = [0; 89]; socket.read_exact(&mut rx).await?; - let expected_tx_handshake = format!( - "transit sender {} ready\n\ngo\n", - key.derive_subkey_from_purpose::("transit_sender").to_hex(), + let expected_rx_handshake = format!( + "transit receiver {} ready\n\n", + key.derive_subkey_from_purpose::("transit_receiver") + .to_hex() ); ensure!( - &rx[..] == expected_tx_handshake.as_bytes(), - TransitHandshakeError::HandshakeFailed + &rx[..] == expected_rx_handshake.as_bytes(), + TransitHandshakeError::HandshakeFailed, ); - } - - Ok(Transit { - socket, - skey, - rkey, - snonce: Default::default(), - rnonce: Default::default(), - }) -} - -async fn leader_handshake_exchange( - mut socket: TcpStream, - host_type: HostType, - key: &Key, -) -> Result { - // 9. create record keys - let skey = key.derive_subkey_from_purpose("transit_record_sender_key"); - let rkey = key.derive_subkey_from_purpose("transit_record_receiver_key"); - - // 10. exchange handshake over tcp - let tside = generate_transit_side(); - - if host_type == HostType::Relay { - socket - .write_all(make_relay_handshake(key, &tside).as_bytes()) - .await?; - let mut rx = [0u8; 3]; - socket.read_exact(&mut rx).await?; - let ok_msg: [u8; 3] = *b"ok\n"; - ensure!(ok_msg == rx, TransitHandshakeError::RelayHandshakeFailed); - } - - { - // for transmit mode, send send_handshake_msg and compare. + } else { + // for receive mode, send receive_handshake_msg and compare. // the received message with send_handshake_msg socket .write_all( format!( - "transit sender {} ready\n\n", - key.derive_subkey_from_purpose::("transit_sender").to_hex() + "transit receiver {} ready\n\n", + key.derive_subkey_from_purpose::("transit_receiver") + .to_hex(), ) .as_bytes(), ) .await?; - // The received message "transit sender $hash ready\n\n" has exactly 89 bytes + // The received message "transit receiver $hash ready\n\n" has exactly 87 bytes + // Three bytes for the "go\n" ack // TODO do proper line parsing one day, this is atrocious - let mut rx: [u8; 89] = [0; 89]; + let mut rx: [u8; 90] = [0; 90]; socket.read_exact(&mut rx).await?; - let expected_rx_handshake = format!( - "transit receiver {} ready\n\n", - key.derive_subkey_from_purpose::("transit_receiver").to_hex() + let expected_tx_handshake = format!( + "transit sender {} ready\n\ngo\n", + key.derive_subkey_from_purpose::("transit_sender") + .to_hex(), ); ensure!( - &rx[..] == expected_rx_handshake.as_bytes(), - TransitHandshakeError::HandshakeFailed, + &rx[..] == expected_tx_handshake.as_bytes(), + TransitHandshakeError::HandshakeFailed ); } @@ -768,60 +730,3 @@ async fn leader_handshake_exchange( rnonce: Default::default(), }) } - -#[allow(clippy::type_complexity)] -fn get_direct_relay_hosts<'a, 'b: 'a>( - ttype: &'b TransitType, -) -> ( - Vec<(HostType, &'a DirectHint)>, - Vec<(HostType, &'a DirectHint)>, -) { - let direct_hosts: Vec<(HostType, &DirectHint)> = ttype - .hints_v1 - .iter() - .filter_map(|hint| match hint { - Hint::DirectTcpV1(dt) => Some((HostType::Direct, dt)), - _ => None, - }) - .collect(); - let relay_hosts_list: Vec<&Vec> = ttype - .hints_v1 - .iter() - .filter_map(|hint| match hint { - Hint::RelayV1(rt) => Some(&rt.hints), - _ => None, - }) - .collect(); - - let _hosts: Vec<(HostType, &DirectHint)> = Vec::new(); - let maybe_relay_hosts = relay_hosts_list.first(); - let relay_hosts: Vec<(HostType, &DirectHint)> = match maybe_relay_hosts { - Some(relay_host_vec) => relay_host_vec - .iter() - .map(|host| (HostType::Relay, host)) - .collect(), - None => vec![], - }; - - (direct_hosts, relay_hosts) -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_transit() { - let abilities = vec![Ability::DirectTcpV1, Ability::RelayV1]; - let hints = vec![ - Hint::new_direct(0.0, "192.168.1.8", 46295), - Hint::new_relay(vec![DirectHint { - priority: 2.0, - hostname: "magic-wormhole-transit.debian.net".to_string(), - port: 4001, - }]), - ]; - let t = crate::transfer::PeerMessage::new_transit(abilities, hints); - assert_eq!(t.serialize(), "{\"transit\":{\"abilities-v1\":[{\"type\":\"direct-tcp-v1\"},{\"type\":\"relay-v1\"}],\"hints-v1\":[{\"hostname\":\"192.168.1.8\",\"port\":46295,\"priority\":0.0,\"type\":\"direct-tcp-v1\"},{\"hints\":[{\"hostname\":\"magic-wormhole-transit.debian.net\",\"port\":4001,\"priority\":2.0}],\"type\":\"relay-v1\"}]}}") - } -} From 103b33c978ea45beb0bc90b7e4fcd3b8096fe697 Mon Sep 17 00:00:00 2001 From: piegames Date: Thu, 26 Aug 2021 19:34:24 +0200 Subject: [PATCH 3/9] Transit refactor III: replace spawned tasks with Stream --- src/transfer.rs | 5 + src/transit.rs | 323 +++++++++++++++++++++++++++--------------------- 2 files changed, 184 insertions(+), 144 deletions(-) diff --git a/src/transfer.rs b/src/transfer.rs index 1f9ea50c..de750271 100644 --- a/src/transfer.rs +++ b/src/transfer.rs @@ -275,6 +275,7 @@ where let mut transit = match connector .leader_connect( wormhole.key().derive_transit_key(wormhole.appid()), + Arc::new(their_abilities), Arc::new(their_hints), ) .await @@ -454,6 +455,7 @@ where let mut transit = match connector .leader_connect( wormhole.key().derive_transit_key(wormhole.appid()), + Arc::new(their_abilities), Arc::new(their_hints), ) .await @@ -635,6 +637,7 @@ pub async fn request_file<'a>( filename, filesize, connector, + their_abilities: Arc::new(their_abilities), their_hints: Arc::new(their_hints), }; @@ -653,6 +656,7 @@ pub struct ReceiveRequest<'a> { /// **Security warning:** this is untrusted and unverified input pub filename: PathBuf, pub filesize: u64, + their_abilities: Arc>, their_hints: Arc, } @@ -683,6 +687,7 @@ impl<'a> ReceiveRequest<'a> { self.wormhole .key() .derive_transit_key(self.wormhole.appid()), + self.their_abilities.clone(), self.their_hints.clone(), ) .await diff --git a/src/transit.rs b/src/transit.rs index cbb88fbb..84849c3f 100644 --- a/src/transit.rs +++ b/src/transit.rs @@ -13,7 +13,7 @@ //! **Notice:** while the resulting TCP connection is naturally bi-directional, the handshake is not symmetric. There *must* be one //! "leader" side and one "follower" side (formerly called "sender" and "receiver"). -use crate::{core::WormholeError, Key, KeyPurpose}; +use crate::{Key, KeyPurpose}; use serde_derive::{Deserialize, Serialize}; use async_std::{ @@ -21,7 +21,7 @@ use async_std::{ net::{TcpListener, TcpStream}, }; #[allow(unused_imports)] /* We need them for the docs */ -use futures::{future::TryFutureExt, Sink, Stream, StreamExt}; +use futures::{future::TryFutureExt, Sink, SinkExt, Stream, StreamExt, TryStreamExt}; use log::*; use std::{collections::HashSet, net::ToSocketAddrs, str::FromStr, sync::Arc}; use xsalsa20poly1305 as secretbox; @@ -43,14 +43,11 @@ impl KeyPurpose for TransitTxKey {} #[derive(Debug, thiserror::Error)] #[non_exhaustive] pub enum TransitConnectError { - #[error("All (relay) handshakes failed, could not establish a connection with the peer")] + /** Incompatible abilities, or wrong hints */ + #[error("{}", _0)] + Protocol(Box), + #[error("All (relay) handshakes failed or timed out; could not establish a connection with the peer")] Handshake, - #[error("Wormhole connection error")] - Wormhole( - #[from] - #[source] - WormholeError, - ), #[error("IO error")] IO( #[from] @@ -68,12 +65,6 @@ enum TransitHandshakeError { HandshakeFailed, #[error("Relay handshake failed")] RelayHandshakeFailed, - #[error("Wormhole connection error")] - Wormhole( - #[from] - #[source] - WormholeError, - ), #[error("IO error")] IO( #[from] @@ -241,7 +232,7 @@ pub async fn init( Ok(TransitConnector { listener, - our_abilities: abilities, + our_abilities: Arc::new(abilities), our_hints: Arc::new(our_hints), }) } @@ -256,12 +247,12 @@ pub async fn init( pub struct TransitConnector { /* Only `Some` if direct-tcp-v1 ability has been enabled. */ listener: Option, - our_abilities: Vec, + our_abilities: Arc>, our_hints: Arc, } impl TransitConnector { - pub fn our_abilities(&self) -> &[Ability] { + pub fn our_abilities(&self) -> &Arc> { &self.our_abilities } @@ -276,6 +267,7 @@ impl TransitConnector { pub async fn leader_connect( self, transit_key: Key, + their_abilities: Arc>, their_hints: Arc, ) -> Result { let Self { @@ -285,40 +277,67 @@ impl TransitConnector { } = self; let transit_key = Arc::new(transit_key); - let mut handshake_futures = - Self::connect(true, transit_key, our_hints, their_hints, listener).await?; + let start = std::time::Instant::now(); + let mut connection_stream = Box::pin( + Self::connect( + true, + transit_key, + our_abilities.clone(), + our_hints, + their_abilities, + their_hints, + listener, + ) + .filter_map(|result| async { + match result { + Ok(val) => Some(val), + Err(err) => { + log::debug!("Some leader handshake failed: {}", err); + None + }, + } + }), + ); - /* Try to get a Transit out of the first handshake that succeeds. If all fail, - * we fail. - */ - let transit; - loop { - ensure!( - !handshake_futures.is_empty(), - TransitConnectError::Handshake - ); + let (mut transit, host_type) = connection_stream + .next() + .await + .ok_or(TransitConnectError::Handshake)?; - match futures::future::select_all(handshake_futures).await { - (Ok(transit2), _index, remaining) => { - transit = transit2; - handshake_futures = remaining; - break; - }, - (Err(e), _index, remaining) => { - debug!("Some handshake failed {:#}", e); - handshake_futures = remaining; - }, - } + if host_type == HostType::Relay && our_abilities.contains(&Ability::DirectTcpV1) { + log::debug!( + "Established transit connection over relay. Trying to find a direct connection …" + ); + /* Measure the time it took us to get a response. Based on this, wait some more for more responses + * in case we like one better. + */ + let elapsed = start.elapsed(); + let to_wait = if elapsed.as_secs() > 5 { + /* If our RTT was *that* long, let's just be happy we even got one connection */ + std::time::Duration::from_secs(1) + } else { + elapsed.mul_f32(0.3) + }; + let _ = async_std::future::timeout(to_wait, async { + while let Some((new_transit, new_host_type)) = connection_stream.next().await { + /* We already got a connection, so we're only interested in direct ones */ + if new_host_type == HostType::Direct { + transit = new_transit; + log::debug!("Found direct connection; using that instead."); + break; + } + } + }) + .await; + log::debug!("Did not manage to establish a better connection in time."); + } else { + log::debug!("Established direct transit connection"); } - let mut transit = transit; /* Cancel all remaining non-finished handshakes. We could send "nevermind" to explicitly tell * the other side (probably, this is mostly for relay server statistics), but eeh, nevermind :) */ - handshake_futures - .into_iter() - .map(async_std::task::JoinHandle::cancel) - .for_each(std::mem::drop); + std::mem::drop(connection_stream); transit.socket.write_all(b"go\n").await?; info!( @@ -335,6 +354,7 @@ impl TransitConnector { pub async fn follower_connect( self, transit_key: Key, + their_abilities: Arc>, their_hints: Arc, ) -> Result { let Self { @@ -344,69 +364,94 @@ impl TransitConnector { } = self; let transit_key = Arc::new(transit_key); - let mut handshake_futures = - Self::connect(false, transit_key, our_hints, their_hints, listener).await?; - - /* Try to get a Transit out of the first handshake that succeeds. If all fail, - * we fail. - */ - let transit; - loop { - ensure!( - !handshake_futures.is_empty(), - TransitConnectError::Handshake - ); + let mut connection_stream = Box::pin( + Self::connect( + false, + transit_key, + our_abilities, + our_hints, + their_abilities, + their_hints, + listener, + ) + .filter_map(|result| async { + match result { + Ok(val) => Some(val), + Err(err) => { + log::debug!("Some follower handshake failed: {}", err); + None + }, + } + }), + ); - match futures::future::select_all(handshake_futures).await { - (Ok(transit2), _index, remaining) => { - transit = transit2; - handshake_futures = remaining; - break; - }, - (Err(e), _index, remaining) => { - debug!("Some handshake failed {:#}", e); - handshake_futures = remaining; - }, - } - } + let transit = match async_std::future::timeout( + std::time::Duration::from_secs(60), + &mut connection_stream.next(), + ) + .await + { + Ok(Some((transit, host_type))) => { + log::debug!( + "Established a {} transit connection.", + if host_type == HostType::Direct { + "direct" + } else { + "relay" + } + ); + Ok(transit) + }, + Ok(None) | Err(_) => Err(TransitConnectError::Handshake), + }; - /* Cancel all remaining non-finished handshakes */ - handshake_futures - .into_iter() - .map(async_std::task::JoinHandle::cancel) - .for_each(std::mem::drop); + /* Cancel all remaining non-finished handshakes. We could send "nevermind" to explicitly tell + * the other side (probably, this is mostly for relay server statistics), but eeh, nevermind :) + */ + std::mem::drop(connection_stream); - Ok(transit) + transit } /** Try to establish a connection with the peer. * * This encapsulates code that is common to both the leader and the follower. + * + * ## Panics + * + * If the receiving end of the channel for the results is closed before all futures in the return + * value are cancelled/dropped. */ - async fn connect( + fn connect( is_leader: bool, transit_key: Arc>, + our_abilities: Arc>, our_hints: Arc, + _their_abilities: Arc>, their_hints: Arc, listener: Option, - ) -> Result< - Vec>>, - TransitConnectError, - > { + ) -> impl Stream> { + assert!(listener.is_some() == our_abilities.contains(&Ability::DirectTcpV1)); + // 8. listen for connections on the port and simultaneously try connecting to the peer port. let tside = Arc::new(hex::encode(rand::random::<[u8; 8]>())); + + /* Process the peer's information first. We only take up to 20 items to prevent DOS, + * this should be absolutely more than enough. + */ let mut hosts: HashSet<(HostType, DirectHint)> = HashSet::new(); hosts.extend( their_hints .direct_tcp .iter() - .map(|hint| (HostType::Direct, hint.clone())), - ); - hosts.extend( - their_hints - .relay - .iter() - .map(|hint| (HostType::Relay, hint.clone())), + .map(|hint| (HostType::Direct, hint.clone())) + .chain( + their_hints + .relay + .iter() + .map(|hint| (HostType::Relay, hint.clone())), + ) + .take(20), ); hosts.extend( our_hints @@ -415,62 +460,55 @@ impl TransitConnector { .map(|hint| (HostType::Relay, hint.clone())), ); - let mut handshake_futures = Vec::new(); - for host in hosts { - let transit_key = transit_key.clone(); - let tside = tside.clone(); - let future = async_std::task::spawn( - //async_std::future::timeout(Duration::from_secs(5), - async move { - debug!("host: {:?}", host); - let mut direct_host_iter = format!("{}:{}", host.1.hostname, host.1.port) - .to_socket_addrs() - .unwrap(); - let direct_host = direct_host_iter.next().unwrap(); - - debug!("peer host: {}", direct_host); - - TcpStream::connect(direct_host) - .err_into::() - .and_then(|socket| { - handshake_exchange(is_leader, &*tside, socket, host.0, &*transit_key) - }) - .await - }, - ); //); - handshake_futures.push(future); - } + /* Now try to find a connection */ + let mut iterator: Box> = Box::new( + hosts + .into_iter() + .map(|host| { + let transit_key = transit_key.clone(); + let tside = tside.clone(); + async move { + let mut direct_host_iter = format!("{}:{}", host.1.hostname, host.1.port) + .to_socket_addrs() + .unwrap(); + let direct_host = direct_host_iter.next().unwrap(); + + let transit = TcpStream::connect(direct_host) + .err_into::() + .and_then(|socket| { + handshake_exchange(is_leader, tside, socket, host.0, transit_key) + }) + .await?; + + Ok((transit, host.0)) + } + }) + .map(futures::stream::once) + .map(|stream| stream.left_stream()), + ); /* If we allowed the other side to make direct connections to us, we have a listening * socket that will attempt to complete handshakes */ if let Some(listener) = listener { - handshake_futures.push(async_std::task::spawn(async move { - /* Mixing and matching two different futures libraries probably isn't the - * best idea, but here we are. Simply be careful about prelude::* imports - * and don't have both StreamExt/FutureExt/… imported at once - */ - use futures::stream::TryStreamExt; - async_std::stream::StreamExt::skip_while(listener.incoming() - .err_into::() - .and_then(move |socket| { - /* Pinning a future + moving some value from outer scope is a bit painful */ - let transit_key = transit_key.clone(); - let tside = tside.clone(); - Box::pin(async move { - handshake_exchange(is_leader, &*tside, socket, HostType::Direct, &*transit_key).await - }) - }), - Result::is_err) - /* We only care about the first that succeeds */ - .next() - .await - /* Next always returns Some because Incoming is an infinite stream. We gotta succeed _sometime_. */ - .unwrap() - })); + let transit_key = transit_key.clone(); + let tside = tside.clone(); + let stream = futures::stream::unfold(listener, |listener| async move { + let next = listener.accept().await; + Some((next, listener)) + }) + .err_into::() + .and_then(move |(socket, _listener)| { + /* We need to clone again because it will be moved into the returned future */ + let transit_key = transit_key.clone(); + let tside = tside.clone(); + handshake_exchange(is_leader, tside, socket, HostType::Direct, transit_key) + }) + .map(|result| result.map(|transit| (transit, HostType::Direct))); + iterator = Box::new(iterator.chain(std::iter::once(stream.right_stream()))) } - Ok(handshake_futures) + futures::stream::select_all(iterator.map(Box::pin)) } } @@ -616,11 +654,6 @@ impl Transit { } } -fn make_relay_handshake(key: &Key, tside: &str) -> String { - let sub_key = key.derive_subkey_from_purpose::("transit_relay_token"); - format!("please relay {} for side {}\n", sub_key.to_hex(), tside) -} - /** * Do a transit handshake exchange, to establish a direct connection. * @@ -632,10 +665,10 @@ fn make_relay_handshake(key: &Key, tside: &str) -> String { */ async fn handshake_exchange( is_leader: bool, - tside: &str, + tside: Arc, mut socket: TcpStream, host_type: HostType, - key: &Key, + key: Arc>, ) -> Result { // 9. create record keys let (rkey, skey) = if is_leader { @@ -654,8 +687,10 @@ async fn handshake_exchange( if host_type == HostType::Relay { trace!("initiating relay handshake"); + + let sub_key = key.derive_subkey_from_purpose::("transit_relay_token"); socket - .write_all(make_relay_handshake(key, tside).as_bytes()) + .write_all(format!("please relay {} for side {}\n", sub_key.to_hex(), tside).as_bytes()) .await?; let mut rx = [0u8; 3]; socket.read_exact(&mut rx).await?; From 04cf173e38da593364b2087400de28ad67d5648a Mon Sep 17 00:00:00 2001 From: piegames Date: Sat, 18 Sep 2021 01:20:22 +0200 Subject: [PATCH 4/9] Transit refactor IV: do firewall hole punching --- Cargo.toml | 3 + src/transit.rs | 358 ++++++++++++++++++++++++++++++++++++------------- 2 files changed, 271 insertions(+), 90 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d4343838..40da2656 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,9 @@ thiserror = "1.0.24" futures = "0.3.12" async-std = { version = "1.9.0", features = ["attributes", "unstable"] } async-tungstenite = { version = "0.14.0", features = ["async-std-runtime", "async-tls"] } +async-io = "1.6.0" +socket2 = "0.4.1" +libc = "0.2.101" # for "bin" feature clap = { version = "2.33.3", optional = true } diff --git a/src/transit.rs b/src/transit.rs index 84849c3f..64a8b0da 100644 --- a/src/transit.rs +++ b/src/transit.rs @@ -23,7 +23,7 @@ use async_std::{ #[allow(unused_imports)] /* We need them for the docs */ use futures::{future::TryFutureExt, Sink, SinkExt, Stream, StreamExt, TryStreamExt}; use log::*; -use std::{collections::HashSet, net::ToSocketAddrs, str::FromStr, sync::Arc}; +use std::{collections::HashSet, str::FromStr, sync::Arc}; use xsalsa20poly1305 as secretbox; use xsalsa20poly1305::aead::{Aead, NewAead}; @@ -65,6 +65,12 @@ enum TransitHandshakeError { HandshakeFailed, #[error("Relay handshake failed")] RelayHandshakeFailed, + #[error("Malformed peer address")] + BadAddress( + #[from] + #[source] + std::net::AddrParseError, + ), #[error("IO error")] IO( #[from] @@ -105,6 +111,7 @@ pub enum Ability { */ DirectTcpV1, /** + * UNSTABLE; NOT IMPLEMENTED! * Try to connect directly to the other side via UDT. * * This supersedes [`Ability::DirectTcpV1`] because it has several advantages: @@ -156,7 +163,8 @@ pub struct Hints { pub relay: Vec, } -#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq, Hash)] +#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq, Hash, derive_more::Display)] +#[display(fmt = "tcp://{}:{}", hostname, port)] pub struct DirectHint { // DirectHint also contains a `priority` field, but it is underspecified // and we won't use it @@ -165,6 +173,28 @@ pub struct DirectHint { pub port: u16, } +use std::convert::{TryFrom, TryInto}; + +impl TryFrom<&DirectHint> for std::net::IpAddr { + type Error = std::net::AddrParseError; + fn try_from(hint: &DirectHint) -> Result { + hint.hostname.parse() + } +} + +impl TryFrom<&DirectHint> for std::net::SocketAddr { + type Error = std::net::AddrParseError; + /** This does not do the obvious thing and also implicitly maps all V4 addresses into V6 */ + fn try_from(hint: &DirectHint) -> Result { + let addr = hint.try_into()?; + let addr = match addr { + std::net::IpAddr::V4(v4) => std::net::IpAddr::V6(v4.to_ipv6_mapped()), + std::net::IpAddr::V6(_) => addr, + }; + Ok(std::net::SocketAddr::new(addr, hint.port)) + } +} + #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] enum HostType { Direct, @@ -195,6 +225,48 @@ impl FromStr for RelayUrl { } } +/** + * Bind to a port with SO_REUSEADDR, connect to the destination and then hide the blood behind a pretty [`async_std::net::TcpStream`] + * + * We want an `async_std::net::TcpStream`, but with SO_REUSEADDR set. + * The former is just a wrapper around `async_io::Async`, of which we + * copy the `connect` method to add a statement that will set the socket flag. + * See https://github.com/smol-rs/async-net/issues/20. + */ +async fn connect_custom( + local_addr: &socket2::SockAddr, + dest_addr: &socket2::SockAddr, +) -> std::io::Result { + let socket = socket2::Socket::new(socket2::Domain::IPV6, socket2::Type::STREAM, None)?; + socket.set_nonblocking(true)?; + /* Set our custum options */ + socket.set_reuse_address(true)?; + socket.set_reuse_port(true)?; + + socket.bind(local_addr)?; + + /* Initiate connect */ + match socket.connect(dest_addr) { + Ok(_) => {}, + #[cfg(unix)] + Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => {}, + Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => {}, + Err(err) => return Err(err), + } + + let stream = async_io::Async::new(std::net::TcpStream::from(socket))?; + /* The stream becomes writable when connected. */ + stream.writable().await?; + + /* Check if there was an error while connecting. */ + stream + .get_ref() + .take_error() + .and_then(|maybe_err| maybe_err.map_or(Ok(()), Result::Err))?; + /* Convert our mess to `async_std::net::TcpStream */ + Ok(stream.into_inner()?.into()) +} + /** * Initialize a relay handshake * @@ -209,18 +281,45 @@ pub async fn init( /* Detect our local IP addresses if the ability is enabled */ if abilities.contains(&Ability::DirectTcpV1) { - listener = Some(TcpListener::bind("[::]:0").await?); - let port = listener.as_ref().unwrap().local_addr()?.port(); + /* Bind a port and find out which addresses it has */ + let socket = + socket2::Socket::new(socket2::Domain::IPV6, socket2::Type::STREAM, None).unwrap(); + socket.set_nonblocking(false).unwrap(); + socket.set_reuse_address(true).unwrap(); + socket.set_reuse_port(true).unwrap(); + + socket + .bind(&"[::]:0".parse::().unwrap().into()) + .unwrap(); + + let port = socket.local_addr().unwrap().as_socket().unwrap().port(); + + /* Do the same thing again, but this time open a listener on that port. + * This sadly doubles the number of hints, but the method above doesn't work + * for systems which don't have any firewalls. + */ + let socket2 = TcpListener::bind("[::]:0").await?; + let port2 = socket2.local_addr().unwrap().port(); our_hints.direct_tcp.extend( get_if_addrs::get_if_addrs()? .iter() .filter(|iface| !iface.is_loopback()) - .map(|ip| DirectHint { - hostname: ip.ip().to_string(), - port, - }), + .flat_map(|ip| + /* TODO replace with array once into_iter works as it should */ + vec![ + DirectHint { + hostname: ip.ip().to_string(), + port, + }, + DirectHint { + hostname: ip.ip().to_string(), + port: port2, + }, + ].into_iter()), ); + + listener = Some((socket, socket2)); } if abilities.contains(&Ability::RelayV1) { @@ -231,7 +330,7 @@ pub async fn init( } Ok(TransitConnector { - listener, + sockets: listener, our_abilities: Arc::new(abilities), our_hints: Arc::new(our_hints), }) @@ -245,8 +344,11 @@ pub async fn init( * are protocol agnostic. */ pub struct TransitConnector { - /* Only `Some` if direct-tcp-v1 ability has been enabled. */ - listener: Option, + /* Only `Some` if direct-tcp-v1 ability has been enabled. + * The first socket is the port from which we will start connection attempts. + * For in case the user is behind no firewalls, we must also listen to the second socket. + */ + sockets: Option<(socket2::Socket, TcpListener)>, our_abilities: Arc>, our_hints: Arc, } @@ -271,7 +373,7 @@ impl TransitConnector { their_hints: Arc, ) -> Result { let Self { - listener, + sockets, our_abilities, our_hints, } = self; @@ -286,23 +388,29 @@ impl TransitConnector { our_hints, their_abilities, their_hints, - listener, + sockets, ) .filter_map(|result| async { match result { Ok(val) => Some(val), Err(err) => { - log::debug!("Some leader handshake failed: {}", err); + log::debug!("Some leader handshake failed: {:?}", err); None }, } }), ); - let (mut transit, host_type) = connection_stream - .next() - .await - .ok_or(TransitConnectError::Handshake)?; + let (mut transit, host_type) = async_std::future::timeout( + std::time::Duration::from_secs(60), + connection_stream.next(), + ) + .await + .map_err(|_| { + log::debug!("`leader_connect` timed out"); + TransitConnectError::Handshake + })? + .ok_or(TransitConnectError::Handshake)?; if host_type == HostType::Relay && our_abilities.contains(&Ability::DirectTcpV1) { log::debug!( @@ -358,7 +466,7 @@ impl TransitConnector { their_hints: Arc, ) -> Result { let Self { - listener, + sockets, our_abilities, our_hints, } = self; @@ -372,13 +480,13 @@ impl TransitConnector { our_hints, their_abilities, their_hints, - listener, + sockets, ) .filter_map(|result| async { match result { Ok(val) => Some(val), Err(err) => { - log::debug!("Some follower handshake failed: {}", err); + log::debug!("Some follower handshake failed: {:?}", err); None }, } @@ -402,7 +510,10 @@ impl TransitConnector { ); Ok(transit) }, - Ok(None) | Err(_) => Err(TransitConnectError::Handshake), + Ok(None) | Err(_) => { + log::debug!("`follower_connect` timed out"); + Err(TransitConnectError::Handshake) + }, }; /* Cancel all remaining non-finished handshakes. We could send "nevermind" to explicitly tell @@ -427,88 +538,155 @@ impl TransitConnector { transit_key: Arc>, our_abilities: Arc>, our_hints: Arc, - _their_abilities: Arc>, + their_abilities: Arc>, their_hints: Arc, - listener: Option, - ) -> impl Stream> { - assert!(listener.is_some() == our_abilities.contains(&Ability::DirectTcpV1)); + socket: Option<(socket2::Socket, TcpListener)>, + ) -> impl Stream> + 'static { + assert!(socket.is_some() == our_abilities.contains(&Ability::DirectTcpV1)); // 8. listen for connections on the port and simultaneously try connecting to the peer port. let tside = Arc::new(hex::encode(rand::random::<[u8; 8]>())); - /* Process the peer's information first. We only take up to 20 items to prevent DOS, - * this should be absolutely more than enough. + /* Iterator of futures yielding a connection. They'll be then mapped with the handshake, collected into + * a Vec and polled concurrently. */ - let mut hosts: HashSet<(HostType, DirectHint)> = HashSet::new(); - hosts.extend( - their_hints - .direct_tcp - .iter() - .map(|hint| (HostType::Direct, hint.clone())) - .chain( + use futures::future::BoxFuture; + type BoxIterator = Box>; + type ConnectorFuture = + BoxFuture<'static, Result<(TcpStream, HostType), TransitHandshakeError>>; + // type ConnectorIterator = Box>; + let mut connectors: BoxIterator = Box::new(std::iter::empty()); + + /* Create direct connection sockets, if we support it. If peer doesn't support it, their list of hints will + * be empty and no entries will be pushed. + */ + let socket2 = if let Some((socket, socket2)) = socket { + let local_addr = Arc::new(socket.local_addr().unwrap()); + dbg!(&their_hints.direct_tcp); + /* Connect to each hint of the peer */ + connectors = Box::new( + connectors.chain( their_hints - .relay - .iter() - .map(|hint| (HostType::Relay, hint.clone())), - ) - .take(20), - ); - hosts.extend( - our_hints - .relay - .iter() - .map(|hint| (HostType::Relay, hint.clone())), - ); - - /* Now try to find a connection */ - let mut iterator: Box> = Box::new( - hosts - .into_iter() - .map(|host| { - let transit_key = transit_key.clone(); - let tside = tside.clone(); - async move { - let mut direct_host_iter = format!("{}:{}", host.1.hostname, host.1.port) - .to_socket_addrs() - .unwrap(); - let direct_host = direct_host_iter.next().unwrap(); + .direct_tcp + .clone() + .into_iter() + /* Nobody should have that many IP addresses, even with NATing */ + .take(10) + .map(move |hint| { + let local_addr = local_addr.clone(); + async move { + let dest_addr = std::net::SocketAddr::try_from(&hint)?; + log::debug!("Connecting directly to {}", dest_addr); + let socket = connect_custom(&local_addr, &dest_addr.into()).await?; + log::debug!("Connected to {}!", dest_addr); + Ok((socket, HostType::Direct)) + } + }) + .map(|fut| Box::pin(fut) as ConnectorFuture), + ), + ) as BoxIterator; + Some(socket2) + } else { + None + }; - let transit = TcpStream::connect(direct_host) + /* Relay hints. Make sure that both sides adverize it, since it is fine to support it without providing own hints. */ + if our_abilities.contains(&Ability::RelayV1) && their_abilities.contains(&Ability::RelayV1) + { + connectors = Box::new( + connectors.chain( + /* TODO maybe take 2 at random instead of always the first two? */ + /* TODO also deduplicate the results list */ + our_hints + .relay + .clone() + .into_iter() + .take(2) + .chain( + their_hints + .relay + .clone() + .into_iter() + .take(2) + ) + .map(|host| async move { + log::debug!("Connecting to relay {}", host); + let transit = TcpStream::connect((host.hostname.as_str(), host.port)) .err_into::() - .and_then(|socket| { - handshake_exchange(is_leader, tside, socket, host.0, transit_key) - }) .await?; + log::debug!("Connected to {}!", host); - Ok((transit, host.0)) + Ok((transit, HostType::Relay)) + }) + .map(|fut| Box::pin(fut) as ConnectorFuture) + ), + ) as BoxIterator; + } + + /* Do a handshake on all our found connections */ + let transit_key2 = transit_key.clone(); + let tside2 = tside.clone(); + let mut connectors = Box::new( + connectors + .map(move |fut| { + let transit_key = transit_key2.clone(); + let tside = tside2.clone(); + async move { + let (socket, host_type) = fut.await?; + let transit = + handshake_exchange(is_leader, tside, socket, host_type, transit_key) + .await?; + Ok((transit, host_type)) } }) - .map(futures::stream::once) - .map(|stream| stream.left_stream()), - ); - - /* If we allowed the other side to make direct connections to us, we have a listening - * socket that will attempt to complete handshakes - */ - if let Some(listener) = listener { - let transit_key = transit_key.clone(); - let tside = tside.clone(); - let stream = futures::stream::unfold(listener, |listener| async move { - let next = listener.accept().await; - Some((next, listener)) - }) - .err_into::() - .and_then(move |(socket, _listener)| { - /* We need to clone again because it will be moved into the returned future */ - let transit_key = transit_key.clone(); - let tside = tside.clone(); - handshake_exchange(is_leader, tside, socket, HostType::Direct, transit_key) - }) - .map(|result| result.map(|transit| (transit, HostType::Direct))); - iterator = Box::new(iterator.chain(std::iter::once(stream.right_stream()))) + .map(|fut| { + Box::pin(fut) as BoxFuture> + }), + ) + as BoxIterator>>; + + /* Also listen on some port just in case. */ + if let Some(socket2) = socket2 { + connectors = Box::new( + connectors.chain( + std::iter::once(async move { + let transit_key = transit_key.clone(); + let tside = tside.clone(); + let connect = || async { + let (stream, peer) = socket2.accept().await?; + log::debug!("Got connection from {}!", peer); + let transit = handshake_exchange( + is_leader, + tside.clone(), + stream, + HostType::Direct, + transit_key.clone(), + ) + .await?; + Result::<_, TransitHandshakeError>::Ok((transit, HostType::Direct)) + }; + loop { + match connect().await { + Ok(success) => break Ok(success), + Err(err) => { + log::debug!( + "Some handshake failed on the listening port: {:?}", + err + ); + continue; + }, + } + } + }) + .map(|fut| { + Box::pin(fut) + as BoxFuture> + }), + ), + ) + as BoxIterator>>; } - - futures::stream::select_all(iterator.map(Box::pin)) + connectors.collect::>() } } From 0141e1ca70a8272b7c88fb11f3ae8f0d416b10f0 Mon Sep 17 00:00:00 2001 From: piegames Date: Sat, 18 Sep 2021 01:21:25 +0200 Subject: [PATCH 5/9] Transit: fix wrong error message on sender hangup --- src/transit.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/transit.rs b/src/transit.rs index 64a8b0da..dc4573cc 100644 --- a/src/transit.rs +++ b/src/transit.rs @@ -734,7 +734,12 @@ impl Transit { // 2. read that many bytes into an array (or a vector?) let mut buffer = Vec::with_capacity(length); - socket.take(length as u64).read_to_end(&mut buffer).await?; + let len = socket.take(length as u64).read_to_end(&mut buffer).await?; + use std::io::{Error, ErrorKind}; + ensure!( + len == length, + Error::new(ErrorKind::UnexpectedEof, "failed to read whole message") + ); buffer }; From 07677477a89593d0c477c12dd03b8077f4a490f4 Mon Sep 17 00:00:00 2001 From: piegames Date: Sat, 18 Sep 2021 23:14:27 +0200 Subject: [PATCH 6/9] Transit: add support for STUN --- Cargo.toml | 4 +- src/transit.rs | 164 +++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 149 insertions(+), 19 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 40da2656..51715e44 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,7 +37,7 @@ futures_ringbuf = "0.3.1" tar = "0.4.33" chrono = "0.4.19" -derive_more = { version = "0.99.0", default-features = false, features = ["display", "deref"] } +derive_more = { version = "0.99.0", default-features = false, features = ["display", "deref", "from"] } thiserror = "1.0.24" futures = "0.3.12" @@ -46,6 +46,8 @@ async-tungstenite = { version = "0.14.0", features = ["async-std-runtime", "asyn async-io = "1.6.0" socket2 = "0.4.1" libc = "0.2.101" +stun_codec = "0.1.13" +bytecodec = "0.4.15" # for "bin" feature clap = { version = "2.33.3", optional = true } diff --git a/src/transit.rs b/src/transit.rs index dc4573cc..229d0172 100644 --- a/src/transit.rs +++ b/src/transit.rs @@ -267,6 +267,96 @@ async fn connect_custom( Ok(stream.into_inner()?.into()) } +/** Perform a STUN query to get the external IP address */ +async fn get_external_ip() -> std::io::Result<(std::net::SocketAddr, TcpStream)> { + let mut socket = connect_custom( + &"[::]:0".parse::().unwrap().into(), + &"stun.stunprotocol.org:3478" + .to_socket_addrs()? + /* If you find yourself behind a NAT66, open an issue */ + .find(|x| x.is_ipv4()) + /* TODO add a helper method to stdlib for this */ + .map(|addr| match addr { + std::net::SocketAddr::V4(v4) => std::net::SocketAddr::new( + std::net::IpAddr::V6(v4.ip().to_ipv6_mapped()), + v4.port(), + ), + std::net::SocketAddr::V6(_) => unreachable!(), + }) + .unwrap() + .into(), + ) + .await?; + + use bytecodec::{DecodeExt, EncodeExt}; + use std::net::{SocketAddr, ToSocketAddrs}; + use stun_codec::{ + rfc5389::{ + self, + attributes::{MappedAddress, Software, XorMappedAddress}, + Attribute, + }, + Message, MessageClass, MessageDecoder, MessageEncoder, TransactionId, + }; + + fn get_binding_request() -> Result, bytecodec::Error> { + use rand::Rng; + let random_bytes = rand::thread_rng().gen::<[u8; 12]>(); + + let mut message = Message::new( + MessageClass::Request, + rfc5389::methods::BINDING, + TransactionId::new(random_bytes), + ); + + message.add_attribute(Attribute::Software(Software::new( + "magic-wormhole-rust".to_owned(), + )?)); + + // Encodes the message + let mut encoder = MessageEncoder::new(); + let bytes = encoder.encode_into_bytes(message.clone())?; + Ok(bytes) + } + + fn decode_address(buf: &[u8]) -> Result { + let mut decoder = MessageDecoder::::new(); + let decoded = decoder.decode_from_bytes(buf)??; + + println!("Decoded message: {:?}", decoded); + + let external_addr1 = decoded + .get_attribute::() + .map(|x| x.address()); + //let external_addr2 = decoded.get_attribute::().map(|x|x.address()); + let external_addr3 = decoded + .get_attribute::() + .map(|x| x.address()); + let external_addr = external_addr1 + // .or(external_addr2) + .or(external_addr3); + let external_addr = external_addr.unwrap(); + + Ok(external_addr) + } + + /* Connect the plugs */ + + socket + .write_all(get_binding_request().expect("TODO").as_ref()) + .await?; + + let mut buf = [0u8; 256]; + /* Read header first */ + socket.read_exact(&mut buf[..20]).await?; + let len: u16 = u16::from_be_bytes([buf[2], buf[3]]); + /* Read the rest of the message */ + socket.read_exact(&mut buf[20..][..len as usize]).await?; + let external_addr = decode_address(&buf[..20 + len as usize]).expect("TODO"); + + Ok((external_addr, socket)) +} + /** * Initialize a relay handshake * @@ -279,28 +369,49 @@ pub async fn init( let mut our_hints = Hints::default(); let mut listener = None; - /* Detect our local IP addresses if the ability is enabled */ + /* Detect our IP addresses if the ability is enabled */ if abilities.contains(&Ability::DirectTcpV1) { - /* Bind a port and find out which addresses it has */ - let socket = - socket2::Socket::new(socket2::Domain::IPV6, socket2::Type::STREAM, None).unwrap(); - socket.set_nonblocking(false).unwrap(); - socket.set_reuse_address(true).unwrap(); - socket.set_reuse_port(true).unwrap(); - - socket - .bind(&"[::]:0".parse::().unwrap().into()) - .unwrap(); - - let port = socket.local_addr().unwrap().as_socket().unwrap().port(); + /* Do a STUN query to get our public IP. If it works, we must reuse the same socket (port) + * so that we will be NATted to the same port again. If it doesn't, simply bind a new socket + * and use that instead. + */ + let socket: MaybeConnectedSocket = match get_external_ip().await { + Ok((external_ip, stream)) => { + log::debug!("Our external IP address is {}", external_ip); + our_hints.direct_tcp.push(DirectHint { + hostname: external_ip.ip().to_string(), + port: external_ip.port(), + }); + stream.into() + }, + Err(err) => { + log::debug!("Failed to get external address via STUN, {}", err); + let socket = + socket2::Socket::new(socket2::Domain::IPV6, socket2::Type::STREAM, None) + .unwrap(); + socket.set_nonblocking(false).unwrap(); + socket.set_reuse_address(true).unwrap(); + socket.set_reuse_port(true).unwrap(); + + socket + .bind(&"[::]:0".parse::().unwrap().into()) + .unwrap(); + + socket.into() + }, + }; - /* Do the same thing again, but this time open a listener on that port. + /* Get a second socket, but this time open a listener on that port. * This sadly doubles the number of hints, but the method above doesn't work - * for systems which don't have any firewalls. + * for systems which don't have any firewalls. Also, this time we can't reuse + * the port. In theory, we could, but it really confused the kernel to the point + * of `accept` calls never returning again. */ let socket2 = TcpListener::bind("[::]:0").await?; - let port2 = socket2.local_addr().unwrap().port(); + /* Find our ports, iterate all our local addresses, combine them with the ports and that's our hints */ + let port = socket.local_addr()?.as_socket().unwrap().port(); + let port2 = socket2.local_addr()?.port(); our_hints.direct_tcp.extend( get_if_addrs::get_if_addrs()? .iter() @@ -336,6 +447,23 @@ pub async fn init( }) } +#[derive(derive_more::From)] +enum MaybeConnectedSocket { + #[from] + Socket(socket2::Socket), + #[from] + Stream(TcpStream), +} + +impl MaybeConnectedSocket { + fn local_addr(&self) -> std::io::Result { + match &self { + Self::Socket(socket) => socket.local_addr(), + Self::Stream(stream) => Ok(stream.local_addr()?.into()), + } + } +} + /** * A partially set up [`Transit`] connection. * @@ -348,7 +476,7 @@ pub struct TransitConnector { * The first socket is the port from which we will start connection attempts. * For in case the user is behind no firewalls, we must also listen to the second socket. */ - sockets: Option<(socket2::Socket, TcpListener)>, + sockets: Option<(MaybeConnectedSocket, TcpListener)>, our_abilities: Arc>, our_hints: Arc, } @@ -540,7 +668,7 @@ impl TransitConnector { our_hints: Arc, their_abilities: Arc>, their_hints: Arc, - socket: Option<(socket2::Socket, TcpListener)>, + socket: Option<(MaybeConnectedSocket, TcpListener)>, ) -> impl Stream> + 'static { assert!(socket.is_some() == our_abilities.contains(&Ability::DirectTcpV1)); From 324d100b08f271eccee0773285c9ec4fa4932c45 Mon Sep 17 00:00:00 2001 From: piegames Date: Sat, 18 Sep 2021 23:39:40 +0200 Subject: [PATCH 7/9] Fix Windows compilation --- src/transit.rs | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/src/transit.rs b/src/transit.rs index 229d0172..dd32585c 100644 --- a/src/transit.rs +++ b/src/transit.rs @@ -225,6 +225,30 @@ impl FromStr for RelayUrl { } } +fn set_socket_opts(socket: &socket2::Socket) -> std::io::Result<()> { + socket.set_nonblocking(true)?; + + /* See https://stackoverflow.com/a/14388707/6094756. + * On most BSD and Linux systems, we need both REUSEADDR and REUSEPORT; + * and if they don't support the latter we won't compile. + * On Windows, there is only REUSEADDR but it does what we want. + */ + socket.set_reuse_address(true)?; + #[cfg(all(unix, not(any(target_os = "solaris", target_os = "illumos"))))] + { + socket.set_reuse_port(true)?; + } + #[cfg(not(any( + all(unix, not(any(target_os = "solaris", target_os = "illumos"))), + target_os = "windows" + )))] + { + compile_error!("Your system is not supported yet, please raise an error"); + } + + Ok(()) +} + /** * Bind to a port with SO_REUSEADDR, connect to the destination and then hide the blood behind a pretty [`async_std::net::TcpStream`] * @@ -238,10 +262,8 @@ async fn connect_custom( dest_addr: &socket2::SockAddr, ) -> std::io::Result { let socket = socket2::Socket::new(socket2::Domain::IPV6, socket2::Type::STREAM, None)?; - socket.set_nonblocking(true)?; /* Set our custum options */ - socket.set_reuse_address(true)?; - socket.set_reuse_port(true)?; + set_socket_opts(&socket)?; socket.bind(local_addr)?; @@ -389,9 +411,7 @@ pub async fn init( let socket = socket2::Socket::new(socket2::Domain::IPV6, socket2::Type::STREAM, None) .unwrap(); - socket.set_nonblocking(false).unwrap(); - socket.set_reuse_address(true).unwrap(); - socket.set_reuse_port(true).unwrap(); + set_socket_opts(&socket)?; socket .bind(&"[::]:0".parse::().unwrap().into()) From 4810f27b7ae43470a7fbc37cba99944654e89b38 Mon Sep 17 00:00:00 2001 From: piegames Date: Sun, 19 Sep 2021 00:17:18 +0200 Subject: [PATCH 8/9] Code cleanup --- src/transfer/messages.rs | 34 ++++++++++------- src/transit.rs | 82 ++++++++++++++++++++++------------------ 2 files changed, 67 insertions(+), 49 deletions(-) diff --git a/src/transfer/messages.rs b/src/transfer/messages.rs index 5078ed31..fa5dffda 100644 --- a/src/transfer/messages.rs +++ b/src/transfer/messages.rs @@ -1,13 +1,13 @@ //! Over-the-wire messages for the file transfer (including transit) //! //! The transit protocol does not specify how to deliver the information to -//! the other side, so it is up to the file transfer to do that. +//! the other side, so it is up to the file transfer to do that. hfoo use crate::transit::{self, Ability, DirectHint}; use serde_derive::{Deserialize, Serialize}; #[cfg(test)] use serde_json::json; -use std::path::PathBuf; +use std::{collections::HashSet, path::PathBuf}; /** * The type of message exchanged over the wormhole for this protocol @@ -131,15 +131,17 @@ impl From for Vec { impl Into for Vec { fn into(self) -> transit::Hints { - let mut direct_tcp = Vec::new(); - let mut relay = Vec::new(); + let mut direct_tcp = HashSet::new(); + let mut relay = HashSet::new(); /* There is only one "relay hint", though it may contain multiple * items. Yes, this is inconsistent and weird, watch your step. */ for hint in self { match hint { - Hint::DirectTcpV1(hint) => direct_tcp.push(hint), + Hint::DirectTcpV1(hint) => { + direct_tcp.insert(hint); + }, Hint::DirectUdtV1(_) => unimplemented!(), Hint::RelayV1(RelayHint { hints }) => relay.extend(hints), } @@ -163,22 +165,24 @@ pub enum Hint { } impl Hint { - pub fn new_direct_tcp(priority: f32, hostname: &str, port: u16) -> Self { + pub fn new_direct_tcp(_priority: f32, hostname: &str, port: u16) -> Self { Hint::DirectTcpV1(DirectHint { hostname: hostname.to_string(), port, }) } - pub fn new_direct_udt(priority: f32, hostname: &str, port: u16) -> Self { + pub fn new_direct_udt(_priority: f32, hostname: &str, port: u16) -> Self { Hint::DirectUdtV1(DirectHint { hostname: hostname.to_string(), port, }) } - pub fn new_relay(h: Vec) -> Self { - Hint::RelayV1(RelayHint { hints: h }) + pub fn new_relay(h: HashSet) -> Self { + Hint::RelayV1(RelayHint { + hints: h.into_iter().collect(), + }) } } @@ -196,10 +200,14 @@ mod test { let abilities = vec![Ability::DirectTcpV1, Ability::RelayV1]; let hints = vec![ Hint::new_direct_tcp(0.0, "192.168.1.8", 46295), - Hint::new_relay(vec![DirectHint { - hostname: "magic-wormhole-transit.debian.net".to_string(), - port: 4001, - }]), + Hint::new_relay( + vec![DirectHint { + hostname: "magic-wormhole-transit.debian.net".to_string(), + port: 4001, + }] + .into_iter() + .collect(), + ), ]; let t = crate::transfer::PeerMessage::new_transit(abilities, hints); assert_eq!(t.serialize(), "{\"transit\":{\"abilities-v1\":[{\"type\":\"direct-tcp-v1\"},{\"type\":\"relay-v1\"}],\"hints-v1\":[{\"hostname\":\"192.168.1.8\",\"port\":46295,\"type\":\"direct-tcp-v1\"},{\"hints\":[{\"hostname\":\"magic-wormhole-transit.debian.net\",\"port\":4001}],\"type\":\"relay-v1\"}]}}") diff --git a/src/transit.rs b/src/transit.rs index dd32585c..e060aa31 100644 --- a/src/transit.rs +++ b/src/transit.rs @@ -159,8 +159,8 @@ impl Ability { #[derive(Clone, Debug, Default)] pub struct Hints { - pub direct_tcp: Vec, - pub relay: Vec, + pub direct_tcp: HashSet, + pub relay: HashSet, } #[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq, Hash, derive_more::Display)] @@ -289,8 +289,26 @@ async fn connect_custom( Ok(stream.into_inner()?.into()) } +#[derive(Debug, thiserror::Error)] +enum StunError { + #[error("No V4 addresses were found for the selected STUN server")] + ServerIsV4Only, + #[error("IO error")] + IO( + #[from] + #[source] + std::io::Error, + ), + #[error("Malformed STUN packet")] + Codec( + #[from] + #[source] + bytecodec::Error, + ), +} + /** Perform a STUN query to get the external IP address */ -async fn get_external_ip() -> std::io::Result<(std::net::SocketAddr, TcpStream)> { +async fn get_external_ip() -> Result<(std::net::SocketAddr, TcpStream), StunError> { let mut socket = connect_custom( &"[::]:0".parse::().unwrap().into(), &"stun.stunprotocol.org:3478" @@ -305,7 +323,7 @@ async fn get_external_ip() -> std::io::Result<(std::net::SocketAddr, TcpStream)> ), std::net::SocketAddr::V6(_) => unreachable!(), }) - .unwrap() + .ok_or(StunError::ServerIsV4Only)? .into(), ) .await?; @@ -364,9 +382,7 @@ async fn get_external_ip() -> std::io::Result<(std::net::SocketAddr, TcpStream)> /* Connect the plugs */ - socket - .write_all(get_binding_request().expect("TODO").as_ref()) - .await?; + socket.write_all(get_binding_request()?.as_ref()).await?; let mut buf = [0u8; 256]; /* Read header first */ @@ -374,7 +390,7 @@ async fn get_external_ip() -> std::io::Result<(std::net::SocketAddr, TcpStream)> let len: u16 = u16::from_be_bytes([buf[2], buf[3]]); /* Read the rest of the message */ socket.read_exact(&mut buf[20..][..len as usize]).await?; - let external_addr = decode_address(&buf[..20 + len as usize]).expect("TODO"); + let external_addr = decode_address(&buf[..20 + len as usize])?; Ok((external_addr, socket)) } @@ -400,7 +416,7 @@ pub async fn init( let socket: MaybeConnectedSocket = match get_external_ip().await { Ok((external_ip, stream)) => { log::debug!("Our external IP address is {}", external_ip); - our_hints.direct_tcp.push(DirectHint { + our_hints.direct_tcp.insert(DirectHint { hostname: external_ip.ip().to_string(), port: external_ip.port(), }); @@ -454,7 +470,7 @@ pub async fn init( } if abilities.contains(&Ability::RelayV1) { - our_hints.relay.push(DirectHint { + our_hints.relay.insert(DirectHint { hostname: relay_url.host.clone(), port: relay_url.port, }); @@ -702,7 +718,6 @@ impl TransitConnector { type BoxIterator = Box>; type ConnectorFuture = BoxFuture<'static, Result<(TcpStream, HostType), TransitHandshakeError>>; - // type ConnectorIterator = Box>; let mut connectors: BoxIterator = Box::new(std::iter::empty()); /* Create direct connection sockets, if we support it. If peer doesn't support it, their list of hints will @@ -710,7 +725,6 @@ impl TransitConnector { */ let socket2 = if let Some((socket, socket2)) = socket { let local_addr = Arc::new(socket.local_addr().unwrap()); - dbg!(&their_hints.direct_tcp); /* Connect to each hint of the peer */ connectors = Box::new( connectors.chain( @@ -741,33 +755,29 @@ impl TransitConnector { /* Relay hints. Make sure that both sides adverize it, since it is fine to support it without providing own hints. */ if our_abilities.contains(&Ability::RelayV1) && their_abilities.contains(&Ability::RelayV1) { + /* Collect intermediate into HashSet for deduplication */ + let relay_hints = our_hints + .relay + .clone() + .into_iter() + .take(2) + .chain(their_hints.relay.clone().into_iter().take(2)) + .collect::>(); connectors = Box::new( connectors.chain( - /* TODO maybe take 2 at random instead of always the first two? */ - /* TODO also deduplicate the results list */ - our_hints - .relay - .clone() - .into_iter() - .take(2) - .chain( - their_hints - .relay - .clone() - .into_iter() - .take(2) - ) - .map(|host| async move { - log::debug!("Connecting to relay {}", host); - let transit = TcpStream::connect((host.hostname.as_str(), host.port)) - .err_into::() - .await?; - log::debug!("Connected to {}!", host); + relay_hints + .into_iter() + .map(|host| async move { + log::debug!("Connecting to relay {}", host); + let transit = TcpStream::connect((host.hostname.as_str(), host.port)) + .err_into::() + .await?; + log::debug!("Connected to {}!", host); - Ok((transit, HostType::Relay)) - }) - .map(|fut| Box::pin(fut) as ConnectorFuture) - ), + Ok((transit, HostType::Relay)) + }) + .map(|fut| Box::pin(fut) as ConnectorFuture), + ), ) as BoxIterator; } From 3e3dd4ee592a25cfaa630c3edcb8149df9dc2a03 Mon Sep 17 00:00:00 2001 From: piegames Date: Tue, 21 Sep 2021 01:10:32 +0200 Subject: [PATCH 9/9] Add timeout to STUN and switch to self-hosted server --- src/transit.rs | 56 ++++++++++++++++++++++++++++++-------------------- 1 file changed, 34 insertions(+), 22 deletions(-) diff --git a/src/transit.rs b/src/transit.rs index e060aa31..7a81a31f 100644 --- a/src/transit.rs +++ b/src/transit.rs @@ -29,6 +29,10 @@ use xsalsa20poly1305::aead::{Aead, NewAead}; /// ULR to a default hosted relay server. Please don't abuse or DOS. pub const DEFAULT_RELAY_SERVER: &str = "tcp:transit.magic-wormhole.io:4001"; +// No need to make public, it's hard-coded anyways (: +// Open an issue if you want an API for this +// Use for non-production testing +const PUBLIC_STUN_SERVER: &str = "stun.piegames.de:3478"; #[derive(Debug)] pub struct TransitKey; @@ -293,6 +297,8 @@ async fn connect_custom( enum StunError { #[error("No V4 addresses were found for the selected STUN server")] ServerIsV4Only, + #[error("Connection timed out")] + Timeout, #[error("IO error")] IO( #[from] @@ -311,7 +317,7 @@ enum StunError { async fn get_external_ip() -> Result<(std::net::SocketAddr, TcpStream), StunError> { let mut socket = connect_custom( &"[::]:0".parse::().unwrap().into(), - &"stun.stunprotocol.org:3478" + &PUBLIC_STUN_SERVER .to_socket_addrs()? /* If you find yourself behind a NAT66, open an issue */ .find(|x| x.is_ipv4()) @@ -413,29 +419,35 @@ pub async fn init( * so that we will be NATted to the same port again. If it doesn't, simply bind a new socket * and use that instead. */ - let socket: MaybeConnectedSocket = match get_external_ip().await { - Ok((external_ip, stream)) => { - log::debug!("Our external IP address is {}", external_ip); - our_hints.direct_tcp.insert(DirectHint { - hostname: external_ip.ip().to_string(), - port: external_ip.port(), - }); - stream.into() - }, - Err(err) => { - log::debug!("Failed to get external address via STUN, {}", err); - let socket = - socket2::Socket::new(socket2::Domain::IPV6, socket2::Type::STREAM, None) + let socket: MaybeConnectedSocket = + match async_std::future::timeout(std::time::Duration::from_secs(4), get_external_ip()) + .await + .map_err(|_| StunError::Timeout) + { + Ok(Ok((external_ip, stream))) => { + log::debug!("Our external IP address is {}", external_ip); + our_hints.direct_tcp.insert(DirectHint { + hostname: external_ip.ip().to_string(), + port: external_ip.port(), + }); + stream.into() + }, + // TODO replace with .flatten() once stable + // https://github.com/rust-lang/rust/issues/70142 + Err(err) | Ok(Err(err)) => { + log::debug!("Failed to get external address via STUN, {}", err); + let socket = + socket2::Socket::new(socket2::Domain::IPV6, socket2::Type::STREAM, None) + .unwrap(); + set_socket_opts(&socket)?; + + socket + .bind(&"[::]:0".parse::().unwrap().into()) .unwrap(); - set_socket_opts(&socket)?; - - socket - .bind(&"[::]:0".parse::().unwrap().into()) - .unwrap(); - socket.into() - }, - }; + socket.into() + }, + }; /* Get a second socket, but this time open a listener on that port. * This sadly doubles the number of hints, but the method above doesn't work