Skip to content

Commit

Permalink
refactor(transport): remove Invoke::NestedOutgoing
Browse files Browse the repository at this point in the history
Signed-off-by: Roman Volosatovs <rvolosatovs@riseup.net>
  • Loading branch information
rvolosatovs committed Jun 14, 2024
1 parent 2bfbd14 commit f71439d
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 26 deletions.
5 changes: 2 additions & 3 deletions crates/runtime-wasmtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -823,9 +823,8 @@ pub fn polyfill<'a, T, C, V>(
C::Error: Into<wasmtime::Error>,
C::Context: Clone + 'static,
<C::Session as Session>::TransportError: Into<wasmtime::Error>,
<C::Outgoing as wrpc_transport::Index<C::NestedOutgoing>>::Error: Into<wasmtime::Error>,
C::NestedOutgoing: 'static,
<C::NestedOutgoing as wrpc_transport::Index<C::NestedOutgoing>>::Error: Into<wasmtime::Error>,
<C::Outgoing as wrpc_transport::Index<C::Outgoing>>::Error: Into<wasmtime::Error>,
C::Outgoing: 'static,
C::Incoming: Unpin + Sized + 'static,
<C::Incoming as wrpc_transport::Index<C::Incoming>>::Error:
Into<Box<dyn std::error::Error + Send + Sync>>,
Expand Down
79 changes: 61 additions & 18 deletions crates/transport-nats/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use async_nats::{HeaderMap, Message, ServerInfo, StatusCode, Subject, Subscriber
use bytes::{Buf as _, Bytes, BytesMut};
use futures::sink::SinkExt as _;
use futures::{Stream, StreamExt};
use tokio::io::{AsyncWrite, AsyncWriteExt as _};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt as _, ReadBuf};
use tokio::sync::oneshot;
use tokio::try_join;
use tokio_util::codec::Encoder;
Expand Down Expand Up @@ -313,11 +313,11 @@ pub struct Reader {
nested: Arc<std::sync::Mutex<SubscriberTree>>,
}

impl wrpc_transport::Index<Reader> for Reader {
impl wrpc_transport::Index<Self> for Reader {
type Error = anyhow::Error;

#[instrument(level = "trace", skip_all)]
fn index(&self, path: &[usize]) -> anyhow::Result<Reader> {
fn index(&self, path: &[usize]) -> anyhow::Result<Self> {
let mut nested = self
.nested
.lock()
Expand All @@ -331,12 +331,12 @@ impl wrpc_transport::Index<Reader> for Reader {
}
}

impl tokio::io::AsyncRead for Reader {
impl AsyncRead for Reader {
#[instrument(level = "trace", skip_all)]
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let cap = buf.remaining();
if cap == 0 {
Expand Down Expand Up @@ -384,11 +384,11 @@ impl SubjectWriter {
}
}

impl wrpc_transport::Index<SubjectWriter> for SubjectWriter {
impl wrpc_transport::Index<Self> for SubjectWriter {
type Error = Infallible;

#[instrument(level = "trace", skip_all)]
fn index(&self, path: &[usize]) -> Result<SubjectWriter, Self::Error> {
fn index(&self, path: &[usize]) -> Result<Self, Self::Error> {
Ok(Self {
nats: Arc::clone(&self.nats),
tx: index_path(self.tx.as_str(), path).into(),
Expand All @@ -397,7 +397,7 @@ impl wrpc_transport::Index<SubjectWriter> for SubjectWriter {
}
}

impl tokio::io::AsyncWrite for SubjectWriter {
impl AsyncWrite for SubjectWriter {
#[instrument(level = "trace", skip_all)]
fn poll_write(
mut self: Pin<&mut Self>,
Expand Down Expand Up @@ -439,7 +439,7 @@ impl tokio::io::AsyncWrite for SubjectWriter {
}

#[derive(Debug, Default)]
pub enum ParamWriter {
pub enum RootParamWriter {
#[default]
Corrupted,
Handshaking {
Expand All @@ -456,7 +456,7 @@ pub enum ParamWriter {
Active(SubjectWriter),
}

impl ParamWriter {
impl RootParamWriter {
fn new(
tx: SubjectWriter,
sub: Subscriber,
Expand All @@ -473,7 +473,7 @@ impl ParamWriter {
}
}

impl ParamWriter {
impl RootParamWriter {
#[instrument(level = "trace", skip_all)]
fn poll_active(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match &mut *self {
Expand Down Expand Up @@ -584,7 +584,7 @@ impl ParamWriter {
}
}

impl wrpc_transport::Index<IndexedParamWriter> for ParamWriter {
impl wrpc_transport::Index<IndexedParamWriter> for RootParamWriter {
type Error = std::io::Error;

#[instrument(level = "trace", skip_all)]
Expand All @@ -610,7 +610,7 @@ impl wrpc_transport::Index<IndexedParamWriter> for ParamWriter {
}
}

impl tokio::io::AsyncWrite for ParamWriter {
impl AsyncWrite for RootParamWriter {
#[instrument(level = "trace", skip_all)]
fn poll_write(
mut self: Pin<&mut Self>,
Expand Down Expand Up @@ -730,7 +730,7 @@ impl wrpc_transport::Index<Self> for IndexedParamWriter {
}
}

impl tokio::io::AsyncWrite for IndexedParamWriter {
impl AsyncWrite for IndexedParamWriter {
#[instrument(level = "trace", skip_all)]
fn poll_write(
mut self: Pin<&mut Self>,
Expand Down Expand Up @@ -778,13 +778,57 @@ impl tokio::io::AsyncWrite for IndexedParamWriter {
}
}

pub enum ParamWriter {
Root(RootParamWriter),
Nested(IndexedParamWriter),
}

impl wrpc_transport::Index<Self> for ParamWriter {
type Error = std::io::Error;

fn index(&self, path: &[usize]) -> Result<Self, Self::Error> {
match self {
ParamWriter::Root(w) => w.index(path),
ParamWriter::Nested(w) => w.index(path),
}
.map(Self::Nested)
}
}

impl AsyncWrite for ParamWriter {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
match &mut *self {
ParamWriter::Root(w) => pin!(w).poll_write(cx, buf),
ParamWriter::Nested(w) => pin!(w).poll_write(cx, buf),
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match &mut *self {
ParamWriter::Root(w) => pin!(w).poll_flush(cx),
ParamWriter::Nested(w) => pin!(w).poll_flush(cx),
}
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match &mut *self {
ParamWriter::Root(w) => pin!(w).poll_shutdown(cx),
ParamWriter::Nested(w) => pin!(w).poll_shutdown(cx),
}
}
}

#[derive(Debug)]
pub enum ClientErrorWriter {
Handshaking(oneshot::Receiver<SubjectWriter>),
Active(SubjectWriter),
}

impl tokio::io::AsyncWrite for ClientErrorWriter {
impl AsyncWrite for ClientErrorWriter {
#[instrument(level = "trace", skip_all)]
fn poll_write(
mut self: Pin<&mut Self>,
Expand Down Expand Up @@ -890,7 +934,6 @@ impl wrpc_transport::Invoke for Client {
type Context = Option<HeaderMap>;
type Session = Session<ClientErrorWriter>;
type Outgoing = ParamWriter;
type NestedOutgoing = IndexedParamWriter;
type Incoming = Reader;

#[instrument(level = "trace", skip(self))]
Expand Down Expand Up @@ -979,7 +1022,7 @@ impl wrpc_transport::Invoke for Client {
.context("failed to send handshake")?;
let (error_tx_tx, error_tx_rx) = oneshot::channel();
Ok(wrpc_transport::Invocation {
outgoing: ParamWriter::new(
outgoing: ParamWriter::Root(RootParamWriter::new(
SubjectWriter::new(
Arc::clone(&self.nats),
param_tx.clone(),
Expand All @@ -988,7 +1031,7 @@ impl wrpc_transport::Invoke for Client {
handshake_rx,
error_tx_tx,
params,
),
)),
incoming: Reader {
buffer: Bytes::default(),
incoming: result_rx,
Expand Down
1 change: 0 additions & 1 deletion crates/transport-quic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,6 @@ impl wrpc_transport::Invoke for Client {
type Context = ();
type Session = Session;
type Outgoing = Outgoing;
type NestedOutgoing = Outgoing;
type Incoming = Incoming;

#[instrument(level = "trace", skip(self))]
Expand Down
5 changes: 1 addition & 4 deletions crates/transport/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,7 @@ pub trait Invoke: Sync + Send {
type Session: Session + Sync + Send;

/// Outgoing multiplexed byte stream
type Outgoing: AsyncWrite + Index<Self::NestedOutgoing> + Sync + Send;

/// Outgoing multiplexed byte stream, nested at a particular path
type NestedOutgoing: AsyncWrite + Index<Self::NestedOutgoing> + Sync + Send;
type Outgoing: AsyncWrite + Index<Self::Outgoing> + Sync + Send;

/// Incoming multiplexed byte stream
type Incoming: AsyncRead + Index<Self::Incoming> + Sync + Send;
Expand Down

0 comments on commit f71439d

Please sign in to comment.