Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(connlib): remove Option<RawFd> return value from Callbacks #4471

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
154 changes: 92 additions & 62 deletions rust/connlib/clients/android/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

use connlib_client_shared::{
file_logger, keypair, Callbacks, Cidrv4, Cidrv6, Error, LoginUrl, LoginUrlError,
ResourceDescription, Session, Sockets,
ResourceDescription, Session, Sockets, Tun,
};
use jni::{
objects::{GlobalRef, JClass, JObject, JString, JValue},
Expand Down Expand Up @@ -34,6 +34,7 @@ pub struct CallbackHandler {
vm: JavaVM,
callback_handler: GlobalRef,
handle: file_logger::Handle,
new_tun_sender: tokio::sync::mpsc::Sender<Tun>,
}

impl Clone for CallbackHandler {
Expand All @@ -47,6 +48,7 @@ impl Clone for CallbackHandler {
vm: unsafe { std::ptr::read(&self.vm) },
callback_handler: self.callback_handler.clone(),
handle: self.handle.clone(),
new_tun_sender: self.new_tun_sender.clone(),
}
}
}
Expand All @@ -72,6 +74,17 @@ pub enum CallbackError {
}

impl CallbackHandler {
fn set_new_tun(&self, new_fd: RawFd) {
match Tun::with_fd(new_fd) {
Ok(tun) => {
let _ = self.new_tun_sender.blocking_send(tun); // If this fails, connlib is shutting down so we don't care.
}
Err(e) => {
tracing::error!("Failed to make new `Tun`: {e}");
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, any error when creating a new Tun device is also only bubbled up

) -> Result<(), ConnlibError> {
and eventually logged
tracing::warn!("Tunnel error: {e}");

so this isn't a change in behaviour.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we fail to create the Tun for some reason, will the app silently go on without tunneling any traffic?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is the behaviour today.

}
};
}

fn env<T>(
&self,
f: impl FnOnce(JNIEnv) -> Result<T, CallbackError>,
Expand Down Expand Up @@ -157,75 +170,75 @@ impl Callbacks for CallbackHandler {
tunnel_address_v4: Ipv4Addr,
tunnel_address_v6: Ipv6Addr,
dns_addresses: Vec<IpAddr>,
) -> Option<RawFd> {
self.env(|mut env| {
let tunnel_address_v4 =
env.new_string(tunnel_address_v4.to_string())
) {
let new_fd = self
.env(|mut env| {
let tunnel_address_v4 =
env.new_string(tunnel_address_v4.to_string())
.map_err(|source| CallbackError::NewStringFailed {
name: "tunnel_address_v4",
source,
})?;
let tunnel_address_v6 =
env.new_string(tunnel_address_v6.to_string())
.map_err(|source| CallbackError::NewStringFailed {
name: "tunnel_address_v6",
source,
})?;
let dns_addresses = env
.new_string(serde_json::to_string(&dns_addresses)?)
.map_err(|source| CallbackError::NewStringFailed {
name: "tunnel_address_v4",
name: "dns_addresses",
source,
})?;
let tunnel_address_v6 =
env.new_string(tunnel_address_v6.to_string())
let name = "onSetInterfaceConfig";
env.call_method(
&self.callback_handler,
name,
"(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)I",
&[
JValue::from(&tunnel_address_v4),
JValue::from(&tunnel_address_v6),
JValue::from(&dns_addresses),
],
)
.and_then(|val| val.i())
.map_err(|source| CallbackError::CallMethodFailed { name, source })
})
.expect("onSetInterfaceConfig callback failed");

self.set_new_tun(new_fd);
}

fn on_update_routes(&self, route_list_4: Vec<Cidrv4>, route_list_6: Vec<Cidrv6>) {
let new_fd = self
.env(|mut env| {
let route_list_4 = env
.new_string(serde_json::to_string(&route_list_4)?)
.map_err(|source| CallbackError::NewStringFailed {
name: "tunnel_address_v6",
name: "route_list_4",
source,
})?;
let route_list_6 = env
.new_string(serde_json::to_string(&route_list_6)?)
.map_err(|source| CallbackError::NewStringFailed {
name: "route_list_6",
source,
})?;
let dns_addresses = env
.new_string(serde_json::to_string(&dns_addresses)?)
.map_err(|source| CallbackError::NewStringFailed {
name: "dns_addresses",
source,
})?;
let name = "onSetInterfaceConfig";
env.call_method(
&self.callback_handler,
name,
"(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)I",
&[
JValue::from(&tunnel_address_v4),
JValue::from(&tunnel_address_v6),
JValue::from(&dns_addresses),
],
)
.and_then(|val| val.i())
.map(Some)
.map_err(|source| CallbackError::CallMethodFailed { name, source })
})
.expect("onSetInterfaceConfig callback failed")
}

fn on_update_routes(
&self,
route_list_4: Vec<Cidrv4>,
route_list_6: Vec<Cidrv6>,
) -> Option<RawFd> {
self.env(|mut env| {
let route_list_4 = env
.new_string(serde_json::to_string(&route_list_4)?)
.map_err(|source| CallbackError::NewStringFailed {
name: "route_list_4",
source,
})?;
let route_list_6 = env
.new_string(serde_json::to_string(&route_list_6)?)
.map_err(|source| CallbackError::NewStringFailed {
name: "route_list_6",
source,
})?;
let name = "onUpdateRoutes";
env.call_method(
&self.callback_handler,
name,
"(Ljava/lang/String;Ljava/lang/String;)I",
&[JValue::from(&route_list_4), JValue::from(&route_list_6)],
)
.and_then(|val| val.i())
.map_err(|source| CallbackError::CallMethodFailed { name, source })
})
.expect("onUpdateRoutes callback failed");

let name = "onUpdateRoutes";
env.call_method(
&self.callback_handler,
name,
"(Ljava/lang/String;Ljava/lang/String;)I",
&[JValue::from(&route_list_4), JValue::from(&route_list_6)],
)
.and_then(|val| val.i())
.map(Some)
.map_err(|source| CallbackError::CallMethodFailed { name, source })
})
.expect("onUpdateRoutes callback failed")
self.set_new_tun(new_fd);
}

fn on_update_resources(&self, resource_list: Vec<ResourceDescription>) {
Expand Down Expand Up @@ -346,10 +359,13 @@ fn connect(

let handle = init_logging(&PathBuf::from(log_dir), log_filter);

let (new_tun_sender, mut new_tun_receiver) = tokio::sync::mpsc::channel(1);

let callback_handler = CallbackHandler {
vm: env.get_java_vm().map_err(ConnectError::GetJavaVmFailed)?,
callback_handler,
handle,
new_tun_sender,
};

let (private_key, public_key) = keypair();
Expand Down Expand Up @@ -386,6 +402,20 @@ fn connect(
runtime.handle().clone(),
);

// This is annoyingly redundant because `Session` already has a `Sender` inside.
// It would be nice to just directly send into that channel.
// We cannot do that because we only construct the channel within `Session::connect` which already requires us to pass the `CallbackHandler`: Circular dependency!
// This will be resolve itself once we no longer have callbacks.
runtime.spawn({
let session = session.clone();

async move {
while let Some(tun) = new_tun_receiver.recv().await {
session.set_tun(tun)
}
}
});

Ok(SessionWrapper {
inner: session,
runtime,
Expand Down
13 changes: 2 additions & 11 deletions rust/connlib/clients/apple/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use connlib_client_shared::{
use secrecy::SecretString;
use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr},
os::fd::RawFd,
path::PathBuf,
sync::Arc,
time::Duration,
Expand Down Expand Up @@ -107,28 +106,20 @@ impl Callbacks for CallbackHandler {
tunnel_address_v4: Ipv4Addr,
tunnel_address_v6: Ipv6Addr,
dns_addresses: Vec<IpAddr>,
) -> Option<RawFd> {
) {
self.inner.on_set_interface_config(
tunnel_address_v4.to_string(),
tunnel_address_v6.to_string(),
serde_json::to_string(&dns_addresses)
.expect("developer error: a list of ips should always be serializable"),
);

None
}

fn on_update_routes(
&self,
route_list_4: Vec<Cidrv4>,
route_list_6: Vec<Cidrv6>,
) -> Option<RawFd> {
fn on_update_routes(&self, route_list_4: Vec<Cidrv4>, route_list_6: Vec<Cidrv6>) {
self.inner.on_update_routes(
serde_json::to_string(&route_list_4).unwrap(),
serde_json::to_string(&route_list_6).unwrap(),
);

None
}

fn on_update_resources(&self, resource_list: Vec<ResourceDescription>) {
Expand Down
4 changes: 3 additions & 1 deletion rust/connlib/clients/shared/src/eventloop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use connlib_shared::{
messages::{ConnectionAccepted, GatewayResponse, ResourceAccepted, ResourceId},
Callbacks,
};
use firezone_tunnel::ClientTunnel;
use firezone_tunnel::{ClientTunnel, Tun};
use phoenix_channel::{ErrorReply, OutboundRequestId, PhoenixChannel};
use std::{
collections::HashMap,
Expand All @@ -32,6 +32,7 @@ pub enum Command {
Stop,
Reconnect,
SetDns(Vec<IpAddr>),
SetTun(Tun),
}

impl<C: Callbacks> Eventloop<C> {
Expand Down Expand Up @@ -62,6 +63,7 @@ where
tracing::warn!("Failed to update DNS: {e}");
}
}
Poll::Ready(Some(Command::SetTun(tun))) => self.tunnel.set_tun(tun),
Poll::Ready(Some(Command::Reconnect)) => {
self.portal.reconnect();
if let Err(e) = self.tunnel.reconnect() {
Expand Down
16 changes: 11 additions & 5 deletions rust/connlib/clients/shared/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,32 @@ pub use connlib_shared::messages::ResourceDescription;
pub use connlib_shared::{
keypair, Callbacks, Cidrv4, Cidrv6, Error, LoginUrl, LoginUrlError, StaticSecret,
};
pub use eventloop::Eventloop;
pub use firezone_tunnel::Sockets;
pub use firezone_tunnel::Tun;
pub use tracing_appender::non_blocking::WorkerGuard;

use backoff::ExponentialBackoffBuilder;
use connlib_shared::get_user_agent;
use eventloop::Command;
use firezone_tunnel::ClientTunnel;
use phoenix_channel::PhoenixChannel;
use secrecy::Secret;
use std::net::IpAddr;
use std::time::Duration;
use tokio::sync::mpsc::UnboundedReceiver;
use tokio::task::JoinHandle;

mod eventloop;
pub mod file_logger;
mod messages;

const PHOENIX_TOPIC: &str = "client";

use eventloop::Command;
pub use eventloop::Eventloop;
use secrecy::Secret;
use tokio::task::JoinHandle;

/// A session is the entry-point for connlib, maintains the runtime and the tunnel.
///
/// A session is created using [Session::connect], then to stop a session we use [Session::disconnect].
#[derive(Clone)]
pub struct Session {
channel: tokio::sync::mpsc::UnboundedSender<Command>,
}
Expand Down Expand Up @@ -95,6 +96,11 @@ impl Session {
let _ = self.channel.send(Command::SetDns(new_dns));
}

/// Sets a new TUN device for this [`Session`].
pub fn set_tun(&self, new_tun: Tun) {
let _ = self.channel.send(Command::SetTun(new_tun));
}

/// Disconnect a [`Session`].
///
/// This consumes [`Session`] which cleans up all state associated with it.
Expand Down
14 changes: 2 additions & 12 deletions rust/connlib/shared/src/callbacks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@ use serde::Serialize;
use std::fmt::Debug;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};

// Avoids having to map types for Windows
type RawFd = i32;

#[derive(Serialize, Clone, Copy, Debug)]
/// Identical to `ip_network::Ipv4Network` except we implement `Serialize` on the Rust side and the equivalent of `Deserialize` on the Swift / Kotlin side to avoid manually serializing and deserializing.
pub struct Cidrv4 {
Expand Down Expand Up @@ -42,17 +39,10 @@ impl From<Ipv6Network> for Cidrv6 {
/// Traits that will be used by connlib to callback the client upper layers.
pub trait Callbacks: Clone + Send + Sync {
/// Called when the tunnel address is set.
///
/// This should return a new `fd` if there is one.
/// (Only happens on android for now)
fn on_set_interface_config(&self, _: Ipv4Addr, _: Ipv6Addr, _: Vec<IpAddr>) -> Option<RawFd> {
None
}
fn on_set_interface_config(&self, _: Ipv4Addr, _: Ipv6Addr, _: Vec<IpAddr>) {}

/// Called when the route list changes.
fn on_update_routes(&self, _: Vec<Cidrv4>, _: Vec<Cidrv6>) -> Option<RawFd> {
None
}
fn on_update_routes(&self, _: Vec<Cidrv4>, _: Vec<Cidrv6>) {}

/// Called when the resource list changes.
fn on_update_resources(&self, _: Vec<ResourceDescription>) {}
Expand Down
6 changes: 5 additions & 1 deletion rust/connlib/tunnel/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use ip_network_table::IpNetworkTable;
use itertools::Itertools;

use crate::utils::{earliest, stun, turn};
use crate::{ClientEvent, ClientTunnel};
use crate::{ClientEvent, ClientTunnel, Tun};
use secrecy::{ExposeSecret as _, Secret};
use snownet::ClientNode;
use std::collections::hash_map::Entry;
Expand Down Expand Up @@ -124,6 +124,10 @@ where
Ok(())
}

pub fn set_tun(&mut self, tun: Tun) {
self.io.device_mut().set_tun(tun);
}

#[tracing::instrument(level = "trace", skip(self))]
pub fn set_new_interface_config(
&mut self,
Expand Down