diff --git a/src/endpoint.rs b/src/endpoint.rs index 5c6acb4c..4baaf906 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -137,33 +137,8 @@ impl Endpoint { let mut builder = quinn::Endpoint::builder(); let _ = builder.listen(config.server.clone()); - let local_addr = local_addr.into(); - let (quic_endpoint, quic_incoming) = builder.bind(&local_addr)?; - let local_addr = quic_endpoint.local_addr().map_err(EndpointError::Socket)?; - - let (message_tx, message_rx) = mpsc::channel(STANDARD_CHANNEL_SIZE); - let (connection_tx, connection_rx) = mpsc::channel(STANDARD_CHANNEL_SIZE); - let (disconnection_tx, disconnection_rx) = mpsc::channel(STANDARD_CHANNEL_SIZE); - - #[cfg(feature = "no-igd")] - let (termination_tx, _) = broadcast::channel(1); - - #[cfg(not(feature = "no-igd"))] - let (termination_tx, termination_rx) = broadcast::channel(1); - - let connection_pool = ConnectionPool::new(); - - let mut endpoint = Self { - local_addr, - public_addr: None, - quic_endpoint, - message_tx: message_tx.clone(), - disconnection_tx: disconnection_tx.clone(), - config, - termination_tx, - connection_pool: connection_pool.clone(), - connection_deduplicator: ConnectionDeduplicator::new(), - }; + let (mut endpoint, quic_incoming, channels) = + Self::build_endpoint(local_addr.into(), config, builder)?; let contact = endpoint.connect_to_any(contacts).await; let public_addr = endpoint.resolve_public_addr(contact).await?; @@ -174,9 +149,9 @@ impl Endpoint { PORT_FORWARD_TIMEOUT, forward_port( public_addr.port(), - local_addr, + endpoint.local_addr(), endpoint.config.upnp_lease_duration, - termination_rx, + channels.termination.1, ), ) .await @@ -185,10 +160,10 @@ impl Endpoint { listen_for_incoming_connections( quic_incoming, - connection_pool, - message_tx, - connection_tx, - disconnection_tx, + endpoint.connection_pool.clone(), + channels.message.0.clone(), + channels.connection.0, + channels.disconnection.0.clone(), endpoint.clone(), ); @@ -204,9 +179,9 @@ impl Endpoint { Ok(( endpoint, - IncomingConnections(connection_rx), - IncomingMessages(message_rx), - DisconnectionEvents(disconnection_rx), + IncomingConnections(channels.connection.1), + IncomingMessages(channels.message.1), + DisconnectionEvents(channels.disconnection.1), contact, )) } @@ -222,29 +197,38 @@ impl Endpoint { ) -> Result { let config = InternalConfig::try_from_config(config)?; - let local_addr = local_addr.into(); - let (quic_endpoint, _) = quinn::Endpoint::builder().bind(&local_addr)?; + let (endpoint, _, _) = + Self::build_endpoint(local_addr.into(), config, quinn::Endpoint::builder())?; + + Ok(endpoint) + } + + // A private helper for initialising an endpoint. + fn build_endpoint( + local_addr: SocketAddr, + config: InternalConfig, + builder: quinn::EndpointBuilder, + ) -> Result<(Self, quinn::Incoming, Channels), quinn::EndpointError> { + let (quic_endpoint, quic_incoming) = builder.bind(&local_addr)?; let local_addr = quic_endpoint .local_addr() - .map_err(ClientEndpointError::Socket)?; + .map_err(quinn::EndpointError::Socket)?; - let (message_tx, _) = mpsc::channel(STANDARD_CHANNEL_SIZE); - let (disconnection_tx, _) = mpsc::channel(STANDARD_CHANNEL_SIZE); - let (termination_tx, _) = broadcast::channel(1); + let channels = Channels::new(); let endpoint = Self { local_addr, public_addr: None, quic_endpoint, - message_tx, - disconnection_tx, + message_tx: channels.message.0.clone(), + disconnection_tx: channels.disconnection.0.clone(), config, - termination_tx, + termination_tx: channels.termination.0.clone(), connection_pool: ConnectionPool::new(), connection_deduplicator: ConnectionDeduplicator::new(), }; - Ok(endpoint) + Ok((endpoint, quic_incoming, channels)) } /// Endpoint local address @@ -648,6 +632,26 @@ impl Endpoint { } } +// a private helper struct for passing a bunch of channel-related things +type Msg = (SocketAddr, Bytes); +struct Channels { + connection: (MpscSender, MpscReceiver), + message: (MpscSender, MpscReceiver), + disconnection: (MpscSender, MpscReceiver), + termination: (Sender<()>, broadcast::Receiver<()>), +} + +impl Channels { + fn new() -> Self { + Self { + connection: mpsc::channel(STANDARD_CHANNEL_SIZE), + message: mpsc::channel(STANDARD_CHANNEL_SIZE), + disconnection: mpsc::channel(STANDARD_CHANNEL_SIZE), + termination: broadcast::channel(1), + } + } +} + #[cfg(test)] mod tests { use super::Endpoint;