Skip to content

Commit

Permalink
Use bounded channels in transport (#987)
Browse files Browse the repository at this point in the history
* Implement DialFuture

* Update with recommended changes to buffer size, `expect()` and `close()`
  • Loading branch information
mattrutherford committed Mar 28, 2019
1 parent 03ce6a6 commit 7549948
Showing 1 changed file with 49 additions and 20 deletions.
69 changes: 49 additions & 20 deletions core/src/transport/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,50 @@ use rw_stream_sink::RwStreamSink;
use std::{collections::hash_map::Entry, error, fmt, io, num::NonZeroU64};

lazy_static! {
static ref HUB: Mutex<FnvHashMap<NonZeroU64, mpsc::UnboundedSender<Channel<Bytes>>>> = Mutex::new(FnvHashMap::default());
static ref HUB: Mutex<FnvHashMap<NonZeroU64, mpsc::Sender<Channel<Bytes>>>> = Mutex::new(FnvHashMap::default());
}

/// Transport that supports `/memory/N` multiaddresses.
#[derive(Debug, Copy, Clone, Default)]
pub struct MemoryTransport;

/// Connection to a `MemoryTransport` currently being opened.
pub struct DialFuture {
sender: mpsc::Sender<Channel<Bytes>>,
channel_to_send: Option<Channel<Bytes>>,
channel_to_return: Option<Channel<Bytes>>,
}

impl Future for DialFuture {
type Item = Channel<Bytes>;
type Error = MemoryTransportError;

fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
if let Some(c) = self.channel_to_send.take() {
match self.sender.start_send(c) {
Err(_) => return Err(MemoryTransportError::Unreachable),
Ok(AsyncSink::NotReady(t)) => {
self.channel_to_send = Some(t);
return Ok(Async::NotReady)
},
_ => (),
}
}
match self.sender.close() {
Err(_) => Err(MemoryTransportError::Unreachable),
Ok(Async::NotReady) => Ok(Async::NotReady),
Ok(Async::Ready(_)) => Ok(Async::Ready(self.channel_to_return.take()
.expect("Future should not be polled again once complete"))),
}
}
}

impl Transport for MemoryTransport {
type Output = Channel<Bytes>;
type Error = MemoryTransportError;
type Listener = Listener;
type ListenerUpgrade = FutureResult<Self::Output, Self::Error>;
type Dial = FutureResult<Self::Output, Self::Error>;
type Dial = DialFuture;

fn listen_on(self, addr: Multiaddr) -> Result<(Self::Listener, Multiaddr), TransportError<Self::Error>> {
let port = if let Ok(port) = parse_memory_addr(&addr) {
Expand All @@ -68,7 +99,7 @@ impl Transport for MemoryTransport {

let actual_addr = Protocol::Memory(port.get()).into();

let (tx, rx) = mpsc::unbounded();
let (tx, rx) = mpsc::channel(2);
match hub.entry(port) {
Entry::Occupied(_) => return Err(TransportError::Other(MemoryTransportError::Unreachable)),
Entry::Vacant(e) => e.insert(tx),
Expand All @@ -82,7 +113,7 @@ impl Transport for MemoryTransport {
Ok((listener, actual_addr))
}

fn dial(self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
fn dial(self, addr: Multiaddr) -> Result<DialFuture, TransportError<Self::Error>> {
let port = if let Ok(port) = parse_memory_addr(&addr) {
if let Some(port) = NonZeroU64::new(port) {
port
Expand All @@ -94,20 +125,18 @@ impl Transport for MemoryTransport {
};

let hub = HUB.lock();
let chan = if let Some(tx) = hub.get(&port) {
let (a_tx, a_rx) = mpsc::unbounded();
let (b_tx, b_rx) = mpsc::unbounded();
let a = RwStreamSink::new(Chan { incoming: a_rx, outgoing: b_tx });
let b = RwStreamSink::new(Chan { incoming: b_rx, outgoing: a_tx });
if tx.unbounded_send(b).is_err() {
return Err(TransportError::Other(MemoryTransportError::Unreachable));
}
a
if let Some(sender) = hub.get(&port) {
let (a_tx, a_rx) = mpsc::channel(4096);
let (b_tx, b_rx) = mpsc::channel(4096);
Ok(DialFuture {
sender: sender.clone(),
channel_to_send: Some(RwStreamSink::new(Chan { incoming: a_rx, outgoing: b_tx })),
channel_to_return: Some(RwStreamSink::new(Chan { incoming: b_rx, outgoing: a_tx })),

})
} else {
return Err(TransportError::Other(MemoryTransportError::Unreachable));
};

Ok(future::ok(chan))
Err(TransportError::Other(MemoryTransportError::Unreachable))
}
}

fn nat_traversal(&self, server: &Multiaddr, observed: &Multiaddr) -> Option<Multiaddr> {
Expand Down Expand Up @@ -145,7 +174,7 @@ pub struct Listener {
/// Port we're listening on.
port: NonZeroU64,
/// Receives incoming connections.
receiver: mpsc::UnboundedReceiver<Channel<Bytes>>,
receiver: mpsc::Receiver<Channel<Bytes>>,
}

impl Stream for Listener {
Expand Down Expand Up @@ -197,8 +226,8 @@ pub type Channel<T> = RwStreamSink<Chan<T>>;
///
/// Implements `Sink` and `Stream`.
pub struct Chan<T = Bytes> {
incoming: mpsc::UnboundedReceiver<T>,
outgoing: mpsc::UnboundedSender<T>,
incoming: mpsc::Receiver<T>,
outgoing: mpsc::Sender<T>,
}

impl<T> Stream for Chan<T> {
Expand Down

0 comments on commit 7549948

Please sign in to comment.