Skip to content

Commit

Permalink
馃帹 Better dual stack selection
Browse files Browse the repository at this point in the history
  • Loading branch information
mokeyish committed May 26, 2024
1 parent f4ddd8b commit baf469c
Show file tree
Hide file tree
Showing 10 changed files with 335 additions and 136 deletions.
4 changes: 2 additions & 2 deletions src/config/domain_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub struct DomainRule {
pub cname: Option<CName>,

/// The mode of speed checking.
pub speed_check_mode: SpeedCheckModeList,
pub speed_check_mode: Option<SpeedCheckModeList>,

pub dualstack_ip_selection: Option<bool>,

Expand All @@ -36,7 +36,7 @@ impl std::ops::AddAssign for DomainRule {
self.address = rhs.address;
}

if !rhs.speed_check_mode.is_empty() {
if rhs.speed_check_mode.is_some() {
self.speed_check_mode = rhs.speed_check_mode;
}
if rhs.dualstack_ip_selection.is_some() {
Expand Down
2 changes: 1 addition & 1 deletion src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ pub struct Config {
/// speed-check-mode tcp:443,ping
/// speed-check-mode none
/// ```
pub speed_check_mode: SpeedCheckModeList,
pub speed_check_mode: Option<SpeedCheckModeList>,

/// force AAAA query return SOA
///
Expand Down
10 changes: 7 additions & 3 deletions src/config/parser/domain_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ impl NomParser for DomainRule {
SpeedCheckModeList::parse,
),
|v| {
rule.speed_check_mode.extend(v.0);
rule.speed_check_mode
.get_or_insert_with(|| SpeedCheckModeList(vec![]))
.extend(v.0);
},
),
map(
Expand Down Expand Up @@ -62,7 +64,7 @@ mod tests {
Ok((
"",
DomainRule {
speed_check_mode: vec![SpeedCheckMode::Ping].into(),
speed_check_mode: Some(vec![SpeedCheckMode::Ping].into()),
..Default::default()
}
))
Expand All @@ -72,7 +74,9 @@ mod tests {
Ok((
"",
DomainRule {
speed_check_mode: vec![SpeedCheckMode::Ping, SpeedCheckMode::Tcp(53)].into(),
speed_check_mode: Some(
vec![SpeedCheckMode::Ping, SpeedCheckMode::Tcp(53)].into()
),
..Default::default()
}
))
Expand Down
2 changes: 1 addition & 1 deletion src/config/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ pub enum OneConfig {
RrTtlMin(u64),
RrTtlMax(u64),
RrTtlReplyMax(u64),
SpeedMode(SpeedCheckModeList),
SpeedMode(Option<SpeedCheckModeList>),
TcpIdleTime(u64),
WhitelistIp(IpNet),
User(String),
Expand Down
20 changes: 13 additions & 7 deletions src/config/parser/speed_mode.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
use super::*;

impl NomParser for SpeedCheckModeList {
impl NomParser for Option<SpeedCheckModeList> {
fn parse(input: &str) -> IResult<&str, Self> {
alt((
value(Default::default(), tag_no_case("none")),
map(
separated_list1(delimited(space0, char(','), space0), NomParser::parse),
SpeedCheckModeList,
),
value(None, tag_no_case("none")),
map(SpeedCheckModeList::parse, Some),
))(input)
}
}

impl NomParser for SpeedCheckModeList {
fn parse(input: &str) -> IResult<&str, Self> {
map(
separated_list1(delimited(space0, char(','), space0), NomParser::parse),
SpeedCheckModeList,
)(input)
}
}

impl NomParser for SpeedCheckMode {
fn parse(input: &str) -> IResult<&str, Self> {
use SpeedCheckMode::*;
Expand Down Expand Up @@ -69,7 +75,7 @@ mod tests {
#[test]
fn test_speed_mode_none() {
assert_eq!(
SpeedCheckModeList::parse("none"),
Option::<SpeedCheckModeList>::parse("none"),
Ok(("", Default::default()))
);
}
Expand Down
58 changes: 54 additions & 4 deletions src/config/speed_mode.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,54 @@
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default)]
use std::net::{IpAddr, SocketAddr};

use crate::infra::ping::PingAddr;

#[derive(Copy, Clone, PartialEq, Eq, Hash)]
pub enum SpeedCheckMode {
#[default]
None,
Ping,
Tcp(u16),
Http(u16),
Https(u16),
}

#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
impl SpeedCheckMode {
pub fn to_ping_addr(self, ip_addr: IpAddr) -> PingAddr {
match self {
SpeedCheckMode::Ping => PingAddr::Icmp(ip_addr),
SpeedCheckMode::Tcp(port) => PingAddr::Tcp(SocketAddr::new(ip_addr, port)),
SpeedCheckMode::Http(port) => PingAddr::Http(SocketAddr::new(ip_addr, port)),
SpeedCheckMode::Https(port) => PingAddr::Https(SocketAddr::new(ip_addr, port)),
}
}

pub fn to_ping_addrs(self, ip_addrs: &[IpAddr]) -> Vec<PingAddr> {
ip_addrs.iter().map(|ip| self.to_ping_addr(*ip)).collect()
}
}

impl std::fmt::Debug for SpeedCheckMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SpeedCheckMode::Ping => write!(f, "ICMP"),
SpeedCheckMode::Tcp(port) => write!(f, "TCP:{port}"),
SpeedCheckMode::Http(port) => {
if *port == 80 {
write!(f, "HTTP")
} else {
write!(f, "HTTP:{port}")
}
}
SpeedCheckMode::Https(port) => {
if *port == 443 {
write!(f, "HTTPS")
} else {
write!(f, "HTTPS:{port}")
}
}
}
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct SpeedCheckModeList(pub Vec<SpeedCheckMode>);

impl SpeedCheckModeList {
Expand Down Expand Up @@ -45,3 +85,13 @@ impl std::ops::DerefMut for SpeedCheckModeList {
&mut self.0
}
}

impl std::default::Default for SpeedCheckModeList {
fn default() -> Self {
Self(vec![
SpeedCheckMode::Ping,
SpeedCheckMode::Http(80),
SpeedCheckMode::Https(443),
])
}
}
2 changes: 2 additions & 0 deletions src/dns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use std::fmt::Debug;

use std::net::IpAddr;
use std::{str::FromStr, sync::Arc, time::Duration};

use crate::dns_error::LookupError;
Expand Down Expand Up @@ -535,6 +536,7 @@ mod response {
pub type DnsRequest = request::DnsRequest;
pub type DnsResponse = response::DnsResponse;
pub type DnsError = LookupError;
use ipnet::IpAdd;
pub use serial_message::SerialMessage;

#[derive(Debug, Clone, Copy, Default)]
Expand Down
27 changes: 15 additions & 12 deletions src/dns_conf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,8 @@ impl RuntimeConfig {

/// speed check mode
#[inline]
pub fn speed_check_mode(&self) -> &SpeedCheckModeList {
&self.speed_check_mode
pub fn speed_check_mode(&self) -> Option<&SpeedCheckModeList> {
self.speed_check_mode.as_ref()
}

/// force AAAA query return SOA
Expand Down Expand Up @@ -758,7 +758,7 @@ impl RuntimeConfigBuilder {
ServerName(v) => self.server_name = Some(v),
NumWorkers(v) => self.num_workers = Some(v),
Domain(v) => self.domain = Some(v),
SpeedMode(v) => self.speed_check_mode.extend(v.0),
SpeedMode(v) => self.speed_check_mode = v,
ServeExpiredTtl(v) => self.cache.serve_expired_ttl = Some(v),
ServeExpiredReplyTtl(v) => self.cache.serve_expired_reply_ttl = Some(v),
CacheSize(v) => self.cache.size = Some(v),
Expand Down Expand Up @@ -1227,7 +1227,7 @@ mod tests {
);
assert_eq!(
domain_rule.speed_check_mode,
vec![SpeedCheckMode::Ping].into()
Some(vec![SpeedCheckMode::Ping].into())
);
assert_eq!(domain_rule.nameserver, Some("test".to_string()));
assert_eq!(domain_rule.dualstack_ip_selection, Some(true));
Expand All @@ -1248,7 +1248,7 @@ mod tests {
);
assert_eq!(
domain_rule.speed_check_mode,
vec![SpeedCheckMode::Ping].into()
Some(vec![SpeedCheckMode::Ping].into())
);
assert_eq!(domain_rule.nameserver, Some("test".to_string()));
assert_eq!(domain_rule.dualstack_ip_selection, Some(true));
Expand All @@ -1266,7 +1266,7 @@ mod tests {
assert_eq!(domain_rule.address, Some(DomainAddress::SOA));
assert_eq!(
domain_rule.speed_check_mode,
vec![SpeedCheckMode::Ping].into()
Some(vec![SpeedCheckMode::Ping].into())
);
assert_eq!(domain_rule.nameserver, Some("test".to_string()));
assert_eq!(domain_rule.dualstack_ip_selection, Some(true));
Expand All @@ -1287,11 +1287,14 @@ mod tests {
let mut cfg = RuntimeConfig::builder();
cfg.config("speed-check-mode ping,tcp:123");

assert_eq!(cfg.speed_check_mode.len(), 2);
assert_eq!(cfg.speed_check_mode.as_ref().unwrap().len(), 2);

assert_eq!(cfg.speed_check_mode.first().unwrap(), &SpeedCheckMode::Ping);
assert_eq!(
cfg.speed_check_mode.get(1).unwrap(),
cfg.speed_check_mode.as_ref().unwrap().first().unwrap(),
&SpeedCheckMode::Ping
);
assert_eq!(
cfg.speed_check_mode.as_ref().unwrap().get(1).unwrap(),
&SpeedCheckMode::Tcp(123)
);
}
Expand All @@ -1301,14 +1304,14 @@ mod tests {
let mut cfg = RuntimeConfig::builder();
cfg.config("speed-check-mode http,https");

assert_eq!(cfg.speed_check_mode.len(), 2);
assert_eq!(cfg.speed_check_mode.as_ref().unwrap().len(), 2);

assert_eq!(
cfg.speed_check_mode.first().unwrap(),
cfg.speed_check_mode.as_ref().unwrap().first().unwrap(),
&SpeedCheckMode::Http(80)
);
assert_eq!(
cfg.speed_check_mode.get(1).unwrap(),
cfg.speed_check_mode.as_ref().unwrap().get(1).unwrap(),
&SpeedCheckMode::Https(443)
);
}
Expand Down
Loading

0 comments on commit baf469c

Please sign in to comment.