Skip to content

Commit

Permalink
feat(kotlin): detect network and dns changes
Browse files Browse the repository at this point in the history
  • Loading branch information
conectado committed Mar 18, 2024
1 parent a9dfe00 commit 021c116
Show file tree
Hide file tree
Showing 13 changed files with 146 additions and 23 deletions.
1 change: 1 addition & 0 deletions kotlin/android/app/src/main/AndroidManifest.xml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

<uses-permission android:name="android.permission.INTERNET" />
<uses-permission android:name="android.permission.ACCESS_NETWORK_STATE" />
<uses-permission android:name="android.permission.CHANGE_NETWORK_STATE" />
<uses-permission android:name="android.permission.ACCESS_WIFI_STATE" />
<uses-permission android:name="android.permission.FOREGROUND_SERVICE" />
<uses-permission android:name="android.permission.RECEIVE_BOOT_COMPLETED" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,6 @@ object ConnlibSession {
): Long

external fun disconnect(connlibSession: Long): Boolean

external fun networkUpdate(connlibSession: Long, dnsList: String): Boolean
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import android.net.ConnectivityManager
import android.net.LinkProperties
import android.net.Network
import android.net.NetworkCapabilities
import android.util.Log
import com.google.gson.Gson
import com.squareup.moshi.Moshi
import com.squareup.moshi.adapter
import dev.firezone.android.tunnel.ConnlibSession
import dev.firezone.android.tunnel.TunnelService
import java.net.InetAddress
import javax.inject.Inject

private const val TAG: String = "NetworkMonitor"

class NetworkMonitor(val connlibSessionPtr: Long) : ConnectivityManager.NetworkCallback() {
@Inject
internal lateinit var moshi: Moshi

override fun onLinkPropertiesChanged(network: Network, linkProperties: LinkProperties) {
ConnlibSession.networkUpdate(connlibSessionPtr, Gson().toJson(linkProperties.dnsServers))
super.onLinkPropertiesChanged(network, linkProperties)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ import android.app.NotificationManager
import android.app.PendingIntent
import android.content.Context
import android.content.Intent
import android.net.ConnectivityManager
import android.net.NetworkCapabilities
import android.net.NetworkRequest
import android.net.VpnService
import android.os.Binder
import android.os.Build
Expand All @@ -28,11 +31,13 @@ import dev.firezone.android.tunnel.callback.ConnlibCallback
import dev.firezone.android.tunnel.model.Cidr
import dev.firezone.android.tunnel.model.Resource
import dev.firezone.android.tunnel.util.DnsServersDetector
import NetworkMonitor
import java.nio.file.Files
import java.nio.file.Paths
import java.util.UUID
import javax.inject.Inject


@AndroidEntryPoint
@OptIn(ExperimentalStdlibApi::class)
class TunnelService : VpnService() {
Expand Down Expand Up @@ -187,7 +192,7 @@ class TunnelService : VpnService() {
fun disconnect() {
Log.d(TAG, "disconnect")

// Connlib should call onDisconnect() when it's done, with no error.
// Connlib should call onDisconnect() when it's don"^Y&|l;e, with no error.
connlibSessionPtr!!.let {
ConnlibSession.disconnect(it)
}
Expand Down Expand Up @@ -223,6 +228,10 @@ class TunnelService : VpnService() {
logFilter = config.logFilter,
callback = callback,
)

val networkRequest = NetworkRequest.Builder().addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN).build()
val connectivityManager = getSystemService(ConnectivityManager::class.java) as ConnectivityManager
connectivityManager.requestNetwork(networkRequest, NetworkMonitor(connlibSessionPtr!!))
}
}

Expand Down
3 changes: 3 additions & 0 deletions rust/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ futures-bounded = "0.2.1"
domain = { version = "0.9", features = ["serde"] }
dns-lookup = "2.0"
tokio-tungstenite = "0.21"
tokio-stream = { version = "0.1", features = ["sync"] }
rtnetlink = { version = "0.14.1", default-features = false, features = ["tokio_socket"] }

connlib-client-android = { path = "connlib/clients/android"}
Expand Down
28 changes: 27 additions & 1 deletion rust/connlib/clients/android/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,9 @@ pub unsafe extern "system" fn Java_dev_firezone_android_tunnel_ConnlibSession_co
None => return std::ptr::null(),
};

Box::into_raw(Box::new(session))
let session_ptr = Box::into_raw(Box::new(session));
tracing::error!("Session pointer is: {:?}", session_ptr);
session_ptr
}

pub struct SessionWrapper {
Expand All @@ -503,3 +505,27 @@ pub unsafe extern "system" fn Java_dev_firezone_android_tunnel_ConnlibSession_di
Box::from_raw(session).inner.disconnect();
});
}

/// # Safety
/// Pointers must be valid
#[allow(non_snake_case)]
#[no_mangle]
pub unsafe extern "system" fn Java_dev_firezone_android_tunnel_ConnlibSession_networkUpdate(
mut env: JNIEnv,
_: JClass,
session: *const SessionWrapper,
dns_list: JString,
) {
let dns = String::from(
env.get_string(&dns_list)
.map_err(|source| ConnectError::StringInvalid {
name: "dns_list",
source,
})
.unwrap(),
);
let dns: Vec<IpAddr> = serde_json::from_str(&dns).unwrap();
let session = &*session;
session.inner.set_dns(dns);
session.inner.reconnect();
}
2 changes: 2 additions & 0 deletions rust/connlib/clients/shared/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ parking_lot = "0.12"
bimap = "0.6"
ip_network = { version = "0.4", default-features = false }
phoenix-channel = { workspace = true }
tokio-stream = { workspace = true }
futures-util = { version = "0.3", default-features = false, features = ["std", "async-await", "async-await-macro"] }


[target.'cfg(target_os = "android")'.dependencies]
Expand Down
14 changes: 13 additions & 1 deletion rust/connlib/clients/shared/src/eventloop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,18 @@ use connlib_shared::{
Callbacks,
};
use firezone_tunnel::ClientTunnel;
use futures_util::StreamExt;
use phoenix_channel::{ErrorReply, OutboundRequestId, PhoenixChannel};
use std::{
collections::HashMap,
io,
net::IpAddr,
path::PathBuf,
task::{Context, Poll},
time::Duration,
};
use tokio::time::{Instant, Interval, MissedTickBehavior};
use tokio_stream::wrappers::WatchStream;
use url::Url;

pub struct Eventloop<C: Callbacks> {
Expand All @@ -28,6 +31,7 @@ pub struct Eventloop<C: Callbacks> {

portal: PhoenixChannel<(), IngressMessages, ReplyMessages>,
rx: tokio::sync::mpsc::Receiver<Command>,
dns_updated: WatchStream<Vec<IpAddr>>,

connection_intents: SentConnectionIntents,
log_upload_interval: tokio::time::Interval,
Expand All @@ -44,6 +48,7 @@ impl<C: Callbacks> Eventloop<C> {
tunnel: ClientTunnel<C>,
portal: PhoenixChannel<(), IngressMessages, ReplyMessages>,
rx: tokio::sync::mpsc::Receiver<Command>,
dns_updated: WatchStream<Vec<IpAddr>>,
) -> Self {
Self {
tunnel,
Expand All @@ -52,6 +57,7 @@ impl<C: Callbacks> Eventloop<C> {
connection_intents: SentConnectionIntents::default(),
log_upload_interval: upload_interval(),
rx,
dns_updated,
}
}
}
Expand All @@ -66,6 +72,7 @@ where
match self.rx.poll_recv(cx) {
Poll::Ready(Some(Command::Stop)) | Poll::Ready(None) => return Poll::Ready(Ok(())),
Poll::Ready(Some(Command::Reconnect)) => {
tracing::debug!("Reconnecting");
self.portal.reconnect();
self.tunnel.reconnect();

Expand All @@ -74,6 +81,11 @@ where
Poll::Pending => {}
}

if let Poll::Ready(Some(dns)) = self.dns_updated.poll_next_unpin(cx) {
tracing::debug!("New dns {dns:?}");
self.tunnel.set_dns(dns);
}

match self.tunnel.poll_next_event(cx) {
Poll::Ready(Ok(event)) => {
self.handle_tunnel_event(event);
Expand Down Expand Up @@ -180,7 +192,7 @@ where
resources,
}) => {
if !self.tunnel_init {
if let Err(e) = self.tunnel.set_interface(&interface) {
if let Err(e) = self.tunnel.set_interface(interface) {
tracing::warn!("Failed to set interface on tunnel: {e}");
return;
}
Expand Down
31 changes: 28 additions & 3 deletions rust/connlib/clients/shared/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ use backoff::ExponentialBackoffBuilder;
use connlib_shared::{get_user_agent, CallbackErrorFacade};
use firezone_tunnel::ClientTunnel;
use phoenix_channel::PhoenixChannel;
use std::collections::HashSet;
use std::net::IpAddr;
use std::time::Duration;
use tokio_stream::wrappers::WatchStream;

mod eventloop;
pub mod file_logger;
Expand All @@ -27,6 +30,7 @@ use tokio::task::JoinHandle;
/// A session is created using [Session::connect], then to stop a session we use [Session::disconnect].
pub struct Session {
channel: tokio::sync::mpsc::Sender<Command>,
dns_updater: tokio::sync::watch::Sender<Vec<IpAddr>>,
}

impl Session {
Expand All @@ -42,7 +46,10 @@ impl Session {
handle: tokio::runtime::Handle,
) -> connlib_shared::Result<Self> {
let callbacks = CallbackErrorFacade(callbacks);
// TODO: this is just for testing sake change me back to 1 and do things with dns propperly
let (tx, rx) = tokio::sync::mpsc::channel(1);
let (dns_updater, dns_updated) = tokio::sync::watch::channel(Vec::new());
let dns_updated = WatchStream::from_changes(dns_updated);

let connect_handle = handle.spawn(connect(
url,
Expand All @@ -51,10 +58,14 @@ impl Session {
callbacks.clone(),
max_partition_time,
rx,
dns_updated,
));
handle.spawn(connect_supervisor(connect_handle, callbacks));

Ok(Self { channel: tx })
Ok(Self {
channel: tx,
dns_updater,
})
}

/// Attempts to reconnect a [`Session`].
Expand All @@ -69,10 +80,23 @@ impl Session {
///
/// In case of destructive network state changes, i.e. the user switched from wifi to cellular,
/// reconnect allows connlib to re-establish connections faster because we don't have to wait for timeouts first.
pub fn reconnect(&mut self) {
pub fn reconnect(&self) {
let _ = self.channel.try_send(Command::Reconnect);
}

pub fn set_dns(&self, new_dns: Vec<IpAddr>) {
self.dns_updater.send_if_modified(|old_dns| {
if HashSet::<IpAddr>::from_iter(old_dns.clone().into_iter())
!= HashSet::from_iter(new_dns.clone().into_iter())
{
*old_dns = new_dns;
true
} else {
false
}
});
}

/// Disconnect a [`Session`].
///
/// This consumes [`Session`] which cleans up all state associated with it.
Expand All @@ -91,6 +115,7 @@ async fn connect<CB>(
callbacks: CB,
max_partition_time: Option<Duration>,
rx: tokio::sync::mpsc::Receiver<Command>,
dns_updated: WatchStream<Vec<IpAddr>>,
) -> Result<(), Error>
where
CB: Callbacks + 'static,
Expand All @@ -107,7 +132,7 @@ where
.build(),
);

let mut eventloop = Eventloop::new(tunnel, portal, rx);
let mut eventloop = Eventloop::new(tunnel, portal, rx, dns_updated);

std::future::poll_fn(|cx| eventloop.poll(cx))
.await
Expand Down
Loading

0 comments on commit 021c116

Please sign in to comment.