Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ jobs:
steps:
- uses: actions/checkout@v2
- uses: dtolnay/rust-toolchain@stable
with:
components: rustfmt, clippy
- uses: swatinem/rust-cache@v2
- name: cargo fmt
run: cargo fmt --all -- --check
Expand Down
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ crate-type = ["cdylib", "rlib"]
[dependencies]
# we require serde even in non-rpc mode
serde = { workspace = true }
# just for the oneshot and mpsc queues
tokio = { workspace = true, features = ["sync"] }
# just for the oneshot and mpsc queues, and tokio::select!
tokio = { workspace = true, features = ["sync", "macros"] }
# for PollSender (which for some reason is not available in the main tokio api)
tokio-util = { version = "0.7.14", default-features = false }
# errors
Expand Down
2 changes: 1 addition & 1 deletion examples/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ pub async fn reference_bench(n: u64) -> anyhow::Result<()> {

#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt().init();
tracing_subscriber::fmt::init();
println!("Local use");
local().await?;
println!("Remote use");
Expand Down
2 changes: 1 addition & 1 deletion examples/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ async fn remote() -> anyhow::Result<()> {

#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt().init();
tracing_subscriber::fmt::init();
println!("Local use");
local().await?;
println!("Remote use");
Expand Down
2 changes: 1 addition & 1 deletion irpc-iroh/examples/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use self::storage::StorageApi;

#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::fmt().init();
tracing_subscriber::fmt::init();
println!("Local use");
local().await?;
println!("Remote use");
Expand Down
33 changes: 25 additions & 8 deletions irpc-iroh/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use irpc::{
};
use n0_future::{future::Boxed as BoxFuture, TryFutureExt};
use serde::de::DeserializeOwned;
use tracing::{trace, trace_span, warn, Instrument};
use tracing::{debug, error_span, trace, trace_span, warn, Instrument};

/// Returns a client that connects to a irpc service using an [`iroh::Endpoint`].
pub fn client<S: irpc::Service>(
Expand Down Expand Up @@ -207,6 +207,10 @@ pub async fn handle_connection<R: DeserializeOwned + 'static>(
connection: Connection,
handler: Handler<R>,
) -> io::Result<()> {
if let Ok(remote) = connection.remote_node_id() {
tracing::Span::current().record("remote", tracing::field::display(remote.fmt_short()));
}
debug!("connection accepted");
loop {
let Some((msg, rx, tx)) = read_request_raw(&connection).await? else {
return Ok(());
Expand Down Expand Up @@ -270,19 +274,32 @@ pub async fn read_request_raw<R: DeserializeOwned + 'static>(
pub async fn listen<R: DeserializeOwned + 'static>(endpoint: iroh::Endpoint, handler: Handler<R>) {
let mut request_id = 0u64;
let mut tasks = n0_future::task::JoinSet::new();
while let Some(incoming) = endpoint.accept().await {
loop {
let incoming = tokio::select! {
Some(res) = tasks.join_next(), if !tasks.is_empty() => {
res.expect("irpc connection task panicked");
continue;
}
incoming = endpoint.accept() => {
match incoming {
None => break,
Some(incoming) => incoming
}
}
};
let handler = handler.clone();
let fut = async move {
let connection = match incoming.await {
Ok(connection) => connection,
match incoming.await {
Ok(connection) => match handle_connection(connection, handler).await {
Err(err) => warn!("connection closed with error: {err:?}"),
Ok(()) => debug!("connection closed"),
},
Err(cause) => {
warn!("failed to accept connection {cause:?}");
return io::Result::Ok(());
warn!("failed to accept connection: {cause:?}");
}
};
handle_connection(connection, handler).await
};
let span = trace_span!("rpc", id = request_id);
let span = error_span!("rpc", id = request_id, remote = tracing::field::Empty);
tasks.spawn(fut.instrument(span));
request_id += 1;
}
Expand Down
34 changes: 26 additions & 8 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1534,7 +1534,7 @@ pub mod rpc {
use quinn::ConnectionError;
use serde::de::DeserializeOwned;
use smallvec::SmallVec;
use tracing::{trace, trace_span, warn, Instrument};
use tracing::{debug, error_span, trace, warn, Instrument};

use crate::{
channel::{
Expand Down Expand Up @@ -2054,19 +2054,32 @@ pub mod rpc {
) {
let mut request_id = 0u64;
let mut tasks = JoinSet::new();
while let Some(incoming) = endpoint.accept().await {
loop {
let incoming = tokio::select! {
Some(res) = tasks.join_next(), if !tasks.is_empty() => {
res.expect("irpc connection task panicked");
continue;
}
incoming = endpoint.accept() => {
match incoming {
None => break,
Some(incoming) => incoming
}
}
};
let handler = handler.clone();
let fut = async move {
let connection = match incoming.await {
Ok(connection) => connection,
match incoming.await {
Ok(connection) => match handle_connection(connection, handler).await {
Err(err) => warn!("connection closed with error: {err:?}"),
Ok(()) => debug!("connection closed"),
},
Err(cause) => {
warn!("failed to accept connection {cause:?}");
return io::Result::Ok(());
warn!("failed to accept connection: {cause:?}");
}
};
handle_connection(connection, handler).await
};
let span = trace_span!("rpc", id = request_id);
let span = error_span!("rpc", id = request_id, remote = tracing::field::Empty);
tasks.spawn(fut.instrument(span));
request_id += 1;
}
Expand All @@ -2077,6 +2090,11 @@ pub mod rpc {
connection: quinn::Connection,
handler: Handler<R>,
) -> io::Result<()> {
tracing::Span::current().record(
"remote",
tracing::field::display(connection.remote_address()),
);
debug!("connection accepted");
loop {
let Some((msg, rx, tx)) = read_request_raw(&connection).await? else {
return Ok(());
Expand Down
8 changes: 6 additions & 2 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#[cfg(feature = "quinn_endpoint_setup")]
#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "quinn_endpoint_setup")))]
mod quinn_setup_utils {
use std::sync::Arc;
use std::{sync::Arc, time::Duration};

use anyhow::Result;
use quinn::{crypto::rustls::QuicClientConfig, ClientConfig, ServerConfig};
Expand All @@ -31,7 +31,11 @@ mod quinn_setup_utils {
let quic_client_config =
quinn::crypto::rustls::QuicClientConfig::try_from(crypto_client_config)?;

Ok(ClientConfig::new(Arc::new(quic_client_config)))
let mut transport_config = quinn::TransportConfig::default();
transport_config.keep_alive_interval(Some(Duration::from_secs(1)));
let mut client_config = ClientConfig::new(Arc::new(quic_client_config));
client_config.transport_config(Arc::new(transport_config));
Ok(client_config)
}

/// Create a quinn server config with a self-signed certificate
Expand Down
Loading