Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,6 @@ iroh = { version = "0.95" }
iroh-base = { version = "0.95" }
quinn = { package = "iroh-quinn", version = "0.14.0", default-features = false }
futures-util = { version = "0.3", features = ["sink"] }

[patch.crates-io]
iroh = { git = "https://github.com/n0-computer/iroh.git", branch = "connection-state" }
67 changes: 10 additions & 57 deletions irpc-iroh/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
use std::{
fmt,
future::Future,
io,
fmt, io,
sync::{atomic::AtomicU64, Arc},
};

use iroh::{
endpoint::{
Accepting, ConnectingError, Connection, ConnectionError, IncomingZeroRttConnection,
OutgoingZeroRttConnection, RecvStream, RemoteEndpointIdError, SendStream, VarInt,
ZeroRttStatus,
Accepting, ConnectingError, Connection, ConnectionError, ConnectionState,
OutgoingZeroRttConnection, RecvStream, SendStream, ZeroRttStatus,
},
protocol::{AcceptError, ProtocolHandler},
EndpointId,
};
use irpc::{
channel::oneshot,
Expand Down Expand Up @@ -289,12 +285,13 @@ impl<R: DeserializeOwned + Send + 'static> ProtocolHandler for Iroh0RttProtocol<

/// Handles a single iroh connection with the provided `handler`.
pub async fn handle_connection<R: DeserializeOwned + 'static>(
connection: &impl IncomingRemoteConnection,
connection: &Connection<impl ConnectionState>,
handler: Handler<R>,
) -> io::Result<()> {
if let Ok(remote) = connection.remote_id() {
tracing::Span::current().record("remote", tracing::field::display(remote.fmt_short()));
}
// We might not have a handshaked connection yet, in which case we don't know the remote endpoint id.
// if let Ok(remote) = connection.remote_id() {
// tracing::Span::current().record("remote", tracing::field::display(remote.fmt_short()));
// }
debug!("connection accepted");
loop {
let Some((msg, rx, tx)) = read_request_raw(connection).await? else {
Expand All @@ -306,57 +303,13 @@ pub async fn handle_connection<R: DeserializeOwned + 'static>(

/// Reads a single request from a connection, and a message with channels.
pub async fn read_request<S: RemoteService>(
connection: &impl IncomingRemoteConnection,
connection: &Connection<impl ConnectionState>,
) -> std::io::Result<Option<S::Message>> {
Ok(read_request_raw::<S>(connection)
.await?
.map(|(msg, rx, tx)| S::with_remote_channels(msg, rx, tx)))
}

/// Abstracts over [`Connection`] and [`IncomingZeroRttConnection`].
///
/// You don't need to implement this trait yourself. It is used by [`read_request`] and
/// [`handle_connection`] to work with both fully authenticated connections and with
/// 0-RTT connections.
pub trait IncomingRemoteConnection {
/// Accepts a single bidirectional stream.
fn accept_bi(
&self,
) -> impl Future<Output = Result<(SendStream, RecvStream), ConnectionError>> + Send;
/// Close the connection.
fn close(&self, error_code: VarInt, reason: &[u8]);
/// Returns the remote's endpoint id.
///
/// This may only fail for 0-RTT connections.
fn remote_id(&self) -> Result<EndpointId, RemoteEndpointIdError>;
}

impl IncomingRemoteConnection for IncomingZeroRttConnection {
async fn accept_bi(&self) -> Result<(SendStream, RecvStream), ConnectionError> {
self.accept_bi().await
}

fn close(&self, error_code: VarInt, reason: &[u8]) {
self.close(error_code, reason)
}
fn remote_id(&self) -> Result<EndpointId, RemoteEndpointIdError> {
self.remote_id()
}
}

impl IncomingRemoteConnection for Connection {
async fn accept_bi(&self) -> Result<(SendStream, RecvStream), ConnectionError> {
self.accept_bi().await
}

fn close(&self, error_code: VarInt, reason: &[u8]) {
self.close(error_code, reason)
}
fn remote_id(&self) -> Result<EndpointId, RemoteEndpointIdError> {
Ok(self.remote_id())
}
}

/// Reads a single request from the connection.
///
/// This accepts a bi-directional stream from the connection and reads and parses the request.
Expand All @@ -365,7 +318,7 @@ impl IncomingRemoteConnection for Connection {
/// Returns None if the remote closed the connection with error code `0`.
/// Returns an error for all other failure cases.
pub async fn read_request_raw<R: DeserializeOwned + 'static>(
connection: &impl IncomingRemoteConnection,
connection: &Connection<impl ConnectionState>,
) -> std::io::Result<Option<(R, RecvStream, SendStream)>> {
let (send, mut recv) = match connection.accept_bi().await {
Ok((s, r)) => (s, r),
Expand Down
Loading