diff --git a/Cargo.lock b/Cargo.lock index 8f7fdf0..60a6dd8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -132,15 +132,21 @@ version = "0.1.0" dependencies = [ "anyhow", "axum", + "bytes", "clap", "configfs-tsm", "dcap-qvl", "hex", + "http", + "http-body-util", + "hyper", + "hyper-util", "pem-rfc7468", "rand_core 0.6.4", "rcgen", "reqwest", "rustls-pemfile", + "serde_json", "sha2", "tdx-quote", "thiserror", diff --git a/Cargo.toml b/Cargo.toml index 739386a..24c5994 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,12 @@ configfs-tsm = "0.0.2" rand_core = { version = "0.6.4", features = ["getrandom"] } dcap-qvl = "0.3.4" hex = "0.4.3" +hyper = { version = "1.7.0", features = ["server"] } +hyper-util = "0.1.17" +http-body-util = "0.1.3" +bytes = "1.10.1" +http = "1.3.1" +serde_json = "1.0.145" [dev-dependencies] rcgen = "0.14.5" diff --git a/README.md b/README.md index 4aff9d1..8e9b9f3 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,30 @@ It has three commands: Unlike `cvm-reverse-proxy`, this uses post-handshake remote-attested TLS, meaning regular CA-signed TLS certificates can be used. -However attestation generation and verification is not yet implemented - there is a trait provided and mock attestation for testing purposes. +This repo shares some code with [ameba23/attested-channels](https://github.com/ameba23/attested-channels) and may eventually be merged with that crate. -This shares some code with [ameba23/attested-channels](https://github.com/ameba23/attested-channels) and may eventually be merged with that crate. +## Measurement headers + +When attestation is validated successfully, the following values are injected into the request / response headers: + +Header name: `X-Flashbots-Measurement` + +Header value: +```json +{ + "0": "48 byte MRTD value encoded as hex", + "1": "48 byte RTMR0 value encoded as hex", + "2": "48 byte RTMR1 value encoded as hex", + "3": "48 byte RTMR2 value encoded as hex", + "4": "48 byte RTMR3 value encoded as hex", +} +``` + +Header name: `X-Flashbots-Attestation-Type` + +Header value: + +One of `none`, `dummy`, `azure-tdx`, `qemu-tdx`, `gcp-tdx`. + +These aim to match the header formatting used by `cvm-reverse-proxy`. diff --git a/src/attestation.rs b/src/attestation.rs index 351449e..c4bb207 100644 --- a/src/attestation.rs +++ b/src/attestation.rs @@ -1,10 +1,15 @@ -use std::time::SystemTimeError; +use std::{ + collections::HashMap, + fmt::{self, Display, Formatter}, + time::SystemTimeError, +}; use configfs_tsm::QuoteGenerationError; use dcap_qvl::{ collateral::get_collateral_for_fmspc, quote::{Quote, Report}, }; +use http::{header::InvalidHeaderValue, HeaderValue}; use sha2::{Digest, Sha256}; use tdx_quote::QuoteParseError; use thiserror::Error; @@ -14,12 +19,105 @@ use x509_parser::prelude::*; /// For fetching collateral directly from intel, if no PCCS is specified const PCS_URL: &str = "https://api.trustedservices.intel.com"; +#[derive(Debug, Clone, PartialEq)] +pub struct Measurements { + pub platform: PlatformMeasurements, + pub cvm_image: CvmImageMeasurements, +} + +impl Measurements { + pub fn to_header_format(&self) -> Result { + let mut measurements_map = HashMap::new(); + measurements_map.insert(0, hex::encode(self.platform.mrtd)); + measurements_map.insert(1, hex::encode(self.platform.rtmr0)); + measurements_map.insert(2, hex::encode(self.cvm_image.rtmr1)); + measurements_map.insert(3, hex::encode(self.cvm_image.rtmr2)); + measurements_map.insert(4, hex::encode(self.cvm_image.rtmr3)); + Ok(HeaderValue::from_str(&serde_json::to_string( + &measurements_map, + )?)?) + } + + pub fn from_header_format(input: &str) -> Result { + let measurements_map: HashMap = serde_json::from_str(input)?; + let measurements_map: HashMap = measurements_map + .into_iter() + .map(|(k, v)| (k, hex::decode(v).unwrap().try_into().unwrap())) + .collect(); + + Ok(Self { + platform: PlatformMeasurements { + mrtd: *measurements_map + .get(&0) + .ok_or(MeasurementFormatError::MissingValue("MRTD".to_string()))?, + rtmr0: *measurements_map + .get(&1) + .ok_or(MeasurementFormatError::MissingValue("RTMR0".to_string()))?, + }, + cvm_image: CvmImageMeasurements { + rtmr1: *measurements_map + .get(&2) + .ok_or(MeasurementFormatError::MissingValue("RTMR1".to_string()))?, + rtmr2: *measurements_map + .get(&3) + .ok_or(MeasurementFormatError::MissingValue("RTMR2".to_string()))?, + rtmr3: *measurements_map + .get(&4) + .ok_or(MeasurementFormatError::MissingValue("RTMR3".to_string()))?, + }, + }) + } +} + +#[derive(Error, Debug)] +pub enum MeasurementFormatError { + #[error("JSON: {0}")] + Json(#[from] serde_json::Error), + #[error("Missing value: {0}")] + MissingValue(String), + #[error("Invalid header value: {0}")] + BadHeaderValue(#[from] InvalidHeaderValue), +} + +/// Type of attestaion used +/// Only supported (or soon-to-be supported) types are given +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum AttestationType { + /// No attestion + None, + /// Mock attestion + Dummy, + /// TDX on Google Cloud Platform + GcpTdx, + /// TDX on Azure, with MAA + AzureTdx, + /// TDX on Qemu (no cloud platform) + QemuTdx, +} + +impl AttestationType { + /// Matches the names used by Constellation aTLS + pub fn as_str(&self) -> &'static str { + match self { + AttestationType::None => "none", + AttestationType::Dummy => "dummy", + AttestationType::AzureTdx => "azure-tdx", + AttestationType::QemuTdx => "qemu-tdx", + AttestationType::GcpTdx => "gcp-tdx", + } + } +} + +impl Display for AttestationType { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + /// Defines how to generate a quote pub trait QuoteGenerator: Clone + Send + 'static { - /// Whether this is CVM attestation. This should always return true except for the [NoQuoteGenerator] case. - /// - /// When false, allows TLS client to be configured without client authentication - fn is_cvm(&self) -> bool; + /// Type of attestation used + fn attestation_type(&self) -> AttestationType; /// Generate an attestation fn create_attestation( @@ -31,10 +129,8 @@ pub trait QuoteGenerator: Clone + Send + 'static { /// Defines how to verify a quote pub trait QuoteVerifier: Clone + Send + 'static { - /// Whether this is CVM attestation. This should always return true except for the [NoQuoteVerifier] case. - /// - /// When false, allows TLS client to be configured without client authentication - fn is_cvm(&self) -> bool; + /// Type of attestation used + fn attestation_type(&self) -> AttestationType; /// Verify the given attestation payload fn verify_attestation( @@ -42,16 +138,19 @@ pub trait QuoteVerifier: Clone + Send + 'static { input: Vec, cert_chain: &[CertificateDer<'_>], exporter: [u8; 32], - ) -> impl Future> + Send; + ) -> impl Future, AttestationError>> + Send; } /// Quote generation using configfs_tsm #[derive(Clone)] -pub struct DcapTdxQuoteGenerator; +pub struct DcapTdxQuoteGenerator { + pub attestation_type: AttestationType, +} impl QuoteGenerator for DcapTdxQuoteGenerator { - fn is_cvm(&self) -> bool { - true + /// Type of attestation used + fn attestation_type(&self) -> AttestationType { + self.attestation_type } fn create_attestation( @@ -66,7 +165,7 @@ impl QuoteGenerator for DcapTdxQuoteGenerator { } /// Measurements determined by the CVM platform -#[derive(Clone, PartialEq)] +#[derive(Clone, PartialEq, Debug)] pub struct PlatformMeasurements { pub mrtd: [u8; 48], pub rtmr0: [u8; 48], @@ -96,7 +195,7 @@ impl PlatformMeasurements { } /// Measurements determined by the CVM image -#[derive(Clone, PartialEq)] +#[derive(Clone, PartialEq, Debug)] pub struct CvmImageMeasurements { pub rtmr1: [u8; 48], pub rtmr2: [u8; 48], @@ -132,6 +231,7 @@ impl CvmImageMeasurements { /// OS image specific measurements #[derive(Clone)] pub struct DcapTdxQuoteVerifier { + pub attestation_type: AttestationType, /// Platform specific allowed Measurements /// Currently an option as this may be determined internally on a per-platform basis (Eg: GCP) pub accepted_platform_measurements: Option>, @@ -142,8 +242,9 @@ pub struct DcapTdxQuoteVerifier { } impl QuoteVerifier for DcapTdxQuoteVerifier { - fn is_cvm(&self) -> bool { - true + /// Type of attestation used + fn attestation_type(&self) -> AttestationType { + self.attestation_type } async fn verify_attestation( @@ -151,7 +252,7 @@ impl QuoteVerifier for DcapTdxQuoteVerifier { input: Vec, cert_chain: &[CertificateDer<'_>], exporter: [u8; 32], - ) -> Result<(), AttestationError> { + ) -> Result, AttestationError> { let quote_input = compute_report_input(cert_chain, exporter)?; let (platform_measurements, image_measurements) = if cfg!(not(test)) { let now = std::time::SystemTime::now() @@ -205,7 +306,10 @@ impl QuoteVerifier for DcapTdxQuoteVerifier { return Err(AttestationError::UnacceptableOsImageMeasurements); } - Ok(()) + Ok(Some(Measurements { + platform: platform_measurements, + cvm_image: image_measurements, + })) } } @@ -236,8 +340,9 @@ pub fn compute_report_input( pub struct NoQuoteGenerator; impl QuoteGenerator for NoQuoteGenerator { - fn is_cvm(&self) -> bool { - false + /// Type of attestation used + fn attestation_type(&self) -> AttestationType { + AttestationType::None } /// Create an empty attestation @@ -255,18 +360,20 @@ impl QuoteGenerator for NoQuoteGenerator { pub struct NoQuoteVerifier; impl QuoteVerifier for NoQuoteVerifier { - fn is_cvm(&self) -> bool { - false + /// Type of attestation used + fn attestation_type(&self) -> AttestationType { + AttestationType::None } + /// Ensure that an empty attestation is given async fn verify_attestation( &self, input: Vec, _cert_chain: &[CertificateDer<'_>], _exporter: [u8; 32], - ) -> Result<(), AttestationError> { + ) -> Result, AttestationError> { if input.is_empty() { - Ok(()) + Ok(None) } else { Err(AttestationError::AttestationGivenWhenNoneExpected) } diff --git a/src/lib.rs b/src/lib.rs index 152a816..f5720e6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,10 +1,18 @@ pub mod attestation; -use attestation::AttestationError; +use attestation::{AttestationError, AttestationType, Measurements}; pub use attestation::{ DcapTdxQuoteGenerator, DcapTdxQuoteVerifier, NoQuoteGenerator, NoQuoteVerifier, QuoteGenerator, QuoteVerifier, }; +use bytes::Bytes; +use http::HeaderValue; +use http_body_util::combinators::BoxBody; +use http_body_util::BodyExt; +use hyper::server::conn::http1::Builder; +use hyper::service::service_fn; +use hyper::Response; +use hyper_util::rt::TokioIo; use thiserror::Error; use tokio_rustls::rustls::server::{VerifierBuilderError, WebPkiClientVerifier}; @@ -25,11 +33,20 @@ use tokio_rustls::{ /// The label used when exporting key material from a TLS session const EXPORTER_LABEL: &[u8; 24] = b"EXPORTER-Channel-Binding"; +const ATTESTATION_TYPE_HEADER: &str = "X-Flashbots-Attestation-Type"; + +/// The header name for giving measurements +const MEASUREMENT_HEADER: &str = "X-Flashbots-Measurement"; + +/// TLS Credentials pub struct TlsCertAndKey { + /// Der-encoded TLS certificate chain pub cert_chain: Vec>, + /// Der-encoded TLS private key pub key: PrivateKeyDer<'static>, } +/// Inner struct used by [ProxyClient] and [ProxyServer] struct Proxy where L: QuoteGenerator, @@ -37,10 +54,10 @@ where { /// The underlying TCP listener listener: TcpListener, - /// Type of CVM platform we run on (including none) - local_attestation_platform: L, - /// Type of CVM platform the remote party runs on (including none) - remote_attestation_platform: R, + /// Quote generation type to use (including none) + local_quote_generator: L, + /// Verifier for remote attestation (including none) + remote_quote_verifier: R, } /// A TLS over TCP server which provides an attestation before forwarding traffic to a given target address @@ -63,11 +80,11 @@ impl ProxyServer { cert_and_key: TlsCertAndKey, local: impl ToSocketAddrs, target: SocketAddr, - local_attestation_platform: L, - remote_attestation_platform: R, + local_quote_generator: L, + remote_quote_verifier: R, client_auth: bool, ) -> Result { - if remote_attestation_platform.is_cvm() && !client_auth { + if remote_quote_verifier.attestation_type() != AttestationType::None && !client_auth { return Err(ProxyError::NoClientAuth); } @@ -90,8 +107,8 @@ impl ProxyServer { server_config.into(), local, target, - local_attestation_platform, - remote_attestation_platform, + local_quote_generator, + remote_quote_verifier, ) .await } @@ -104,16 +121,16 @@ impl ProxyServer { server_config: Arc, local: impl ToSocketAddrs, target: SocketAddr, - local_attestation_platform: L, - remote_attestation_platform: R, + local_quote_generator: L, + remote_quote_verifier: R, ) -> Result { let acceptor = tokio_rustls::TlsAcceptor::from(server_config); let listener = TcpListener::bind(local).await?; let inner = Proxy { listener, - local_attestation_platform, - remote_attestation_platform, + local_quote_generator, + remote_quote_verifier, }; Ok(Self { @@ -131,16 +148,16 @@ impl ProxyServer { let acceptor = self.acceptor.clone(); let target = self.target; let cert_chain = self.cert_chain.clone(); - let local_attestation_platform = self.inner.local_attestation_platform.clone(); - let remote_attestation_platform = self.inner.remote_attestation_platform.clone(); + let local_quote_generator = self.inner.local_quote_generator.clone(); + let remote_quote_verifier = self.inner.remote_quote_verifier.clone(); tokio::spawn(async move { if let Err(err) = Self::handle_connection( inbound, acceptor, target, cert_chain, - local_attestation_platform, - remote_attestation_platform, + local_quote_generator, + remote_quote_verifier, ) .await { @@ -160,8 +177,8 @@ impl ProxyServer { acceptor: TlsAcceptor, target: SocketAddr, cert_chain: Vec>, - local_attestation_platform: L, - remote_attestation_platform: R, + local_quote_generator: L, + remote_quote_verifier: R, ) -> Result<(), ProxyError> { let mut tls_stream = acceptor.accept(inbound).await?; let (_io, connection) = tls_stream.get_ref(); @@ -175,8 +192,8 @@ impl ProxyServer { let remote_cert_chain = connection.peer_certificates().map(|c| c.to_owned()); - let attestation = if local_attestation_platform.is_cvm() { - local_attestation_platform.create_attestation(&cert_chain, exporter)? + let attestation = if local_quote_generator.attestation_type() != AttestationType::None { + local_quote_generator.create_attestation(&cert_chain, exporter)? } else { Vec::new() }; @@ -194,26 +211,96 @@ impl ProxyServer { let mut buf = vec![0; length]; tls_stream.read_exact(&mut buf).await?; - if remote_attestation_platform.is_cvm() { - remote_attestation_platform + let measurements = if remote_quote_verifier.attestation_type() != AttestationType::None { + remote_quote_verifier .verify_attestation( buf, &remote_cert_chain.ok_or(ProxyError::NoClientAuth)?, exporter, ) - .await?; - } + .await? + } else { + None + }; + let remote_attestation_type = remote_quote_verifier.attestation_type(); + + let http = Builder::new(); + let service = service_fn(move |mut req| { + // If we have measurements, add them to the request header + let measurements = measurements.clone(); + if let Some(measurements) = measurements { + let headers = req.headers_mut(); + + match measurements.to_header_format() { + Ok(header_value) => { + headers.insert(MEASUREMENT_HEADER, header_value); + } + Err(e) => { + // This error is highly unlikely - that the measurement values fail to + // encode to JSON or fit in an HTTP header + eprintln!("Failed to encode measurement values: {e}"); + } + } + headers.insert( + ATTESTATION_TYPE_HEADER, + HeaderValue::from_str(remote_attestation_type.as_str()).unwrap(), + ); + } - let outbound = TcpStream::connect(target).await?; + async move { + match Self::handle_http_request(req, target).await { + Ok(res) => { + Ok::>, hyper::Error>(res) + } + Err(e) => { + eprintln!("send_request error: {e}"); + let mut resp = Response::new(full(format!("Request failed: {e}"))); + *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; + Ok(resp) + } + } + } + }); - let (mut inbound_reader, mut inbound_writer) = tokio::io::split(tls_stream); - let (mut outbound_reader, mut outbound_writer) = outbound.into_split(); + let io = TokioIo::new(tls_stream); + http.serve_connection(io, service).await?; - let client_to_server = tokio::io::copy(&mut inbound_reader, &mut outbound_writer); - let server_to_client = tokio::io::copy(&mut outbound_reader, &mut inbound_writer); - tokio::try_join!(client_to_server, server_to_client)?; Ok(()) } + + // Handle a request from the proxy client to the target server + async fn handle_http_request( + req: hyper::Request, + target: SocketAddr, + ) -> Result>, ProxyError> { + let outbound = TcpStream::connect(target).await?; + let outbound_io = TokioIo::new(outbound); + let (mut sender, conn) = hyper::client::conn::http1::Builder::new() + .handshake::<_, hyper::body::Incoming>(outbound_io) + .await?; + // Drive the connection + tokio::spawn(async move { + if let Err(e) = conn.await { + eprintln!("Client connection error: {e}"); + } + }); + + match sender.send_request(req).await { + Ok(resp) => Ok(resp.map(|b| b.boxed())), + Err(e) => { + eprintln!("send_request error: {e}"); + let mut resp = Response::new(full(format!("Request failed: {e}"))); + *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; + Ok(resp) + } + } + } +} + +fn full>(chunk: T) -> BoxBody { + http_body_util::Full::new(chunk.into()) + .map_err(|never| match never {}) + .boxed() } pub struct ProxyClient @@ -234,10 +321,12 @@ impl ProxyClient { cert_and_key: Option, address: impl ToSocketAddrs, server_name: String, - local_attestation_platform: L, - remote_attestation_platform: R, + local_quote_generator: L, + remote_quote_verifier: R, ) -> Result { - if local_attestation_platform.is_cvm() && cert_and_key.is_none() { + if local_quote_generator.attestation_type() != AttestationType::None + && cert_and_key.is_none() + { return Err(ProxyError::NoClientAuth); } @@ -260,8 +349,8 @@ impl ProxyClient { client_config.into(), address, server_name, - local_attestation_platform, - remote_attestation_platform, + local_quote_generator, + remote_quote_verifier, cert_and_key.map(|c| c.cert_chain), ) .await @@ -274,8 +363,8 @@ impl ProxyClient { client_config: Arc, local: impl ToSocketAddrs, target_name: String, - local_attestation_platform: L, - remote_attestation_platform: R, + local_quote_generator: L, + remote_quote_verifier: R, cert_chain: Option>>, ) -> Result { let listener = TcpListener::bind(local).await?; @@ -283,8 +372,8 @@ impl ProxyClient { let inner = Proxy { listener, - local_attestation_platform, - remote_attestation_platform, + local_quote_generator, + remote_quote_verifier, }; Ok(Self { @@ -301,8 +390,8 @@ impl ProxyClient { let connector = self.connector.clone(); let target = self.target.clone(); - let local_attestation_platform = self.inner.local_attestation_platform.clone(); - let remote_attestation_platform = self.inner.remote_attestation_platform.clone(); + let local_quote_generator = self.inner.local_quote_generator.clone(); + let remote_quote_verifier = self.inner.remote_quote_verifier.clone(); let cert_chain = self.cert_chain.clone(); tokio::spawn(async move { @@ -311,8 +400,8 @@ impl ProxyClient { connector, target, cert_chain, - local_attestation_platform, - remote_attestation_platform, + local_quote_generator, + remote_quote_verifier, ) .await { @@ -334,9 +423,59 @@ impl ProxyClient { connector: TlsConnector, target: String, cert_chain: Option>>, - local_attestation_platform: L, - remote_attestation_platform: R, + local_quote_generator: L, + remote_quote_verifier: R, ) -> Result<(), ProxyError> { + let http = Builder::new(); + let service = service_fn(move |req| { + let connector = connector.clone(); + let target = target.clone(); + let cert_chain = cert_chain.clone(); + let local_quote_generator = local_quote_generator.clone(); + let remote_quote_verifier = remote_quote_verifier.clone(); + async move { + match Self::handle_http_request( + req, + connector, + target, + cert_chain, + local_quote_generator, + remote_quote_verifier, + ) + .await + { + Ok(res) => { + Ok::>, hyper::Error>(res) + } + Err(e) => { + eprintln!("send_request error: {e}"); + let mut resp = Response::new(full(format!("Request failed: {e}"))); + *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; + Ok(resp) + } + } + } + }); + + let io = TokioIo::new(inbound); + http.serve_connection(io, service).await?; + + Ok(()) + } + + async fn setup_connection( + connector: TlsConnector, + target: String, + cert_chain: Option>>, + local_quote_generator: L, + remote_quote_verifier: R, + ) -> Result< + ( + tokio_rustls::client::TlsStream, + Option, + ), + ProxyError, + > { let out = TcpStream::connect(&target).await?; let mut tls_stream = connector .connect(server_name_from_host(&target)?, out) @@ -363,14 +502,12 @@ impl ProxyClient { let mut buf = vec![0; length]; tls_stream.read_exact(&mut buf).await?; - if remote_attestation_platform.is_cvm() { - remote_attestation_platform - .verify_attestation(buf, &remote_cert_chain, exporter) - .await?; - } + let measurements = remote_quote_verifier + .verify_attestation(buf, &remote_cert_chain, exporter) + .await?; - let attestation = if local_attestation_platform.is_cvm() { - local_attestation_platform + let attestation = if local_quote_generator.attestation_type() != AttestationType::None { + local_quote_generator .create_attestation(&cert_chain.ok_or(ProxyError::NoClientAuth)?, exporter)? } else { Vec::new() @@ -382,36 +519,88 @@ impl ProxyClient { tls_stream.write_all(&attestation).await?; - let (mut inbound_reader, mut inbound_writer) = inbound.into_split(); - let (mut outbound_reader, mut outbound_writer) = tokio::io::split(tls_stream); + Ok((tls_stream, measurements)) + } - let client_to_server = tokio::io::copy(&mut inbound_reader, &mut outbound_writer); - let server_to_client = tokio::io::copy(&mut outbound_reader, &mut inbound_writer); - tokio::try_join!(client_to_server, server_to_client)?; - Ok(()) + // Handle a request from the source client to the proxy server + async fn handle_http_request( + req: hyper::Request, + connector: TlsConnector, + target: String, + cert_chain: Option>>, + local_quote_generator: L, + remote_quote_verifier: R, + ) -> Result>, ProxyError> { + let remote_attestation_type = remote_quote_verifier.attestation_type(); + + let (tls_stream, measurements) = Self::setup_connection( + connector, + target, + cert_chain, + local_quote_generator, + remote_quote_verifier, + ) + .await?; + + // Now the attestation is done, forward the request to the proxy server + let outbound_io = TokioIo::new(tls_stream); + let (mut sender, conn) = hyper::client::conn::http1::Builder::new() + .handshake::<_, hyper::body::Incoming>(outbound_io) + .await?; + + // Drive the connection + tokio::spawn(async move { + if let Err(e) = conn.await { + eprintln!("Client connection error: {e}"); + } + }); + + match sender.send_request(req).await { + Ok(mut resp) => { + if let Some(measurements) = measurements { + let headers = resp.headers_mut(); + match measurements.to_header_format() { + Ok(header_value) => { + headers.insert(MEASUREMENT_HEADER, header_value); + } + Err(e) => { + // This error is highly unlikely - that the measurement values fail to + // encode to JSON or fit in an HTTP header + eprintln!("Failed to encode measurement values: {e}"); + } + } + headers.insert( + ATTESTATION_TYPE_HEADER, + HeaderValue::from_str(remote_attestation_type.as_str()).unwrap(), + ); + } + Ok(resp.map(|b| b.boxed())) + } + Err(e) => { + eprintln!("send_request error: {e}"); + let mut resp = Response::new(full(format!("Request failed: {e}"))); + *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY; + Ok(resp) + } + } } } /// Just get the attested remote certificate, with no client authentication pub async fn get_tls_cert( server_name: String, - remote_attestation_platform: R, + remote_quote_verifier: R, ) -> Result>, ProxyError> { let root_store = RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); let client_config = ClientConfig::builder() .with_root_certificates(root_store) .with_no_client_auth(); - get_tls_cert_with_config( - server_name, - remote_attestation_platform, - client_config.into(), - ) - .await + get_tls_cert_with_config(server_name, remote_quote_verifier, client_config.into()).await } async fn get_tls_cert_with_config( server_name: String, - remote_attestation_platform: R, + remote_quote_verifier: R, client_config: Arc, ) -> Result>, ProxyError> { let connector = TlsConnector::from(client_config); @@ -442,11 +631,9 @@ async fn get_tls_cert_with_config( let mut buf = vec![0; length]; tls_stream.read_exact(&mut buf).await?; - if remote_attestation_platform.is_cvm() { - remote_attestation_platform - .verify_attestation(buf, &remote_cert_chain, exporter) - .await?; - } + let _measurements = remote_quote_verifier + .verify_attestation(buf, &remote_cert_chain, exporter) + .await?; Ok(remote_cert_chain) } @@ -470,6 +657,8 @@ pub enum ProxyError { IntConversion(#[from] TryFromIntError), #[error("Bad host name: {0}")] BadDnsName(#[from] tokio_rustls::rustls::pki_types::InvalidDnsNameError), + #[error("HTTP: {0}")] + Hyper(#[from] hyper::Error), } /// Given a byte array, encode its length as a 4 byte big endian u32 @@ -505,8 +694,8 @@ mod tests { use super::*; use test_helpers::{ - example_http_service, example_service, generate_certificate_chain, generate_tls_config, - generate_tls_config_with_client_auth, + default_measurements, example_http_service, example_service, generate_certificate_chain, + generate_tls_config, generate_tls_config_with_client_auth, }; #[tokio::test] @@ -521,7 +710,9 @@ mod tests { server_config, "127.0.0.1:0", target_addr, - DcapTdxQuoteGenerator, + DcapTdxQuoteGenerator { + attestation_type: AttestationType::Dummy, + }, NoQuoteVerifier, ) .await @@ -530,6 +721,7 @@ mod tests { let proxy_addr = proxy_server.local_addr().unwrap(); let quote_verifier = DcapTdxQuoteVerifier { + attestation_type: AttestationType::Dummy, accepted_platform_measurements: None, accepted_cvm_image_measurements: vec![CvmImageMeasurements { rtmr1: [0u8; 48], @@ -562,12 +754,22 @@ mod tests { let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) .await + .unwrap(); + + let headers = res.headers(); + let measurements_json = headers.get(MEASUREMENT_HEADER).unwrap().to_str().unwrap(); + let measurements = Measurements::from_header_format(measurements_json).unwrap(); + assert_eq!(measurements, default_measurements()); + + let attestation_type = headers + .get(ATTESTATION_TYPE_HEADER) .unwrap() - .text() - .await + .to_str() .unwrap(); + assert_eq!(attestation_type, AttestationType::Dummy.as_str()); - assert_eq!(res, "foobar"); + let res_body = res.text().await.unwrap(); + assert_eq!(res_body, "No measurements"); } #[tokio::test] @@ -590,6 +792,7 @@ mod tests { ); let quote_verifier = DcapTdxQuoteVerifier { + attestation_type: AttestationType::Dummy, accepted_platform_measurements: None, accepted_cvm_image_measurements: vec![CvmImageMeasurements { rtmr1: [0u8; 48], @@ -604,7 +807,9 @@ mod tests { server_tls_server_config, "127.0.0.1:0", target_addr, - DcapTdxQuoteGenerator, + DcapTdxQuoteGenerator { + attestation_type: AttestationType::Dummy, + }, quote_verifier.clone(), ) .await @@ -620,7 +825,9 @@ mod tests { client_tls_client_config, "127.0.0.1:0", proxy_addr.to_string(), - DcapTdxQuoteGenerator, + DcapTdxQuoteGenerator { + attestation_type: AttestationType::Dummy, + }, quote_verifier, Some(client_cert_chain), ) @@ -634,72 +841,27 @@ mod tests { }); let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) - .await - .unwrap() - .text() .await .unwrap(); - assert_eq!(res, "foobar"); - } - - #[tokio::test] - async fn raw_tcp_proxy() { - let target_addr = example_service().await; - - let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); - let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); - - let proxy_server = ProxyServer::new_with_tls_config( - cert_chain, - server_config, - "127.0.0.1:0", - target_addr, - DcapTdxQuoteGenerator, - NoQuoteVerifier, - ) - .await - .unwrap(); - - let proxy_server_addr = proxy_server.local_addr().unwrap(); + let headers = res.headers(); + let measurements_json = headers.get(MEASUREMENT_HEADER).unwrap().to_str().unwrap(); + let measurements = Measurements::from_header_format(measurements_json).unwrap(); + assert_eq!(measurements, default_measurements()); - tokio::spawn(async move { - proxy_server.accept().await.unwrap(); - }); - - let quote_verifier = DcapTdxQuoteVerifier { - accepted_platform_measurements: None, - accepted_cvm_image_measurements: vec![CvmImageMeasurements { - rtmr1: [0u8; 48], - rtmr2: [0u8; 48], - rtmr3: [0u8; 48], - }], - pccs_url: None, - }; - - let proxy_client = ProxyClient::new_with_tls_config( - client_config, - "127.0.0.1:0", - proxy_server_addr.to_string(), - NoQuoteGenerator, - quote_verifier, - None, - ) - .await - .unwrap(); - - let proxy_client_addr = proxy_client.local_addr().unwrap(); - - tokio::spawn(async move { - proxy_client.accept().await.unwrap(); - }); - - let mut out = TcpStream::connect(proxy_client_addr).await.unwrap(); + let attestation_type = headers + .get(ATTESTATION_TYPE_HEADER) + .unwrap() + .to_str() + .unwrap(); + assert_eq!(attestation_type, AttestationType::Dummy.as_str()); - let mut buf = [0; 9]; - out.read(&mut buf).await.unwrap(); + let res_body = res.text().await.unwrap(); - assert_eq!(buf[..], b"some data"[..]); + // The response body shows us what was in the request header (as the test http server + // handler puts them there) + let measurements = Measurements::from_header_format(&res_body).unwrap(); + assert_eq!(measurements, default_measurements()); } #[tokio::test] @@ -714,7 +876,9 @@ mod tests { server_config, "127.0.0.1:0", target_addr, - DcapTdxQuoteGenerator, + DcapTdxQuoteGenerator { + attestation_type: AttestationType::Dummy, + }, NoQuoteVerifier, ) .await @@ -727,6 +891,7 @@ mod tests { }); let quote_verifier = DcapTdxQuoteVerifier { + attestation_type: AttestationType::Dummy, accepted_platform_measurements: None, accepted_cvm_image_measurements: vec![CvmImageMeasurements { rtmr1: [0u8; 48], diff --git a/src/main.rs b/src/main.rs index f8c1009..e3ab26d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,8 +4,9 @@ use std::{fs::File, net::SocketAddr, path::PathBuf}; use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer}; use attested_tls_proxy::{ - attestation::CvmImageMeasurements, get_tls_cert, DcapTdxQuoteGenerator, DcapTdxQuoteVerifier, - NoQuoteGenerator, NoQuoteVerifier, ProxyClient, ProxyServer, TlsCertAndKey, + attestation::{AttestationType, CvmImageMeasurements}, + get_tls_cert, DcapTdxQuoteGenerator, DcapTdxQuoteVerifier, NoQuoteGenerator, NoQuoteVerifier, + ProxyClient, ProxyServer, TlsCertAndKey, }; #[derive(Parser, Debug, Clone)] @@ -81,6 +82,7 @@ async fn main() -> anyhow::Result<()> { }; let quote_verifier = DcapTdxQuoteVerifier { + attestation_type: AttestationType::Dummy, accepted_platform_measurements: None, accepted_cvm_image_measurements: vec![CvmImageMeasurements { rtmr1: [0u8; 48], @@ -113,7 +115,9 @@ async fn main() -> anyhow::Result<()> { client_auth, } => { let tls_cert_and_chain = load_tls_cert_and_key(cert_chain, private_key)?; - let local_attestation = DcapTdxQuoteGenerator; + let local_attestation = DcapTdxQuoteGenerator { + attestation_type: AttestationType::Dummy, + }; let remote_attestation = NoQuoteVerifier; let server = ProxyServer::new( @@ -134,6 +138,7 @@ async fn main() -> anyhow::Result<()> { } CliCommand::GetTlsCert { server } => { let quote_verifier = DcapTdxQuoteVerifier { + attestation_type: AttestationType::Dummy, accepted_platform_measurements: None, accepted_cvm_image_measurements: vec![CvmImageMeasurements { rtmr1: [0u8; 48], diff --git a/src/test_helpers.rs b/src/test_helpers.rs index 5d4a349..14d3c34 100644 --- a/src/test_helpers.rs +++ b/src/test_helpers.rs @@ -1,3 +1,4 @@ +use axum::response::IntoResponse; use std::{ net::{IpAddr, SocketAddr}, sync::Arc, @@ -10,6 +11,11 @@ use tokio_rustls::rustls::{ ClientConfig, RootCertStore, ServerConfig, }; +use crate::{ + attestation::{CvmImageMeasurements, Measurements, PlatformMeasurements}, + MEASUREMENT_HEADER, +}; + /// Helper to generate a self-signed certificate for testing pub fn generate_certificate_chain( ip: IpAddr, @@ -111,7 +117,7 @@ pub async fn example_http_service() -> SocketAddr { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); - let app = axum::Router::new().route("/", axum::routing::get(|| async { "foobar" })); + let app = axum::Router::new().route("/", axum::routing::get(get_handler)); tokio::spawn(async move { axum::serve(listener, app).await.unwrap(); @@ -120,6 +126,14 @@ pub async fn example_http_service() -> SocketAddr { addr } +async fn get_handler(headers: http::HeaderMap) -> impl IntoResponse { + headers + .get(MEASUREMENT_HEADER) + .and_then(|v| v.to_str().ok()) + .unwrap_or("No measurements") + .to_string() +} + pub async fn example_service() -> SocketAddr { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); @@ -133,3 +147,17 @@ pub async fn example_service() -> SocketAddr { addr } + +pub fn default_measurements() -> Measurements { + Measurements { + platform: PlatformMeasurements { + mrtd: [0u8; 48], + rtmr0: [0u8; 48], + }, + cvm_image: CvmImageMeasurements { + rtmr1: [0u8; 48], + rtmr2: [0u8; 48], + rtmr3: [0u8; 48], + }, + } +}