Skip to content

Commit

Permalink
server: keep MessageResponse iterators generic
Browse files Browse the repository at this point in the history
  • Loading branch information
djc committed Jan 19, 2022
1 parent a30e694 commit d7aa80f
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 31 deletions.
12 changes: 10 additions & 2 deletions crates/server/src/authority/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use std::{borrow::Borrow, collections::HashMap, future::Future, io};

use cfg_if::cfg_if;
use log::{debug, error, info, trace, warn};
use trust_dns_proto::rr::Record;

#[cfg(feature = "dnssec")]
use crate::client::rr::{
Expand All @@ -37,9 +38,16 @@ pub struct Catalog {
}

#[allow(unused_mut, unused_variables)]
async fn send_response<R: ResponseHandler>(
async fn send_response<'a, R: ResponseHandler>(
response_edns: Option<Edns>,
mut response: MessageResponse<'_, '_>,
mut response: MessageResponse<
'_,
'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
>,
mut response_handle: R,
) -> io::Result<ResponseInfo> {
#[cfg(feature = "dnssec")]
Expand Down
49 changes: 30 additions & 19 deletions crates/server/src/authority/message_response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,19 @@ use super::message_request::WireQuery;

/// A EncodableMessage with borrowed data for Responses in the Server
#[derive(Debug)]
pub struct MessageResponse<
'q,
'a,
A = Box<dyn Iterator<Item = &'a Record> + Send + 'a>,
N = Box<dyn Iterator<Item = &'a Record> + Send + 'a>,
S = Box<dyn Iterator<Item = &'a Record> + Send + 'a>,
D = Box<dyn Iterator<Item = &'a Record> + Send + 'a>,
> where
A: Iterator<Item = &'a Record> + Send + 'a,
N: Iterator<Item = &'a Record> + Send + 'a,
S: Iterator<Item = &'a Record> + Send + 'a,
D: Iterator<Item = &'a Record> + Send + 'a,
pub struct MessageResponse<'q, 'a, Answers, NameServers, Soa, Additionals>
where
Answers: Iterator<Item = &'a Record> + Send + 'a,
NameServers: Iterator<Item = &'a Record> + Send + 'a,
Soa: Iterator<Item = &'a Record> + Send + 'a,
Additionals: Iterator<Item = &'a Record> + Send + 'a,
{
header: Header,
query: Option<&'q WireQuery>,
answers: A,
name_servers: N,
soa: S,
additionals: D,
answers: Answers,
name_servers: NameServers,
soa: Soa,
additionals: Additionals,
sig0: Vec<Record>,
edns: Option<Edns>,
}
Expand Down Expand Up @@ -188,7 +182,17 @@ impl<'q> MessageResponseBuilder<'q> {
}

/// Construct a Response with no associated records
pub fn build_no_records(self, header: Header) -> MessageResponse<'q, 'static> {
pub fn build_no_records<'a>(
self,
header: Header,
) -> MessageResponse<
'q,
'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
> {
MessageResponse {
header,
query: self.query,
Expand All @@ -208,11 +212,18 @@ impl<'q> MessageResponseBuilder<'q> {
/// * `id` - request id to which this is a response
/// * `op_code` - operation for which this is a response
/// * `response_code` - the type of error
pub fn error_msg(
pub fn error_msg<'a>(
self,
request_header: &Header,
response_code: ResponseCode,
) -> MessageResponse<'q, 'static> {
) -> MessageResponse<
'q,
'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
> {
let mut header = Header::response_from_request(request_header);
header.set_response_code(response_code);

Expand Down
12 changes: 10 additions & 2 deletions crates/server/src/server/https_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use futures_util::lock::Mutex;
use h2::server;
use log::{debug, warn};
use tokio::io::{AsyncRead, AsyncWrite};
use trust_dns_proto::rr::Record;

use crate::{
authority::MessageResponse,
Expand Down Expand Up @@ -83,9 +84,16 @@ struct HttpsResponseHandle(Arc<Mutex<::h2::server::SendResponse<Bytes>>>);

#[async_trait::async_trait]
impl ResponseHandler for HttpsResponseHandle {
async fn send_response(
async fn send_response<'a>(
&mut self,
response: MessageResponse<'_, '_>,
response: MessageResponse<
'_,
'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
>,
) -> io::Result<ResponseInfo> {
use crate::proto::https::response;
use crate::proto::https::HttpsError;
Expand Down
23 changes: 19 additions & 4 deletions crates/server/src/server/response_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
use std::{io, net::SocketAddr};

use log::debug;
use trust_dns_proto::rr::Record;

use crate::{
authority::MessageResponse,
Expand All @@ -25,9 +26,16 @@ pub trait ResponseHandler: Clone + Send + Sync + Unpin + 'static {
/// Serializes and sends a message to to the wrapped handle
///
/// self is consumed as only one message should ever be sent in response to a Request
async fn send_response(
async fn send_response<'a>(
&mut self,
response: MessageResponse<'_, '_>,
response: MessageResponse<
'_,
'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
>,
) -> io::Result<ResponseInfo>;
}

Expand All @@ -51,9 +59,16 @@ impl ResponseHandler for ResponseHandle {
/// Serializes and sends a message to to the wrapped handle
///
/// self is consumed as only one message should ever be sent in response to a Request
async fn send_response(
async fn send_response<'a>(
&mut self,
response: MessageResponse<'_, '_>,
response: MessageResponse<
'_,
'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
>,
) -> io::Result<ResponseInfo> {
debug!(
"response: {} response_code: {}",
Expand Down
12 changes: 10 additions & 2 deletions crates/server/src/server/server_future.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use log::{debug, info, warn};
#[cfg(feature = "dns-over-rustls")]
use rustls::{Certificate, PrivateKey};
use tokio::{net, task::JoinHandle};
use trust_dns_proto::rr::Record;

#[cfg(all(feature = "dns-over-openssl", not(feature = "dns-over-rustls")))]
use crate::proto::openssl::tls_server::*;
Expand Down Expand Up @@ -609,9 +610,16 @@ struct ReportingResponseHandler<R: ResponseHandler> {

#[async_trait::async_trait]
impl<R: ResponseHandler> ResponseHandler for ReportingResponseHandler<R> {
async fn send_response(
async fn send_response<'a>(
&mut self,
response: crate::authority::MessageResponse<'_, '_>,
response: crate::authority::MessageResponse<
'_,
'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
>,
) -> io::Result<super::ResponseInfo> {
let response_info = self.handler.send_response(response).await?;

Expand Down
12 changes: 10 additions & 2 deletions tests/integration-tests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use trust_dns_client::{
};
use trust_dns_proto::{
error::ProtoError,
rr::Record,
xfer::{DnsClientStream, DnsMultiplexer, DnsMultiplexerConnect, SerialMessage, StreamReceiver},
BufDnsStreamHandle, TokioTime,
};
Expand Down Expand Up @@ -105,9 +106,16 @@ impl TestResponseHandler {

#[async_trait::async_trait]
impl ResponseHandler for TestResponseHandler {
async fn send_response(
async fn send_response<'a>(
&mut self,
response: MessageResponse<'_, '_>,
response: MessageResponse<
'_,
'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
>,
) -> io::Result<ResponseInfo> {
let buf = &mut self.buf.lock().unwrap();
buf.clear();
Expand Down

0 comments on commit d7aa80f

Please sign in to comment.