Skip to content

Commit

Permalink
refactor config into dedicated config struct
Browse files Browse the repository at this point in the history
  • Loading branch information
dead10ck committed Apr 20, 2024
1 parent 4005e6a commit 3e24675
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 165 deletions.
63 changes: 27 additions & 36 deletions src/dns/client.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{net::SocketAddr, pin::Pin, sync::Arc};
use std::{pin::Pin, sync::Arc};

use futures_util::{future, Stream, StreamExt};
use hickory_client::client::{AsyncClient, AsyncDnssecClient};
Expand All @@ -14,14 +14,11 @@ use hickory_proto::{
DnsHandle, DnsMultiplexer,
};
use hickory_resolver::config::Protocol;
use nu_plugin::EvaluatedCall;
use nu_protocol::{LabeledError, Span};
use rustls::{OwnedTrustAnchor, RootCertStore};
use tokio::{net::UdpSocket, task::JoinSet};

use crate::dns::{constants, serde};

use super::serde::DnssecMode;
use super::{config::Config, serde::DnssecMode};

type DnsHandleResponse =
Pin<Box<(dyn Stream<Item = Result<DnsResponse, ProtoError>> + Send + 'static)>>;
Expand All @@ -43,28 +40,20 @@ type TokioTcpConnect = AsyncIoTokioAsStd<tokio::net::TcpStream>;

impl DnsClient {
pub async fn new(
addr: SocketAddr,
addr_span: Option<Span>,
protocol: Protocol,
call: &EvaluatedCall,
config: &Config,
) -> Result<(Self, JoinSet<Result<(), ProtoError>>), LabeledError> {
let connect_err = |err| {
LabeledError::new("connection error").with_label(
format!("Error creating client connection: {}", err),
addr_span.unwrap_or(Span::unknown()),
Span::unknown(),
)
};

let dnssec_mode = match call.get_flag_value(constants::flags::DNSSEC) {
Some(val) => serde::DnssecMode::try_from(val)?,
None => serde::DnssecMode::Opportunistic,
};

let mut join_set = JoinSet::new();

macro_rules! make_clients {
($conn:expr) => {{
let async_client = if dnssec_mode != DnssecMode::Strict {
let async_client = if config.dnssec_mode.item != DnssecMode::Strict {
let (async_client, bg) =
AsyncClient::connect($conn).await.map_err(connect_err)?;
join_set.spawn(bg);
Expand All @@ -73,7 +62,7 @@ impl DnsClient {
None
};

let dnssec_client = if dnssec_mode != DnssecMode::None {
let dnssec_client = if config.dnssec_mode.item != DnssecMode::None {
let (dnssec_client, bg) = AsyncDnssecClient::connect($conn)
.await
.map_err(connect_err)?;
Expand All @@ -86,13 +75,14 @@ impl DnsClient {
}};
}

let (async_client, dnssec_client) = match protocol {
let (async_client, dnssec_client) = match config.protocol.item {
Protocol::Udp => {
make_clients!(UdpClientStream::<UdpSocket>::new(addr))
make_clients!(UdpClientStream::<UdpSocket>::new(config.server.item))
}
Protocol::Tcp => {
make_clients!({
let (stream, sender) = TcpClientStream::<TokioTcpConnect>::new(addr);
let (stream, sender) =
TcpClientStream::<TokioTcpConnect>::new(config.server.item);
DnsMultiplexer::<_, NoopMessageFinalizer>::new(stream, sender, None)
})
}
Expand All @@ -111,44 +101,45 @@ impl DnsClient {
.with_root_certificates(root_store)
.with_no_client_auth();

let dns_name = call
.get_flag_value(constants::flags::DNS_NAME)
.ok_or_else(|| {
LabeledError::new("missing required argument")
.with_label("HTTPS requires a DNS name", call.head)
})?
.into_string()?;

match proto {
Protocol::Tls => {
let client_config = Arc::new(client_config);
make_clients!({
let (stream, sender) = hickory_proto::rustls::tls_client_connect::<
TokioTcpConnect,
>(
addr, dns_name.clone(), client_config.clone()
);
let (stream, sender) =
hickory_proto::rustls::tls_client_connect::<TokioTcpConnect>(
config.server.item,
// safe to unwrap because having a DNS name
// is enforced when constructing the config
config.dns_name.as_ref().unwrap().clone().item,
client_config.clone(),
);
DnsMultiplexer::<_, NoopMessageFinalizer>::new(stream, sender, None)
})
}
Protocol::Https => {
let client_config = Arc::new(client_config);
make_clients!({
HttpsClientStreamBuilder::with_client_config(client_config.clone())
.build::<TokioTcpConnect>(addr, dns_name.clone())
.build::<TokioTcpConnect>(
config.server.item,
config.dns_name.as_ref().unwrap().clone().item,
)
})
}
Protocol::Quic => make_clients!({
let mut builder = QuicClientStream::builder();
builder.crypto_config(client_config.clone());
builder.build(addr, dns_name.clone())
builder.build(
config.server.item,
config.dns_name.as_ref().unwrap().clone().item,
)
}),
_ => unreachable!(),
}
}
proto => {
return Err(LabeledError::new("unknown protocol")
.with_label(format!("Unknown protocol: {}", proto), call.head))
.with_label(format!("Unknown protocol: {}", proto), config.protocol.span))
}
};

Expand Down
90 changes: 22 additions & 68 deletions src/dns/mod.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,23 @@
use std::{
net::{IpAddr, SocketAddr},
str::FromStr,
sync::Arc,
};
use std::sync::Arc;

use futures_util::{stream::FuturesUnordered, Future, StreamExt};
use hickory_client::client::ClientHandle;

use nu_plugin::{EngineInterface, EvaluatedCall};
use nu_protocol::{LabeledError, ListStream, PipelineData, Span, Value};

use hickory_resolver::config::{Protocol, ResolverConfig};

use tokio::task::{JoinHandle, JoinSet};
use tracing_subscriber::prelude::*;

use self::{client::DnsClient, constants::flags, serde::Query};
use self::{client::DnsClient, config::Config, serde::Query};

mod client;
mod config;
mod constants;
mod nu;
mod serde;
#[macro_use]
mod util;

type DnsQueryJoinHandle = JoinHandle<Result<(), LabeledError>>;
type DnsQueryResult = FuturesUnordered<Result<Value, LabeledError>>;
Expand Down Expand Up @@ -48,7 +45,7 @@ impl Dns {
}
}

pub async fn dns_client(&self, call: &EvaluatedCall) -> Result<DnsClient, LabeledError> {
pub async fn dns_client(&self, config: &Config) -> Result<DnsClient, LabeledError> {
// we could use OnceLock once get_or_try_init is stable
if let Some((client, _)) = &*self.client.read().await {
return Ok(client.clone());
Expand All @@ -61,7 +58,7 @@ impl Dns {
match &mut *client_guard {
Some((client, _)) => Ok(client.clone()),
None => {
let (client, client_bg) = self.make_dns_client(call).await?;
let (client, client_bg) = self.make_dns_client(config).await?;
*client_guard = Some((client.clone(), client_bg));
Ok(client)
}
Expand All @@ -70,61 +67,16 @@ impl Dns {

async fn make_dns_client(
&self,
call: &EvaluatedCall,
config: &Config,
) -> Result<
(
DnsClient,
JoinSet<Result<(), hickory_proto::error::ProtoError>>,
),
LabeledError,
> {
let protocol = match call.get_flag_value(flags::PROTOCOL) {
None => None,
Some(val) => Some(serde::Protocol::try_from(val).map(|serde::Protocol(proto)| proto)?),
};

let (addr, addr_span, protocol) = match call.get_flag_value(flags::SERVER) {
Some(ref value @ Value::String { .. }) => {
let protocol = protocol.unwrap_or(Protocol::Udp);
let addr = SocketAddr::from_str(value.as_str().unwrap())
.or_else(|_| {
IpAddr::from_str(value.as_str().unwrap()).map(|ip| {
SocketAddr::new(ip, constants::config::default_port(protocol))
})
})
.map_err(|err| {
LabeledError::new("invalid server")
.with_label(err.to_string(), value.clone().span())
})?;

(addr, Some(value.span()), protocol)
}
None => {
let (config, _) =
hickory_resolver::system_conf::read_system_conf().unwrap_or_default();
tracing::debug!(?config);
match config.name_servers() {
[ns, ..] => (ns.socket_addr, None, ns.protocol),
[] => {
let config = ResolverConfig::default();
let ns = config.name_servers().first().unwrap();

// if protocol is explicitly configured, it should take
// precedence over the system config
(ns.socket_addr, None, protocol.unwrap_or(ns.protocol))
}
}
}
Some(val) => {
return Err(LabeledError::new("invalid server address")
.with_label("server address should be a string", val.span()));
}
};

let (client, bg) = DnsClient::new(addr, addr_span, protocol, call).await?;

tracing::debug!(client.addr = ?addr, client.protocol = ?protocol);

let (client, bg) = DnsClient::new(config).await?;
tracing::debug!(client.addr = ?config.server, client.protocol = ?config.protocol);
Ok((client, bg))
}

Expand Down Expand Up @@ -168,6 +120,9 @@ impl DnsQuery {
.with(tracing_subscriber::EnvFilter::from_default_env())
.try_init();

// [TODO]
// let config = engine.get_plugin_config()?.try_into()?;
let config = Config::try_from(call)?;
let arg_inputs: Value = call.nth(0).unwrap_or(Value::nothing(call.head));

let input: PipelineData = match input {
Expand All @@ -186,8 +141,8 @@ impl DnsQuery {
}
};

let client = plugin.dns_client(call).await?;
let call = Arc::new(call.clone());
let client = plugin.dns_client(&config).await?;
let config = Arc::new(config);

match input {
PipelineData::Value(val, _) => {
Expand All @@ -197,7 +152,7 @@ impl DnsQuery {
tracing::debug!(phase = "input", data.kind = "value");
}

let values = Self::query(call, val, client.clone()).await;
let values = Self::query(config, val, client.clone()).await;

let val = PipelineData::Value(
Value::list(
Expand All @@ -218,7 +173,6 @@ impl DnsQuery {
let (resp_tx, mut resp_rx) = tokio::sync::mpsc::channel(16);

plugin.spawn({
let call = call.clone();
let client = client.clone();

async move {
Expand All @@ -227,12 +181,12 @@ impl DnsQuery {
while let Some(val) = Box::pin(request_rx.recv()).await {
tracing::trace!(query = ?val, query.phase = "received");

let call = call.clone();
let config = config.clone();
let client = client.clone();
let resp_tx = resp_tx.clone();

let handle = tokio::spawn(async move {
let resps = Self::query(call, val, client).await;
let resps = Self::query(config, val, client).await;

for resp in resps.into_iter() {
resp_tx.send(resp).await.map_err(|send_err| {
Expand Down Expand Up @@ -311,9 +265,9 @@ impl DnsQuery {
Ok(())
}

async fn query(call: Arc<EvaluatedCall>, input: Value, client: DnsClient) -> DnsQueryResult {
async fn query(config: Arc<Config>, input: Value, client: DnsClient) -> DnsQueryResult {
let in_span = input.span();
let queries = match Query::try_from_value(&input, &call) {
let queries = match Query::try_from_value(&input, &config) {
Ok(queries) => queries,
Err(err) => {
return vec![Ok(Value::error(err.into(), in_span))]
Expand All @@ -327,7 +281,7 @@ impl DnsQuery {
futures_util::stream::iter(queries)
.then(|query| {
let mut client = client.clone();
let call = call.clone();
let config = config.clone();

async move {
let parts = query.0.into_parts();
Expand All @@ -347,7 +301,7 @@ impl DnsQuery {
})
.and_then(|resp: hickory_proto::xfer::DnsResponse| {
let msg = serde::Message::new(resp.into_message());
msg.into_value(&call)
msg.into_value(&config)
})
.inspect_err(
|err| tracing::debug!(query.phase = "finish", query.error = ?err),
Expand Down
Loading

0 comments on commit 3e24675

Please sign in to comment.