Skip to content

Commit

Permalink
feat(refactor): remove unnecessary mutex
Browse files Browse the repository at this point in the history
BREAKING CHANGE: updates some apis to be async
  • Loading branch information
joshuef committed Jun 6, 2021
1 parent 078918c commit 30814cb
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 40 deletions.
4 changes: 3 additions & 1 deletion src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,9 @@ impl QuicP2p {
})?
.0;
let bootstrapped_peer = successful_connection.connection.remote_address();
endpoint.add_new_connection_to_pool(successful_connection);
endpoint
.add_new_connection_to_pool(successful_connection)
.await;
Ok(bootstrapped_peer)
}

Expand Down
41 changes: 18 additions & 23 deletions src/connection_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,26 @@
// specific language governing permissions and limitations relating to use of the SAFE Network
// Software.

use std::{
collections::BTreeMap,
net::SocketAddr,
sync::{Arc, Mutex, PoisonError},
};
use std::{collections::BTreeMap, net::SocketAddr, sync::Arc};

use tokio::sync::RwLock;

// Pool for keeping open connections. Pooled connections are associated with a `ConnectionRemover`
// which can be used to remove them from the pool.
#[derive(Clone)]
pub(crate) struct ConnectionPool {
store: Arc<Mutex<Store>>,
store: Arc<RwLock<Store>>,
}

impl ConnectionPool {
pub fn new() -> Self {
Self {
store: Arc::new(Mutex::new(Store::default())),
store: Arc::new(RwLock::new(Store::default())),
}
}

pub fn insert(&self, addr: SocketAddr, conn: quinn::Connection) -> ConnectionRemover {
let mut store = self.store.lock().unwrap_or_else(PoisonError::into_inner);
pub async fn insert(&self, addr: SocketAddr, conn: quinn::Connection) -> ConnectionRemover {
let mut store = self.store.write().await;

let key = Key {
addr,
Expand All @@ -42,19 +40,19 @@ impl ConnectionPool {
}
}

pub fn has(&self, addr: &SocketAddr) -> bool {
let mut store = self.store.lock().unwrap_or_else(PoisonError::into_inner);
pub async fn has(&self, addr: &SocketAddr) -> bool {
let store = self.store.read().await;

// Efficiently fetch the first entry whose key is equal to `key` and check if it exists
store
.map
.range_mut(Key::min(*addr)..=Key::max(*addr))
.range(Key::min(*addr)..=Key::max(*addr))
.next()
.is_some()
}

pub fn remove(&self, addr: &SocketAddr) -> Vec<quinn::Connection> {
let mut store = self.store.lock().unwrap_or_else(PoisonError::into_inner);
pub async fn remove(&self, addr: &SocketAddr) -> Vec<quinn::Connection> {
let mut store = self.store.write().await;

let keys_to_remove = store
.map
Expand All @@ -69,14 +67,11 @@ impl ConnectionPool {
.collect::<Vec<_>>()
}

pub fn get(&self, addr: &SocketAddr) -> Option<(quinn::Connection, ConnectionRemover)> {
let mut store = self.store.lock().unwrap_or_else(PoisonError::into_inner);
pub async fn get(&self, addr: &SocketAddr) -> Option<(quinn::Connection, ConnectionRemover)> {
let store = self.store.read().await;

// Efficiently fetch the first entry whose key is equal to `key`.
let (key, conn) = store
.map
.range_mut(Key::min(*addr)..=Key::max(*addr))
.next()?;
let (key, conn) = store.map.range(Key::min(*addr)..=Key::max(*addr)).next()?;

let conn = conn.clone();
let remover = ConnectionRemover {
Expand All @@ -91,14 +86,14 @@ impl ConnectionPool {
// Handle for removing a connection from the pool.
#[derive(Clone)]
pub(crate) struct ConnectionRemover {
store: Arc<Mutex<Store>>,
store: Arc<RwLock<Store>>,
key: Key,
}

impl ConnectionRemover {
// Remove the connection from the pool.
pub fn remove(&self) {
let mut store = self.store.lock().unwrap_or_else(PoisonError::into_inner);
pub async fn remove(&self) {
let mut store = self.store.write().await;
let _ = store.map.remove(&self.key);
}

Expand Down
17 changes: 9 additions & 8 deletions src/connections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,30 +34,31 @@ impl Connection {
}

pub async fn open_bi(&self) -> Result<(SendStream, RecvStream)> {
let (send_stream, recv_stream) = self.handle_error(self.quic_conn.open_bi().await)?;
let (send_stream, recv_stream) = self.handle_error(self.quic_conn.open_bi().await).await?;
Ok((SendStream::new(send_stream), RecvStream::new(recv_stream)))
}

/// Send message to peer using a uni-directional stream.
pub async fn send_uni(&self, msg: Bytes) -> Result<()> {
let mut send_stream = self.handle_error(self.quic_conn.open_uni().await)?;
self.handle_error(send_msg(&mut send_stream, msg).await)?;
let mut send_stream = self.handle_error(self.quic_conn.open_uni().await).await?;
self.handle_error(send_msg(&mut send_stream, msg).await)
.await?;

// We try to make sure the stream is gracefully closed and the bytes get sent,
// but if it was already closed (perhaps by the peer) then we
// don't remove the connection from the pool.
match send_stream.finish().await {
Ok(()) | Err(quinn::WriteError::Stopped(_)) => Ok(()),
Err(err) => {
self.handle_error(Err(err))?;
self.handle_error(Err(err)).await?;
Ok(())
}
}
}

fn handle_error<T, E>(&self, result: Result<T, E>) -> Result<T, E> {
async fn handle_error<T, E>(&self, result: Result<T, E>) -> Result<T, E> {
if result.is_err() {
self.remover.remove()
self.remover.remove().await
}

result
Expand Down Expand Up @@ -154,7 +155,7 @@ pub(super) fn listen_for_incoming_connections(
..
}) => {
let peer_address = connection.remote_address();
let pool_handle = connection_pool.insert(peer_address, connection);
let pool_handle = connection_pool.insert(peer_address, connection).await;
let _ = connection_tx.send(peer_address);
listen_for_incoming_messages(
uni_streams,
Expand Down Expand Up @@ -197,7 +198,7 @@ pub(super) fn listen_for_incoming_messages(

log::trace!("The connection to {:?} has been terminated.", src);
let _ = disconnection_tx.send(src);
remover.remove();
remover.remove().await;
});
}

Expand Down
24 changes: 16 additions & 8 deletions src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ impl Endpoint {
endpoint.connect_to(contact).await?;
let connection = endpoint
.get_connection(&contact)
.await
.ok_or(Error::MissingConnection)?;
let (mut send, mut recv) = connection.open_bi().await?;
send.send(WireMsg::EndpointVerificationReq(addr)).await?;
Expand Down Expand Up @@ -283,9 +284,10 @@ impl Endpoint {
}

/// Removes all existing connections to a given peer
pub fn disconnect_from(&self, peer_addr: &SocketAddr) -> Result<()> {
pub async fn disconnect_from(&self, peer_addr: &SocketAddr) -> Result<()> {
self.connection_pool
.remove(peer_addr)
.await
.iter()
.for_each(|conn| {
conn.close(0u8.into(), b"");
Expand Down Expand Up @@ -331,7 +333,7 @@ impl Endpoint {
/// from the pool and the subsequent call to `connect_to` is guaranteed to reopen new connection
/// too.
pub async fn connect_to(&self, node_addr: &SocketAddr) -> Result<()> {
if self.connection_pool.has(node_addr) {
if self.connection_pool.has(node_addr).await {
trace!("We are already connected to this peer: {}", node_addr);
}

Expand Down Expand Up @@ -369,7 +371,7 @@ impl Endpoint {

trace!("Successfully connected to peer: {}", node_addr);

self.add_new_connection_to_pool(new_conn);
self.add_new_connection_to_pool(new_conn).await;

self.connection_deduplicator
.complete(node_addr, Ok(()))
Expand Down Expand Up @@ -426,10 +428,11 @@ impl Endpoint {
Ok(new_connection)
}

pub(crate) fn add_new_connection_to_pool(&self, conn: quinn::NewConnection) {
pub(crate) async fn add_new_connection_to_pool(&self, conn: quinn::NewConnection) {
let guard = self
.connection_pool
.insert(conn.connection.remote_address(), conn.connection);
.insert(conn.connection.remote_address(), conn.connection)
.await;

listen_for_incoming_messages(
conn.uni_streams,
Expand All @@ -442,8 +445,8 @@ impl Endpoint {
}

/// Get an existing connection for the peer address.
pub(crate) fn get_connection(&self, peer_addr: &SocketAddr) -> Option<Connection> {
if let Some((conn, guard)) = self.connection_pool.get(peer_addr) {
pub(crate) async fn get_connection(&self, peer_addr: &SocketAddr) -> Option<Connection> {
if let Some((conn, guard)) = self.connection_pool.get(peer_addr).await {
trace!("Connection exists in the connection pool: {}", peer_addr);
Some(Connection::new(conn, guard))
} else {
Expand All @@ -459,14 +462,18 @@ impl Endpoint {
self.connect_to(peer_addr).await?;
let connection = self
.get_connection(peer_addr)
.await
.ok_or(Error::MissingConnection)?;
connection.open_bi().await
}

/// Sends a message to a peer. This will attempt to use an existing connection
/// to the destination peer. If a connection does not exist, this will fail with `Error::MissingConnection`
pub async fn try_send_message(&self, msg: Bytes, dest: &SocketAddr) -> Result<()> {
let connection = self.get_connection(dest).ok_or(Error::MissingConnection)?;
let connection = self
.get_connection(dest)
.await
.ok_or(Error::MissingConnection)?;
connection.send_uni(msg).await?;
Ok(())
}
Expand Down Expand Up @@ -502,6 +509,7 @@ impl Endpoint {
endpoint.connect_to(&node).await?;
let connection = endpoint
.get_connection(&node)
.await
.ok_or(Error::MissingConnection)?;
let (mut send_stream, mut recv_stream) = connection.open_bi().await?;
send_stream.send(WireMsg::EndpointEchoReq).await?;
Expand Down

0 comments on commit 30814cb

Please sign in to comment.