Skip to content

Commit

Permalink
refactor(connlib): explicitly set DNS from clients instead of request…
Browse files Browse the repository at this point in the history
…ing it via callback (#4240)

Extracted from #4163

Dependant PRs:
#4198
#4133
#4163
  • Loading branch information
conectado committed Mar 21, 2024
1 parent 7449c9b commit 40f5fa3
Show file tree
Hide file tree
Showing 11 changed files with 240 additions and 121 deletions.
33 changes: 18 additions & 15 deletions rust/connlib/clients/android/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,21 @@ impl CallbackHandler {
.map_err(CallbackError::AttachCurrentThreadFailed)
.and_then(f)
}

fn get_system_default_resolvers(&self) -> Vec<IpAddr> {
self.env(|mut env| {
let name = "getSystemDefaultResolvers";
let addrs = env
.call_method(&self.callback_handler, name, "()[[B", &[])
.and_then(JValueGen::l)
.and_then(|arr| convert_byte_array_array(&mut env, arr.into()))
.map_err(|source| CallbackError::CallMethodFailed { name, source })?;

Ok(Some(addrs.iter().filter_map(|v| to_ip(v)).collect()))
})
.expect("getSystemDefaultResolvers callback failed")
.unwrap_or_default()
}
}

fn call_method(
Expand Down Expand Up @@ -286,20 +301,6 @@ impl Callbacks for CallbackHandler {
None
})
}

fn get_system_default_resolvers(&self) -> Option<Vec<IpAddr>> {
self.env(|mut env| {
let name = "getSystemDefaultResolvers";
let addrs = env
.call_method(&self.callback_handler, name, "()[[B", &[])
.and_then(JValueGen::l)
.and_then(|arr| convert_byte_array_array(&mut env, arr.into()))
.map_err(|source| CallbackError::CallMethodFailed { name, source })?;

Ok(Some(addrs.iter().filter_map(|v| to_ip(v)).collect()))
})
.expect("getSystemDefaultResolvers callback failed")
}
}

fn to_ip(val: &[u8]) -> Option<IpAddr> {
Expand Down Expand Up @@ -427,11 +428,13 @@ fn connect(
login,
private_key,
Some(os_version),
callback_handler,
callback_handler.clone(),
Some(MAX_PARTITION_TIME),
runtime.handle().clone(),
)?;

session.set_dns(callback_handler.get_system_default_resolvers());

Ok(SessionWrapper {
inner: session,
runtime,
Expand Down
20 changes: 7 additions & 13 deletions rust/connlib/clients/apple/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ mod ffi {
#[swift_bridge(swift_name = "onDisconnect")]
fn on_disconnect(&self, error: String);

// TODO: remove in favor of set_dns
#[swift_bridge(swift_name = "getSystemDefaultResolvers")]
fn get_system_default_resolvers(&self) -> String;
}
Expand Down Expand Up @@ -141,19 +142,6 @@ impl Callbacks for CallbackHandler {
self.inner.on_disconnect(error.to_string());
}

fn get_system_default_resolvers(&self) -> Option<Vec<IpAddr>> {
let resolvers_json = self.inner.get_system_default_resolvers();
tracing::debug!(
"get_system_default_resolvers returned: {:?}",
resolvers_json
);

let resolvers: Vec<IpAddr> = serde_json::from_str(&resolvers_json)
.expect("developer error: failed to deserialize resolvers");

Some(resolvers)
}

fn roll_log_file(&self) -> Option<PathBuf> {
self.handle.roll_to_new_file().unwrap_or_else(|e| {
tracing::error!("Failed to roll over to new log file: {e}");
Expand Down Expand Up @@ -193,6 +181,10 @@ impl WrappedSession {
let handle = init_logging(log_dir.into(), log_filter).map_err(|e| e.to_string())?;
let secret = SecretString::from(token);

let resolvers_json = callback_handler.get_system_default_resolvers();
let resolvers: Vec<IpAddr> = serde_json::from_str(&resolvers_json)
.expect("developer error: failed to deserialize resolvers");

let (private_key, public_key) = keypair();
let login = LoginUrl::client(
api_url.as_str(),
Expand Down Expand Up @@ -223,6 +215,8 @@ impl WrappedSession {
)
.map_err(|err| err.to_string())?;

session.set_dns(resolvers);

Ok(Self {
inner: session,
runtime,
Expand Down
15 changes: 9 additions & 6 deletions rust/connlib/clients/shared/src/eventloop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,20 @@ use phoenix_channel::{ErrorReply, OutboundRequestId, PhoenixChannel};
use std::{
collections::HashMap,
io,
net::IpAddr,
path::PathBuf,
task::{Context, Poll},
time::Duration,
time::{Duration, Instant},
};
use tokio::time::{Instant, Interval, MissedTickBehavior};
use tokio::time::{Interval, MissedTickBehavior};
use url::Url;

pub struct Eventloop<C: Callbacks> {
tunnel: ClientTunnel<C>,
tunnel_init: bool,

portal: PhoenixChannel<(), IngressMessages, ReplyMessages>,
rx: tokio::sync::mpsc::Receiver<Command>,
rx: tokio::sync::mpsc::UnboundedReceiver<Command>,

connection_intents: SentConnectionIntents,
log_upload_interval: tokio::time::Interval,
Expand All @@ -37,13 +38,14 @@ pub struct Eventloop<C: Callbacks> {
pub enum Command {
Stop,
Reconnect,
SetDns(Vec<IpAddr>),
}

impl<C: Callbacks> Eventloop<C> {
pub(crate) fn new(
tunnel: ClientTunnel<C>,
portal: PhoenixChannel<(), IngressMessages, ReplyMessages>,
rx: tokio::sync::mpsc::Receiver<Command>,
rx: tokio::sync::mpsc::UnboundedReceiver<Command>,
) -> Self {
Self {
tunnel,
Expand All @@ -65,6 +67,7 @@ where
loop {
match self.rx.poll_recv(cx) {
Poll::Ready(Some(Command::Stop)) | Poll::Ready(None) => return Poll::Ready(Ok(())),
Poll::Ready(Some(Command::SetDns(dns))) => self.tunnel.set_dns(dns, Instant::now()),
Poll::Ready(Some(Command::Reconnect)) => {
self.portal.reconnect();
self.tunnel.reconnect();
Expand Down Expand Up @@ -180,7 +183,7 @@ where
resources,
}) => {
if !self.tunnel_init {
if let Err(e) = self.tunnel.set_interface(&interface) {
if let Err(e) = self.tunnel.set_interface(interface) {
tracing::warn!("Failed to set interface on tunnel: {e}");
return;
}
Expand Down Expand Up @@ -364,7 +367,7 @@ async fn upload(_path: PathBuf, _url: Url) -> io::Result<()> {

fn upload_interval() -> Interval {
let duration = upload_interval_duration_from_env_or_default();
let mut interval = tokio::time::interval_at(Instant::now() + duration, duration);
let mut interval = tokio::time::interval_at(tokio::time::Instant::now() + duration, duration);
interval.set_missed_tick_behavior(MissedTickBehavior::Skip);

interval
Expand Down
18 changes: 12 additions & 6 deletions rust/connlib/clients/shared/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ pub use connlib_shared::messages::ResourceDescription;
pub use connlib_shared::{
keypair, Callbacks, Cidrv4, Cidrv6, Error, LoginUrl, LoginUrlError, StaticSecret,
};
use tokio::sync::mpsc::UnboundedReceiver;
pub use tracing_appender::non_blocking::WorkerGuard;

use backoff::ExponentialBackoffBuilder;
use connlib_shared::get_user_agent;
use firezone_tunnel::ClientTunnel;
use phoenix_channel::PhoenixChannel;
use std::net::IpAddr;
use std::time::Duration;

mod eventloop;
Expand All @@ -26,7 +28,7 @@ use tokio::task::JoinHandle;
///
/// A session is created using [Session::connect], then to stop a session we use [Session::disconnect].
pub struct Session {
channel: tokio::sync::mpsc::Sender<Command>,
channel: tokio::sync::mpsc::UnboundedSender<Command>,
}

impl Session {
Expand All @@ -41,7 +43,7 @@ impl Session {
max_partition_time: Option<Duration>,
handle: tokio::runtime::Handle,
) -> connlib_shared::Result<Self> {
let (tx, rx) = tokio::sync::mpsc::channel(1);
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();

let connect_handle = handle.spawn(connect(
url,
Expand All @@ -68,15 +70,19 @@ impl Session {
///
/// In case of destructive network state changes, i.e. the user switched from wifi to cellular,
/// reconnect allows connlib to re-establish connections faster because we don't have to wait for timeouts first.
pub fn reconnect(&mut self) {
let _ = self.channel.try_send(Command::Reconnect);
pub fn reconnect(&self) {
let _ = self.channel.send(Command::Reconnect);
}

pub fn set_dns(&self, new_dns: Vec<IpAddr>) {
let _ = self.channel.send(Command::SetDns(new_dns));
}

/// Disconnect a [`Session`].
///
/// This consumes [`Session`] which cleans up all state associated with it.
pub fn disconnect(self) {
let _ = self.channel.try_send(Command::Stop);
let _ = self.channel.send(Command::Stop);
}
}

Expand All @@ -89,7 +95,7 @@ async fn connect<CB>(
os_version_override: Option<String>,
callbacks: CB,
max_partition_time: Option<Duration>,
rx: tokio::sync::mpsc::Receiver<Command>,
rx: UnboundedReceiver<Command>,
) -> Result<(), Error>
where
CB: Callbacks + 'static,
Expand Down
8 changes: 0 additions & 8 deletions rust/connlib/shared/src/callbacks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,6 @@ pub trait Callbacks: Clone + Send + Sync {
std::process::exit(0);
}

/// Returns the system's default resolver(s)
///
/// It's okay for clients to include Firezone's own DNS here, e.g. 100.100.111.1.
/// connlib internally filters them out.
fn get_system_default_resolvers(&self) -> Option<Vec<IpAddr>> {
None
}

/// Protects the socket file descriptor from routing loops.
#[cfg(target_os = "android")]
fn protect_file_descriptor(&self, file_descriptor: std::os::fd::RawFd);
Expand Down

0 comments on commit 40f5fa3

Please sign in to comment.