Skip to content

Commit

Permalink
feat: connection state
Browse files Browse the repository at this point in the history
  • Loading branch information
merklefruit committed Jan 24, 2024
1 parent deed07a commit 4ba2bf7
Show file tree
Hide file tree
Showing 12 changed files with 228 additions and 93 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions msg-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ repository.workspace = true
[dependencies]
futures.workspace = true
tokio.workspace = true
tokio-util.workspace = true
80 changes: 71 additions & 9 deletions msg-common/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use std::{
pin::Pin,
task::{Context, Poll},
time::SystemTime,
};

use futures::{future::BoxFuture, Sink, SinkExt, Stream};
use tokio::sync::mpsc::{
self,
error::{SendError, TryRecvError, TrySendError},
Receiver, Sender,
error::{TryRecvError, TrySendError},
Receiver,
};

use futures::future::BoxFuture;
use tokio_util::sync::{PollSendError, PollSender};

/// Returns the current UNIX timestamp in microseconds.
#[inline]
Expand All @@ -36,8 +37,10 @@ pub mod constants {

/// A bounded, bi-directional channel for sending and receiving messages.
/// Relies on Tokio's [`mpsc`] channel.
///
/// Channel also implements the [`Stream`] and [`Sink`] traits for convenience.
pub struct Channel<S, R> {
tx: Sender<S>,
tx: PollSender<S>,
rx: Receiver<R>,
}

Expand All @@ -49,14 +52,21 @@ pub struct Channel<S, R> {
/// the tuple can be used to send messages of type `S` and receive messages of
/// type `R`. The second channel can be used to send messages of type `R` and
/// receive messages of type `S`.
pub fn channel<S, R>(tx_buffer: usize, rx_buffer: usize) -> (Channel<S, R>, Channel<R, S>) {
pub fn channel<S, R>(tx_buffer: usize, rx_buffer: usize) -> (Channel<S, R>, Channel<R, S>)
where
S: Send,
R: Send,
{
let (tx1, rx1) = mpsc::channel(tx_buffer);
let (tx2, rx2) = mpsc::channel(rx_buffer);

let tx1 = PollSender::new(tx1);
let tx2 = PollSender::new(tx2);

(Channel { tx: tx1, rx: rx2 }, Channel { tx: tx2, rx: rx1 })
}

impl<S, R> Channel<S, R> {
impl<S: Send + 'static, R> Channel<S, R> {
/// Sends a value, waiting until there is capacity.
///
/// A successful send occurs when it is determined that the other end of the
Expand All @@ -66,7 +76,7 @@ impl<S, R> Channel<S, R> {
/// value of `Ok` does not mean that the data will be received. It is
/// possible for the corresponding receiver to hang up immediately after
/// this function returns `Ok`.
pub async fn send(&mut self, msg: S) -> Result<(), SendError<S>> {
pub async fn send(&mut self, msg: S) -> Result<(), PollSendError<S>> {
self.tx.send(msg).await
}

Expand All @@ -77,7 +87,11 @@ impl<S, R> Channel<S, R> {
/// with [`send`], this function has two failure cases instead of one (one for
/// disconnection, one for a full buffer).
pub fn try_send(&mut self, msg: S) -> Result<(), TrySendError<S>> {
self.tx.try_send(msg)
if let Some(tx) = self.tx.get_ref() {
tx.try_send(msg)
} else {
Err(TrySendError::Closed(msg))
}
}

/// Receives the next value for this receiver.
Expand Down Expand Up @@ -135,3 +149,51 @@ impl<S, R> Channel<S, R> {
self.rx.poll_recv(cx)
}
}

impl<S, R> Stream for Channel<S, R> {
type Item = R;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.rx.poll_recv(cx)
}
}

impl<S: Send + 'static, R> Sink<S> for Channel<S, R> {
type Error = PollSendError<S>;

fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.tx.poll_ready_unpin(cx)
}

fn start_send(mut self: Pin<&mut Self>, item: S) -> Result<(), Self::Error> {
self.tx.start_send_unpin(item)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.tx.poll_flush_unpin(cx)
}

fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.tx.poll_close_unpin(cx)
}
}

// impl<S: Send + 'static, R: Send + 'static> Sink<R> for Channel<S, R> {
// type Error = PollSendError<R>;

// fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// self.tx.poll_ready_unpin(cx)
// }

// fn start_send(mut self: Pin<&mut Self>, item: R) -> Result<(), Self::Error> {
// self.tx.start_send_unpin(item)
// }

// fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// self.tx.poll_flush_unpin(cx)
// }

// fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// self.tx.poll_close_unpin(cx)
// }
// }
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@ use futures::{FutureExt, Stream};
use std::{pin::Pin, task::Poll, time::Duration};
use tokio::time::sleep;

/// Helper trait alias for backoff streams.
/// We define any stream that yields `Duration`s as a backoff
pub trait Backoff: Stream<Item = Duration> + Unpin {}

/// Blanket implementation of `Backoff` for any stream that yields `Duration`s.
impl<T> Backoff for T where T: Stream<Item = Duration> + Unpin {}

/// A stream that yields exponentially increasing backoff durations.
pub struct ExponentialBackoff {
/// Current number of retries.
Expand Down
5 changes: 5 additions & 0 deletions msg-socket/src/connection/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pub mod state;
pub use state::ConnectionState;

pub mod backoff;
pub use backoff::{Backoff, ExponentialBackoff};
34 changes: 34 additions & 0 deletions msg-socket/src/connection/state.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use std::net::SocketAddr;

use super::Backoff;

/// Abstraction to represent the state of a connection.
///
/// * `C` is the channel type, which is used to send and receive generic messages.
/// * `B` is the backoff type, used to control the backoff state for inactive connections.
pub enum ConnectionState<C, B> {
Active {
/// Channel to control the underlying connection. This is used to send
/// and receive any kind of message in any direction.
channel: C,
},
Inactive {
addr: SocketAddr,
/// The current backoff state for inactive connections.
backoff: B,
},
}

impl<C, B: Backoff> ConnectionState<C, B> {
/// Returns `true` if the connection is active.
#[allow(unused)]
pub fn is_active(&self) -> bool {
matches!(self, Self::Active { .. })
}

/// Returns `true` if the connection is inactive.
#[allow(unused)]
pub fn is_inactive(&self) -> bool {
matches!(self, Self::Inactive { .. })
}
}
2 changes: 1 addition & 1 deletion msg-socket/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ mod rep;
mod req;
mod sub;

mod backoff;
mod connection;

use bytes::Bytes;
pub use pubs::{PubError, PubOptions, PubSocket};
Expand Down
113 changes: 69 additions & 44 deletions msg-socket/src/req/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ use std::{
use tokio::sync::{mpsc, oneshot};
use tokio_util::codec::Framed;

use crate::{req::SocketState, ReqMessage};
use crate::{
connection::{ConnectionState, ExponentialBackoff},
req::SocketState,
ReqMessage,
};

use super::{Command, ReqError, ReqOptions};
use msg_wire::{
Expand All @@ -34,8 +38,9 @@ pub(crate) struct ReqDriver<T: Transport> {
pub(crate) id_counter: u32,
/// Commands from the socket.
pub(crate) from_socket: mpsc::Receiver<Command>,
/// The actual [`Framed`] connection with the `Req`-specific codec.
pub(crate) conn: Framed<T::Io, reqrep::Codec>,
/// The transport controller, wrapped in a [`ConnectionState`] for backoff.
/// The [`Framed`] object can send and receive messages from the socket.
pub(crate) conn_state: ConnectionState<Framed<T::Io, reqrep::Codec>, ExponentialBackoff>,
/// The outgoing message queue.
pub(crate) egress_queue: VecDeque<reqrep::Message>,
/// The currently pending requests, if any. Uses [`FxHashMap`] for performance.
Expand Down Expand Up @@ -138,63 +143,77 @@ where
let this = self.get_mut();

loop {
// Try to flush pending messages
if this.should_flush(cx) {
if let Poll::Ready(Ok(_)) = this.conn.poll_flush_unpin(cx) {
this.should_flush = false;
if let ConnectionState::Active { ref mut channel } = this.conn_state {
if let Poll::Ready(Ok(_)) = channel.poll_flush_unpin(cx) {
this.should_flush = false;
}
}

// TODO: what to do with an inactive connection here?
}

// Check for incoming messages from the socket
match this.conn.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(msg))) => {
this.on_message(msg);
match this.conn_state {
ConnectionState::Active { ref mut channel } => {
// Check for incoming messages from the socket
match channel.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(msg))) => {
this.on_message(msg);

continue;
}
Poll::Ready(Some(Err(e))) => {
if let reqrep::Error::Io(e) = e {
tracing::error!("Socket error: {:?}", e);
if e.kind() == std::io::ErrorKind::Other {
tracing::error!("Other error: {:?}", e);
continue;
}
Poll::Ready(Some(Err(e))) => {
if let reqrep::Error::Io(e) = e {
tracing::error!("Socket error: {:?}", e);
if e.kind() == std::io::ErrorKind::Other {
tracing::error!("Other error: {:?}", e);
return Poll::Ready(());
}
}

continue;
}
Poll::Ready(None) => {
tracing::debug!("Socket closed, shutting down backend");
return Poll::Ready(());
}
Poll::Pending => {}
}

// Check for outgoing messages to the socket
if channel.poll_ready_unpin(cx).is_ready() {
// Drain the egress queue
if let Some(msg) = this.egress_queue.pop_front() {
// Generate the new message
let size = msg.size();
tracing::debug!("Sending msg {}", msg.id());
match channel.start_send_unpin(msg) {
Ok(_) => {
this.socket_state.stats.increment_tx(size);

this.should_flush = true;
// We might be able to send more queued messages
continue;
}
Err(e) => {
tracing::error!("Failed to send message to socket: {:?}", e);
return Poll::Ready(());
}
}
}
}
continue;
}
Poll::Ready(None) => {
tracing::debug!("Socket closed, shutting down backend");
return Poll::Ready(());
ConnectionState::Inactive { addr, ref backoff } => {
// TODO: handle backoff in case of an inactive connection
}
Poll::Pending => {}
}

// Check for request timeouts
while this.timeout_check_interval.poll_tick(cx).is_ready() {
this.check_timeouts();
}

if this.conn.poll_ready_unpin(cx).is_ready() {
// Drain the egress queue
if let Some(msg) = this.egress_queue.pop_front() {
// Generate the new message
let size = msg.size();
tracing::debug!("Sending msg {}", msg.id());
match this.conn.start_send_unpin(msg) {
Ok(_) => {
this.socket_state.stats.increment_tx(size);

this.should_flush = true;
// We might be able to send more queued messages
continue;
}
Err(e) => {
tracing::error!("Failed to send message to socket: {:?}", e);
return Poll::Ready(());
}
}
}
}

// Check for outgoing messages from the socket handle
match this.from_socket.poll_recv(cx) {
Poll::Ready(Some(Command::Send { message, response })) => {
Expand Down Expand Up @@ -236,7 +255,13 @@ where
tracing::debug!(
"Socket dropped, shutting down backend and flushing connection"
);
let _ = ready!(this.conn.poll_close_unpin(cx));

if let ConnectionState::Active { ref mut channel } = this.conn_state {
let _ = ready!(channel.poll_close_unpin(cx));
}

// TODO: handle inactive connection here?

return Poll::Ready(());
}
Poll::Pending => {}
Expand Down
5 changes: 3 additions & 2 deletions msg-socket/src/req/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use msg_transport::Transport;
use msg_wire::{auth, reqrep};

use super::{Command, ReqDriver, ReqError, ReqOptions, DEFAULT_BUFFER_SIZE};
use crate::backoff::ExponentialBackoff;
use crate::connection::{ConnectionState, ExponentialBackoff};
use crate::{req::stats::SocketStats, req::SocketState};

/// The request socket.
Expand Down Expand Up @@ -119,13 +119,14 @@ where

let mut framed = Framed::new(stream, reqrep::Codec::new());
framed.set_backpressure_boundary(self.options.backpressure_boundary);
let conn = ConnectionState::Active { channel: framed };

// Create the socket backend
let driver: ReqDriver<T> = ReqDriver {
options: Arc::clone(&self.options),
id_counter: 0,
from_socket,
conn: framed,
conn_state: conn,
egress_queue: VecDeque::new(),
// TODO: we should limit the amount of active outgoing requests, and that should be the capacity.
// If we do this, we'll never have to re-allocate.
Expand Down
Loading

0 comments on commit 4ba2bf7

Please sign in to comment.