Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refresh SRV records and send out ServiceRemoved for expired SRV #180

Merged
merged 3 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ jobs:
- name: Build
run: cargo build
- name: Run clippy and fail if any warnings
run: cargo clippy -- -W clippy::all -D warnings
run: cargo clippy -- -D warnings
- name: Run tests
run: cargo test
87 changes: 82 additions & 5 deletions src/dns_parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,14 @@ pub struct DnsRecord {
impl DnsRecord {
fn new(name: &str, ty: u16, class: u16, ttl: u32) -> Self {
let created = current_time_millis();

// From RFC 6762 section 5.2:
// "... The querier should plan to issue a query at 80% of the record
// lifetime, and then if no answer is received, at 85%, 90%, and 95%."
let refresh = get_expiration_time(created, ttl, 80);

let expires = get_expiration_time(created, ttl, 100);

Self {
entry: DnsEntry::new(name.to_string(), ty, class),
ttl,
Expand Down Expand Up @@ -136,6 +142,36 @@ impl DnsRecord {
self.refresh = get_expiration_time(self.created, self.ttl, 100);
}

/// Returns if this record is due for refresh. If yes, `refresh` time is updated.
pub(crate) fn refresh_maybe(&mut self, now: u64) -> bool {
if self.is_expired(now) || !self.refresh_due(now) {
return false;
}

debug!(
"{} qtype {} is due to refresh",
&self.entry.name, self.entry.ty
);

// From RFC 6762 section 5.2:
// "... The querier should plan to issue a query at 80% of the record
// lifetime, and then if no answer is received, at 85%, 90%, and 95%."
//
// If the answer is received in time, 'refresh' will be reset outside
// this function, back to 80% of the new TTL.
if self.refresh == get_expiration_time(self.created, self.ttl, 80) {
self.refresh = get_expiration_time(self.created, self.ttl, 85);
} else if self.refresh == get_expiration_time(self.created, self.ttl, 85) {
self.refresh = get_expiration_time(self.created, self.ttl, 90);
} else if self.refresh == get_expiration_time(self.created, self.ttl, 90) {
self.refresh = get_expiration_time(self.created, self.ttl, 95);
} else {
self.refresh_no_more();
}

true
}

/// Returns the remaining TTL in seconds
fn get_remaining_ttl(&self, now: u64) -> u32 {
let remaining_millis = get_expiration_time(self.created, self.ttl, 100) - now;
Expand Down Expand Up @@ -190,6 +226,8 @@ pub trait DnsRecordExt: fmt::Debug {
self.get_record().entry.ty
}

/// Resets TTL using `other` record.
/// `self.refresh` and `self.expires` are also reset.
fn reset_ttl(&mut self, other: &dyn DnsRecordExt) {
self.get_record_mut().reset_ttl(other.get_record());
}
Expand Down Expand Up @@ -1084,7 +1122,16 @@ impl DnsIncoming {

let ty = u16_from_be_slice(&slice[..2]);
let class = u16_from_be_slice(&slice[2..4]);
let ttl = u32_from_be_slice(&slice[4..8]);
let mut ttl = u32_from_be_slice(&slice[4..8]);
if ttl == 0 && self.is_response() {
// RFC 6762 section 10.1:
// "...Queriers receiving a Multicast DNS response with a TTL of zero SHOULD
// NOT immediately delete the record from the cache, but instead record
// a TTL of 1 and then delete the record one second later."
// See https://datatracker.ietf.org/doc/html/rfc6762#section-10.1

ttl = 1;
}
let length = u16_from_be_slice(&slice[8..10]) as usize;
self.offset += 10;
let next_offset = self.offset + length;
Expand Down Expand Up @@ -1354,19 +1401,22 @@ const fn u32_from_be_slice(s: &[u8]) -> u32 {
u32::from_be_bytes(u8_array)
}

/// Returns the time in millis at which this record will have expired
/// Returns the UNIX time in millis at which this record will have expired
/// by a certain percentage.
const fn get_expiration_time(created: u64, ttl: u32, percent: u32) -> u64 {
// 'created' is in millis, 'ttl' is in seconds, hence:
// ttl * 1000 * (percent / 100) => ttl * percent * 10
created + (ttl * percent * 10) as u64
}

#[cfg(test)]
mod tests {
use crate::dns_parser::{TYPE_A, TYPE_AAAA};
use crate::dns_parser::get_expiration_time;

use super::{
DnsIncoming, DnsNSec, DnsOutgoing, DnsSrv, CLASS_CACHE_FLUSH, CLASS_IN, FLAGS_QR_QUERY,
FLAGS_QR_RESPONSE, TYPE_PTR,
current_time_millis, DnsIncoming, DnsNSec, DnsOutgoing, DnsRecordExt, DnsSrv,
CLASS_CACHE_FLUSH, CLASS_IN, FLAGS_QR_QUERY, FLAGS_QR_RESPONSE, TYPE_A, TYPE_AAAA,
TYPE_PTR,
};

#[test]
Expand Down Expand Up @@ -1453,4 +1503,31 @@ mod tests {
assert_eq!(absent_types[0], TYPE_A);
assert_eq!(absent_types[1], TYPE_AAAA);
}

#[test]
fn test_refresh_maybe() {
let name = "test_refresh._udp.local.";
let ttl = 2;
let hostname = "instance1.local.";
let mut srv = DnsSrv::new(name, CLASS_IN, ttl, 0, 0, 0, hostname.to_string());

// refresh is not due yet.
let now = current_time_millis();
let refreshed = srv.get_record_mut().refresh_maybe(now);
assert!(!refreshed);

// sleep for 80 percent of TTL in millis to reach "refresh" time.
let sleep_in_mills = (ttl * 80 * 10) as u64;
std::thread::sleep(std::time::Duration::from_millis(sleep_in_mills));

// refresh is due.
let now = current_time_millis();
let refreshed = srv.get_record_mut().refresh_maybe(now);
assert!(refreshed);

// refresh time is updated.
let dns_record = srv.get_record();
let new_refresh = get_expiration_time(dns_record.get_created(), dns_record.ttl, 85);
assert_eq!(new_refresh, dns_record.get_refresh_time());
}
}
Loading