Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 16 additions & 4 deletions sqlx-core/src/net/tls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,25 @@ impl std::fmt::Display for CertificateInput {
}
}

pub struct TlsConfig<'a> {
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum TlsConfig<'a> {
RawTlsConfig(RawTlsConfig<'a>),
#[cfg(feature = "_tls-rustls")]
PrebuiltRustls {
config: &'a rustls::ClientConfig,
hostname: &'a str,
},
}

#[derive(Debug, Clone)]
pub struct RawTlsConfig<'a> {
pub accept_invalid_certs: bool,
pub accept_invalid_hostnames: bool,
pub hostname: &'a str,
pub root_cert_path: Option<&'a CertificateInput>,
pub client_cert_path: Option<&'a CertificateInput>,
pub client_key_path: Option<&'a CertificateInput>,
pub root_cert: Option<&'a CertificateInput>,
pub client_cert: Option<&'a CertificateInput>,
pub client_key: Option<&'a CertificateInput>,
}

pub async fn handshake<S, Ws>(
Expand Down
73 changes: 47 additions & 26 deletions sqlx-core/src/net/tls/tls_native_tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::io::{self, Read, Write};

use crate::io::ReadBuf;
use crate::net::tls::util::StdSocket;
use crate::net::tls::RawTlsConfig;
use crate::net::tls::TlsConfig;
use crate::net::Socket;
use crate::rt;
Expand Down Expand Up @@ -39,36 +40,56 @@ impl<S: Socket> Socket for NativeTlsSocket<S> {
}
}

pub async fn handshake<S: Socket>(
socket: S,
config: TlsConfig<'_>,
) -> crate::Result<NativeTlsSocket<S>> {
let mut builder = native_tls::TlsConnector::builder();

builder
.danger_accept_invalid_certs(config.accept_invalid_certs)
.danger_accept_invalid_hostnames(config.accept_invalid_hostnames);
impl TlsConfig<'_> {
async fn native_tls_connector(&self) -> crate::Result<(native_tls::TlsConnector, &str), Error> {
#[allow(irrefutable_let_patterns)]
let TlsConfig::RawTlsConfig(RawTlsConfig {
root_cert,
client_cert,
client_key,
accept_invalid_certs,
accept_invalid_hostnames,
hostname,
}) = self
else {
unreachable!()
};
let mut builder = native_tls::TlsConnector::builder();

builder
.danger_accept_invalid_certs(*accept_invalid_certs)
.danger_accept_invalid_hostnames(*accept_invalid_hostnames);

if let Some(root_cert) = root_cert {
let data = root_cert.data().await?;
builder.add_root_certificate(
native_tls::Certificate::from_pem(&data).map_err(Error::tls)?,
);
}

if let Some(root_cert_path) = config.root_cert_path {
let data = root_cert_path.data().await?;
builder.add_root_certificate(native_tls::Certificate::from_pem(&data).map_err(Error::tls)?);
}
// authentication using user's key-file and its associated certificate
if let (Some(cert), Some(key)) = (client_cert, client_key) {
let cert = cert.data().await?;
let key = key.data().await?;
let identity = Identity::from_pkcs8(&cert, &key).map_err(Error::tls)?;
builder.identity(identity);
}

// authentication using user's key-file and its associated certificate
if let (Some(cert_path), Some(key_path)) = (config.client_cert_path, config.client_key_path) {
let cert_path = cert_path.data().await?;
let key_path = key_path.data().await?;
let identity = Identity::from_pkcs8(&cert_path, &key_path).map_err(Error::tls)?;
builder.identity(identity);
// The openssl TlsConnector synchronously loads certificates from files.
// Loading these files can block for tens of milliseconds.
let connector = rt::spawn_blocking(move || builder.build())
.await
.map_err(Error::tls)?;
Ok((connector, hostname))
}
}

// The openssl TlsConnector synchronously loads certificates from files.
// Loading these files can block for tens of milliseconds.
let connector = rt::spawn_blocking(move || builder.build())
.await
.map_err(Error::tls)?;

let mut mid_handshake = match connector.connect(config.hostname, StdSocket::new(socket)) {
pub async fn handshake<S: Socket>(
socket: S,
config: TlsConfig<'_>,
) -> crate::Result<NativeTlsSocket<S>> {
let (connector, hostname) = config.native_tls_connector().await?;
let mut mid_handshake = match connector.connect(hostname, StdSocket::new(socket)) {
Ok(tls_stream) => return Ok(NativeTlsSocket { stream: tls_stream }),
Err(HandshakeError::Failure(e)) => return Err(Error::tls(e)),
Err(HandshakeError::WouldBlock(mid_handshake)) => mid_handshake,
Expand Down
189 changes: 112 additions & 77 deletions sqlx-core/src/net/tls/tls_rustls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use rustls::{
use crate::error::Error;
use crate::io::ReadBuf;
use crate::net::tls::util::StdSocket;
use crate::net::tls::TlsConfig;
use crate::net::tls::{RawTlsConfig, TlsConfig};
use crate::net::Socket;

pub struct RustlsSocket<S: Socket> {
Expand Down Expand Up @@ -87,100 +87,135 @@ impl<S: Socket> Socket for RustlsSocket<S> {
}
}

pub async fn handshake<S>(socket: S, tls_config: TlsConfig<'_>) -> Result<RustlsSocket<S>, Error>
where
S: Socket,
{
#[cfg(all(
feature = "_tls-rustls-aws-lc-rs",
not(feature = "_tls-rustls-ring-webpki"),
not(feature = "_tls-rustls-ring-native-roots")
))]
let provider = Arc::new(rustls::crypto::aws_lc_rs::default_provider());
#[cfg(any(
feature = "_tls-rustls-ring-webpki",
feature = "_tls-rustls-ring-native-roots"
))]
let provider = Arc::new(rustls::crypto::ring::default_provider());

// Unwrapping is safe here because we use a default provider.
let config = ClientConfig::builder_with_provider(provider.clone())
impl TlsConfig<'_> {
async fn rustls_config(&self) -> crate::Result<(rustls::ClientConfig, &str), Error> {
let RawTlsConfig {
accept_invalid_certs,
accept_invalid_hostnames,
hostname,
root_cert,
client_cert,
client_key,
} = match self {
TlsConfig::RawTlsConfig(raw) => raw,
TlsConfig::PrebuiltRustls { config, hostname } => {
return Ok(((*config).to_owned(), hostname));
}
};

#[cfg(all(
feature = "_tls-rustls-aws-lc-rs",
not(feature = "_tls-rustls-ring-webpki"),
not(feature = "_tls-rustls-ring-native-roots")
))]
let config = ClientConfig::builder_with_provider(Arc::new(
rustls::crypto::aws_lc_rs::default_provider(),
))
.with_safe_default_protocol_versions()
.unwrap();
#[cfg(any(
feature = "_tls-rustls-ring-webpki",
feature = "_tls-rustls-ring-native-roots"
))]
let config =
ClientConfig::builder_with_provider(Arc::new(rustls::crypto::ring::default_provider()))
.with_safe_default_protocol_versions()
.unwrap();
#[cfg(all(
not(feature = "_tls-rustls-aws-lc-rs"),
not(feature = "_tls-rustls-ring-webpki"),
not(feature = "_tls-rustls-ring-native-roots")
))]
let config = ClientConfig::builder();

// authentication using user's key and its associated certificate
let user_auth = match (client_cert, client_key) {
(Some(cert), Some(key)) => {
let cert_chain = certs_from_pem(cert.data().await?)?;
let key_der = private_key_from_pem(key.data().await?)?;
Some((cert_chain, key_der))
}
(None, None) => None,
(_, _) => {
return Err(Error::Configuration(
"user auth key and certs must be given together".into(),
))
}
};

// authentication using user's key and its associated certificate
let user_auth = match (tls_config.client_cert_path, tls_config.client_key_path) {
(Some(cert_path), Some(key_path)) => {
let cert_chain = certs_from_pem(cert_path.data().await?)?;
let key_der = private_key_from_pem(key_path.data().await?)?;
Some((cert_chain, key_der))
}
(None, None) => None,
(_, _) => {
return Err(Error::Configuration(
"user auth key and certs must be given together".into(),
))
}
};
let provider = config.crypto_provider().clone();

let config = if tls_config.accept_invalid_certs {
if let Some(user_auth) = user_auth {
config
.dangerous()
.with_custom_certificate_verifier(Arc::new(DummyTlsVerifier { provider }))
.with_client_auth_cert(user_auth.0, user_auth.1)
.map_err(Error::tls)?
let config = if *accept_invalid_certs {
if let Some(user_auth) = user_auth {
config
.dangerous()
.with_custom_certificate_verifier(Arc::new(DummyTlsVerifier { provider }))
.with_client_auth_cert(user_auth.0, user_auth.1)
.map_err(Error::tls)?
} else {
config
.dangerous()
.with_custom_certificate_verifier(Arc::new(DummyTlsVerifier { provider }))
.with_no_client_auth()
}
} else {
config
.dangerous()
.with_custom_certificate_verifier(Arc::new(DummyTlsVerifier { provider }))
.with_no_client_auth()
}
} else {
let mut cert_store = import_root_certs();
let mut cert_store = import_root_certs();

if let Some(ca) = tls_config.root_cert_path {
let data = ca.data().await?;
if let Some(ca) = root_cert {
let data = ca.data().await?;

for result in CertificateDer::pem_slice_iter(&data) {
let Ok(cert) = result else {
return Err(Error::Tls(format!("Invalid certificate {ca}").into()));
};
for result in CertificateDer::pem_slice_iter(&data) {
let Ok(cert) = result else {
return Err(Error::Tls(format!("Invalid certificate {ca}").into()));
};

cert_store.add(cert).map_err(|err| Error::Tls(err.into()))?;
cert_store.add(cert).map_err(|err| Error::Tls(err.into()))?;
}
}
}

if tls_config.accept_invalid_hostnames {
let verifier = WebPkiServerVerifier::builder(Arc::new(cert_store))
.build()
.map_err(|err| Error::Tls(err.into()))?;

if let Some(user_auth) = user_auth {
if *accept_invalid_hostnames {
let verifier = WebPkiServerVerifier::builder(Arc::new(cert_store))
.build()
.map_err(|err| Error::Tls(err.into()))?;

if let Some(user_auth) = user_auth {
config
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoHostnameTlsVerifier {
verifier,
}))
.with_client_auth_cert(user_auth.0, user_auth.1)
.map_err(Error::tls)?
} else {
config
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoHostnameTlsVerifier {
verifier,
}))
.with_no_client_auth()
}
} else if let Some(user_auth) = user_auth {
config
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoHostnameTlsVerifier { verifier }))
.with_root_certificates(cert_store)
.with_client_auth_cert(user_auth.0, user_auth.1)
.map_err(Error::tls)?
} else {
config
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoHostnameTlsVerifier { verifier }))
.with_root_certificates(cert_store)
.with_no_client_auth()
}
} else if let Some(user_auth) = user_auth {
config
.with_root_certificates(cert_store)
.with_client_auth_cert(user_auth.0, user_auth.1)
.map_err(Error::tls)?
} else {
config
.with_root_certificates(cert_store)
.with_no_client_auth()
}
};
};

Ok((config, hostname))
}
}

let host = ServerName::try_from(tls_config.hostname.to_owned()).map_err(Error::tls)?;
pub async fn handshake<S>(socket: S, tls_config: TlsConfig<'_>) -> Result<RustlsSocket<S>, Error>
where
S: Socket,
{
let (config, hostname) = tls_config.rustls_config().await?;
let host = ServerName::try_from(hostname.to_owned()).map_err(Error::tls)?;

let mut socket = RustlsSocket {
inner: StdSocket::new(socket),
Expand Down
12 changes: 7 additions & 5 deletions sqlx-mysql/src/connection/tls.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use sqlx_core::net::tls::RawTlsConfig;

use crate::connection::{MySqlStream, Waiting};
use crate::error::Error;
use crate::net::tls::TlsConfig;
Expand Down Expand Up @@ -53,17 +55,17 @@ pub(super) async fn maybe_upgrade<S: Socket>(
}
}

let tls_config = TlsConfig {
let tls_config = TlsConfig::RawTlsConfig(RawTlsConfig {
accept_invalid_certs: !matches!(
options.ssl_mode,
MySqlSslMode::VerifyCa | MySqlSslMode::VerifyIdentity
),
accept_invalid_hostnames: !matches!(options.ssl_mode, MySqlSslMode::VerifyIdentity),
hostname: &options.host,
root_cert_path: options.ssl_ca.as_ref(),
client_cert_path: options.ssl_client_cert.as_ref(),
client_key_path: options.ssl_client_key.as_ref(),
};
root_cert: options.ssl_ca.as_ref(),
client_cert: options.ssl_client_cert.as_ref(),
client_key: options.ssl_client_key.as_ref(),
});

// Request TLS upgrade
stream.write_packet(SslRequest {
Expand Down
4 changes: 4 additions & 0 deletions sqlx-postgres/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ any = ["sqlx-core/any"]
json = ["sqlx-core/json"]
migrate = ["sqlx-core/migrate"]
offline = ["sqlx-core/offline"]
rustls = ["dep:rustls", "sqlx-core/_tls-rustls"]

# Type Integration features
bigdecimal = ["dep:bigdecimal", "dep:num-bigint", "sqlx-core/bigdecimal"]
Expand All @@ -27,6 +28,9 @@ time = ["dep:time", "sqlx-core/time"]
uuid = ["dep:uuid", "sqlx-core/uuid"]

[dependencies]
# TLS
rustls = { version = "0.23.24", default-features = false, features = ["std", "tls12"], optional = true }

# Futures crates
futures-channel = { version = "0.3.19", default-features = false, features = ["sink", "alloc", "std"] }
futures-core = { version = "0.3.19", default-features = false }
Expand Down
Loading