Skip to content

Commit

Permalink
refactor: Deduplicate endpoint building logic
Browse files Browse the repository at this point in the history
This is an internal-only refactor to deduplicate the code for building
an endpoint, between the `new` and `new_client` constructors.
  • Loading branch information
Chris Connelly committed Aug 26, 2021
1 parent 335bc6c commit 9537cfc
Showing 1 changed file with 50 additions and 46 deletions.
96 changes: 50 additions & 46 deletions src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,33 +137,8 @@ impl<I: ConnId> Endpoint<I> {
let mut builder = quinn::Endpoint::builder();
let _ = builder.listen(config.server.clone());

let local_addr = local_addr.into();
let (quic_endpoint, quic_incoming) = builder.bind(&local_addr)?;
let local_addr = quic_endpoint.local_addr().map_err(EndpointError::Socket)?;

let (message_tx, message_rx) = mpsc::channel(STANDARD_CHANNEL_SIZE);
let (connection_tx, connection_rx) = mpsc::channel(STANDARD_CHANNEL_SIZE);
let (disconnection_tx, disconnection_rx) = mpsc::channel(STANDARD_CHANNEL_SIZE);

#[cfg(feature = "no-igd")]
let (termination_tx, _) = broadcast::channel(1);

#[cfg(not(feature = "no-igd"))]
let (termination_tx, termination_rx) = broadcast::channel(1);

let connection_pool = ConnectionPool::new();

let mut endpoint = Self {
local_addr,
public_addr: None,
quic_endpoint,
message_tx: message_tx.clone(),
disconnection_tx: disconnection_tx.clone(),
config,
termination_tx,
connection_pool: connection_pool.clone(),
connection_deduplicator: ConnectionDeduplicator::new(),
};
let (mut endpoint, quic_incoming, channels) =
Self::build_endpoint(local_addr.into(), config, builder)?;

let contact = endpoint.connect_to_any(contacts).await;
let public_addr = endpoint.resolve_public_addr(contact).await?;
Expand All @@ -174,9 +149,9 @@ impl<I: ConnId> Endpoint<I> {
PORT_FORWARD_TIMEOUT,
forward_port(
public_addr.port(),
local_addr,
endpoint.local_addr(),
endpoint.config.upnp_lease_duration,
termination_rx,
channels.termination.1,
),
)
.await
Expand All @@ -185,10 +160,10 @@ impl<I: ConnId> Endpoint<I> {

listen_for_incoming_connections(
quic_incoming,
connection_pool,
message_tx,
connection_tx,
disconnection_tx,
endpoint.connection_pool.clone(),
channels.message.0.clone(),
channels.connection.0,
channels.disconnection.0.clone(),
endpoint.clone(),
);

Expand All @@ -204,9 +179,9 @@ impl<I: ConnId> Endpoint<I> {

Ok((
endpoint,
IncomingConnections(connection_rx),
IncomingMessages(message_rx),
DisconnectionEvents(disconnection_rx),
IncomingConnections(channels.connection.1),
IncomingMessages(channels.message.1),
DisconnectionEvents(channels.disconnection.1),
contact,
))
}
Expand All @@ -222,29 +197,38 @@ impl<I: ConnId> Endpoint<I> {
) -> Result<Self, ClientEndpointError> {
let config = InternalConfig::try_from_config(config)?;

let local_addr = local_addr.into();
let (quic_endpoint, _) = quinn::Endpoint::builder().bind(&local_addr)?;
let (endpoint, _, _) =
Self::build_endpoint(local_addr.into(), config, quinn::Endpoint::builder())?;

Ok(endpoint)
}

// A private helper for initialising an endpoint.
fn build_endpoint(
local_addr: SocketAddr,
config: InternalConfig,
builder: quinn::EndpointBuilder,
) -> Result<(Self, quinn::Incoming, Channels), quinn::EndpointError> {
let (quic_endpoint, quic_incoming) = builder.bind(&local_addr)?;
let local_addr = quic_endpoint
.local_addr()
.map_err(ClientEndpointError::Socket)?;
.map_err(quinn::EndpointError::Socket)?;

let (message_tx, _) = mpsc::channel(STANDARD_CHANNEL_SIZE);
let (disconnection_tx, _) = mpsc::channel(STANDARD_CHANNEL_SIZE);
let (termination_tx, _) = broadcast::channel(1);
let channels = Channels::new();

let endpoint = Self {
local_addr,
public_addr: None,
quic_endpoint,
message_tx,
disconnection_tx,
message_tx: channels.message.0.clone(),
disconnection_tx: channels.disconnection.0.clone(),
config,
termination_tx,
termination_tx: channels.termination.0.clone(),
connection_pool: ConnectionPool::new(),
connection_deduplicator: ConnectionDeduplicator::new(),
};

Ok(endpoint)
Ok((endpoint, quic_incoming, channels))
}

/// Endpoint local address
Expand Down Expand Up @@ -648,6 +632,26 @@ impl<I: ConnId> Endpoint<I> {
}
}

// a private helper struct for passing a bunch of channel-related things
type Msg = (SocketAddr, Bytes);
struct Channels {
connection: (MpscSender<SocketAddr>, MpscReceiver<SocketAddr>),
message: (MpscSender<Msg>, MpscReceiver<Msg>),
disconnection: (MpscSender<SocketAddr>, MpscReceiver<SocketAddr>),
termination: (Sender<()>, broadcast::Receiver<()>),
}

impl Channels {
fn new() -> Self {
Self {
connection: mpsc::channel(STANDARD_CHANNEL_SIZE),
message: mpsc::channel(STANDARD_CHANNEL_SIZE),
disconnection: mpsc::channel(STANDARD_CHANNEL_SIZE),
termination: broadcast::channel(1),
}
}
}

#[cfg(test)]
mod tests {
use super::Endpoint;
Expand Down

0 comments on commit 9537cfc

Please sign in to comment.