Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cli: add streams to rpc, generic 'spawn' command #179732

Merged
merged 4 commits into from Apr 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 3 additions & 2 deletions cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion cli/Cargo.toml
Expand Up @@ -17,7 +17,7 @@ clap = { version = "3.0", features = ["derive", "env"] }
open = { version = "2.1.0" }
reqwest = { version = "0.11.9", default-features = false, features = ["json", "stream", "native-tls"] }
tokio = { version = "1.24.2", features = ["full"] }
tokio-util = { version = "0.7", features = ["compat"] }
tokio-util = { version = "0.7", features = ["compat", "codec"] }
flate2 = { version = "1.0.22" }
zip = { version = "0.5.13", default-features = false, features = ["time", "deflate"] }
regex = { version = "1.5.5" }
Expand Down Expand Up @@ -54,6 +54,7 @@ thiserror = "1.0"
cfg-if = "1.0.0"
pin-project = "1.0"
console = "0.15"
bytes = "1.4"

[build-dependencies]
serde = { version = "1.0" }
Expand Down
4 changes: 2 additions & 2 deletions cli/src/commands/tunnels.rs
Expand Up @@ -190,7 +190,7 @@ pub async fn rename(ctx: CommandContext, rename_args: TunnelRenameArgs) -> Resul
let auth = Auth::new(&ctx.paths, ctx.log.clone());
let mut dt = dev_tunnels::DevTunnels::new(&ctx.log, auth, &ctx.paths);
dt.rename_tunnel(&rename_args.name).await?;
ctx.log.result(&format!(
ctx.log.result(format!(
"Successfully renamed this gateway to {}",
&rename_args.name
));
Expand Down Expand Up @@ -287,7 +287,7 @@ pub async fn prune(ctx: CommandContext) -> Result<i32, AnyError> {
.filter(|s| s.get_running_pid().is_none())
.try_for_each(|s| {
ctx.log
.result(&format!("Deleted {}", s.server_dir.display()));
.result(format!("Deleted {}", s.server_dir.display()));
s.delete()
})
.map_err(AnyError::from)?;
Expand Down
4 changes: 3 additions & 1 deletion cli/src/commands/update.rs
Expand Up @@ -3,6 +3,8 @@
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/

use std::sync::Arc;

use indicatif::ProgressBar;

use crate::{
Expand All @@ -17,7 +19,7 @@ use super::{args::StandaloneUpdateArgs, CommandContext};
pub async fn update(ctx: CommandContext, args: StandaloneUpdateArgs) -> Result<i32, AnyError> {
let update_service = UpdateService::new(
ctx.log.clone(),
ReqwestSimpleHttp::with_client(ctx.http.clone()),
Arc::new(ReqwestSimpleHttp::with_client(ctx.http.clone())),
);
let update_service = SelfUpdate::new(&update_service)?;

Expand Down
2 changes: 1 addition & 1 deletion cli/src/commands/version.rs
Expand Up @@ -58,5 +58,5 @@ pub async fn show(ctx: CommandContext) -> Result<i32, AnyError> {
}

fn print_now_using(log: &log::Logger, version: &RequestedVersion, path: &Path) {
log.result(&format!("Now using {} from {}", version, path.display()));
log.result(format!("Now using {} from {}", version, path.display()));
}
15 changes: 13 additions & 2 deletions cli/src/json_rpc.rs
Expand Up @@ -50,7 +50,7 @@ pub async fn start_json_rpc<C: Send + Sync + 'static, S: Clone>(
mut msg_rx: impl Receivable<Vec<u8>>,
mut shutdown_rx: Barrier<S>,
) -> io::Result<Option<S>> {
let (write_tx, mut write_rx) = mpsc::unbounded_channel::<Vec<u8>>();
let (write_tx, mut write_rx) = mpsc::channel::<Vec<u8>>(8);
let mut read = BufReader::new(read);

let mut read_buf = String::new();
Expand Down Expand Up @@ -84,7 +84,18 @@ pub async fn start_json_rpc<C: Send + Sync + 'static, S: Clone>(
let write_tx = write_tx.clone();
tokio::spawn(async move {
if let Some(v) = fut.await {
write_tx.send(v).ok();
let _ = write_tx.send(v).await;
}
});
},
MaybeSync::Stream((dto, fut)) => {
if let Some(dto) = dto {
dispatcher.register_stream(write_tx.clone(), dto).await;
}
let write_tx = write_tx.clone();
tokio::spawn(async move {
if let Some(v) = fut.await {
let _ = write_tx.send(v).await;
}
});
}
Expand Down
10 changes: 4 additions & 6 deletions cli/src/log.rs
Expand Up @@ -27,21 +27,19 @@ pub fn next_counter() -> u32 {

// Log level
#[derive(clap::ArgEnum, PartialEq, Eq, PartialOrd, Clone, Copy, Debug, Serialize, Deserialize)]
#[derive(Default)]
pub enum Level {
Trace = 0,
Debug,
Info,
#[default]
Info,
Warn,
Error,
Critical,
Off,
}

impl Default for Level {
fn default() -> Self {
Level::Info
}
}


impl fmt::Display for Level {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
Expand Down
65 changes: 53 additions & 12 deletions cli/src/msgpack_rpc.rs
Expand Up @@ -8,6 +8,7 @@ use tokio::{
pin,
sync::mpsc,
};
use tokio_util::codec::Decoder;

use crate::{
rpc::{self, MaybeSync, Serialization},
Expand Down Expand Up @@ -38,42 +39,52 @@ pub fn new_msgpack_rpc() -> rpc::RpcBuilder<MsgPackSerializer> {
rpc::RpcBuilder::new(MsgPackSerializer {})
}

#[allow(clippy::read_zero_byte_vec)] // false positive
pub async fn start_msgpack_rpc<C: Send + Sync + 'static, S: Clone>(
dispatcher: rpc::RpcDispatcher<MsgPackSerializer, C>,
read: impl AsyncRead + Unpin,
mut write: impl AsyncWrite + Unpin,
mut msg_rx: impl Receivable<Vec<u8>>,
mut shutdown_rx: Barrier<S>,
) -> io::Result<Option<S>> {
let (write_tx, mut write_rx) = mpsc::unbounded_channel::<Vec<u8>>();
let (write_tx, mut write_rx) = mpsc::channel::<Vec<u8>>(8);
let mut read = BufReader::new(read);
let mut decode_buf = vec![];
let mut decoder = U32PrefixedCodec {};
let mut decoder_buf = bytes::BytesMut::new();

let shutdown_fut = shutdown_rx.wait();
pin!(shutdown_fut);

loop {
tokio::select! {
u = read.read_u32() => {
let msg_length = u? as usize;
decode_buf.resize(msg_length, 0);
tokio::select! {
r = read.read_exact(&mut decode_buf) => match dispatcher.dispatch(&decode_buf[..r?]) {
r = read.read_buf(&mut decoder_buf) => {
r?;

while let Some(frame) = decoder.decode(&mut decoder_buf)? {
match dispatcher.dispatch(&frame) {
MaybeSync::Sync(Some(v)) => {
write_tx.send(v).ok();
let _ = write_tx.send(v).await;
},
MaybeSync::Sync(None) => continue,
MaybeSync::Future(fut) => {
let write_tx = write_tx.clone();
tokio::spawn(async move {
if let Some(v) = fut.await {
write_tx.send(v).ok();
let _ = write_tx.send(v).await;
}
});
}
},
r = &mut shutdown_fut => return Ok(r.ok()),
MaybeSync::Stream((stream, fut)) => {
if let Some(stream) = stream {
dispatcher.register_stream(write_tx.clone(), stream).await;
}
let write_tx = write_tx.clone();
tokio::spawn(async move {
if let Some(v) = fut.await {
let _ = write_tx.send(v).await;
}
});
}
}
};
},
Some(m) = write_rx.recv() => {
Expand All @@ -88,3 +99,33 @@ pub async fn start_msgpack_rpc<C: Send + Sync + 'static, S: Clone>(
write.flush().await?;
}
}

/// Reader that reads length-prefixed msgpack messages in a cancellation-safe
/// way using Tokio's codecs.
pub struct U32PrefixedCodec {}

const U32_SIZE: usize = 4;

impl tokio_util::codec::Decoder for U32PrefixedCodec {
type Item = Vec<u8>;
type Error = io::Error;

fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if src.len() < 4 {
src.reserve(U32_SIZE - src.len());
return Ok(None);
}

let mut be_bytes = [0; U32_SIZE];
be_bytes.copy_from_slice(&src[..U32_SIZE]);
let required_len = U32_SIZE + (u32::from_be_bytes(be_bytes) as usize);
if src.len() < required_len {
src.reserve(required_len - src.len());
return Ok(None);
}

let msg = src[U32_SIZE..].to_vec();
src.resize(0, 0);
Ok(Some(msg))
}
}