diff --git a/Cargo.lock b/Cargo.lock index 2af1c6b..5d758a9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -655,6 +655,30 @@ dependencies = [ "syn", ] +[[package]] +name = "dummy-attestation-server" +version = "0.1.0" +dependencies = [ + "anyhow", + "attested-tls-proxy", + "axum", + "clap", + "configfs-tsm", + "hex", + "parity-scale-codec", + "rcgen", + "reqwest", + "rustls-pemfile", + "serde", + "serde_json", + "thiserror", + "tokio", + "tokio-rustls", + "tracing", + "tracing-subscriber", + "webpki-roots", +] + [[package]] name = "ecdsa" version = "0.16.9" diff --git a/Cargo.toml b/Cargo.toml index 4c62518..a5e7085 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,3 +1,6 @@ +[workspace] +members = [".", "dummy-attestation-server"] + [package] name = "attested-tls-proxy" version = "0.1.0" diff --git a/dummy-attestation-server/Cargo.toml b/dummy-attestation-server/Cargo.toml new file mode 100644 index 0000000..10c5379 --- /dev/null +++ b/dummy-attestation-server/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "dummy-attestation-server" +version = "0.1.0" +edition = "2024" +license = "MIT" +publish = false + +[dependencies] +attested-tls-proxy = { path = ".." } +tokio = { version = "1.48.0", features = ["full"] } +axum = "0.8.6" +tokio-rustls = { version = "0.26.4", default-features = false, features = ["ring"] } +thiserror = "2.0.17" +clap = { version = "4.5.51", features = ["derive", "env"] } +webpki-roots = "1.0.4" +rustls-pemfile = "2.2.0" +anyhow = "1.0.100" +configfs-tsm = "0.0.2" +hex = "0.4.3" +serde_json = "1.0.145" +serde = "1.0.228" +tracing = "0.1.41" +tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] } +rcgen = "0.14.5" +parity-scale-codec = "3.7.5" +reqwest = { version = "0.12.23", default-features = false } + +[dev-dependencies] diff --git a/dummy-attestation-server/src/lib.rs b/dummy-attestation-server/src/lib.rs new file mode 100644 index 0000000..9181b44 --- /dev/null +++ b/dummy-attestation-server/src/lib.rs @@ -0,0 +1,131 @@ +use std::{ + net::{IpAddr, SocketAddr}, + sync::Arc, +}; + +use attested_tls_proxy::{attestation::AttestationExchangeMessage, QuoteGenerator}; +use axum::{ + extract::{Path, State}, + http::StatusCode, + response::{IntoResponse, Response}, +}; +use parity_scale_codec::{Decode, Encode}; +use tokio::net::TcpListener; +use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}; + +#[derive(Clone)] +struct SharedState { + attestation_generator: Arc, +} + +pub async fn dummy_attestation_server( + listener: TcpListener, + attestation_generator: Arc, +) -> anyhow::Result { + let addr = listener.local_addr()?; + + let app = axum::Router::new() + .route("/attest/{input_data}", axum::routing::get(get_attest)) + .with_state(SharedState { + attestation_generator, + }); + + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + Ok(addr) +} + +async fn get_attest( + State(shared_state): State, + Path(input_data): Path, +) -> Result<(StatusCode, Vec), ServerError> { + let (cert_chain, _) = generate_certificate_chain("0.0.0.0".parse().unwrap()); + let input_data: [u8; 64] = hex::decode(input_data).unwrap().try_into().unwrap(); + + let attestation = AttestationExchangeMessage::from_attestation_generator( + &cert_chain, + input_data[..32].try_into().unwrap(), + shared_state.attestation_generator, + )? + .encode(); + + Ok((StatusCode::OK, attestation)) +} + +pub async fn dummy_attestation_client(server_addr: SocketAddr) -> anyhow::Result<()> { + let input_data = [0; 64]; + let response = reqwest::get(format!( + "http://{server_addr}/attest/{}", + hex::encode(input_data) + )) + .await + .unwrap() + .bytes() + .await + .unwrap(); + + let remote_attestation_message = AttestationExchangeMessage::decode(&mut &response[..])?; + let remote_attestation_type = remote_attestation_message.attestation_type; + println!("{remote_attestation_type}"); + + // TODO validate the attestation + Ok(()) +} + +struct ServerError(pub anyhow::Error); + +impl From for ServerError +where + E: Into, +{ + fn from(err: E) -> Self { + ServerError(err.into()) + } +} + +impl IntoResponse for ServerError { + fn into_response(self) -> Response { + eprintln!("{:?}", self.0); + (StatusCode::INTERNAL_SERVER_ERROR, format!("{:?}", self.0)).into_response() + } +} + +/// Helper to generate a self-signed certificate for testing +fn generate_certificate_chain( + ip: IpAddr, +) -> (Vec>, PrivateKeyDer<'static>) { + let mut params = rcgen::CertificateParams::new(vec![]).unwrap(); + params.subject_alt_names.push(rcgen::SanType::IpAddress(ip)); + params + .distinguished_name + .push(rcgen::DnType::CommonName, ip.to_string()); + + let keypair = rcgen::KeyPair::generate().unwrap(); + let cert = params.self_signed(&keypair).unwrap(); + + let certs = vec![CertificateDer::from(cert)]; + let key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(keypair.serialize_der())); + (certs, key) +} + +#[cfg(test)] +mod tests { + + use attested_tls_proxy::attestation::AttestationType; + + use super::*; + + #[tokio::test] + async fn test_dummy_server() { + let attestation_generator = AttestationType::None.get_quote_generator().unwrap(); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let server_addr = listener.local_addr().unwrap(); + dummy_attestation_server(listener, attestation_generator) + .await + .unwrap(); + dummy_attestation_client(server_addr).await.unwrap(); + } +} diff --git a/dummy-attestation-server/src/main.rs b/dummy-attestation-server/src/main.rs new file mode 100644 index 0000000..fe1946d --- /dev/null +++ b/dummy-attestation-server/src/main.rs @@ -0,0 +1,76 @@ +use attested_tls_proxy::attestation::AttestationType; +use clap::{Parser, Subcommand}; +use dummy_attestation_server::{dummy_attestation_client, dummy_attestation_server}; +use std::net::SocketAddr; +use tokio::net::TcpListener; +use tracing::level_filters::LevelFilter; + +#[derive(Parser, Debug, Clone)] +#[clap(version, about, long_about = None)] +struct Cli { + #[clap(subcommand)] + command: CliCommand, + /// Log debug messages + #[arg(long, global = true)] + log_debug: bool, + /// Log in JSON format + #[arg(long, global = true)] + log_json: bool, +} +#[derive(Subcommand, Debug, Clone)] +enum CliCommand { + Server { + /// Socket address to listen on + #[arg(short, long, default_value = "0.0.0.0:0", env = "LISTEN_ADDR")] + listen_addr: SocketAddr, + /// Type of attestation to present (defaults to none) + #[arg(long)] + server_attestation_type: Option, + }, + Client { + /// Socket address of a dummy attestation server + server_addr: SocketAddr, + }, +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let cli = Cli::parse(); + + let level_filter = if cli.log_debug { + LevelFilter::DEBUG + } else { + LevelFilter::WARN + }; + + let env_filter = tracing_subscriber::EnvFilter::builder() + .with_default_directive(level_filter.into()) + .from_env_lossy(); + + let subscriber = tracing_subscriber::fmt::Subscriber::builder().with_env_filter(env_filter); + + if cli.log_json { + subscriber.json().init(); + } else { + subscriber.pretty().init(); + } + + match cli.command { + CliCommand::Server { + listen_addr, + server_attestation_type, + } => { + let server_attestation_type: AttestationType = serde_json::from_value( + serde_json::Value::String(server_attestation_type.unwrap_or("none".to_string())), + )?; + + let attestation_generator = server_attestation_type.get_quote_generator()?; + + let listener = TcpListener::bind(listen_addr).await?; + dummy_attestation_server(listener, attestation_generator).await?; + } + CliCommand::Client { server_addr } => dummy_attestation_client(server_addr).await?, + } + + Ok(()) +}