From e088e09e5916fe8720545937177730e884f9f4b6 Mon Sep 17 00:00:00 2001 From: 0xffffharry <95022881+0xffffharry@users.noreply.github.com> Date: Tue, 16 Apr 2024 05:07:00 +0000 Subject: [PATCH] Make H3ClientStream Clonable --- crates/proto/src/h3/h3_client_stream.rs | 45 ++++++++++++++++++------- 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/crates/proto/src/h3/h3_client_stream.rs b/crates/proto/src/h3/h3_client_stream.rs index 591edc3c82..a4783b6409 100644 --- a/crates/proto/src/h3/h3_client_stream.rs +++ b/crates/proto/src/h3/h3_client_stream.rs @@ -6,7 +6,7 @@ // copied, modified, or distributed except according to those terms. use std::fmt::{self, Display}; -use std::future::Future; +use std::future::{self, Future}; use std::net::SocketAddr; use std::pin::Pin; use std::str::FromStr; @@ -16,12 +16,13 @@ use std::task::{Context, Poll}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use futures_util::future::FutureExt; use futures_util::stream::Stream; -use h3::client::{Connection, SendRequest}; +use h3::client::SendRequest; use h3_quinn::OpenStreams; use http::header::{self, CONTENT_LENGTH}; use quinn::{ClientConfig, Endpoint, EndpointConfig, TransportConfig}; use rustls::ClientConfig as TlsClientConfig; -use tracing::debug; +use tokio::sync::mpsc; +use tracing::{debug, warn}; use crate::error::ProtoError; use crate::http::Version; @@ -34,13 +35,14 @@ use crate::xfer::{DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream}; use super::ALPN_H3; /// A DNS client connection for DNS-over-HTTP/3 +#[derive(Clone)] #[must_use = "futures do nothing unless polled"] pub struct H3ClientStream { // Corresponds to the dns-name of the HTTP/3 server name_server_name: Arc, name_server: SocketAddr, - driver: Connection, send_request: SendRequest, + shutdown_tx: mpsc::Sender<()>, is_shutdown: bool, } @@ -264,19 +266,19 @@ impl DnsRequestSender for H3ClientStream { impl Stream for H3ClientStream { type Item = Result<(), ProtoError>; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { if self.is_shutdown { return Poll::Ready(None); } // just checking if the connection is ok - match self.driver.poll_close(cx) { - Poll::Ready(Ok(())) => Poll::Ready(None), - Poll::Pending => Poll::Pending, - Poll::Ready(Err(e)) => Poll::Ready(Some(Err(ProtoError::from(format!( - "h3 stream errored: {e}", - ))))), + if self.shutdown_tx.is_closed() { + return Poll::Ready(Some(Err(ProtoError::from( + "h3 connection is already shutdown", + )))); } + + Poll::Ready(Some(Ok(()))) } } @@ -398,15 +400,32 @@ impl H3ClientStreamBuilder { }; let h3_connection = h3_quinn::Connection::new(quic_connection); - let (driver, send_request) = h3::client::new(h3_connection) + let (mut driver, send_request) = h3::client::new(h3_connection) .await .map_err(|e| ProtoError::from(format!("h3 connection failed: {e}")))?; + + let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1); + + // TODO: hand this back for others to run rather than spawning here? + debug!("h3 connection is ready: {}", name_server); + tokio::spawn(async move { + tokio::select! { + res = future::poll_fn(|cx| driver.poll_close(cx)) => { + res.map_err(|e| warn!("h3 connection failed: {e}")) + } + _ = shutdown_rx.recv() => { + debug!("h3 connection is shutting down: {}", name_server); + Ok(()) + } + } + }); + Ok(H3ClientStream { name_server_name: Arc::from(dns_name), name_server, - driver, send_request, + shutdown_tx, is_shutdown: false, }) }