diff --git a/cli/src/rpc.rs b/cli/src/rpc.rs index 02cbac61bee19..acd53dc38e077 100644 --- a/cli/src/rpc.rs +++ b/cli/src/rpc.rs @@ -273,30 +273,21 @@ impl RpcMethodBuilder { /// Builds into a usable, sync rpc dispatcher. pub fn build(mut self, log: log::Logger) -> RpcDispatcher { - let streams: Arc>>> = - Arc::new(tokio::sync::Mutex::new(HashMap::new())); + let streams = Streams::default(); let s1 = streams.clone(); self.register_async(METHOD_STREAM_ENDED, move |m: StreamEndedParams, _| { let s1 = s1.clone(); async move { - if let Some(mut s) = s1.lock().await.remove(&m.stream) { - let _ = s.shutdown().await; - } + s1.remove(m.stream).await; Ok(()) } }); let s2 = streams.clone(); - self.register_async(METHOD_STREAM_DATA, move |m: StreamDataIncomingParams, _| { - let s2 = s2.clone(); - async move { - let mut lock = s2.lock().await; - if let Some(stream) = lock.get_mut(&m.stream) { - let _ = stream.write_all(&m.segment).await; - } - Ok(()) - } + self.register_sync(METHOD_STREAM_DATA, move |m: StreamDataIncomingParams, _| { + s2.write(m.stream, m.segment); + Ok(()) }); RpcDispatcher { @@ -400,7 +391,7 @@ pub struct RpcDispatcher { serializer: Arc, methods: Arc>, calls: Arc>>, - streams: Arc>>>, + streams: Streams, } static MESSAGE_ID_COUNTER: AtomicU32 = AtomicU32::new(0); @@ -483,10 +474,9 @@ impl RpcDispatcher { return; } - let mut streams_map = self.streams.lock().await; for (stream_id, duplex) in dto.streams { let (mut read, write) = tokio::io::split(duplex); - streams_map.insert(stream_id, write); + self.streams.insert(stream_id, write); let write_tx = write_tx.clone(); let serial = self.serializer.clone(); @@ -538,6 +528,90 @@ impl RpcDispatcher { } } +struct StreamRec { + write: Option>, + q: Vec>, +} + +#[derive(Clone, Default)] +struct Streams { + map: Arc>>, +} + +impl Streams { + pub async fn remove(&self, id: u32) { + let stream = self.map.lock().unwrap().remove(&id); + if let Some(s) = stream { + // if there's no 'write' right now, it'll shut down in the write_loop + if let Some(mut w) = s.write { + let _ = w.shutdown().await; + } + } + } + + pub fn write(&self, id: u32, buf: Vec) { + let mut map = self.map.lock().unwrap(); + if let Some(s) = map.get_mut(&id) { + s.q.push(buf); + + if let Some(w) = s.write.take() { + tokio::spawn(write_loop(id, w, self.map.clone())); + } + } + } + + pub fn insert(&self, id: u32, stream: WriteHalf) { + self.map.lock().unwrap().insert( + id, + StreamRec { + write: Some(stream), + q: Vec::new(), + }, + ); + } +} + +/// Write loop started by `Streams.write`. It takes the WriteHalf, and +/// runs until there's no more items in the 'write queue'. At that point, if the +/// record still exists in the `streams` (i.e. we haven't shut down), it'll +/// return the WriteHalf so that the next `write` call starts +/// the loop again. Otherwise, it'll shut down the WriteHalf. +/// +/// This is the equivalent of the same write_loop in the server_multiplexer. +/// I couldn't figure out a nice way to abstract it without introducing +/// performance overhead... +async fn write_loop( + id: u32, + mut w: WriteHalf, + streams: Arc>>, +) { + let mut items_vec = vec![]; + loop { + { + let mut lock = streams.lock().unwrap(); + let stream_rec = match lock.get_mut(&id) { + Some(b) => b, + None => break, + }; + + if stream_rec.q.is_empty() { + stream_rec.write = Some(w); + return; + } + + std::mem::swap(&mut stream_rec.q, &mut items_vec); + } + + for item in items_vec.drain(..) { + if w.write_all(&item).await.is_err() { + break; + } + } + } + + let _ = w.shutdown().await; // got here from `break` above, meaning our record got cleared. Close the bridge if so +} + const METHOD_STREAMS_STARTED: &str = "streams_started"; const METHOD_STREAM_DATA: &str = "stream_data"; const METHOD_STREAM_ENDED: &str = "stream_ended"; diff --git a/cli/src/tunnels/server_multiplexer.rs b/cli/src/tunnels/server_multiplexer.rs index 34782ff375b79..65eb1df7ad5b2 100644 --- a/cli/src/tunnels/server_multiplexer.rs +++ b/cli/src/tunnels/server_multiplexer.rs @@ -105,7 +105,7 @@ impl ServerMultiplexer { } } -/// Write loop started by `handle_server_message`. It take sthe ServerBridge, and +/// Write loop started by `handle_server_message`. It takes the ServerBridge, and /// runs until there's no more items in the 'write queue'. At that point, if the /// record still exists in the bridges_lock (i.e. we haven't shut down), it'll /// return the ServerBridge so that the next handle_server_message call starts diff --git a/cli/src/util/sync.rs b/cli/src/util/sync.rs index 2b506bd54e3f2..8b653cd2d535c 100644 --- a/cli/src/util/sync.rs +++ b/cli/src/util/sync.rs @@ -4,11 +4,9 @@ *--------------------------------------------------------------------------------------------*/ use async_trait::async_trait; use std::{marker::PhantomData, sync::Arc}; -use tokio::{ - sync::{ - broadcast, mpsc, - watch::{self, error::RecvError}, - }, +use tokio::sync::{ + broadcast, mpsc, + watch::{self, error::RecvError}, }; #[derive(Clone)]