Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(connlib): introduce Session::reconnect #4116

Merged
merged 8 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions rust/connlib/clients/shared/src/eventloop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub struct Eventloop<C: Callbacks> {
/// Commands that can be sent to the [`Eventloop`].
pub enum Command {
Stop,
Reconnect,
}

impl<C: Callbacks> Eventloop<C> {
Expand Down Expand Up @@ -64,6 +65,12 @@ where
loop {
match self.rx.poll_recv(cx) {
Poll::Ready(Some(Command::Stop)) | Poll::Ready(None) => return Poll::Ready(Ok(())),
Poll::Ready(Some(Command::Reconnect)) => {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if using a channel for this is the most convenient way to go about this.

We might want to use a different mechanism so that multiple reconnects aren't queued up I was thinking we can use a Notify, that way we don't need to worry about the bounded channel and there's no point on doing multiple reconnects in a row we want to just listen to the latest.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah if they're guaranteed idempotent that would keep the channel from filling up. We did a similar thing for on_update_resources in the Tauri Client:

self.notify_controller.notify_one();

It was possible if I only allowed 5 items in the channel, and connlib rapidly sent on_update_resources events, that the channel might fill up and error (since it's not allowed to block the callbacks) before the GUI got around to dealing with them.

So the channel was replaced with a Notify and something that the reader can poll when it's notified, same as if it was a channel that dropped all but the most recent event.

I think we also considered tokio's watch and it wasn't a perfect fit. Notify has worked well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, Notify doesn't have a poll API so it would be a bit clunky to use. If debouncing is what we want, then I can add a small delay to the sending of the command through the channel and cancel the current send if we get another one.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh, yeah it doesn't. And it can't be replicated with AtomicWaker?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, a size 1 channel would also achieve the same effect as Notify

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm so it's a 1-sized channel that uses try_send for both commands, so if I did reconnect and stop in the same tick somehow, the stop will be silently ignored?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. The Stop isn't super critical though. If you drop Session, the Runtime gets dropped and with it, all tasks should be stopped.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even the Stop does not need a command, as much as a "I want you to be running / not be running" flag and a way to notify when it's changed.

With channels we have to trade off between 3 problems: Sender may block, sends may fail silently, or sends may panic.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you suggesting using shared memory instead of channels instead and just notify when to re-read the shared memory?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kinda. AtomicBool, so not like it has to lock.
But I wrote my last comment at the same time as you wrote "The Stop isn't super critical", so it might just be something to put as an issue and merge this PR anyway

self.portal.reconnect();
self.tunnel.reconnect();

continue;
}
Poll::Pending => {}
}

Expand Down
22 changes: 16 additions & 6 deletions rust/connlib/clients/shared/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@ pub use eventloop::Eventloop;
use secrecy::Secret;
use tokio::task::JoinHandle;

/// Max interval to retry connections to the portal if it's down or the client has network
/// connectivity changes. Set this to something short so that the end-user experiences
/// minimal disruption to their Firezone resources when switching networks.
const MAX_RECONNECT_INTERVAL: Duration = Duration::from_secs(5);

/// A session is the entry-point for connlib, maintains the runtime and the tunnel.
///
/// A session is created using [Session::connect], then to stop a session we use [Session::disconnect].
Expand Down Expand Up @@ -60,6 +55,22 @@ impl Session {
Ok(Self { channel: tx })
}

/// Attempts to reconnect a [`Session`].
///
/// This can and should be called by client applications on any network state changes.
/// It is a signal to connlib to:
///
/// - validate all currently used network paths to relays and peers
/// - ensure we are connected to the portal
///
/// Reconnect is non-destructive and can be called several times in a row.
///
/// In case of destructive network state changes, i.e. the user switched from wifi to cellular,
/// reconnect allows connlib to re-establish connections faster because we don't have to wait for timeouts first.
pub fn reconnect(&mut self) {
let _ = self.channel.try_send(Command::Reconnect);
}

/// Disconnect a [`Session`].
///
/// This consumes [`Session`] which cleans up all state associated with it.
Expand Down Expand Up @@ -91,7 +102,6 @@ where
(),
ExponentialBackoffBuilder::default()
.with_max_elapsed_time(max_partition_time)
.with_max_interval(MAX_RECONNECT_INTERVAL)
.build(),
);

Expand Down
26 changes: 20 additions & 6 deletions rust/connlib/snownet/src/allocation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,16 +129,30 @@ impl Allocation {
.flatten()
}

/// Refresh this allocation.
/// Update the credentials of this [`Allocation`].
///
/// In case refreshing the allocation fails, we will attempt to make a new one.
pub fn refresh(&mut self, username: Username, password: &str, realm: Realm, now: Instant) {
self.update_now(now);

/// This will implicitly trigger a [`refresh`](Allocation::refresh) to ensure these credentials are valid.
pub fn update_credentials(
&mut self,
username: Username,
password: &str,
realm: Realm,
now: Instant,
) {
self.username = username;
self.realm = realm;
self.password = password.to_owned();

self.refresh(now);
}

/// Refresh this allocation.
///
/// In case refreshing the allocation fails, we will attempt to make a new one.
#[tracing::instrument(level = "debug", skip_all, fields(relay = %self.server))]
pub fn refresh(&mut self, now: Instant) {
self.update_now(now);

if !self.has_allocation() && self.allocate_in_flight() {
tracing::debug!("Not refreshing allocation because we are already making one");
return;
Expand Down Expand Up @@ -1998,7 +2012,7 @@ mod tests {
}

fn refresh_with_same_credentials(&mut self) {
self.refresh(
self.update_credentials(
Username::new("foobar".to_owned()).unwrap(),
"baz",
Realm::new("firezone".to_owned()).unwrap(),
Expand Down
12 changes: 11 additions & 1 deletion rust/connlib/snownet/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,16 @@ where
}
}

pub fn reconnect(&mut self, now: Instant) {
for binding in self.bindings.values_mut() {
binding.refresh(now);
}

for allocation in self.allocations.values_mut() {
allocation.refresh(now);
}
}

pub fn public_key(&self) -> PublicKey {
(&self.private_key).into()
}
Expand Down Expand Up @@ -902,7 +912,7 @@ where
};

if let Some(existing) = self.allocations.get_mut(server) {
existing.refresh(username, password, realm, now);
existing.update_credentials(username, password, realm, now);
ReactorScram marked this conversation as resolved.
Show resolved Hide resolved
continue;
}

Expand Down
15 changes: 15 additions & 0 deletions rust/connlib/snownet/src/stun_binding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,21 @@ impl StunBinding {
true
}

pub(crate) fn refresh(&mut self, now: Instant) {
self.last_now = now;
self.backoff.clock.now = now;

self.backoff.reset();
let backoff = self
.backoff
.next_backoff()
.expect("to have backoff right after resetting");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.expect("to have backoff right after resetting");
.expect("should have backoff Instant right after resetting");


let (state, transmit) = new_binding_request(self.server, now, backoff);
self.state = state;
self.buffered_transmits.push_back(transmit);
}

pub fn handle_timeout(&mut self, now: Instant) {
self.last_now = now;
self.backoff.clock.now = now;
Expand Down
4 changes: 4 additions & 0 deletions rust/connlib/tunnel/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ impl<CB> Tunnel<CB, ClientState, snownet::Client, GatewayId>
where
CB: Callbacks + 'static,
{
pub fn reconnect(&mut self) {
self.connections_state.node.reconnect(Instant::now());
}

pub fn poll_next_event(&mut self, cx: &mut Context<'_>) -> Poll<Result<Event<GatewayId>>> {
match self.role_state.poll_next_event(cx) {
Poll::Ready(Event::SendPacket(packet)) => {
Expand Down
27 changes: 24 additions & 3 deletions rust/linux-client/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use connlib_shared::{
};
use firezone_cli_utils::{setup_global_subscriber, CommonArgs};
use secrecy::SecretString;
use std::{net::IpAddr, path::PathBuf, str::FromStr};
use std::{future, net::IpAddr, path::PathBuf, str::FromStr, task::Poll};
use tokio::signal::unix::SignalKind;

#[tokio::main]
async fn main() -> Result<()> {
Expand Down Expand Up @@ -39,7 +40,7 @@ async fn main() -> Result<()> {
public_key.to_bytes(),
)?;

let session = Session::connect(
let mut session = Session::connect(
login,
private_key,
None,
Expand All @@ -49,13 +50,33 @@ async fn main() -> Result<()> {
)
.unwrap();

tokio::signal::ctrl_c().await?;
let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt())?;
conectado marked this conversation as resolved.
Show resolved Hide resolved
let mut sighup = tokio::signal::unix::signal(SignalKind::hangup())?;

future::poll_fn(|cx| loop {
if sigint.poll_recv(cx).is_ready() {
tracing::debug!("Received SIGINT");

return Poll::Ready(());
}

if sighup.poll_recv(cx).is_ready() {
tracing::debug!("Received SIGHUP");

session.reconnect();
continue;
}

return Poll::Pending;
})
.await;

if let Some(DnsControlMethod::EtcResolvConf) = dns_control_method {
etc_resolv_conf::unconfigure_dns()?;
}

session.disconnect();

Ok(())
}

Expand Down
42 changes: 36 additions & 6 deletions rust/phoenix-channel/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use heartbeat::{Heartbeat, MissedLastHeartbeat};
use rand_core::{OsRng, RngCore};
use secrecy::{ExposeSecret as _, Secret};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::task::{ready, Context, Poll};
use std::task::{Context, Poll, Waker};
use tokio::net::TcpStream;
use tokio_tungstenite::tungstenite::http::StatusCode;
use tokio_tungstenite::{
Expand All @@ -28,6 +28,7 @@ pub use login_url::{LoginUrl, LoginUrlError};
// See https://github.com/firezone/firezone/issues/2158
pub struct PhoenixChannel<TInitReq, TInboundMsg, TOutboundRes> {
state: State,
waker: Option<Waker>,
pending_messages: VecDeque<String>,
next_request_id: u64,

Expand Down Expand Up @@ -227,6 +228,7 @@ where

Ok(stream)
})),
waker: None,
pending_messages: Default::default(),
_phantom: PhantomData,
next_request_id: 0,
Expand All @@ -251,6 +253,28 @@ where
self.send_message(topic, message)
}

/// Reconnects to the portal.
pub fn reconnect(&mut self) {
// 1. Reset the backoff.
self.reconnect_backoff.reset();

// 2. Set state to `Connecting` without a timer.
let url = self.url.clone();
let user_agent = self.user_agent.clone();
self.state = State::Connecting(Box::pin(async move {
let (stream, _) = connect_async(make_request(url, user_agent))
.await
.map_err(InternalError::WebSocket)?;

Ok(stream)
}));

// 3. In case we were already re-connecting, we need to wake the suspended task.
if let Some(waker) = self.waker.take() {
waker.wake();
}
}

pub fn poll(
&mut self,
cx: &mut Context,
Expand All @@ -259,8 +283,8 @@ where
// First, check if we are connected.
let stream = match &mut self.state {
State::Connected(stream) => stream,
State::Connecting(future) => match ready!(future.poll_unpin(cx)) {
Ok(stream) => {
State::Connecting(future) => match future.poll_unpin(cx) {
Poll::Ready(Ok(stream)) => {
self.reconnect_backoff.reset();
self.state = State::Connected(stream);

Expand All @@ -271,12 +295,12 @@ where

continue;
}
Err(InternalError::WebSocket(tokio_tungstenite::tungstenite::Error::Http(
r,
Poll::Ready(Err(InternalError::WebSocket(
tokio_tungstenite::tungstenite::Error::Http(r),
))) if r.status().is_client_error() => {
return Poll::Ready(Err(Error::ClientError(r.status())));
}
Err(e) => {
Poll::Ready(Err(e)) => {
let Some(backoff) = self.reconnect_backoff.next_backoff() else {
tracing::warn!("Reconnect backoff expired");
return Poll::Ready(Err(Error::MaxRetriesReached));
Expand All @@ -298,6 +322,11 @@ where
}));
continue;
}
Poll::Pending => {
// Save a waker in case we want to reset the `Connecting` state while we are waiting.
self.waker = Some(cx.waker().clone());
return Poll::Pending;
}
},
};

Expand Down Expand Up @@ -494,6 +523,7 @@ where
reconnect_backoff: self.reconnect_backoff,
login: self.login,
init_req: self.init_req,
waker: self.waker,
}
}
}
Expand Down
Loading