diff --git a/crates/proto/src/serialize/binary/encoder.rs b/crates/proto/src/serialize/binary/encoder.rs index b92f02a29a..c54894bb12 100644 --- a/crates/proto/src/serialize/binary/encoder.rs +++ b/crates/proto/src/serialize/binary/encoder.rs @@ -132,7 +132,7 @@ impl<'a> BinEncoder<'a> { BinEncoder { offset: offset as usize, // TODO: add max_size to signature - buffer: private::MaximalBuf::new(u16::max_value(), buf), + buffer: private::MaximalBuf::new(u16::MAX, buf), name_pointers: Vec::new(), mode, canonical_names: false, @@ -240,8 +240,8 @@ impl<'a> BinEncoder<'a> { /// The location is the current position in the buffer /// implicitly, it is expected that the name will be written to the stream after the current index. pub fn store_label_pointer(&mut self, start: usize, end: usize) { - assert!(start <= (u16::max_value() as usize)); - assert!(end <= (u16::max_value() as usize)); + assert!(start <= (u16::MAX as usize)); + assert!(end <= (u16::MAX as usize)); assert!(start <= end); if self.offset < 0x3FFF_usize { self.name_pointers @@ -255,7 +255,7 @@ impl<'a> BinEncoder<'a> { for (match_start, matcher) in &self.name_pointers { if matcher.as_slice() == search { - assert!(match_start <= &(u16::max_value() as usize)); + assert!(match_start <= &(u16::MAX as usize)); return Some(*match_start as u16); } } diff --git a/crates/proto/src/udp/mod.rs b/crates/proto/src/udp/mod.rs index f3d06f57ee..7e76b4184e 100644 --- a/crates/proto/src/udp/mod.rs +++ b/crates/proto/src/udp/mod.rs @@ -21,3 +21,7 @@ mod udp_stream; pub use self::udp_client_stream::{UdpClientConnect, UdpClientStream}; pub use self::udp_stream::{DnsUdpSocket, QuicLocalAddr, UdpSocket, UdpStream}; + +/// Max size for the UDP receive buffer as recommended by +/// [RFC6891](https://datatracker.ietf.org/doc/html/rfc6891#section-6.2.5). +pub const MAX_RECEIVE_BUFFER_SIZE: usize = 4096; diff --git a/crates/proto/src/udp/udp_client_stream.rs b/crates/proto/src/udp/udp_client_stream.rs index 7ce469fb1d..99efb742fb 100644 --- a/crates/proto/src/udp/udp_client_stream.rs +++ b/crates/proto/src/udp/udp_client_stream.rs @@ -15,13 +15,13 @@ use std::task::{Context, Poll}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use futures_util::{future::Future, stream::Stream}; -use tracing::{debug, warn}; +use tracing::{debug, trace, warn}; use crate::error::ProtoError; use crate::op::message::NoopMessageFinalizer; use crate::op::{Message, MessageFinalizer, MessageVerifier}; use crate::udp::udp_stream::{NextRandomUdpSocket, UdpCreator, UdpSocket}; -use crate::udp::DnsUdpSocket; +use crate::udp::{DnsUdpSocket, MAX_RECEIVE_BUFFER_SIZE}; use crate::xfer::{DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream, SerialMessage}; use crate::Time; @@ -212,6 +212,9 @@ impl DnsRequestSender } } + // Get an appropriate read buffer size. + let recv_buf_size = MAX_RECEIVE_BUFFER_SIZE.min(message.max_payload() as usize); + let bytes = match message.to_vec() { Ok(bytes) => bytes, Err(err) => { @@ -235,7 +238,8 @@ impl DnsRequestSender self.timeout, Box::pin(async move { let socket: S = NextRandomUdpSocket::new_with_closure(&addr, creator).await?; - send_serial_message_inner(message, message_id, verifier, socket).await + send_serial_message_inner(message, message_id, verifier, socket, recv_buf_size) + .await }), ) .into() @@ -298,6 +302,7 @@ async fn send_serial_message_inner( msg_id: u16, verifier: Option, socket: S, + recv_buf_size: usize, ) -> Result { let bytes = msg.bytes(); let addr = msg.addr(); @@ -311,13 +316,16 @@ async fn send_serial_message_inner( ))); } + // Create the receive buffer. + trace!("creating UDP receive buffer with size {recv_buf_size}"); + let mut recv_buf = vec![0; recv_buf_size]; + // TODO: limit the max number of attempted messages? this relies on a timeout to die... loop { - // TODO: consider making this heap based? need to verify it matches EDNS settings - let mut recv_buf = [0u8; 2048]; - let (len, src) = socket.recv_from(&mut recv_buf).await?; - let buffer: Vec<_> = recv_buf.iter().take(len).cloned().collect(); + + // Copy the slice of read bytes. + let buffer: Vec<_> = Vec::from(&recv_buf[0..len]); // compare expected src to received packet let request_target = msg.addr(); diff --git a/crates/proto/src/udp/udp_stream.rs b/crates/proto/src/udp/udp_stream.rs index 6ecced3a23..d30d48037c 100644 --- a/crates/proto/src/udp/udp_stream.rs +++ b/crates/proto/src/udp/udp_stream.rs @@ -19,6 +19,7 @@ use rand; use rand::distributions::{uniform::Uniform, Distribution}; use tracing::{debug, warn}; +use crate::udp::MAX_RECEIVE_BUFFER_SIZE; use crate::xfer::{BufDnsStreamHandle, SerialMessage, StreamReceiver}; use crate::Time; @@ -220,7 +221,7 @@ impl Stream for UdpStream { // receive all inbound messages // TODO: this should match edns settings - let mut buf = [0u8; 4096]; + let mut buf = [0u8; MAX_RECEIVE_BUFFER_SIZE]; let (len, src) = ready!(socket.poll_recv_from(cx, &mut buf))?; let serial_message = SerialMessage::new(buf.iter().take(len).cloned().collect(), src); diff --git a/crates/server/src/authority/catalog.rs b/crates/server/src/authority/catalog.rs index 791bf08555..b7457bbbe3 100644 --- a/crates/server/src/authority/catalog.rs +++ b/crates/server/src/authority/catalog.rs @@ -47,22 +47,23 @@ async fn send_response<'a, R: ResponseHandler>( >, mut response_handle: R, ) -> io::Result { - #[cfg(feature = "dnssec")] if let Some(mut resp_edns) = response_edns { - // set edns DAU and DHU - // send along the algorithms which are supported by this authority - let mut algorithms = SupportedAlgorithms::default(); - algorithms.set(Algorithm::RSASHA256); - algorithms.set(Algorithm::ECDSAP256SHA256); - algorithms.set(Algorithm::ECDSAP384SHA384); - algorithms.set(Algorithm::ED25519); - - let dau = EdnsOption::DAU(algorithms); - let dhu = EdnsOption::DHU(algorithms); - - resp_edns.options_mut().insert(dau); - resp_edns.options_mut().insert(dhu); - + #[cfg(feature = "dnssec")] + { + // set edns DAU and DHU + // send along the algorithms which are supported by this authority + let mut algorithms = SupportedAlgorithms::default(); + algorithms.set(Algorithm::RSASHA256); + algorithms.set(Algorithm::ECDSAP256SHA256); + algorithms.set(Algorithm::ECDSAP384SHA384); + algorithms.set(Algorithm::ED25519); + + let dau = EdnsOption::DAU(algorithms); + let dhu = EdnsOption::DHU(algorithms); + + resp_edns.options_mut().insert(dau); + resp_edns.options_mut().insert(dhu); + } response.set_edns(resp_edns); } diff --git a/crates/server/src/authority/message_response.rs b/crates/server/src/authority/message_response.rs index 457545801d..d8b9c23c40 100644 --- a/crates/server/src/authority/message_response.rs +++ b/crates/server/src/authority/message_response.rs @@ -96,6 +96,11 @@ where self } + /// Gets a reference to the EDNS options for the Response. + pub fn get_edns(&self) -> &Option { + &self.edns + } + /// Consumes self, and emits to the encoder. pub fn destructive_emit(mut self, encoder: &mut BinEncoder<'_>) -> ProtoResult { // soa records are part of the nameserver section diff --git a/crates/server/src/server/response_handler.rs b/crates/server/src/server/response_handler.rs index eb5c5bd9c5..0b245a375a 100644 --- a/crates/server/src/server/response_handler.rs +++ b/crates/server/src/server/response_handler.rs @@ -10,6 +10,7 @@ use std::{io, net::SocketAddr}; use tracing::debug; use trust_dns_proto::rr::Record; +use crate::server::Protocol; use crate::{ authority::MessageResponse, proto::{ @@ -46,12 +47,44 @@ pub trait ResponseHandler: Clone + Send + Sync + Unpin + 'static { pub struct ResponseHandle { dst: SocketAddr, stream_handle: BufDnsStreamHandle, + protocol: Protocol, } impl ResponseHandle { /// Returns a new `ResponseHandle` for sending a response message - pub fn new(dst: SocketAddr, stream_handle: BufDnsStreamHandle) -> Self { - Self { dst, stream_handle } + pub fn new(dst: SocketAddr, stream_handle: BufDnsStreamHandle, protocol: Protocol) -> Self { + Self { + dst, + stream_handle, + protocol, + } + } + + /// Selects an appropriate maximum serialized size for the given response. + fn max_size_for_response<'a>( + &self, + response: &MessageResponse< + '_, + 'a, + impl Iterator + Send + 'a, + impl Iterator + Send + 'a, + impl Iterator + Send + 'a, + impl Iterator + Send + 'a, + >, + ) -> u16 { + // Use EDNS, if available. + if let Some(edns) = response.get_edns() { + edns.max_payload() + } else { + // No EDNS. Use an appropriate maximum for the protocol. + match self.protocol { + Protocol::Udp => { + // For UDP, we use the recommended max from RFC6891. + trust_dns_proto::udp::MAX_RECEIVE_BUFFER_SIZE as u16 + } + _ => u16::MAX, + } + } } } @@ -79,6 +112,11 @@ impl ResponseHandler for ResponseHandle { let mut buffer = Vec::with_capacity(512); let encode_result = { let mut encoder = BinEncoder::new(&mut buffer); + + // Set an appropriate maximum on the encoder. + let max_size = self.max_size_for_response(&response); + encoder.set_max_size(max_size); + response.destructive_emit(&mut encoder) }; diff --git a/crates/server/src/server/server_future.rs b/crates/server/src/server/server_future.rs index 596c54d619..0f75068747 100644 --- a/crates/server/src/server/server_future.rs +++ b/crates/server/src/server/server_future.rs @@ -716,7 +716,7 @@ pub(crate) async fn handle_raw_request( response_handler: BufDnsStreamHandle, ) { let src_addr = message.addr(); - let response_handler = ResponseHandle::new(message.addr(), response_handler); + let response_handler = ResponseHandle::new(message.addr(), response_handler, protocol); self::handle_request( message.bytes(), diff --git a/tests/integration-tests/tests/truncation_tests.rs b/tests/integration-tests/tests/truncation_tests.rs new file mode 100644 index 0000000000..6414446f1b --- /dev/null +++ b/tests/integration-tests/tests/truncation_tests.rs @@ -0,0 +1,118 @@ +use std::collections::BTreeMap; +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; +use std::str::FromStr; +use std::sync::Arc; +use tokio::net::UdpSocket; +use trust_dns_client::client::AsyncClient; +use trust_dns_proto::op::{Edns, Message, MessageType, OpCode, Query}; +use trust_dns_proto::rr::rdata::{A, SOA}; +use trust_dns_proto::rr::{DNSClass, Name, RData, Record, RecordSet, RecordType, RrKey}; +use trust_dns_proto::udp::UdpClientStream; +use trust_dns_proto::xfer::FirstAnswer; +use trust_dns_proto::DnsHandle; +use trust_dns_server::authority::{Catalog, ZoneType}; +use trust_dns_server::store::in_memory::InMemoryAuthority; +use trust_dns_server::ServerFuture; + +#[tokio::test] +async fn test_truncation() { + let _guard = subscribe(); + + let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 0)); + let udp_socket = UdpSocket::bind(&addr).await.unwrap(); + + let nameserver = udp_socket.local_addr().unwrap(); + println!("udp_socket on port: {nameserver}"); + + // Create and start the server. + let mut server = ServerFuture::new(new_large_catalog(128)); + server.register_socket(udp_socket); + tokio::spawn(server.block_until_done()); + + // Create the UDP client. + let stream = UdpClientStream::::new(nameserver); + let (mut client, bg) = AsyncClient::connect(stream).await.unwrap(); + + // Run the client exchange in the background. + tokio::spawn(bg); + + // Build the query. + let max_payload = 512; + let mut msg = Message::new(); + msg.add_query({ + let mut query = Query::query(large_name(), RecordType::A); + query.set_query_class(DNSClass::IN); + query + }) + .set_id(rand::random::()) + .set_message_type(MessageType::Query) + .set_op_code(OpCode::Query) + .set_recursion_desired(true) + .set_edns({ + let mut edns = Edns::new(); + edns.set_max_payload(max_payload).set_version(0); + edns + }); + + let result = client.send(msg).first_answer().await.expect("query failed"); + + assert!(result.truncated()); + assert_eq!(max_payload, result.max_payload()); +} + +// TODO: should we do this for all of the integration tests? +fn subscribe() -> tracing::subscriber::DefaultGuard { + let sub = tracing_subscriber::FmtSubscriber::builder() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .finish(); + tracing::subscriber::set_default(sub) +} + +pub fn new_large_catalog(num_records: u32) -> Catalog { + // Create a large record set. + let name = large_name(); + let mut record_set = RecordSet::new(&name, RecordType::A, 0); + for i in 1..num_records + 1 { + let ip = Ipv4Addr::from(i); + let rdata = RData::A(A(ip)); + record_set.insert(Record::from_rdata(name.clone(), 86400, rdata), 0); + } + + let mut soa_record_set = RecordSet::new(&name, RecordType::SOA, 0); + soa_record_set.insert( + Record::from_rdata( + name.clone(), + 86400, + RData::SOA(SOA::new( + n("sns.dns.icann.org."), + n("noc.dns.icann.org."), + 2015082403, + 7200, + 3600, + 1209600, + 3600, + )), + ), + 0, + ); + + let mut records = BTreeMap::new(); + records.insert(RrKey::new(name.clone().into(), RecordType::A), record_set); + records.insert(RrKey::new(name.into(), RecordType::SOA), soa_record_set); + let authority = + InMemoryAuthority::new(Name::root(), records, ZoneType::Primary, false).unwrap(); + + let mut catalog: Catalog = Catalog::new(); + catalog.upsert(Name::root().into(), Box::new(Arc::new(authority))); + catalog +} + +const LARGE_NAME: &str = "large.com"; + +fn large_name() -> Name { + n(LARGE_NAME) +} + +pub fn n>(name: S) -> Name { + Name::from_str(name.as_ref()).unwrap() +}