From 0540ba2540fd19971e08ee97b3ff4e36745fe8c1 Mon Sep 17 00:00:00 2001 From: peg Date: Tue, 11 Nov 2025 08:25:55 +0100 Subject: [PATCH 01/12] Attestation returns measurements if successful --- src/attestation.rs | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/attestation.rs b/src/attestation.rs index 351449e..84e1dcf 100644 --- a/src/attestation.rs +++ b/src/attestation.rs @@ -14,6 +14,8 @@ use x509_parser::prelude::*; /// For fetching collateral directly from intel, if no PCCS is specified const PCS_URL: &str = "https://api.trustedservices.intel.com"; +type Measurements = (PlatformMeasurements, CvmImageMeasurements); + /// Defines how to generate a quote pub trait QuoteGenerator: Clone + Send + 'static { /// Whether this is CVM attestation. This should always return true except for the [NoQuoteGenerator] case. @@ -42,7 +44,7 @@ pub trait QuoteVerifier: Clone + Send + 'static { input: Vec, cert_chain: &[CertificateDer<'_>], exporter: [u8; 32], - ) -> impl Future> + Send; + ) -> impl Future, AttestationError>> + Send; } /// Quote generation using configfs_tsm @@ -151,7 +153,7 @@ impl QuoteVerifier for DcapTdxQuoteVerifier { input: Vec, cert_chain: &[CertificateDer<'_>], exporter: [u8; 32], - ) -> Result<(), AttestationError> { + ) -> Result, AttestationError> { let quote_input = compute_report_input(cert_chain, exporter)?; let (platform_measurements, image_measurements) = if cfg!(not(test)) { let now = std::time::SystemTime::now() @@ -205,7 +207,7 @@ impl QuoteVerifier for DcapTdxQuoteVerifier { return Err(AttestationError::UnacceptableOsImageMeasurements); } - Ok(()) + Ok(Some((platform_measurements, image_measurements))) } } @@ -264,9 +266,9 @@ impl QuoteVerifier for NoQuoteVerifier { input: Vec, _cert_chain: &[CertificateDer<'_>], _exporter: [u8; 32], - ) -> Result<(), AttestationError> { + ) -> Result, AttestationError> { if input.is_empty() { - Ok(()) + Ok(None) } else { Err(AttestationError::AttestationGivenWhenNoneExpected) } From 2ab208679846e98e767bad4623bebc31353ce61a Mon Sep 17 00:00:00 2001 From: peg Date: Tue, 11 Nov 2025 10:28:58 +0100 Subject: [PATCH 02/12] Switch to http proxy server (rather than raw TCP) --- Cargo.lock | 4 + Cargo.toml | 4 + src/lib.rs | 226 ++++++++++++++++++++++++++++++++--------------------- 3 files changed, 147 insertions(+), 87 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8f7fdf0..64a3a94 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -132,10 +132,14 @@ version = "0.1.0" dependencies = [ "anyhow", "axum", + "bytes", "clap", "configfs-tsm", "dcap-qvl", "hex", + "http-body-util", + "hyper", + "hyper-util", "pem-rfc7468", "rand_core 0.6.4", "rcgen", diff --git a/Cargo.toml b/Cargo.toml index 739386a..5dd4a7c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,10 @@ configfs-tsm = "0.0.2" rand_core = { version = "0.6.4", features = ["getrandom"] } dcap-qvl = "0.3.4" hex = "0.4.3" +hyper = { version = "1.7.0", features = ["server"] } +hyper-util = "0.1.17" +http-body-util = "0.1.3" +bytes = "1.10.1" [dev-dependencies] rcgen = "0.14.5" diff --git a/src/lib.rs b/src/lib.rs index 152a816..d0a51bc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,10 @@ pub use attestation::{ DcapTdxQuoteGenerator, DcapTdxQuoteVerifier, NoQuoteGenerator, NoQuoteVerifier, QuoteGenerator, QuoteVerifier, }; +use hyper::server::conn::http1::Builder; +use hyper::service::service_fn; +use hyper::Response; +use hyper_util::rt::TokioIo; use thiserror::Error; use tokio_rustls::rustls::server::{VerifierBuilderError, WebPkiClientVerifier}; @@ -204,16 +208,52 @@ impl ProxyServer { .await?; } - let outbound = TcpStream::connect(target).await?; + let http = Builder::new(); + let service = + service_fn(move |req| async move { Self::handle_http_request(req, target).await }); - let (mut inbound_reader, mut inbound_writer) = tokio::io::split(tls_stream); - let (mut outbound_reader, mut outbound_writer) = outbound.into_split(); + let io = TokioIo::new(tls_stream); + http.serve_connection(io, service).await.unwrap(); - let client_to_server = tokio::io::copy(&mut inbound_reader, &mut outbound_writer); - let server_to_client = tokio::io::copy(&mut outbound_reader, &mut inbound_writer); - tokio::try_join!(client_to_server, server_to_client)?; + // let (mut inbound_reader, mut inbound_writer) = tokio::io::split(tls_stream); + // let (mut outbound_reader, mut outbound_writer) = outbound.into_split(); + // + // let client_to_server = tokio::io::copy(&mut inbound_reader, &mut outbound_writer); + // let server_to_client = tokio::io::copy(&mut outbound_reader, &mut inbound_writer); + // tokio::try_join!(client_to_server, server_to_client)?; Ok(()) } + + // Handle a request from the proxy client to the target server + async fn handle_http_request( + req: hyper::Request, + target: SocketAddr, + ) -> Result, hyper::Error> { + let outbound = TcpStream::connect(target).await.unwrap(); + let outbound_io = TokioIo::new(outbound); + let (mut sender, conn) = hyper::client::conn::http1::Builder::new() + .handshake::<_, hyper::body::Incoming>(outbound_io) + .await + .unwrap(); + + // Drive the connection + tokio::spawn(async move { + if let Err(e) = conn.await { + eprintln!("client conn error: {e}"); + } + }); + + match sender.send_request(req).await { + Ok(resp) => Ok(resp), + Err(e) => { + eprintln!("send_request error: {e}"); + // let mut resp = Response::new(hyper::body::Incoming::empty()); + // *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; + // Ok(resp) + panic!("todo"); + } + } + } } pub struct ProxyClient @@ -337,58 +377,129 @@ impl ProxyClient { local_attestation_platform: L, remote_attestation_platform: R, ) -> Result<(), ProxyError> { - let out = TcpStream::connect(&target).await?; + let http = Builder::new(); + let service = service_fn(move |req| { + let connector = connector.clone(); + let target = target.clone(); + let cert_chain = cert_chain.clone(); + let local_attestation_platform = local_attestation_platform.clone(); + let remote_attestation_platform = remote_attestation_platform.clone(); + async move { + Self::handle_http_request( + req, + connector, + target, + cert_chain, + local_attestation_platform, + remote_attestation_platform, + ) + .await + } + }); + + let io = TokioIo::new(inbound); + http.serve_connection(io, service).await.unwrap(); + + // let (mut inbound_reader, mut inbound_writer) = inbound.into_split(); + // let (mut outbound_reader, mut outbound_writer) = tokio::io::split(tls_stream); + // + // let client_to_server = tokio::io::copy(&mut inbound_reader, &mut outbound_writer); + // let server_to_client = tokio::io::copy(&mut outbound_reader, &mut inbound_writer); + // tokio::try_join!(client_to_server, server_to_client)?; + Ok(()) + } + + // Handle a request from the source client to the proxy server + async fn handle_http_request( + req: hyper::Request, + connector: TlsConnector, + target: String, + cert_chain: Option>>, + local_attestation_platform: L, + remote_attestation_platform: R, + ) -> Result, hyper::Error> { + let out = TcpStream::connect(&target).await.unwrap(); let mut tls_stream = connector - .connect(server_name_from_host(&target)?, out) - .await?; + .connect(server_name_from_host(&target).unwrap(), out) + .await + .unwrap(); let (_io, server_connection) = tls_stream.get_ref(); let mut exporter = [0u8; 32]; - server_connection.export_keying_material( - &mut exporter, - EXPORTER_LABEL, - None, // context - )?; + server_connection + .export_keying_material( + &mut exporter, + EXPORTER_LABEL, + None, // context + ) + .unwrap(); let remote_cert_chain = server_connection .peer_certificates() - .ok_or(ProxyError::NoCertificate)? + .ok_or(ProxyError::NoCertificate) + .unwrap() .to_owned(); let mut length_bytes = [0; 4]; - tls_stream.read_exact(&mut length_bytes).await?; - let length: usize = u32::from_be_bytes(length_bytes).try_into()?; + tls_stream.read_exact(&mut length_bytes).await.unwrap(); + let length: usize = u32::from_be_bytes(length_bytes).try_into().unwrap(); let mut buf = vec![0; length]; - tls_stream.read_exact(&mut buf).await?; + tls_stream.read_exact(&mut buf).await.unwrap(); if remote_attestation_platform.is_cvm() { remote_attestation_platform .verify_attestation(buf, &remote_cert_chain, exporter) - .await?; + .await + .unwrap(); } let attestation = if local_attestation_platform.is_cvm() { local_attestation_platform - .create_attestation(&cert_chain.ok_or(ProxyError::NoClientAuth)?, exporter)? + .create_attestation( + &cert_chain.ok_or(ProxyError::NoClientAuth).unwrap(), + exporter, + ) + .unwrap() } else { Vec::new() }; let attestation_length_prefix = length_prefix(&attestation); - tls_stream.write_all(&attestation_length_prefix).await?; + tls_stream + .write_all(&attestation_length_prefix) + .await + .unwrap(); - tls_stream.write_all(&attestation).await?; + tls_stream.write_all(&attestation).await.unwrap(); - let (mut inbound_reader, mut inbound_writer) = inbound.into_split(); - let (mut outbound_reader, mut outbound_writer) = tokio::io::split(tls_stream); + // Now the attestation is done, forward the connection to the proxy server + // let outbound = TcpStream::connect(target).await.unwrap(); + let outbound_io = TokioIo::new(tls_stream); + let (mut sender, conn) = hyper::client::conn::http1::Builder::new() + .handshake::<_, hyper::body::Incoming>(outbound_io) + .await + .unwrap(); - let client_to_server = tokio::io::copy(&mut inbound_reader, &mut outbound_writer); - let server_to_client = tokio::io::copy(&mut outbound_reader, &mut inbound_writer); - tokio::try_join!(client_to_server, server_to_client)?; - Ok(()) + // Drive the connection + tokio::spawn(async move { + if let Err(e) = conn.await { + eprintln!("client conn error: {e}"); + } + }); + + match sender.send_request(req).await { + Ok(resp) => Ok(resp), + Err(e) => { + eprintln!("send_request error: {e}"); + // let mut resp = Response::new(hyper::body::Incoming::empty()); + // *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; + // Ok(resp) + panic!("todo"); + } + } } } @@ -643,65 +754,6 @@ mod tests { assert_eq!(res, "foobar"); } - #[tokio::test] - async fn raw_tcp_proxy() { - let target_addr = example_service().await; - - let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); - let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); - - let proxy_server = ProxyServer::new_with_tls_config( - cert_chain, - server_config, - "127.0.0.1:0", - target_addr, - DcapTdxQuoteGenerator, - NoQuoteVerifier, - ) - .await - .unwrap(); - - let proxy_server_addr = proxy_server.local_addr().unwrap(); - - tokio::spawn(async move { - proxy_server.accept().await.unwrap(); - }); - - let quote_verifier = DcapTdxQuoteVerifier { - accepted_platform_measurements: None, - accepted_cvm_image_measurements: vec![CvmImageMeasurements { - rtmr1: [0u8; 48], - rtmr2: [0u8; 48], - rtmr3: [0u8; 48], - }], - pccs_url: None, - }; - - let proxy_client = ProxyClient::new_with_tls_config( - client_config, - "127.0.0.1:0", - proxy_server_addr.to_string(), - NoQuoteGenerator, - quote_verifier, - None, - ) - .await - .unwrap(); - - let proxy_client_addr = proxy_client.local_addr().unwrap(); - - tokio::spawn(async move { - proxy_client.accept().await.unwrap(); - }); - - let mut out = TcpStream::connect(proxy_client_addr).await.unwrap(); - - let mut buf = [0; 9]; - out.read(&mut buf).await.unwrap(); - - assert_eq!(buf[..], b"some data"[..]); - } - #[tokio::test] async fn test_get_tls_cert() { let target_addr = example_service().await; From 69cf7dc91950062a87db66a0736853f576392c59 Mon Sep 17 00:00:00 2001 From: peg Date: Tue, 11 Nov 2025 10:43:00 +0100 Subject: [PATCH 03/12] Error handling for connections --- src/lib.rs | 46 ++++++++++++++++++++-------------------------- 1 file changed, 20 insertions(+), 26 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index d0a51bc..3634222 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,9 @@ pub use attestation::{ DcapTdxQuoteGenerator, DcapTdxQuoteVerifier, NoQuoteGenerator, NoQuoteVerifier, QuoteGenerator, QuoteVerifier, }; +use bytes::Bytes; +use http_body_util::combinators::BoxBody; +use http_body_util::BodyExt; use hyper::server::conn::http1::Builder; use hyper::service::service_fn; use hyper::Response; @@ -215,12 +218,6 @@ impl ProxyServer { let io = TokioIo::new(tls_stream); http.serve_connection(io, service).await.unwrap(); - // let (mut inbound_reader, mut inbound_writer) = tokio::io::split(tls_stream); - // let (mut outbound_reader, mut outbound_writer) = outbound.into_split(); - // - // let client_to_server = tokio::io::copy(&mut inbound_reader, &mut outbound_writer); - // let server_to_client = tokio::io::copy(&mut outbound_reader, &mut inbound_writer); - // tokio::try_join!(client_to_server, server_to_client)?; Ok(()) } @@ -228,7 +225,7 @@ impl ProxyServer { async fn handle_http_request( req: hyper::Request, target: SocketAddr, - ) -> Result, hyper::Error> { + ) -> Result>, hyper::Error> { let outbound = TcpStream::connect(target).await.unwrap(); let outbound_io = TokioIo::new(outbound); let (mut sender, conn) = hyper::client::conn::http1::Builder::new() @@ -244,18 +241,23 @@ impl ProxyServer { }); match sender.send_request(req).await { - Ok(resp) => Ok(resp), + Ok(resp) => Ok(resp.map(|b| b.boxed())), Err(e) => { eprintln!("send_request error: {e}"); - // let mut resp = Response::new(hyper::body::Incoming::empty()); - // *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; - // Ok(resp) - panic!("todo"); + let mut resp = Response::new(full(format!("Request failed: {e}"))); + *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; + Ok(resp) } } } } +fn full>(chunk: T) -> BoxBody { + http_body_util::Full::new(chunk.into()) + .map_err(|never| match never {}) + .boxed() +} + pub struct ProxyClient where L: QuoteGenerator, @@ -400,12 +402,6 @@ impl ProxyClient { let io = TokioIo::new(inbound); http.serve_connection(io, service).await.unwrap(); - // let (mut inbound_reader, mut inbound_writer) = inbound.into_split(); - // let (mut outbound_reader, mut outbound_writer) = tokio::io::split(tls_stream); - // - // let client_to_server = tokio::io::copy(&mut inbound_reader, &mut outbound_writer); - // let server_to_client = tokio::io::copy(&mut outbound_reader, &mut inbound_writer); - // tokio::try_join!(client_to_server, server_to_client)?; Ok(()) } @@ -417,7 +413,7 @@ impl ProxyClient { cert_chain: Option>>, local_attestation_platform: L, remote_attestation_platform: R, - ) -> Result, hyper::Error> { + ) -> Result>, hyper::Error> { let out = TcpStream::connect(&target).await.unwrap(); let mut tls_stream = connector .connect(server_name_from_host(&target).unwrap(), out) @@ -475,8 +471,7 @@ impl ProxyClient { tls_stream.write_all(&attestation).await.unwrap(); - // Now the attestation is done, forward the connection to the proxy server - // let outbound = TcpStream::connect(target).await.unwrap(); + // Now the attestation is done, forward the request to the proxy server let outbound_io = TokioIo::new(tls_stream); let (mut sender, conn) = hyper::client::conn::http1::Builder::new() .handshake::<_, hyper::body::Incoming>(outbound_io) @@ -491,13 +486,12 @@ impl ProxyClient { }); match sender.send_request(req).await { - Ok(resp) => Ok(resp), + Ok(resp) => Ok(resp.map(|b| b.boxed())), Err(e) => { eprintln!("send_request error: {e}"); - // let mut resp = Response::new(hyper::body::Incoming::empty()); - // *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; - // Ok(resp) - panic!("todo"); + let mut resp = Response::new(full(format!("Request failed: {e}"))); + *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; + Ok(resp) } } } From 0fe56d5a3991c6e3cd25fa0d53183b189f72a359 Mon Sep 17 00:00:00 2001 From: peg Date: Tue, 11 Nov 2025 10:57:29 +0100 Subject: [PATCH 04/12] Error handling --- src/lib.rs | 72 ++++++++++++++++++++++++++++++------------------------ 1 file changed, 40 insertions(+), 32 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 3634222..dd5bef5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -405,71 +405,79 @@ impl ProxyClient { Ok(()) } - // Handle a request from the source client to the proxy server - async fn handle_http_request( - req: hyper::Request, + async fn setup_connection( connector: TlsConnector, target: String, cert_chain: Option>>, local_attestation_platform: L, remote_attestation_platform: R, - ) -> Result>, hyper::Error> { - let out = TcpStream::connect(&target).await.unwrap(); + ) -> Result, ProxyError> { + let out = TcpStream::connect(&target).await?; let mut tls_stream = connector - .connect(server_name_from_host(&target).unwrap(), out) - .await - .unwrap(); + .connect(server_name_from_host(&target)?, out) + .await?; let (_io, server_connection) = tls_stream.get_ref(); let mut exporter = [0u8; 32]; - server_connection - .export_keying_material( - &mut exporter, - EXPORTER_LABEL, - None, // context - ) - .unwrap(); + server_connection.export_keying_material( + &mut exporter, + EXPORTER_LABEL, + None, // context + )?; let remote_cert_chain = server_connection .peer_certificates() - .ok_or(ProxyError::NoCertificate) - .unwrap() + .ok_or(ProxyError::NoCertificate)? .to_owned(); let mut length_bytes = [0; 4]; - tls_stream.read_exact(&mut length_bytes).await.unwrap(); - let length: usize = u32::from_be_bytes(length_bytes).try_into().unwrap(); + tls_stream.read_exact(&mut length_bytes).await?; + let length: usize = u32::from_be_bytes(length_bytes).try_into()?; let mut buf = vec![0; length]; - tls_stream.read_exact(&mut buf).await.unwrap(); + tls_stream.read_exact(&mut buf).await?; if remote_attestation_platform.is_cvm() { remote_attestation_platform .verify_attestation(buf, &remote_cert_chain, exporter) - .await - .unwrap(); + .await?; } let attestation = if local_attestation_platform.is_cvm() { local_attestation_platform - .create_attestation( - &cert_chain.ok_or(ProxyError::NoClientAuth).unwrap(), - exporter, - ) - .unwrap() + .create_attestation(&cert_chain.ok_or(ProxyError::NoClientAuth)?, exporter)? } else { Vec::new() }; let attestation_length_prefix = length_prefix(&attestation); - tls_stream - .write_all(&attestation_length_prefix) - .await - .unwrap(); + tls_stream.write_all(&attestation_length_prefix).await?; + + tls_stream.write_all(&attestation).await?; + + Ok(tls_stream) + } - tls_stream.write_all(&attestation).await.unwrap(); + // Handle a request from the source client to the proxy server + async fn handle_http_request( + req: hyper::Request, + connector: TlsConnector, + target: String, + cert_chain: Option>>, + local_attestation_platform: L, + remote_attestation_platform: R, + ) -> Result>, hyper::Error> { + let tls_stream = Self::setup_connection( + connector, + target, + cert_chain, + local_attestation_platform, + remote_attestation_platform, + ) + .await + .unwrap(); // Now the attestation is done, forward the request to the proxy server let outbound_io = TokioIo::new(tls_stream); From e8ab3779577573243eb356604bf3927d921493ad Mon Sep 17 00:00:00 2001 From: peg Date: Tue, 11 Nov 2025 13:01:01 +0100 Subject: [PATCH 05/12] Error handling --- src/lib.rs | 48 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index dd5bef5..5d17b30 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -212,11 +212,20 @@ impl ProxyServer { } let http = Builder::new(); - let service = - service_fn(move |req| async move { Self::handle_http_request(req, target).await }); + let service = service_fn(move |req| async move { + match Self::handle_http_request(req, target).await { + Ok(res) => Ok::>, hyper::Error>(res), + Err(e) => { + eprintln!("send_request error: {e}"); + let mut resp = Response::new(full(format!("Request failed: {e}"))); + *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; + Ok(resp) + } + } + }); let io = TokioIo::new(tls_stream); - http.serve_connection(io, service).await.unwrap(); + http.serve_connection(io, service).await?; Ok(()) } @@ -225,14 +234,12 @@ impl ProxyServer { async fn handle_http_request( req: hyper::Request, target: SocketAddr, - ) -> Result>, hyper::Error> { - let outbound = TcpStream::connect(target).await.unwrap(); + ) -> Result>, ProxyError> { + let outbound = TcpStream::connect(target).await?; let outbound_io = TokioIo::new(outbound); let (mut sender, conn) = hyper::client::conn::http1::Builder::new() .handshake::<_, hyper::body::Incoming>(outbound_io) - .await - .unwrap(); - + .await?; // Drive the connection tokio::spawn(async move { if let Err(e) = conn.await { @@ -387,7 +394,7 @@ impl ProxyClient { let local_attestation_platform = local_attestation_platform.clone(); let remote_attestation_platform = remote_attestation_platform.clone(); async move { - Self::handle_http_request( + match Self::handle_http_request( req, connector, target, @@ -396,11 +403,22 @@ impl ProxyClient { remote_attestation_platform, ) .await + { + Ok(res) => { + Ok::>, hyper::Error>(res) + } + Err(e) => { + eprintln!("send_request error: {e}"); + let mut resp = Response::new(full(format!("Request failed: {e}"))); + *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; + Ok(resp) + } + } } }); let io = TokioIo::new(inbound); - http.serve_connection(io, service).await.unwrap(); + http.serve_connection(io, service).await?; Ok(()) } @@ -468,7 +486,7 @@ impl ProxyClient { cert_chain: Option>>, local_attestation_platform: L, remote_attestation_platform: R, - ) -> Result>, hyper::Error> { + ) -> Result>, ProxyError> { let tls_stream = Self::setup_connection( connector, target, @@ -476,15 +494,13 @@ impl ProxyClient { local_attestation_platform, remote_attestation_platform, ) - .await - .unwrap(); + .await?; // Now the attestation is done, forward the request to the proxy server let outbound_io = TokioIo::new(tls_stream); let (mut sender, conn) = hyper::client::conn::http1::Builder::new() .handshake::<_, hyper::body::Incoming>(outbound_io) - .await - .unwrap(); + .await?; // Drive the connection tokio::spawn(async move { @@ -583,6 +599,8 @@ pub enum ProxyError { IntConversion(#[from] TryFromIntError), #[error("Bad host name: {0}")] BadDnsName(#[from] tokio_rustls::rustls::pki_types::InvalidDnsNameError), + #[error("HTTP: {0}")] + Hyper(#[from] hyper::Error), } /// Given a byte array, encode its length as a 4 byte big endian u32 From 9bd9bacc63ef8950265274d71be8eb6bd526fa46 Mon Sep 17 00:00:00 2001 From: peg Date: Tue, 11 Nov 2025 18:18:54 +0100 Subject: [PATCH 06/12] Insert measurements into headers --- :w | 386 ++++++++++++++++++++++++++++++++++++++++++++ Cargo.lock | 2 + Cargo.toml | 2 + src/attestation.rs | 72 ++++++++- src/lib.rs | 101 ++++++++---- src/test_helpers.rs | 30 +++- 6 files changed, 558 insertions(+), 35 deletions(-) create mode 100644 :w diff --git a/:w b/:w new file mode 100644 index 0000000..5a9b3d9 --- /dev/null +++ b/:w @@ -0,0 +1,386 @@ +use std::{collections::HashMap, time::SystemTimeError}; + +use configfs_tsm::QuoteGenerationError; +use dcap_qvl::{ + collateral::get_collateral_for_fmspc, + quote::{Quote, Report}, +}; +use sha2::{Digest, Sha256}; +use tdx_quote::QuoteParseError; +use thiserror::Error; +use tokio_rustls::rustls::pki_types::CertificateDer; +use x509_parser::prelude::*; + +/// For fetching collateral directly from intel, if no PCCS is specified +const PCS_URL: &str = "https://api.trustedservices.intel.com"; + +#[derive(Debug, Clone, PartialEq)] +pub struct Measurements { + pub platform: PlatformMeasurements, + pub cvm_image: CvmImageMeasurements, +} + +impl Measurements { + pub fn to_header_format(&self) -> Result { + let mut measurements_map = HashMap::new(); + measurements_map.insert(0, hex::encode(self.platform.mrtd)); + measurements_map.insert(1, hex::encode(self.platform.rtmr0)); + measurements_map.insert(2, hex::encode(self.cvm_image.rtmr1)); + measurements_map.insert(3, hex::encode(self.cvm_image.rtmr2)); + measurements_map.insert(4, hex::encode(self.cvm_image.rtmr3)); + Ok(serde_json::to_string(&measurements_map)?) + } + + pub fn from_header_format(input: &str) -> Result { + let measurements_map: HashMap = serde_json::from_str(input)?; + let measurements_map: HashMap = measurements_map + .into_iter() + .map(|(k, v)| (k, hex::decode(v).unwrap().try_into().unwrap())) + .collect(); + + Ok(Self { + platform: PlatformMeasurements { + mrtd: *measurements_map.get(&0).ok_or(MeasurementFormatError::MissingValue("MRTD".to_string())?, + rtmr0: *measurements_map.get(&1).unwrap(), + }, + cvm_image: CvmImageMeasurements { + rtmr1: *measurements_map.get(&2).unwrap(), + rtmr2: *measurements_map.get(&3).unwrap(), + rtmr3: *measurements_map.get(&4).unwrap(), + }, + }) + } +} + +#[derive(Error, Debug)] +pub enum MeasurementFormatError { + #[error("JSON: {0}")] + Json(#[from] serde_json::Error), + #[error("Missing value: {0}")] + MissingValue(String), +} + +/// Defines how to generate a quote +pub trait QuoteGenerator: Clone + Send + 'static { + /// Whether this is CVM attestation. This should always return true except for the [NoQuoteGenerator] case. + /// + /// When false, allows TLS client to be configured without client authentication + fn is_cvm(&self) -> bool; + + /// Generate an attestation + fn create_attestation( + &self, + cert_chain: &[CertificateDer<'_>], + exporter: [u8; 32], + ) -> Result, AttestationError>; +} + +/// Defines how to verify a quote +pub trait QuoteVerifier: Clone + Send + 'static { + /// Whether this is CVM attestation. This should always return true except for the [NoQuoteVerifier] case. + /// + /// When false, allows TLS client to be configured without client authentication + fn is_cvm(&self) -> bool; + + /// Verify the given attestation payload + fn verify_attestation( + &self, + input: Vec, + cert_chain: &[CertificateDer<'_>], + exporter: [u8; 32], + ) -> impl Future, AttestationError>> + Send; +} + +/// Quote generation using configfs_tsm +#[derive(Clone)] +pub struct DcapTdxQuoteGenerator; + +impl QuoteGenerator for DcapTdxQuoteGenerator { + fn is_cvm(&self) -> bool { + true + } + + fn create_attestation( + &self, + cert_chain: &[CertificateDer<'_>], + exporter: [u8; 32], + ) -> Result, AttestationError> { + let quote_input = compute_report_input(cert_chain, exporter)?; + + Ok(generate_quote(quote_input)?) + } +} + +/// Measurements determined by the CVM platform +#[derive(Clone, PartialEq, Debug)] +pub struct PlatformMeasurements { + pub mrtd: [u8; 48], + pub rtmr0: [u8; 48], +} + +impl PlatformMeasurements { + fn from_dcap_qvl_quote(quote: &dcap_qvl::quote::Quote) -> Result { + let report = match quote.report { + Report::TD10(report) => report, + Report::TD15(report) => report.base, + Report::SgxEnclave(_) => { + return Err(AttestationError::SgxNotSupported); + } + }; + Ok(Self { + mrtd: report.mr_td, + rtmr0: report.rt_mr0, + }) + } + + fn from_tdx_quote(quote: &tdx_quote::Quote) -> Self { + Self { + mrtd: quote.mrtd(), + rtmr0: quote.rtmr0(), + } + } +} + +/// Measurements determined by the CVM image +#[derive(Clone, PartialEq, Debug)] +pub struct CvmImageMeasurements { + pub rtmr1: [u8; 48], + pub rtmr2: [u8; 48], + pub rtmr3: [u8; 48], +} + +impl CvmImageMeasurements { + fn from_dcap_qvl_quote(quote: &dcap_qvl::quote::Quote) -> Result { + let report = match quote.report { + Report::TD10(report) => report, + Report::TD15(report) => report.base, + Report::SgxEnclave(_) => { + return Err(AttestationError::SgxNotSupported); + } + }; + Ok(Self { + rtmr1: report.rt_mr1, + rtmr2: report.rt_mr2, + rtmr3: report.rt_mr3, + }) + } + + fn from_tdx_quote(quote: &tdx_quote::Quote) -> Self { + Self { + rtmr1: quote.rtmr1(), + rtmr2: quote.rtmr2(), + rtmr3: quote.rtmr3(), + } + } +} + +/// Verify DCAP TDX quotes, allowing them if they have one of a given set of platform-specific and +/// OS image specific measurements +#[derive(Clone)] +pub struct DcapTdxQuoteVerifier { + /// Platform specific allowed Measurements + /// Currently an option as this may be determined internally on a per-platform basis (Eg: GCP) + pub accepted_platform_measurements: Option>, + /// OS-image specific allows measurement - this is effectively a list of allowed OS images + pub accepted_cvm_image_measurements: Vec, + /// URL of a PCCS (defaults to Intel PCS) + pub pccs_url: Option, +} + +impl QuoteVerifier for DcapTdxQuoteVerifier { + fn is_cvm(&self) -> bool { + true + } + + async fn verify_attestation( + &self, + input: Vec, + cert_chain: &[CertificateDer<'_>], + exporter: [u8; 32], + ) -> Result, AttestationError> { + let quote_input = compute_report_input(cert_chain, exporter)?; + let (platform_measurements, image_measurements) = if cfg!(not(test)) { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH)? + .as_secs(); + let quote = Quote::parse(&input)?; + + let ca = quote.ca()?; + let fmspc = hex::encode_upper(quote.fmspc()?); + let collateral = get_collateral_for_fmspc( + &self.pccs_url.clone().unwrap_or(PCS_URL.to_string()), + fmspc, + ca, + false, + ) + .await?; + + let _verified_report = dcap_qvl::verify::verify(&input, &collateral, now)?; + + let measurements = ( + PlatformMeasurements::from_dcap_qvl_quote("e)?, + CvmImageMeasurements::from_dcap_qvl_quote("e)?, + ); + if get_quote_input_data(quote.report) != quote_input { + return Err(AttestationError::InputMismatch); + } + measurements + } else { + // In tests we use mock quotes which will fail to verify + let quote = tdx_quote::Quote::from_bytes(&input)?; + if quote.report_input_data() != quote_input { + return Err(AttestationError::InputMismatch); + } + + ( + PlatformMeasurements::from_tdx_quote("e), + CvmImageMeasurements::from_tdx_quote("e), + ) + }; + + if let Some(accepted_platform_measurements) = &self.accepted_platform_measurements + && !accepted_platform_measurements.contains(&platform_measurements) + { + return Err(AttestationError::UnacceptablePlatformMeasurements); + } + + if !self + .accepted_cvm_image_measurements + .contains(&image_measurements) + { + return Err(AttestationError::UnacceptableOsImageMeasurements); + } + + Ok(Some(Measurements { + platform: platform_measurements, + cvm_image: image_measurements, + })) + } +} + +/// Given a [Report] get the input data regardless of report type +fn get_quote_input_data(report: Report) -> [u8; 64] { + match report { + Report::TD10(r) => r.report_data, + Report::TD15(r) => r.base.report_data, + Report::SgxEnclave(r) => r.report_data, + } +} + +/// Given a certificate chain and an exporter (session key material), build the quote input value +/// SHA256(pki) || exporter +pub fn compute_report_input( + cert_chain: &[CertificateDer<'_>], + exporter: [u8; 32], +) -> Result<[u8; 64], AttestationError> { + let mut quote_input = [0u8; 64]; + let pki_hash = get_pki_hash_from_certificate_chain(cert_chain)?; + quote_input[..32].copy_from_slice(&pki_hash); + quote_input[32..].copy_from_slice(&exporter); + Ok(quote_input) +} + +/// For no CVM platform (eg: for one-sided remote-attested TLS) +#[derive(Clone)] +pub struct NoQuoteGenerator; + +impl QuoteGenerator for NoQuoteGenerator { + fn is_cvm(&self) -> bool { + false + } + + /// Create an empty attestation + fn create_attestation( + &self, + _cert_chain: &[CertificateDer<'_>], + _exporter: [u8; 32], + ) -> Result, AttestationError> { + Ok(Vec::new()) + } +} + +/// For no CVM platform (eg: for one-sided remote-attested TLS) +#[derive(Clone)] +pub struct NoQuoteVerifier; + +impl QuoteVerifier for NoQuoteVerifier { + fn is_cvm(&self) -> bool { + false + } + /// Ensure that an empty attestation is given + async fn verify_attestation( + &self, + input: Vec, + _cert_chain: &[CertificateDer<'_>], + _exporter: [u8; 32], + ) -> Result, AttestationError> { + if input.is_empty() { + Ok(None) + } else { + Err(AttestationError::AttestationGivenWhenNoneExpected) + } + } +} + +/// Create a mock quote for testing on non-confidential hardware +#[cfg(test)] +fn generate_quote(input: [u8; 64]) -> Result, QuoteGenerationError> { + let attestation_key = tdx_quote::SigningKey::random(&mut rand_core::OsRng); + let provisioning_certification_key = tdx_quote::SigningKey::random(&mut rand_core::OsRng); + Ok(tdx_quote::Quote::mock( + attestation_key.clone(), + provisioning_certification_key.clone(), + input, + b"Mock cert chain".to_vec(), + ) + .as_bytes()) +} + +/// Create a quote +#[cfg(not(test))] +fn generate_quote(input: [u8; 64]) -> Result, QuoteGenerationError> { + configfs_tsm::create_quote(input) +} + +/// Given a certificate chain, get the [Sha256] hash of the public key of the leaf certificate +fn get_pki_hash_from_certificate_chain( + cert_chain: &[CertificateDer<'_>], +) -> Result<[u8; 32], AttestationError> { + let leaf_certificate = cert_chain.first().ok_or(AttestationError::NoCertificate)?; + let (_, cert) = parse_x509_certificate(leaf_certificate.as_ref())?; + let public_key = &cert.tbs_certificate.subject_pki; + let key_bytes = public_key.subject_public_key.as_ref(); + + let mut hasher = Sha256::new(); + hasher.update(key_bytes); + Ok(hasher.finalize().into()) +} + +/// An error when generating or verifying an attestation +#[derive(Error, Debug)] +pub enum AttestationError { + #[error("Certificate chain is empty")] + NoCertificate, + #[error("X509 parse: {0}")] + X509Parse(#[from] x509_parser::asn1_rs::Err), + #[error("X509: {0}")] + X509(#[from] x509_parser::error::X509Error), + #[error("Quote input is not as expected")] + InputMismatch, + #[error("Configuration mismatch - expected no remote attestation")] + AttestationGivenWhenNoneExpected, + #[error("Configfs-tsm quote generation: {0}")] + QuoteGeneration(#[from] configfs_tsm::QuoteGenerationError), + #[error("SGX quote given when TDX quote expected")] + SgxNotSupported, + #[error("Platform measurements do not match any accepted values")] + UnacceptablePlatformMeasurements, + #[error("OS image measurements do not match any accepted values")] + UnacceptableOsImageMeasurements, + #[error("System Time: {0}")] + SystemTime(#[from] SystemTimeError), + #[error("DCAP quote verification: {0}")] + DcapQvl(#[from] anyhow::Error), + #[error("Quote parse: {0}")] + QuoteParse(#[from] QuoteParseError), +} diff --git a/Cargo.lock b/Cargo.lock index 64a3a94..60a6dd8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -137,6 +137,7 @@ dependencies = [ "configfs-tsm", "dcap-qvl", "hex", + "http", "http-body-util", "hyper", "hyper-util", @@ -145,6 +146,7 @@ dependencies = [ "rcgen", "reqwest", "rustls-pemfile", + "serde_json", "sha2", "tdx-quote", "thiserror", diff --git a/Cargo.toml b/Cargo.toml index 5dd4a7c..24c5994 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,8 @@ hyper = { version = "1.7.0", features = ["server"] } hyper-util = "0.1.17" http-body-util = "0.1.3" bytes = "1.10.1" +http = "1.3.1" +serde_json = "1.0.145" [dev-dependencies] rcgen = "0.14.5" diff --git a/src/attestation.rs b/src/attestation.rs index 84e1dcf..bfb3ce2 100644 --- a/src/attestation.rs +++ b/src/attestation.rs @@ -1,10 +1,11 @@ -use std::time::SystemTimeError; +use std::{collections::HashMap, time::SystemTimeError}; use configfs_tsm::QuoteGenerationError; use dcap_qvl::{ collateral::get_collateral_for_fmspc, quote::{Quote, Report}, }; +use http::{header::InvalidHeaderValue, HeaderValue}; use sha2::{Digest, Sha256}; use tdx_quote::QuoteParseError; use thiserror::Error; @@ -14,7 +15,65 @@ use x509_parser::prelude::*; /// For fetching collateral directly from intel, if no PCCS is specified const PCS_URL: &str = "https://api.trustedservices.intel.com"; -type Measurements = (PlatformMeasurements, CvmImageMeasurements); +#[derive(Debug, Clone, PartialEq)] +pub struct Measurements { + pub platform: PlatformMeasurements, + pub cvm_image: CvmImageMeasurements, +} + +impl Measurements { + pub fn to_header_format(&self) -> Result { + let mut measurements_map = HashMap::new(); + measurements_map.insert(0, hex::encode(self.platform.mrtd)); + measurements_map.insert(1, hex::encode(self.platform.rtmr0)); + measurements_map.insert(2, hex::encode(self.cvm_image.rtmr1)); + measurements_map.insert(3, hex::encode(self.cvm_image.rtmr2)); + measurements_map.insert(4, hex::encode(self.cvm_image.rtmr3)); + Ok(HeaderValue::from_str(&serde_json::to_string( + &measurements_map, + )?)?) + } + + pub fn from_header_format(input: &str) -> Result { + let measurements_map: HashMap = serde_json::from_str(input)?; + let measurements_map: HashMap = measurements_map + .into_iter() + .map(|(k, v)| (k, hex::decode(v).unwrap().try_into().unwrap())) + .collect(); + + Ok(Self { + platform: PlatformMeasurements { + mrtd: *measurements_map + .get(&0) + .ok_or(MeasurementFormatError::MissingValue("MRTD".to_string()))?, + rtmr0: *measurements_map + .get(&1) + .ok_or(MeasurementFormatError::MissingValue("MRTD".to_string()))?, + }, + cvm_image: CvmImageMeasurements { + rtmr1: *measurements_map + .get(&2) + .ok_or(MeasurementFormatError::MissingValue("MRTD".to_string()))?, + rtmr2: *measurements_map + .get(&3) + .ok_or(MeasurementFormatError::MissingValue("MRTD".to_string()))?, + rtmr3: *measurements_map + .get(&4) + .ok_or(MeasurementFormatError::MissingValue("MRTD".to_string()))?, + }, + }) + } +} + +#[derive(Error, Debug)] +pub enum MeasurementFormatError { + #[error("JSON: {0}")] + Json(#[from] serde_json::Error), + #[error("Missing value: {0}")] + MissingValue(String), + #[error("Invalid header value: {0}")] + BadHeaderValue(#[from] InvalidHeaderValue), +} /// Defines how to generate a quote pub trait QuoteGenerator: Clone + Send + 'static { @@ -68,7 +127,7 @@ impl QuoteGenerator for DcapTdxQuoteGenerator { } /// Measurements determined by the CVM platform -#[derive(Clone, PartialEq)] +#[derive(Clone, PartialEq, Debug)] pub struct PlatformMeasurements { pub mrtd: [u8; 48], pub rtmr0: [u8; 48], @@ -98,7 +157,7 @@ impl PlatformMeasurements { } /// Measurements determined by the CVM image -#[derive(Clone, PartialEq)] +#[derive(Clone, PartialEq, Debug)] pub struct CvmImageMeasurements { pub rtmr1: [u8; 48], pub rtmr2: [u8; 48], @@ -207,7 +266,10 @@ impl QuoteVerifier for DcapTdxQuoteVerifier { return Err(AttestationError::UnacceptableOsImageMeasurements); } - Ok(Some((platform_measurements, image_measurements))) + Ok(Some(Measurements { + platform: platform_measurements, + cvm_image: image_measurements, + })) } } diff --git a/src/lib.rs b/src/lib.rs index 5d17b30..b16666d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,6 @@ pub mod attestation; -use attestation::AttestationError; +use attestation::{AttestationError, Measurements}; pub use attestation::{ DcapTdxQuoteGenerator, DcapTdxQuoteVerifier, NoQuoteGenerator, NoQuoteVerifier, QuoteGenerator, QuoteVerifier, @@ -32,6 +32,11 @@ use tokio_rustls::{ /// The label used when exporting key material from a TLS session const EXPORTER_LABEL: &[u8; 24] = b"EXPORTER-Channel-Binding"; +// const ATTESTATION_TYPE_HEADER: &str = "X-Flashbots-Attestation-Type"; + +/// The header name for giving measurements +const MEASUREMENT_HEADER: &str = "X-Flashbots-Measurement"; + pub struct TlsCertAndKey { pub cert_chain: Vec>, pub key: PrivateKeyDer<'static>, @@ -201,25 +206,39 @@ impl ProxyServer { let mut buf = vec![0; length]; tls_stream.read_exact(&mut buf).await?; - if remote_attestation_platform.is_cvm() { + let measurements = if remote_attestation_platform.is_cvm() { remote_attestation_platform .verify_attestation( buf, &remote_cert_chain.ok_or(ProxyError::NoClientAuth)?, exporter, ) - .await?; - } + .await? + } else { + None + }; let http = Builder::new(); - let service = service_fn(move |req| async move { - match Self::handle_http_request(req, target).await { - Ok(res) => Ok::>, hyper::Error>(res), - Err(e) => { - eprintln!("send_request error: {e}"); - let mut resp = Response::new(full(format!("Request failed: {e}"))); - *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; - Ok(resp) + let service = service_fn(move |mut req| { + // If we have measurements, add them to the request header + let measurements = measurements.clone(); + if let Some(measurements) = measurements { + let headers = req.headers_mut(); + + headers.insert(MEASUREMENT_HEADER, measurements.to_header_format().unwrap()); + } + + async move { + match Self::handle_http_request(req, target).await { + Ok(res) => { + Ok::>, hyper::Error>(res) + } + Err(e) => { + eprintln!("send_request error: {e}"); + let mut resp = Response::new(full(format!("Request failed: {e}"))); + *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; + Ok(resp) + } } } }); @@ -429,7 +448,13 @@ impl ProxyClient { cert_chain: Option>>, local_attestation_platform: L, remote_attestation_platform: R, - ) -> Result, ProxyError> { + ) -> Result< + ( + tokio_rustls::client::TlsStream, + Option, + ), + ProxyError, + > { let out = TcpStream::connect(&target).await?; let mut tls_stream = connector .connect(server_name_from_host(&target)?, out) @@ -456,11 +481,13 @@ impl ProxyClient { let mut buf = vec![0; length]; tls_stream.read_exact(&mut buf).await?; - if remote_attestation_platform.is_cvm() { + let measurements = if remote_attestation_platform.is_cvm() { remote_attestation_platform .verify_attestation(buf, &remote_cert_chain, exporter) - .await?; - } + .await? + } else { + None + }; let attestation = if local_attestation_platform.is_cvm() { local_attestation_platform @@ -475,7 +502,7 @@ impl ProxyClient { tls_stream.write_all(&attestation).await?; - Ok(tls_stream) + Ok((tls_stream, measurements)) } // Handle a request from the source client to the proxy server @@ -487,7 +514,7 @@ impl ProxyClient { local_attestation_platform: L, remote_attestation_platform: R, ) -> Result>, ProxyError> { - let tls_stream = Self::setup_connection( + let (tls_stream, measurements) = Self::setup_connection( connector, target, cert_chain, @@ -510,7 +537,13 @@ impl ProxyClient { }); match sender.send_request(req).await { - Ok(resp) => Ok(resp.map(|b| b.boxed())), + Ok(mut resp) => { + if let Some(measurements) = measurements { + let headers = resp.headers_mut(); + headers.insert(MEASUREMENT_HEADER, measurements.to_header_format().unwrap()); + } + Ok(resp.map(|b| b.boxed())) + } Err(e) => { eprintln!("send_request error: {e}"); let mut resp = Response::new(full(format!("Request failed: {e}"))); @@ -636,8 +669,8 @@ mod tests { use super::*; use test_helpers::{ - example_http_service, example_service, generate_certificate_chain, generate_tls_config, - generate_tls_config_with_client_auth, + default_measurements, example_http_service, example_service, generate_certificate_chain, + generate_tls_config, generate_tls_config_with_client_auth, }; #[tokio::test] @@ -692,13 +725,16 @@ mod tests { }); let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) - .await - .unwrap() - .text() .await .unwrap(); - assert_eq!(res, "foobar"); + let headers = res.headers(); + let measurements_json = headers.get(MEASUREMENT_HEADER).unwrap().to_str().unwrap(); + let measurements = Measurements::from_header_format(measurements_json).unwrap(); + assert_eq!(measurements, default_measurements()); + + let res_body = res.text().await.unwrap(); + assert_eq!(res_body, "No measurements"); } #[tokio::test] @@ -765,13 +801,20 @@ mod tests { }); let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) - .await - .unwrap() - .text() .await .unwrap(); - assert_eq!(res, "foobar"); + let headers = res.headers(); + let measurements_json = headers.get(MEASUREMENT_HEADER).unwrap().to_str().unwrap(); + let measurements = Measurements::from_header_format(measurements_json).unwrap(); + assert_eq!(measurements, default_measurements()); + + let res_body = res.text().await.unwrap(); + + // The response body shows us what was in the request header (as the test http server + // handler puts them there) + let measurements = Measurements::from_header_format(&res_body).unwrap(); + assert_eq!(measurements, default_measurements()); } #[tokio::test] diff --git a/src/test_helpers.rs b/src/test_helpers.rs index 5d4a349..14d3c34 100644 --- a/src/test_helpers.rs +++ b/src/test_helpers.rs @@ -1,3 +1,4 @@ +use axum::response::IntoResponse; use std::{ net::{IpAddr, SocketAddr}, sync::Arc, @@ -10,6 +11,11 @@ use tokio_rustls::rustls::{ ClientConfig, RootCertStore, ServerConfig, }; +use crate::{ + attestation::{CvmImageMeasurements, Measurements, PlatformMeasurements}, + MEASUREMENT_HEADER, +}; + /// Helper to generate a self-signed certificate for testing pub fn generate_certificate_chain( ip: IpAddr, @@ -111,7 +117,7 @@ pub async fn example_http_service() -> SocketAddr { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); - let app = axum::Router::new().route("/", axum::routing::get(|| async { "foobar" })); + let app = axum::Router::new().route("/", axum::routing::get(get_handler)); tokio::spawn(async move { axum::serve(listener, app).await.unwrap(); @@ -120,6 +126,14 @@ pub async fn example_http_service() -> SocketAddr { addr } +async fn get_handler(headers: http::HeaderMap) -> impl IntoResponse { + headers + .get(MEASUREMENT_HEADER) + .and_then(|v| v.to_str().ok()) + .unwrap_or("No measurements") + .to_string() +} + pub async fn example_service() -> SocketAddr { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); @@ -133,3 +147,17 @@ pub async fn example_service() -> SocketAddr { addr } + +pub fn default_measurements() -> Measurements { + Measurements { + platform: PlatformMeasurements { + mrtd: [0u8; 48], + rtmr0: [0u8; 48], + }, + cvm_image: CvmImageMeasurements { + rtmr1: [0u8; 48], + rtmr2: [0u8; 48], + rtmr3: [0u8; 48], + }, + } +} From 8e60052b9b205c746ff0fbbc1fcebdd223998851 Mon Sep 17 00:00:00 2001 From: peg Date: Tue, 11 Nov 2025 18:20:04 +0100 Subject: [PATCH 07/12] Rn accidentally created file --- :w | 386 ------------------------------------------------------------- 1 file changed, 386 deletions(-) delete mode 100644 :w diff --git a/:w b/:w deleted file mode 100644 index 5a9b3d9..0000000 --- a/:w +++ /dev/null @@ -1,386 +0,0 @@ -use std::{collections::HashMap, time::SystemTimeError}; - -use configfs_tsm::QuoteGenerationError; -use dcap_qvl::{ - collateral::get_collateral_for_fmspc, - quote::{Quote, Report}, -}; -use sha2::{Digest, Sha256}; -use tdx_quote::QuoteParseError; -use thiserror::Error; -use tokio_rustls::rustls::pki_types::CertificateDer; -use x509_parser::prelude::*; - -/// For fetching collateral directly from intel, if no PCCS is specified -const PCS_URL: &str = "https://api.trustedservices.intel.com"; - -#[derive(Debug, Clone, PartialEq)] -pub struct Measurements { - pub platform: PlatformMeasurements, - pub cvm_image: CvmImageMeasurements, -} - -impl Measurements { - pub fn to_header_format(&self) -> Result { - let mut measurements_map = HashMap::new(); - measurements_map.insert(0, hex::encode(self.platform.mrtd)); - measurements_map.insert(1, hex::encode(self.platform.rtmr0)); - measurements_map.insert(2, hex::encode(self.cvm_image.rtmr1)); - measurements_map.insert(3, hex::encode(self.cvm_image.rtmr2)); - measurements_map.insert(4, hex::encode(self.cvm_image.rtmr3)); - Ok(serde_json::to_string(&measurements_map)?) - } - - pub fn from_header_format(input: &str) -> Result { - let measurements_map: HashMap = serde_json::from_str(input)?; - let measurements_map: HashMap = measurements_map - .into_iter() - .map(|(k, v)| (k, hex::decode(v).unwrap().try_into().unwrap())) - .collect(); - - Ok(Self { - platform: PlatformMeasurements { - mrtd: *measurements_map.get(&0).ok_or(MeasurementFormatError::MissingValue("MRTD".to_string())?, - rtmr0: *measurements_map.get(&1).unwrap(), - }, - cvm_image: CvmImageMeasurements { - rtmr1: *measurements_map.get(&2).unwrap(), - rtmr2: *measurements_map.get(&3).unwrap(), - rtmr3: *measurements_map.get(&4).unwrap(), - }, - }) - } -} - -#[derive(Error, Debug)] -pub enum MeasurementFormatError { - #[error("JSON: {0}")] - Json(#[from] serde_json::Error), - #[error("Missing value: {0}")] - MissingValue(String), -} - -/// Defines how to generate a quote -pub trait QuoteGenerator: Clone + Send + 'static { - /// Whether this is CVM attestation. This should always return true except for the [NoQuoteGenerator] case. - /// - /// When false, allows TLS client to be configured without client authentication - fn is_cvm(&self) -> bool; - - /// Generate an attestation - fn create_attestation( - &self, - cert_chain: &[CertificateDer<'_>], - exporter: [u8; 32], - ) -> Result, AttestationError>; -} - -/// Defines how to verify a quote -pub trait QuoteVerifier: Clone + Send + 'static { - /// Whether this is CVM attestation. This should always return true except for the [NoQuoteVerifier] case. - /// - /// When false, allows TLS client to be configured without client authentication - fn is_cvm(&self) -> bool; - - /// Verify the given attestation payload - fn verify_attestation( - &self, - input: Vec, - cert_chain: &[CertificateDer<'_>], - exporter: [u8; 32], - ) -> impl Future, AttestationError>> + Send; -} - -/// Quote generation using configfs_tsm -#[derive(Clone)] -pub struct DcapTdxQuoteGenerator; - -impl QuoteGenerator for DcapTdxQuoteGenerator { - fn is_cvm(&self) -> bool { - true - } - - fn create_attestation( - &self, - cert_chain: &[CertificateDer<'_>], - exporter: [u8; 32], - ) -> Result, AttestationError> { - let quote_input = compute_report_input(cert_chain, exporter)?; - - Ok(generate_quote(quote_input)?) - } -} - -/// Measurements determined by the CVM platform -#[derive(Clone, PartialEq, Debug)] -pub struct PlatformMeasurements { - pub mrtd: [u8; 48], - pub rtmr0: [u8; 48], -} - -impl PlatformMeasurements { - fn from_dcap_qvl_quote(quote: &dcap_qvl::quote::Quote) -> Result { - let report = match quote.report { - Report::TD10(report) => report, - Report::TD15(report) => report.base, - Report::SgxEnclave(_) => { - return Err(AttestationError::SgxNotSupported); - } - }; - Ok(Self { - mrtd: report.mr_td, - rtmr0: report.rt_mr0, - }) - } - - fn from_tdx_quote(quote: &tdx_quote::Quote) -> Self { - Self { - mrtd: quote.mrtd(), - rtmr0: quote.rtmr0(), - } - } -} - -/// Measurements determined by the CVM image -#[derive(Clone, PartialEq, Debug)] -pub struct CvmImageMeasurements { - pub rtmr1: [u8; 48], - pub rtmr2: [u8; 48], - pub rtmr3: [u8; 48], -} - -impl CvmImageMeasurements { - fn from_dcap_qvl_quote(quote: &dcap_qvl::quote::Quote) -> Result { - let report = match quote.report { - Report::TD10(report) => report, - Report::TD15(report) => report.base, - Report::SgxEnclave(_) => { - return Err(AttestationError::SgxNotSupported); - } - }; - Ok(Self { - rtmr1: report.rt_mr1, - rtmr2: report.rt_mr2, - rtmr3: report.rt_mr3, - }) - } - - fn from_tdx_quote(quote: &tdx_quote::Quote) -> Self { - Self { - rtmr1: quote.rtmr1(), - rtmr2: quote.rtmr2(), - rtmr3: quote.rtmr3(), - } - } -} - -/// Verify DCAP TDX quotes, allowing them if they have one of a given set of platform-specific and -/// OS image specific measurements -#[derive(Clone)] -pub struct DcapTdxQuoteVerifier { - /// Platform specific allowed Measurements - /// Currently an option as this may be determined internally on a per-platform basis (Eg: GCP) - pub accepted_platform_measurements: Option>, - /// OS-image specific allows measurement - this is effectively a list of allowed OS images - pub accepted_cvm_image_measurements: Vec, - /// URL of a PCCS (defaults to Intel PCS) - pub pccs_url: Option, -} - -impl QuoteVerifier for DcapTdxQuoteVerifier { - fn is_cvm(&self) -> bool { - true - } - - async fn verify_attestation( - &self, - input: Vec, - cert_chain: &[CertificateDer<'_>], - exporter: [u8; 32], - ) -> Result, AttestationError> { - let quote_input = compute_report_input(cert_chain, exporter)?; - let (platform_measurements, image_measurements) = if cfg!(not(test)) { - let now = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH)? - .as_secs(); - let quote = Quote::parse(&input)?; - - let ca = quote.ca()?; - let fmspc = hex::encode_upper(quote.fmspc()?); - let collateral = get_collateral_for_fmspc( - &self.pccs_url.clone().unwrap_or(PCS_URL.to_string()), - fmspc, - ca, - false, - ) - .await?; - - let _verified_report = dcap_qvl::verify::verify(&input, &collateral, now)?; - - let measurements = ( - PlatformMeasurements::from_dcap_qvl_quote("e)?, - CvmImageMeasurements::from_dcap_qvl_quote("e)?, - ); - if get_quote_input_data(quote.report) != quote_input { - return Err(AttestationError::InputMismatch); - } - measurements - } else { - // In tests we use mock quotes which will fail to verify - let quote = tdx_quote::Quote::from_bytes(&input)?; - if quote.report_input_data() != quote_input { - return Err(AttestationError::InputMismatch); - } - - ( - PlatformMeasurements::from_tdx_quote("e), - CvmImageMeasurements::from_tdx_quote("e), - ) - }; - - if let Some(accepted_platform_measurements) = &self.accepted_platform_measurements - && !accepted_platform_measurements.contains(&platform_measurements) - { - return Err(AttestationError::UnacceptablePlatformMeasurements); - } - - if !self - .accepted_cvm_image_measurements - .contains(&image_measurements) - { - return Err(AttestationError::UnacceptableOsImageMeasurements); - } - - Ok(Some(Measurements { - platform: platform_measurements, - cvm_image: image_measurements, - })) - } -} - -/// Given a [Report] get the input data regardless of report type -fn get_quote_input_data(report: Report) -> [u8; 64] { - match report { - Report::TD10(r) => r.report_data, - Report::TD15(r) => r.base.report_data, - Report::SgxEnclave(r) => r.report_data, - } -} - -/// Given a certificate chain and an exporter (session key material), build the quote input value -/// SHA256(pki) || exporter -pub fn compute_report_input( - cert_chain: &[CertificateDer<'_>], - exporter: [u8; 32], -) -> Result<[u8; 64], AttestationError> { - let mut quote_input = [0u8; 64]; - let pki_hash = get_pki_hash_from_certificate_chain(cert_chain)?; - quote_input[..32].copy_from_slice(&pki_hash); - quote_input[32..].copy_from_slice(&exporter); - Ok(quote_input) -} - -/// For no CVM platform (eg: for one-sided remote-attested TLS) -#[derive(Clone)] -pub struct NoQuoteGenerator; - -impl QuoteGenerator for NoQuoteGenerator { - fn is_cvm(&self) -> bool { - false - } - - /// Create an empty attestation - fn create_attestation( - &self, - _cert_chain: &[CertificateDer<'_>], - _exporter: [u8; 32], - ) -> Result, AttestationError> { - Ok(Vec::new()) - } -} - -/// For no CVM platform (eg: for one-sided remote-attested TLS) -#[derive(Clone)] -pub struct NoQuoteVerifier; - -impl QuoteVerifier for NoQuoteVerifier { - fn is_cvm(&self) -> bool { - false - } - /// Ensure that an empty attestation is given - async fn verify_attestation( - &self, - input: Vec, - _cert_chain: &[CertificateDer<'_>], - _exporter: [u8; 32], - ) -> Result, AttestationError> { - if input.is_empty() { - Ok(None) - } else { - Err(AttestationError::AttestationGivenWhenNoneExpected) - } - } -} - -/// Create a mock quote for testing on non-confidential hardware -#[cfg(test)] -fn generate_quote(input: [u8; 64]) -> Result, QuoteGenerationError> { - let attestation_key = tdx_quote::SigningKey::random(&mut rand_core::OsRng); - let provisioning_certification_key = tdx_quote::SigningKey::random(&mut rand_core::OsRng); - Ok(tdx_quote::Quote::mock( - attestation_key.clone(), - provisioning_certification_key.clone(), - input, - b"Mock cert chain".to_vec(), - ) - .as_bytes()) -} - -/// Create a quote -#[cfg(not(test))] -fn generate_quote(input: [u8; 64]) -> Result, QuoteGenerationError> { - configfs_tsm::create_quote(input) -} - -/// Given a certificate chain, get the [Sha256] hash of the public key of the leaf certificate -fn get_pki_hash_from_certificate_chain( - cert_chain: &[CertificateDer<'_>], -) -> Result<[u8; 32], AttestationError> { - let leaf_certificate = cert_chain.first().ok_or(AttestationError::NoCertificate)?; - let (_, cert) = parse_x509_certificate(leaf_certificate.as_ref())?; - let public_key = &cert.tbs_certificate.subject_pki; - let key_bytes = public_key.subject_public_key.as_ref(); - - let mut hasher = Sha256::new(); - hasher.update(key_bytes); - Ok(hasher.finalize().into()) -} - -/// An error when generating or verifying an attestation -#[derive(Error, Debug)] -pub enum AttestationError { - #[error("Certificate chain is empty")] - NoCertificate, - #[error("X509 parse: {0}")] - X509Parse(#[from] x509_parser::asn1_rs::Err), - #[error("X509: {0}")] - X509(#[from] x509_parser::error::X509Error), - #[error("Quote input is not as expected")] - InputMismatch, - #[error("Configuration mismatch - expected no remote attestation")] - AttestationGivenWhenNoneExpected, - #[error("Configfs-tsm quote generation: {0}")] - QuoteGeneration(#[from] configfs_tsm::QuoteGenerationError), - #[error("SGX quote given when TDX quote expected")] - SgxNotSupported, - #[error("Platform measurements do not match any accepted values")] - UnacceptablePlatformMeasurements, - #[error("OS image measurements do not match any accepted values")] - UnacceptableOsImageMeasurements, - #[error("System Time: {0}")] - SystemTime(#[from] SystemTimeError), - #[error("DCAP quote verification: {0}")] - DcapQvl(#[from] anyhow::Error), - #[error("Quote parse: {0}")] - QuoteParse(#[from] QuoteParseError), -} From e311deee06dcd149e8336d611f519c7284cf3fc6 Mon Sep 17 00:00:00 2001 From: peg Date: Tue, 11 Nov 2025 18:21:56 +0100 Subject: [PATCH 08/12] Improve error messages --- src/attestation.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/attestation.rs b/src/attestation.rs index bfb3ce2..d6c7eb7 100644 --- a/src/attestation.rs +++ b/src/attestation.rs @@ -48,18 +48,18 @@ impl Measurements { .ok_or(MeasurementFormatError::MissingValue("MRTD".to_string()))?, rtmr0: *measurements_map .get(&1) - .ok_or(MeasurementFormatError::MissingValue("MRTD".to_string()))?, + .ok_or(MeasurementFormatError::MissingValue("RTMR0".to_string()))?, }, cvm_image: CvmImageMeasurements { rtmr1: *measurements_map .get(&2) - .ok_or(MeasurementFormatError::MissingValue("MRTD".to_string()))?, + .ok_or(MeasurementFormatError::MissingValue("RTMR1".to_string()))?, rtmr2: *measurements_map .get(&3) - .ok_or(MeasurementFormatError::MissingValue("MRTD".to_string()))?, + .ok_or(MeasurementFormatError::MissingValue("RTMR2".to_string()))?, rtmr3: *measurements_map .get(&4) - .ok_or(MeasurementFormatError::MissingValue("MRTD".to_string()))?, + .ok_or(MeasurementFormatError::MissingValue("RTMR3".to_string()))?, }, }) } From 1905204c2a46f3f88032fabb4afbc6be1ae6c698 Mon Sep 17 00:00:00 2001 From: peg Date: Tue, 11 Nov 2025 18:34:40 +0100 Subject: [PATCH 09/12] Error handling --- src/lib.rs | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index b16666d..815a9bf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -225,7 +225,16 @@ impl ProxyServer { if let Some(measurements) = measurements { let headers = req.headers_mut(); - headers.insert(MEASUREMENT_HEADER, measurements.to_header_format().unwrap()); + match measurements.to_header_format() { + Ok(header_value) => { + headers.insert(MEASUREMENT_HEADER, header_value); + } + Err(e) => { + // This error is highly unlikely - that the measurement values fail to + // encode to JSON or fit in an HTTP header + eprintln!("Failed to encode measurement values: {e}"); + } + } } async move { @@ -262,7 +271,7 @@ impl ProxyServer { // Drive the connection tokio::spawn(async move { if let Err(e) = conn.await { - eprintln!("client conn error: {e}"); + eprintln!("Client connection error: {e}"); } }); @@ -532,7 +541,7 @@ impl ProxyClient { // Drive the connection tokio::spawn(async move { if let Err(e) = conn.await { - eprintln!("client conn error: {e}"); + eprintln!("Client connection error: {e}"); } }); @@ -540,7 +549,16 @@ impl ProxyClient { Ok(mut resp) => { if let Some(measurements) = measurements { let headers = resp.headers_mut(); - headers.insert(MEASUREMENT_HEADER, measurements.to_header_format().unwrap()); + match measurements.to_header_format() { + Ok(header_value) => { + headers.insert(MEASUREMENT_HEADER, header_value); + } + Err(e) => { + // This error is highly unlikely - that the measurement values fail to + // encode to JSON or fit in an HTTP header + eprintln!("Failed to encode measurement values: {e}"); + } + } } Ok(resp.map(|b| b.boxed())) } From 30a27f94ee711497b85a19e2e4d98358fc65d223 Mon Sep 17 00:00:00 2001 From: peg Date: Wed, 12 Nov 2025 08:35:05 +0100 Subject: [PATCH 10/12] Rename struct fields for clarity --- src/lib.rs | 123 ++++++++++++++++++++++++++--------------------------- 1 file changed, 61 insertions(+), 62 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 815a9bf..4e824cf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -37,11 +37,15 @@ const EXPORTER_LABEL: &[u8; 24] = b"EXPORTER-Channel-Binding"; /// The header name for giving measurements const MEASUREMENT_HEADER: &str = "X-Flashbots-Measurement"; +/// TLS Credentials pub struct TlsCertAndKey { + /// Der-encoded TLS certificate chain pub cert_chain: Vec>, + /// Der-encoded TLS private key pub key: PrivateKeyDer<'static>, } +/// Inner struct used by [ProxyClient] and [ProxyServer] struct Proxy where L: QuoteGenerator, @@ -49,10 +53,10 @@ where { /// The underlying TCP listener listener: TcpListener, - /// Type of CVM platform we run on (including none) - local_attestation_platform: L, - /// Type of CVM platform the remote party runs on (including none) - remote_attestation_platform: R, + /// Quote generation type to use (including none) + local_quote_generator: L, + /// Verifier for remote attestation (including none) + remote_quote_verifier: R, } /// A TLS over TCP server which provides an attestation before forwarding traffic to a given target address @@ -75,11 +79,11 @@ impl ProxyServer { cert_and_key: TlsCertAndKey, local: impl ToSocketAddrs, target: SocketAddr, - local_attestation_platform: L, - remote_attestation_platform: R, + local_quote_generator: L, + remote_quote_verifier: R, client_auth: bool, ) -> Result { - if remote_attestation_platform.is_cvm() && !client_auth { + if remote_quote_verifier.is_cvm() && !client_auth { return Err(ProxyError::NoClientAuth); } @@ -102,8 +106,8 @@ impl ProxyServer { server_config.into(), local, target, - local_attestation_platform, - remote_attestation_platform, + local_quote_generator, + remote_quote_verifier, ) .await } @@ -116,16 +120,16 @@ impl ProxyServer { server_config: Arc, local: impl ToSocketAddrs, target: SocketAddr, - local_attestation_platform: L, - remote_attestation_platform: R, + local_quote_generator: L, + remote_quote_verifier: R, ) -> Result { let acceptor = tokio_rustls::TlsAcceptor::from(server_config); let listener = TcpListener::bind(local).await?; let inner = Proxy { listener, - local_attestation_platform, - remote_attestation_platform, + local_quote_generator, + remote_quote_verifier, }; Ok(Self { @@ -143,16 +147,16 @@ impl ProxyServer { let acceptor = self.acceptor.clone(); let target = self.target; let cert_chain = self.cert_chain.clone(); - let local_attestation_platform = self.inner.local_attestation_platform.clone(); - let remote_attestation_platform = self.inner.remote_attestation_platform.clone(); + let local_quote_generator = self.inner.local_quote_generator.clone(); + let remote_quote_verifier = self.inner.remote_quote_verifier.clone(); tokio::spawn(async move { if let Err(err) = Self::handle_connection( inbound, acceptor, target, cert_chain, - local_attestation_platform, - remote_attestation_platform, + local_quote_generator, + remote_quote_verifier, ) .await { @@ -172,8 +176,8 @@ impl ProxyServer { acceptor: TlsAcceptor, target: SocketAddr, cert_chain: Vec>, - local_attestation_platform: L, - remote_attestation_platform: R, + local_quote_generator: L, + remote_quote_verifier: R, ) -> Result<(), ProxyError> { let mut tls_stream = acceptor.accept(inbound).await?; let (_io, connection) = tls_stream.get_ref(); @@ -187,8 +191,8 @@ impl ProxyServer { let remote_cert_chain = connection.peer_certificates().map(|c| c.to_owned()); - let attestation = if local_attestation_platform.is_cvm() { - local_attestation_platform.create_attestation(&cert_chain, exporter)? + let attestation = if local_quote_generator.is_cvm() { + local_quote_generator.create_attestation(&cert_chain, exporter)? } else { Vec::new() }; @@ -206,8 +210,8 @@ impl ProxyServer { let mut buf = vec![0; length]; tls_stream.read_exact(&mut buf).await?; - let measurements = if remote_attestation_platform.is_cvm() { - remote_attestation_platform + let measurements = if remote_quote_verifier.is_cvm() { + remote_quote_verifier .verify_attestation( buf, &remote_cert_chain.ok_or(ProxyError::NoClientAuth)?, @@ -311,10 +315,10 @@ impl ProxyClient { cert_and_key: Option, address: impl ToSocketAddrs, server_name: String, - local_attestation_platform: L, - remote_attestation_platform: R, + local_quote_generator: L, + remote_quote_verifier: R, ) -> Result { - if local_attestation_platform.is_cvm() && cert_and_key.is_none() { + if local_quote_generator.is_cvm() && cert_and_key.is_none() { return Err(ProxyError::NoClientAuth); } @@ -337,8 +341,8 @@ impl ProxyClient { client_config.into(), address, server_name, - local_attestation_platform, - remote_attestation_platform, + local_quote_generator, + remote_quote_verifier, cert_and_key.map(|c| c.cert_chain), ) .await @@ -351,8 +355,8 @@ impl ProxyClient { client_config: Arc, local: impl ToSocketAddrs, target_name: String, - local_attestation_platform: L, - remote_attestation_platform: R, + local_quote_generator: L, + remote_quote_verifier: R, cert_chain: Option>>, ) -> Result { let listener = TcpListener::bind(local).await?; @@ -360,8 +364,8 @@ impl ProxyClient { let inner = Proxy { listener, - local_attestation_platform, - remote_attestation_platform, + local_quote_generator, + remote_quote_verifier, }; Ok(Self { @@ -378,8 +382,8 @@ impl ProxyClient { let connector = self.connector.clone(); let target = self.target.clone(); - let local_attestation_platform = self.inner.local_attestation_platform.clone(); - let remote_attestation_platform = self.inner.remote_attestation_platform.clone(); + let local_quote_generator = self.inner.local_quote_generator.clone(); + let remote_quote_verifier = self.inner.remote_quote_verifier.clone(); let cert_chain = self.cert_chain.clone(); tokio::spawn(async move { @@ -388,8 +392,8 @@ impl ProxyClient { connector, target, cert_chain, - local_attestation_platform, - remote_attestation_platform, + local_quote_generator, + remote_quote_verifier, ) .await { @@ -411,24 +415,24 @@ impl ProxyClient { connector: TlsConnector, target: String, cert_chain: Option>>, - local_attestation_platform: L, - remote_attestation_platform: R, + local_quote_generator: L, + remote_quote_verifier: R, ) -> Result<(), ProxyError> { let http = Builder::new(); let service = service_fn(move |req| { let connector = connector.clone(); let target = target.clone(); let cert_chain = cert_chain.clone(); - let local_attestation_platform = local_attestation_platform.clone(); - let remote_attestation_platform = remote_attestation_platform.clone(); + let local_quote_generator = local_quote_generator.clone(); + let remote_quote_verifier = remote_quote_verifier.clone(); async move { match Self::handle_http_request( req, connector, target, cert_chain, - local_attestation_platform, - remote_attestation_platform, + local_quote_generator, + remote_quote_verifier, ) .await { @@ -455,8 +459,8 @@ impl ProxyClient { connector: TlsConnector, target: String, cert_chain: Option>>, - local_attestation_platform: L, - remote_attestation_platform: R, + local_quote_generator: L, + remote_quote_verifier: R, ) -> Result< ( tokio_rustls::client::TlsStream, @@ -490,16 +494,16 @@ impl ProxyClient { let mut buf = vec![0; length]; tls_stream.read_exact(&mut buf).await?; - let measurements = if remote_attestation_platform.is_cvm() { - remote_attestation_platform + let measurements = if remote_quote_verifier.is_cvm() { + remote_quote_verifier .verify_attestation(buf, &remote_cert_chain, exporter) .await? } else { None }; - let attestation = if local_attestation_platform.is_cvm() { - local_attestation_platform + let attestation = if local_quote_generator.is_cvm() { + local_quote_generator .create_attestation(&cert_chain.ok_or(ProxyError::NoClientAuth)?, exporter)? } else { Vec::new() @@ -520,15 +524,15 @@ impl ProxyClient { connector: TlsConnector, target: String, cert_chain: Option>>, - local_attestation_platform: L, - remote_attestation_platform: R, + local_quote_generator: L, + remote_quote_verifier: R, ) -> Result>, ProxyError> { let (tls_stream, measurements) = Self::setup_connection( connector, target, cert_chain, - local_attestation_platform, - remote_attestation_platform, + local_quote_generator, + remote_quote_verifier, ) .await?; @@ -575,23 +579,18 @@ impl ProxyClient { /// Just get the attested remote certificate, with no client authentication pub async fn get_tls_cert( server_name: String, - remote_attestation_platform: R, + remote_quote_verifier: R, ) -> Result>, ProxyError> { let root_store = RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); let client_config = ClientConfig::builder() .with_root_certificates(root_store) .with_no_client_auth(); - get_tls_cert_with_config( - server_name, - remote_attestation_platform, - client_config.into(), - ) - .await + get_tls_cert_with_config(server_name, remote_quote_verifier, client_config.into()).await } async fn get_tls_cert_with_config( server_name: String, - remote_attestation_platform: R, + remote_quote_verifier: R, client_config: Arc, ) -> Result>, ProxyError> { let connector = TlsConnector::from(client_config); @@ -622,8 +621,8 @@ async fn get_tls_cert_with_config( let mut buf = vec![0; length]; tls_stream.read_exact(&mut buf).await?; - if remote_attestation_platform.is_cvm() { - remote_attestation_platform + if remote_quote_verifier.is_cvm() { + remote_quote_verifier .verify_attestation(buf, &remote_cert_chain, exporter) .await?; } From daab94fd92c77bc613e0d5e33cab5b56ac661a83 Mon Sep 17 00:00:00 2001 From: peg Date: Wed, 12 Nov 2025 09:36:25 +0100 Subject: [PATCH 11/12] Also handle attestation type header --- src/attestation.rs | 79 +++++++++++++++++++++++++++++++++++----------- src/lib.rs | 79 ++++++++++++++++++++++++++++++++-------------- src/main.rs | 11 +++++-- 3 files changed, 125 insertions(+), 44 deletions(-) diff --git a/src/attestation.rs b/src/attestation.rs index d6c7eb7..c4bb207 100644 --- a/src/attestation.rs +++ b/src/attestation.rs @@ -1,4 +1,8 @@ -use std::{collections::HashMap, time::SystemTimeError}; +use std::{ + collections::HashMap, + fmt::{self, Display, Formatter}, + time::SystemTimeError, +}; use configfs_tsm::QuoteGenerationError; use dcap_qvl::{ @@ -75,12 +79,45 @@ pub enum MeasurementFormatError { BadHeaderValue(#[from] InvalidHeaderValue), } +/// Type of attestaion used +/// Only supported (or soon-to-be supported) types are given +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum AttestationType { + /// No attestion + None, + /// Mock attestion + Dummy, + /// TDX on Google Cloud Platform + GcpTdx, + /// TDX on Azure, with MAA + AzureTdx, + /// TDX on Qemu (no cloud platform) + QemuTdx, +} + +impl AttestationType { + /// Matches the names used by Constellation aTLS + pub fn as_str(&self) -> &'static str { + match self { + AttestationType::None => "none", + AttestationType::Dummy => "dummy", + AttestationType::AzureTdx => "azure-tdx", + AttestationType::QemuTdx => "qemu-tdx", + AttestationType::GcpTdx => "gcp-tdx", + } + } +} + +impl Display for AttestationType { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + /// Defines how to generate a quote pub trait QuoteGenerator: Clone + Send + 'static { - /// Whether this is CVM attestation. This should always return true except for the [NoQuoteGenerator] case. - /// - /// When false, allows TLS client to be configured without client authentication - fn is_cvm(&self) -> bool; + /// Type of attestation used + fn attestation_type(&self) -> AttestationType; /// Generate an attestation fn create_attestation( @@ -92,10 +129,8 @@ pub trait QuoteGenerator: Clone + Send + 'static { /// Defines how to verify a quote pub trait QuoteVerifier: Clone + Send + 'static { - /// Whether this is CVM attestation. This should always return true except for the [NoQuoteVerifier] case. - /// - /// When false, allows TLS client to be configured without client authentication - fn is_cvm(&self) -> bool; + /// Type of attestation used + fn attestation_type(&self) -> AttestationType; /// Verify the given attestation payload fn verify_attestation( @@ -108,11 +143,14 @@ pub trait QuoteVerifier: Clone + Send + 'static { /// Quote generation using configfs_tsm #[derive(Clone)] -pub struct DcapTdxQuoteGenerator; +pub struct DcapTdxQuoteGenerator { + pub attestation_type: AttestationType, +} impl QuoteGenerator for DcapTdxQuoteGenerator { - fn is_cvm(&self) -> bool { - true + /// Type of attestation used + fn attestation_type(&self) -> AttestationType { + self.attestation_type } fn create_attestation( @@ -193,6 +231,7 @@ impl CvmImageMeasurements { /// OS image specific measurements #[derive(Clone)] pub struct DcapTdxQuoteVerifier { + pub attestation_type: AttestationType, /// Platform specific allowed Measurements /// Currently an option as this may be determined internally on a per-platform basis (Eg: GCP) pub accepted_platform_measurements: Option>, @@ -203,8 +242,9 @@ pub struct DcapTdxQuoteVerifier { } impl QuoteVerifier for DcapTdxQuoteVerifier { - fn is_cvm(&self) -> bool { - true + /// Type of attestation used + fn attestation_type(&self) -> AttestationType { + self.attestation_type } async fn verify_attestation( @@ -300,8 +340,9 @@ pub fn compute_report_input( pub struct NoQuoteGenerator; impl QuoteGenerator for NoQuoteGenerator { - fn is_cvm(&self) -> bool { - false + /// Type of attestation used + fn attestation_type(&self) -> AttestationType { + AttestationType::None } /// Create an empty attestation @@ -319,9 +360,11 @@ impl QuoteGenerator for NoQuoteGenerator { pub struct NoQuoteVerifier; impl QuoteVerifier for NoQuoteVerifier { - fn is_cvm(&self) -> bool { - false + /// Type of attestation used + fn attestation_type(&self) -> AttestationType { + AttestationType::None } + /// Ensure that an empty attestation is given async fn verify_attestation( &self, diff --git a/src/lib.rs b/src/lib.rs index 4e824cf..f5720e6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,11 +1,12 @@ pub mod attestation; -use attestation::{AttestationError, Measurements}; +use attestation::{AttestationError, AttestationType, Measurements}; pub use attestation::{ DcapTdxQuoteGenerator, DcapTdxQuoteVerifier, NoQuoteGenerator, NoQuoteVerifier, QuoteGenerator, QuoteVerifier, }; use bytes::Bytes; +use http::HeaderValue; use http_body_util::combinators::BoxBody; use http_body_util::BodyExt; use hyper::server::conn::http1::Builder; @@ -32,7 +33,7 @@ use tokio_rustls::{ /// The label used when exporting key material from a TLS session const EXPORTER_LABEL: &[u8; 24] = b"EXPORTER-Channel-Binding"; -// const ATTESTATION_TYPE_HEADER: &str = "X-Flashbots-Attestation-Type"; +const ATTESTATION_TYPE_HEADER: &str = "X-Flashbots-Attestation-Type"; /// The header name for giving measurements const MEASUREMENT_HEADER: &str = "X-Flashbots-Measurement"; @@ -83,7 +84,7 @@ impl ProxyServer { remote_quote_verifier: R, client_auth: bool, ) -> Result { - if remote_quote_verifier.is_cvm() && !client_auth { + if remote_quote_verifier.attestation_type() != AttestationType::None && !client_auth { return Err(ProxyError::NoClientAuth); } @@ -191,7 +192,7 @@ impl ProxyServer { let remote_cert_chain = connection.peer_certificates().map(|c| c.to_owned()); - let attestation = if local_quote_generator.is_cvm() { + let attestation = if local_quote_generator.attestation_type() != AttestationType::None { local_quote_generator.create_attestation(&cert_chain, exporter)? } else { Vec::new() @@ -210,7 +211,7 @@ impl ProxyServer { let mut buf = vec![0; length]; tls_stream.read_exact(&mut buf).await?; - let measurements = if remote_quote_verifier.is_cvm() { + let measurements = if remote_quote_verifier.attestation_type() != AttestationType::None { remote_quote_verifier .verify_attestation( buf, @@ -221,6 +222,7 @@ impl ProxyServer { } else { None }; + let remote_attestation_type = remote_quote_verifier.attestation_type(); let http = Builder::new(); let service = service_fn(move |mut req| { @@ -239,6 +241,10 @@ impl ProxyServer { eprintln!("Failed to encode measurement values: {e}"); } } + headers.insert( + ATTESTATION_TYPE_HEADER, + HeaderValue::from_str(remote_attestation_type.as_str()).unwrap(), + ); } async move { @@ -318,7 +324,9 @@ impl ProxyClient { local_quote_generator: L, remote_quote_verifier: R, ) -> Result { - if local_quote_generator.is_cvm() && cert_and_key.is_none() { + if local_quote_generator.attestation_type() != AttestationType::None + && cert_and_key.is_none() + { return Err(ProxyError::NoClientAuth); } @@ -494,15 +502,11 @@ impl ProxyClient { let mut buf = vec![0; length]; tls_stream.read_exact(&mut buf).await?; - let measurements = if remote_quote_verifier.is_cvm() { - remote_quote_verifier - .verify_attestation(buf, &remote_cert_chain, exporter) - .await? - } else { - None - }; + let measurements = remote_quote_verifier + .verify_attestation(buf, &remote_cert_chain, exporter) + .await?; - let attestation = if local_quote_generator.is_cvm() { + let attestation = if local_quote_generator.attestation_type() != AttestationType::None { local_quote_generator .create_attestation(&cert_chain.ok_or(ProxyError::NoClientAuth)?, exporter)? } else { @@ -527,6 +531,8 @@ impl ProxyClient { local_quote_generator: L, remote_quote_verifier: R, ) -> Result>, ProxyError> { + let remote_attestation_type = remote_quote_verifier.attestation_type(); + let (tls_stream, measurements) = Self::setup_connection( connector, target, @@ -563,6 +569,10 @@ impl ProxyClient { eprintln!("Failed to encode measurement values: {e}"); } } + headers.insert( + ATTESTATION_TYPE_HEADER, + HeaderValue::from_str(remote_attestation_type.as_str()).unwrap(), + ); } Ok(resp.map(|b| b.boxed())) } @@ -621,11 +631,9 @@ async fn get_tls_cert_with_config( let mut buf = vec![0; length]; tls_stream.read_exact(&mut buf).await?; - if remote_quote_verifier.is_cvm() { - remote_quote_verifier - .verify_attestation(buf, &remote_cert_chain, exporter) - .await?; - } + let _measurements = remote_quote_verifier + .verify_attestation(buf, &remote_cert_chain, exporter) + .await?; Ok(remote_cert_chain) } @@ -702,7 +710,9 @@ mod tests { server_config, "127.0.0.1:0", target_addr, - DcapTdxQuoteGenerator, + DcapTdxQuoteGenerator { + attestation_type: AttestationType::Dummy, + }, NoQuoteVerifier, ) .await @@ -711,6 +721,7 @@ mod tests { let proxy_addr = proxy_server.local_addr().unwrap(); let quote_verifier = DcapTdxQuoteVerifier { + attestation_type: AttestationType::Dummy, accepted_platform_measurements: None, accepted_cvm_image_measurements: vec![CvmImageMeasurements { rtmr1: [0u8; 48], @@ -750,6 +761,13 @@ mod tests { let measurements = Measurements::from_header_format(measurements_json).unwrap(); assert_eq!(measurements, default_measurements()); + let attestation_type = headers + .get(ATTESTATION_TYPE_HEADER) + .unwrap() + .to_str() + .unwrap(); + assert_eq!(attestation_type, AttestationType::Dummy.as_str()); + let res_body = res.text().await.unwrap(); assert_eq!(res_body, "No measurements"); } @@ -774,6 +792,7 @@ mod tests { ); let quote_verifier = DcapTdxQuoteVerifier { + attestation_type: AttestationType::Dummy, accepted_platform_measurements: None, accepted_cvm_image_measurements: vec![CvmImageMeasurements { rtmr1: [0u8; 48], @@ -788,7 +807,9 @@ mod tests { server_tls_server_config, "127.0.0.1:0", target_addr, - DcapTdxQuoteGenerator, + DcapTdxQuoteGenerator { + attestation_type: AttestationType::Dummy, + }, quote_verifier.clone(), ) .await @@ -804,7 +825,9 @@ mod tests { client_tls_client_config, "127.0.0.1:0", proxy_addr.to_string(), - DcapTdxQuoteGenerator, + DcapTdxQuoteGenerator { + attestation_type: AttestationType::Dummy, + }, quote_verifier, Some(client_cert_chain), ) @@ -826,6 +849,13 @@ mod tests { let measurements = Measurements::from_header_format(measurements_json).unwrap(); assert_eq!(measurements, default_measurements()); + let attestation_type = headers + .get(ATTESTATION_TYPE_HEADER) + .unwrap() + .to_str() + .unwrap(); + assert_eq!(attestation_type, AttestationType::Dummy.as_str()); + let res_body = res.text().await.unwrap(); // The response body shows us what was in the request header (as the test http server @@ -846,7 +876,9 @@ mod tests { server_config, "127.0.0.1:0", target_addr, - DcapTdxQuoteGenerator, + DcapTdxQuoteGenerator { + attestation_type: AttestationType::Dummy, + }, NoQuoteVerifier, ) .await @@ -859,6 +891,7 @@ mod tests { }); let quote_verifier = DcapTdxQuoteVerifier { + attestation_type: AttestationType::Dummy, accepted_platform_measurements: None, accepted_cvm_image_measurements: vec![CvmImageMeasurements { rtmr1: [0u8; 48], diff --git a/src/main.rs b/src/main.rs index f8c1009..e3ab26d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,8 +4,9 @@ use std::{fs::File, net::SocketAddr, path::PathBuf}; use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer}; use attested_tls_proxy::{ - attestation::CvmImageMeasurements, get_tls_cert, DcapTdxQuoteGenerator, DcapTdxQuoteVerifier, - NoQuoteGenerator, NoQuoteVerifier, ProxyClient, ProxyServer, TlsCertAndKey, + attestation::{AttestationType, CvmImageMeasurements}, + get_tls_cert, DcapTdxQuoteGenerator, DcapTdxQuoteVerifier, NoQuoteGenerator, NoQuoteVerifier, + ProxyClient, ProxyServer, TlsCertAndKey, }; #[derive(Parser, Debug, Clone)] @@ -81,6 +82,7 @@ async fn main() -> anyhow::Result<()> { }; let quote_verifier = DcapTdxQuoteVerifier { + attestation_type: AttestationType::Dummy, accepted_platform_measurements: None, accepted_cvm_image_measurements: vec![CvmImageMeasurements { rtmr1: [0u8; 48], @@ -113,7 +115,9 @@ async fn main() -> anyhow::Result<()> { client_auth, } => { let tls_cert_and_chain = load_tls_cert_and_key(cert_chain, private_key)?; - let local_attestation = DcapTdxQuoteGenerator; + let local_attestation = DcapTdxQuoteGenerator { + attestation_type: AttestationType::Dummy, + }; let remote_attestation = NoQuoteVerifier; let server = ProxyServer::new( @@ -134,6 +138,7 @@ async fn main() -> anyhow::Result<()> { } CliCommand::GetTlsCert { server } => { let quote_verifier = DcapTdxQuoteVerifier { + attestation_type: AttestationType::Dummy, accepted_platform_measurements: None, accepted_cvm_image_measurements: vec![CvmImageMeasurements { rtmr1: [0u8; 48], From d48d121c18c8e74c9fdfacde87f851938002d199 Mon Sep 17 00:00:00 2001 From: peg Date: Wed, 12 Nov 2025 09:47:39 +0100 Subject: [PATCH 12/12] Update readme --- README.md | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 4aff9d1..8e9b9f3 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,30 @@ It has three commands: Unlike `cvm-reverse-proxy`, this uses post-handshake remote-attested TLS, meaning regular CA-signed TLS certificates can be used. -However attestation generation and verification is not yet implemented - there is a trait provided and mock attestation for testing purposes. +This repo shares some code with [ameba23/attested-channels](https://github.com/ameba23/attested-channels) and may eventually be merged with that crate. -This shares some code with [ameba23/attested-channels](https://github.com/ameba23/attested-channels) and may eventually be merged with that crate. +## Measurement headers + +When attestation is validated successfully, the following values are injected into the request / response headers: + +Header name: `X-Flashbots-Measurement` + +Header value: +```json +{ + "0": "48 byte MRTD value encoded as hex", + "1": "48 byte RTMR0 value encoded as hex", + "2": "48 byte RTMR1 value encoded as hex", + "3": "48 byte RTMR2 value encoded as hex", + "4": "48 byte RTMR3 value encoded as hex", +} +``` + +Header name: `X-Flashbots-Attestation-Type` + +Header value: + +One of `none`, `dummy`, `azure-tdx`, `qemu-tdx`, `gcp-tdx`. + +These aim to match the header formatting used by `cvm-reverse-proxy`.