Skip to content

Commit

Permalink
add unit tests for dns updates
Browse files Browse the repository at this point in the history
  • Loading branch information
conectado committed Mar 27, 2024
1 parent f175a02 commit 5334184
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 69 deletions.
4 changes: 2 additions & 2 deletions rust/connlib/shared/src/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ pub struct ResourceDescriptionCidr {
pub name: String,
}

#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq, Hash)]
#[derive(Debug, Deserialize, Serialize, Clone, Copy, PartialEq, Eq, Hash)]
#[serde(tag = "protocol", rename_all = "snake_case")]
pub enum DnsServer {
IpPort(IpDnsServer),
Expand Down Expand Up @@ -328,7 +328,7 @@ where
}
}

#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq, Hash)]
#[derive(Debug, Deserialize, Serialize, Clone, Copy, PartialEq, Eq, Hash)]
pub struct IpDnsServer {
pub address: SocketAddr,
}
Expand Down
255 changes: 188 additions & 67 deletions rust/connlib/tunnel/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,15 @@ where

/// Updates the system's dns
pub fn set_dns(&mut self, new_dns: Vec<IpAddr>) -> connlib_shared::Result<()> {
// We store the sentinel dns both in the config and in the system's resolvers
// but once calculated the dns mapping those are ignored.
self.role_state.update_system_resolvers(new_dns);

let dns_changed = self.update_dns_mapping();
if !dns_changed {
return Ok(());
}

self.update_interface()?;

Ok(())
Expand Down Expand Up @@ -274,33 +278,12 @@ where
}

fn update_dns_mapping(&mut self) -> bool {
let Some(config) = &self.role_state.interface_config else {
return false;
};

let effective_dns_servers = effective_dns_servers(
config.upstream_dns.clone(),
self.role_state.system_resolvers.clone(),
);

if HashSet::<&DnsServer>::from_iter(effective_dns_servers.iter())
== HashSet::from_iter(self.role_state.dns_mapping.right_values())
{
if !self.role_state.update_dns_mapping() {
return false;
}

let dns_mapping = sentinel_dns_mapping(
&effective_dns_servers,
self.role_state
.dns_mapping()
.left_values()
.copied()
.map(Into::into)
.collect_vec(),
);

self.role_state.set_dns_mapping(dns_mapping.clone());
self.io.set_upstream_dns_servers(dns_mapping.clone());
self.io
.set_upstream_dns_servers(self.role_state.dns_mapping());

tracing::info!("Setting new DNS resolvers");
true
Expand Down Expand Up @@ -932,38 +915,71 @@ impl ClientState {
pub(crate) fn poll_transmit(&mut self) -> Option<snownet::Transmit<'_>> {
self.node.poll_transmit()
}
}

fn dns_updated(old_dns: &[IpAddr], new_dns: &[IpAddr]) -> bool {
HashSet::<&IpAddr>::from_iter(old_dns.iter()) != HashSet::<&IpAddr>::from_iter(new_dns.iter())
fn update_dns_mapping(&mut self) -> bool {
let Some(config) = &self.interface_config else {
return false;
};

let effective_dns_servers =
effective_dns_servers(config.upstream_dns.clone(), self.system_resolvers.clone());

if HashSet::<&DnsServer>::from_iter(effective_dns_servers.iter())
== HashSet::from_iter(self.dns_mapping.right_values())
{
return false;
}

let dns_mapping = sentinel_dns_mapping(
&effective_dns_servers,
self.dns_mapping()
.left_values()
.copied()
.map(Into::into)
.collect_vec(),
);

self.set_dns_mapping(dns_mapping.clone());

true
}
}

fn effective_dns_servers(
upstream_dns: Vec<DnsServer>,
default_resolvers: Vec<IpAddr>,
) -> Vec<DnsServer> {
if !upstream_dns.is_empty() {
return upstream_dns;
let mut upstream_dns = upstream_dns.into_iter().filter_map(not_sentinel).peekable();
if upstream_dns.peek().is_some() {
return upstream_dns.collect();
}

let mut dns_servers = default_resolvers
.into_iter()
.filter(|ip| !IpNetwork::from_str(DNS_SENTINELS_V4).unwrap().contains(*ip))
.filter(|ip| !IpNetwork::from_str(DNS_SENTINELS_V6).unwrap().contains(*ip))
.map(|ip| {
DnsServer::IpPort(IpDnsServer {
address: (ip, DNS_PORT).into(),
})
})
.filter_map(not_sentinel)
.peekable();

if dns_servers.peek().is_none() {
tracing::error!("No system default DNS servers available! Can't initialize resolver. DNS interception will be disabled.");
return Vec::new();
}

dns_servers
.map(|ip| {
DnsServer::IpPort(IpDnsServer {
address: (ip, DNS_PORT).into(),
})
})
.collect()
dns_servers.collect()
}

fn not_sentinel(srv: DnsServer) -> Option<DnsServer> {
(!IpNetwork::from_str(DNS_SENTINELS_V4)
.unwrap()
.contains(srv.ip())
&& !IpNetwork::from_str(DNS_SENTINELS_V6)
.unwrap()
.contains(srv.ip()))
.then_some(srv)
}

fn sentinel_dns_mapping(
Expand Down Expand Up @@ -1082,50 +1098,132 @@ mod tests {
}

#[test]
fn dns_updated_when_dns_changes() {
assert!(dns_updated(&[ip("1.0.0.1")], &[ip("1.1.1.1")]))
fn update_system_dns_works() {
let mut client_state = ClientState::for_test();
client_state.interface_config = Some(interface_config_without_dns());

client_state.update_system_resolvers(vec![ip("1.1.1.1")]);
assert!(client_state.update_dns_mapping());
dns_mapping_is_exactly(client_state.dns_mapping(), vec![dns("1.1.1.1:53")]);
}

#[test]
fn dns_not_updated_when_dns_remains_the_same() {
assert!(!dns_updated(&[ip("1.1.1.1")], &[ip("1.1.1.1")]))
fn update_system_dns_without_change_is_a_no_op() {
let mut client_state = ClientState::for_test();
client_state.interface_config = Some(interface_config_without_dns());

client_state.update_system_resolvers(vec![ip("1.1.1.1")]);
client_state.update_dns_mapping();

client_state.update_system_resolvers(vec![ip("1.1.1.1")]);
assert!(!client_state.update_dns_mapping());
dns_mapping_is_exactly(client_state.dns_mapping(), vec![dns("1.1.1.1:53")]);
}

#[test]
fn dns_updated_ignores_order() {
assert!(!dns_updated(
&[ip("1.0.0.1"), ip("1.1.1.1")],
&[ip("1.1.1.1"), ip("1.0.0.1")]
))
fn update_system_dns_with_change_works() {
let mut client_state = ClientState::for_test();
client_state.interface_config = Some(interface_config_without_dns());

client_state.update_system_resolvers(vec![ip("1.1.1.1")]);
client_state.update_dns_mapping();

client_state.update_system_resolvers(vec![ip("1.0.0.1")]);
assert!(client_state.update_dns_mapping());
dns_mapping_is_exactly(client_state.dns_mapping(), vec![dns("1.0.0.1:53")]);
}

#[test]
fn update_system_dns_works() {
fn update_to_system_with_sentinels_are_ignored() {
let mut client_state = ClientState::for_test();
client_state.interface_config = Some(interface_config_without_dns());

let changed = client_state.update_system_resolvers(vec![ip("1.1.1.1")]);
client_state.update_system_resolvers(vec![ip("1.1.1.1")]);
client_state.update_dns_mapping();

assert!(changed);
client_state.update_system_resolvers(vec![
ip("1.1.1.1"),
ip("100.100.111.1"),
ip("fd00:2021:1111:8000:100:100:111:0"),
]);
assert!(!client_state.update_dns_mapping());
dns_mapping_is_exactly(client_state.dns_mapping(), vec![dns("1.1.1.1:53")]);
}

#[test]
fn update_system_dns_without_change_is_a_no_op() {
fn upstream_dns_wins_over_system() {
let mut client_state = ClientState::for_test();

client_state.update_system_resolvers(vec![ip("1.1.1.1")]);
let changed = client_state.update_system_resolvers(vec![ip("1.1.1.1")]);
client_state.interface_config = Some(interface_config_with_dns());

assert!(client_state.update_dns_mapping());

assert!(!changed)
client_state.update_system_resolvers(vec![ip("1.0.0.1")]);
assert!(!client_state.update_dns_mapping());
dns_mapping_is_exactly(client_state.dns_mapping(), dns_list());
}

#[test]
fn update_system_dns_with_change_works() {
fn upstream_dns_change_updates() {
let mut client_state = ClientState::for_test();

client_state.update_system_resolvers(vec![ip("1.1.1.1")]);
let changed = client_state.update_system_resolvers(vec![ip("1.0.0.1")]);
client_state.interface_config = Some(interface_config_with_dns());

assert!(client_state.update_dns_mapping());

let mut new_config = interface_config_without_dns();
new_config.upstream_dns = vec![dns("8.8.8.8:53")];
client_state.interface_config = Some(new_config);

assert!(client_state.update_dns_mapping());
dns_mapping_is_exactly(client_state.dns_mapping(), vec![dns("8.8.8.8:53")]);
}

#[test]
fn upstream_dns_no_change_is_a_no_op() {
let mut client_state = ClientState::for_test();

client_state.interface_config = Some(interface_config_with_dns());
client_state.update_system_resolvers(vec![ip("1.0.0.1")]);

assert!(client_state.update_dns_mapping());

client_state.interface_config = Some(interface_config_with_dns());
assert!(!client_state.update_dns_mapping());
dns_mapping_is_exactly(client_state.dns_mapping(), dns_list());
}

#[test]
fn upstream_dns_sentinels_are_ignored() {
let mut client_state = ClientState::for_test();

let mut config = interface_config_with_dns();
client_state.interface_config = Some(config.clone());

client_state.update_dns_mapping();

assert!(changed)
config.upstream_dns.push(dns("100.100.111.1:53"));
config
.upstream_dns
.push(dns("[fd00:2021:1111:8000:100:100:111:0]:53"));
client_state.interface_config = Some(config);
assert!(!client_state.update_dns_mapping());
dns_mapping_is_exactly(client_state.dns_mapping(), dns_list())
}

#[test]
fn system_dns_takes_over_when_upstream_are_unset() {
let mut client_state = ClientState::for_test();

client_state.interface_config = Some(interface_config_with_dns());
client_state.update_dns_mapping();

client_state.update_system_resolvers(vec![ip("1.0.0.1")]);
client_state.update_dns_mapping();

client_state.interface_config = Some(interface_config_without_dns());
assert!(client_state.update_dns_mapping());
dns_mapping_is_exactly(client_state.dns_mapping(), vec![dns("1.0.0.1:53")]);
}

#[test]
Expand Down Expand Up @@ -1165,6 +1263,29 @@ mod tests {
}
}

fn dns_mapping_is_exactly(mapping: BiMap<IpAddr, DnsServer>, servers: Vec<DnsServer>) {
assert_eq!(
HashSet::<&DnsServer>::from_iter(mapping.right_values()),
HashSet::from_iter(servers.iter())
)
}

fn interface_config_without_dns() -> InterfaceConfig {
InterfaceConfig {
ipv4: "10.0.0.1".parse().unwrap(),
ipv6: "fe80::".parse().unwrap(),
upstream_dns: Vec::new(),
}
}

fn interface_config_with_dns() -> InterfaceConfig {
InterfaceConfig {
ipv4: "10.0.0.1".parse().unwrap(),
ipv6: "fe80::".parse().unwrap(),
upstream_dns: dns_list(),
}
}

fn sentinel_ranges() -> Vec<IpNetwork> {
vec![
IpNetwork::from_str(DNS_SENTINELS_V4).unwrap(),
Expand All @@ -1174,18 +1295,18 @@ mod tests {

fn dns_list() -> Vec<DnsServer> {
vec![
DnsServer::IpPort(IpDnsServer {
address: "1.1.1.1:53".parse().unwrap(),
}),
DnsServer::IpPort(IpDnsServer {
address: "1.0.0.1:53".parse().unwrap(),
}),
DnsServer::IpPort(IpDnsServer {
address: "[2606:4700:4700::1111]:53".parse().unwrap(),
}),
dns("1.1.1.1:53"),
dns("1.0.0.1:53"),
dns("[2606:4700:4700::1111]:53"),
]
}

fn dns(address: &str) -> DnsServer {
DnsServer::IpPort(IpDnsServer {
address: address.parse().unwrap(),
})
}

fn ip(addr: &str) -> IpAddr {
addr.parse().unwrap()
}
Expand Down

0 comments on commit 5334184

Please sign in to comment.