Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(iroh-net): simplify relay handshake #2164

Merged
merged 14 commits into from
Apr 15, 2024
394 changes: 200 additions & 194 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions iroh-net/src/defaults.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ use url::Url;
use crate::relay::{RelayMap, RelayNode};

/// Hostname of the default NA relay.
pub const NA_RELAY_HOSTNAME: &str = "use1-1.derp.iroh.network.";
pub const NA_RELAY_HOSTNAME: &str = "use1-1.relay.iroh.network.";
/// Hostname of the default EU relay.
pub const EU_RELAY_HOSTNAME: &str = "euw1-1.derp.iroh.network.";
pub const EU_RELAY_HOSTNAME: &str = "euw1-1.relay.iroh.network.";

/// STUN port as defined by [RFC 8489](<https://www.rfc-editor.org/rfc/rfc8489#section-18.6>)
pub const DEFAULT_RELAY_STUN_PORT: u16 = 3478;
Expand Down
102 changes: 12 additions & 90 deletions iroh-net/src/relay/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@ use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;

use anyhow::{anyhow, bail, ensure, Context, Result};
use anyhow::{anyhow, bail, ensure, Result};
use bytes::Bytes;
use futures::stream::Stream;
use futures::{Sink, SinkExt, StreamExt};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::mpsc;
Expand All @@ -15,10 +14,10 @@ use tracing::{debug, info_span, trace, Instrument};
use super::codec::PER_CLIENT_READ_QUEUE_DEPTH;
use super::{
codec::{
recv_frame, write_frame, DerpCodec, Frame, FrameType, MAX_PACKET_SIZE,
PER_CLIENT_SEND_QUEUE_DEPTH, PROTOCOL_VERSION,
write_frame, DerpCodec, Frame, MAX_PACKET_SIZE, PER_CLIENT_SEND_QUEUE_DEPTH,
PROTOCOL_VERSION,
},
types::{ClientInfo, RateLimiter, ServerInfo},
types::{ClientInfo, RateLimiter},
};

use crate::key::{PublicKey, SecretKey};
Expand Down Expand Up @@ -74,8 +73,6 @@ pub struct InnerClient {
/// JoinHandle for the [`ClientWriter`] task
writer_task: AbortingJoinHandle<Result<()>>,
reader_task: AbortingJoinHandle<()>,
/// [`PublicKey`] of the server we are connected to
server_public_key: PublicKey,
}

impl Client {
Expand Down Expand Up @@ -150,11 +147,6 @@ impl Client {
.ok();
self.inner.reader_task.abort();
}

/// The [`PublicKey`] of the [`super::server::Server`] this [`Client`] is connected with.
pub fn server_public_key(self) -> PublicKey {
self.inner.server_public_key
}
}

fn process_incoming_frame(frame: Frame) -> Result<ReceivedMessage> {
Expand Down Expand Up @@ -255,9 +247,6 @@ pub struct ClientBuilder {
reader: RelayReader,
writer: FramedWrite<Box<dyn AsyncWrite + Unpin + Send + Sync + 'static>, DerpCodec>,
local_addr: SocketAddr,
is_prober: bool,
server_public_key: Option<PublicKey>,
can_ack_pings: bool,
}

impl ClientBuilder {
Expand All @@ -272,84 +261,28 @@ impl ClientBuilder {
reader: FramedRead::new(reader, DerpCodec),
writer: FramedWrite::new(writer, DerpCodec),
local_addr,
is_prober: false,
server_public_key: None,
can_ack_pings: false,
}
}

pub fn prober(mut self, is_prober: bool) -> Self {
self.is_prober = is_prober;
self
}

// Set the expected server_public_key. If this is not what is sent by the
// [`super::server::Server`], it is an error.
pub fn server_public_key(mut self, key: Option<PublicKey>) -> Self {
self.server_public_key = key;
self
}

pub fn can_ack_pings(mut self, can_ack_pings: bool) -> Self {
self.can_ack_pings = can_ack_pings;
self
}

async fn server_handshake(&mut self) -> Result<(PublicKey, Option<RateLimiter>)> {
async fn server_handshake(&mut self) -> Result<Option<RateLimiter>> {
debug!("server_handshake: started");
let server_key = recv_server_key(&mut self.reader)
.await
.context("failed to receive server key")?;

debug!("server_handshake: received server_key: {:?}", server_key);

if let Some(expected_key) = &self.server_public_key {
if *expected_key != server_key {
bail!("unexpected server key, expected {expected_key:?} got {server_key:?}");
}
}
let client_info = ClientInfo {
version: PROTOCOL_VERSION,
can_ack_pings: self.can_ack_pings,
is_prober: self.is_prober,
mesh_key: None,
};
debug!("server_handshake: sending client_key: {:?}", &client_info);
let shared_secret = self.secret_key.shared(&server_key);
crate::relay::codec::send_client_key(
&mut self.writer,
&shared_secret,
&self.secret_key.public(),
&client_info,
)
.await?;

let Frame::ServerInfo { encrypted_message } =
recv_frame(FrameType::ServerInfo, &mut self.reader).await?
else {
bail!("expected server info");
};
let mut buf = encrypted_message.to_vec();
shared_secret.open(&mut buf)?;
let info: ServerInfo = postcard::from_bytes(&buf)?;
if info.version != PROTOCOL_VERSION {
bail!(
"incompatible protocol version, expected {PROTOCOL_VERSION}, got {}",
info.version
);
}
let rate_limiter = RateLimiter::new(
info.token_bucket_bytes_per_second,
info.token_bucket_bytes_burst,
)?;
crate::relay::codec::send_client_key(&mut self.writer, &self.secret_key, &client_info)
.await?;

// TODO: add some actual configuration
let rate_limiter = RateLimiter::new(0, 0)?;

debug!("server_handshake: done");
Ok((server_key, rate_limiter))
Ok(rate_limiter)
}

pub async fn build(mut self) -> Result<(Client, ClientReceiver)> {
// exchange information with the server
let (server_public_key, rate_limiter) = self.server_handshake().await?;
let rate_limiter = self.server_handshake().await?;

// create task to handle writing to the server
let (writer_sender, writer_recv) = mpsc::channel(PER_CLIENT_SEND_QUEUE_DEPTH);
Expand Down Expand Up @@ -411,7 +344,6 @@ impl ClientBuilder {
writer_channel: writer_sender,
writer_task: writer_task.into(),
reader_task: reader_task.into(),
server_public_key,
}),
};

Expand All @@ -423,16 +355,6 @@ impl ClientBuilder {
}
}

pub(crate) async fn recv_server_key<S: Stream<Item = anyhow::Result<Frame>> + Unpin>(
stream: S,
) -> Result<PublicKey> {
if let Frame::ServerKey { key } = recv_frame(FrameType::ServerKey, stream).await? {
Ok(key)
} else {
bail!("expected server key");
}
}

#[derive(derive_more::Debug, Clone)]
/// The type of message received by the [`Client`] from the [`super::server::Server`].
pub enum ReceivedMessage {
Expand Down
Loading
Loading