From 26b842b7e983d83030d369373131826eae6ea94b Mon Sep 17 00:00:00 2001 From: Harry Cheng Date: Thu, 4 Feb 2021 03:26:39 +0800 Subject: [PATCH] Enable RuntimeProvider in DoT implementations (#1373) --- bin/tests/named_openssl_tests.rs | 8 +++-- bin/tests/named_rustls_tests.rs | 15 +++++++-- crates/native-tls/src/tests.rs | 4 ++- crates/native-tls/src/tls_client_stream.rs | 17 +++++----- crates/native-tls/src/tls_stream.rs | 32 +++++++++--------- crates/openssl/src/tls_client_stream.rs | 15 +++++---- crates/openssl/src/tls_stream.rs | 33 ++++++++++--------- crates/openssl/tests/openssl_tests.rs | 9 +++-- .../src/name_server/connection_provider.rs | 21 ++++++------ .../resolver/src/tls/dns_over_native_tls.rs | 6 ++-- crates/resolver/src/tls/dns_over_openssl.rs | 6 ++-- crates/resolver/src/tls/dns_over_rustls.rs | 5 +-- crates/rustls/src/tests.rs | 8 ++++- crates/rustls/src/tls_client_stream.rs | 12 +++---- crates/rustls/src/tls_stream.rs | 22 ++++++++----- .../src/tls_client_connection.rs | 17 ++++++---- .../tests/server_future_tests.rs | 6 +++- 17 files changed, 142 insertions(+), 94 deletions(-) diff --git a/bin/tests/named_openssl_tests.rs b/bin/tests/named_openssl_tests.rs index 4c4f68b263..fac1750fdd 100644 --- a/bin/tests/named_openssl_tests.rs +++ b/bin/tests/named_openssl_tests.rs @@ -22,12 +22,14 @@ use std::io::*; use std::net::*; use native_tls::Certificate; +use tokio::net::TcpStream as TokioTcpStream; use tokio::runtime::Runtime; use trust_dns_client::client::*; use trust_dns_native_tls::TlsClientStreamBuilder; use server_harness::{named_test_harness, query_a}; +use trust_dns_proto::iocompat::AsyncIoTokioAsStd; #[test] fn test_example_tls_toml_startup() { @@ -59,7 +61,8 @@ fn test_startup(toml: &'static str) { .unwrap() .next() .unwrap(); - let mut tls_conn_builder = TlsClientStreamBuilder::new(); + let mut tls_conn_builder = + TlsClientStreamBuilder::>::new(); let cert = to_trust_anchor(&cert_der); tls_conn_builder.add_ca(cert); let (stream, sender) = tls_conn_builder.build(addr, "ns.example.com".to_string()); @@ -74,7 +77,8 @@ fn test_startup(toml: &'static str) { .unwrap() .next() .unwrap(); - let mut tls_conn_builder = TlsClientStreamBuilder::new(); + let mut tls_conn_builder = + TlsClientStreamBuilder::>::new(); let cert = to_trust_anchor(&cert_der); tls_conn_builder.add_ca(cert); let (stream, sender) = tls_conn_builder.build(addr, "ns.example.com".to_string()); diff --git a/bin/tests/named_rustls_tests.rs b/bin/tests/named_rustls_tests.rs index 427b61c71c..d9a75c9d4b 100644 --- a/bin/tests/named_rustls_tests.rs +++ b/bin/tests/named_rustls_tests.rs @@ -21,9 +21,11 @@ use std::sync::Arc; use rustls::Certificate; use rustls::ClientConfig; +use tokio::net::TcpStream as TokioTcpStream; use tokio::runtime::Runtime; use trust_dns_client::client::*; +use trust_dns_proto::iocompat::AsyncIoTokioAsStd; use trust_dns_rustls::tls_client_connect; use server_harness::{named_test_harness, query_a}; @@ -57,8 +59,11 @@ fn test_example_tls_toml_startup() { config.root_store.add(&cert).expect("bad certificate"); let config = Arc::new(config); - let (stream, sender) = - tls_client_connect(addr, "ns.example.com".to_string(), config.clone()); + let (stream, sender) = tls_client_connect::>( + addr, + "ns.example.com".to_string(), + config.clone(), + ); let client = AsyncClient::new(stream, Box::new(sender), None); let (mut client, bg) = io_loop.block_on(client).expect("client failed to connect"); @@ -72,7 +77,11 @@ fn test_example_tls_toml_startup() { .unwrap() .next() .unwrap(); - let (stream, sender) = tls_client_connect(addr, "ns.example.com".to_string(), config); + let (stream, sender) = tls_client_connect::>( + addr, + "ns.example.com".to_string(), + config, + ); let client = AsyncClient::new(stream, Box::new(sender), None); let (mut client, bg) = io_loop.block_on(client).expect("client failed to connect"); diff --git a/crates/native-tls/src/tests.rs b/crates/native-tls/src/tests.rs index fd410bff18..5752491969 100644 --- a/crates/native-tls/src/tests.rs +++ b/crates/native-tls/src/tests.rs @@ -26,8 +26,10 @@ use std::{thread, time}; use futures_util::stream::StreamExt; use native_tls; use native_tls::{Certificate, TlsAcceptor}; +use tokio::net::TcpStream as TokioTcpStream; use tokio::runtime::Runtime; +use trust_dns_proto::iocompat::AsyncIoTokioAsStd; use trust_dns_proto::xfer::SerialMessage; #[allow(clippy::useless_attribute)] @@ -193,7 +195,7 @@ fn tls_client_stream_test(server_addr: IpAddr, mtls: bool) { let trust_chain = Certificate::from_der(&root_cert_der).unwrap(); // barrier.wait(); - let mut builder = TlsStreamBuilder::new(); + let mut builder = TlsStreamBuilder::>::new(); builder.add_ca(trust_chain); // fix MTLS diff --git a/crates/native-tls/src/tls_client_stream.rs b/crates/native-tls/src/tls_client_stream.rs index dce0b7dcf7..13164e1b9b 100644 --- a/crates/native-tls/src/tls_client_stream.rs +++ b/crates/native-tls/src/tls_client_stream.rs @@ -15,12 +15,12 @@ use futures_util::TryFutureExt; use native_tls::Certificate; #[cfg(feature = "mtls")] use native_tls::Pkcs12; -use tokio::net::TcpStream as TokioTcpStream; use tokio_native_tls::TlsStream as TokioTlsStream; use trust_dns_proto::error::ProtoError; +use trust_dns_proto::iocompat::AsyncIoStdAsTokio; use trust_dns_proto::iocompat::AsyncIoTokioAsStd; -use trust_dns_proto::tcp::TcpClientStream; +use trust_dns_proto::tcp::{Connect, TcpClientStream}; use trust_dns_proto::xfer::BufDnsStreamHandle; use crate::TlsStreamBuilder; @@ -28,14 +28,15 @@ use crate::TlsStreamBuilder; /// TlsClientStream secure DNS over TCP stream /// /// See TlsClientStreamBuilder::new() -pub type TlsClientStream = TcpClientStream>>; +pub type TlsClientStream = + TcpClientStream>>>; /// Builder for TlsClientStream -pub struct TlsClientStreamBuilder(TlsStreamBuilder); +pub struct TlsClientStreamBuilder(TlsStreamBuilder); -impl TlsClientStreamBuilder { +impl TlsClientStreamBuilder { /// Creates a builder fo the construction of a TlsClientStream - pub fn new() -> TlsClientStreamBuilder { + pub fn new() -> TlsClientStreamBuilder { TlsClientStreamBuilder(TlsStreamBuilder::new()) } @@ -64,7 +65,7 @@ impl TlsClientStreamBuilder { name_server: SocketAddr, dns_name: String, ) -> ( - Pin> + Send>>, + Pin, ProtoError>> + Send>>, BufDnsStreamHandle, ) { let (stream_future, sender) = self.0.build(name_server, dns_name); @@ -81,7 +82,7 @@ impl TlsClientStreamBuilder { } } -impl Default for TlsClientStreamBuilder { +impl Default for TlsClientStreamBuilder { fn default() -> Self { Self::new() } diff --git a/crates/native-tls/src/tls_stream.rs b/crates/native-tls/src/tls_stream.rs index ecc84d97c1..59e7b14da4 100644 --- a/crates/native-tls/src/tls_stream.rs +++ b/crates/native-tls/src/tls_stream.rs @@ -7,23 +7,23 @@ //! Base TlsStream -use std::future::Future; use std::io; use std::net::SocketAddr; use std::pin::Pin; +use std::{future::Future, marker::PhantomData}; use futures_util::TryFutureExt; use native_tls::Protocol::Tlsv12; use native_tls::{Certificate, Identity, TlsConnector}; -use tokio::net::TcpStream as TokioTcpStream; use tokio_native_tls::{TlsConnector as TokioTlsConnector, TlsStream as TokioTlsStream}; -use trust_dns_proto::iocompat::AsyncIoTokioAsStd; -use trust_dns_proto::tcp::{self, TcpStream}; +use trust_dns_proto::iocompat::{AsyncIoStdAsTokio, AsyncIoTokioAsStd}; +use trust_dns_proto::tcp::Connect; +use trust_dns_proto::tcp::TcpStream; use trust_dns_proto::xfer::{BufStreamHandle, StreamReceiver}; /// A TlsStream counterpart to the TcpStream which embeds a secure TlsStream -pub type TlsStream = TcpStream>>; +pub type TlsStream = TcpStream>>>; fn tls_new(certs: Vec, pkcs12: Option) -> io::Result { let mut builder = TlsConnector::builder(); @@ -47,10 +47,10 @@ fn tls_new(certs: Vec, pkcs12: Option) -> io::Result, +pub fn tls_from_stream( + stream: TokioTlsStream>, peer_addr: SocketAddr, -) -> (TlsStream, BufStreamHandle) { +) -> (TlsStream, BufStreamHandle) { let (message_sender, outbound_messages) = BufStreamHandle::create(); let stream = TcpStream::from_stream_with_receiver( @@ -64,17 +64,19 @@ pub fn tls_from_stream( /// A builder for the TlsStream #[derive(Default)] -pub struct TlsStreamBuilder { +pub struct TlsStreamBuilder { ca_chain: Vec, identity: Option, + marker: PhantomData, } -impl TlsStreamBuilder { +impl TlsStreamBuilder { /// Constructs a new TlsStreamBuilder - pub fn new() -> TlsStreamBuilder { + pub fn new() -> TlsStreamBuilder { TlsStreamBuilder { ca_chain: vec![], identity: None, + marker: PhantomData, } } @@ -123,7 +125,7 @@ impl TlsStreamBuilder { dns_name: String, ) -> ( // TODO: change to impl? - Pin> + Send>>, + Pin, io::Error>> + Send>>, BufStreamHandle, ) { let (message_sender, outbound_messages) = BufStreamHandle::create(); @@ -137,17 +139,17 @@ impl TlsStreamBuilder { name_server: SocketAddr, dns_name: String, outbound_messages: StreamReceiver, - ) -> Result { + ) -> Result, io::Error> { use crate::tls_stream; let ca_chain = self.ca_chain.clone(); let identity = self.identity; - let tcp_stream = tcp::tokio::connect(&name_server).await; + let tcp_stream = S::connect(name_server).await; // TODO: for some reason the above wouldn't accept a ? let tcp_stream = match tcp_stream { - Ok(tcp_stream) => tcp_stream, + Ok(tcp_stream) => AsyncIoStdAsTokio(tcp_stream), Err(err) => return Err(err), }; diff --git a/crates/openssl/src/tls_client_stream.rs b/crates/openssl/src/tls_client_stream.rs index 9ddec8a937..a191158ee5 100644 --- a/crates/openssl/src/tls_client_stream.rs +++ b/crates/openssl/src/tls_client_stream.rs @@ -14,23 +14,24 @@ use futures_util::TryFutureExt; #[cfg(feature = "mtls")] use openssl::pkcs12::Pkcs12; use openssl::x509::X509; -use tokio::net::TcpStream as TokioTcpStream; use tokio_openssl::SslStream as TokioTlsStream; use trust_dns_proto::error::ProtoError; +use trust_dns_proto::iocompat::AsyncIoStdAsTokio; use trust_dns_proto::iocompat::AsyncIoTokioAsStd; -use trust_dns_proto::tcp::TcpClientStream; +use trust_dns_proto::tcp::{Connect, TcpClientStream}; use trust_dns_proto::xfer::BufDnsStreamHandle; use super::TlsStreamBuilder; /// A Type definition for the TLS stream -pub type TlsClientStream = TcpClientStream>>; +pub type TlsClientStream = + TcpClientStream>>>; /// A Builder for the TlsClientStream -pub struct TlsClientStreamBuilder(TlsStreamBuilder); +pub struct TlsClientStreamBuilder(TlsStreamBuilder); -impl TlsClientStreamBuilder { +impl TlsClientStreamBuilder { /// Creates a builder for the construction of a TlsClientStream. pub fn new() -> Self { TlsClientStreamBuilder(TlsStreamBuilder::new()) @@ -71,7 +72,7 @@ impl TlsClientStreamBuilder { name_server: SocketAddr, dns_name: String, ) -> ( - Pin> + Send>>, + Pin, ProtoError>> + Send>>, BufDnsStreamHandle, ) { let (stream_future, sender) = self.0.build(name_server, dns_name); @@ -88,7 +89,7 @@ impl TlsClientStreamBuilder { } } -impl Default for TlsClientStreamBuilder { +impl Default for TlsClientStreamBuilder { fn default() -> Self { Self::new() } diff --git a/crates/openssl/src/tls_stream.rs b/crates/openssl/src/tls_stream.rs index 0ed11db89d..5c7bdb8fe8 100644 --- a/crates/openssl/src/tls_stream.rs +++ b/crates/openssl/src/tls_stream.rs @@ -5,10 +5,10 @@ // http://opensource.org/licenses/MIT>, at your option. This file may not be // copied, modified, or distributed except according to those terms. -use std::future::Future; use std::io; use std::net::SocketAddr; use std::pin::Pin; +use std::{future::Future, marker::PhantomData}; use futures_util::{future, TryFutureExt}; use openssl::pkcs12::ParsedPkcs12; @@ -17,11 +17,11 @@ use openssl::ssl::{ConnectConfiguration, SslConnector, SslContextBuilder, SslMet use openssl::stack::Stack; use openssl::x509::store::X509StoreBuilder; use openssl::x509::{X509Ref, X509}; -use tokio::net::TcpStream as TokioTcpStream; use tokio_openssl::{self, SslStream as TokioTlsStream}; -use trust_dns_proto::iocompat::AsyncIoTokioAsStd; -use trust_dns_proto::tcp::{self, TcpStream}; +use trust_dns_proto::iocompat::{AsyncIoStdAsTokio, AsyncIoTokioAsStd}; +use trust_dns_proto::tcp::Connect; +use trust_dns_proto::tcp::TcpStream; use trust_dns_proto::xfer::BufStreamHandle; pub trait TlsIdentityExt { @@ -57,7 +57,8 @@ impl TlsIdentityExt for SslContextBuilder { } /// A TlsStream counterpart to the TcpStream which embeds a secure TlsStream -pub type TlsStream = TcpStream>>; +pub type TlsStream = TcpStream>>; +pub type CompatTlsStream = TlsStream>; fn new(certs: Vec, pkcs12: Option) -> io::Result { let mut tls = SslConnector::builder(SslMethod::tls()).map_err(|e| { @@ -115,21 +116,21 @@ fn new(certs: Vec, pkcs12: Option) -> io::Result>, +pub fn tls_stream_from_existing_tls_stream( + stream: AsyncIoTokioAsStd>>, peer_addr: SocketAddr, -) -> (TlsStream, BufStreamHandle) { +) -> (CompatTlsStream, BufStreamHandle) { let (message_sender, outbound_messages) = BufStreamHandle::create(); let stream = TcpStream::from_stream_with_receiver(stream, peer_addr, outbound_messages); (stream, message_sender) } -async fn connect_tls( +async fn connect_tls( tls_config: ConnectConfiguration, dns_name: String, name_server: SocketAddr, -) -> Result, io::Error> { - let tcp = tcp::tokio::connect(&name_server).await.map_err(|e| { +) -> Result>, io::Error> { + let tcp = S::connect(name_server).await.map_err(|e| { io::Error::new( io::ErrorKind::ConnectionRefused, format!("tls error: {}", e), @@ -137,7 +138,7 @@ async fn connect_tls( })?; let mut stream = tls_config .into_ssl(&dns_name) - .and_then(|ssl| TokioTlsStream::new(ssl, tcp)) + .and_then(|ssl| TokioTlsStream::new(ssl, AsyncIoStdAsTokio(tcp))) .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("tls error: {}", e)))?; Pin::new(&mut stream).connect().await.map_err(|e| { io::Error::new( @@ -150,17 +151,19 @@ async fn connect_tls( /// A builder for the TlsStream #[derive(Default)] -pub struct TlsStreamBuilder { +pub struct TlsStreamBuilder { ca_chain: Vec, identity: Option, + marker: PhantomData, } -impl TlsStreamBuilder { +impl TlsStreamBuilder { /// A builder for associating trust information to the `TlsStream`. pub fn new() -> Self { TlsStreamBuilder { ca_chain: vec![], identity: None, + marker: PhantomData, } } @@ -208,7 +211,7 @@ impl TlsStreamBuilder { name_server: SocketAddr, dns_name: String, ) -> ( - Pin> + Send>>, + Pin, io::Error>> + Send>>, BufStreamHandle, ) { let (message_sender, outbound_messages) = BufStreamHandle::create(); diff --git a/crates/openssl/tests/openssl_tests.rs b/crates/openssl/tests/openssl_tests.rs index bffd2015c7..3995820fb8 100644 --- a/crates/openssl/tests/openssl_tests.rs +++ b/crates/openssl/tests/openssl_tests.rs @@ -19,6 +19,7 @@ use openssl::pkey::*; use openssl::ssl::*; use openssl::x509::store::X509StoreBuilder; use openssl::x509::*; +use tokio::net::TcpStream as TokioTcpStream; use tokio::runtime::Runtime; use openssl::asn1::*; @@ -29,6 +30,8 @@ use openssl::pkcs12::*; use openssl::rsa::*; use openssl::x509::extension::*; +use trust_dns_proto::iocompat::AsyncIoTokioAsStd; +use trust_dns_proto::tcp::Connect; use trust_dns_proto::xfer::SerialMessage; use trust_dns_openssl::TlsStreamBuilder; @@ -198,7 +201,7 @@ fn tls_client_stream_test(server_addr: IpAddr, mtls: bool) { let trust_chain = X509::from_der(&root_cert_der).unwrap(); // barrier.wait(); - let mut builder = TlsStreamBuilder::new(); + let mut builder = TlsStreamBuilder::>::new(); builder.add_ca(trust_chain); if mtls { @@ -228,11 +231,11 @@ fn tls_client_stream_test(server_addr: IpAddr, mtls: bool) { } #[allow(unused_variables)] -fn config_mtls( +fn config_mtls( root_pkey: &PKey, root_name: &X509Name, root_cert: &X509, - builder: &mut TlsStreamBuilder, + builder: &mut TlsStreamBuilder, ) { #[cfg(feature = "mtls")] { diff --git a/crates/resolver/src/name_server/connection_provider.rs b/crates/resolver/src/name_server/connection_provider.rs index 9dd43e17f2..178c65bc73 100644 --- a/crates/resolver/src/name_server/connection_provider.rs +++ b/crates/resolver/src/name_server/connection_provider.rs @@ -153,9 +153,10 @@ where #[cfg(feature = "dns-over-rustls")] let (stream, handle) = - { crate::tls::new_tls_stream(socket_addr, tls_dns_name, client_config) }; + { crate::tls::new_tls_stream::(socket_addr, tls_dns_name, client_config) }; #[cfg(not(feature = "dns-over-rustls"))] - let (stream, handle) = { crate::tls::new_tls_stream(socket_addr, tls_dns_name) }; + let (stream, handle) = + { crate::tls::new_tls_stream::(socket_addr, tls_dns_name) }; let dns_conn = DnsMultiplexer::with_timeout( stream, @@ -205,6 +206,11 @@ where } } +#[cfg(feature = "dns-over-tls")] +/// Predefined type for TLS client stream +type TlsClientStream = + TcpClientStream>>>; + /// The variants of all supported connections for the Resolver #[allow(clippy::large_enum_variant, clippy::type_complexity)] pub(crate) enum ConnectionConnect { @@ -228,22 +234,17 @@ pub(crate) enum ConnectionConnect { Box< dyn Future< Output = Result< - TcpClientStream< - AsyncIoTokioAsStd>, - >, + TlsClientStream<::Tcp>, ProtoError, >, > + Send + 'static, >, >, - TcpClientStream>>, - NoopMessageFinalizer, - >, - DnsMultiplexer< - TcpClientStream>>, + TlsClientStream<::Tcp>, NoopMessageFinalizer, >, + DnsMultiplexer::Tcp>, NoopMessageFinalizer>, TokioTime, >, ), diff --git a/crates/resolver/src/tls/dns_over_native_tls.rs b/crates/resolver/src/tls/dns_over_native_tls.rs index e21fc52c90..b4d30fd319 100644 --- a/crates/resolver/src/tls/dns_over_native_tls.rs +++ b/crates/resolver/src/tls/dns_over_native_tls.rs @@ -17,12 +17,14 @@ use proto::error::ProtoError; use proto::BufDnsStreamHandle; use trust_dns_native_tls::{TlsClientStream, TlsClientStreamBuilder}; +use crate::name_server::RuntimeProvider; + #[allow(clippy::type_complexity)] -pub(crate) fn new_tls_stream( +pub(crate) fn new_tls_stream( socket_addr: SocketAddr, dns_name: String, ) -> ( - Pin> + Send>>, + Pin, ProtoError>> + Send>>, BufDnsStreamHandle, ) { let tls_builder = TlsClientStreamBuilder::new(); diff --git a/crates/resolver/src/tls/dns_over_openssl.rs b/crates/resolver/src/tls/dns_over_openssl.rs index 94aa9a404c..4289ce021c 100644 --- a/crates/resolver/src/tls/dns_over_openssl.rs +++ b/crates/resolver/src/tls/dns_over_openssl.rs @@ -17,12 +17,14 @@ use proto::error::ProtoError; use proto::BufDnsStreamHandle; use trust_dns_openssl::{TlsClientStream, TlsClientStreamBuilder}; +use crate::name_server::RuntimeProvider; + #[allow(clippy::type_complexity)] -pub(crate) fn new_tls_stream( +pub(crate) fn new_tls_stream( socket_addr: SocketAddr, dns_name: String, ) -> ( - Pin> + Send>>, + Pin, ProtoError>> + Send>>, BufDnsStreamHandle, ) { let tls_builder = TlsClientStreamBuilder::new(); diff --git a/crates/resolver/src/tls/dns_over_rustls.rs b/crates/resolver/src/tls/dns_over_rustls.rs index d977d284ae..e95d5db9f7 100644 --- a/crates/resolver/src/tls/dns_over_rustls.rs +++ b/crates/resolver/src/tls/dns_over_rustls.rs @@ -20,6 +20,7 @@ use proto::BufDnsStreamHandle; use trust_dns_rustls::{tls_client_connect, TlsClientStream}; use crate::config::TlsClientConfig; +use crate::name_server::RuntimeProvider; const ALPN_H2: &[u8] = b"h2"; @@ -40,12 +41,12 @@ lazy_static! { } #[allow(clippy::type_complexity)] -pub(crate) fn new_tls_stream( +pub(crate) fn new_tls_stream( socket_addr: SocketAddr, dns_name: String, client_config: Option, ) -> ( - Pin> + Send>>, + Pin, ProtoError>> + Send>>, BufDnsStreamHandle, ) { let client_config = client_config.map_or_else( diff --git a/crates/rustls/src/tests.rs b/crates/rustls/src/tests.rs index 8167ea4029..8e4b6f73e1 100644 --- a/crates/rustls/src/tests.rs +++ b/crates/rustls/src/tests.rs @@ -26,8 +26,10 @@ use openssl::x509::*; use futures_util::stream::StreamExt; use rustls::Certificate; use rustls::ClientConfig; +use tokio::net::TcpStream as TokioTcpStream; use tokio::runtime::Runtime; +use trust_dns_proto::iocompat::AsyncIoTokioAsStd; use trust_dns_proto::xfer::SerialMessage; use crate::tls_connect; @@ -214,7 +216,11 @@ fn tls_client_stream_test(server_addr: IpAddr, mtls: bool) { // config_mtls(&root_pkey, &root_name, &root_cert, &mut builder); // } - let (stream, mut sender) = tls_connect(server_addr, dns_name.to_string(), Arc::new(config)); + let (stream, mut sender) = tls_connect::>( + server_addr, + dns_name.to_string(), + Arc::new(config), + ); // TODO: there is a race failure here... a race with the server thread most likely... let mut stream = io_loop.block_on(stream).expect("run failed to get stream"); diff --git a/crates/rustls/src/tls_client_stream.rs b/crates/rustls/src/tls_client_stream.rs index ff75258ad2..db560235ea 100644 --- a/crates/rustls/src/tls_client_stream.rs +++ b/crates/rustls/src/tls_client_stream.rs @@ -14,18 +14,18 @@ use std::sync::Arc; use futures_util::TryFutureExt; use rustls::ClientConfig; -use tokio::net::TcpStream as TokioTcpStream; use trust_dns_proto::error::ProtoError; +use trust_dns_proto::iocompat::AsyncIoStdAsTokio; use trust_dns_proto::iocompat::AsyncIoTokioAsStd; -use trust_dns_proto::tcp::TcpClientStream; +use trust_dns_proto::tcp::{Connect, TcpClientStream}; use trust_dns_proto::xfer::BufDnsStreamHandle; use crate::tls_stream::tls_connect; /// Type of TlsClientStream used with Rustls -pub type TlsClientStream = - TcpClientStream>>; +pub type TlsClientStream = + TcpClientStream>>>; /// Creates a new TlsStream to the specified name_server /// @@ -34,12 +34,12 @@ pub type TlsClientStream = /// * `name_server` - IP and Port for the remote DNS resolver /// * `dns_name` - The DNS name, Subject Public Key Info (SPKI) name, as associated to a certificate #[allow(clippy::type_complexity)] -pub fn tls_client_connect( +pub fn tls_client_connect( name_server: SocketAddr, dns_name: String, client_config: Arc, ) -> ( - Pin> + Send + Unpin>>, + Pin, ProtoError>> + Send + Unpin>>, BufDnsStreamHandle, ) { let (stream_future, sender) = tls_connect(name_server, dns_name, client_config); diff --git a/crates/rustls/src/tls_stream.rs b/crates/rustls/src/tls_stream.rs index 0552dd3dbf..83677018ca 100644 --- a/crates/rustls/src/tls_stream.rs +++ b/crates/rustls/src/tls_stream.rs @@ -20,12 +20,13 @@ use tokio::net::TcpStream as TokioTcpStream; use tokio_rustls::TlsConnector; use webpki::{DNSName, DNSNameRef}; -use trust_dns_proto::iocompat::AsyncIoTokioAsStd; -use trust_dns_proto::tcp::{self, DnsTcpStream, TcpStream}; +use trust_dns_proto::iocompat::{AsyncIoStdAsTokio, AsyncIoTokioAsStd}; +use trust_dns_proto::tcp::Connect; +use trust_dns_proto::tcp::{DnsTcpStream, TcpStream}; use trust_dns_proto::xfer::{BufStreamHandle, StreamReceiver}; /// Predefined type for abstracting the TlsClientStream with TokioTls -pub type TokioTlsClientStream = tokio_rustls::client::TlsStream; +pub type TokioTlsClientStream = tokio_rustls::client::TlsStream>; /// Predefined type for abstracting the TlsServerStream with TokioTls pub type TokioTlsServerStream = tokio_rustls::server::TlsStream; @@ -71,7 +72,7 @@ pub fn tls_from_stream( /// * `name_server` - IP and Port for the remote DNS resolver /// * `dns_name` - The DNS name, Subject Public Key Info (SPKI) name, as associated to a certificate #[allow(clippy::type_complexity)] -pub fn tls_connect( +pub fn tls_connect( name_server: SocketAddr, dns_name: String, client_config: Arc, @@ -79,7 +80,10 @@ pub fn tls_connect( Pin< Box< dyn Future< - Output = Result>, io::Error>, + Output = Result< + TlsStream>>, + io::Error, + >, > + Send, >, >, @@ -101,20 +105,20 @@ pub fn tls_connect( (stream, message_sender) } -async fn connect_tls( +async fn connect_tls( tls_connector: TlsConnector, name_server: SocketAddr, dns_name: String, outbound_messages: StreamReceiver, -) -> io::Result>> { - let tcp = tcp::tokio::connect(&name_server).await?; +) -> io::Result>>> { + let tcp = S::connect(name_server).await?; let dns_name = DNSNameRef::try_from_ascii_str(&dns_name) .map(DNSName::from) .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "bad dns_name"))?; let s = tls_connector - .connect(dns_name.as_ref(), tcp) + .connect(dns_name.as_ref(), AsyncIoStdAsTokio(tcp)) .map_err(|e| { io::Error::new( io::ErrorKind::ConnectionRefused, diff --git a/tests/integration-tests/src/tls_client_connection.rs b/tests/integration-tests/src/tls_client_connection.rs index 59dd9cca70..098bbbf54c 100644 --- a/tests/integration-tests/src/tls_client_connection.rs +++ b/tests/integration-tests/src/tls_client_connection.rs @@ -8,15 +8,16 @@ //! TLS based DNS client connection for Client impls //! TODO: This modules was moved from trust-dns-rustls, it really doesn't need to exist if tests are refactored... -use std::net::SocketAddr; use std::pin::Pin; use std::sync::Arc; +use std::{marker::PhantomData, net::SocketAddr}; use futures::Future; use trust_dns_client::client::ClientConnection; use trust_dns_client::rr::dnssec::Signer; use trust_dns_proto::error::ProtoError; +use trust_dns_proto::tcp::Connect; use trust_dns_proto::xfer::{DnsMultiplexer, DnsMultiplexerConnect}; use rustls::ClientConfig; @@ -25,14 +26,15 @@ use trust_dns_rustls::{tls_client_connect, TlsClientStream}; /// Tls client connection /// /// Use with `trust_dns_client::client::Client` impls -pub struct TlsClientConnection { +pub struct TlsClientConnection { name_server: SocketAddr, dns_name: String, client_config: Arc, + marker: PhantomData, } #[cfg(all(feature = "dns-over-openssl", not(feature = "dns-over-rustls")))] -impl TlsClientConnection { +impl TlsClientConnection { pub fn new( name_server: SocketAddr, dns_name: String, @@ -42,16 +44,17 @@ impl TlsClientConnection { name_server, dns_name, client_config, + marker: PhantomData, } } } #[allow(clippy::type_complexity)] -impl ClientConnection for TlsClientConnection { - type Sender = DnsMultiplexer; +impl ClientConnection for TlsClientConnection { + type Sender = DnsMultiplexer, Signer>; type SenderFuture = DnsMultiplexerConnect< - Pin> + Send>>, - TlsClientStream, + Pin, ProtoError>> + Send>>, + TlsClientStream, Signer, >; diff --git a/tests/integration-tests/tests/server_future_tests.rs b/tests/integration-tests/tests/server_future_tests.rs index 8efdd9b27d..50e84bd77a 100644 --- a/tests/integration-tests/tests/server_future_tests.rs +++ b/tests/integration-tests/tests/server_future_tests.rs @@ -194,7 +194,11 @@ fn lazy_tcp_client(ipaddr: SocketAddr) -> TcpClientConnection { } #[cfg(all(feature = "dns-over-openssl", not(feature = "dns-over-rustls")))] -fn lazy_tls_client(ipaddr: SocketAddr, dns_name: String, cert_der: Vec) -> TlsClientConnection { +fn lazy_tls_client( + ipaddr: SocketAddr, + dns_name: String, + cert_der: Vec, +) -> TlsClientConnection> { use rustls::{Certificate, ClientConfig}; let trust_chain = Certificate(cert_der);