From 2603c6032c2fc4756487faa0f23d65f0d70a4f52 Mon Sep 17 00:00:00 2001 From: YISH Date: Sun, 26 May 2024 10:20:14 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20Better=20dual=20stack=20selectio?= =?UTF-8?q?n?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/config/domain_rule.rs | 4 +- src/config/mod.rs | 2 +- src/config/parser/domain_rule.rs | 10 +- src/config/parser/mod.rs | 2 +- src/config/parser/speed_mode.rs | 20 ++- src/config/speed_mode.rs | 58 ++++++- src/dns.rs | 2 + src/dns_conf.rs | 27 ++-- src/dns_mw_dualstack.rs | 256 +++++++++++++++++++++++++------ src/dns_mw_ns.rs | 90 ++++------- 10 files changed, 335 insertions(+), 136 deletions(-) diff --git a/src/config/domain_rule.rs b/src/config/domain_rule.rs index a8c226bc..62dc186d 100644 --- a/src/config/domain_rule.rs +++ b/src/config/domain_rule.rs @@ -11,7 +11,7 @@ pub struct DomainRule { pub cname: Option, /// The mode of speed checking. - pub speed_check_mode: SpeedCheckModeList, + pub speed_check_mode: Option, pub dualstack_ip_selection: Option, @@ -36,7 +36,7 @@ impl std::ops::AddAssign for DomainRule { self.address = rhs.address; } - if !rhs.speed_check_mode.is_empty() { + if rhs.speed_check_mode.is_some() { self.speed_check_mode = rhs.speed_check_mode; } if rhs.dualstack_ip_selection.is_some() { diff --git a/src/config/mod.rs b/src/config/mod.rs index 6ab5c668..281d1708 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -130,7 +130,7 @@ pub struct Config { /// speed-check-mode tcp:443,ping /// speed-check-mode none /// ``` - pub speed_check_mode: SpeedCheckModeList, + pub speed_check_mode: Option, /// force AAAA query return SOA /// diff --git a/src/config/parser/domain_rule.rs b/src/config/parser/domain_rule.rs index 5b663658..b74c063a 100644 --- a/src/config/parser/domain_rule.rs +++ b/src/config/parser/domain_rule.rs @@ -12,7 +12,9 @@ impl NomParser for DomainRule { SpeedCheckModeList::parse, ), |v| { - rule.speed_check_mode.extend(v.0); + rule.speed_check_mode + .get_or_insert_with(|| SpeedCheckModeList(vec![])) + .extend(v.0); }, ), map( @@ -62,7 +64,7 @@ mod tests { Ok(( "", DomainRule { - speed_check_mode: vec![SpeedCheckMode::Ping].into(), + speed_check_mode: Some(vec![SpeedCheckMode::Ping].into()), ..Default::default() } )) @@ -72,7 +74,9 @@ mod tests { Ok(( "", DomainRule { - speed_check_mode: vec![SpeedCheckMode::Ping, SpeedCheckMode::Tcp(53)].into(), + speed_check_mode: Some( + vec![SpeedCheckMode::Ping, SpeedCheckMode::Tcp(53)].into() + ), ..Default::default() } )) diff --git a/src/config/parser/mod.rs b/src/config/parser/mod.rs index 6baa18f9..5db51026 100644 --- a/src/config/parser/mod.rs +++ b/src/config/parser/mod.rs @@ -128,7 +128,7 @@ pub enum OneConfig { RrTtlMin(u64), RrTtlMax(u64), RrTtlReplyMax(u64), - SpeedMode(SpeedCheckModeList), + SpeedMode(Option), TcpIdleTime(u64), WhitelistIp(IpNet), User(String), diff --git a/src/config/parser/speed_mode.rs b/src/config/parser/speed_mode.rs index c4a5ba3f..c773f0e1 100644 --- a/src/config/parser/speed_mode.rs +++ b/src/config/parser/speed_mode.rs @@ -1,17 +1,23 @@ use super::*; -impl NomParser for SpeedCheckModeList { +impl NomParser for Option { fn parse(input: &str) -> IResult<&str, Self> { alt(( - value(Default::default(), tag_no_case("none")), - map( - separated_list1(delimited(space0, char(','), space0), NomParser::parse), - SpeedCheckModeList, - ), + value(None, tag_no_case("none")), + map(SpeedCheckModeList::parse, Some), ))(input) } } +impl NomParser for SpeedCheckModeList { + fn parse(input: &str) -> IResult<&str, Self> { + map( + separated_list1(delimited(space0, char(','), space0), NomParser::parse), + SpeedCheckModeList, + )(input) + } +} + impl NomParser for SpeedCheckMode { fn parse(input: &str) -> IResult<&str, Self> { use SpeedCheckMode::*; @@ -69,7 +75,7 @@ mod tests { #[test] fn test_speed_mode_none() { assert_eq!( - SpeedCheckModeList::parse("none"), + Option::::parse("none"), Ok(("", Default::default())) ); } diff --git a/src/config/speed_mode.rs b/src/config/speed_mode.rs index 61fdb543..47c0f0d6 100644 --- a/src/config/speed_mode.rs +++ b/src/config/speed_mode.rs @@ -1,14 +1,54 @@ -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default)] +use std::net::{IpAddr, SocketAddr}; + +use crate::infra::ping::PingAddr; + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] pub enum SpeedCheckMode { - #[default] - None, Ping, Tcp(u16), Http(u16), Https(u16), } -#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] +impl SpeedCheckMode { + pub fn to_ping_addr(self, ip_addr: IpAddr) -> PingAddr { + match self { + SpeedCheckMode::Ping => PingAddr::Icmp(ip_addr), + SpeedCheckMode::Tcp(port) => PingAddr::Tcp(SocketAddr::new(ip_addr, port)), + SpeedCheckMode::Http(port) => PingAddr::Http(SocketAddr::new(ip_addr, port)), + SpeedCheckMode::Https(port) => PingAddr::Https(SocketAddr::new(ip_addr, port)), + } + } + + pub fn to_ping_addrs(self, ip_addrs: &[IpAddr]) -> Vec { + ip_addrs.iter().map(|ip| self.to_ping_addr(*ip)).collect() + } +} + +impl std::fmt::Debug for SpeedCheckMode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SpeedCheckMode::Ping => write!(f, "ICMP"), + SpeedCheckMode::Tcp(port) => write!(f, "TCP:{port}"), + SpeedCheckMode::Http(port) => { + if *port == 80 { + write!(f, "HTTP") + } else { + write!(f, "HTTP:{port}") + } + } + SpeedCheckMode::Https(port) => { + if *port == 443 { + write!(f, "HTTPS") + } else { + write!(f, "HTTPS:{port}") + } + } + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct SpeedCheckModeList(pub Vec); impl SpeedCheckModeList { @@ -45,3 +85,13 @@ impl std::ops::DerefMut for SpeedCheckModeList { &mut self.0 } } + +impl std::default::Default for SpeedCheckModeList { + fn default() -> Self { + Self(vec![ + SpeedCheckMode::Ping, + SpeedCheckMode::Http(80), + SpeedCheckMode::Https(443), + ]) + } +} diff --git a/src/dns.rs b/src/dns.rs index 0eaa0553..a7d1aa50 100644 --- a/src/dns.rs +++ b/src/dns.rs @@ -2,6 +2,7 @@ use std::fmt::Debug; +use std::net::IpAddr; use std::{str::FromStr, sync::Arc, time::Duration}; use crate::dns_error::LookupError; @@ -535,6 +536,7 @@ mod response { pub type DnsRequest = request::DnsRequest; pub type DnsResponse = response::DnsResponse; pub type DnsError = LookupError; +use ipnet::IpAdd; pub use serial_message::SerialMessage; #[derive(Debug, Clone, Copy, Default)] diff --git a/src/dns_conf.rs b/src/dns_conf.rs index ab53cfc6..ee0c11bd 100644 --- a/src/dns_conf.rs +++ b/src/dns_conf.rs @@ -289,8 +289,8 @@ impl RuntimeConfig { /// speed check mode #[inline] - pub fn speed_check_mode(&self) -> &SpeedCheckModeList { - &self.speed_check_mode + pub fn speed_check_mode(&self) -> Option<&SpeedCheckModeList> { + self.speed_check_mode.as_ref() } /// force AAAA query return SOA @@ -758,7 +758,7 @@ impl RuntimeConfigBuilder { ServerName(v) => self.server_name = Some(v), NumWorkers(v) => self.num_workers = Some(v), Domain(v) => self.domain = Some(v), - SpeedMode(v) => self.speed_check_mode.extend(v.0), + SpeedMode(v) => self.speed_check_mode = v, ServeExpiredTtl(v) => self.cache.serve_expired_ttl = Some(v), ServeExpiredReplyTtl(v) => self.cache.serve_expired_reply_ttl = Some(v), CacheSize(v) => self.cache.size = Some(v), @@ -1227,7 +1227,7 @@ mod tests { ); assert_eq!( domain_rule.speed_check_mode, - vec![SpeedCheckMode::Ping].into() + Some(vec![SpeedCheckMode::Ping].into()) ); assert_eq!(domain_rule.nameserver, Some("test".to_string())); assert_eq!(domain_rule.dualstack_ip_selection, Some(true)); @@ -1248,7 +1248,7 @@ mod tests { ); assert_eq!( domain_rule.speed_check_mode, - vec![SpeedCheckMode::Ping].into() + Some(vec![SpeedCheckMode::Ping].into()) ); assert_eq!(domain_rule.nameserver, Some("test".to_string())); assert_eq!(domain_rule.dualstack_ip_selection, Some(true)); @@ -1266,7 +1266,7 @@ mod tests { assert_eq!(domain_rule.address, Some(DomainAddress::SOA)); assert_eq!( domain_rule.speed_check_mode, - vec![SpeedCheckMode::Ping].into() + Some(vec![SpeedCheckMode::Ping].into()) ); assert_eq!(domain_rule.nameserver, Some("test".to_string())); assert_eq!(domain_rule.dualstack_ip_selection, Some(true)); @@ -1287,11 +1287,14 @@ mod tests { let mut cfg = RuntimeConfig::builder(); cfg.config("speed-check-mode ping,tcp:123"); - assert_eq!(cfg.speed_check_mode.len(), 2); + assert_eq!(cfg.speed_check_mode.as_ref().unwrap().len(), 2); - assert_eq!(cfg.speed_check_mode.first().unwrap(), &SpeedCheckMode::Ping); assert_eq!( - cfg.speed_check_mode.get(1).unwrap(), + cfg.speed_check_mode.as_ref().unwrap().first().unwrap(), + &SpeedCheckMode::Ping + ); + assert_eq!( + cfg.speed_check_mode.as_ref().unwrap().get(1).unwrap(), &SpeedCheckMode::Tcp(123) ); } @@ -1301,14 +1304,14 @@ mod tests { let mut cfg = RuntimeConfig::builder(); cfg.config("speed-check-mode http,https"); - assert_eq!(cfg.speed_check_mode.len(), 2); + assert_eq!(cfg.speed_check_mode.as_ref().unwrap().len(), 2); assert_eq!( - cfg.speed_check_mode.first().unwrap(), + cfg.speed_check_mode.as_ref().unwrap().first().unwrap(), &SpeedCheckMode::Http(80) ); assert_eq!( - cfg.speed_check_mode.get(1).unwrap(), + cfg.speed_check_mode.as_ref().unwrap().get(1).unwrap(), &SpeedCheckMode::Https(443) ); } diff --git a/src/dns_mw_dualstack.rs b/src/dns_mw_dualstack.rs index 90de9edb..a9cc4f2e 100644 --- a/src/dns_mw_dualstack.rs +++ b/src/dns_mw_dualstack.rs @@ -1,5 +1,15 @@ +use std::net::IpAddr; +use std::time::Duration; + +use futures::future::{select, Either}; +use futures::FutureExt; +use tokio::time::sleep; + +use crate::config::SpeedCheckMode; use crate::dns::*; +use crate::log::debug; use crate::middleware::*; +use crate::third_ext::FutureTimeoutExt; pub struct DnsDualStackIpSelectionMiddleware; @@ -13,50 +23,206 @@ impl Middleware req: &DnsRequest, next: Next<'_, DnsContext, DnsRequest, DnsResponse, DnsError>, ) -> Result { - // use RecordType::{A, AAAA}; - - // // highest priority - // if ctx.server_opts.no_dualstack_selection() { - // return next.run(ctx, req).await; - // } - - // let query_type = req.query().query_type(); - - // // must be ip query. - // if !matches!(query_type, A | AAAA) { - // return next.run(ctx, req).await; - // } - - // // read config - // let dualstack_ip_selection = ctx - // .domain_rule - // .as_ref() - // .map(|rule| rule.dualstack_ip_selection) - // .unwrap_or_default() - // .unwrap_or(ctx.cfg.dualstack_ip_selection()); - - // if !dualstack_ip_selection { - // return next.run(ctx, req).await; - // } - - // let (new_ctx, new_req, new_next) = { - // let mut new_req = req.clone(); - // new_req.set_query_type(match query_type { - // A => AAAA, - // AAAA => A, - // typ @ _ => typ, - // }); - - // (ctx.clone(), new_req, next.clone()) - // }; - - // let tasks = [ - // next.run(ctx, req).boxed(), - // move || async { new_next.run(&mut new_ctx, &new_req).await }.boxed(), - // ]; - - // todo!() - - next.run(ctx, req).await + use RecordType::{A, AAAA}; + + // highest priority + if ctx.server_opts.no_dualstack_selection() { + return next.run(ctx, req).await; + } + + let query_type = req.query().query_type(); + + // must be ip query. + if !query_type.is_ip_addr() { + return next.run(ctx, req).await; + } + + let mut prefer_that = false; // As long as it succeeds, there is no need to check the selection threshold. + + if matches!(query_type, A) { + if ctx.cfg().dualstack_ip_allow_force_aaaa() { + prefer_that = true; + } else { + return next.run(ctx, req).await; + } + } + + // read config + let dualstack_ip_selection = ctx + .domain_rule + .as_ref() + .map(|rule| rule.dualstack_ip_selection) + .unwrap_or_default() + .unwrap_or(ctx.cfg().dualstack_ip_selection()); + + if !dualstack_ip_selection { + return next.run(ctx, req).await; + } + + let selection_threshold = + Duration::from_millis(ctx.cfg().dualstack_ip_selection_threshold().into()); + + let speed_check_mode = ctx + .domain_rule + .as_ref() + .and_then(|r| r.speed_check_mode.as_ref()) + .cloned() + .unwrap_or_default(); + + let ttl = ctx.cfg().rr_ttl().unwrap_or_default() as u32; + + let that_type = match query_type { + A => AAAA, + AAAA => A, + typ => typ, + }; + + let mut that_ctx = ctx.clone(); + let that_req = { + let mut req = req.clone(); + req.set_query_type(that_type); + req + }; + + let that = next.clone().run(&mut that_ctx, &that_req); + let this = next.run(ctx, req); + + let dual_task = futures::future::select(this, that).await; + + let this_no_records = || { + debug!( + "dual stack IP selection: {} , choose {}", + req.query().name(), + that_type + ); + Err(DnsError::no_records_found( + req.query().original().to_owned(), + ttl, + )) + }; + + match dual_task { + Either::Left((res, that)) => match res { + Ok(this) => { + let that = that.timeout(selection_threshold).await; + + if let Ok(Ok(that)) = that { + if !prefer_that { + let that_faster = matches!( + which_faster(&this, &that, &speed_check_mode, selection_threshold) + .await, + Either::Right(_) + ); + + if that_faster { + return this_no_records(); + } + } + } + + Ok(this) + } + Err(err) => Err(err), + }, + Either::Right((res, this)) => match res { + Ok(that) => { + if !prefer_that { + match this.await { + Ok(this) => { + let that_faster = matches!( + which_faster( + &this, + &that, + &speed_check_mode, + selection_threshold + ) + .await, + Either::Right(_) + ); + + if that_faster { + return this_no_records(); + } + + Ok(this) + } + Err(err) => Err(err), + } + } else { + return this_no_records(); + } + } + Err(_) => this.await, + }, + } + } +} + +async fn which_faster( + this: &DnsResponse, + that: &DnsResponse, + modes: &[SpeedCheckMode], + selection_threshold: Duration, +) -> Either<(), ()> { + let this_ip_addrs = this.ip_addrs(); + let that_ip_addrs = that.ip_addrs(); + + let this_ping = multi_mode_ping_fastest(this_ip_addrs, modes.to_vec()).boxed(); + let that_ping = multi_mode_ping_fastest(that_ip_addrs, modes.to_vec()).boxed(); + + let which_faster = select(this_ping, that_ping).await; + + let that_faster = match which_faster { + Either::Right((Some((_, that_dura)), this_ping)) => match this_ping.await { + Some((_, this_dura)) => { + that_dura > this_dura && (that_dura - this_dura) > selection_threshold + } + None => true, + }, + _ => false, + }; + + if that_faster { + Either::Right(()) + } else { + Either::Left(()) } } + +async fn multi_mode_ping_fastest( + ip_addrs: Vec, + modes: Vec, +) -> Option<(IpAddr, Duration)> { + use crate::infra::ping::{ping_fastest, PingOptions}; + let duration = Duration::from_millis(200); + let ping_ops = PingOptions::default().with_timeout_secs(2); + + let mut fastest_ip = None; + + for mode in &modes { + let dests = mode.to_ping_addrs(&ip_addrs); + + let ping_task = ping_fastest(dests, ping_ops).boxed(); + let timeout_task = sleep(duration).boxed(); + match futures_util::future::select(ping_task, timeout_task).await { + futures::future::Either::Left((ping_res, _)) => { + match ping_res { + Ok(ping_out) => { + // ping success + let ip = ping_out.destination().ip(); + let duration = ping_out.duration(); + fastest_ip = Some((ip, duration)); + break; + } + Err(_) => continue, + } + } + futures::future::Either::Right((_, _)) => { + // timeout + continue; + } + } + } + + fastest_ip +} diff --git a/src/dns_mw_ns.rs b/src/dns_mw_ns.rs index cb3472b3..a5184161 100644 --- a/src/dns_mw_ns.rs +++ b/src/dns_mw_ns.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; -use std::net::SocketAddr; use std::ops::Deref; use std::sync::Arc; use std::{borrow::Borrow, net::IpAddr, time::Duration}; @@ -7,7 +6,6 @@ use std::{borrow::Borrow, net::IpAddr, time::Duration}; use crate::dns_client::{LookupOptions, NameServer}; use crate::infra::ipset::IpSet; -use crate::infra::ping::PingAddr; use crate::{ config::{ResponseMode, SpeedCheckMode, SpeedCheckModeList}, dns::*, @@ -99,10 +97,9 @@ impl Middleware for NameServerMid response_strategy: rule .get(|n| n.response_mode) .unwrap_or_else(|| cfg.response_mode()), - speed_check_mode: if rule.speed_check_mode.is_empty() { - cfg.speed_check_mode().clone() - } else { - rule.speed_check_mode.clone() + speed_check_mode: match rule.speed_check_mode.as_ref() { + Some(mode) => Some(mode.clone()), + None => cfg.speed_check_mode().cloned(), }, no_speed_check: ctx.server_opts.no_speed_check(), ignore_ip: cfg.ignore_ip().clone(), @@ -112,7 +109,7 @@ impl Middleware for NameServerMid }, None => LookupIpOptions { response_strategy: cfg.response_mode(), - speed_check_mode: cfg.speed_check_mode().clone(), + speed_check_mode: cfg.speed_check_mode().cloned(), no_speed_check: ctx.server_opts.no_speed_check(), ignore_ip: cfg.ignore_ip().clone(), blacklist_ip: cfg.blacklist_ip().clone(), @@ -135,7 +132,7 @@ impl Middleware for NameServerMid struct LookupIpOptions { response_strategy: ResponseMode, - speed_check_mode: SpeedCheckModeList, + speed_check_mode: Option, no_speed_check: bool, ignore_ip: Arc, whitelist_ip: Arc, @@ -170,6 +167,7 @@ async fn lookup_ip( ) -> Result { use crate::third_ext::FutureJoinAllExt; use futures_util::future::{select, select_all, Either}; + use ResponseMode::*; assert!(options.record_type.is_ip_addr()); @@ -182,23 +180,22 @@ async fn lookup_ip( return Err(ProtoErrorKind::NoConnections.into()); } - use ResponseMode::*; - // ignore - let response_strategy = if options.no_speed_check - || options.speed_check_mode.is_empty() - || options - .speed_check_mode - .iter() - .any(|m| *m == SpeedCheckMode::None) - { + let response_strategy = if options.no_speed_check || options.speed_check_mode.is_none() { FastestResponse } else { options.response_strategy }; + let speed_check_mode = options + .speed_check_mode + .as_ref() + .map(|m| m.as_slice()) + .unwrap_or_default(); + let mut ok_tasks = vec![]; let mut err_tasks = vec![]; + let selected_ip = match response_strategy { FirstPing => { let mut tasks = tasks; @@ -253,7 +250,7 @@ async fn lookup_ip( multi_mode_ping_fastest( name.clone(), ip_addrs, - options.speed_check_mode.to_vec(), + speed_check_mode.to_vec(), ) .boxed(), ); @@ -311,12 +308,9 @@ async fn lookup_ip( 0 => None, 1 => ip_addrs.pop(), _ => { - let fastest_ip = multi_mode_ping_fastest( - name.clone(), - ip_addrs, - options.speed_check_mode.to_vec(), - ) - .await; + let fastest_ip = + multi_mode_ping_fastest(name.clone(), ip_addrs, speed_check_mode.to_vec()) + .await; fastest_ip.or_else(|| { ip_addrs_map @@ -351,7 +345,13 @@ async fn lookup_ip( unreachable!() } - Ok(ok_tasks.into_iter().next().unwrap()) // There is definitely one. + match ok_tasks.into_iter().next() { + Some(lookup) => Ok(lookup), + None => match err_tasks.into_iter().next() { + Some(err) => Err(err), + None => unreachable!(), + }, + } } async fn multi_mode_ping_fastest( @@ -360,49 +360,17 @@ async fn multi_mode_ping_fastest( modes: Vec, ) -> Option { use crate::infra::ping::{ping_fastest, PingOptions}; - let duaration = Duration::from_millis(200); + let duration = Duration::from_millis(200); let ping_ops = PingOptions::default().with_timeout_secs(2); let mut fastest_ip = None; for mode in &modes { - let dests = match mode { - SpeedCheckMode::None => continue, - SpeedCheckMode::Ping => { - debug!("Speed test {} ping {:?}", name, ip_addrs); - ip_addrs - .iter() - .map(|ip| PingAddr::Icmp(*ip)) - .collect::>() - } - SpeedCheckMode::Tcp(port) => { - debug!("Speed test {} TCP ping {:?} port {}", name, ip_addrs, port); - ip_addrs - .iter() - .map(|ip| PingAddr::Tcp(SocketAddr::new(*ip, *port))) - .collect::>() - } - SpeedCheckMode::Http(port) => { - debug!("Speed test {} HTTP ping {:?} port {}", name, ip_addrs, port); - ip_addrs - .iter() - .map(|ip| PingAddr::Http(SocketAddr::new(*ip, *port))) - .collect::>() - } - SpeedCheckMode::Https(port) => { - debug!( - "Speed test {} HTTPS ping {:?} port {}", - name, ip_addrs, port - ); - ip_addrs - .iter() - .map(|ip| PingAddr::Https(SocketAddr::new(*ip, *port))) - .collect::>() - } - }; + debug!("Speed test {} {:?} ping {:?}", name, mode, ip_addrs); + let dests = mode.to_ping_addrs(&ip_addrs); let ping_task = ping_fastest(dests, ping_ops).boxed(); - let timeout_task = sleep(duaration).boxed(); + let timeout_task = sleep(duration).boxed(); match futures_util::future::select(ping_task, timeout_task).await { futures::future::Either::Left((ping_res, _)) => { match ping_res {