Skip to content

Commit

Permalink
feat(phoenix-channel): fail on missing heartbeat after 5s (#4296)
Browse files Browse the repository at this point in the history
This PR fixes a bug and adds a missing feature to `phoenix-channel`.

1. Previously, we used to erroneously reset the heartbeat state on all
sorts of empty replies, not just the specific one from the heartbeat.
2. We only failed on missing heartbeats when it was time to send the
next one.

With this PR, we correct the first bug and add a dedicated timeout of 5s
for the heartbeat reply.
  • Loading branch information
thomaseizinger committed Mar 25, 2024
1 parent b113a7c commit ecce024
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 58 deletions.
129 changes: 86 additions & 43 deletions rust/phoenix-channel/src/heartbeat.rs
Original file line number Diff line number Diff line change
@@ -1,81 +1,93 @@
use crate::{EgressControlMessage, OutboundRequestId};
use crate::OutboundRequestId;
use futures::FutureExt;
use std::{
pin::Pin,
sync::{atomic::AtomicU64, Arc},
task::{ready, Context, Poll},
time::Duration,
};
use tokio::time::MissedTickBehavior;

const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30);
pub const INTERVAL: Duration = Duration::from_secs(30);
pub const TIMEOUT: Duration = Duration::from_secs(5);

pub struct Heartbeat {
/// When to send the next heartbeat.
interval: Pin<Box<tokio::time::Interval>>,

timeout: Duration,

/// The ID of our heatbeat if we haven't received a reply yet.
id: Option<OutboundRequestId>,
pending: Option<(OutboundRequestId, Pin<Box<tokio::time::Sleep>>)>,

next_request_id: Arc<AtomicU64>,
}

impl Heartbeat {
pub fn maybe_handle_reply(&mut self, id: OutboundRequestId) -> bool {
let Some(pending) = self.id.take() else {
return false;
};
pub fn new(interval: Duration, timeout: Duration, next_request_id: Arc<AtomicU64>) -> Self {
let mut interval = tokio::time::interval(interval);
interval.set_missed_tick_behavior(MissedTickBehavior::Skip);

if pending != id {
return false;
Self {
interval: Box::pin(interval),
pending: Default::default(),
next_request_id,
timeout,
}

self.id = None;
true
}

pub fn set_id(&mut self, id: OutboundRequestId) {
self.id = Some(id);
pub fn maybe_handle_reply(&mut self, id: OutboundRequestId) -> bool {
match self.pending.as_ref() {
Some((pending, timeout)) if pending == &id && !dbg!(timeout.is_elapsed()) => {
self.pending = None;

true
}
_ => false,
}
}

pub fn poll(
&mut self,
cx: &mut Context,
) -> Poll<Result<EgressControlMessage<()>, MissedLastHeartbeat>> {
ready!(self.interval.poll_tick(cx));

if self.id.is_some() {
self.id = None;
) -> Poll<Result<OutboundRequestId, MissedLastHeartbeat>> {
if let Some((_, timeout)) = self.pending.as_mut() {
ready!(timeout.poll_unpin(cx));
self.pending = None;
return Poll::Ready(Err(MissedLastHeartbeat {}));
}

Poll::Ready(Ok(EgressControlMessage::Heartbeat(crate::Empty {})))
}
ready!(self.interval.poll_tick(cx));

fn new(interval: Duration) -> Self {
let mut interval = tokio::time::interval(interval);
interval.set_missed_tick_behavior(MissedTickBehavior::Skip);
let next_id = self
.next_request_id
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
self.pending = Some((
OutboundRequestId(next_id),
Box::pin(tokio::time::sleep(self.timeout)),
));

Self {
interval: Box::pin(interval),
id: Default::default(),
}
Poll::Ready(Ok(OutboundRequestId(next_id)))
}
}

#[derive(Debug)]
pub struct MissedLastHeartbeat {}

impl Default for Heartbeat {
fn default() -> Self {
Self::new(HEARTBEAT_INTERVAL)
}
}

#[cfg(test)]
mod tests {
use super::*;
use futures::future::Either;
use std::{future::poll_fn, time::Instant};

const INTERVAL: Duration = Duration::from_millis(30);
const TIMEOUT: Duration = Duration::from_millis(5);

#[tokio::test]
async fn returns_heartbeat_after_interval() {
let mut heartbeat = Heartbeat::new(Duration::from_millis(30));
let _ = poll_fn(|cx| heartbeat.poll(cx)).await; // Tick once at startup.
let mut heartbeat = Heartbeat::new(INTERVAL, TIMEOUT, Arc::new(AtomicU64::new(0)));
let id = poll_fn(|cx| heartbeat.poll(cx)).await.unwrap(); // Tick once at startup.
heartbeat.maybe_handle_reply(id);

let start = Instant::now();

Expand All @@ -84,29 +96,60 @@ mod tests {
let elapsed = start.elapsed();

assert!(result.is_ok());
assert!(elapsed >= Duration::from_millis(10));
assert!(elapsed >= INTERVAL);
}

#[tokio::test]
async fn fails_if_response_is_not_provided_before_next_poll() {
let mut heartbeat = Heartbeat::new(Duration::from_millis(10));
let mut heartbeat = Heartbeat::new(INTERVAL, TIMEOUT, Arc::new(AtomicU64::new(0)));

let _ = poll_fn(|cx| heartbeat.poll(cx)).await;
heartbeat.set_id(OutboundRequestId::for_test(1));

let result = poll_fn(|cx| heartbeat.poll(cx)).await;
assert!(result.is_err());
}

#[tokio::test]
async fn succeeds_if_response_is_provided_inbetween_polls() {
let mut heartbeat = Heartbeat::new(Duration::from_millis(10));
async fn ignores_other_ids() {
let mut heartbeat = Heartbeat::new(INTERVAL, TIMEOUT, Arc::new(AtomicU64::new(0)));

let _ = poll_fn(|cx| heartbeat.poll(cx)).await;
heartbeat.set_id(OutboundRequestId::for_test(1));
heartbeat.maybe_handle_reply(OutboundRequestId::for_test(1));
heartbeat.maybe_handle_reply(OutboundRequestId::for_test(2));

let result = poll_fn(|cx| heartbeat.poll(cx)).await;
assert!(result.is_err());
}

#[tokio::test]
async fn succeeds_if_response_is_provided_inbetween_polls() {
let mut heartbeat = Heartbeat::new(INTERVAL, TIMEOUT, Arc::new(AtomicU64::new(0)));

let id = poll_fn(|cx| heartbeat.poll(cx)).await.unwrap();
heartbeat.maybe_handle_reply(id);

let result = poll_fn(|cx| heartbeat.poll(cx)).await;
assert!(result.is_ok());
}

#[tokio::test]
async fn fails_if_not_provided_within_timeout() {
let mut heartbeat = Heartbeat::new(INTERVAL, TIMEOUT, Arc::new(AtomicU64::new(0)));

let id = poll_fn(|cx| heartbeat.poll(cx)).await.unwrap();

let select = futures::future::select(
tokio::time::sleep(TIMEOUT * 2).boxed(),
poll_fn(|cx| heartbeat.poll(cx)),
)
.await;

match select {
Either::Left(((), _)) => panic!("timeout should not resolve"),
Either::Right((Ok(_), _)) => panic!("heartbeat should fail and not issue new ID"),
Either::Right((Err(_), _)) => {}
}

let handled = heartbeat.maybe_handle_reply(id);
assert!(!handled);
}
}
49 changes: 34 additions & 15 deletions rust/phoenix-channel/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,16 @@ use tokio_tungstenite::{
};

pub use login_url::{LoginUrl, LoginUrlError};
use std::sync::atomic::AtomicU64;
use std::sync::Arc;

// TODO: Refactor this PhoenixChannel to be compatible with the needs of the client and gateway
// See https://github.com/firezone/firezone/issues/2158
pub struct PhoenixChannel<TInitReq, TInboundMsg, TOutboundRes> {
state: State,
waker: Option<Waker>,
pending_messages: VecDeque<String>,
next_request_id: u64,
next_request_id: Arc<AtomicU64>,

heartbeat: Heartbeat,

Expand Down Expand Up @@ -208,6 +210,8 @@ where
init_req: TInitReq,
reconnect_backoff: ExponentialBackoff,
) -> Self {
let next_request_id = Arc::new(AtomicU64::new(0));

Self {
reconnect_backoff,
url: url.clone(),
Expand All @@ -222,8 +226,12 @@ where
waker: None,
pending_messages: Default::default(),
_phantom: PhantomData,
next_request_id: 0,
heartbeat: Default::default(),
heartbeat: Heartbeat::new(
heartbeat::INTERVAL,
heartbeat::TIMEOUT,
next_request_id.clone(),
),
next_request_id,
pending_join_requests: Default::default(),
login,
init_req: init_req.clone(),
Expand Down Expand Up @@ -447,10 +455,12 @@ where

// Priority 3: Handle heartbeats.
match self.heartbeat.poll(cx) {
Poll::Ready(Ok(msg)) => {
let (id, msg) = self.make_message("phoenix", msg);
self.pending_messages.push_back(msg);
self.heartbeat.set_id(id);
Poll::Ready(Ok(id)) => {
self.pending_messages.push_back(serialize_msg(
"phoenix",
EgressControlMessage::<()>::Heartbeat(Empty {}),
id.copy(),
));
return Poll::Ready(Ok(Event::HeartbeatSent));
}
Expand Down Expand Up @@ -492,19 +502,15 @@ where
let request_id = self.fetch_add_request_id();

// We don't care about the reply type when serializing
let msg = serde_json::to_string(&PhoenixMessage::<_, ()>::new_message(
topic,
payload,
Some(request_id.copy()),
))
.expect("we should always be able to serialize a join topic message");
let msg = serialize_msg(topic, payload, request_id.copy());

(request_id, msg)
}

fn fetch_add_request_id(&mut self) -> OutboundRequestId {
let next_id = self.next_request_id;
self.next_request_id += 1;
let next_id = self
.next_request_id
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);

OutboundRequestId(next_id)
}
Expand Down Expand Up @@ -685,6 +691,19 @@ enum EgressControlMessage<T> {
Heartbeat(Empty),
}

fn serialize_msg(
topic: impl Into<String>,
payload: impl Serialize,
request_id: OutboundRequestId,
) -> String {
serde_json::to_string(&PhoenixMessage::<_, ()>::new_message(
topic,
payload,
Some(request_id),
))
.expect("we should always be able to serialize a join topic message")
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down

0 comments on commit ecce024

Please sign in to comment.