Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 44 additions & 39 deletions crates/core/src/node/network_bridge/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,21 @@ pub(crate) enum Event {
transaction: Option<Transaction>,
peer: Option<PeerId>,
connection: PeerConnection,
courtesy: bool,
transient: bool,
},
/// An outbound connection attempt succeeded.
OutboundEstablished {
transaction: Transaction,
peer: PeerId,
connection: PeerConnection,
courtesy: bool,
transient: bool,
},
/// An outbound connection attempt failed.
OutboundFailed {
transaction: Transaction,
peer: PeerId,
error: ConnectionError,
courtesy: bool,
transient: bool,
},
}

Expand All @@ -56,13 +56,13 @@ pub(crate) enum Command {
Connect {
peer: PeerId,
transaction: Transaction,
courtesy: bool,
transient: bool,
},
/// Register expectation for an inbound connection from `peer`.
ExpectInbound {
peer: PeerId,
transaction: Option<Transaction>,
courtesy: bool,
transient: bool,
},
/// Remove state associated with `peer`.
DropConnection { peer: PeerId },
Expand Down Expand Up @@ -122,64 +122,69 @@ impl Stream for HandshakeHandler {
struct ExpectedInbound {
peer: PeerId,
transaction: Option<Transaction>,
courtesy: bool,
transient: bool, // TODO: rename to transient in protocol once we migrate terminology
}

#[derive(Default)]
struct ExpectedInboundTracker {
// Keyed by remote IP to tolerate port changes; multiple expectations per IP
// are tracked and deduped by port.
entries: HashMap<IpAddr, Vec<ExpectedInbound>>,
}

impl ExpectedInboundTracker {
fn register(&mut self, peer: PeerId, transaction: Option<Transaction>, courtesy: bool) {
fn register(&mut self, peer: PeerId, transaction: Option<Transaction>, transient: bool) {
tracing::debug!(
remote = %peer.addr,
courtesy,
transient,
tx = ?transaction,
"ExpectInbound: registering expectation"
);
let list = self.entries.entry(peer.addr.ip()).or_default();
// Replace any existing expectation for the same peer/port so the newest wins.
// Replace any existing expectation for the same peer/port to ensure the newest registration wins.
list.retain(|entry| entry.peer.addr.port() != peer.addr.port());
list.push(ExpectedInbound {
peer,
transaction,
courtesy,
transient,
});
}

fn drop_peer(&mut self, peer: &PeerId) {
if let Some(list) = self.entries.get_mut(&peer.addr.ip()) {
list.retain(|entry| entry.peer.addr.port() != peer.addr.port());
list.retain(|entry| entry.peer != *peer);
if list.is_empty() {
self.entries.remove(&peer.addr.ip());
}
}
}

fn consume(&mut self, addr: SocketAddr) -> Option<ExpectedInbound> {
let list = self.entries.get_mut(&addr.ip())?;
let pos = list
let ip = addr.ip();
let list = self.entries.get_mut(&ip)?;
if let Some(pos) = list
.iter()
.position(|entry| entry.peer.addr.port() == addr.port())?;
let entry = list.swap_remove(pos);
.position(|entry| entry.peer.addr.port() == addr.port())
{
let entry = list.remove(pos);
if list.is_empty() {
self.entries.remove(&ip);
}
tracing::debug!(remote = %addr, peer = %entry.peer.addr, transient = entry.transient, tx = ?entry.transaction, "ExpectInbound: matched by exact port");
return Some(entry);
}
let entry = list.pop();
if list.is_empty() {
self.entries.remove(&addr.ip());
self.entries.remove(&ip);
}
Some(entry)
if let Some(entry) = entry {
tracing::debug!(remote = %addr, peer = %entry.peer.addr, transient = entry.transient, tx = ?entry.transaction, "ExpectInbound: matched by IP fallback");
return Some(entry);
}
None
}

#[cfg(test)]
fn contains(&self, addr: SocketAddr) -> bool {
self.entries
.get(&addr.ip())
.map(|list| {
list.iter()
.any(|entry| entry.peer.addr.port() == addr.port())
})
.unwrap_or(false)
self.entries.contains_key(&addr.ip())
}
}

Expand All @@ -197,12 +202,12 @@ async fn run_driver(
loop {
select! {
command = commands_rx.recv() => match command {
Some(Command::Connect { peer, transaction, courtesy }) => {
spawn_outbound(outbound.clone(), events_tx.clone(), peer, transaction, courtesy, peer_ready.clone());
}
Some(Command::ExpectInbound { peer, transaction, courtesy }) => {
expected_inbound.register(peer, transaction, courtesy);
Some(Command::Connect { peer, transaction, transient }) => {
spawn_outbound(outbound.clone(), events_tx.clone(), peer, transaction, transient, peer_ready.clone());
}
Some(Command::ExpectInbound { peer, transaction, transient }) => {
expected_inbound.register(peer, transaction, transient /* transient */);
}
Some(Command::DropConnection { peer }) => {
expected_inbound.drop_peer(&peer);
}
Expand All @@ -217,8 +222,8 @@ async fn run_driver(

let remote_addr = conn.remote_addr();
let entry = expected_inbound.consume(remote_addr);
let (peer, transaction, courtesy) = if let Some(entry) = entry {
(Some(entry.peer), entry.transaction, entry.courtesy)
let (peer, transaction, transient) = if let Some(entry) = entry {
(Some(entry.peer), entry.transaction, entry.transient)
} else {
(None, None, false)
};
Expand All @@ -227,7 +232,7 @@ async fn run_driver(
transaction,
peer,
connection: conn,
courtesy,
transient,
}).await.is_err() {
break;
}
Expand All @@ -244,7 +249,7 @@ fn spawn_outbound(
events_tx: mpsc::Sender<Event>,
peer: PeerId,
transaction: Transaction,
courtesy: bool,
transient: bool,
peer_ready: Option<Arc<std::sync::atomic::AtomicBool>>,
) {
tokio::spawn(async move {
Expand All @@ -268,13 +273,13 @@ fn spawn_outbound(
transaction,
peer: peer.clone(),
connection,
courtesy,
transient,
},
Err(error) => Event::OutboundFailed {
transaction,
peer: peer.clone(),
error,
courtesy,
transient,
},
};

Expand Down Expand Up @@ -307,7 +312,7 @@ mod tests {
.expect("expected registered inbound entry");
assert_eq!(entry.peer, peer);
assert_eq!(entry.transaction, Some(tx));
assert!(entry.courtesy);
assert!(entry.transient);
assert!(tracker.consume(peer.addr).is_none());
}

Expand Down Expand Up @@ -335,6 +340,6 @@ mod tests {
.consume(peer.addr)
.expect("entry should be present after overwrite");
assert_eq!(entry.transaction, Some(new_tx));
assert!(entry.courtesy);
assert!(entry.transient);
}
}
Loading
Loading