Skip to content

Commit

Permalink
refactor(network): use builder pattern to construct the Network
Browse files Browse the repository at this point in the history
- this allows us to enable feature flagged fields to be passed in
  • Loading branch information
RolandSherwin committed Sep 14, 2023
1 parent 604f894 commit 0a49270
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 99 deletions.
29 changes: 14 additions & 15 deletions sn_client/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ use libp2p::{identity::Keypair, kad::Record, Multiaddr};
#[cfg(feature = "open-metrics")]
use prometheus_client::registry::Registry;
use sn_dbc::{DbcId, PublicAddress, SignedSpend, Token};
use sn_networking::{
multiaddr_is_global, NetworkConfig, NetworkEvent, SwarmDriver, CLOSE_GROUP_SIZE,
};
use sn_networking::{multiaddr_is_global, NetworkBuilder, NetworkEvent, CLOSE_GROUP_SIZE};
use sn_protocol::{
error::Error as ProtocolError,
storage::{
Expand Down Expand Up @@ -59,18 +57,19 @@ impl Client {
info!("Startup a client with peers {peers:?} and local {local:?} flag");
info!("Starting Kad swarm in client mode...");

let network_cfg = NetworkConfig {
keypair: Keypair::generate_ed25519(),
local,
root_dir: std::env::temp_dir(),
listen_addr: None,
request_timeout: req_response_timeout,
concurrency_limit: Some(custom_concurrency_limit.unwrap_or(DEFAULT_CLIENT_CONCURRENCY)),
#[cfg(feature = "open-metrics")]
metrics_registry: Registry::default(),
};
let (network, mut network_event_receiver, swarm_driver) =
SwarmDriver::new_client(network_cfg)?;
let mut network_builder =
NetworkBuilder::new(Keypair::generate_ed25519(), local, std::env::temp_dir());

if let Some(request_timeout) = req_response_timeout {
network_builder.request_timeout(request_timeout);
}
network_builder
.concurrency_limit(custom_concurrency_limit.unwrap_or(DEFAULT_CLIENT_CONCURRENCY));

#[cfg(feature = "open-metrics")]
network_builder.metrics_registry(Registry::default());

let (network, mut network_event_receiver, swarm_driver) = network_builder.build_client()?;
info!("Client constructed network and swarm_driver");
let events_channel = ClientEventsChannel::default();

Expand Down
161 changes: 91 additions & 70 deletions sn_networking/src/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
// permissions and limitations relating to use of the SAFE Network Software.

#[cfg(feature = "open-metrics")]
use crate::metrics_service::metrics_server;
use crate::metrics_service::run_metrics_server;
use crate::{
circular_vec::CircularVec,
cmd::SwarmCmd,
Expand Down Expand Up @@ -55,6 +55,9 @@ use std::{
use tokio::sync::{mpsc, oneshot};
use tracing::warn;

type PendingGetClosest = HashMap<QueryId, (oneshot::Sender<HashSet<PeerId>>, HashSet<PeerId>)>;
type PendingGetRecord = HashMap<QueryId, (oneshot::Sender<Result<Record>>, GetRecordResultMap)>;

/// What is the largest packet to send over the network.
/// Records larger than this will be rejected.
// TODO: revisit once utxo is in
Expand Down Expand Up @@ -89,45 +92,48 @@ pub(super) struct NodeBehaviour {
}

#[derive(Debug)]
pub struct NetworkConfig {
pub keypair: Keypair,
pub local: bool,
pub root_dir: PathBuf,
pub listen_addr: Option<SocketAddr>,
pub request_timeout: Option<Duration>,
pub concurrency_limit: Option<usize>,
pub struct NetworkBuilder {
keypair: Keypair,
local: bool,
root_dir: PathBuf,
listen_addr: Option<SocketAddr>,
request_timeout: Option<Duration>,
concurrency_limit: Option<usize>,
#[cfg(feature = "open-metrics")]
pub metrics_registry: Registry,
metrics_registry: Option<Registry>,
}

type PendingGetClosest = HashMap<QueryId, (oneshot::Sender<HashSet<PeerId>>, HashSet<PeerId>)>;
type PendingGetRecord = HashMap<QueryId, (oneshot::Sender<Result<Record>>, GetRecordResultMap)>;
impl NetworkBuilder {
pub fn new(keypair: Keypair, local: bool, root_dir: PathBuf) -> Self {
Self {
keypair,
local,
root_dir,
listen_addr: None,
request_timeout: None,
concurrency_limit: None,
#[cfg(feature = "open-metrics")]
metrics_registry: None,
}
}

pub struct SwarmDriver {
pub(crate) swarm: Swarm<NodeBehaviour>,
pub(crate) self_peer_id: PeerId,
pub(crate) local: bool,
pub(crate) is_client: bool,
pub(crate) bootstrap_ongoing: bool,
/// The peers that are closer to our PeerId. Includes self.
pub(crate) close_group: Vec<PeerId>,
pub(crate) replication_fetcher: ReplicationFetcher,
#[cfg(feature = "open-metrics")]
pub(crate) network_metrics: NetworkMetrics,
pub fn listen_addr(&mut self, listen_addr: SocketAddr) {
self.listen_addr = Some(listen_addr);
}

cmd_receiver: mpsc::Receiver<SwarmCmd>,
event_sender: mpsc::Sender<NetworkEvent>, // Use `self.send_event()` to send a NetworkEvent.
pub fn request_timeout(&mut self, request_timeout: Duration) {
self.request_timeout = Some(request_timeout);
}

/// Trackers for underlying behaviour related events
pub(crate) pending_get_closest_peers: PendingGetClosest,
pub(crate) pending_requests: HashMap<RequestId, Option<oneshot::Sender<Result<Response>>>>,
pub(crate) pending_get_record: PendingGetRecord,
/// A list of the most recent peers we have dialed ourselves.
pub(crate) dialed_peers: CircularVec<PeerId>,
pub(crate) unroutable_peers: CircularVec<PeerId>,
}
pub fn concurrency_limit(&mut self, concurrency_limit: usize) {
self.concurrency_limit = Some(concurrency_limit);
}

#[cfg(feature = "open-metrics")]
pub fn metrics_registry(&mut self, metrics_registry: Registry) {
self.metrics_registry = Some(metrics_registry);
}

impl SwarmDriver {
/// Creates a new `SwarmDriver` instance, along with a `Network` handle
/// for sending commands and an `mpsc::Receiver<NetworkEvent>` for receiving
/// network events. It initializes the swarm, sets up the transport, and
Expand All @@ -141,9 +147,7 @@ impl SwarmDriver {
/// # Errors
///
/// Returns an error if there is a problem initializing the mDNS behaviour.
pub fn new(
network_cfg: NetworkConfig,
) -> Result<(Network, mpsc::Receiver<NetworkEvent>, Self)> {
pub fn build_node(self) -> Result<(Network, mpsc::Receiver<NetworkEvent>, SwarmDriver)> {
let mut kad_cfg = KademliaConfig::default();
let _ = kad_cfg
.set_kbucket_inserts(libp2p::kad::KademliaBucketInserts::Manual)
Expand Down Expand Up @@ -173,7 +177,7 @@ impl SwarmDriver {

let store_cfg = {
// Configures the disk_store to store records under the provided path and increase the max record size
let storage_dir_path = network_cfg.root_dir.join("record_store");
let storage_dir_path = self.root_dir.join("record_store");
if let Err(error) = std::fs::create_dir_all(&storage_dir_path) {
return Err(Error::FailedToCreateRecordStoreDir {
path: storage_dir_path,
Expand All @@ -187,10 +191,9 @@ impl SwarmDriver {
}
};

let listen_addr = network_cfg.listen_addr;
let listen_addr = self.listen_addr;

let (network, events_receiver, mut swarm_driver) = Self::with(
network_cfg,
let (network, events_receiver, mut swarm_driver) = self.build(
kad_cfg,
Some(store_cfg),
false,
Expand All @@ -216,10 +219,8 @@ impl SwarmDriver {
Ok((network, events_receiver, swarm_driver))
}

/// Same as `new` API but creates the network components in client mode
pub fn new_client(
network_cfg: NetworkConfig,
) -> Result<(Network, mpsc::Receiver<NetworkEvent>, Self)> {
/// Same as `build_node` API but creates the network components in client mode
pub fn build_client(self) -> Result<(Network, mpsc::Receiver<NetworkEvent>, SwarmDriver)> {
// Create a Kademlia behaviour for client mode, i.e. set req/resp protocol
// to outbound-only mode and don't listen on any address
let mut kad_cfg = KademliaConfig::default(); // default query timeout is 60 secs
Expand All @@ -234,9 +235,8 @@ impl SwarmDriver {
NonZeroUsize::new(CLOSE_GROUP_SIZE).ok_or_else(|| Error::InvalidCloseGroupSize)?,
);

let concurrency_limit = network_cfg.concurrency_limit;
let (mut network, net_event_recv, driver) = Self::with(
network_cfg,
let concurrency_limit = self.concurrency_limit;
let (mut network, net_event_recv, driver) = self.build(
kad_cfg,
None,
true,
Expand All @@ -252,26 +252,22 @@ impl SwarmDriver {
}

/// Private helper to create the network components with the provided config and req/res behaviour
fn with(
network_cfg: NetworkConfig,
fn build(
self,
kad_cfg: KademliaConfig,
record_store_cfg: Option<NodeRecordStoreConfig>,
is_client: bool,
req_res_protocol: ProtocolSupport,
identify_version: String,
) -> Result<(Network, mpsc::Receiver<NetworkEvent>, Self)> {
let peer_id = PeerId::from(network_cfg.keypair.public());
) -> Result<(Network, mpsc::Receiver<NetworkEvent>, SwarmDriver)> {
let peer_id = PeerId::from(self.keypair.public());
info!("Node (PID: {}) with PeerId: {peer_id}", std::process::id());

// RequestResponse Behaviour
let request_response = {
let mut cfg = RequestResponseConfig::default();
let _ = cfg
.set_request_timeout(
network_cfg
.request_timeout
.unwrap_or(REQUEST_TIMEOUT_DEFAULT_S),
)
.set_request_timeout(self.request_timeout.unwrap_or(REQUEST_TIMEOUT_DEFAULT_S))
.set_connection_keep_alive(CONNECTION_KEEP_ALIVE_TIMEOUT);

request_response::cbor::Behaviour::new(
Expand Down Expand Up @@ -323,7 +319,7 @@ impl SwarmDriver {
let identify = {
let cfg = libp2p::identify::Config::new(
IDENTIFY_PROTOCOL_STR.to_string(),
network_cfg.keypair.public(),
self.keypair.public(),
)
.with_agent_version(identify_version);
libp2p::identify::Behaviour::new(cfg)
Expand All @@ -334,26 +330,25 @@ impl SwarmDriver {
let mut transport = libp2p::tcp::tokio::Transport::new(libp2p::tcp::Config::default())
.upgrade(libp2p::core::upgrade::Version::V1)
.authenticate(
libp2p::noise::Config::new(&network_cfg.keypair)
libp2p::noise::Config::new(&self.keypair)
.expect("Signing libp2p-noise static DH keypair failed."),
)
.multiplex(libp2p::yamux::Config::default())
.boxed();

#[cfg(feature = "quic")]
let mut transport =
libp2p_quic::tokio::Transport::new(quic::Config::new(&network_cfg.keypair))
.map(|(peer_id, muxer), _| (peer_id, StreamMuxerBox::new(muxer)))
.boxed();
let mut transport = libp2p_quic::tokio::Transport::new(quic::Config::new(&self.keypair))
.map(|(peer_id, muxer), _| (peer_id, StreamMuxerBox::new(muxer)))
.boxed();

if !network_cfg.local {
if !self.local {
debug!("Preventing non-global dials");
// Wrap TCP or UDP in a transport that prevents dialing local addresses.
transport = libp2p::core::transport::global_only::Transport::new(transport).boxed();
}

// Disable AutoNAT if we are either running locally or a client.
let autonat = if !network_cfg.local && !is_client {
let autonat = if !self.local && !is_client {
let cfg = libp2p::autonat::Config {
// Defaults to 15. But we want to be a little quicker on checking for our NAT status.
boot_delay: Duration::from_secs(3),
Expand All @@ -374,9 +369,9 @@ impl SwarmDriver {

#[cfg(feature = "open-metrics")]
let network_metrics = {
let mut metrics_registry = network_cfg.metrics_registry;
let mut metrics_registry = self.metrics_registry.unwrap_or_default();
let metrics = NetworkMetrics::new(&mut metrics_registry);
metrics_server(metrics_registry);
run_metrics_server(metrics_registry);
metrics
};

Expand All @@ -391,10 +386,10 @@ impl SwarmDriver {
let swarm = SwarmBuilder::with_tokio_executor(transport, behaviour, peer_id).build();

let (swarm_cmd_sender, swarm_cmd_receiver) = mpsc::channel(NETWORKING_CHANNEL_SIZE);
let swarm_driver = Self {
let swarm_driver = SwarmDriver {
swarm,
self_peer_id: peer_id,
local: network_cfg.local,
local: self.local,
is_client,
bootstrap_ongoing: false,
close_group: Default::default(),
Expand All @@ -418,15 +413,41 @@ impl SwarmDriver {
Network {
swarm_cmd_sender,
peer_id,
root_dir_path: network_cfg.root_dir,
keypair: network_cfg.keypair,
root_dir_path: self.root_dir,
keypair: self.keypair,
concurrency_limiter: None,
},
network_event_receiver,
swarm_driver,
))
}
}

pub struct SwarmDriver {
pub(crate) swarm: Swarm<NodeBehaviour>,
pub(crate) self_peer_id: PeerId,
pub(crate) local: bool,
pub(crate) is_client: bool,
pub(crate) bootstrap_ongoing: bool,
/// The peers that are closer to our PeerId. Includes self.
pub(crate) close_group: Vec<PeerId>,
pub(crate) replication_fetcher: ReplicationFetcher,
#[cfg(feature = "open-metrics")]
pub(crate) network_metrics: NetworkMetrics,

cmd_receiver: mpsc::Receiver<SwarmCmd>,
event_sender: mpsc::Sender<NetworkEvent>, // Use `self.send_event()` to send a NetworkEvent.

/// Trackers for underlying behaviour related events
pub(crate) pending_get_closest_peers: PendingGetClosest,
pub(crate) pending_requests: HashMap<RequestId, Option<oneshot::Sender<Result<Response>>>>,
pub(crate) pending_get_record: PendingGetRecord,
/// A list of the most recent peers we have dialed ourselves.
pub(crate) dialed_peers: CircularVec<PeerId>,
pub(crate) unroutable_peers: CircularVec<PeerId>,
}

impl SwarmDriver {
/// Asynchronously drives the swarm event loop, handling events from both
/// the swarm and command receiver. This function will run indefinitely,
/// until the command channel is closed.
Expand Down
2 changes: 1 addition & 1 deletion sn_networking/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ mod replication_fetcher;

pub use self::{
cmd::SwarmLocalState,
driver::{NetworkConfig, SwarmDriver},
driver::{NetworkBuilder, SwarmDriver},
error::Error,
event::{MsgResponder, NetworkEvent},
record_store::NodeRecordStore,
Expand Down
2 changes: 1 addition & 1 deletion sn_networking/src/metrics_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use std::{

const METRICS_CONTENT_TYPE: &str = "application/openmetrics-text;charset=utf-8;version=1.0.0";

pub(crate) fn metrics_server(registry: Registry) {
pub(crate) fn run_metrics_server(registry: Registry) {
// Serve on localhost.
let addr = ([127, 0, 0, 1], 0).into();

Expand Down
18 changes: 6 additions & 12 deletions sn_node/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use prometheus_client::registry::Registry;
use rand::{rngs::StdRng, Rng, SeedableRng};
use sn_dbc::MainKey;
use sn_networking::{
MsgResponder, NetworkConfig, NetworkEvent, SwarmDriver, SwarmLocalState, CLOSE_GROUP_SIZE,
MsgResponder, NetworkBuilder, NetworkEvent, SwarmLocalState, CLOSE_GROUP_SIZE,
};
use sn_protocol::{
messages::{Cmd, CmdResponse, Query, QueryResponse, Request, Response},
Expand Down Expand Up @@ -112,18 +112,12 @@ impl Node {
(metrics_registry, node_metrics)
};

let network_cfg = NetworkConfig {
keypair,
local,
root_dir,
listen_addr: Some(addr),
request_timeout: None,
concurrency_limit: None,
#[cfg(feature = "open-metrics")]
metrics_registry,
};
let mut network_builder = NetworkBuilder::new(keypair, local, root_dir);
network_builder.listen_addr(addr);
#[cfg(feature = "open-metrics")]
network_builder.metrics_registry(metrics_registry);

let (network, mut network_event_receiver, swarm_driver) = SwarmDriver::new(network_cfg)?;
let (network, mut network_event_receiver, swarm_driver) = network_builder.build_node()?;
let node_events_channel = NodeEventsChannel::default();

let node = Self {
Expand Down

0 comments on commit 0a49270

Please sign in to comment.