Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
refactor(all): rename Id trait to ConnId
- wraps the return type of the generate function in a Result
- updates tests and doc-tests
- implements ConnId for XorName
  • Loading branch information
lionel-faber authored and joshuef committed Aug 3, 2021
1 parent 10c8ea0 commit 738c0cd
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 71 deletions.
4 changes: 3 additions & 1 deletion Cargo.toml
Expand Up @@ -20,7 +20,9 @@ serde_json = "1.0.59"
structopt = "~0.3.15"
thiserror = "1.0.23"
webpki = "~0.21.3"
tiny-keccak = "2.0.2"
tracing = "~0.1.26"
xor_name = "1.2.1"

[dependencies.backoff]
version = "0.3.0"
Expand Down Expand Up @@ -62,7 +64,7 @@ tracing = "~0.1.26"

[dev-dependencies]
anyhow = "1.0.36"
rand = "~0.8.4"
rand = "~0.7.3"
tracing-test = "0.1"

[dev-dependencies.tiny-keccak]
Expand Down
12 changes: 6 additions & 6 deletions examples/p2p_node.rs
Expand Up @@ -14,16 +14,16 @@

use anyhow::Result;
use bytes::Bytes;
use qp2p::{Config, Id, QuicP2p};
use qp2p::{Config, ConnId, QuicP2p};
use std::env;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};

#[derive(Default, Ord, PartialEq, PartialOrd, Eq, Clone, Copy)]
struct ConnId(pub [u8; 32]);
struct XId(pub [u8; 32]);

impl Id for ConnId {
fn generate(_socket_addr: &SocketAddr) -> Self {
ConnId(rand::random())
impl ConnId for XId {
fn generate(_socket_addr: &SocketAddr) -> Result<Self, Box<dyn std::error::Error>> {
Ok(XId(rand::random()))
}
}

Expand All @@ -36,7 +36,7 @@ async fn main() -> Result<()> {
let args: Vec<String> = env::args().collect();

// instantiate QuicP2p with custom config
let qp2p: QuicP2p<ConnId> = QuicP2p::with_config(
let qp2p: QuicP2p<XId> = QuicP2p::with_config(
Some(Config {
local_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)),
idle_timeout_msec: Some(1000 * 3600), // 1 hour idle timeout.
Expand Down
46 changes: 23 additions & 23 deletions src/api.rs
Expand Up @@ -10,7 +10,7 @@
use super::{
bootstrap_cache::BootstrapCache,
config::{Config, SerialisableCertificate},
connection_pool::Id,
connection_pool::ConnId,
connections::DisconnectionEvents,
endpoint::{Endpoint, IncomingConnections, IncomingMessages},
error::{Error, Result},
Expand All @@ -33,7 +33,7 @@ const MAIDSAFE_DOMAIN: &str = "maidsafe.net";

/// Main QuicP2p instance to communicate with QuicP2p using an async API
#[derive(Debug, Clone)]
pub struct QuicP2p<I: Id> {
pub struct QuicP2p<I: ConnId> {
local_addr: SocketAddr,
allow_random_port: bool,
bootstrap_cache: BootstrapCache,
Expand All @@ -43,7 +43,7 @@ pub struct QuicP2p<I: Id> {
phantom: PhantomData<I>,
}

impl<I: Id> QuicP2p<I> {
impl<I: ConnId> QuicP2p<I> {
/// Construct `QuicP2p` with supplied parameters, ready to be used.
/// If config is not specified it'll call `Config::read_or_construct_default()`
///
Expand All @@ -55,23 +55,23 @@ impl<I: Id> QuicP2p<I> {
/// # Example
///
/// ```
/// use qp2p::{QuicP2p, Config, Id};
/// use qp2p::{QuicP2p, Config, ConnId};
/// use std::net::{IpAddr, Ipv4Addr, SocketAddr};
///
/// # #[derive(Default, Ord, PartialEq, PartialOrd, Eq, Clone, Copy)]
/// # struct ConnId(pub [u8; 32]);
/// # struct XId(pub [u8; 32]);
/// #
/// # impl Id for ConnId {
/// # fn generate(_socket_addr: &SocketAddr) -> Self {
/// # ConnId(rand::random())
/// # impl ConnId for XId {
/// # fn generate(_socket_addr: &SocketAddr) -> Result<Self, Box<dyn std::error::Error>> {
/// # Ok(XId(rand::random()))
/// # }
/// # }
///
/// let mut config = Config::default();
/// config.local_ip = Some(IpAddr::V4(Ipv4Addr::LOCALHOST));
/// config.local_port = Some(3000);
/// let hcc = &["127.0.0.1:8080".parse().unwrap()];
/// let quic_p2p = QuicP2p::<ConnId>::with_config(Some(config), hcc, true).expect("Error initializing QuicP2p");
/// let quic_p2p = QuicP2p::<XId>::with_config(Some(config), hcc, true).expect("Error initializing QuicP2p");
/// ```
pub fn with_config(
cfg: Option<Config>,
Expand Down Expand Up @@ -154,15 +154,15 @@ impl<I: Id> QuicP2p<I> {
/// # Example
///
/// ```
/// use qp2p::{QuicP2p, Config, Error, Id};
/// use qp2p::{QuicP2p, Config, Error, ConnId};
/// use std::net::{IpAddr, Ipv4Addr, SocketAddr};
///
/// # #[derive(Default, Ord, PartialEq, PartialOrd, Eq, Clone, Copy)]
/// # struct ConnId(pub [u8; 32]);
/// # struct XId(pub [u8; 32]);
/// #
/// # impl Id for ConnId {
/// # fn generate(_socket_addr: &SocketAddr) -> Self {
/// # ConnId(rand::random())
/// # impl ConnId for XId {
/// # fn generate(_socket_addr: &SocketAddr) -> Result<Self, Box<dyn std::error::Error>> {
/// # Ok(XId(rand::random()))
/// # }
/// # }
///
Expand All @@ -171,12 +171,12 @@ impl<I: Id> QuicP2p<I> {
/// let mut config = Config::default();
/// config.local_ip = Some(IpAddr::V4(Ipv4Addr::LOCALHOST));
/// config.local_port = Some(3000);
/// let mut quic_p2p = QuicP2p::<ConnId>::with_config(Some(config.clone()), Default::default(), true)?;
/// let mut quic_p2p = QuicP2p::<XId>::with_config(Some(config.clone()), Default::default(), true)?;
/// let (mut endpoint, _, _, _) = quic_p2p.new_endpoint().await?;
/// let peer_addr = endpoint.socket_addr();
///
/// config.local_port = Some(3001);
/// let mut quic_p2p = QuicP2p::<ConnId>::with_config(Some(config), &[peer_addr], true)?;
/// let mut quic_p2p = QuicP2p::<XId>::with_config(Some(config), &[peer_addr], true)?;
/// let endpoint = quic_p2p.bootstrap().await?;
/// Ok(())
/// }
Expand Down Expand Up @@ -212,23 +212,23 @@ impl<I: Id> QuicP2p<I> {
/// # Example
///
/// ```
/// use qp2p::{QuicP2p, Config, Error, Id};
/// use qp2p::{QuicP2p, Config, Error, ConnId};
/// use std::net::{IpAddr, Ipv4Addr, SocketAddr};
///
/// # #[derive(Default, Ord, PartialEq, PartialOrd, Eq, Clone, Copy)]
/// # struct ConnId(pub [u8; 32]);
/// # struct XId(pub [u8; 32]);
/// #
/// # impl Id for ConnId {
/// # fn generate(_socket_addr: &SocketAddr) -> Self {
/// # ConnId(rand::random())
/// # impl ConnId for XId {
/// # fn generate(_socket_addr: &SocketAddr) -> Result<Self, Box<dyn std::error::Error>> {
/// # Ok(XId(rand::random()))
/// # }
/// # }
///
/// #[tokio::main]
/// async fn main() -> Result<(), Error> {
/// let mut config = Config::default();
/// config.local_ip = Some(IpAddr::V4(Ipv4Addr::LOCALHOST));
/// let mut quic_p2p = QuicP2p::<ConnId>::with_config(Some(config.clone()), Default::default(), true)?;
/// let mut quic_p2p = QuicP2p::<XId>::with_config(Some(config.clone()), Default::default(), true)?;
/// let (endpoint, incoming_connections, incoming_messages, disconnections) = quic_p2p.new_endpoint().await?;
/// Ok(())
/// }
Expand Down Expand Up @@ -300,7 +300,7 @@ impl<I: Id> QuicP2p<I> {
let bootstrapped_peer = successful_connection.connection.remote_address();
endpoint
.add_new_connection_to_pool(successful_connection)
.await;
.await?;
Ok(bootstrapped_peer)
}

Expand Down
65 changes: 48 additions & 17 deletions src/connection_pool.rs
Expand Up @@ -9,16 +9,18 @@

use std::{collections::BTreeMap, net::SocketAddr, sync::Arc};

use tiny_keccak::{Hasher, Sha3};
use tokio::sync::RwLock;
use xor_name::XorName;

// Pool for keeping open connections. Pooled connections are associated with a `ConnectionRemover`
// which can be used to remove them from the pool.
#[derive(Clone)]
pub(crate) struct ConnectionPool<T: Id> {
store: Arc<RwLock<Store<T>>>,
pub(crate) struct ConnectionPool<I: ConnId> {
store: Arc<RwLock<Store<I>>>,
}

impl<T: Id> ConnectionPool<T> {
impl<I: ConnId> ConnectionPool<I> {
pub fn new() -> Self {
Self {
store: Arc::new(RwLock::new(Store::default())),
Expand All @@ -27,10 +29,10 @@ impl<T: Id> ConnectionPool<T> {

pub async fn insert(
&self,
id: T,
id: I,
addr: SocketAddr,
conn: quinn::Connection,
) -> ConnectionRemover<T> {
) -> ConnectionRemover<I> {
let mut store = self.store.write().await;

let key = Key {
Expand Down Expand Up @@ -59,7 +61,7 @@ impl<T: Id> ConnectionPool<T> {
}

#[allow(unused)]
pub async fn has_id(&self, id: &T) -> bool {
pub async fn has_id(&self, id: &I) -> bool {
let store = self.store.read().await;

store.id_map.contains_key(id)
Expand All @@ -77,14 +79,28 @@ impl<T: Id> ConnectionPool<T> {

keys_to_remove
.iter()
.filter_map(|key| store.key_map.remove(&key).map(|entry| entry.0))
.filter_map(|key| store.key_map.remove(key).map(|entry| entry.0))
.collect::<Vec<_>>()
}

pub async fn get_by_id(&self, addr: &I) -> Option<(quinn::Connection, ConnectionRemover<I>)> {
let store = self.store.read().await;

let (conn, key) = store.id_map.get(addr)?;

let remover = ConnectionRemover {
store: self.store.clone(),
key: *key,
id: *addr,
};

Some((conn.clone(), remover))
}

pub async fn get_by_addr(
&self,
addr: &SocketAddr,
) -> Option<(quinn::Connection, ConnectionRemover<T>)> {
) -> Option<(quinn::Connection, ConnectionRemover<I>)> {
let store = self.store.read().await;

// Efficiently fetch the first entry whose key is equal to `key`.
Expand All @@ -106,13 +122,13 @@ impl<T: Id> ConnectionPool<T> {

// Handle for removing a connection from the pool.
#[derive(Clone)]
pub(crate) struct ConnectionRemover<T: Id> {
store: Arc<RwLock<Store<T>>>,
pub(crate) struct ConnectionRemover<I: ConnId> {
store: Arc<RwLock<Store<I>>>,
key: Key,
id: T,
id: I,
}

impl<T: Id> ConnectionRemover<T> {
impl<I: ConnId> ConnectionRemover<I> {
// Remove the connection from the pool.
pub async fn remove(&self) {
let mut store = self.store.write().await;
Expand All @@ -123,22 +139,37 @@ impl<T: Id> ConnectionRemover<T> {
pub fn remote_addr(&self) -> &SocketAddr {
&self.key.addr
}

pub fn id(&self) -> I {
self.id
}
}

#[derive(Default)]
struct Store<T: Id> {
id_map: BTreeMap<T, (quinn::Connection, Key)>,
key_map: BTreeMap<Key, (quinn::Connection, T)>,
struct Store<I: ConnId> {
id_map: BTreeMap<I, (quinn::Connection, Key)>,
key_map: BTreeMap<Key, (quinn::Connection, I)>,
id_gen: IdGen,
}

/// Unique key identifying a connection. Two connections will always have distict keys even if they
/// have the same socket address.
pub trait Id:
pub trait ConnId:
Clone + Copy + Eq + PartialEq + Ord + PartialOrd + Default + Send + Sync + 'static
{
/// Generate
fn generate(socket_addr: &SocketAddr) -> Self;
fn generate(socket_addr: &SocketAddr) -> Result<Self, Box<dyn std::error::Error>>;
}

impl ConnId for XorName {
fn generate(addr: &SocketAddr) -> Result<Self, Box<dyn std::error::Error>> {
let data = bincode::serialize(addr)?;
let mut hasher = Sha3::v256();
let mut output = [0u8; 32];
hasher.update(&data);
hasher.finalize(&mut output);
Ok(XorName(output))
}
}

// Unique key identifying a connection. Two connections will always have distict keys even if they
Expand Down
17 changes: 9 additions & 8 deletions src/connections.rs
Expand Up @@ -10,7 +10,7 @@
use crate::Endpoint;

use super::{
connection_pool::{ConnectionPool, ConnectionRemover, Id},
connection_pool::{ConnId, ConnectionPool, ConnectionRemover},
error::{Error, Result},
wire_msg::WireMsg,
};
Expand All @@ -23,7 +23,7 @@ use tracing::{trace, warn};

/// Connection instance to a node which can be used to send messages to it
#[derive(Clone)]
pub(crate) struct Connection<I: Id> {
pub(crate) struct Connection<I: ConnId> {
quic_conn: quinn::Connection,
remover: ConnectionRemover<I>,
}
Expand All @@ -39,7 +39,7 @@ impl DisconnectionEvents {
}
}

impl<I: Id> Connection<I> {
impl<I: ConnId> Connection<I> {
pub(crate) fn new(quic_conn: quinn::Connection, remover: ConnectionRemover<I>) -> Self {
Self { quic_conn, remover }
}
Expand Down Expand Up @@ -147,7 +147,7 @@ pub async fn send_msg(mut send_stream: &mut quinn::SendStream, msg: Bytes) -> Re
Ok(())
}

pub(super) fn listen_for_incoming_connections<I: Id>(
pub(super) fn listen_for_incoming_connections<I: ConnId>(
mut quinn_incoming: quinn::Incoming,
connection_pool: ConnectionPool<I>,
message_tx: Sender<(SocketAddr, Bytes)>,
Expand All @@ -166,7 +166,8 @@ pub(super) fn listen_for_incoming_connections<I: Id>(
..
}) => {
let peer_address = connection.remote_address();
let id = Id::generate(&peer_address);
let id = ConnId::generate(&peer_address)
.map_err(|err| Error::ConnectionIdGeneration(err.to_string()))?;
let pool_handle =
connection_pool.insert(id, peer_address, connection).await;
let _ = connection_tx.send(peer_address).await;
Expand All @@ -193,7 +194,7 @@ pub(super) fn listen_for_incoming_connections<I: Id>(
});
}

pub(super) fn listen_for_incoming_messages<I: Id>(
pub(super) fn listen_for_incoming_messages<I: ConnId>(
mut uni_streams: quinn::IncomingUniStreams,
mut bi_streams: quinn::IncomingBiStreams,
remover: ConnectionRemover<I>,
Expand Down Expand Up @@ -269,7 +270,7 @@ async fn read_on_uni_streams(
}

// Read messages sent by peer in a bidirectional stream.
async fn read_on_bi_streams<I: Id>(
async fn read_on_bi_streams<I: ConnId>(
bi_streams: &mut quinn::IncomingBiStreams,
peer_addr: SocketAddr,
message_tx: Sender<(SocketAddr, Bytes)>,
Expand Down Expand Up @@ -354,7 +355,7 @@ async fn handle_endpoint_echo_req(
Ok(())
}

async fn handle_endpoint_verification_req<I: Id>(
async fn handle_endpoint_verification_req<I: ConnId>(
peer_addr: SocketAddr,
addr_sent: SocketAddr,
send_stream: &mut quinn::SendStream,
Expand Down

0 comments on commit 738c0cd

Please sign in to comment.