Skip to content

Commit

Permalink
Store ClientConfig per connection provider
Browse files Browse the repository at this point in the history
  • Loading branch information
daxpedda committed Jun 14, 2023
1 parent e4be254 commit b63c240
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 60 deletions.
21 changes: 4 additions & 17 deletions crates/resolver/src/https.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use std::future::Future;
use std::net::SocketAddr;

use crate::error::ResolveError;
use crate::tls::CLIENT_CONFIG;

use proto::https::{HttpsClientConnect, HttpsClientStream, HttpsClientStreamBuilder};
use proto::tcp::{Connect, DnsTcpStream};
Expand All @@ -24,18 +23,12 @@ pub(crate) fn new_https_stream<S>(
socket_addr: SocketAddr,
bind_addr: Option<SocketAddr>,
dns_name: String,
client_config: Option<TlsClientConfig>,
client_config: TlsClientConfig,
) -> Result<DnsExchangeConnect<HttpsClientConnect<S>, HttpsClientStream, TokioTime>, ResolveError>
where
S: Connect,
{
let client_config = if let Some(TlsClientConfig(client_config)) = client_config {
client_config
} else {
CLIENT_CONFIG.clone()?
};

let mut https_builder = HttpsClientStreamBuilder::with_client_config(client_config);
let mut https_builder = HttpsClientStreamBuilder::with_client_config(client_config.0);
if let Some(bind_addr) = bind_addr {
https_builder.bind_addr(bind_addr);
}
Expand All @@ -49,20 +42,14 @@ pub(crate) fn new_https_stream_with_future<S, F>(
future: F,
socket_addr: SocketAddr,
dns_name: String,
client_config: Option<TlsClientConfig>,
client_config: TlsClientConfig,
) -> Result<DnsExchangeConnect<HttpsClientConnect<S>, HttpsClientStream, TokioTime>, ResolveError>
where
S: DnsTcpStream,
F: Future<Output = std::io::Result<S>> + Send + Unpin + 'static,
{
let client_config = if let Some(TlsClientConfig(client_config)) = client_config {
client_config
} else {
CLIENT_CONFIG.clone()?
};

Ok(DnsExchange::connect(
HttpsClientStreamBuilder::build_with_future(future, client_config, socket_addr, dns_name),
HttpsClientStreamBuilder::build_with_future(future, client_config.0, socket_addr, dns_name),
))
}

Expand Down
59 changes: 53 additions & 6 deletions crates/resolver/src/name_server/connection_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ use std::task::{Context, Poll};
use futures_util::future::{Future, FutureExt};
use futures_util::ready;
use futures_util::stream::{Stream, StreamExt};
#[cfg(any(feature = "dns-over-https-rustls", feature = "dns-over-quic"))]
use once_cell::sync::Lazy;
#[cfg(feature = "tokio-runtime")]
use tokio::net::TcpStream as TokioTcpStream;
#[cfg(all(feature = "dns-over-native-tls", not(feature = "dns-over-rustls")))]
Expand All @@ -30,6 +32,8 @@ use tokio_openssl::SslStream as TokioTlsStream;
#[cfg(feature = "dns-over-rustls")]
use tokio_rustls::client::TlsStream as TokioTlsStream;

#[cfg(any(feature = "dns-over-https-rustls", feature = "dns-over-quic"))]
use crate::config::TlsClientConfig;
use crate::config::{NameServerConfig, Protocol, ResolverOpts};
#[cfg(feature = "dns-over-https")]
use proto::https::{HttpsClientConnect, HttpsClientStream};
Expand Down Expand Up @@ -246,19 +250,27 @@ impl DnsHandle for GenericConnection {
#[derive(Clone)]
pub struct GenericConnector<P: RuntimeProvider> {
runtime_provider: P,
#[cfg(any(feature = "dns-over-https-rustls", feature = "dns-over-quic"))]
client_config: Arc<Lazy<Result<Arc<rustls::ClientConfig>, ProtoError>>>,
}

impl<P: RuntimeProvider> GenericConnector<P> {
/// Create a new instance.
pub fn new(runtime_provider: P) -> Self {
Self { runtime_provider }
Self {
runtime_provider,
#[cfg(any(feature = "dns-over-https-rustls", feature = "dns-over-quic"))]
client_config: Arc::new(Lazy::new(crate::tls::client_config)),
}
}
}

impl<P: RuntimeProvider + Default> Default for GenericConnector<P> {
fn default() -> Self {
Self {
runtime_provider: P::default(),
#[cfg(any(feature = "dns-over-https-rustls", feature = "dns-over-quic"))]
client_config: Arc::new(Lazy::new(crate::tls::client_config)),
}
}
}
Expand Down Expand Up @@ -313,11 +325,22 @@ impl<P: RuntimeProvider> ConnectionProvider for GenericConnector<P> {
let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
let tcp_future = self.runtime_provider.connect_tcp(socket_addr);

#[cfg(feature = "dns-over-rustls")]
let client_config = config.tls_config.clone();

#[cfg(feature = "dns-over-rustls")]
let (stream, handle) = {
let client_config = if let Some(client_config) = config.tls_config.clone() {
client_config
} else {
match (*self.client_config).clone() {
Ok(client_config) => TlsClientConfig(client_config),
Err(err) => {
return ConnectionFuture::<P> {
connect: ConnectionConnect::Error(err.into()),
spawner: self.runtime_provider.create_handle(),
}
}
}
};

crate::tls::new_tls_stream_with_future(
tcp_future,
socket_addr,
Expand Down Expand Up @@ -345,7 +368,19 @@ impl<P: RuntimeProvider> ConnectionProvider for GenericConnector<P> {
let socket_addr = config.socket_addr;
let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
#[cfg(feature = "dns-over-rustls")]
let client_config = config.tls_config.clone();
let client_config = if let Some(client_config) = config.tls_config.clone() {
client_config
} else {
match (*self.client_config).clone() {
Ok(client_config) => TlsClientConfig(client_config),
Err(err) => {
return ConnectionFuture::<P> {
connect: ConnectionConnect::Error(err.into()),
spawner: self.runtime_provider.create_handle(),
}
}
}
};
let tcp_future = self.runtime_provider.connect_tcp(socket_addr);

crate::https::new_https_stream_with_future(
Expand All @@ -367,7 +402,19 @@ impl<P: RuntimeProvider> ConnectionProvider for GenericConnector<P> {
});
let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
#[cfg(feature = "dns-over-rustls")]
let client_config = config.tls_config.clone();
let client_config = if let Some(client_config) = config.tls_config.clone() {
client_config
} else {
match (*self.client_config).clone() {
Ok(client_config) => TlsClientConfig(client_config),
Err(err) => {
return ConnectionFuture::<P> {
connect: ConnectionConnect::Error(err.into()),
spawner: self.runtime_provider.create_handle(),
}
}
}
};
let udp_future = self.runtime_provider.bind_udp(bind_addr, socket_addr);

crate::quic::new_quic_stream_with_future(
Expand Down
21 changes: 4 additions & 17 deletions crates/resolver/src/quic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,19 @@ use trust_dns_proto::quic::{QuicClientConnect, QuicClientStream};

use crate::config::TlsClientConfig;
use crate::error::ResolveError;
use crate::tls::CLIENT_CONFIG;

#[allow(clippy::type_complexity)]
#[allow(unused)]
pub(crate) fn new_quic_stream(
socket_addr: SocketAddr,
bind_addr: Option<SocketAddr>,
dns_name: String,
client_config: Option<TlsClientConfig>,
client_config: TlsClientConfig,
) -> Result<DnsExchangeConnect<QuicClientConnect, QuicClientStream, TokioTime>, ResolveError> {
let client_config = if let Some(TlsClientConfig(client_config)) = client_config {
client_config
} else {
CLIENT_CONFIG.clone()?
};

let mut quic_builder = QuicClientStream::builder()?;

// TODO: normalize the crypto config settings, can we just use common ALPN settings?
let crypto_config: CryptoConfig = (*client_config).clone();
let crypto_config: CryptoConfig = (*client_config.0).clone();

quic_builder.crypto_config(crypto_config);
if let Some(bind_addr) = bind_addr {
Expand All @@ -52,22 +45,16 @@ pub(crate) fn new_quic_stream_with_future<S, F>(
future: F,
socket_addr: SocketAddr,
dns_name: String,
client_config: Option<TlsClientConfig>,
client_config: TlsClientConfig,
) -> Result<DnsExchangeConnect<QuicClientConnect, QuicClientStream, TokioTime>, ResolveError>
where
S: DnsUdpSocket + QuicLocalAddr + 'static,
F: Future<Output = std::io::Result<S>> + Send + 'static,
{
let client_config = if let Some(TlsClientConfig(client_config)) = client_config {
client_config
} else {
CLIENT_CONFIG.clone()?
};

let mut quic_builder = QuicClientStream::builder()?;

// TODO: normalize the crypto config settings, can we just use common ALPN settings?
let crypto_config: CryptoConfig = (*client_config).clone();
let crypto_config: CryptoConfig = (*client_config.0).clone();

quic_builder.crypto_config(crypto_config);
Ok(DnsExchange::connect(quic_builder.build_with_future(
Expand Down
36 changes: 17 additions & 19 deletions crates/resolver/src/tls/dns_over_rustls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,12 @@
#![cfg(feature = "dns-over-rustls")]
#![allow(dead_code)]

use std::future;
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;

use futures_util::future::Future;
use once_cell::sync::Lazy;
use rustls::{ClientConfig, RootCertStore};

use proto::error::ProtoError;
Expand All @@ -28,7 +26,20 @@ use crate::config::TlsClientConfig;

const ALPN_H2: &[u8] = b"h2";

pub(crate) static CLIENT_CONFIG: Lazy<Result<Arc<ClientConfig>, ProtoError>> = Lazy::new(|| {
pub(crate) fn client_config() -> Result<Arc<ClientConfig>, ProtoError> {
#[cfg(not(feature = "native-certs"))]
{
use once_cell::sync::Lazy;

static CONFIG: Lazy<Result<Arc<ClientConfig>, ProtoError>> =
Lazy::new(client_config_internal);
CONFIG.clone()
}
#[cfg(feature = "native-certs")]
client_config_internal()
}

fn client_config_internal() -> Result<Arc<ClientConfig>, ProtoError> {
#[cfg_attr(
not(any(feature = "native-certs", feature = "webpki-roots")),
allow(unused_mut)
Expand Down Expand Up @@ -73,14 +84,14 @@ pub(crate) static CLIENT_CONFIG: Lazy<Result<Arc<ClientConfig>, ProtoError>> = L
client_config.alpn_protocols.push(ALPN_H2.to_vec());

Ok(Arc::new(client_config))
});
}

#[allow(clippy::type_complexity)]
pub(crate) fn new_tls_stream_with_future<S, F>(
future: F,
socket_addr: SocketAddr,
dns_name: String,
client_config: Option<TlsClientConfig>,
client_config: TlsClientConfig,
) -> (
Pin<Box<dyn Future<Output = Result<TlsClientStream<S>, ProtoError>> + Send>>,
BufDnsStreamHandle,
Expand All @@ -89,20 +100,7 @@ where
S: DnsTcpStream,
F: Future<Output = io::Result<S>> + Send + Unpin + 'static,
{
let client_config = if let Some(TlsClientConfig(client_config)) = client_config {
client_config
} else {
match CLIENT_CONFIG.clone() {
Ok(client_config) => client_config,
Err(err) => {
return (
Box::pin(future::ready(Err(err))),
BufDnsStreamHandle::new(socket_addr).0,
)
}
}
};
let (stream, handle) =
tls_client_connect_with_future(future, socket_addr, dns_name, client_config);
tls_client_connect_with_future(future, socket_addr, dns_name, client_config.0);
(Box::pin(stream), handle)
}
2 changes: 1 addition & 1 deletion crates/resolver/src/tls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ cfg_if! {
if #[cfg(feature = "dns-over-rustls")] {
pub(crate) use self::dns_over_rustls::new_tls_stream_with_future;
#[cfg(any(feature = "dns-over-https-rustls", feature = "dns-over-quic"))]
pub(crate) use self::dns_over_rustls::CLIENT_CONFIG;
pub(crate) use self::dns_over_rustls::client_config;
} else if #[cfg(feature = "dns-over-native-tls")] {
pub(crate) use self::dns_over_native_tls::new_tls_stream_with_future;
} else if #[cfg(feature = "dns-over-openssl")] {
Expand Down

0 comments on commit b63c240

Please sign in to comment.