diff --git a/examples/limit.rs b/examples/limit.rs index 6aaa2921f..830e75836 100644 --- a/examples/limit.rs +++ b/examples/limit.rs @@ -88,13 +88,19 @@ fn limit_by_node_id(allowed_nodes: HashSet) -> EventSender { n0_future::task::spawn(async move { while let Some(msg) = rx.recv().await { if let ProviderMessage::ClientConnected(msg) = msg { - let node_id = msg.node_id; - let res = if allowed_nodes.contains(&node_id) { - println!("Client connected: {node_id}"); - Ok(()) - } else { - println!("Client rejected: {node_id}"); - Err(AbortReason::Permission) + let res = match msg.node_id { + Some(node_id) if allowed_nodes.contains(&node_id) => { + println!("Client connected: {node_id}"); + Ok(()) + } + Some(node_id) => { + println!("Client rejected: {node_id}"); + Err(AbortReason::Permission) + } + None => { + println!("Client rejected: no node id"); + Err(AbortReason::Permission) + } }; msg.tx.send(res).await.ok(); } @@ -202,7 +208,7 @@ fn limit_max_connections(max_connections: usize) -> EventSender { let connection_id = msg.connection_id; let node_id = msg.node_id; let res = if let Ok(n) = requests.inc() { - println!("Accepting connection {n}, node_id {node_id}, connection_id {connection_id}"); + println!("Accepting connection {n}, node_id {node_id:?}, connection_id {connection_id}"); Ok(()) } else { Err(AbortReason::RateLimited) diff --git a/src/provider.rs b/src/provider.rs index 0134169c6..ba415df41 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -16,7 +16,7 @@ use n0_future::StreamExt; use quinn::{ClosedStream, ConnectionError, ReadToEndError}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use tokio::select; -use tracing::{debug, debug_span, warn, Instrument}; +use tracing::{debug, debug_span, Instrument}; use crate::{ api::{ @@ -319,14 +319,10 @@ pub async fn handle_connection( let connection_id = connection.stable_id() as u64; let span = debug_span!("connection", connection_id); async move { - let Ok(node_id) = connection.remote_node_id() else { - warn!("failed to get node id"); - return; - }; if let Err(cause) = progress .client_connected(|| ClientConnected { connection_id, - node_id, + node_id: connection.remote_node_id().ok(), }) .await { diff --git a/src/provider/events.rs b/src/provider/events.rs index e24e0efbb..40ec56f89 100644 --- a/src/provider/events.rs +++ b/src/provider/events.rs @@ -578,7 +578,7 @@ mod proto { #[derive(Debug, Serialize, Deserialize)] pub struct ClientConnected { pub connection_id: u64, - pub node_id: NodeId, + pub node_id: Option, } #[derive(Debug, Serialize, Deserialize)] diff --git a/src/tests.rs b/src/tests.rs index 0ef0c027c..09b2e5b33 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -348,10 +348,10 @@ fn event_handler( while let Some(event) = events_rx.recv().await { match event { ProviderMessage::ClientConnected(msg) => { - let res = if allowed_nodes.contains(&msg.inner.node_id) { - Ok(()) - } else { - Err(AbortReason::Permission) + let res = match msg.node_id { + Some(node_id) if allowed_nodes.contains(&node_id) => Ok(()), + Some(_) => Err(AbortReason::Permission), + None => Err(AbortReason::Permission), }; msg.tx.send(res).await.ok(); }