From 3b37c0990f509013e8c55807213b5cbcf4fde2f0 Mon Sep 17 00:00:00 2001 From: peg Date: Tue, 18 Nov 2025 09:34:36 +0100 Subject: [PATCH 1/7] Proxy client re-uses connection for multiple requests --- src/lib.rs | 240 ++++++++++++++++++++++++++++++----------------------- 1 file changed, 134 insertions(+), 106 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 26609fc..06cc45b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,6 +11,7 @@ use hyper::service::service_fn; use hyper::Response; use hyper_util::rt::TokioIo; use thiserror::Error; +use tokio::sync::{mpsc, oneshot}; use tokio_rustls::rustls::server::{VerifierBuilderError, WebPkiClientVerifier}; #[cfg(test)] @@ -306,11 +307,15 @@ fn full>(chunk: T) -> BoxBody { pub struct ProxyClient { inner: Proxy, - connector: TlsConnector, - /// The host and port of the proxy server - target: String, - /// Certificate chain for client auth - cert_chain: Option>>, + // connector: TlsConnector, + // /// The host and port of the proxy server + // target: String, + // /// Certificate chain for client auth + // cert_chain: Option>> + requests_tx: mpsc::Sender<( + http::Request, + oneshot::Sender>, hyper::Error>>, + )>, } impl ProxyClient { @@ -377,78 +382,122 @@ impl ProxyClient { let inner = Proxy { listener, + local_quote_generator: local_quote_generator.clone(), + attestation_verifier: attestation_verifier.clone(), + }; + + let target = host_to_host_with_port(&target_name); + + // TODO connect to server and attest + // start run loop with channels + // return struct with sender + // + let (tls_stream, measurements, remote_attestation_type) = Self::setup_connection( + connector.clone(), + target, + cert_chain.clone(), local_quote_generator, attestation_verifier, - }; + ) + .await?; + + let outbound_io = TokioIo::new(tls_stream); + let (mut sender, conn) = hyper::client::conn::http1::Builder::new() + .handshake::<_, hyper::body::Incoming>(outbound_io) + .await?; + + // Drive the connection + tokio::spawn(async move { + if let Err(e) = conn.await { + eprintln!("Client connection error: {e}"); + } + }); + + // Channel for getting incoming requests from the source client + let (requests_tx, mut requests_rx) = mpsc::channel::<( + http::Request, + oneshot::Sender< + Result>, hyper::Error>, + >, + )>(1024); + + tokio::spawn(async move { + while let Some((req, response_tx)) = requests_rx.recv().await { + let resp = match sender.send_request(req).await { + Ok(mut resp) => { + if let Some(measurements) = measurements.clone() { + let headers = resp.headers_mut(); + match measurements.to_header_format() { + Ok(header_value) => { + headers.insert(MEASUREMENT_HEADER, header_value); + } + Err(e) => { + // This error is highly unlikely - that the measurement values fail to + // encode to JSON or fit in an HTTP header + eprintln!("Failed to encode measurement values: {e}"); + } + } + headers.insert( + ATTESTATION_TYPE_HEADER, + HeaderValue::from_str(remote_attestation_type.as_str()).unwrap(), + ); + } + Ok(resp.map(|b| b.boxed())) + } + Err(e) => { + eprintln!("send_request error: {e}"); + let mut resp = Response::new(full(format!("Request failed: {e}"))); + *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; + Ok(resp) + } + }; + // Send the response back to the source client + response_tx.send(resp).unwrap(); + } + }); Ok(Self { inner, - connector, - target: host_to_host_with_port(&target_name), - cert_chain, + // connector, + // target: host_to_host_with_port(&target_name), + // cert_chain, + requests_tx, }) } - /// Accept an incoming connection and handle it + /// Helper to return the local socket address from the underlying TCP listener + pub fn local_addr(&self) -> std::io::Result { + self.inner.listener.local_addr() + } + + /// Accept an incoming connection and handle it in a separate task pub async fn accept(&self) -> io::Result<()> { let (inbound, _client_addr) = self.inner.listener.accept().await?; - let connector = self.connector.clone(); - let target = self.target.clone(); - let local_quote_generator = self.inner.local_quote_generator.clone(); - let attestation_verifier = self.inner.attestation_verifier.clone(); - let cert_chain = self.cert_chain.clone(); + let requests_tx = self.requests_tx.clone(); tokio::spawn(async move { - if let Err(err) = Self::handle_connection( - inbound, - connector, - target, - cert_chain, - local_quote_generator, - attestation_verifier, - ) - .await - { - eprintln!("Failed to handle connection: {err}"); + if let Err(err) = Self::handle_connection(inbound, requests_tx).await { + eprintln!("Failed to handle connection from source client: {err}"); } }); Ok(()) } - /// Helper to return the local socket address from the underlying TCP listener - pub fn local_addr(&self) -> std::io::Result { - self.inner.listener.local_addr() - } - /// Handle an incoming connection async fn handle_connection( inbound: TcpStream, - connector: TlsConnector, - target: String, - cert_chain: Option>>, - local_quote_generator: Arc, - attestation_verifier: AttestationVerifier, + requests_tx: mpsc::Sender<( + http::Request, + oneshot::Sender>, hyper::Error>>, + )>, ) -> Result<(), ProxyError> { let http = Builder::new(); let service = service_fn(move |req| { - let connector = connector.clone(); - let target = target.clone(); - let cert_chain = cert_chain.clone(); - let local_quote_generator = local_quote_generator.clone(); - let attestation_verifier = attestation_verifier.clone(); + let requests_tx = requests_tx.clone(); async move { - match Self::handle_http_request( - req, - connector, - target, - cert_chain, - local_quote_generator, - attestation_verifier, - ) - .await - { + match Self::handle_http_request(req, requests_tx).await { Ok(res) => { Ok::>, hyper::Error>(res) } @@ -537,62 +586,14 @@ impl ProxyClient { // Handle a request from the source client to the proxy server async fn handle_http_request( req: hyper::Request, - connector: TlsConnector, - target: String, - cert_chain: Option>>, - local_quote_generator: Arc, - attestation_verifier: AttestationVerifier, + requests_tx: mpsc::Sender<( + http::Request, + oneshot::Sender>, hyper::Error>>, + )>, ) -> Result>, ProxyError> { - let (tls_stream, measurements, remote_attestation_type) = Self::setup_connection( - connector, - target, - cert_chain, - local_quote_generator, - attestation_verifier, - ) - .await?; - - // Now the attestation is done, forward the request to the proxy server - let outbound_io = TokioIo::new(tls_stream); - let (mut sender, conn) = hyper::client::conn::http1::Builder::new() - .handshake::<_, hyper::body::Incoming>(outbound_io) - .await?; - - // Drive the connection - tokio::spawn(async move { - if let Err(e) = conn.await { - eprintln!("Client connection error: {e}"); - } - }); - - match sender.send_request(req).await { - Ok(mut resp) => { - if let Some(measurements) = measurements { - let headers = resp.headers_mut(); - match measurements.to_header_format() { - Ok(header_value) => { - headers.insert(MEASUREMENT_HEADER, header_value); - } - Err(e) => { - // This error is highly unlikely - that the measurement values fail to - // encode to JSON or fit in an HTTP header - eprintln!("Failed to encode measurement values: {e}"); - } - } - headers.insert( - ATTESTATION_TYPE_HEADER, - HeaderValue::from_str(remote_attestation_type.as_str()).unwrap(), - ); - } - Ok(resp.map(|b| b.boxed())) - } - Err(e) => { - eprintln!("send_request error: {e}"); - let mut resp = Response::new(full(format!("Request failed: {e}"))); - *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; - Ok(resp) - } - } + let (response_tx, response_rx) = oneshot::channel(); + requests_tx.send((req, response_tx)).await.unwrap(); + Ok(response_rx.await.unwrap()?) } } @@ -808,6 +809,7 @@ mod tests { let proxy_addr = proxy_server.local_addr().unwrap(); tokio::spawn(async move { + // Accept one connection, then finish proxy_server.accept().await.unwrap(); }); @@ -827,6 +829,8 @@ mod tests { let proxy_client_addr = proxy_client.local_addr().unwrap(); tokio::spawn(async move { + // Accept two connections, then finish + proxy_client.accept().await.unwrap(); proxy_client.accept().await.unwrap(); }); @@ -852,6 +856,30 @@ mod tests { // handler puts them there) let measurements = Measurements::from_header_format(&res_body).unwrap(); assert_eq!(measurements, default_measurements()); + + // Now do another request - to check that the connection has stayed open + let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + .await + .unwrap(); + + let headers = res.headers(); + let measurements_json = headers.get(MEASUREMENT_HEADER).unwrap().to_str().unwrap(); + let measurements = Measurements::from_header_format(measurements_json).unwrap(); + assert_eq!(measurements, default_measurements()); + + let attestation_type = headers + .get(ATTESTATION_TYPE_HEADER) + .unwrap() + .to_str() + .unwrap(); + assert_eq!(attestation_type, AttestationType::DcapTdx.as_str()); + + let res_body = res.text().await.unwrap(); + + // The response body shows us what was in the request header (as the test http server + // handler puts them there) + let measurements = Measurements::from_header_format(&res_body).unwrap(); + assert_eq!(measurements, default_measurements()); } #[tokio::test] From ec1a238830b26944757fc62934067a34c369a08a Mon Sep 17 00:00:00 2001 From: peg Date: Tue, 18 Nov 2025 09:49:37 +0100 Subject: [PATCH 2/7] Use http2 for proxy-client to proxy-server connections --- Cargo.lock | 33 +++++++++++++++++++++++++++++++++ Cargo.toml | 2 +- src/lib.rs | 50 ++++++++++++++++++++++++++++---------------------- 3 files changed, 62 insertions(+), 23 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c1f3904..d0666eb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -902,6 +902,25 @@ dependencies = [ "subtle", ] +[[package]] +name = "h2" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3c0b69cfcb4e1b9f1bf2f53f95f766e4661169728ec61cd3fe5a0166f2d1386" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "hashbrown" version = "0.16.0" @@ -1031,6 +1050,7 @@ dependencies = [ "bytes", "futures-channel", "futures-core", + "h2", "http", "http-body", "httparse", @@ -2410,6 +2430,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-util" +version = "0.7.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2efa149fe76073d6e8fd97ef4f4eca7b67f599660115591483572e406e165594" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "toml_datetime" version = "0.7.3" diff --git a/Cargo.toml b/Cargo.toml index 3140040..06b72fa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ configfs-tsm = "0.0.2" rand_core = { version = "0.6.4", features = ["getrandom"] } dcap-qvl = "0.3.4" hex = "0.4.3" -hyper = { version = "1.7.0", features = ["server"] } +hyper = { version = "1.7.0", features = ["server", "http2"] } hyper-util = "0.1.17" http-body-util = "0.1.3" bytes = "1.10.1" diff --git a/src/lib.rs b/src/lib.rs index 06cc45b..63e64bf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,7 +6,6 @@ use bytes::Bytes; use http::HeaderValue; use http_body_util::combinators::BoxBody; use http_body_util::BodyExt; -use hyper::server::conn::http1::Builder; use hyper::service::service_fn; use hyper::Response; use hyper_util::rt::TokioIo; @@ -144,7 +143,7 @@ impl ProxyServer { let attestation_verifier = self.inner.attestation_verifier.clone(); tokio::spawn(async move { if let Err(err) = Self::handle_connection( - inbound, // TODO should be AttestationType + inbound, acceptor, target, cert_chain, @@ -160,6 +159,7 @@ impl ProxyServer { Ok(()) } + /// Helper to get the socket address of the underlying TCP listener pub fn local_addr(&self) -> std::io::Result { self.inner.listener.local_addr() } @@ -226,7 +226,7 @@ impl ProxyServer { (None, AttestationType::None) }; - let http = Builder::new(); + let http = hyper::server::conn::http2::Builder::new(TokioExecutor); let service = service_fn(move |mut req| { // If we have measurements, add them to the request header let measurements = measurements.clone(); @@ -245,7 +245,8 @@ impl ProxyServer { } headers.insert( ATTESTATION_TYPE_HEADER, - HeaderValue::from_str(remote_attestation_type.as_str()).unwrap(), + HeaderValue::from_str(remote_attestation_type.as_str()) + .expect("Attestation type should be able to be encoded as a header value"), ); } @@ -255,7 +256,7 @@ impl ProxyServer { Ok::>, hyper::Error>(res) } Err(e) => { - eprintln!("send_request error: {e}"); + eprintln!("Failed to handle a request from a proxy-client: {e}"); let mut resp = Response::new(full(format!("Request failed: {e}"))); *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; Ok(resp) @@ -306,12 +307,7 @@ fn full>(chunk: T) -> BoxBody { } pub struct ProxyClient { - inner: Proxy, - // connector: TlsConnector, - // /// The host and port of the proxy server - // target: String, - // /// Certificate chain for client auth - // cert_chain: Option>> + listener: TcpListener, requests_tx: mpsc::Sender<( http::Request, oneshot::Sender>, hyper::Error>>, @@ -380,12 +376,6 @@ impl ProxyClient { let listener = TcpListener::bind(local).await?; let connector = TlsConnector::from(client_config.clone()); - let inner = Proxy { - listener, - local_quote_generator: local_quote_generator.clone(), - attestation_verifier: attestation_verifier.clone(), - }; - let target = host_to_host_with_port(&target_name); // TODO connect to server and attest @@ -402,7 +392,7 @@ impl ProxyClient { .await?; let outbound_io = TokioIo::new(tls_stream); - let (mut sender, conn) = hyper::client::conn::http1::Builder::new() + let (mut sender, conn) = hyper::client::conn::http2::Builder::new(TokioExecutor) .handshake::<_, hyper::body::Incoming>(outbound_io) .await?; @@ -457,7 +447,7 @@ impl ProxyClient { }); Ok(Self { - inner, + listener, // connector, // target: host_to_host_with_port(&target_name), // cert_chain, @@ -467,12 +457,12 @@ impl ProxyClient { /// Helper to return the local socket address from the underlying TCP listener pub fn local_addr(&self) -> std::io::Result { - self.inner.listener.local_addr() + self.listener.local_addr() } /// Accept an incoming connection and handle it in a separate task pub async fn accept(&self) -> io::Result<()> { - let (inbound, _client_addr) = self.inner.listener.accept().await?; + let (inbound, _client_addr) = self.listener.accept().await?; let requests_tx = self.requests_tx.clone(); @@ -493,7 +483,7 @@ impl ProxyClient { oneshot::Sender>, hyper::Error>>, )>, ) -> Result<(), ProxyError> { - let http = Builder::new(); + let http = hyper::server::conn::http1::Builder::new(); let service = service_fn(move |req| { let requests_tx = requests_tx.clone(); async move { @@ -703,6 +693,22 @@ fn server_name_from_host( ServerName::try_from(host_part.to_string()) } +/// An Executor for hyper that uses the tokio runtime +#[derive(Clone)] +struct TokioExecutor; + +// Implement the `hyper::rt::Executor` trait for `TokioExecutor` so that it can be used to spawn +// tasks in the hyper runtime. +impl hyper::rt::Executor for TokioExecutor +where + F: std::future::Future + Send + 'static, + F::Output: Send + 'static, +{ + fn execute(&self, fut: F) { + tokio::task::spawn(fut); + } +} + #[cfg(test)] mod tests { use super::*; From 3952f7f109b1dc24f3349aa8882f53de646029a3 Mon Sep 17 00:00:00 2001 From: peg Date: Tue, 18 Nov 2025 10:08:11 +0100 Subject: [PATCH 3/7] Error handling --- src/lib.rs | 55 +++++++++++++++++++++++++++++------------------------- 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 63e64bf..3a8cf9a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -37,6 +37,11 @@ const ATTESTATION_TYPE_HEADER: &str = "X-Flashbots-Attestation-Type"; /// The header name for giving measurements const MEASUREMENT_HEADER: &str = "X-Flashbots-Measurement"; +type RequestWithResponseSender = ( + http::Request, + oneshot::Sender>, hyper::Error>>, +); + /// TLS Credentials pub struct TlsCertAndKey { /// Der-encoded TLS certificate chain @@ -308,10 +313,7 @@ fn full>(chunk: T) -> BoxBody { pub struct ProxyClient { listener: TcpListener, - requests_tx: mpsc::Sender<( - http::Request, - oneshot::Sender>, hyper::Error>>, - )>, + requests_tx: mpsc::Sender, } impl ProxyClient { @@ -362,7 +364,7 @@ impl ProxyClient { .await } - /// Create a new proxy with given TLS configuration + /// Create a new proxy client with given TLS configuration /// /// This is private as it allows dangerous configuration but is used in tests async fn new_with_tls_config( @@ -378,10 +380,8 @@ impl ProxyClient { let target = host_to_host_with_port(&target_name); - // TODO connect to server and attest - // start run loop with channels - // return struct with sender - // + // Connect to the proxy server + // TODO this should run in a loop, reconnecting when the connection is lost let (tls_stream, measurements, remote_attestation_type) = Self::setup_connection( connector.clone(), target, @@ -429,28 +429,29 @@ impl ProxyClient { } headers.insert( ATTESTATION_TYPE_HEADER, - HeaderValue::from_str(remote_attestation_type.as_str()).unwrap(), + HeaderValue::from_str(remote_attestation_type.as_str()) + .expect("Attestation type should be able to be encoded as a header value"), ); } Ok(resp.map(|b| b.boxed())) } Err(e) => { - eprintln!("send_request error: {e}"); + eprintln!("Failed to send request to proxy-server: {e}"); let mut resp = Response::new(full(format!("Request failed: {e}"))); *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; Ok(resp) } }; + // Send the response back to the source client - response_tx.send(resp).unwrap(); + if response_tx.send(resp).is_err() { + eprintln!("Failed to forward response to source client, probably they dropped the connection"); + } } }); Ok(Self { listener, - // connector, - // target: host_to_host_with_port(&target_name), - // cert_chain, requests_tx, }) } @@ -478,10 +479,7 @@ impl ProxyClient { /// Handle an incoming connection async fn handle_connection( inbound: TcpStream, - requests_tx: mpsc::Sender<( - http::Request, - oneshot::Sender>, hyper::Error>>, - )>, + requests_tx: mpsc::Sender, ) -> Result<(), ProxyError> { let http = hyper::server::conn::http1::Builder::new(); let service = service_fn(move |req| { @@ -576,14 +574,11 @@ impl ProxyClient { // Handle a request from the source client to the proxy server async fn handle_http_request( req: hyper::Request, - requests_tx: mpsc::Sender<( - http::Request, - oneshot::Sender>, hyper::Error>>, - )>, + requests_tx: mpsc::Sender, ) -> Result>, ProxyError> { let (response_tx, response_rx) = oneshot::channel(); - requests_tx.send((req, response_tx)).await.unwrap(); - Ok(response_rx.await.unwrap()?) + requests_tx.send((req, response_tx)).await?; + Ok(response_rx.await??) } } @@ -664,6 +659,16 @@ pub enum ProxyError { Hyper(#[from] hyper::Error), #[error("JSON: {0}")] Json(#[from] serde_json::Error), + #[error("Could not forward response - sender was dropped")] + OneShotRecv(#[from] oneshot::error::RecvError), + #[error("Failed to send request, connection to proxy-server dropped")] + MpscSend, +} + +impl From> for ProxyError { + fn from(_err: mpsc::error::SendError) -> Self { + Self::MpscSend + } } /// Given a byte array, encode its length as a 4 byte big endian u32 From f3e00bf20e48ef574d904be51a3a8907d92afa66 Mon Sep 17 00:00:00 2001 From: peg Date: Tue, 18 Nov 2025 10:28:39 +0100 Subject: [PATCH 4/7] Reconnect on failed server connection --- src/lib.rs | 106 +++++++++++++++++++++++++++-------------------------- 1 file changed, 54 insertions(+), 52 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 3a8cf9a..1ea817a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,6 +41,7 @@ type RequestWithResponseSender = ( http::Request, oneshot::Sender>, hyper::Error>>, ); +type Http2Sender = hyper::client::conn::http2::SendRequest; /// TLS Credentials pub struct TlsCertAndKey { @@ -50,19 +51,14 @@ pub struct TlsCertAndKey { pub key: PrivateKeyDer<'static>, } -/// Inner struct used by [ProxyClient] and [ProxyServer] -struct Proxy { +/// A TLS over TCP server which provides an attestation before forwarding traffic to a given target address +pub struct ProxyServer { /// The underlying TCP listener listener: TcpListener, /// Quote generation type to use (including none) local_quote_generator: Arc, /// Verifier for remote attestation (including none) attestation_verifier: AttestationVerifier, -} - -/// A TLS over TCP server which provides an attestation before forwarding traffic to a given target address -pub struct ProxyServer { - inner: Proxy, /// The certificate chain cert_chain: Vec>, /// For accepting TLS connections @@ -123,29 +119,25 @@ impl ProxyServer { let acceptor = tokio_rustls::TlsAcceptor::from(server_config); let listener = TcpListener::bind(local).await?; - let inner = Proxy { + Ok(Self { listener, local_quote_generator, attestation_verifier, - }; - - Ok(Self { acceptor, target, - inner, cert_chain, }) } /// Accept an incoming connection pub async fn accept(&self) -> Result<(), ProxyError> { - let (inbound, _client_addr) = self.inner.listener.accept().await?; + let (inbound, _client_addr) = self.listener.accept().await?; let acceptor = self.acceptor.clone(); let target = self.target; let cert_chain = self.cert_chain.clone(); - let local_quote_generator = self.inner.local_quote_generator.clone(); - let attestation_verifier = self.inner.attestation_verifier.clone(); + let local_quote_generator = self.local_quote_generator.clone(); + let attestation_verifier = self.attestation_verifier.clone(); tokio::spawn(async move { if let Err(err) = Self::handle_connection( inbound, @@ -166,7 +158,7 @@ impl ProxyServer { /// Helper to get the socket address of the underlying TCP listener pub fn local_addr(&self) -> std::io::Result { - self.inner.listener.local_addr() + self.listener.local_addr() } async fn handle_connection( @@ -380,29 +372,6 @@ impl ProxyClient { let target = host_to_host_with_port(&target_name); - // Connect to the proxy server - // TODO this should run in a loop, reconnecting when the connection is lost - let (tls_stream, measurements, remote_attestation_type) = Self::setup_connection( - connector.clone(), - target, - cert_chain.clone(), - local_quote_generator, - attestation_verifier, - ) - .await?; - - let outbound_io = TokioIo::new(tls_stream); - let (mut sender, conn) = hyper::client::conn::http2::Builder::new(TokioExecutor) - .handshake::<_, hyper::body::Incoming>(outbound_io) - .await?; - - // Drive the connection - tokio::spawn(async move { - if let Err(e) = conn.await { - eprintln!("Client connection error: {e}"); - } - }); - // Channel for getting incoming requests from the source client let (requests_tx, mut requests_rx) = mpsc::channel::<( http::Request, @@ -411,9 +380,19 @@ impl ProxyClient { >, )>(1024); + // Connect to the proxy server + let (mut sender, mut measurements, mut remote_attestation_type) = Self::setup_connection( + connector.clone(), + target.clone(), + cert_chain.clone(), + local_quote_generator.clone(), + attestation_verifier.clone(), + ) + .await?; + tokio::spawn(async move { while let Some((req, response_tx)) = requests_rx.recv().await { - let resp = match sender.send_request(req).await { + let (response, should_reconnect) = match sender.send_request(req).await { Ok(mut resp) => { if let Some(measurements) = measurements.clone() { let headers = resp.headers_mut(); @@ -433,20 +412,36 @@ impl ProxyClient { .expect("Attestation type should be able to be encoded as a header value"), ); } - Ok(resp.map(|b| b.boxed())) + (Ok(resp.map(|b| b.boxed())), false) } Err(e) => { eprintln!("Failed to send request to proxy-server: {e}"); let mut resp = Response::new(full(format!("Request failed: {e}"))); *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; - Ok(resp) + + (Ok(resp), true) } }; // Send the response back to the source client - if response_tx.send(resp).is_err() { + if response_tx.send(response).is_err() { eprintln!("Failed to forward response to source client, probably they dropped the connection"); } + + // If the connection to the proxy server failed, reconnect + if should_reconnect { + // Reconnect to the server + // TODO the error should be handled in a backoff loop + (sender, measurements, remote_attestation_type) = Self::setup_connection( + connector.clone(), + target.clone(), + cert_chain.clone(), + local_quote_generator.clone(), + attestation_verifier.clone(), + ) + .await + .unwrap(); + } } }); @@ -511,14 +506,7 @@ impl ProxyClient { cert_chain: Option>>, local_quote_generator: Arc, attestation_verifier: AttestationVerifier, - ) -> Result< - ( - tokio_rustls::client::TlsStream, - Option, - AttestationType, - ), - ProxyError, - > { + ) -> Result<(Http2Sender, Option, AttestationType), ProxyError> { let out = TcpStream::connect(&target).await?; let mut tls_stream = connector .connect(server_name_from_host(&target)?, out) @@ -568,7 +556,21 @@ impl ProxyClient { tls_stream.write_all(&attestation).await?; - Ok((tls_stream, measurements, remote_attestation_type)) + // Attestation is complete - now seturn an HTTP client + + let outbound_io = TokioIo::new(tls_stream); + let (sender, conn) = hyper::client::conn::http2::Builder::new(TokioExecutor) + .handshake::<_, hyper::body::Incoming>(outbound_io) + .await?; + + // Drive the connection + tokio::spawn(async move { + if let Err(e) = conn.await { + eprintln!("Client connection error: {e}"); + } + }); + + Ok((sender, measurements, remote_attestation_type)) } // Handle a request from the source client to the proxy server From 4b9246d88f537bbb9ffdb1da60bdf82241c40d33 Mon Sep 17 00:00:00 2001 From: peg Date: Tue, 18 Nov 2025 11:34:10 +0100 Subject: [PATCH 5/7] Add backoff when reconnecting to the server --- src/lib.rs | 60 ++++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 49 insertions(+), 11 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 1ea817a..5d361ac 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,6 +17,7 @@ use tokio_rustls::rustls::server::{VerifierBuilderError, WebPkiClientVerifier}; mod test_helpers; use std::num::TryFromIntError; +use std::time::Duration; use std::{net::SocketAddr, sync::Arc}; use tokio::io::{self, AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; @@ -32,11 +33,15 @@ use crate::attestation::{AttesationPayload, AttestationVerifier}; /// The label used when exporting key material from a TLS session const EXPORTER_LABEL: &[u8; 24] = b"EXPORTER-Channel-Binding"; +/// The header name for giving attestation type const ATTESTATION_TYPE_HEADER: &str = "X-Flashbots-Attestation-Type"; /// The header name for giving measurements const MEASUREMENT_HEADER: &str = "X-Flashbots-Measurement"; +/// The longest time in seconds to wait between reconnection attempts +const SERVER_RECONNECT_MAX_BACKOFF_SECS: u64 = 120; + type RequestWithResponseSender = ( http::Request, oneshot::Sender>, hyper::Error>>, @@ -430,17 +435,16 @@ impl ProxyClient { // If the connection to the proxy server failed, reconnect if should_reconnect { - // Reconnect to the server - // TODO the error should be handled in a backoff loop - (sender, measurements, remote_attestation_type) = Self::setup_connection( - connector.clone(), - target.clone(), - cert_chain.clone(), - local_quote_generator.clone(), - attestation_verifier.clone(), - ) - .await - .unwrap(); + // Reconnect to the server - retrying with a backoff + (sender, measurements, remote_attestation_type) = + Self::setup_connection_with_backoff( + connector.clone(), + target.clone(), + cert_chain.clone(), + local_quote_generator.clone(), + attestation_verifier.clone(), + ) + .await; } } }); @@ -500,6 +504,40 @@ impl ProxyClient { Ok(()) } + async fn setup_connection_with_backoff( + connector: TlsConnector, + target: String, + cert_chain: Option>>, + local_quote_generator: Arc, + attestation_verifier: AttestationVerifier, + ) -> (Http2Sender, Option, AttestationType) { + let mut delay = Duration::from_secs(1); + let max_delay = Duration::from_secs(SERVER_RECONNECT_MAX_BACKOFF_SECS); + + loop { + match Self::setup_connection( + connector.clone(), + target.clone(), + cert_chain.clone(), + local_quote_generator.clone(), + attestation_verifier.clone(), + ) + .await + { + Ok(output) => { + return output; + } + Err(e) => { + eprintln!("Reconnect failed: {e}. Retrying in {:#?}...", delay); + tokio::time::sleep(delay).await; + + // increase delay for next time (exponential), but clamp to max_delay + delay = std::cmp::min(delay * 2, max_delay); + } + } + } + } + async fn setup_connection( connector: TlsConnector, target: String, From 1f3fe4533c524c9afe4da7466803fb17cfc24afd Mon Sep 17 00:00:00 2001 From: peg Date: Tue, 18 Nov 2025 11:36:31 +0100 Subject: [PATCH 6/7] Comments --- src/lib.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lib.rs b/src/lib.rs index 5d361ac..e87c9b0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -538,6 +538,7 @@ impl ProxyClient { } } + /// Connect to the proxy-server, do TLS handshake and remote attestation async fn setup_connection( connector: TlsConnector, target: String, From 6e579234df9c770f64033916dbf4dc847d41ff17 Mon Sep 17 00:00:00 2001 From: peg Date: Wed, 19 Nov 2025 14:32:43 +0100 Subject: [PATCH 7/7] Add doccomments --- src/lib.rs | 59 +++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 49 insertions(+), 10 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index e87c9b0..5317d6b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -134,7 +134,7 @@ impl ProxyServer { }) } - /// Accept an incoming connection + /// Accept an incoming connection and handle it in a seperate task pub async fn accept(&self) -> Result<(), ProxyError> { let (inbound, _client_addr) = self.listener.accept().await?; @@ -166,6 +166,7 @@ impl ProxyServer { self.listener.local_addr() } + /// Handle an incoming connection from a proxy-client async fn handle_connection( inbound: TcpStream, acceptor: TlsAcceptor, @@ -174,9 +175,11 @@ impl ProxyServer { local_quote_generator: Arc, attestation_verifier: AttestationVerifier, ) -> Result<(), ProxyError> { + // Do TLS handshake let mut tls_stream = acceptor.accept(inbound).await?; let (_io, connection) = tls_stream.get_ref(); + // Compute an exporter unique to the session let mut exporter = [0u8; 32]; connection.export_keying_material( &mut exporter, @@ -184,8 +187,10 @@ impl ProxyServer { None, // context )?; + // Get the TLS certficate chain of the client, if there is one let remote_cert_chain = connection.peer_certificates().map(|c| c.to_owned()); + // If we are in a CVM, generate an attestation let attestation = if local_quote_generator.attestation_type() != AttestationType::None { serde_json::to_vec(&AttesationPayload::from_attestation_generator( &cert_chain, @@ -196,12 +201,13 @@ impl ProxyServer { Vec::new() }; + // Write our attestation to the channel, with length prefix let attestation_length_prefix = length_prefix(&attestation); - tls_stream.write_all(&attestation_length_prefix).await?; - tls_stream.write_all(&attestation).await?; + // Now read a length-prefixed attestation from the remote peer + // In the case of no client attestation this will be zero bytes let mut length_bytes = [0; 4]; tls_stream.read_exact(&mut length_bytes).await?; let length: usize = u32::from_be_bytes(length_bytes).try_into()?; @@ -209,6 +215,7 @@ impl ProxyServer { let mut buf = vec![0; length]; tls_stream.read_exact(&mut buf).await?; + // If we expect an attestaion from the client, verify it and get measurements let (measurements, remote_attestation_type) = if attestation_verifier.has_remote_attestion() { let remote_attestation_payload: AttesationPayload = serde_json::from_slice(&buf)?; @@ -228,9 +235,12 @@ impl ProxyServer { (None, AttestationType::None) }; + // Setup an HTTP server let http = hyper::server::conn::http2::Builder::new(TokioExecutor); + + // Setup a request handler let service = service_fn(move |mut req| { - // If we have measurements, add them to the request header + // If we have measurements, from the remote peer, add them to the request header let measurements = measurements.clone(); if let Some(measurements) = measurements { let headers = req.headers_mut(); @@ -267,6 +277,7 @@ impl ProxyServer { } }); + // Serve this connection using the request handler defined above let io = TokioIo::new(tls_stream); http.serve_connection(io, service).await?; @@ -278,11 +289,13 @@ impl ProxyServer { req: hyper::Request, target: SocketAddr, ) -> Result>, ProxyError> { + // Connect to the target server let outbound = TcpStream::connect(target).await?; let outbound_io = TokioIo::new(outbound); let (mut sender, conn) = hyper::client::conn::http1::Builder::new() .handshake::<_, hyper::body::Incoming>(outbound_io) .await?; + // Drive the connection tokio::spawn(async move { if let Err(e) = conn.await { @@ -290,6 +303,7 @@ impl ProxyServer { } }); + // Forward the request from the proxy-client to the target server match sender.send_request(req).await { Ok(resp) => Ok(resp.map(|b| b.boxed())), Err(e) => { @@ -302,18 +316,23 @@ impl ProxyServer { } } +/// Helper to create a binary http body fn full>(chunk: T) -> BoxBody { http_body_util::Full::new(chunk.into()) .map_err(|never| match never {}) .boxed() } +/// A proxy client which forwards http traffic to a proxy-server pub struct ProxyClient { + /// The underlying TCP listener listener: TcpListener, + /// A channel for sending requests to the connection to the proxy-server requests_tx: mpsc::Sender, } impl ProxyClient { + /// Start with optional TLS client auth pub async fn new( cert_and_key: Option, address: impl ToSocketAddrs, @@ -322,12 +341,14 @@ impl ProxyClient { attestation_verifier: AttestationVerifier, remote_certificate: Option>, ) -> Result { + // If we will provide attestation, we must also use client auth if local_quote_generator.attestation_type() != AttestationType::None && cert_and_key.is_none() { return Err(ProxyError::NoClientAuth); } + // If a remote CA cert was given, use it as the root store, otherwise use webpki_roots let root_store = match remote_certificate { Some(remote_certificate) => { let mut root_store = RootCertStore::empty(); @@ -337,6 +358,7 @@ impl ProxyClient { None => RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()), }; + // Setup TLS client configuration, with or without client auth let client_config = if let Some(ref cert_and_key) = cert_and_key { ClientConfig::builder() .with_root_certificates(root_store) @@ -372,9 +394,11 @@ impl ProxyClient { attestation_verifier: AttestationVerifier, cert_chain: Option>>, ) -> Result { + // Setup TCP server and TLS client let listener = TcpListener::bind(local).await?; let connector = TlsConnector::from(client_config.clone()); + // Process the hostname / port provided by the user let target = host_to_host_with_port(&target_name); // Channel for getting incoming requests from the source client @@ -385,7 +409,7 @@ impl ProxyClient { >, )>(1024); - // Connect to the proxy server + // Connect to the proxy server and provide / verify attestation let (mut sender, mut measurements, mut remote_attestation_type) = Self::setup_connection( connector.clone(), target.clone(), @@ -396,9 +420,13 @@ impl ProxyClient { .await?; tokio::spawn(async move { + // Read an incoming request from the channel (from the source client) while let Some((req, response_tx)) = requests_rx.recv().await { + // Attempt to forward it to the proxy server let (response, should_reconnect) = match sender.send_request(req).await { Ok(mut resp) => { + // If we have measurements from the proxy-server, inject them into the + // response header if let Some(measurements) = measurements.clone() { let headers = resp.headers_mut(); match measurements.to_header_format() { @@ -435,7 +463,7 @@ impl ProxyClient { // If the connection to the proxy server failed, reconnect if should_reconnect { - // Reconnect to the server - retrying with a backoff + // Reconnect to the server - retrying indefinately with a backoff (sender, measurements, remote_attestation_type) = Self::setup_connection_with_backoff( connector.clone(), @@ -475,11 +503,12 @@ impl ProxyClient { Ok(()) } - /// Handle an incoming connection + /// Handle an incoming connection from the source client async fn handle_connection( inbound: TcpStream, requests_tx: mpsc::Sender, ) -> Result<(), ProxyError> { + // Setup http server and handler let http = hyper::server::conn::http1::Builder::new(); let service = service_fn(move |req| { let requests_tx = requests_tx.clone(); @@ -504,6 +533,8 @@ impl ProxyClient { Ok(()) } + // Attempt connection and handshake with the proxy-server + // If it fails retry with a backoff (indefinately) async fn setup_connection_with_backoff( connector: TlsConnector, target: String, @@ -546,6 +577,7 @@ impl ProxyClient { local_quote_generator: Arc, attestation_verifier: AttestationVerifier, ) -> Result<(Http2Sender, Option, AttestationType), ProxyError> { + // Make a TCP client connection and TLS handshake let out = TcpStream::connect(&target).await?; let mut tls_stream = connector .connect(server_name_from_host(&target)?, out) @@ -553,6 +585,7 @@ impl ProxyClient { let (_io, server_connection) = tls_stream.get_ref(); + // Compute an exporter unique to the channel let mut exporter = [0u8; 32]; server_connection.export_keying_material( &mut exporter, @@ -560,11 +593,13 @@ impl ProxyClient { None, // context )?; + // Get the TLS certificate chain of the server let remote_cert_chain = server_connection .peer_certificates() .ok_or(ProxyError::NoCertificate)? .to_owned(); + // Read a length prefixed attestation from the proxy-server let mut length_bytes = [0; 4]; tls_stream.read_exact(&mut length_bytes).await?; let length: usize = u32::from_be_bytes(length_bytes).try_into()?; @@ -575,10 +610,12 @@ impl ProxyClient { let remote_attestation_payload: AttesationPayload = serde_json::from_slice(&buf)?; let remote_attestation_type = remote_attestation_payload.attestation_type; + // Verify the remote attestation against our accepted measurements let measurements = attestation_verifier .verify_attestation(remote_attestation_payload, &remote_cert_chain, exporter) .await?; + // If we are in a CVM, provide an attestation let attestation = if local_quote_generator.attestation_type() != AttestationType::None { serde_json::to_vec(&AttesationPayload::from_attestation_generator( &cert_chain.ok_or(ProxyError::NoClientAuth)?, @@ -589,13 +626,12 @@ impl ProxyClient { Vec::new() }; + // Send our attestation (or zero bytes) prefixed with length let attestation_length_prefix = length_prefix(&attestation); - tls_stream.write_all(&attestation_length_prefix).await?; - tls_stream.write_all(&attestation).await?; - // Attestation is complete - now seturn an HTTP client + // The attestation exchange is now complete - now setup an HTTP client let outbound_io = TokioIo::new(tls_stream); let (sender, conn) = hyper::client::conn::http2::Builder::new(TokioExecutor) @@ -609,6 +645,7 @@ impl ProxyClient { } }); + // Return the HTTP client, as well as remote measurements Ok((sender, measurements, remote_attestation_type)) } @@ -718,6 +755,7 @@ fn length_prefix(input: &[u8]) -> [u8; 4] { len.to_be_bytes() } +/// If no port was provided, default to 443 fn host_to_host_with_port(host: &str) -> String { if host.contains(':') { host.to_string() @@ -726,6 +764,7 @@ fn host_to_host_with_port(host: &str) -> String { } } +/// Given a hostname with or without port number, create a TLS [ServerName] with just the host part fn server_name_from_host( host: &str, ) -> Result, tokio_rustls::rustls::pki_types::InvalidDnsNameError> {