diff --git a/rust/connlib/clients/android/src/lib.rs b/rust/connlib/clients/android/src/lib.rs index 94cdefd086..8b0a7eeb8d 100644 --- a/rust/connlib/clients/android/src/lib.rs +++ b/rust/connlib/clients/android/src/lib.rs @@ -5,7 +5,7 @@ use connlib_client_shared::{ file_logger, keypair, Callbacks, Cidrv4, Cidrv6, Error, LoginUrl, LoginUrlError, - ResourceDescription, Session, Sockets, + ResourceDescription, Session, Sockets, Tun, }; use jni::{ objects::{GlobalRef, JClass, JObject, JString, JValue}, @@ -34,6 +34,7 @@ pub struct CallbackHandler { vm: JavaVM, callback_handler: GlobalRef, handle: file_logger::Handle, + new_tun_sender: tokio::sync::mpsc::Sender, } impl Clone for CallbackHandler { @@ -47,6 +48,7 @@ impl Clone for CallbackHandler { vm: unsafe { std::ptr::read(&self.vm) }, callback_handler: self.callback_handler.clone(), handle: self.handle.clone(), + new_tun_sender: self.new_tun_sender.clone(), } } } @@ -72,6 +74,17 @@ pub enum CallbackError { } impl CallbackHandler { + fn set_new_tun(&self, new_fd: RawFd) { + match Tun::with_fd(new_fd) { + Ok(tun) => { + let _ = self.new_tun_sender.blocking_send(tun); // If this fails, connlib is shutting down so we don't care. + } + Err(e) => { + tracing::error!("Failed to make new `Tun`: {e}"); + } + }; + } + fn env( &self, f: impl FnOnce(JNIEnv) -> Result, @@ -157,75 +170,75 @@ impl Callbacks for CallbackHandler { tunnel_address_v4: Ipv4Addr, tunnel_address_v6: Ipv6Addr, dns_addresses: Vec, - ) -> Option { - self.env(|mut env| { - let tunnel_address_v4 = - env.new_string(tunnel_address_v4.to_string()) + ) { + let new_fd = self + .env(|mut env| { + let tunnel_address_v4 = + env.new_string(tunnel_address_v4.to_string()) + .map_err(|source| CallbackError::NewStringFailed { + name: "tunnel_address_v4", + source, + })?; + let tunnel_address_v6 = + env.new_string(tunnel_address_v6.to_string()) + .map_err(|source| CallbackError::NewStringFailed { + name: "tunnel_address_v6", + source, + })?; + let dns_addresses = env + .new_string(serde_json::to_string(&dns_addresses)?) .map_err(|source| CallbackError::NewStringFailed { - name: "tunnel_address_v4", + name: "dns_addresses", source, })?; - let tunnel_address_v6 = - env.new_string(tunnel_address_v6.to_string()) + let name = "onSetInterfaceConfig"; + env.call_method( + &self.callback_handler, + name, + "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)I", + &[ + JValue::from(&tunnel_address_v4), + JValue::from(&tunnel_address_v6), + JValue::from(&dns_addresses), + ], + ) + .and_then(|val| val.i()) + .map_err(|source| CallbackError::CallMethodFailed { name, source }) + }) + .expect("onSetInterfaceConfig callback failed"); + + self.set_new_tun(new_fd); + } + + fn on_update_routes(&self, route_list_4: Vec, route_list_6: Vec) { + let new_fd = self + .env(|mut env| { + let route_list_4 = env + .new_string(serde_json::to_string(&route_list_4)?) .map_err(|source| CallbackError::NewStringFailed { - name: "tunnel_address_v6", + name: "route_list_4", + source, + })?; + let route_list_6 = env + .new_string(serde_json::to_string(&route_list_6)?) + .map_err(|source| CallbackError::NewStringFailed { + name: "route_list_6", source, })?; - let dns_addresses = env - .new_string(serde_json::to_string(&dns_addresses)?) - .map_err(|source| CallbackError::NewStringFailed { - name: "dns_addresses", - source, - })?; - let name = "onSetInterfaceConfig"; - env.call_method( - &self.callback_handler, - name, - "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)I", - &[ - JValue::from(&tunnel_address_v4), - JValue::from(&tunnel_address_v6), - JValue::from(&dns_addresses), - ], - ) - .and_then(|val| val.i()) - .map(Some) - .map_err(|source| CallbackError::CallMethodFailed { name, source }) - }) - .expect("onSetInterfaceConfig callback failed") - } - fn on_update_routes( - &self, - route_list_4: Vec, - route_list_6: Vec, - ) -> Option { - self.env(|mut env| { - let route_list_4 = env - .new_string(serde_json::to_string(&route_list_4)?) - .map_err(|source| CallbackError::NewStringFailed { - name: "route_list_4", - source, - })?; - let route_list_6 = env - .new_string(serde_json::to_string(&route_list_6)?) - .map_err(|source| CallbackError::NewStringFailed { - name: "route_list_6", - source, - })?; + let name = "onUpdateRoutes"; + env.call_method( + &self.callback_handler, + name, + "(Ljava/lang/String;Ljava/lang/String;)I", + &[JValue::from(&route_list_4), JValue::from(&route_list_6)], + ) + .and_then(|val| val.i()) + .map_err(|source| CallbackError::CallMethodFailed { name, source }) + }) + .expect("onUpdateRoutes callback failed"); - let name = "onUpdateRoutes"; - env.call_method( - &self.callback_handler, - name, - "(Ljava/lang/String;Ljava/lang/String;)I", - &[JValue::from(&route_list_4), JValue::from(&route_list_6)], - ) - .and_then(|val| val.i()) - .map(Some) - .map_err(|source| CallbackError::CallMethodFailed { name, source }) - }) - .expect("onUpdateRoutes callback failed") + self.set_new_tun(new_fd); } fn on_update_resources(&self, resource_list: Vec) { @@ -346,10 +359,13 @@ fn connect( let handle = init_logging(&PathBuf::from(log_dir), log_filter); + let (new_tun_sender, mut new_tun_receiver) = tokio::sync::mpsc::channel(1); + let callback_handler = CallbackHandler { vm: env.get_java_vm().map_err(ConnectError::GetJavaVmFailed)?, callback_handler, handle, + new_tun_sender, }; let (private_key, public_key) = keypair(); @@ -386,6 +402,20 @@ fn connect( runtime.handle().clone(), ); + // This is annoyingly redundant because `Session` already has a `Sender` inside. + // It would be nice to just directly send into that channel. + // We cannot do that because we only construct the channel within `Session::connect` which already requires us to pass the `CallbackHandler`: Circular dependency! + // This will be resolve itself once we no longer have callbacks. + runtime.spawn({ + let session = session.clone(); + + async move { + while let Some(tun) = new_tun_receiver.recv().await { + session.set_tun(tun) + } + } + }); + Ok(SessionWrapper { inner: session, runtime, diff --git a/rust/connlib/clients/apple/src/lib.rs b/rust/connlib/clients/apple/src/lib.rs index 8df91c862e..b17ee09684 100644 --- a/rust/connlib/clients/apple/src/lib.rs +++ b/rust/connlib/clients/apple/src/lib.rs @@ -8,7 +8,6 @@ use connlib_client_shared::{ use secrecy::SecretString; use std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr}, - os::fd::RawFd, path::PathBuf, sync::Arc, time::Duration, @@ -107,28 +106,20 @@ impl Callbacks for CallbackHandler { tunnel_address_v4: Ipv4Addr, tunnel_address_v6: Ipv6Addr, dns_addresses: Vec, - ) -> Option { + ) { self.inner.on_set_interface_config( tunnel_address_v4.to_string(), tunnel_address_v6.to_string(), serde_json::to_string(&dns_addresses) .expect("developer error: a list of ips should always be serializable"), ); - - None } - fn on_update_routes( - &self, - route_list_4: Vec, - route_list_6: Vec, - ) -> Option { + fn on_update_routes(&self, route_list_4: Vec, route_list_6: Vec) { self.inner.on_update_routes( serde_json::to_string(&route_list_4).unwrap(), serde_json::to_string(&route_list_6).unwrap(), ); - - None } fn on_update_resources(&self, resource_list: Vec) { diff --git a/rust/connlib/clients/shared/src/eventloop.rs b/rust/connlib/clients/shared/src/eventloop.rs index 758cdc7b35..13c89361bc 100644 --- a/rust/connlib/clients/shared/src/eventloop.rs +++ b/rust/connlib/clients/shared/src/eventloop.rs @@ -10,7 +10,7 @@ use connlib_shared::{ messages::{ConnectionAccepted, GatewayResponse, ResourceAccepted, ResourceId}, Callbacks, }; -use firezone_tunnel::ClientTunnel; +use firezone_tunnel::{ClientTunnel, Tun}; use phoenix_channel::{ErrorReply, OutboundRequestId, PhoenixChannel}; use std::{ collections::HashMap, @@ -32,6 +32,7 @@ pub enum Command { Stop, Reconnect, SetDns(Vec), + SetTun(Tun), } impl Eventloop { @@ -62,6 +63,7 @@ where tracing::warn!("Failed to update DNS: {e}"); } } + Poll::Ready(Some(Command::SetTun(tun))) => self.tunnel.set_tun(tun), Poll::Ready(Some(Command::Reconnect)) => { self.portal.reconnect(); if let Err(e) = self.tunnel.reconnect() { diff --git a/rust/connlib/clients/shared/src/lib.rs b/rust/connlib/clients/shared/src/lib.rs index 9542d6cfd4..ac49995d9d 100644 --- a/rust/connlib/clients/shared/src/lib.rs +++ b/rust/connlib/clients/shared/src/lib.rs @@ -3,16 +3,21 @@ pub use connlib_shared::messages::ResourceDescription; pub use connlib_shared::{ keypair, Callbacks, Cidrv4, Cidrv6, Error, LoginUrl, LoginUrlError, StaticSecret, }; +pub use eventloop::Eventloop; pub use firezone_tunnel::Sockets; +pub use firezone_tunnel::Tun; pub use tracing_appender::non_blocking::WorkerGuard; use backoff::ExponentialBackoffBuilder; use connlib_shared::get_user_agent; +use eventloop::Command; use firezone_tunnel::ClientTunnel; use phoenix_channel::PhoenixChannel; +use secrecy::Secret; use std::net::IpAddr; use std::time::Duration; use tokio::sync::mpsc::UnboundedReceiver; +use tokio::task::JoinHandle; mod eventloop; pub mod file_logger; @@ -20,14 +25,10 @@ mod messages; const PHOENIX_TOPIC: &str = "client"; -use eventloop::Command; -pub use eventloop::Eventloop; -use secrecy::Secret; -use tokio::task::JoinHandle; - /// A session is the entry-point for connlib, maintains the runtime and the tunnel. /// /// A session is created using [Session::connect], then to stop a session we use [Session::disconnect]. +#[derive(Clone)] pub struct Session { channel: tokio::sync::mpsc::UnboundedSender, } @@ -95,6 +96,11 @@ impl Session { let _ = self.channel.send(Command::SetDns(new_dns)); } + /// Sets a new TUN device for this [`Session`]. + pub fn set_tun(&self, new_tun: Tun) { + let _ = self.channel.send(Command::SetTun(new_tun)); + } + /// Disconnect a [`Session`]. /// /// This consumes [`Session`] which cleans up all state associated with it. diff --git a/rust/connlib/shared/src/callbacks.rs b/rust/connlib/shared/src/callbacks.rs index 387c8d1e48..a92362cbe5 100644 --- a/rust/connlib/shared/src/callbacks.rs +++ b/rust/connlib/shared/src/callbacks.rs @@ -4,9 +4,6 @@ use serde::Serialize; use std::fmt::Debug; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; -// Avoids having to map types for Windows -type RawFd = i32; - #[derive(Serialize, Clone, Copy, Debug)] /// Identical to `ip_network::Ipv4Network` except we implement `Serialize` on the Rust side and the equivalent of `Deserialize` on the Swift / Kotlin side to avoid manually serializing and deserializing. pub struct Cidrv4 { @@ -42,17 +39,10 @@ impl From for Cidrv6 { /// Traits that will be used by connlib to callback the client upper layers. pub trait Callbacks: Clone + Send + Sync { /// Called when the tunnel address is set. - /// - /// This should return a new `fd` if there is one. - /// (Only happens on android for now) - fn on_set_interface_config(&self, _: Ipv4Addr, _: Ipv6Addr, _: Vec) -> Option { - None - } + fn on_set_interface_config(&self, _: Ipv4Addr, _: Ipv6Addr, _: Vec) {} /// Called when the route list changes. - fn on_update_routes(&self, _: Vec, _: Vec) -> Option { - None - } + fn on_update_routes(&self, _: Vec, _: Vec) {} /// Called when the resource list changes. fn on_update_resources(&self, _: Vec) {} diff --git a/rust/connlib/tunnel/src/client.rs b/rust/connlib/tunnel/src/client.rs index ed234c2c06..f43cb9559d 100644 --- a/rust/connlib/tunnel/src/client.rs +++ b/rust/connlib/tunnel/src/client.rs @@ -16,7 +16,7 @@ use ip_network_table::IpNetworkTable; use itertools::Itertools; use crate::utils::{earliest, stun, turn}; -use crate::{ClientEvent, ClientTunnel}; +use crate::{ClientEvent, ClientTunnel, Tun}; use secrecy::{ExposeSecret as _, Secret}; use snownet::ClientNode; use std::collections::hash_map::Entry; @@ -124,6 +124,10 @@ where Ok(()) } + pub fn set_tun(&mut self, tun: Tun) { + self.io.device_mut().set_tun(tun); + } + #[tracing::instrument(level = "trace", skip(self))] pub fn set_new_interface_config( &mut self, diff --git a/rust/connlib/tunnel/src/device_channel.rs b/rust/connlib/tunnel/src/device_channel.rs index e98f9177c8..3395d7c03d 100644 --- a/rust/connlib/tunnel/src/device_channel.rs +++ b/rust/connlib/tunnel/src/device_channel.rs @@ -35,7 +35,8 @@ use std::io; use std::net::IpAddr; use std::task::{Context, Poll, Waker}; use std::time::{Duration, Instant}; -use tun::Tun; + +pub use tun::Tun; pub struct Device { mtu: usize, @@ -70,7 +71,27 @@ impl Device { } } - #[cfg(any(target_os = "android", target_os = "linux"))] + pub(crate) fn set_tun(&mut self, tun: Tun) { + self.tun = Some(tun); + + if let Some(waker) = self.waker.take() { + waker.wake(); + } + } + + #[cfg(target_os = "android")] + pub(crate) fn set_config( + &mut self, + config: &Interface, + dns_config: Vec, + callbacks: &impl Callbacks, + ) -> Result<(), ConnlibError> { + callbacks.on_set_interface_config(config.ipv4, config.ipv6, dns_config); + + Ok(()) + } + + #[cfg(target_os = "linux")] pub(crate) fn set_config( &mut self, config: &Interface, diff --git a/rust/connlib/tunnel/src/device_channel/tun_android.rs b/rust/connlib/tunnel/src/device_channel/tun_android.rs index 2331dd216f..84eea32fd3 100644 --- a/rust/connlib/tunnel/src/device_channel/tun_android.rs +++ b/rust/connlib/tunnel/src/device_channel/tun_android.rs @@ -1,8 +1,7 @@ use super::utils; use crate::device_channel::{ioctl, ipv4, ipv6}; -use connlib_shared::{messages::Interface as InterfaceConfig, Callbacks, Error, Result}; +use connlib_shared::{Callbacks, Result}; use ip_network::IpNetwork; -use std::net::IpAddr; use std::task::{Context, Poll}; use std::{ collections::HashSet, @@ -14,7 +13,7 @@ use tokio::io::unix::AsyncFd; pub(crate) const SIOCGIFMTU: libc::c_ulong = libc::SIOCGIFMTU; #[derive(Debug)] -pub(crate) struct Tun { +pub struct Tun { fd: AsyncFd, name: String, } @@ -38,14 +37,7 @@ impl Tun { utils::poll_raw_fd(&self.fd, |fd| read(fd, buf), cx) } - pub fn new( - config: &InterfaceConfig, - dns_config: Vec, - callbacks: &impl Callbacks, - ) -> Result { - let fd = callbacks - .on_set_interface_config(config.ipv4, config.ipv6, dns_config) - .ok_or(Error::NoFd)?; + pub fn with_fd(fd: RawFd) -> Result { // Safety: File descriptor is open. let name = unsafe { interface_name(fd)? }; @@ -64,26 +56,10 @@ impl Tun { routes: HashSet, callbacks: &impl Callbacks, ) -> Result<()> { - let fd = callbacks - .on_update_routes( - routes.iter().copied().filter_map(ipv4).collect(), - routes.iter().copied().filter_map(ipv6).collect(), - ) - .ok_or(Error::NoFd)?; - - // SAFETY: we expect the callback to return a valid file descriptor - unsafe { self.replace_fd(fd)? }; - - Ok(()) - } - - // SAFETY: must be called with a valid file descriptor - unsafe fn replace_fd(&mut self, fd: RawFd) -> Result<()> { - if self.fd.as_raw_fd() != fd { - unsafe { libc::close(self.fd.as_raw_fd()) }; - self.fd = AsyncFd::new(fd)?; - self.name = interface_name(fd)?; - } + callbacks.on_update_routes( + routes.iter().copied().filter_map(ipv4).collect(), + routes.iter().copied().filter_map(ipv6).collect(), + ); Ok(()) } diff --git a/rust/connlib/tunnel/src/device_channel/tun_darwin.rs b/rust/connlib/tunnel/src/device_channel/tun_darwin.rs index f0bd47e421..99ff153bae 100644 --- a/rust/connlib/tunnel/src/device_channel/tun_darwin.rs +++ b/rust/connlib/tunnel/src/device_channel/tun_darwin.rs @@ -21,7 +21,7 @@ const CTL_NAME: &[u8] = b"com.apple.net.utun_control"; pub(crate) const SIOCGIFMTU: u64 = 0x0000_0000_c020_6933; #[derive(Debug)] -pub(crate) struct Tun { +pub struct Tun { name: String, fd: AsyncFd, } diff --git a/rust/connlib/tunnel/src/device_channel/tun_linux.rs b/rust/connlib/tunnel/src/device_channel/tun_linux.rs index 5b42da76b0..a411dc66c0 100644 --- a/rust/connlib/tunnel/src/device_channel/tun_linux.rs +++ b/rust/connlib/tunnel/src/device_channel/tun_linux.rs @@ -102,6 +102,14 @@ impl Tun { utils::poll_raw_fd(&self.fd, |fd| read(fd, buf), cx) } + /// Stub implementation of the `with_fd` constructor that exists only for Android. + /// + /// This will eventually disappear once updating of routes and interface config no longer happens within `firezone-tunnel`. + /// At that point, `firezone-tunnel` will interact with a `Stream + Sink` of `IpPacket` and the corresponding implementation of it will sit in the client-crates. + pub fn with_fd(_: RawFd) -> Result { + unimplemented!("This API should never be called on Linux.") + } + pub fn new( config: &InterfaceConfig, dns_config: Vec, diff --git a/rust/connlib/tunnel/src/lib.rs b/rust/connlib/tunnel/src/lib.rs index e85d27e10d..2e79c9bcab 100644 --- a/rust/connlib/tunnel/src/lib.rs +++ b/rust/connlib/tunnel/src/lib.rs @@ -16,6 +16,7 @@ use std::{ }; pub use client::{ClientState, Request}; +pub use device_channel::Tun; pub use gateway::GatewayState; pub use sockets::Sockets; diff --git a/rust/gui-client/src-tauri/src/client/tunnel-wrapper/in_proc.rs b/rust/gui-client/src-tauri/src/client/tunnel-wrapper/in_proc.rs index 7ee2b1ee55..3277deafd1 100644 --- a/rust/gui-client/src-tauri/src/client/tunnel-wrapper/in_proc.rs +++ b/rust/gui-client/src-tauri/src/client/tunnel-wrapper/in_proc.rs @@ -98,12 +98,11 @@ impl connlib_client_shared::Callbacks for CallbackHandler { .expect("controller channel failed"); } - fn on_set_interface_config(&self, _: Ipv4Addr, _: Ipv6Addr, _: Vec) -> Option { + fn on_set_interface_config(&self, _: Ipv4Addr, _: Ipv6Addr, _: Vec) { tracing::info!("on_set_interface_config"); self.ctlr_tx .try_send(ControllerRequest::TunnelReady) .expect("controller channel failed"); - None } fn on_update_resources(&self, resources: Vec) {