Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cli: fix exec server not reading all stdin with immediate close #195257

Merged
merged 1 commit into from
Oct 10, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
75 changes: 68 additions & 7 deletions cli/src/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,7 @@ impl<S: Serialization, C: Send + Sync> RpcDispatcher<S, C> {
struct StreamRec {
write: Option<WriteHalf<DuplexStream>>,
q: Vec<Vec<u8>>,
ended: bool,
}

#[derive(Clone, Default)]
Expand All @@ -540,13 +541,24 @@ struct Streams {

impl Streams {
pub async fn remove(&self, id: u32) {
let stream = self.map.lock().unwrap().remove(&id);
if let Some(s) = stream {
// if there's no 'write' right now, it'll shut down in the write_loop
if let Some(mut w) = s.write {
let _ = w.shutdown().await;
let mut remove = None;

{
let mut map = self.map.lock().unwrap();
if let Some(s) = map.get_mut(&id) {
if let Some(w) = s.write.take() {
map.remove(&id);
remove = Some(w);
} else {
s.ended = true; // will shut down in write loop
}
}
}

// do this outside of the sync lock:
if let Some(mut w) = remove {
let _ = w.shutdown().await;
}
}

pub fn write(&self, id: u32, buf: Vec<u8>) {
Expand All @@ -566,6 +578,7 @@ impl Streams {
StreamRec {
write: Some(stream),
q: Vec::new(),
ended: false,
},
);
}
Expand Down Expand Up @@ -595,8 +608,13 @@ async fn write_loop(
};

if stream_rec.q.is_empty() {
stream_rec.write = Some(w);
return;
if stream_rec.ended {
lock.remove(&id);
break;
} else {
stream_rec.write = Some(w);
return;
}
}

std::mem::swap(&mut stream_rec.q, &mut items_vec);
Expand Down Expand Up @@ -691,3 +709,46 @@ pub enum MaybeSync {
Future(BoxFuture<'static, Option<Vec<u8>>>),
Sync(Option<Vec<u8>>),
}

#[cfg(test)]
mod tests {
use super::*;

#[tokio::test]
async fn test_remove() {
let streams = Streams::default();
let (writer, mut reader) = tokio::io::duplex(1024);
streams.insert(1, tokio::io::split(writer).1);
streams.remove(1).await;

assert!(streams.map.lock().unwrap().get(&1).is_none());
let mut buffer = Vec::new();
assert_eq!(reader.read_to_end(&mut buffer).await.unwrap(), 0);
}

#[tokio::test]
async fn test_write() {
let streams = Streams::default();
let (writer, mut reader) = tokio::io::duplex(1024);
streams.insert(1, tokio::io::split(writer).1);
streams.write(1, vec![1, 2, 3]);

let mut buffer = [0; 3];
assert_eq!(reader.read_exact(&mut buffer).await.unwrap(), 3);
assert_eq!(buffer, [1, 2, 3]);
}

#[tokio::test]
async fn test_write_with_immediate_end() {
let streams = Streams::default();
let (writer, mut reader) = tokio::io::duplex(1);
streams.insert(1, tokio::io::split(writer).1);
streams.write(1, vec![1, 2, 3]); // spawn write loop
streams.write(1, vec![4, 5, 6]); // enqueued while writing
streams.remove(1).await; // end stream

let mut buffer = Vec::new();
assert_eq!(reader.read_to_end(&mut buffer).await.unwrap(), 6);
assert_eq!(buffer, vec![1, 2, 3, 4, 5, 6]);
}
}