Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cli: ensure ordering of rpc server messages #183558

Merged
merged 2 commits into from May 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
108 changes: 91 additions & 17 deletions cli/src/rpc.rs
Expand Up @@ -273,30 +273,21 @@ impl<S: Serialization, C: Send + Sync + 'static> RpcMethodBuilder<S, C> {

/// Builds into a usable, sync rpc dispatcher.
pub fn build(mut self, log: log::Logger) -> RpcDispatcher<S, C> {
let streams: Arc<tokio::sync::Mutex<HashMap<u32, WriteHalf<DuplexStream>>>> =
Arc::new(tokio::sync::Mutex::new(HashMap::new()));
let streams = Streams::default();

let s1 = streams.clone();
self.register_async(METHOD_STREAM_ENDED, move |m: StreamEndedParams, _| {
let s1 = s1.clone();
async move {
if let Some(mut s) = s1.lock().await.remove(&m.stream) {
let _ = s.shutdown().await;
}
s1.remove(m.stream).await;
Ok(())
}
});

let s2 = streams.clone();
self.register_async(METHOD_STREAM_DATA, move |m: StreamDataIncomingParams, _| {
let s2 = s2.clone();
async move {
let mut lock = s2.lock().await;
if let Some(stream) = lock.get_mut(&m.stream) {
let _ = stream.write_all(&m.segment).await;
}
Ok(())
}
self.register_sync(METHOD_STREAM_DATA, move |m: StreamDataIncomingParams, _| {
s2.write(m.stream, m.segment);
Ok(())
});

RpcDispatcher {
Expand Down Expand Up @@ -400,7 +391,7 @@ pub struct RpcDispatcher<S, C> {
serializer: Arc<S>,
methods: Arc<HashMap<&'static str, Method>>,
calls: Arc<Mutex<HashMap<u32, DispatchMethod>>>,
streams: Arc<tokio::sync::Mutex<HashMap<u32, WriteHalf<DuplexStream>>>>,
streams: Streams,
}

static MESSAGE_ID_COUNTER: AtomicU32 = AtomicU32::new(0);
Expand Down Expand Up @@ -483,10 +474,9 @@ impl<S: Serialization, C: Send + Sync> RpcDispatcher<S, C> {
return;
}

let mut streams_map = self.streams.lock().await;
for (stream_id, duplex) in dto.streams {
let (mut read, write) = tokio::io::split(duplex);
streams_map.insert(stream_id, write);
self.streams.insert(stream_id, write);

let write_tx = write_tx.clone();
let serial = self.serializer.clone();
Expand Down Expand Up @@ -538,6 +528,90 @@ impl<S: Serialization, C: Send + Sync> RpcDispatcher<S, C> {
}
}

struct StreamRec {
write: Option<WriteHalf<DuplexStream>>,
q: Vec<Vec<u8>>,
}

#[derive(Clone, Default)]
struct Streams {
map: Arc<std::sync::Mutex<HashMap<u32, StreamRec>>>,
}

impl Streams {
pub async fn remove(&self, id: u32) {
let stream = self.map.lock().unwrap().remove(&id);
if let Some(s) = stream {
// if there's no 'write' right now, it'll shut down in the write_loop
if let Some(mut w) = s.write {
let _ = w.shutdown().await;
}
}
}

pub fn write(&self, id: u32, buf: Vec<u8>) {
let mut map = self.map.lock().unwrap();
if let Some(s) = map.get_mut(&id) {
s.q.push(buf);

if let Some(w) = s.write.take() {
tokio::spawn(write_loop(id, w, self.map.clone()));
}
}
}

pub fn insert(&self, id: u32, stream: WriteHalf<DuplexStream>) {
self.map.lock().unwrap().insert(
id,
StreamRec {
write: Some(stream),
q: Vec::new(),
},
);
}
}

/// Write loop started by `Streams.write`. It takes the WriteHalf, and
/// runs until there's no more items in the 'write queue'. At that point, if the
/// record still exists in the `streams` (i.e. we haven't shut down), it'll
/// return the WriteHalf so that the next `write` call starts
/// the loop again. Otherwise, it'll shut down the WriteHalf.
///
/// This is the equivalent of the same write_loop in the server_multiplexer.
/// I couldn't figure out a nice way to abstract it without introducing
/// performance overhead...
async fn write_loop(
id: u32,
mut w: WriteHalf<DuplexStream>,
streams: Arc<std::sync::Mutex<HashMap<u32, StreamRec>>>,
) {
let mut items_vec = vec![];
loop {
{
let mut lock = streams.lock().unwrap();
let stream_rec = match lock.get_mut(&id) {
Some(b) => b,
None => break,
};

if stream_rec.q.is_empty() {
stream_rec.write = Some(w);
return;
}

std::mem::swap(&mut stream_rec.q, &mut items_vec);
}

for item in items_vec.drain(..) {
if w.write_all(&item).await.is_err() {
break;
}
}
}

let _ = w.shutdown().await; // got here from `break` above, meaning our record got cleared. Close the bridge if so
}

const METHOD_STREAMS_STARTED: &str = "streams_started";
const METHOD_STREAM_DATA: &str = "stream_data";
const METHOD_STREAM_ENDED: &str = "stream_ended";
Expand Down
2 changes: 1 addition & 1 deletion cli/src/tunnels/server_multiplexer.rs
Expand Up @@ -105,7 +105,7 @@ impl ServerMultiplexer {
}
}

/// Write loop started by `handle_server_message`. It take sthe ServerBridge, and
/// Write loop started by `handle_server_message`. It takes the ServerBridge, and
/// runs until there's no more items in the 'write queue'. At that point, if the
/// record still exists in the bridges_lock (i.e. we haven't shut down), it'll
/// return the ServerBridge so that the next handle_server_message call starts
Expand Down
8 changes: 3 additions & 5 deletions cli/src/util/sync.rs
Expand Up @@ -4,11 +4,9 @@
*--------------------------------------------------------------------------------------------*/
use async_trait::async_trait;
use std::{marker::PhantomData, sync::Arc};
use tokio::{
sync::{
broadcast, mpsc,
watch::{self, error::RecvError},
},
use tokio::sync::{
broadcast, mpsc,
watch::{self, error::RecvError},
};

#[derive(Clone)]
Expand Down