Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: improve derp connection establishment #1631

Merged
merged 10 commits into from
Oct 16, 2023
154 changes: 105 additions & 49 deletions iroh-net/src/derp/http/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ use std::time::Duration;
use anyhow::bail;
use bytes::Bytes;
use futures::future::BoxFuture;
use futures::StreamExt;
use hyper::upgrade::{Parts, Upgraded};
use hyper::{header::UPGRADE, Body, Request};
use iroh_metrics::inc;
use rand::Rng;
use rustls::client::Resumption;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio::net::TcpStream;
use tokio::sync::oneshot;
Expand Down Expand Up @@ -47,8 +49,8 @@ pub enum ClientError {
#[error("error sending a packet")]
Send,
/// There was an error receiving a packet
#[error("error receiving a packet")]
Receive,
#[error("error receiving a packet: {0:?}")]
Receive(anyhow::Error),
/// There was a connection timeout error
#[error("connect timeout")]
ConnectTimeout,
Expand Down Expand Up @@ -88,6 +90,9 @@ pub enum ClientError {
/// The ping request timed out
#[error("ping timeout")]
PingTimeout,
/// The ping request was aborted
#[error("ping aborted")]
PingAborted,
/// This [`Client`] cannot acknowledge pings
#[error("cannot acknowledge pings")]
CannotAckPings,
Expand Down Expand Up @@ -137,6 +142,7 @@ struct InnerClient {
is_prober: bool,
server_public_key: Option<PublicKey>,
url: Option<Url>,
tls_connector: tokio_rustls::TlsConnector,
}

/// Build a Client.
Expand Down Expand Up @@ -245,6 +251,29 @@ impl ClientBuilder {
/// Will error if there is no region or no url set.
pub fn build(self, key: SecretKey) -> anyhow::Result<Client> {
anyhow::ensure!(self.get_region.is_some() || self.url.is_some(), "The `get_region` call back or `server_url` must be set so the Client knows how to dial the derp server.");

// TODO: review TLS config
let mut roots = rustls::RootCertStore::empty();
roots.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
let mut config = rustls::client::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots)
.with_no_client_auth();
#[cfg(test)]
config
.dangerous()
.set_certificate_verifier(Arc::new(NoCertVerifier));

config.resumption = Resumption::default();

let tls_connector: tokio_rustls::TlsConnector = Arc::new(config).into();

Ok(Client {
inner: Arc::new(InnerClient {
secret_key: key,
Expand All @@ -260,6 +289,7 @@ impl ClientBuilder {
is_prober: self.is_prober,
server_public_key: self.server_public_key,
url: self.url,
tls_connector,
}),
})
}
Expand Down Expand Up @@ -460,30 +490,15 @@ impl Client {

let res = if self.use_https(derp_node.as_deref()) {
debug!("Starting TLS handshake");
// TODO: review TLS config
let mut roots = rustls::RootCertStore::empty();
roots.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
#[allow(unused_mut)]
let mut config = rustls::client::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(roots)
.with_no_client_auth();
#[cfg(test)]
config
.dangerous()
.set_certificate_verifier(Arc::new(NoCertVerifier));

let tls_connector: tokio_rustls::TlsConnector = Arc::new(config).into();

let hostname = self
.tls_servername(derp_node.as_deref())
.ok_or_else(|| ClientError::InvalidUrl("no tls servername".into()))?;
let tls_stream = tls_connector.connect(hostname, tcp_stream).await?;
let tls_stream = self
.inner
.tls_connector
.connect(hostname, tcp_stream)
.await?;
debug!("tls_connector connect success");
let (mut request_sender, connection) = hyper::client::conn::Builder::new()
.handshake(tls_stream)
Expand Down Expand Up @@ -569,7 +584,8 @@ impl Client {
derp_client.close().await;
return Err(ClientError::Send);
}
debug!("built");

trace!("connect_0 done");
Ok(derp_client)
}

Expand Down Expand Up @@ -624,24 +640,44 @@ impl Client {
if reg.nodes.is_empty() {
return Err(ClientError::NoNodeForTarget(target));
}
let mut first_err: Option<ClientError> = None;
// TODO (ramfox): these dials should probably happen in parallel, and we should return the
// first one to respond.
for node in reg.nodes.iter() {
if node.stun_only {
if first_err.is_none() {
first_err = Some(ClientError::StunOnlyNodesFound(target.clone()));
// usually 1 IPv4, 1 IPv6 and 2x http
const DIAL_PARALLELISM: usize = 4;

let this = self.clone();
let mut dials = futures::stream::iter(reg.nodes.clone().into_iter())
.map(|node| {
let this = this.clone();
let target = target.clone();
async move {
if node.stun_only {
return Err(ClientError::StunOnlyNodesFound(target));
}
let conn = this.dial_node(&node).await;
match conn {
Ok(conn) => Ok((conn, node)),
Err(e) => Err(e),
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optional, but isn't this this.dial_node(&node).await.map(|c| (c, node))?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

}
})
.buffer_unordered(DIAL_PARALLELISM);

let mut first_err = None;
while let Some(res) = dials.next().await {
match res {
Ok((conn, node)) => {
// return on the first successfull one
trace!("dialed region");
return Ok((conn, node));
}
Err(e) => {
if first_err.is_none() {
first_err = Some(e);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

related to @rklaehn 's comment below, unless you return a Vec of errors you'll always have an arbitrary choice. Probably easier is to tracing::error!() log all the error cases here? or at least warn!.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

}
}
continue;
}
let conn = self.dial_node(node).await;
match conn {
Ok(conn) => return Ok((conn, node.clone())),
Err(e) => first_err = Some(e),
}
}
let err = first_err.unwrap();
Err(err)

Err(first_err.unwrap())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems a bit arbitrary to just return the first error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logging them at least now

}

/// Returns a TCP connection to node n, racing IPv4 and IPv6
Expand Down Expand Up @@ -704,6 +740,7 @@ impl Client {
node: &DerpNode,
dst_primary: UseIp,
) -> Result<TcpStream, ClientError> {
trace!("dial start: {:?}", dst_primary);
if matches!(dst_primary, UseIp::Ipv4(_)) && self.prefer_ipv6().await {
tokio::time::sleep(Duration::from_millis(200)).await;
// Start v4 dial
Expand Down Expand Up @@ -758,28 +795,37 @@ impl Client {
.map_err(ClientError::DialIO)?;
// TODO: ipv6 vs ipv4 specific connection

trace!("dial done: {:?}", dst_primary);
Ok(tcp_stream)
}

/// Send a ping to the server. Return once we get an expected pong.
///
/// There must be a task polling `recv_detail` to process the `pong` response.
pub async fn ping(&self) -> Result<(), ClientError> {
debug!("ping");
let (client, _) = self.connect().await?;
pub async fn ping(&self) -> Result<Duration, ClientError> {
let ping = rand::thread_rng().gen::<[u8; 8]>();
debug!("ping: {}", hex::encode(ping));
let (client, _) = self.connect().await?;

let start = Instant::now();
let (send, recv) = oneshot::channel();
self.register_ping(ping, send).await;
if client.send_ping(ping).await.is_err() {
self.close_for_reconnect().await;
let _ = self.unregister_ping(ping).await;
return Err(ClientError::Send);
}
if tokio::time::timeout(PING_TIMEOUT, recv).await.is_err() {
self.unregister_ping(ping).await;
return Err(ClientError::PingTimeout);
match tokio::time::timeout(PING_TIMEOUT, recv).await {
Ok(Ok(())) => Ok(start.elapsed()),
Err(_) => {
self.unregister_ping(ping).await;
Err(ClientError::PingTimeout)
}
Ok(Err(_)) => {
self.unregister_ping(ping).await;
Err(ClientError::PingAborted)
}
}
Ok(())
}

/// Send a pong back to the server.
Expand Down Expand Up @@ -834,6 +880,7 @@ impl Client {
}

if let ReceivedMessage::Pong(ping) = msg {
trace!("got pong: {}", hex::encode(ping));
if let Some(chan) = self.unregister_ping(ping).await {
if chan.send(()).is_err() {
warn!("pong recieved for ping {ping:?}, but the receiving channel was closed");
Expand All @@ -843,13 +890,13 @@ impl Client {
}
return Ok((msg, conn_gen));
}
Err(_) => {
Err(e) => {
self.close_for_reconnect().await;
if self.inner.is_closed.load(Ordering::SeqCst) {
return Err(ClientError::Closed);
}
// TODO(ramfox): more specific error?
return Err(ClientError::Receive);
return Err(ClientError::Receive(e));
}
}
}
Expand All @@ -874,7 +921,8 @@ impl Client {

/// Close the underlying derp connection. The next time the client takes some action that
/// requires a connection, it will call `connect`.
async fn close_for_reconnect(&self) {
pub async fn close_for_reconnect(&self) {
debug!("close for reconnect");
let mut client = self.inner.derp_client.lock().await;
if let Some(client) = client.take() {
client.close().await
Expand All @@ -887,6 +935,14 @@ impl Client {
self.close_for_reconnect().await;
}

/// Returns `true` if the underyling derp connection is established.
pub async fn is_connected(&self) -> bool {
if self.inner.is_closed.load(Ordering::Relaxed) {
return false;
}
self.inner.derp_client.lock().await.is_some()
}

/// Send a request to subscribe as a "watcher" on the server.
///
/// This returns the public key of the remote derp server that we have meshed to,
Expand Down
Loading
Loading