diff --git a/crates/server/src/server/https_handler.rs b/crates/server/src/server/https_handler.rs index ab5118a8a8..18cdbbe803 100644 --- a/crates/server/src/server/https_handler.rs +++ b/crates/server/src/server/https_handler.rs @@ -5,11 +5,11 @@ // http://opensource.org/licenses/MIT>, at your option. This file may not be // copied, modified, or distributed except according to those terms. -use std::{io, net::SocketAddr, sync::Arc}; - use bytes::{Bytes, BytesMut}; use futures_util::lock::Mutex; use h2::server; +use std::future::Future; +use std::{io, net::SocketAddr, sync::Arc}; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{debug, warn}; use trust_dns_proto::rr::Record; @@ -28,6 +28,7 @@ pub(crate) async fn h2_handler( io: I, src_addr: SocketAddr, dns_hostname: Option>, + drain: impl Future, ) where T: RequestHandler, I: AsyncRead + AsyncWrite + Unpin, @@ -45,13 +46,22 @@ pub(crate) async fn h2_handler( // Accept all inbound HTTP/2.0 streams sent over the // connection. - while let Some(next_request) = h2.accept().await { - let (request, respond) = match next_request { - Ok(next_request) => next_request, - Err(err) => { - warn!("error accepting request {}: {}", src_addr, err); - return; - } + loop { + let next_request = tokio::select! { + result = h2.accept() => match result { + Some(Ok(next_request)) => next_request, + Some(Err(err)) => { + warn!("error accepting request {}: {}", src_addr, err); + return; + } + None => { + return; + } + }, + _ = drain => { + // A graceful shutdown was initiated. + return + }, }; debug!("Received request: {:#?}", request); @@ -80,7 +90,7 @@ async fn handle_request( } #[derive(Clone)] -struct HttpsResponseHandle(Arc>>); +struct HttpsResponseHandle(Arc>>); #[async_trait::async_trait] impl ResponseHandler for HttpsResponseHandle { diff --git a/crates/server/src/server/quic_handler.rs b/crates/server/src/server/quic_handler.rs index abc4a3ab7c..547473fd29 100644 --- a/crates/server/src/server/quic_handler.rs +++ b/crates/server/src/server/quic_handler.rs @@ -5,10 +5,10 @@ // http://opensource.org/licenses/MIT>, at your option. This file may not be // copied, modified, or distributed except according to those terms. -use std::{io, net::SocketAddr, sync::Arc}; - use bytes::{Bytes, BytesMut}; use futures_util::lock::Mutex; +use std::future::Future; +use std::{io, net::SocketAddr, sync::Arc}; use tracing::{debug, warn}; use trust_dns_proto::{ error::ProtoError, @@ -30,6 +30,7 @@ pub(crate) async fn quic_handler( mut quic_streams: QuicStreams, src_addr: SocketAddr, _dns_hostname: Option>, + drain: impl Future, ) -> Result<(), ProtoError> where T: RequestHandler, @@ -38,13 +39,22 @@ where let mut max_requests = 100u32; // Accept all inbound quic streams sent over the connection. - while let Some(next_request) = quic_streams.next().await { - let mut request_stream = match next_request { - Ok(next_request) => next_request, - Err(err) => { - warn!("error accepting request {}: {}", src_addr, err); - return Err(err); - } + loop { + let next_request = tokio::select! { + result = quic_streams.next() => match result { + Some(Ok(next_request)) => next_request, + Some(Err(err)) => { + warn!("error accepting request {}: {}", src_addr, err); + return Err(err); + } + None => { + break; + } + }, + _ = drain => { + // A graceful shutdown was initiated. + break; + }, }; let request = request_stream.receive_bytes().await?; diff --git a/crates/server/src/server/server_future.rs b/crates/server/src/server/server_future.rs index 0f75068747..3a8d9b49b0 100644 --- a/crates/server/src/server/server_future.rs +++ b/crates/server/src/server/server_future.rs @@ -4,6 +4,10 @@ // http://apache.org/licenses/LICENSE-2.0> or the MIT license , at your option. This file may not be // copied, modified, or distributed except according to those terms. +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::AtomicBool; +use std::task::{Context, Poll}; use std::{ io, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, @@ -14,7 +18,7 @@ use std::{ use futures_util::{FutureExt, StreamExt}; #[cfg(feature = "dns-over-rustls")] use rustls::{Certificate, PrivateKey}; -use tokio::{net, task::JoinSet}; +use tokio::task::JoinSet; use tracing::{debug, info, warn}; use trust_dns_proto::{op::MessageType, rr::Record}; @@ -40,6 +44,7 @@ use crate::{ pub struct ServerFuture { handler: Arc, join_set: JoinSet>, + drain: Drain, } impl ServerFuture { @@ -48,25 +53,27 @@ impl ServerFuture { Self { handler: Arc::new(handler), join_set: JoinSet::new(), + drain: Drain::new(), } } /// Register a UDP socket. Should be bound before calling this function. - pub fn register_socket(&mut self, socket: net::UdpSocket) { + pub fn register_socket(&mut self, socket: tokio::net::UdpSocket) { debug!("registering udp: {:?}", socket); // create the new UdpStream, the IP address isn't relevant, and ideally goes essentially no where. // the address used is acquired from the inbound queries - let (mut buf_stream, stream_handle) = + let (stream, stream_handle) = UdpStream::with_bound(socket, ([127, 255, 255, 254], 0).into()); - //let request_stream = RequestStream::new(buf_stream, stream_handle); + let mut stream = stream.take_until(self.drain.clone()); let handler = self.handler.clone(); // this spawns a ForEach future which handles all the requests into a Handler. + let drain = self.drain.clone(); self.join_set.spawn({ async move { let mut inner_join_set = JoinSet::new(); - while let Some(message) = buf_stream.next().await { + while let Some(message) = stream.next().await { let message = match message { Err(e) => { warn!("error receiving message on udp_socket: {}", e); @@ -92,22 +99,25 @@ impl ServerFuture { let stream_handle = stream_handle.with_remote_addr(src_addr); inner_join_set.spawn(async move { - self::handle_raw_request(message, Protocol::Udp, handler, stream_handle) - .await; + handle_raw_request(message, Protocol::Udp, handler, stream_handle).await; }); - - reap_tasks(&mut inner_join_set); } - // TODO: let's consider capturing all the initial configuration details so that the socket could be recreated... - Err(ProtoError::from("unexpected close of UDP socket")) + reap_tasks(&mut inner_join_set); + + if drain.is_shutdown() { + Ok(()) + } else { + // TODO: let's consider capturing all the initial configuration details so that the socket could be recreated... + Err(ProtoError::from("unexpected close of UDP socket")) + } } }); } /// Register a UDP socket. Should be bound before calling this function. pub fn register_socket_std(&mut self, socket: std::net::UdpSocket) -> io::Result<()> { - self.register_socket(net::UdpSocket::from_std(socket)?); + self.register_socket(tokio::net::UdpSocket::from_std(socket)?); Ok(()) } @@ -123,73 +133,78 @@ impl ServerFuture { /// requests within this time period will be closed. In the future it should be /// possible to create long-lived queries, but these should be from trusted sources /// only, this would require some type of whitelisting. - pub fn register_listener(&mut self, listener: net::TcpListener, timeout: Duration) { + pub fn register_listener(&mut self, listener: tokio::net::TcpListener, timeout: Duration) { debug!("register tcp: {:?}", listener); let handler = self.handler.clone(); // for each incoming request... - self.join_set.spawn({ - async move { - let mut inner_join_set = JoinSet::new(); - loop { - let tcp_stream = listener.accept().await; - let (tcp_stream, src_addr) = match tcp_stream { + let drain = self.drain.clone(); + self.join_set.spawn(async move { + let mut inner_join_set = JoinSet::new(); + loop { + let (tcp_stream, src_addr) = tokio::select! { + tcp_stream = listener.accept() => match tcp_stream { Ok((t, s)) => (t, s), Err(e) => { debug!("error receiving TCP tcp_stream error: {}", e); continue; - } - }; + }, + }, + _ = drain.clone() => { + // A graceful shutdown was initiated. Break out of the loop. + break; + }, + }; + + // verify that the src address is safe for responses + if let Err(e) = sanitize_src_address(src_addr) { + warn!( + "address can not be responded to {src_addr}: {e}", + src_addr = src_addr, + e = e + ); + continue; + } - // verify that the src address is safe for responses - if let Err(e) = sanitize_src_address(src_addr) { - warn!( - "address can not be responded to {src_addr}: {e}", - src_addr = src_addr, - e = e - ); - continue; - } + let handler = handler.clone(); - let handler = handler.clone(); + // and spawn to the io_loop + inner_join_set.spawn(async move { + debug!("accepted request from: {}", src_addr); + // take the created stream... + let (buf_stream, stream_handle) = + TcpStream::from_stream(AsyncIoTokioAsStd(tcp_stream), src_addr); + let mut timeout_stream = TimeoutStream::new(buf_stream, timeout); - // and spawn to the io_loop - inner_join_set.spawn(async move { - debug!("accepted request from: {}", src_addr); - // take the created stream... - let (buf_stream, stream_handle) = - TcpStream::from_stream(AsyncIoTokioAsStd(tcp_stream), src_addr); - let mut timeout_stream = TimeoutStream::new(buf_stream, timeout); - //let request_stream = RequestStream::new(timeout_stream, stream_handle); - - while let Some(message) = timeout_stream.next().await { - let message = match message { - Ok(message) => message, - Err(e) => { - debug!( - "error in TCP request_stream src: {} error: {}", - src_addr, e - ); - // we're going to bail on this connection... - return; - } - }; - - // we don't spawn here to limit clients from getting too many resources - self::handle_raw_request( - message, - Protocol::Tcp, - handler.clone(), - stream_handle.clone(), - ) - .await; - } - }); + while let Some(message) = timeout_stream.next().await { + let message = match message { + Ok(message) => message, + Err(e) => { + debug!( + "error in TCP request_stream src: {} error: {}", + src_addr, e + ); + // we're going to bail on this connection... + return; + } + }; - reap_tasks(&mut inner_join_set); - } + // we don't spawn here to limit clients from getting too many resources + handle_raw_request( + message, + Protocol::Tcp, + handler.clone(), + stream_handle.clone(), + ) + .await; + } + }); } + + // We're exiting, reap any pending tasks. + reap_tasks(&mut inner_join_set); + Ok(()) }); } @@ -210,7 +225,7 @@ impl ServerFuture { listener: std::net::TcpListener, timeout: Duration, ) -> io::Result<()> { - self.register_listener(net::TcpListener::from_std(listener)?, timeout); + self.register_listener(tokio::net::TcpListener::from_std(listener)?, timeout); Ok(()) } @@ -251,84 +266,89 @@ impl ServerFuture { let tls_acceptor = Box::pin(tls_server::new_acceptor(cert, chain, key)?); // for each incoming request... - self.join_set.spawn({ - async move { - let mut inner_join_set = JoinSet::new(); - loop { - let tcp_stream = listener.accept().await; - let (tcp_stream, src_addr) = match tcp_stream { + let drain = self.drain.clone(); + self.join_set.spawn(async move { + let mut inner_join_set = JoinSet::new(); + loop { + let (tcp_stream, src_addr) = tokio::select! { + tcp_stream = listener.accept() => match tcp_stream { Ok((t, s)) => (t, s), Err(e) => { debug!("error receiving TLS tcp_stream error: {}", e); continue; - } - }; - - // verify that the src address is safe for responses - if let Err(e) = sanitize_src_address(src_addr) { - warn!( - "address can not be responded to {src_addr}: {e}", - src_addr = src_addr, - e = e - ); - continue; - } + }, + }, + _ = drain.clone() => { + // A graceful shutdown was initiated. Break out of the loop. + break; + }, + }; + + // verify that the src address is safe for responses + if let Err(e) = sanitize_src_address(src_addr) { + warn!( + "address can not be responded to {src_addr}: {e}", + src_addr = src_addr, + e = e + ); + continue; + } - let handler = handler.clone(); - let tls_acceptor = tls_acceptor.clone(); + let handler = handler.clone(); + let tls_acceptor = tls_acceptor.clone(); - // kick out to a different task immediately, let them do the TLS handshake - inner_join_set.spawn(async move { - debug!("starting TLS request from: {}", src_addr); + // kick out to a different task immediately, let them do the TLS handshake + inner_join_set.spawn(async move { + debug!("starting TLS request from: {}", src_addr); - // perform the TLS - let mut tls_stream = match Ssl::new(tls_acceptor.context()) - .and_then(|ssl| TokioSslStream::new(ssl, tcp_stream)) - { - Ok(tls_stream) => tls_stream, - Err(e) => { - debug!("tls handshake src: {} error: {}", src_addr, e); - return (); - } - }; - match Pin::new(&mut tls_stream).accept().await { - Ok(()) => {} + // perform the TLS + let mut tls_stream = match Ssl::new(tls_acceptor.context()) + .and_then(|ssl| TokioSslStream::new(ssl, tcp_stream)) + { + Ok(tls_stream) => tls_stream, + Err(e) => { + debug!("tls handshake src: {} error: {}", src_addr, e); + return (); + } + }; + match Pin::new(&mut tls_stream).accept().await { + Ok(()) => {} + Err(e) => { + debug!("tls handshake src: {} error: {}", src_addr, e); + return (); + } + }; + debug!("accepted TLS request from: {}", src_addr); + let (buf_stream, stream_handle) = + TlsStream::from_stream(AsyncIoTokioAsStd(tls_stream), src_addr); + let mut timeout_stream = TimeoutStream::new(buf_stream, timeout); + while let Some(message) = timeout_stream.next().await { + let message = match message { + Ok(message) => message, Err(e) => { - debug!("tls handshake src: {} error: {}", src_addr, e); + debug!( + "error in TLS request_stream src: {:?} error: {}", + src_addr, e + ); + + // kill this connection return (); } }; - debug!("accepted TLS request from: {}", src_addr); - let (buf_stream, stream_handle) = - TlsStream::from_stream(AsyncIoTokioAsStd(tls_stream), src_addr); - let mut timeout_stream = TimeoutStream::new(buf_stream, timeout); - while let Some(message) = timeout_stream.next().await { - let message = match message { - Ok(message) => message, - Err(e) => { - debug!( - "error in TLS request_stream src: {:?} error: {}", - src_addr, e - ); - - // kill this connection - return (); - } - }; - - self::handle_raw_request( - message, - Protocol::Tls, - handler.clone(), - stream_handle.clone(), - ) - .await; - } - }); - reap_tasks(&mut inner_join_set); - } + self::handle_raw_request( + message, + Protocol::Tls, + handler.clone(), + stream_handle.clone(), + ) + .await; + } + }); } + + reap_tasks(&mut inner_join_set); + Ok(()) }); Ok(()) @@ -382,7 +402,7 @@ impl ServerFuture { #[cfg_attr(docsrs, doc(cfg(feature = "dns-over-rustls")))] pub fn register_tls_listener( &mut self, - listener: net::TcpListener, + listener: tokio::net::TcpListener, timeout: Duration, certificate_and_key: (Vec, PrivateKey), ) -> io::Result<()> { @@ -403,76 +423,81 @@ impl ServerFuture { let tls_acceptor = TlsAcceptor::from(Arc::new(tls_acceptor)); // for each incoming request... - self.join_set.spawn({ - async move { - let mut inner_join_set = JoinSet::new(); - loop { - let tcp_stream = listener.accept().await; - let (tcp_stream, src_addr) = match tcp_stream { + let drain = self.drain.clone(); + self.join_set.spawn(async move { + let mut inner_join_set = JoinSet::new(); + loop { + let (tcp_stream, src_addr) = tokio::select! { + tcp_stream = listener.accept() => match tcp_stream { Ok((t, s)) => (t, s), Err(e) => { debug!("error receiving TLS tcp_stream error: {}", e); continue; - } - }; - - // verify that the src address is safe for responses - if let Err(e) = sanitize_src_address(src_addr) { - warn!( - "address can not be responded to {src_addr}: {e}", - src_addr = src_addr, - e = e - ); - continue; - } + }, + }, + _ = drain.clone() => { + // A graceful shutdown was initiated. Break out of the loop. + break; + }, + }; + + // verify that the src address is safe for responses + if let Err(e) = sanitize_src_address(src_addr) { + warn!( + "address can not be responded to {src_addr}: {e}", + src_addr = src_addr, + e = e + ); + continue; + } - let handler = handler.clone(); - let tls_acceptor = tls_acceptor.clone(); + let handler = handler.clone(); + let tls_acceptor = tls_acceptor.clone(); - // kick out to a different task immediately, let them do the TLS handshake - inner_join_set.spawn(async move { - debug!("starting TLS request from: {}", src_addr); + // kick out to a different task immediately, let them do the TLS handshake + inner_join_set.spawn(async move { + debug!("starting TLS request from: {}", src_addr); - // perform the TLS - let tls_stream = tls_acceptor.accept(tcp_stream).await; + // perform the TLS + let tls_stream = tls_acceptor.accept(tcp_stream).await; - let tls_stream = match tls_stream { - Ok(tls_stream) => AsyncIoTokioAsStd(tls_stream), + let tls_stream = match tls_stream { + Ok(tls_stream) => AsyncIoTokioAsStd(tls_stream), + Err(e) => { + debug!("tls handshake src: {} error: {}", src_addr, e); + return; + } + }; + debug!("accepted TLS request from: {}", src_addr); + let (buf_stream, stream_handle) = tls_from_stream(tls_stream, src_addr); + let mut timeout_stream = TimeoutStream::new(buf_stream, timeout); + while let Some(message) = timeout_stream.next().await { + let message = match message { + Ok(message) => message, Err(e) => { - debug!("tls handshake src: {} error: {}", src_addr, e); + debug!( + "error in TLS request_stream src: {:?} error: {}", + src_addr, e + ); + + // kill this connection return; } }; - debug!("accepted TLS request from: {}", src_addr); - let (buf_stream, stream_handle) = tls_from_stream(tls_stream, src_addr); - let mut timeout_stream = TimeoutStream::new(buf_stream, timeout); - while let Some(message) = timeout_stream.next().await { - let message = match message { - Ok(message) => message, - Err(e) => { - debug!( - "error in TLS request_stream src: {:?} error: {}", - src_addr, e - ); - - // kill this connection - return; - } - }; - - self::handle_raw_request( - message, - Protocol::Tls, - handler.clone(), - stream_handle.clone(), - ) - .await; - } - }); - reap_tasks(&mut inner_join_set); - } + handle_raw_request( + message, + Protocol::Tls, + handler.clone(), + stream_handle.clone(), + ) + .await; + } + }); } + + reap_tasks(&mut inner_join_set); + Ok(()) }); Ok(()) @@ -528,7 +553,7 @@ impl ServerFuture { #[cfg_attr(docsrs, doc(cfg(feature = "dns-over-https-rustls")))] pub fn register_https_listener( &mut self, - listener: net::TcpListener, + listener: tokio::net::TcpListener, // TODO: need to set a timeout between requests. _timeout: Duration, certificate_and_key: (Vec, PrivateKey), @@ -555,52 +580,57 @@ impl ServerFuture { // for each incoming request... let dns_hostname = dns_hostname; - self.join_set.spawn({ - async move { - let mut inner_join_set = JoinSet::new(); - let dns_hostname = dns_hostname; - loop { - let tcp_stream = listener.accept().await; - let (tcp_stream, src_addr) = match tcp_stream { + let drain = self.drain.clone(); + self.join_set.spawn(async move { + let mut inner_join_set = JoinSet::new(); + let dns_hostname = dns_hostname; + loop { + let (tcp_stream, src_addr) = tokio::select! { + tcp_stream = listener.accept() => match tcp_stream { Ok((t, s)) => (t, s), Err(e) => { - debug!("error receiving HTTPS tcp_stream error: {e}"); + debug!("error receiving HTTPS tcp_stream error: {}", e); continue; - } - }; - - // verify that the src address is safe for responses - if let Err(e) = sanitize_src_address(src_addr) { - warn!("address can not be responded to {src_addr}: {e}"); - continue; - } - - let handler = handler.clone(); - let tls_acceptor = tls_acceptor.clone(); - let dns_hostname = dns_hostname.clone(); + }, + }, + _ = drain.clone() => { + // A graceful shutdown was initiated. Break out of the loop. + break; + }, + }; + + // verify that the src address is safe for responses + if let Err(e) = sanitize_src_address(src_addr) { + warn!("address can not be responded to {src_addr}: {e}"); + continue; + } - inner_join_set.spawn(async move { - debug!("starting HTTPS request from: {src_addr}"); + let handler = handler.clone(); + let tls_acceptor = tls_acceptor.clone(); + let dns_hostname = dns_hostname.clone(); - // TODO: need to consider timeout of total connect... - // take the created stream... - let tls_stream = tls_acceptor.accept(tcp_stream).await; + inner_join_set.spawn(async move { + debug!("starting HTTPS request from: {src_addr}"); - let tls_stream = match tls_stream { - Ok(tls_stream) => tls_stream, - Err(e) => { - debug!("https handshake src: {src_addr} error: {e}"); - return; - } - }; - debug!("accepted HTTPS request from: {src_addr}"); + // TODO: need to consider timeout of total connect... + // take the created stream... + let tls_stream = tls_acceptor.accept(tcp_stream).await; - h2_handler(handler, tls_stream, src_addr, dns_hostname).await; - }); + let tls_stream = match tls_stream { + Ok(tls_stream) => tls_stream, + Err(e) => { + debug!("https handshake src: {src_addr} error: {e}"); + return; + } + }; + debug!("accepted HTTPS request from: {src_addr}"); - reap_tasks(&mut inner_join_set); - } + h2_handler(handler, tls_stream, src_addr, dns_hostname, drain.clone()).await; + }); } + + reap_tasks(&mut inner_join_set); + Ok(()) }); Ok(()) @@ -623,7 +653,7 @@ impl ServerFuture { #[cfg_attr(docsrs, doc(cfg(feature = "dns-over-quic")))] pub fn register_quic_listener( &mut self, - socket: net::UdpSocket, + socket: tokio::net::UdpSocket, // TODO: need to set a timeout between requests. _timeout: Duration, certificate_and_key: (Vec, PrivateKey), @@ -642,71 +672,141 @@ impl ServerFuture { // for each incoming request... let dns_hostname = dns_hostname; - self.join_set.spawn({ - async move { - let mut inner_join_set = JoinSet::new(); - let dns_hostname = dns_hostname; - loop { - let (streams, src_addr) = match server.next().await { + let drain = self.drain.clone(); + self.join_set.spawn(async move { + let mut inner_join_set = JoinSet::new(); + let dns_hostname = dns_hostname; + loop { + let (tcp_stream, src_addr) = tokio::select! { + result = server.next() => match result { Ok(Some(c)) => c, Ok(None) => continue, Err(e) => { debug!("error receiving quic connection: {e}"); continue; } - }; - - // verify that the src address is safe for responses - // TODO: we're relying the quinn library to actually validate responses before we get here, but this check is still worth doing - if let Err(e) = sanitize_src_address(src_addr) { - warn!( - "address can not be responded to {src_addr}: {e}", - src_addr = src_addr, - e = e - ); - continue; - } - - let handler = handler.clone(); - let dns_hostname = dns_hostname.clone(); + }, + _ = drain.clone() => { + // A graceful shutdown was initiated. Break out of the loop. + break; + }, + }; + + // verify that the src address is safe for responses + // TODO: we're relying the quinn library to actually validate responses before we get here, but this check is still worth doing + if let Err(e) = sanitize_src_address(src_addr) { + warn!( + "address can not be responded to {src_addr}: {e}", + src_addr = src_addr, + e = e + ); + continue; + } - inner_join_set.spawn(async move { - debug!("starting quic stream request from: {src_addr}"); + let handler = handler.clone(); + let dns_hostname = dns_hostname.clone(); - // TODO: need to consider timeout of total connect... - let result = quic_handler(handler, streams, src_addr, dns_hostname).await; + inner_join_set.spawn(async move { + debug!("starting quic stream request from: {src_addr}"); - if let Err(e) = result { - warn!("quic stream processing failed from {src_addr}: {e}") - } - }); + // TODO: need to consider timeout of total connect... + let result = + quic_handler(handler, streams, src_addr, dns_hostname, drain.clone()).await; - reap_tasks(&mut inner_join_set); - } + if let Err(e) = result { + warn!("quic stream processing failed from {src_addr}: {e}") + } + }); } + + reap_tasks(&mut inner_join_set); + Ok(()) }); Ok(()) } - /// This will run until a background task of the trust_dns_server ends. + /// Triggers a graceful shutdown the server. All background tasks will stop accepting + /// new connections and the returned future will complete once all tasks have terminated. + pub fn shutdown(self) -> impl Future> { + // Trigger the graceful shudown of all the server tasks. + self.drain.shutdown(); + + // Return the future that blocks until all tasks complete. + self.block_until_done() + } + + /// This will run until all background tasks complete. If one or more tasks return an error, + /// one will be chosen as the returned error for this future. pub async fn block_until_done(mut self) -> Result<(), ProtoError> { - let result = self.join_set.join_next().await; + if self.join_set.is_empty() { + warn!("block_until_done called with no pending tasks"); + return Ok(()); + } - match result { - None => { - tracing::warn!("block_until_done called with no pending tasks"); - Ok(()) + // Now wait for all of the tasks to complete. + let mut out = Ok(()); + while let Some(join_result) = self.join_set.join_next().await { + match join_result { + Ok(result) => { + match result { + Ok(_) => (), + Err(e) => { + // Save the last error. + out = Err(e); + } + } + } + Err(e) => return Err(ProtoError::from(format!("Internal error in spawn: {e}"))), } - Some(Ok(x)) => x, - Some(Err(e)) => Err(ProtoError::from(format!("Internal error in spawn: {e}"))), + } + out + } +} + +/// A Future that controls the graceful shutdown of the server. +#[derive(Clone)] +struct Drain { + is_shutdown: Arc, +} + +impl Drain { + /// Create a new server handle. + fn new() -> Self { + Self { + is_shutdown: Arc::new(AtomicBool::new(false)), + } + } + + /// Tell the server handlers to shutdown + fn shutdown(&self) { + self.is_shutdown + .store(true, std::sync::atomic::Ordering::Relaxed); + } + + /// Indicates whether the server is shutting down. + fn is_shutdown(&self) -> bool { + self.is_shutdown.load(std::sync::atomic::Ordering::Relaxed) + } +} + +impl Future for Drain { + type Output = (); + + fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll { + if self.is_shutdown() { + Poll::Ready(()) + } else { + Poll::Pending } } } /// Reap finished tasks from a `JoinSet`, without awaiting or blocking. fn reap_tasks(join_set: &mut JoinSet<()>) { - while FutureExt::now_or_never(join_set.join_next()).is_some() {} + if !join_set.is_empty() { + while FutureExt::now_or_never(join_set.join_next()).is_some() {} + } } pub(crate) async fn handle_raw_request( @@ -718,7 +818,7 @@ pub(crate) async fn handle_raw_request( let src_addr = message.addr(); let response_handler = ResponseHandle::new(message.addr(), response_handler, protocol); - self::handle_request( + handle_request( message.bytes(), src_addr, protocol, @@ -939,31 +1039,37 @@ mod tests { use super::*; use crate::authority::Catalog; use futures_util::future; - use std::net::{Ipv4Addr, SocketAddr, UdpSocket}; + use std::net::{Ipv4Addr, SocketAddr}; + use tokio::net::{TcpListener, UdpSocket}; + use tokio::time::timeout; - #[test] - fn cleanup_after_shutdown() { - let runtime = tokio::runtime::Runtime::new().unwrap(); - let random_port = UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)) - .unwrap() - .local_addr() - .unwrap() - .port(); - let bind_addr = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), random_port); - - let (server_future, abort_handle) = future::abortable(async move { - let mut server_future = ServerFuture::new(Catalog::new()); - let udp_socket = tokio::net::UdpSocket::bind(bind_addr).await.unwrap(); - server_future.register_socket(udp_socket); - server_future.block_until_done().await - }); + #[tokio::test] + async fn abort() { + let mut server_future = ServerFuture::new(Catalog::new()); + let reg = Registrar::new(&mut server_future).await; + + let (abortable, abort_handle) = + future::abortable(async move { server_future.block_until_done().await }); abort_handle.abort(); - runtime.block_on(async move { - let _ = server_future.await; - }); + abortable.await.expect_err("expected abort"); + + // Rebind the same addresses to make sure they're available. + reg.rebind().await; + } + + #[tokio::test] + async fn shutdown() { + let mut server_future = ServerFuture::new(Catalog::new()); + let reg = Registrar::new(&mut server_future).await; - UdpSocket::bind(bind_addr).unwrap(); + timeout(Duration::from_secs(2), server_future.shutdown()) + .await + .expect("timed out waiting for the server to complete") + .expect("error while awaiting tasks"); + + // Rebind the same addresses to make sure they're available. + reg.rebind().await; } #[test] @@ -989,4 +1095,43 @@ mod tests { sanitize_src_address(SocketAddr::from(([0x20, 0, 0, 0, 0, 0, 0, 0x1], 0))).is_err() ); } + + struct Registrar { + udp_addr: SocketAddr, + tcp_addr: SocketAddr, + // TODO: implement the remaining endpoint types. + } + + impl Registrar { + async fn new(server: &mut ServerFuture) -> Self { + let udp_addr = bind_addr().await; + server.register_socket(UdpSocket::bind(udp_addr).await.unwrap()); + + let tcp_addr = bind_addr().await; + server.register_listener( + TcpListener::bind(tcp_addr).await.unwrap(), + Duration::from_secs(1), + ); + + Self { udp_addr, tcp_addr } + } + + async fn rebind(self) { + UdpSocket::bind(self.udp_addr).await.unwrap(); + TcpListener::bind(self.tcp_addr).await.unwrap(); + } + } + + async fn bind_addr() -> SocketAddr { + SocketAddr::new(Ipv4Addr::LOCALHOST.into(), random_port().await) + } + + async fn random_port() -> u16 { + UdpSocket::bind((Ipv4Addr::LOCALHOST, 0)) + .await + .unwrap() + .local_addr() + .unwrap() + .port() + } }