Skip to content

Commit

Permalink
Merge branch 'rumenov/NET-1533' into 'master'
Browse files Browse the repository at this point in the history
refactor: NET-1533 remove the redundant AllowClients type

Closes NET-1533 

Closes NET-1533

See merge request dfinity-lab/public/ic!15600
  • Loading branch information
rumenov committed Oct 30, 2023
2 parents 5983771 + 88dc685 commit 7dbe6dc
Show file tree
Hide file tree
Showing 15 changed files with 44 additions and 90 deletions.
6 changes: 3 additions & 3 deletions rs/crypto/src/tls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::*;
use async_trait::async_trait;
use ic_crypto_internal_logmon::metrics::{MetricsDomain, MetricsResult, MetricsScope};
use ic_crypto_tls_interfaces::{
AllowedClients, AuthenticatedPeer, TlsClientHandshakeError, TlsConfig, TlsConfigError,
AuthenticatedPeer, SomeOrAllNodes, TlsClientHandshakeError, TlsConfig, TlsConfigError,
TlsHandshake, TlsPublicKeyCert, TlsServerHandshakeError, TlsStream,
};
use ic_logger::{debug, new_logger};
Expand All @@ -20,7 +20,7 @@ where
{
fn server_config(
&self,
allowed_clients: AllowedClients,
allowed_clients: SomeOrAllNodes,
registry_version: RegistryVersion,
) -> Result<tokio_rustls::rustls::ServerConfig, TlsConfigError> {
let log_id = get_log_id(&self.logger, module_path!());
Expand Down Expand Up @@ -142,7 +142,7 @@ where
async fn perform_tls_server_handshake(
&self,
tcp_stream: TcpStream,
allowed_clients: AllowedClients,
allowed_clients: SomeOrAllNodes,
registry_version: RegistryVersion,
) -> Result<(Box<dyn TlsStream>, AuthenticatedPeer), TlsServerHandshakeError> {
let log_id = get_log_id(&self.logger, module_path!());
Expand Down
8 changes: 4 additions & 4 deletions rs/crypto/src/tls/rustls/server_handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::tls::tls_cert_from_registry;
use ic_crypto_internal_csp::api::CspTlsHandshakeSignerProvider;
use ic_crypto_internal_csp::key_id::KeyId;
use ic_crypto_tls_interfaces::{
AllowedClients, AuthenticatedPeer, TlsConfigError, TlsPublicKeyCert, TlsServerHandshakeError,
AuthenticatedPeer, SomeOrAllNodes, TlsConfigError, TlsPublicKeyCert, TlsServerHandshakeError,
TlsStream,
};
use ic_crypto_utils_tls::{
Expand All @@ -31,7 +31,7 @@ pub fn server_config<P: CspTlsHandshakeSignerProvider>(
signer_provider: &P,
self_node_id: NodeId,
registry_client: Arc<dyn RegistryClient>,
allowed_clients: AllowedClients,
allowed_clients: SomeOrAllNodes,
registry_version: RegistryVersion,
) -> Result<ServerConfig, TlsConfigError> {
let self_tls_cert =
Expand All @@ -42,7 +42,7 @@ pub fn server_config<P: CspTlsHandshakeSignerProvider>(
}
})?;
let client_cert_verifier = NodeClientCertVerifier::new_with_mandatory_client_auth(
allowed_clients.nodes().clone(),
allowed_clients.clone(),
registry_client,
registry_version,
);
Expand Down Expand Up @@ -84,7 +84,7 @@ pub async fn perform_tls_server_handshake<P: CspTlsHandshakeSignerProvider>(
self_node_id: NodeId,
registry_client: Arc<dyn RegistryClient>,
tcp_stream: TcpStream,
allowed_clients: AllowedClients,
allowed_clients: SomeOrAllNodes,
registry_version: RegistryVersion,
) -> Result<(Box<dyn TlsStream>, AuthenticatedPeer), TlsServerHandshakeError> {
let config = server_config(
Expand Down
6 changes: 3 additions & 3 deletions rs/crypto/temp_crypto/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub mod internal {
RemoteVaultEnvironment, TempCspVaultServer, TokioRuntimeOrHandle,
};
use ic_crypto_tls_interfaces::{
AllowedClients, AuthenticatedPeer, TlsClientHandshakeError, TlsConfig, TlsConfigError,
AuthenticatedPeer, SomeOrAllNodes, TlsClientHandshakeError, TlsConfig, TlsConfigError,
TlsHandshake, TlsPublicKeyCert, TlsServerHandshakeError, TlsStream,
};
use ic_crypto_utils_basic_sig::conversions::derive_node_id;
Expand Down Expand Up @@ -684,7 +684,7 @@ pub mod internal {
async fn perform_tls_server_handshake(
&self,
tcp_stream: TcpStream,
allowed_clients: AllowedClients,
allowed_clients: SomeOrAllNodes,
registry_version: RegistryVersion,
) -> Result<(Box<dyn TlsStream>, AuthenticatedPeer), TlsServerHandshakeError> {
self.crypto_component
Expand Down Expand Up @@ -719,7 +719,7 @@ pub mod internal {
{
fn server_config(
&self,
allowed_clients: AllowedClients,
allowed_clients: SomeOrAllNodes,
registry_version: RegistryVersion,
) -> Result<ServerConfig, TlsConfigError> {
self.crypto_component
Expand Down
13 changes: 6 additions & 7 deletions rs/crypto/tests/tls_utils/test_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::tls_utils::{temp_crypto_component_with_tls_keys, REG_V1};
use ic_crypto_temp_crypto::TempCryptoComponent;
use ic_crypto_tls_interfaces::TlsPublicKeyCert;
use ic_crypto_tls_interfaces::{
AllowedClients, AuthenticatedPeer, SomeOrAllNodes, TlsHandshake, TlsServerHandshakeError,
AuthenticatedPeer, SomeOrAllNodes, TlsHandshake, TlsServerHandshakeError,
};
use ic_protobuf::registry::crypto::v1::X509PublicKeyCert;
use ic_registry_client_fake::FakeRegistryClient;
Expand Down Expand Up @@ -68,10 +68,9 @@ impl ServerBuilder {
pub fn build(self, registry: Arc<FakeRegistryClient>) -> Server {
let listener = std::net::TcpListener::bind(("0.0.0.0", 0)).expect("failed to bind");
let (crypto, cert) = temp_crypto_component_with_tls_keys(registry, self.node_id);
let allowed_clients = AllowedClients::new(
self.allowed_nodes
.unwrap_or_else(|| SomeOrAllNodes::Some(BTreeSet::new())),
);
let allowed_clients = self
.allowed_nodes
.unwrap_or_else(|| SomeOrAllNodes::Some(BTreeSet::new()));
Server {
listener,
crypto,
Expand All @@ -88,7 +87,7 @@ impl ServerBuilder {
pub struct Server {
listener: std::net::TcpListener,
crypto: TempCryptoComponent,
allowed_clients: AllowedClients,
allowed_clients: SomeOrAllNodes,
msg_for_client: Option<String>,
msg_expected_from_client: Option<String>,
cert: TlsPublicKeyCert,
Expand Down Expand Up @@ -194,7 +193,7 @@ impl Server {
}

pub fn allowed_clients(&self) -> &BTreeSet<NodeId> {
match self.allowed_clients.nodes() {
match &self.allowed_clients {
SomeOrAllNodes::Some(nodes) => nodes,
SomeOrAllNodes::All => unimplemented!(),
}
Expand Down
4 changes: 2 additions & 2 deletions rs/crypto/tls_interfaces/mocks/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use async_trait::async_trait;
use ic_base_types::{NodeId, RegistryVersion};
use ic_crypto_tls_interfaces::{
AllowedClients, AuthenticatedPeer, TlsClientHandshakeError, TlsHandshake,
AuthenticatedPeer, SomeOrAllNodes, TlsClientHandshakeError, TlsHandshake,
TlsServerHandshakeError, TlsStream,
};
use mockall::*;
Expand All @@ -15,7 +15,7 @@ mock! {
async fn perform_tls_server_handshake(
&self,
tcp_stream: TcpStream,
allowed_clients: AllowedClients,
allowed_clients: SomeOrAllNodes,
registry_version: RegistryVersion,
) -> Result<(Box<dyn TlsStream>, AuthenticatedPeer), TlsServerHandshakeError>;

Expand Down
29 changes: 3 additions & 26 deletions rs/crypto/tls_interfaces/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ pub trait TlsHandshake {
async fn perform_tls_server_handshake(
&self,
tcp_stream: TcpStream,
allowed_clients: AllowedClients,
allowed_clients: SomeOrAllNodes,
registry_version: RegistryVersion,
) -> Result<(Box<dyn TlsStream>, AuthenticatedPeer), TlsServerHandshakeError>;

Expand Down Expand Up @@ -426,7 +426,7 @@ pub trait TlsConfig {
/// is an error in the setup of the node and registry.
fn server_config(
&self,
allowed_clients: AllowedClients,
allowed_clients: SomeOrAllNodes,
registry_version: RegistryVersion,
) -> Result<ServerConfig, TlsConfigError>;

Expand Down Expand Up @@ -547,31 +547,8 @@ impl From<TlsConfigError> for TlsServerHandshakeError {
}
}

#[derive(Clone, Debug)]
/// A list of allowed TLS peers, which can be `All` to allow any node to connect.
pub struct AllowedClients {
nodes: SomeOrAllNodes,
}

impl AllowedClients {
pub fn new(nodes: SomeOrAllNodes) -> Self {
Self { nodes }
}

/// Create an `AllowedClients` with a set of nodes.
pub fn new_with_nodes(node_ids: BTreeSet<NodeId>) -> Self {
Self::new(SomeOrAllNodes::Some(node_ids))
}

/// Access the allowed nodes.
pub fn nodes(&self) -> &SomeOrAllNodes {
&self.nodes
}
}

#[derive(Clone, Debug, PartialEq, Eq)]
/// A list of node IDs, or "all nodes"
// TODO: NET-1533
/// A list of node IDs or all nodes present in the registry.
pub enum SomeOrAllNodes {
Some(BTreeSet<NodeId>),
All,
Expand Down
20 changes: 1 addition & 19 deletions rs/crypto/tls_interfaces/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,29 +137,11 @@ mod tls_public_key_cert {
}

mod allowed_clients {
use crate::{AllowedClients, SomeOrAllNodes};
use crate::SomeOrAllNodes;
use assert_matches::assert_matches;
use ic_types::{NodeId, PrincipalId};
use maplit::btreeset;

#[test]
fn should_correctly_construct_with_new() {
let nodes = SomeOrAllNodes::Some(btreeset! {node_id(1)});

let allowed_clients = AllowedClients::new(nodes.clone());

assert_eq!(allowed_clients.nodes(), &nodes);
}

#[test]
fn should_correctly_construct_with_new_with_nodes() {
let nodes = btreeset! {node_id(1)};

let allowed_clients = AllowedClients::new_with_nodes(nodes.clone());

assert_eq!(allowed_clients.nodes(), &SomeOrAllNodes::Some(nodes));
}

#[test]
fn should_contain_any_node_in_all_nodes() {
let all_nodes = SomeOrAllNodes::All;
Expand Down
4 changes: 2 additions & 2 deletions rs/orchestrator/src/registration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -741,8 +741,8 @@ mod tests {
use super::*;
use async_trait::async_trait;
use ic_crypto_temp_crypto::EcdsaSubnetConfig;
use ic_crypto_tls_interfaces::AllowedClients;
use ic_crypto_tls_interfaces::AuthenticatedPeer;
use ic_crypto_tls_interfaces::SomeOrAllNodes;
use ic_crypto_tls_interfaces::TlsClientHandshakeError;
use ic_crypto_tls_interfaces::TlsHandshake;
use ic_crypto_tls_interfaces::TlsServerHandshakeError;
Expand Down Expand Up @@ -822,7 +822,7 @@ mod tests {
async fn perform_tls_server_handshake(
&self,
tcp_stream: TcpStream,
allowed_clients: AllowedClients,
allowed_clients: SomeOrAllNodes,
registry_version: RegistryVersion,
) -> Result<(Box<dyn TlsStream>, AuthenticatedPeer), TlsServerHandshakeError>;

Expand Down
15 changes: 6 additions & 9 deletions rs/p2p/quic_transport/src/connection_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ use futures::StreamExt;
use ic_async_utils::JoinMap;
use ic_base_types::{NodeId, RegistryVersion};
use ic_crypto_tls_interfaces::{
AllowedClients, MalformedPeerCertificateError, SomeOrAllNodes, TlsConfig, TlsConfigError,
TlsStream,
MalformedPeerCertificateError, SomeOrAllNodes, TlsConfig, TlsConfigError, TlsStream,
};
use ic_crypto_utils_tls::{
node_id_from_cert_subject_common_name, tls_pubkey_cert_from_rustls_certs,
Expand Down Expand Up @@ -216,9 +215,7 @@ pub(crate) fn start_connection_manager(
let endpoint_config = EndpointConfig::default();
let rustls_server_config = tls_config
.server_config(
AllowedClients::new(ic_crypto_tls_interfaces::SomeOrAllNodes::Some(
BTreeSet::new(),
)),
SomeOrAllNodes::Some(BTreeSet::new()),
registry_client.get_latest_version(),
)
.unwrap();
Expand Down Expand Up @@ -419,10 +416,10 @@ impl ConnectionManager {
let subnet_nodes = SomeOrAllNodes::Some(subnet_node_set);

// Set new server config to only accept connections from the current set.
match self.tls_config.server_config(
AllowedClients::new(subnet_nodes),
self.topology.latest_registry_version(),
) {
match self
.tls_config
.server_config(subnet_nodes, self.topology.latest_registry_version())
{
Ok(rustls_server_config) => {
let mut server_config =
quinn::ServerConfig::with_crypto(Arc::new(rustls_server_config));
Expand Down
6 changes: 3 additions & 3 deletions rs/p2p/quic_transport/tests/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{
};

use ic_base_types::{NodeId, RegistryVersion};
use ic_crypto_tls_interfaces::{AllowedClients, TlsConfig, TlsConfigError};
use ic_crypto_tls_interfaces::{SomeOrAllNodes, TlsConfig, TlsConfigError};
use ic_icos_sev_interfaces::{ValidateAttestationError, ValidateAttestedStream};
use ic_p2p_test_utils::{temp_crypto_component_with_tls_keys, RegistryConsensusHandle};
use tokio::io::{AsyncRead, AsyncWrite};
Expand Down Expand Up @@ -77,10 +77,10 @@ impl PeerRestrictedTlsConfig {
impl TlsConfig for PeerRestrictedTlsConfig {
fn server_config(
&self,
_allowed_clients: AllowedClients,
_allowed_clients: SomeOrAllNodes,
registry_version: RegistryVersion,
) -> Result<ServerConfig, TlsConfigError> {
let allowed_clients = AllowedClients::new_with_nodes(BTreeSet::from_iter(
let allowed_clients = SomeOrAllNodes::Some(BTreeSet::from_iter(
self.allowed_peers.lock().unwrap().clone().into_iter(),
));
self.crypto.server_config(allowed_clients, registry_version)
Expand Down
4 changes: 2 additions & 2 deletions rs/test_utilities/src/crypto.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
pub mod fake_tls_handshake;

pub use ic_crypto_test_utils::files as temp_dir;
use ic_crypto_tls_interfaces::{AllowedClients, TlsConfig, TlsConfigError};
use ic_crypto_tls_interfaces::{SomeOrAllNodes, TlsConfig, TlsConfigError};
use tokio_rustls::rustls::{ClientConfig, PrivateKey, RootCertStore, ServerConfig};

use crate::types::ids::node_test_id;
Expand Down Expand Up @@ -492,7 +492,7 @@ impl ThresholdEcdsaSigVerifier for CryptoReturningOk {
impl TlsConfig for CryptoReturningOk {
fn server_config(
&self,
_allowed_clients: AllowedClients,
_allowed_clients: SomeOrAllNodes,
_registry_version: RegistryVersion,
) -> Result<ServerConfig, TlsConfigError> {
Ok(ServerConfig::builder()
Expand Down
4 changes: 2 additions & 2 deletions rs/test_utilities/src/crypto/fake_tls_handshake.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use async_trait::async_trait;
use ic_crypto_tls_interfaces::{
AllowedClients, AuthenticatedPeer, TlsClientHandshakeError, TlsHandshake,
AuthenticatedPeer, SomeOrAllNodes, TlsClientHandshakeError, TlsHandshake,
TlsServerHandshakeError, TlsStream,
};
use ic_types::{NodeId, RegistryVersion};
Expand All @@ -27,7 +27,7 @@ impl TlsHandshake for FakeTlsHandshake {
async fn perform_tls_server_handshake(
&self,
_tcp_stream: TcpStream,
_allowed_clients: AllowedClients,
_allowed_clients: SomeOrAllNodes,
_registry_version: RegistryVersion,
) -> Result<(Box<dyn TlsStream>, AuthenticatedPeer), TlsServerHandshakeError> {
unimplemented!()
Expand Down
4 changes: 2 additions & 2 deletions rs/transport/src/control_plane.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::{
};
use ic_async_utils::start_tcp_listener;
use ic_base_types::{NodeId, RegistryVersion};
use ic_crypto_tls_interfaces::{AllowedClients, AuthenticatedPeer, TlsStream};
use ic_crypto_tls_interfaces::{AuthenticatedPeer, SomeOrAllNodes, TlsStream};
use ic_interfaces_transport::{TransportChannelId, TransportEvent, TransportEventHandler};
use ic_logger::{error, warn};
use std::{net::SocketAddr, time::Duration};
Expand Down Expand Up @@ -434,7 +434,7 @@ impl TransportImpl {
) -> Result<(NodeId, Box<dyn TlsStream>), TransportTlsHandshakeError> {
let latest_registry_version = *self.latest_registry_version.read().await;
let current_allowed_clients = self.allowed_clients.read().await.clone();
let allowed_clients = AllowedClients::new_with_nodes(current_allowed_clients);
let allowed_clients = SomeOrAllNodes::Some(current_allowed_clients);
let (tls_stream, authenticated_peer) = match tokio::time::timeout(
Duration::from_secs(TLS_HANDSHAKE_TIMEOUT_SECONDS),
self.crypto.perform_tls_server_handshake(
Expand Down
6 changes: 3 additions & 3 deletions rs/transport/tests/tls_tests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use ic_base_types::{NodeId, RegistryVersion};
use ic_crypto_tls_interfaces::{
AllowedClients, TlsClientHandshakeError, TlsHandshake, TlsServerHandshakeError,
SomeOrAllNodes, TlsClientHandshakeError, TlsHandshake, TlsServerHandshakeError,
};
use ic_crypto_tls_interfaces_mocks::MockTlsHandshake;
use ic_interfaces_transport::TransportEvent;
Expand Down Expand Up @@ -181,7 +181,7 @@ fn test_single_transient_failure_of_tls_server_handshake_impl(use_h2: bool) {
.times(1)
.returning({
move |_tcp_stream: TcpStream,
_allowed_clients: AllowedClients,
_allowed_clients: SomeOrAllNodes,
_registry_version: RegistryVersion| {
Err(TlsServerHandshakeError::HandshakeError {
internal_error: "transient".to_string(),
Expand All @@ -194,7 +194,7 @@ fn test_single_transient_failure_of_tls_server_handshake_impl(use_h2: bool) {
.times(1)
.returning(
move |tcp_stream: TcpStream,
allowed_clients: AllowedClients,
allowed_clients: SomeOrAllNodes,
registry_version: RegistryVersion| {
let rt_handle = rt_handle.clone();
let crypto = crypto.clone();
Expand Down

0 comments on commit 7dbe6dc

Please sign in to comment.