From 3fe9b5628ec477a2d94cb7e577eaea765556d374 Mon Sep 17 00:00:00 2001 From: keepsimple1 Date: Tue, 25 Apr 2023 18:04:27 -0700 Subject: [PATCH] Refactoring DnsCache and how to resolve Service Instance (#108) Simplify the logic of resolving an instance: always resolve from the cache. For any incoming updates, we update the cache first, then try to resolve. --- src/dns_parser.rs | 9 + src/service_daemon.rs | 446 ++++++++++++++++-------------------------- src/service_info.rs | 15 +- 3 files changed, 182 insertions(+), 288 deletions(-) diff --git a/src/dns_parser.rs b/src/dns_parser.rs index cd75365..3ded475 100644 --- a/src/dns_parser.rs +++ b/src/dns_parser.rs @@ -73,6 +73,9 @@ pub(crate) struct DnsRecord { pub(crate) entry: DnsEntry, ttl: u32, // in seconds, 0 means this record should not be cached created: u64, // UNIX time in millis + + /// Support re-query an instance before its PTR record expires. + /// See https://datatracker.ietf.org/doc/html/rfc6762#section-5.2 refresh: u64, // UNIX time in millis } @@ -88,6 +91,10 @@ impl DnsRecord { } } + pub(crate) fn get_created(&self) -> u64 { + self.created + } + pub(crate) fn is_expired(&self, now: u64) -> bool { get_expiration_time(self.created, self.ttl, 100) <= now } @@ -796,6 +803,8 @@ pub(crate) struct DnsIncoming { offset: usize, data: Vec, pub(crate) questions: Vec, + /// This field includes records in the `answers` section + /// and in the `additionals` section. pub(crate) answers: Vec, pub(crate) id: u16, flags: u16, diff --git a/src/service_daemon.rs b/src/service_daemon.rs index d45fda3..ed90426 100644 --- a/src/service_daemon.rs +++ b/src/service_daemon.rs @@ -324,7 +324,7 @@ impl ServiceDaemon { // Send out additional queries for unresolved instances, where // the early responses did not have SRV records. - zc.query_unresolved(); + zc.query_missing_srv(); // process commands from the command channel match receiver.try_recv() { @@ -581,19 +581,6 @@ struct IntfSock { sock: Socket, } -/// Represents possible states for an instance. -#[derive(PartialEq, Eq)] -enum InstanceState { - Found, // Not resolved yet. - Resolved, -} - -/// Represents an instance found for a service type we are interested in. -struct Instance { - state: InstanceState, - info: ServiceInfo, -} - /// A struct holding the state. It was inspired by `zeroconf` package in Python. struct Zeroconf { /// Local interfaces with sockets to recv/send on these interfaces. @@ -610,9 +597,6 @@ struct Zeroconf { /// Active "Browse" commands. queriers: HashMap>, // - /// Active queriers interested instances - instances_found: HashMap, - /// All repeating transmissions. retransmissions: Vec, @@ -658,7 +642,6 @@ impl Zeroconf { broadcast_addr, cache: DnsCache::new(), queriers: HashMap::new(), - instances_found: HashMap::new(), retransmissions: Vec::new(), counters: HashMap::new(), poller, @@ -1058,31 +1041,21 @@ impl Zeroconf { true } - /// Sends query TYPE_ANY for instances that are unresolved for a while. - fn query_unresolved(&mut self) { + /// Sends TYPE_ANY query for instances that're missing SRV records. + fn query_missing_srv(&mut self) { let now = current_time_millis(); - let wait_in_millis = 800; - let unresolved: Vec = self - .instances_found - .iter() - .filter(|(name, instance)| { - valid_instance_name(name) - && instance.state == InstanceState::Found - && now > (instance.info.get_last_update() + wait_in_millis) - }) - .map(|(name, _)| name.to_string()) - .collect(); + let wait_in_millis = 800; // The threshold for deeming SRV missing. - for instance_name in unresolved.iter() { - debug!( - "{}: send query for unresolved instance {}", - &now, instance_name - ); - self.send_query(instance_name, TYPE_ANY); - - // update the info timestamp. - if let Some(instance) = self.instances_found.get_mut(instance_name) { - instance.info.set_last_update(now); + for records in self.cache.ptr.values() { + for record in records.iter() { + if let Some(ptr) = record.any().downcast_ref::() { + if !self.cache.srv.contains_key(&ptr.alias) + && valid_instance_name(&ptr.alias) + && now > ptr.get_record().get_created() + wait_in_millis + { + self.send_query(&ptr.alias, TYPE_ANY); + } + } } } } @@ -1090,7 +1063,7 @@ impl Zeroconf { /// Checks if `ty_domain` has records in the cache. If yes, sends the /// cached records via `sender`. fn query_cache(&mut self, ty_domain: &str, sender: Sender) { - if let Some(records) = self.cache.get_records_by_name(ty_domain) { + if let Some(records) = self.cache.ptr.get(ty_domain) { for record in records.iter() { if let Some(ptr) = record.any().downcast_ref::() { let info = self.create_service_info_from_cache(ty_domain, &ptr.alias); @@ -1118,14 +1091,6 @@ impl Zeroconf { Ok(()) => debug!("sent service resolved"), Err(e) => error!("failed to send service resolved: {}", e), } - } else if !self.instances_found.contains_key(info.get_fullname()) { - self.instances_found.insert( - ty_domain.to_string(), - Instance { - state: InstanceState::Found, - info, - }, - ); } } } @@ -1144,19 +1109,27 @@ impl Zeroconf { let mut info = ServiceInfo::new(ty_domain, &my_name, "", (), 0, None)?; - // resolve SRV and TXT records - if let Some(records) = self.cache.map.get(fullname) { - for answer in records.iter() { + // resolve SRV record + if let Some(records) = self.cache.srv.get(fullname) { + if let Some(answer) = records.get(0) { if let Some(dns_srv) = answer.any().downcast_ref::() { info.set_hostname(dns_srv.host.clone()); info.set_port(dns_srv.port); - } else if let Some(dns_txt) = answer.any().downcast_ref::() { + } + } + } + + // resolve TXT record + if let Some(records) = self.cache.txt.get(fullname) { + if let Some(record) = records.get(0) { + if let Some(dns_txt) = record.any().downcast_ref::() { info.set_properties_from_txt(&dns_txt.text); } } } - if let Some(records) = self.cache.map.get(info.get_hostname()) { + // resolve A records + if let Some(records) = self.cache.addr.get(info.get_hostname()) { for answer in records.iter() { if let Some(dns_a) = answer.any().downcast_ref::() { info.insert_ipv4addr(dns_a.address); @@ -1167,125 +1140,6 @@ impl Zeroconf { Ok(info) } - /// Try to resolve some instances based on a record (answer), - /// and return a list of instances that got resolved or updated. - fn resolve_by_answer( - instances_found: &mut HashMap, - answer: &DnsRecordBox, - ) -> Vec { - let mut resolved = Vec::new(); - if let Some(dns_srv) = answer.any().downcast_ref::() { - if let Some(instance) = instances_found.get_mut(answer.get_name()) { - if instance.info.get_hostname() != dns_srv.host - || instance.info.get_port() != dns_srv.port - { - debug!("setting server and port for service info"); - - instance.info.set_hostname(dns_srv.host.clone()); - instance.info.set_port(dns_srv.port); - if instance.info.is_ready() { - if instance.state == InstanceState::Found { - instance.state = InstanceState::Resolved; - } - resolved.push(answer.get_name().to_string()); - } - } - } - } else if let Some(dns_txt) = answer.any().downcast_ref::() { - if let Some(instance) = instances_found.get_mut(answer.get_name()) { - if instance.info.set_properties_from_txt(&dns_txt.text) { - debug!("setting TXT: {:?}", instance.info.get_properties()); - - if instance.info.is_ready() { - if instance.state == InstanceState::Found { - instance.state = InstanceState::Resolved; - } - resolved.push(answer.get_name().to_string()); - } - } - } - } else if let Some(dns_a) = answer.any().downcast_ref::() { - for (instance_name, instance) in instances_found.iter_mut() { - if instance.info.get_hostname() == answer.get_name() - && !instance.info.get_addresses().contains(&dns_a.address) - { - debug!( - "setting address in server {}: {}", - instance.info.get_hostname(), - &dns_a.address - ); - instance.info.insert_ipv4addr(dns_a.address); - if instance.info.is_ready() { - if instance.state == InstanceState::Found { - instance.state = InstanceState::Resolved; - } - resolved.push(instance_name.clone()); - } - } - } - } - resolved - } - - /// Returns a list of instances that have resolved by the answer. - fn handle_answer(&mut self, record: DnsRecordBox) -> Vec { - let (record_ext, existing) = self.cache.add_or_update(record); - let dns_entry = &record_ext.get_record().entry; - let mut resolved = Vec::new(); - debug!("add_or_update record name: {:?}", &dns_entry.name); - - if let Some(dns_ptr) = record_ext.any().downcast_ref::() { - let service_type = dns_entry.name.clone(); - let instance = dns_ptr.alias.clone(); - - if !self.queriers.contains_key(&service_type) { - debug!("Not interested for any querier"); - return resolved; - } - - // Insert into services_to_resolve if this is a new instance - if !self.instances_found.contains_key(&instance) { - if existing { - debug!("already knew: {}", &instance); - return resolved; - } - - let my_name = { - let name = instance.trim_end_matches(split_sub_domain(&service_type).0); - name.strip_suffix('.').unwrap_or(name).to_string() - }; - - let service_info = ServiceInfo::new(&service_type, &my_name, "", (), 0, None); - - match service_info { - Ok(service_info) => { - debug!("Inserting service info: {:?}", &service_info); - self.instances_found.insert( - instance.clone(), - Instance { - state: InstanceState::Found, - info: service_info, - }, - ); - } - Err(err) => { - error!("Malformed service info while inserting: {:?}", err); - } - } - } - - call_listener( - &self.queriers, - &dns_entry.name, - ServiceEvent::ServiceFound(service_type, instance), - ); - } else { - resolved = Self::resolve_by_answer(&mut self.instances_found, record_ext); - } - - resolved - } - /// Deal with incoming response packets. All answers /// are held in the cache, and listeners are notified. fn handle_response(&mut self, mut msg: DnsIncoming) { @@ -1320,61 +1174,80 @@ impl Zeroconf { false }); - let mut resolved = HashSet::new(); - - // process PTR records first as we create entries in cache based on PTR records. - // This code can be simplified when `drain_filter` is stablized. - let mut i = 0; - while i < msg.answers.len() { - if msg.answers[i].get_type() == TYPE_PTR { - let record = msg.answers.remove(i); - let newly_resolved = self.handle_answer(record); - resolved.extend(newly_resolved); - } else { - i += 1; - } + /// Represents a DNS record change that involves one service instance. + struct InstanceChange { + ty: u16, // The type of DNS record for the instance. + name: String, // The name of the record. } - // process other types of records. + // Go through all answers to get the new and updated records. + // For new PTR records, send out ServiceFound immediately. For others, + // collect them into `changes`. + // + // Note: we don't try to identify the update instances based on + // each record immediately as the answers are likely related to each + // other. + let mut changes = Vec::new(); for record in msg.answers { - let newly_resolved = self.handle_answer(record); - resolved.extend(newly_resolved); + if let Some((dns_record, true)) = self.cache.add_or_update(record) { + let ty = dns_record.get_type(); + let name = dns_record.get_name(); + if ty == TYPE_PTR { + // send ServiceFound + if let Some(dns_ptr) = dns_record.any().downcast_ref::() { + call_listener( + &self.queriers, + name, + ServiceEvent::ServiceFound(name.to_string(), dns_ptr.alias.clone()), + ); + } + } else { + changes.push(InstanceChange { + ty, + name: name.to_string(), + }); + } + } } - self.process_resolved(resolved); - } - - /// Process resolved instances and send out notifications. - /// It is OK to have duplicated instances in `resolved`. - fn process_resolved(&mut self, resolved: HashSet) { - for instance_name in resolved.iter() { - let instance = match self.instances_found.get(instance_name) { - Some(i) => i, - None => { - debug!("Instance {} was not found", instance_name); - continue; + // Identify the instances that need to be "resolved". + let mut updated_instances = HashSet::new(); + for update in changes { + match update.ty { + TYPE_SRV | TYPE_TXT => { + updated_instances.insert(update.name); } - }; - fn s(listener: &Sender, info: ServiceInfo) { - match listener.send(ServiceEvent::ServiceResolved(info)) { - Ok(()) => debug!("sent service info successfully"), - Err(e) => error!("failed to send service info: {}", e), + TYPE_A => { + let instances = self.cache.get_instances_on_host(&update.name); + updated_instances.extend(instances); } + _ => {} } - let sub_query = instance - .info - .get_subtype() - .as_ref() - .and_then(|s| self.queriers.get(s)); - let query = self.queriers.get(instance.info.get_type()); - match (sub_query, query) { - (Some(sub_listener), Some(listener)) => { - s(sub_listener, instance.info.clone()); - s(listener, instance.info.clone()); + } + + // Resolve the updated (including new) instances. + // + // Note: it is possible that more than 1 PTR pointing to the same + // instance. For example, a regular service type PTR and a sub-type + // service type PTR can both point to the same service instance. + // This loop automatically handles the sub-type PTRs. + for (ty_domain, records) in self.cache.ptr.iter() { + for record in records.iter() { + if let Some(dns_ptr) = record.any().downcast_ref::() { + if updated_instances.contains(&dns_ptr.alias) { + if let Ok(info) = + self.create_service_info_from_cache(ty_domain, &dns_ptr.alias) + { + if info.is_ready() { + call_listener( + &self.queriers, + ty_domain, + ServiceEvent::ServiceResolved(info), + ); + } + } + } } - (None, Some(listener)) => s(listener, instance.info.clone()), - (Some(listener), None) => s(listener, instance.info.clone()), - _ => {} } } } @@ -1595,50 +1468,79 @@ enum DaemonOption { } struct DnsCache { - /// - map: HashMap>, + ptr: HashMap>, + srv: HashMap>, + txt: HashMap>, + addr: HashMap>, } impl DnsCache { fn new() -> Self { Self { - map: HashMap::new(), + ptr: HashMap::new(), + srv: HashMap::new(), + txt: HashMap::new(), + addr: HashMap::new(), } } - fn get_records_by_name(&self, name: &str) -> Option<&Vec> { - self.map.get(name) + /// Returns the list of instances that has `host` as its hostname. + fn get_instances_on_host(&self, host: &str) -> Vec { + self.srv + .iter() + .filter_map(|(instance, srv_list)| { + if let Some(item) = srv_list.get(0) { + if let Some(dns_srv) = item.any().downcast_ref::() { + if dns_srv.host == host { + return Some(instance.clone()); + } + } + } + None + }) + .collect() } /// Update a DNSRecord if already exists, otherwise insert a new record - fn add_or_update(&mut self, incoming: DnsRecordBox) -> (&DnsRecordBox, bool) { - let record_vec = self.map.entry(incoming.get_name().to_string()).or_default(); - - let mut found = false; - let mut idx = record_vec.len(); + fn add_or_update(&mut self, incoming: DnsRecordBox) -> Option<(&DnsRecordBox, bool)> { + let entry_name = incoming.get_name().to_string(); + let record_vec = match incoming.get_type() { + TYPE_PTR => self.ptr.entry(entry_name).or_default(), + TYPE_SRV => self.srv.entry(entry_name).or_default(), + TYPE_TXT => self.txt.entry(entry_name).or_default(), + TYPE_A => self.addr.entry(entry_name).or_default(), + _ => return None, + }; - for i in 0..record_vec.len() { - let r = record_vec.get_mut(i).unwrap(); - if r.matches(incoming.as_ref()) { + let (idx, updated) = match record_vec + .iter_mut() + .enumerate() + .find(|(_idx, r)| r.matches(incoming.as_ref())) + { + Some((i, r)) => { r.reset_ttl(incoming.as_ref()); - found = true; - idx = i; - break; + (i, false) } - } - - if !found { - record_vec.insert(0, incoming); // we did not find it. - idx = 0; - } - - (record_vec.get(idx).unwrap(), found) + None => { + record_vec.insert(0, incoming); // A new record. + (0, true) + } + }; + Some((record_vec.get(idx).unwrap(), updated)) } /// Remove a record from the cache if exists, otherwise no-op fn remove(&mut self, record: &DnsRecordBox) -> bool { let mut found = false; - if let Some(record_vec) = self.map.get_mut(record.get_name()) { + let record_name = record.get_name(); + let record_vec = match record.get_type() { + TYPE_PTR => self.ptr.get_mut(record_name), + TYPE_SRV => self.srv.get_mut(record_name), + TYPE_TXT => self.txt.get_mut(record_name), + TYPE_A => self.addr.get_mut(record_name), + _ => return found, + }; + if let Some(record_vec) = record_vec { record_vec.retain(|x| match x.matches(record.as_ref()) { true => { found = true; @@ -1656,7 +1558,13 @@ impl DnsCache { where F: Fn(&DnsRecordBox), // Caller has a chance to do something with expired { - for records in self.map.values_mut() { + let all_records = self + .ptr + .values_mut() + .chain(self.srv.values_mut()) + .chain(self.txt.values_mut()) + .chain(self.addr.values_mut()); + for records in all_records { records.retain(|x| { let expired = x.get_record().is_expired(now); if expired { @@ -1667,19 +1575,6 @@ impl DnsCache { } } - /// Returns the list of full name of the instances for a `ty_domain`. - fn instance_names(&self, ty_domain: &str) -> Vec { - let mut result = Vec::new(); - if let Some(instances) = self.map.get(ty_domain) { - for instance_ptr in instances.iter() { - if let Some(dns_ptr) = instance_ptr.any().downcast_ref::() { - result.push(dns_ptr.alias.clone()); - } - } - } - result - } - /// Returns the list of instance names that are due for refresh /// for a `ty_domain`. /// @@ -1687,23 +1582,24 @@ impl DnsCache { /// they will not refresh again. fn refresh_due(&mut self, ty_domain: &str) -> Vec { let now = current_time_millis(); - let mut result = Vec::new(); - - for instance in self.instance_names(ty_domain).iter() { - if let Some(records) = self.map.get_mut(instance) { - for record in records.iter_mut() { - let rec = record.get_record_mut(); - if !rec.is_expired(now) && rec.refresh_due(now) { - result.push(instance.clone()); - - // Only refresh a record once, until it expires and resets. - rec.refresh_no_more(); - break; // for one instance, only query once - } + + self.ptr + .get_mut(ty_domain) + .into_iter() + .flatten() + .filter_map(|record| { + let rec = record.get_record_mut(); + if rec.is_expired(now) || !rec.refresh_due(now) { + return None; } - } - } - result + rec.refresh_no_more(); + + record + .any() + .downcast_ref::() + .map(|dns_ptr| dns_ptr.alias.clone()) + }) + .collect() } } diff --git a/src/service_info.rs b/src/service_info.rs index c50dccd..fcac23c 100644 --- a/src/service_info.rs +++ b/src/service_info.rs @@ -1,6 +1,6 @@ #[cfg(feature = "logging")] use crate::log::error; -use crate::{dns_parser::current_time_millis, Error, Result}; +use crate::{Error, Result}; use if_addrs::Ifv4Addr; use std::{ collections::{HashMap, HashSet}, @@ -31,8 +31,7 @@ pub struct ServiceInfo { priority: u16, weight: u16, txt_properties: TxtProperties, - last_update: u64, // UNIX time in millis - addr_auto: bool, // Let the system update addresses automatically. + addr_auto: bool, // Let the system update addresses automatically. } impl ServiceInfo { @@ -74,7 +73,6 @@ impl ServiceInfo { let server = host_name.to_string(); let addresses = host_ipv4.as_ipv4_addrs()?; let txt_properties = properties.into_txt_properties(); - let last_update = current_time_millis(); // RFC6763 section 6.4: https://www.rfc-editor.org/rfc/rfc6763#section-6.4 // The characters of a key MUST be printable US-ASCII values (0x20-0x7E) @@ -107,7 +105,6 @@ impl ServiceInfo { priority: 0, weight: 0, txt_properties, - last_update, addr_auto: false, }; @@ -276,14 +273,6 @@ impl ServiceInfo { false } } - - pub(crate) fn get_last_update(&self) -> u64 { - self.last_update - } - - pub(crate) fn set_last_update(&mut self, update: u64) { - self.last_update = update; - } } /// This trait allows for parsing an input into a set of one or multiple [`Ipv4Addr`].