Skip to content

Commit

Permalink
Fix truncation for UDP
Browse files Browse the repository at this point in the history
This fixes a couple of issues for UDP on both the client and server:

* Previously, the UdpClientStream was using a fixed `2048` for the size of the receive buffer. This can cause problems on interfaces with a larger MTU. #1096 adjusted this value on the server side to 4096 (the maximum as recommended by RFC6891). This PR sets a constant that is shared by the UDP client and server. Additionally, the client uses EDNS in the request to further trim down the buffer size.
* The Server previously was not setting a maximum for the `BinEncoder`, which defaults to `u16::MAX` (i.e. effectively no truncation for UDP). This PR sets an appropriate maximum for the `BinEncoder` based on the response EDNS and protocol being used.

Fixes: #1973
  • Loading branch information
nmittler committed Jun 26, 2023
1 parent dc14427 commit bf2d99c
Show file tree
Hide file tree
Showing 9 changed files with 205 additions and 30 deletions.
8 changes: 4 additions & 4 deletions crates/proto/src/serialize/binary/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ impl<'a> BinEncoder<'a> {
BinEncoder {
offset: offset as usize,
// TODO: add max_size to signature
buffer: private::MaximalBuf::new(u16::max_value(), buf),
buffer: private::MaximalBuf::new(u16::MAX, buf),
name_pointers: Vec::new(),
mode,
canonical_names: false,
Expand Down Expand Up @@ -240,8 +240,8 @@ impl<'a> BinEncoder<'a> {
/// The location is the current position in the buffer
/// implicitly, it is expected that the name will be written to the stream after the current index.
pub fn store_label_pointer(&mut self, start: usize, end: usize) {
assert!(start <= (u16::max_value() as usize));
assert!(end <= (u16::max_value() as usize));
assert!(start <= (u16::MAX as usize));
assert!(end <= (u16::MAX as usize));
assert!(start <= end);
if self.offset < 0x3FFF_usize {
self.name_pointers
Expand All @@ -255,7 +255,7 @@ impl<'a> BinEncoder<'a> {

for (match_start, matcher) in &self.name_pointers {
if matcher.as_slice() == search {
assert!(match_start <= &(u16::max_value() as usize));
assert!(match_start <= &(u16::MAX as usize));
return Some(*match_start as u16);
}
}
Expand Down
4 changes: 4 additions & 0 deletions crates/proto/src/udp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ mod udp_stream;

pub use self::udp_client_stream::{UdpClientConnect, UdpClientStream};
pub use self::udp_stream::{DnsUdpSocket, QuicLocalAddr, UdpSocket, UdpStream};

/// Max size for the UDP receive buffer as recommended by
/// [RFC6891](https://datatracker.ietf.org/doc/html/rfc6891#section-6.2.5).
pub const MAX_RECEIVE_BUFFER_SIZE: usize = 4096;
22 changes: 15 additions & 7 deletions crates/proto/src/udp/udp_client_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ use std::task::{Context, Poll};
use std::time::{Duration, SystemTime, UNIX_EPOCH};

use futures_util::{future::Future, stream::Stream};
use tracing::{debug, warn};
use tracing::{debug, trace, warn};

use crate::error::ProtoError;
use crate::op::message::NoopMessageFinalizer;
use crate::op::{Message, MessageFinalizer, MessageVerifier};
use crate::udp::udp_stream::{NextRandomUdpSocket, UdpCreator, UdpSocket};
use crate::udp::DnsUdpSocket;
use crate::udp::{DnsUdpSocket, MAX_RECEIVE_BUFFER_SIZE};
use crate::xfer::{DnsRequest, DnsRequestSender, DnsResponse, DnsResponseStream, SerialMessage};
use crate::Time;

Expand Down Expand Up @@ -212,6 +212,9 @@ impl<S: DnsUdpSocket + Send + 'static, MF: MessageFinalizer> DnsRequestSender
}
}

// Get an appropriate read buffer size.
let recv_buf_size = MAX_RECEIVE_BUFFER_SIZE.min(message.max_payload() as usize);

let bytes = match message.to_vec() {
Ok(bytes) => bytes,
Err(err) => {
Expand All @@ -235,7 +238,8 @@ impl<S: DnsUdpSocket + Send + 'static, MF: MessageFinalizer> DnsRequestSender
self.timeout,
Box::pin(async move {
let socket: S = NextRandomUdpSocket::new_with_closure(&addr, creator).await?;
send_serial_message_inner(message, message_id, verifier, socket).await
send_serial_message_inner(message, message_id, verifier, socket, recv_buf_size)
.await
}),
)
.into()
Expand Down Expand Up @@ -298,6 +302,7 @@ async fn send_serial_message_inner<S: DnsUdpSocket + Send>(
msg_id: u16,
verifier: Option<MessageVerifier>,
socket: S,
recv_buf_size: usize,
) -> Result<DnsResponse, ProtoError> {
let bytes = msg.bytes();
let addr = msg.addr();
Expand All @@ -311,13 +316,16 @@ async fn send_serial_message_inner<S: DnsUdpSocket + Send>(
)));
}

// Create the receive buffer.
trace!("creating UDP receive buffer with size {recv_buf_size}");
let mut recv_buf = vec![0; recv_buf_size];

// TODO: limit the max number of attempted messages? this relies on a timeout to die...
loop {
// TODO: consider making this heap based? need to verify it matches EDNS settings
let mut recv_buf = [0u8; 2048];

let (len, src) = socket.recv_from(&mut recv_buf).await?;
let buffer: Vec<_> = recv_buf.iter().take(len).cloned().collect();

// Copy the slice of read bytes.
let buffer: Vec<_> = Vec::from(&recv_buf[0..len]);

// compare expected src to received packet
let request_target = msg.addr();
Expand Down
3 changes: 2 additions & 1 deletion crates/proto/src/udp/udp_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use rand;
use rand::distributions::{uniform::Uniform, Distribution};
use tracing::{debug, warn};

use crate::udp::MAX_RECEIVE_BUFFER_SIZE;
use crate::xfer::{BufDnsStreamHandle, SerialMessage, StreamReceiver};
use crate::Time;

Expand Down Expand Up @@ -220,7 +221,7 @@ impl<S: DnsUdpSocket + Send + 'static> Stream for UdpStream<S> {
// receive all inbound messages

// TODO: this should match edns settings
let mut buf = [0u8; 4096];
let mut buf = [0u8; MAX_RECEIVE_BUFFER_SIZE];
let (len, src) = ready!(socket.poll_recv_from(cx, &mut buf))?;

let serial_message = SerialMessage::new(buf.iter().take(len).cloned().collect(), src);
Expand Down
31 changes: 16 additions & 15 deletions crates/server/src/authority/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,23 @@ async fn send_response<'a, R: ResponseHandler>(
>,
mut response_handle: R,
) -> io::Result<ResponseInfo> {
#[cfg(feature = "dnssec")]
if let Some(mut resp_edns) = response_edns {
// set edns DAU and DHU
// send along the algorithms which are supported by this authority
let mut algorithms = SupportedAlgorithms::default();
algorithms.set(Algorithm::RSASHA256);
algorithms.set(Algorithm::ECDSAP256SHA256);
algorithms.set(Algorithm::ECDSAP384SHA384);
algorithms.set(Algorithm::ED25519);

let dau = EdnsOption::DAU(algorithms);
let dhu = EdnsOption::DHU(algorithms);

resp_edns.options_mut().insert(dau);
resp_edns.options_mut().insert(dhu);

#[cfg(feature = "dnssec")]
{
// set edns DAU and DHU
// send along the algorithms which are supported by this authority
let mut algorithms = SupportedAlgorithms::default();
algorithms.set(Algorithm::RSASHA256);
algorithms.set(Algorithm::ECDSAP256SHA256);
algorithms.set(Algorithm::ECDSAP384SHA384);
algorithms.set(Algorithm::ED25519);

let dau = EdnsOption::DAU(algorithms);
let dhu = EdnsOption::DHU(algorithms);

resp_edns.options_mut().insert(dau);
resp_edns.options_mut().insert(dhu);
}
response.set_edns(resp_edns);
}

Expand Down
5 changes: 5 additions & 0 deletions crates/server/src/authority/message_response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ where
self
}

/// Gets a reference to the EDNS options for the Response.
pub fn get_edns(&self) -> &Option<Edns> {
&self.edns
}

/// Consumes self, and emits to the encoder.
pub fn destructive_emit(mut self, encoder: &mut BinEncoder<'_>) -> ProtoResult<ResponseInfo> {
// soa records are part of the nameserver section
Expand Down
42 changes: 40 additions & 2 deletions crates/server/src/server/response_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use std::{io, net::SocketAddr};
use tracing::debug;
use trust_dns_proto::rr::Record;

use crate::server::Protocol;
use crate::{
authority::MessageResponse,
proto::{
Expand Down Expand Up @@ -46,12 +47,44 @@ pub trait ResponseHandler: Clone + Send + Sync + Unpin + 'static {
pub struct ResponseHandle {
dst: SocketAddr,
stream_handle: BufDnsStreamHandle,
protocol: Protocol,
}

impl ResponseHandle {
/// Returns a new `ResponseHandle` for sending a response message
pub fn new(dst: SocketAddr, stream_handle: BufDnsStreamHandle) -> Self {
Self { dst, stream_handle }
pub fn new(dst: SocketAddr, stream_handle: BufDnsStreamHandle, protocol: Protocol) -> Self {
Self {
dst,
stream_handle,
protocol,
}
}

/// Selects an appropriate maximum serialized size for the given response.
fn max_size_for_response<'a>(
&self,
response: &MessageResponse<
'_,
'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
>,
) -> u16 {
// Use EDNS, if available.
if let Some(edns) = response.get_edns() {
edns.max_payload()
} else {
// No EDNS. Use an appropriate maximum for the protocol.
match self.protocol {
Protocol::Udp => {
// For UDP, we use the recommended max from RFC6891.
trust_dns_proto::udp::MAX_RECEIVE_BUFFER_SIZE as u16
}
_ => u16::MAX,
}
}
}
}

Expand Down Expand Up @@ -79,6 +112,11 @@ impl ResponseHandler for ResponseHandle {
let mut buffer = Vec::with_capacity(512);
let encode_result = {
let mut encoder = BinEncoder::new(&mut buffer);

// Set an appropriate maximum on the encoder.
let max_size = self.max_size_for_response(&response);
encoder.set_max_size(max_size);

response.destructive_emit(&mut encoder)
};

Expand Down
2 changes: 1 addition & 1 deletion crates/server/src/server/server_future.rs
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ pub(crate) async fn handle_raw_request<T: RequestHandler>(
response_handler: BufDnsStreamHandle,
) {
let src_addr = message.addr();
let response_handler = ResponseHandle::new(message.addr(), response_handler);
let response_handler = ResponseHandle::new(message.addr(), response_handler, protocol);

self::handle_request(
message.bytes(),
Expand Down
118 changes: 118 additions & 0 deletions tests/integration-tests/tests/truncation_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
use std::collections::BTreeMap;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::str::FromStr;
use std::sync::Arc;
use tokio::net::UdpSocket;
use trust_dns_client::client::AsyncClient;
use trust_dns_proto::op::{Edns, Message, MessageType, OpCode, Query};
use trust_dns_proto::rr::rdata::{A, SOA};
use trust_dns_proto::rr::{DNSClass, Name, RData, Record, RecordSet, RecordType, RrKey};
use trust_dns_proto::udp::UdpClientStream;
use trust_dns_proto::xfer::FirstAnswer;
use trust_dns_proto::DnsHandle;
use trust_dns_server::authority::{Catalog, ZoneType};
use trust_dns_server::store::in_memory::InMemoryAuthority;
use trust_dns_server::ServerFuture;

#[tokio::test]
async fn test_truncation() {
let _guard = subscribe();

let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 0));
let udp_socket = UdpSocket::bind(&addr).await.unwrap();

let nameserver = udp_socket.local_addr().unwrap();
println!("udp_socket on port: {nameserver}");

// Create and start the server.
let mut server = ServerFuture::new(new_large_catalog(128));
server.register_socket(udp_socket);
tokio::spawn(server.block_until_done());

// Create the UDP client.
let stream = UdpClientStream::<UdpSocket>::new(nameserver);
let (mut client, bg) = AsyncClient::connect(stream).await.unwrap();

// Run the client exchange in the background.
tokio::spawn(bg);

// Build the query.
let max_payload = 512;
let mut msg = Message::new();
msg.add_query({
let mut query = Query::query(large_name(), RecordType::A);
query.set_query_class(DNSClass::IN);
query
})
.set_id(rand::random::<u16>())
.set_message_type(MessageType::Query)
.set_op_code(OpCode::Query)
.set_recursion_desired(true)
.set_edns({
let mut edns = Edns::new();
edns.set_max_payload(max_payload).set_version(0);
edns
});

let result = client.send(msg).first_answer().await.expect("query failed");

assert!(result.truncated());
assert_eq!(max_payload, result.max_payload());
}

// TODO: should we do this for all of the integration tests?
fn subscribe() -> tracing::subscriber::DefaultGuard {
let sub = tracing_subscriber::FmtSubscriber::builder()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.finish();
tracing::subscriber::set_default(sub)
}

pub fn new_large_catalog(num_records: u32) -> Catalog {
// Create a large record set.
let name = large_name();
let mut record_set = RecordSet::new(&name, RecordType::A, 0);
for i in 1..num_records + 1 {
let ip = Ipv4Addr::from(i);
let rdata = RData::A(A(ip));
record_set.insert(Record::from_rdata(name.clone(), 86400, rdata), 0);
}

let mut soa_record_set = RecordSet::new(&name, RecordType::SOA, 0);
soa_record_set.insert(
Record::from_rdata(
name.clone(),
86400,
RData::SOA(SOA::new(
n("sns.dns.icann.org."),
n("noc.dns.icann.org."),
2015082403,
7200,
3600,
1209600,
3600,
)),
),
0,
);

let mut records = BTreeMap::new();
records.insert(RrKey::new(name.clone().into(), RecordType::A), record_set);
records.insert(RrKey::new(name.into(), RecordType::SOA), soa_record_set);
let authority =
InMemoryAuthority::new(Name::root(), records, ZoneType::Primary, false).unwrap();

let mut catalog: Catalog = Catalog::new();
catalog.upsert(Name::root().into(), Box::new(Arc::new(authority)));
catalog
}

const LARGE_NAME: &str = "large.com";

fn large_name() -> Name {
n(LARGE_NAME)
}

pub fn n<S: AsRef<str>>(name: S) -> Name {
Name::from_str(name.as_ref()).unwrap()
}

0 comments on commit bf2d99c

Please sign in to comment.