Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 49 additions & 38 deletions crates/core/src/node/network_bridge/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
//! connection attempts to/from the event loop. Higher-level routing decisions now live inside
//! `ConnectOp`.

use std::net::{IpAddr, SocketAddr};
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
Expand Down Expand Up @@ -127,7 +127,7 @@ struct ExpectedInbound {

#[derive(Default)]
struct ExpectedInboundTracker {
entries: HashMap<IpAddr, Vec<ExpectedInbound>>,
entries: HashMap<SocketAddr, ExpectedInbound>,
}

impl ExpectedInboundTracker {
Expand All @@ -138,53 +138,38 @@ impl ExpectedInboundTracker {
tx = ?transaction,
"ExpectInbound: registering expectation"
);
let list = self.entries.entry(peer.addr.ip()).or_default();
// Replace any existing expectation for the same peer/port to ensure the newest registration wins.
list.retain(|entry| entry.peer.addr.port() != peer.addr.port());
list.push(ExpectedInbound {
peer,
transaction,
transient,
});
// Replace any existing expectation for the same socket to ensure the newest registration wins.
self.entries.insert(
peer.addr,
ExpectedInbound {
peer,
transaction,
transient,
},
);
}

fn drop_peer(&mut self, peer: &PeerId) {
if let Some(list) = self.entries.get_mut(&peer.addr.ip()) {
list.retain(|entry| entry.peer != *peer);
if list.is_empty() {
self.entries.remove(&peer.addr.ip());
}
}
self.entries.remove(&peer.addr);
}

fn consume(&mut self, addr: SocketAddr) -> Option<ExpectedInbound> {
let ip = addr.ip();
let list = self.entries.get_mut(&ip)?;
if let Some(pos) = list
.iter()
.position(|entry| entry.peer.addr.port() == addr.port())
{
let entry = list.remove(pos);
if list.is_empty() {
self.entries.remove(&ip);
}
tracing::debug!(remote = %addr, peer = %entry.peer.addr, transient = entry.transient, tx = ?entry.transaction, "ExpectInbound: matched by exact port");
return Some(entry);
}
let entry = list.pop();
if list.is_empty() {
self.entries.remove(&ip);
let entry = self.entries.remove(&addr);
if let Some(entry) = &entry {
tracing::debug!(
remote = %addr,
peer = %entry.peer.addr,
transient = entry.transient,
tx = ?entry.transaction,
"ExpectInbound: matched by socket address"
);
}
if let Some(entry) = entry {
tracing::debug!(remote = %addr, peer = %entry.peer.addr, transient = entry.transient, tx = ?entry.transaction, "ExpectInbound: matched by IP fallback");
return Some(entry);
}
None
entry
}

#[cfg(test)]
fn contains(&self, addr: SocketAddr) -> bool {
self.entries.contains_key(&addr.ip())
self.entries.contains_key(&addr)
}
}

Expand Down Expand Up @@ -225,6 +210,10 @@ async fn run_driver(
let (peer, transaction, transient) = if let Some(entry) = entry {
(Some(entry.peer), entry.transaction, entry.transient)
} else {
tracing::warn!(
remote = %remote_addr,
"Received unexpected inbound connection (no matching expectation)"
);
(None, None, false)
};

Expand Down Expand Up @@ -342,4 +331,26 @@ mod tests {
assert_eq!(entry.transaction, Some(new_tx));
assert!(entry.transient);
}

#[test]
fn tracker_keeps_peers_separate_on_same_ip() {
let mut tracker = ExpectedInboundTracker::default();
let peer_a = make_peer(4400);
let peer_b = make_peer(4401);

tracker.register(peer_a.clone(), None, false);
tracker.register(peer_b.clone(), None, true);

let first = tracker
.consume(peer_a.addr)
.expect("first peer should be matched by exact socket");
assert_eq!(first.peer, peer_a);
assert!(!first.transient);

let second = tracker
.consume(peer_b.addr)
.expect("second peer should still be tracked independently");
assert_eq!(second.peer, peer_b);
assert!(second.transient);
}
}
10 changes: 10 additions & 0 deletions crates/core/tests/connectivity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ use freenet_stdlib::{
use std::time::Duration;
use tokio_tungstenite::connect_async;

// Fixed ring locations for the three-node connectivity test.
fn fixed_three_node_locations() -> Vec<f64> {
vec![0.1, 0.5, 0.9]
}

/// Test gateway reconnection:
/// 1. Start a gateway and a peer connected to it
/// 2. Perform operations to verify connectivity
Expand Down Expand Up @@ -254,6 +259,7 @@ async fn test_basic_gateway_connectivity(ctx: &mut TestContext) -> TestResult {
auto_connect_peers = true,
timeout_secs = 180,
startup_wait_secs = 30,
node_locations_fn = fixed_three_node_locations,
aggregate_events = "always",
tokio_flavor = "multi_thread",
tokio_worker_threads = 4
Expand All @@ -272,6 +278,10 @@ async fn test_three_node_network_connectivity(ctx: &mut TestContext) -> TestResu
let gateway = ctx.node("gateway")?;
let peer1 = ctx.node("peer1")?;
let peer2 = ctx.node("peer2")?;
println!(
"Using deterministic node locations: gateway={:.3}, peer1={:.3}, peer2={:.3}",
gateway.location, peer1.location, peer2.location
);

let peer1_public_port = peer1.network_port.context(
"peer1 missing network port; auto_connect_peers requires public_port for mesh connectivity",
Expand Down
Loading