37 changes: 29 additions & 8 deletions src/endpoint.rs
Expand Up @@ -7,7 +7,7 @@
// specific language governing permissions and limitations relating to use of the SAFE Network
// Software.

use crate::connection_pool::Id;
use crate::connection_pool::ConnId;

use super::error::Error;
use super::wire_msg::WireMsg;
Expand Down Expand Up @@ -70,7 +70,7 @@ impl IncomingConnections {
/// Endpoint instance which can be used to create connections to peers,
/// and listen to incoming messages from other peers.
#[derive(Clone)]
pub struct Endpoint<I: Id> {
pub struct Endpoint<I: ConnId> {
local_addr: SocketAddr,
public_addr: Option<SocketAddr>,
quic_endpoint: quinn::Endpoint,
Expand All @@ -84,7 +84,7 @@ pub struct Endpoint<I: Id> {
connection_deduplicator: ConnectionDeduplicator,
}

impl<I: Id> std::fmt::Debug for Endpoint<I> {
impl<I: ConnId> std::fmt::Debug for Endpoint<I> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Endpoint")
.field("local_addr", &self.local_addr)
Expand All @@ -94,7 +94,7 @@ impl<I: Id> std::fmt::Debug for Endpoint<I> {
}
}

impl<I: Id> Endpoint<I> {
impl<I: ConnId> Endpoint<I> {
pub(crate) async fn new(
quic_endpoint: quinn::Endpoint,
quic_incoming: quinn::Incoming,
Expand Down Expand Up @@ -141,7 +141,7 @@ impl<I: Id> Endpoint<I> {
info!("Verifying provided public IP address");
endpoint.connect_to(contact).await?;
let connection = endpoint
.get_connection(&contact)
.get_connection(contact)
.await
.ok_or(Error::MissingConnection)?;
let (mut send, mut recv) = connection.open_bi().await?;
Expand Down Expand Up @@ -354,7 +354,7 @@ impl<I: Id> Endpoint<I> {

trace!("Successfully connected to peer: {}", node_addr);

self.add_new_connection_to_pool(final_conn).await;
self.add_new_connection_to_pool(final_conn).await?;

self.connection_deduplicator
.complete(node_addr, Ok(()))
Expand Down Expand Up @@ -452,8 +452,12 @@ impl<I: Id> Endpoint<I> {
Ok(new_connection)
}

pub(crate) async fn add_new_connection_to_pool(&self, conn: quinn::NewConnection) {
let id = Id::generate(&conn.connection.remote_address());
pub(crate) async fn add_new_connection_to_pool(
&self,
conn: quinn::NewConnection,
) -> Result<()> {
let id = ConnId::generate(&conn.connection.remote_address())
.map_err(|err| Error::ConnectionIdGeneration(err.to_string()))?;
let guard = self
.connection_pool
.insert(id, conn.connection.remote_address(), conn.connection)
Expand All @@ -467,6 +471,7 @@ impl<I: Id> Endpoint<I> {
self.disconnection_tx.clone(),
self.clone(),
);
Ok(())
}

/// Get an existing connection for the peer address.
Expand All @@ -479,6 +484,22 @@ impl<I: Id> Endpoint<I> {
}
}

/// Get the connection ID of an existing connection with the provided socket address
pub async fn get_connection_id(&self, addr: &SocketAddr) -> Option<I> {
self.connection_pool
.get_by_addr(addr)
.await
.map(|(_, remover)| remover.id())
}

/// Get the SocketAddr of a connection using the connection ID
pub async fn get_socket_addr_by_id(&self, addr: &I) -> Option<SocketAddr> {
self.connection_pool
.get_by_id(addr)
.await
.map(|(_, remover)| *remover.remote_addr())
}

/// Open a bi-directional peer with a given peer
pub async fn open_bidirectional_stream(
&self,
Expand Down
3 changes: 3 additions & 0 deletions src/error.rs
Expand Up @@ -130,4 +130,7 @@ pub enum Error {
/// Couldn't resolve Public IP address
#[error("Unresolved Public IP address")]
UnresolvedPublicIp,
/// Couldn't generate connection ID
#[error("Couldn't generate connection ID")]
ConnectionIdGeneration(String),
}
3 changes: 1 addition & 2 deletions src/lib.rs
Expand Up @@ -34,7 +34,6 @@
unused_parens,
while_true,
clippy::unicode_not_nfc,
clippy::wrong_pub_self_convention,
warnings
)]
#![warn(
Expand Down Expand Up @@ -65,7 +64,7 @@ mod wire_msg;

pub use api::QuicP2p;
pub use config::Config;
pub use connection_pool::Id;
pub use connection_pool::ConnId;
pub use connections::{DisconnectionEvents, RecvStream, SendStream};
pub use endpoint::{Endpoint, IncomingConnections, IncomingMessages};
pub use error::{Error, Result};
Expand Down
4 changes: 2 additions & 2 deletions src/tests/common.rs
Expand Up @@ -445,7 +445,7 @@ async fn multiple_connections_with_many_concurrent_messages() -> Result<()> {
let mut hash_results = BTreeSet::new();
send_endpoint.connect_to(&server_addr).await?;
for (index, message) in messages.iter().enumerate().take(num_messages_each) {
let _ = hash_results.insert(hash(&message));
let _ = hash_results.insert(hash(message));
info!("sender #{} sending message #{}", id, index);
send_endpoint
.send_message(message.clone(), &server_addr)
Expand Down Expand Up @@ -554,7 +554,7 @@ async fn multiple_connections_with_many_larger_concurrent_messages() -> Result<(

send_endpoint.connect_to(&server_addr).await?;
for (index, message) in messages.iter().enumerate().take(num_messages_each) {
let _ = hash_results.insert(hash(&message));
let _ = hash_results.insert(hash(message));

info!("sender #{} sending message #{}", id, index);
send_endpoint
Expand Down
8 changes: 4 additions & 4 deletions src/tests/mod.rs
Expand Up @@ -7,7 +7,7 @@
// specific language governing permissions and limitations relating to use of the SAFE Network
// Software.

use crate::{Config, Id, QuicP2p};
use crate::{Config, ConnId, QuicP2p};
use anyhow::Result;
use bytes::Bytes;
use std::{
Expand All @@ -17,9 +17,9 @@ use std::{

mod common;

impl Id for [u8; 32] {
fn generate(_socket_addr: &SocketAddr) -> Self {
rand::random()
impl ConnId for [u8; 32] {
fn generate(_socket_addr: &SocketAddr) -> Result<Self, Box<dyn std::error::Error>> {
Ok(rand::random())
}
}

Expand Down