Skip to content

Commit

Permalink
fix: improve binding and rebinding of sockets
Browse files Browse the repository at this point in the history
  • Loading branch information
dignifiedquire committed Mar 20, 2023
1 parent d02f65b commit 156560a
Show file tree
Hide file tree
Showing 3 changed files with 232 additions and 240 deletions.
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ os_info = "3.6.0"
portable-atomic = "1"
postcard = { version = "1", default-features = false, features = ["alloc", "use-std", "experimental-derive"] }
quic-rpc = { version = "0.5", default-features = false, features = ["quinn-transport", "flume-transport"] }
quinn = "0.9.3"
quinn-proto = "0.9.2"
quinn-udp = "0.3.2"
quinn = "0.9"
quinn-proto = "0.9"
quinn-udp = "0.3"
rand = "0.8"
rcgen = "0.10"
reqwest = "0.11.14"
Expand Down
250 changes: 87 additions & 163 deletions src/hp/magicsock/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,15 @@ use crate::{
cfg::{self, DERP_MAGIC_IP},
derp::{self, DerpMap},
disco, key,
magicsock::{rebinding_conn, SESSION_ACTIVE_TIMEOUT},
magicsock::SESSION_ACTIVE_TIMEOUT,
monitor, netcheck, netmap, portmapper, stun,
},
net::LocalAddresses,
};

use super::{
endpoint::PeerMap,
rebinding_conn::{RebindingUdpConn, UdpSocket},
Endpoint, Timer, DERP_CLEAN_STALE_INTERVAL, DERP_INACTIVE_CLEANUP_TIME,
ENDPOINTS_FRESH_ENOUGH_DURATION, SOCKET_BUFFER_SIZE,
endpoint::PeerMap, rebinding_conn::RebindingUdpConn, Endpoint, Timer,
DERP_CLEAN_STALE_INTERVAL, DERP_INACTIVE_CLEANUP_TIME, ENDPOINTS_FRESH_ENOUGH_DURATION,
};

/// How many packets writes can be queued up the DERP client to write on the wire before we start
Expand All @@ -50,7 +48,7 @@ use super::{
const BUFFERED_DERP_WRITES_BEFORE_DROP: usize = 32;

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum CurrentPortFate {
pub(super) enum CurrentPortFate {
Keep,
Drop,
}
Expand All @@ -62,7 +60,7 @@ pub(super) enum Network {
}

impl Network {
fn default_addr(&self) -> IpAddr {
pub(super) fn default_addr(&self) -> IpAddr {
match self {
Self::Ip4 => Ipv4Addr::UNSPECIFIED.into(),
Self::Ip6 => Ipv6Addr::UNSPECIFIED.into(),
Expand Down Expand Up @@ -161,17 +159,13 @@ pub struct Inner {
/// A callback that provides a `cfg::NetInfo` when discovered network conditions change.
on_net_info: Option<Box<dyn Fn(cfg::NetInfo) + Send + Sync + 'static>>,

// ================================================================
// No locking required to access these fields, either because
// they're static after construction, or are wholly owned by a single goroutine.

// TODO
// connCtx: context.Context, // closed on Conn.Close
// connCtxCancel: func(), // closes connCtx

// The underlying UDP sockets used to send/rcv packets for wireguard and other magicsock protocols.
pconn4: RebindingUdpConn,
pconn6: RebindingUdpConn,
pconn6: Option<RebindingUdpConn>,

// TODO:
// closeDisco4 and closeDisco6 are io.Closers to shut down the raw
Expand Down Expand Up @@ -246,7 +240,7 @@ impl EndpointUpdateState {
}

pub(super) struct ConnState {
/// Close was called
/// Close was called.
closed: bool,

/// A timer that fires to occasionally clean up idle DERP connections.
Expand Down Expand Up @@ -395,6 +389,10 @@ impl Conn {

let derp_recv_ch = flume::bounded(64);

let (pconn4, pconn6) = Self::bind(port).await?;
let port = pconn4.port().await;
port_mapper.set_local_port(port).await;

let c = Conn(Arc::new(Inner {
name,
on_endpoints,
Expand All @@ -410,8 +408,8 @@ impl Conn {
public_key: private_key.verifying_key().into(),
last_net_check_report: Default::default(),
no_v4_send: AtomicBool::new(false),
pconn4: RebindingUdpConn::default(),
pconn6: RebindingUdpConn::default(),
pconn4,
pconn6,
socket_endpoint4: SocketEndpointCache::default(),
socket_endpoint6: SocketEndpointCache::default(),
on_stun_receive: Default::default(),
Expand All @@ -424,8 +422,6 @@ impl Conn {
peer_map: Default::default(),
}));

c.rebind(CurrentPortFate::Keep).await?;

Ok(c)
}

Expand Down Expand Up @@ -972,8 +968,7 @@ impl Conn {

/// Returns the current IPv4 listener's port number.
pub async fn local_port(&self) -> u16 {
let laddr = self.pconn4.local_addr().await;
laddr.map(|l| l.port()).unwrap_or_default()
self.pconn4.port().await
}

fn network_down(&self) -> bool {
Expand Down Expand Up @@ -1046,7 +1041,16 @@ impl Conn {
let transmits = [transmit];
match transmits[0].destination {
SocketAddr::V4(_) => self.pconn4.poll_send(udp_state, cx, &transmits),
SocketAddr::V6(_) => self.pconn6.poll_send(udp_state, cx, &transmits),
SocketAddr::V6(_) => {
if let Some(ref conn) = self.pconn6 {
conn.poll_send(udp_state, cx, &transmits)
} else {
Poll::Ready(Err(io::Error::new(
io::ErrorKind::Other,
"no IPv6 connection",
)))
}
}
}
}

Expand Down Expand Up @@ -2461,7 +2465,9 @@ impl Conn {
self.close_all_derp_locked(&mut state, "conn-close");
// Ignore errors from c.pconnN.Close.
// They will frequently have been closed already by a call to connBind.Close.
self.pconn6.close().await.ok();
if let Some(ref conn) = self.pconn6 {
conn.close().await.ok();
}
self.pconn4.close().await.ok();

// Wait on tasks updating right at the end, once everything is
Expand Down Expand Up @@ -2539,132 +2545,43 @@ impl Conn {
}
}

/// Opens a packet listener.
async fn listen_packet(&self, network: Network, port: u16) -> Result<UdpSocket> {
let addr = SocketAddr::new(network.default_addr(), port);
let socket = socket2::Socket::new(
network.into(),
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;

if let Err(err) = socket.set_recv_buffer_size(SOCKET_BUFFER_SIZE) {
info!(
"failed to set recv_buffer_size to {}: {:?}",
SOCKET_BUFFER_SIZE, err
);
}
if let Err(err) = socket.set_send_buffer_size(SOCKET_BUFFER_SIZE) {
info!(
"failed to set send_buffer_size to {}: {:?}",
SOCKET_BUFFER_SIZE, err
);
/// Closes and re-binds the UDP sockets.
/// We consider it successful if we manage to bind the IPv4 socket.
async fn rebind(&self, cur_port_fate: CurrentPortFate) -> Result<()> {
let port = self.local_port().await;
if let Some(ref conn) = self.pconn6 {
// If we were not able to bind ipv6 at program start, dont retry
if let Err(err) = conn.rebind(port, Network::Ip6, cur_port_fate).await {
info!("rebind ignoring IPv6 bind failure: {:?}", err);
}
}
socket.set_nonblocking(true)?;
socket.bind(&addr.into())?;
let socket = UdpSocket::from_std(socket.into())?;
self.pconn4
.rebind(port, Network::Ip4, cur_port_fate)
.await
.context("rebind IPv4 failed")?;

debug!("bound to {}", socket.local_addr()?);
// reread, as it might have changed
let port = self.local_port().await;
self.port_mapper.set_local_port(port).await;

Ok(socket)
Ok(())
}

// bindSocket initializes rucPtr if necessary and binds a UDP socket to it.
// Network indicates the UDP socket type; it must be "udp4" or "udp6".
// If rucPtr had an existing UDP socket bound, it closes that socket.
// The caller is responsible for informing the portMapper of any changes.
// If curPortFate is set to dropCurrentPort, no attempt is made to reuse
// the current port.
async fn bind_socket(
&self,
ruc: &RebindingUdpConn,
network: Network,
cur_port_fate: CurrentPortFate,
) -> Result<()> {
debug!(
"bind_socket: network={:?} cur_port_fate={:?}",
network, cur_port_fate
);

// Hold the ruc lock the entire time, so that the close+bind is atomic from the perspective of ruc receive functions.
let mut ruc = ruc.inner.write().await;

// Build a list of preferred ports.
// - Best is the port that the user requested.
// - Second best is the port that is currently in use.
// - If those fail, fall back to 0.

let mut ports = Vec::new();
let port = self.port.load(Ordering::Relaxed);
if port != 0 {
ports.push(port);
}
if cur_port_fate == CurrentPortFate::Keep {
if let Ok(cur_addr) = ruc.local_addr() {
ports.push(cur_addr.port());
}
}
ports.push(0);
// Remove duplicates. (All duplicates are consecutive.)
ports.dedup();
debug!("bind_socket: candidate ports: {:?}", ports);

for port in &ports {
// Close the existing conn, in case it is sitting on the port we want.
if let Err(err) = ruc.close() {
if !matches!(rebinding_conn::Error::NoConn, err) {
info!("bind_socket {:?} close failed: {:?}", network, err);
}
}
// Open a new one with the desired port.
match self.listen_packet(network, *port).await {
Ok(pconn) => {
debug!(
"bind_socket: successfully listened {:?} port {}",
network, port
);
ruc.set_conn(pconn, network);
return Ok(());
}
Err(err) => {
info!(
"bind_socket: unable to bind {:?} port {}: {:?}",
network, port, err
);
continue;
}
/// Initial connection setup.
async fn bind(port: u16) -> Result<(RebindingUdpConn, Option<RebindingUdpConn>)> {
let pconn6 = match RebindingUdpConn::bind(port, Network::Ip6).await {
Ok(conn) => Some(conn),
Err(err) => {
info!("rebind ignoring IPv6 bind failure: {:?}", err);
None
}
}

// Failed to bind, including on port 0 (!).
// Set pconn to a dummy conn whose reads block until closed.
// This keeps the receive funcs alive for a future in which
// we get a link change and we can try binding again.

// TODO: is this needed?
// ruc.set_conn(newBlockForeverConn(), "");

bail!("failed to bind any ports (tried {:?})", ports);
}
};

/// Closes and re-binds the UDP sockets.
/// We consider it successful if we manage to bind the IPv4 socket.
async fn rebind(&self, cur_port_fate: CurrentPortFate) -> Result<()> {
if let Err(err) = self
.bind_socket(&self.pconn6, Network::Ip6, cur_port_fate)
.await
{
info!("rebind ignoring IPv6 bind failure: {:?}", err);
}
self.bind_socket(&self.pconn4, Network::Ip4, cur_port_fate)
let pconn4 = RebindingUdpConn::bind(port, Network::Ip4)
.await
.context("rebind IPv4 failed")?;

self.port_mapper
.set_local_port(self.local_port().await)
.await;

Ok(())
Ok((pconn4, pconn6))
}

/// Closes and re-binds the UDP sockets and resets the DERP connection.
Expand Down Expand Up @@ -2719,7 +2636,14 @@ impl Conn {
|a, b| a.destination.is_ipv6() == b.destination.is_ipv6(),
|group| {
let res = if group[0].destination.is_ipv6() {
self.pconn6.poll_send(state, cx, group)
if let Some(ref conn) = self.pconn6 {
conn.poll_send(state, cx, &transmits)
} else {
Poll::Ready(Err(io::Error::new(
io::ErrorKind::Other,
"no IPv6 connection",
)))
}
} else {
self.pconn4.poll_send(state, cx, group)
};
Expand Down Expand Up @@ -2932,34 +2856,33 @@ impl AsyncUdpSocket for Conn {
}
// IPv6
if num_msgs_total < bufs.len() {
match self.pconn6.poll_recv(
cx,
&mut bufs[num_msgs_total..],
&mut meta[num_msgs_total..],
) {
Poll::Pending => {}
Poll::Ready(Err(err)) => {
return Poll::Ready(Err(err));
}
Poll::Ready(Ok(mut num_msgs)) => {
debug!("received {} msgs on IPv6", num_msgs);
debug_assert!(num_msgs + num_msgs_total < bufs.len());
let mut i = num_msgs_total;
while i < num_msgs + num_msgs_total {
if !self.receive_ip(&mut bufs[i], &mut meta[i], &self.socket_endpoint6) {
// move all following over
for k in i..num_msgs + num_msgs_total - 1 {
bufs.swap(k, k + 1);
meta.swap(k, k + 1);
if let Some(ref conn) = self.pconn6 {
match conn.poll_recv(cx, &mut bufs[num_msgs_total..], &mut meta[num_msgs_total..]) {
Poll::Pending => {}
Poll::Ready(Err(err)) => {
return Poll::Ready(Err(err));
}
Poll::Ready(Ok(mut num_msgs)) => {
debug!("received {} msgs on IPv6", num_msgs);
debug_assert!(num_msgs + num_msgs_total <= bufs.len());
let mut i = num_msgs_total;
while i < num_msgs + num_msgs_total {
if !self.receive_ip(&mut bufs[i], &mut meta[i], &self.socket_endpoint6)
{
// move all following over
for k in i..num_msgs + num_msgs_total - 1 {
bufs.swap(k, k + 1);
meta.swap(k, k + 1);
}

// reduce num_msgs
num_msgs -= 1;
}

// reduce num_msgs
num_msgs -= 1;
i += 1;
}

i += 1;
num_msgs_total += num_msgs;
}
num_msgs_total += num_msgs;
}
}
}
Expand Down Expand Up @@ -3495,7 +3418,8 @@ mod tests {
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer().with_writer(std::io::stderr))
.with(EnvFilter::from_default_env())
.init();
.try_init()
.ok();

let devices = Devices {
m1_ip: "127.0.0.1".parse()?,
Expand Down
Loading

0 comments on commit 156560a

Please sign in to comment.