diff --git a/crates/proto/src/udp/mod.rs b/crates/proto/src/udp/mod.rs index f3d06f57ee..958d8337ce 100644 --- a/crates/proto/src/udp/mod.rs +++ b/crates/proto/src/udp/mod.rs @@ -21,3 +21,7 @@ mod udp_stream; pub use self::udp_client_stream::{UdpClientConnect, UdpClientStream}; pub use self::udp_stream::{DnsUdpSocket, QuicLocalAddr, UdpSocket, UdpStream}; + +/// Max size for the UDP receive buffer as recommended by +/// [RFC6891](https://datatracker.ietf.org/doc/html/rfc6891#section-6.2.5). +const MAX_RECEIVE_BUFFER_SIZE: usize = 4096; diff --git a/crates/proto/src/udp/udp_client_stream.rs b/crates/proto/src/udp/udp_client_stream.rs index 7ce469fb1d..99efb742fb 100644 --- a/crates/proto/src/udp/udp_client_stream.rs +++ b/crates/proto/src/udp/udp_client_stream.rs @@ -15,13 +15,13 @@ use std::task::{Context, Poll}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use futures_util::{future::Future, stream::Stream}; -use tracing::{debug, warn}; +use tracing::{debug, trace, warn}; use crate::error::ProtoError; use crate::op::message::NoopMessageFinalizer; use crate::op::{Message, MessageFinalizer, MessageVerifier}; use crate::udp::udp_stream::{NextRandomUdpSocket, UdpCreator, UdpSocket}; -use crate::udp::DnsUdpSocket; +use crate::udp::{DnsUdpSocket, MAX_RECEIVE_BUFFER_SIZE}; use crate::xfer::{DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream, SerialMessage}; use crate::Time; @@ -212,6 +212,9 @@ impl DnsRequestSender } } + // Get an appropriate read buffer size. + let recv_buf_size = MAX_RECEIVE_BUFFER_SIZE.min(message.max_payload() as usize); + let bytes = match message.to_vec() { Ok(bytes) => bytes, Err(err) => { @@ -235,7 +238,8 @@ impl DnsRequestSender self.timeout, Box::pin(async move { let socket: S = NextRandomUdpSocket::new_with_closure(&addr, creator).await?; - send_serial_message_inner(message, message_id, verifier, socket).await + send_serial_message_inner(message, message_id, verifier, socket, recv_buf_size) + .await }), ) .into() @@ -298,6 +302,7 @@ async fn send_serial_message_inner( msg_id: u16, verifier: Option, socket: S, + recv_buf_size: usize, ) -> Result { let bytes = msg.bytes(); let addr = msg.addr(); @@ -311,13 +316,16 @@ async fn send_serial_message_inner( ))); } + // Create the receive buffer. + trace!("creating UDP receive buffer with size {recv_buf_size}"); + let mut recv_buf = vec![0; recv_buf_size]; + // TODO: limit the max number of attempted messages? this relies on a timeout to die... loop { - // TODO: consider making this heap based? need to verify it matches EDNS settings - let mut recv_buf = [0u8; 2048]; - let (len, src) = socket.recv_from(&mut recv_buf).await?; - let buffer: Vec<_> = recv_buf.iter().take(len).cloned().collect(); + + // Copy the slice of read bytes. + let buffer: Vec<_> = Vec::from(&recv_buf[0..len]); // compare expected src to received packet let request_target = msg.addr(); diff --git a/crates/proto/src/udp/udp_stream.rs b/crates/proto/src/udp/udp_stream.rs index 6ecced3a23..d30d48037c 100644 --- a/crates/proto/src/udp/udp_stream.rs +++ b/crates/proto/src/udp/udp_stream.rs @@ -19,6 +19,7 @@ use rand; use rand::distributions::{uniform::Uniform, Distribution}; use tracing::{debug, warn}; +use crate::udp::MAX_RECEIVE_BUFFER_SIZE; use crate::xfer::{BufDnsStreamHandle, SerialMessage, StreamReceiver}; use crate::Time; @@ -220,7 +221,7 @@ impl Stream for UdpStream { // receive all inbound messages // TODO: this should match edns settings - let mut buf = [0u8; 4096]; + let mut buf = [0u8; MAX_RECEIVE_BUFFER_SIZE]; let (len, src) = ready!(socket.poll_recv_from(cx, &mut buf))?; let serial_message = SerialMessage::new(buf.iter().take(len).cloned().collect(), src);