Skip to content

Commit 02105e1

Browse files
committed
refactor: make the P2P/statesync impl generic over the StateSyncMessage type
1 parent b81a5d8 commit 02105e1

File tree

10 files changed

+60
-52
lines changed

10 files changed

+60
-52
lines changed

rs/interfaces/src/p2p/state_sync.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
use ic_types::{
22
artifact::StateSyncArtifactId,
33
chunkable::{Chunk, ChunkId, Chunkable},
4-
state_sync::StateSyncMessage,
54
};
65

76
pub trait StateSyncClient: Send + Sync {
7+
type Message;
8+
89
/// Returns a list of all states available.
910
fn available_states(&self) -> Vec<StateSyncArtifactId>;
1011
/// Initiates new state sync for the specified Id. Returns None if the state should not be synced.
@@ -15,13 +16,13 @@ pub trait StateSyncClient: Send + Sync {
1516
fn start_state_sync(
1617
&self,
1718
id: &StateSyncArtifactId,
18-
) -> Option<Box<dyn Chunkable<StateSyncMessage> + Send + Sync>>;
19+
) -> Option<Box<dyn Chunkable<Self::Message> + Send>>;
1920
/// Returns true if a state sync with the specified Id can be cancelled because a newer state is available.
2021
/// The result of this function is only meaningful the Id refers to a active state sync started with `start_state_sync`.
2122
/// TODO: (NET-1469) In the future this API should be made safer by only allowing the id of a previously initiated state sync.
2223
fn should_cancel(&self, id: &StateSyncArtifactId) -> bool;
2324
/// Get a specific chunk from the specified state.
2425
fn chunk(&self, id: &StateSyncArtifactId, chunk_id: ChunkId) -> Option<Chunk>;
2526
/// Finish a state sync by delivering the `StateSyncMessage` returned in `Chunkable::add_chunks`.
26-
fn deliver_state_sync(&self, msg: StateSyncMessage);
27+
fn deliver_state_sync(&self, msg: Self::Message);
2728
}

rs/p2p/state_sync_manager/src/lib.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ const ADVERT_BROADCAST_INTERVAL: Duration = Duration::from_secs(5);
5151
const ADVERT_BROADCAST_TIMEOUT: Duration =
5252
ADVERT_BROADCAST_INTERVAL.saturating_sub(Duration::from_secs(2));
5353

54-
pub fn build_axum_router(
55-
state_sync: Arc<dyn StateSyncClient>,
54+
pub fn build_axum_router<T: 'static>(
55+
state_sync: Arc<dyn StateSyncClient<Message = T>>,
5656
log: ReplicaLogger,
5757
metrics_registry: &MetricsRegistry,
5858
) -> (
@@ -81,12 +81,12 @@ pub fn build_axum_router(
8181
(app, rx)
8282
}
8383

84-
pub fn start_state_sync_manager(
84+
pub fn start_state_sync_manager<T: Send + 'static>(
8585
log: ReplicaLogger,
8686
metrics: &MetricsRegistry,
8787
rt: &Handle,
8888
transport: Arc<dyn Transport>,
89-
state_sync: Arc<dyn StateSyncClient>,
89+
state_sync: Arc<dyn StateSyncClient<Message = T>>,
9090
advert_receiver: tokio::sync::mpsc::Receiver<(StateSyncArtifactId, NodeId)>,
9191
) -> JoinHandle<()> {
9292
let state_sync_manager_metrics = StateSyncManagerMetrics::new(metrics);
@@ -102,17 +102,17 @@ pub fn start_state_sync_manager(
102102
rt.spawn(manager.run())
103103
}
104104

105-
struct StateSyncManager {
105+
struct StateSyncManager<T> {
106106
log: ReplicaLogger,
107107
rt: Handle,
108108
metrics: StateSyncManagerMetrics,
109109
transport: Arc<dyn Transport>,
110-
state_sync: Arc<dyn StateSyncClient>,
110+
state_sync: Arc<dyn StateSyncClient<Message = T>>,
111111
advert_receiver: tokio::sync::mpsc::Receiver<(StateSyncArtifactId, NodeId)>,
112112
ongoing_state_sync: Option<OngoingStateSyncHandle>,
113113
}
114114

115-
impl StateSyncManager {
115+
impl<T: 'static + Send> StateSyncManager<T> {
116116
async fn run(mut self) {
117117
let mut interval = tokio::time::interval(ADVERT_BROADCAST_INTERVAL);
118118
let mut advertise_task = JoinSet::new();
@@ -186,7 +186,7 @@ impl StateSyncManager {
186186

187187
async fn send_state_adverts(
188188
rt: Handle,
189-
state_sync: Arc<dyn StateSyncClient>,
189+
state_sync: Arc<dyn StateSyncClient<Message = T>>,
190190
transport: Arc<dyn Transport>,
191191
metrics: StateSyncManagerMetrics,
192192
) {

rs/p2p/state_sync_manager/src/metrics.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
use ic_metrics::{
22
buckets::decimal_buckets, tokio_metrics_collector::TokioTaskMetricsCollector, MetricsRegistry,
33
};
4-
use ic_types::state_sync::StateSyncMessage;
54
use prometheus::{Histogram, IntCounter, IntCounterVec, IntGauge};
65
use tokio_metrics::TaskMonitor;
76

@@ -119,10 +118,7 @@ impl OngoingStateSyncMetrics {
119118
}
120119

121120
/// Utility to record metrics for download result.
122-
pub fn record_chunk_download_result(
123-
&self,
124-
res: &Result<Option<StateSyncMessage>, DownloadChunkError>,
125-
) {
121+
pub fn record_chunk_download_result<T>(&self, res: &Result<Option<T>, DownloadChunkError>) {
126122
match res {
127123
// Received chunk
128124
Ok(Some(_)) => {

rs/p2p/state_sync_manager/src/ongoing.rs

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ use ic_types::{
2727
artifact::StateSyncArtifactId,
2828
chunkable::ChunkId,
2929
chunkable::{ArtifactErrorCode, Chunkable},
30-
state_sync::StateSyncMessage,
3130
NodeId,
3231
};
3332
use rand::{
@@ -53,7 +52,7 @@ const CHUNK_DOWNLOAD_TIMEOUT: Duration = Duration::from_secs(10);
5352
/// to last for 1000 blocks (two checkpoint intervals) -> 1000b/0.1b/s = 10000s
5453
const STATE_SYNC_TIMEOUT: Duration = Duration::from_secs(10000);
5554

56-
struct OngoingStateSync {
55+
struct OngoingStateSync<T: Send> {
5756
log: ReplicaLogger,
5857
rt: Handle,
5958
artifact_id: StateSyncArtifactId,
@@ -67,10 +66,10 @@ struct OngoingStateSync {
6766
allowed_downloads: usize,
6867
chunks_to_download: Box<dyn Iterator<Item = ChunkId> + Send>,
6968
// Event tasks
70-
downloading_chunks: JoinMap<ChunkId, DownloadResult>,
69+
downloading_chunks: JoinMap<ChunkId, DownloadResult<T>>,
7170
// State sync
72-
state_sync: Arc<dyn StateSyncClient>,
73-
tracker: Arc<Mutex<Box<dyn Chunkable<StateSyncMessage> + Send + Sync>>>,
71+
state_sync: Arc<dyn StateSyncClient<Message = T>>,
72+
tracker: Arc<Mutex<Box<dyn Chunkable<T> + Send>>>,
7473
state_sync_finished: bool,
7574
}
7675

@@ -79,18 +78,18 @@ pub(crate) struct OngoingStateSyncHandle {
7978
pub jh: JoinHandle<()>,
8079
}
8180

82-
pub(crate) struct DownloadResult {
81+
pub(crate) struct DownloadResult<T> {
8382
peer_id: NodeId,
84-
result: Result<Option<StateSyncMessage>, DownloadChunkError>,
83+
result: Result<Option<T>, DownloadChunkError>,
8584
}
8685

87-
pub(crate) fn start_ongoing_state_sync(
86+
pub(crate) fn start_ongoing_state_sync<T: Send + 'static>(
8887
log: ReplicaLogger,
8988
rt: &Handle,
9089
metrics: OngoingStateSyncMetrics,
91-
tracker: Arc<Mutex<Box<dyn Chunkable<StateSyncMessage> + Send + Sync>>>,
90+
tracker: Arc<Mutex<Box<dyn Chunkable<T> + Send>>>,
9291
artifact_id: StateSyncArtifactId,
93-
state_sync: Arc<dyn StateSyncClient>,
92+
state_sync: Arc<dyn StateSyncClient<Message = T>>,
9493
transport: Arc<dyn Transport>,
9594
) -> OngoingStateSyncHandle {
9695
let (new_peers_tx, new_peers_rx) = tokio::sync::mpsc::channel(ONGOING_STATE_SYNC_CHANNEL_SIZE);
@@ -117,7 +116,7 @@ pub(crate) fn start_ongoing_state_sync(
117116
}
118117
}
119118

120-
impl OngoingStateSync {
119+
impl<T: 'static + Send> OngoingStateSync<T> {
121120
pub async fn run(mut self) {
122121
let state_sync_timeout = tokio::time::sleep(STATE_SYNC_TIMEOUT);
123122
tokio::pin!(state_sync_timeout);
@@ -193,7 +192,7 @@ impl OngoingStateSync {
193192

194193
async fn handle_downloaded_chunk_result(
195194
&mut self,
196-
DownloadResult { peer_id, result }: DownloadResult,
195+
DownloadResult { peer_id, result }: DownloadResult<T>,
197196
) {
198197
self.metrics.record_chunk_download_result(&result);
199198
match result {
@@ -296,11 +295,11 @@ impl OngoingStateSync {
296295
async fn download_chunk_task(
297296
peer_id: NodeId,
298297
client: Arc<dyn Transport>,
299-
tracker: Arc<Mutex<Box<dyn Chunkable<StateSyncMessage> + Send + Sync>>>,
298+
tracker: Arc<Mutex<Box<dyn Chunkable<T> + Send>>>,
300299
artifact_id: StateSyncArtifactId,
301300
chunk_id: ChunkId,
302301
metrics: OngoingStateSyncMetrics,
303-
) -> DownloadResult {
302+
) -> DownloadResult<T> {
304303
let _timer = metrics.chunk_download_duration.start_timer();
305304

306305
let response_result = tokio::time::timeout(
@@ -386,6 +385,9 @@ mod tests {
386385

387386
use super::*;
388387

388+
#[derive(Clone)]
389+
struct TestMessage;
390+
389391
fn compress_empty_bytes() -> Bytes {
390392
let mut raw = BytesMut::new();
391393
Bytes::new()
@@ -409,7 +411,7 @@ mod tests {
409411
.body(compress_empty_bytes())
410412
.unwrap())
411413
});
412-
let mut c = MockChunkable::<StateSyncMessage>::default();
414+
let mut c = MockChunkable::<TestMessage>::default();
413415
c.expect_chunks_to_download()
414416
.returning(|| Box::new(std::iter::once(ChunkId::from(1))));
415417

@@ -448,7 +450,7 @@ mod tests {
448450
.body(compress_empty_bytes())
449451
.unwrap())
450452
});
451-
let mut c = MockChunkable::<StateSyncMessage>::default();
453+
let mut c = MockChunkable::<TestMessage>::default();
452454
c.expect_chunks_to_download()
453455
.returning(|| Box::new(std::iter::once(ChunkId::from(1))));
454456
c.expect_add_chunk()
@@ -494,7 +496,7 @@ mod tests {
494496
.body(compress_empty_bytes())
495497
.unwrap())
496498
});
497-
let mut c = MockChunkable::<StateSyncMessage>::default();
499+
let mut c = MockChunkable::<TestMessage>::default();
498500
// Endless iterator
499501
c.expect_chunks_to_download()
500502
.returning(|| Box::new(std::iter::once(ChunkId::from(1))));

rs/p2p/state_sync_manager/src/routes/chunk.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,16 @@ pub const STATE_SYNC_CHUNK_PATH: &str = "/state-sync/chunk";
2222
/// State sync uses 1Mb chunks. To be safe we use 8Mib here same as transport.
2323
const MAX_CHUNK_SIZE: usize = 8 * 1024 * 1024;
2424

25-
pub(crate) struct StateSyncChunkHandler {
25+
pub(crate) struct StateSyncChunkHandler<T> {
2626
_log: ReplicaLogger,
27-
state_sync: Arc<dyn StateSyncClient>,
27+
state_sync: Arc<dyn StateSyncClient<Message = T>>,
2828
metrics: StateSyncManagerHandlerMetrics,
2929
}
3030

31-
impl StateSyncChunkHandler {
31+
impl<T> StateSyncChunkHandler<T> {
3232
pub fn new(
3333
log: ReplicaLogger,
34-
state_sync: Arc<dyn StateSyncClient>,
34+
state_sync: Arc<dyn StateSyncClient<Message = T>>,
3535
metrics: StateSyncManagerHandlerMetrics,
3636
) -> Self {
3737
Self {
@@ -42,8 +42,8 @@ impl StateSyncChunkHandler {
4242
}
4343
}
4444

45-
pub(crate) async fn state_sync_chunk_handler(
46-
State(state): State<Arc<StateSyncChunkHandler>>,
45+
pub(crate) async fn state_sync_chunk_handler<T: 'static>(
46+
State(state): State<Arc<StateSyncChunkHandler<T>>>,
4747
payload: Bytes,
4848
) -> Result<Bytes, StatusCode> {
4949
// Parse payload

rs/p2p/state_sync_manager/tests/common.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ impl FakeStateSync {
183183
}
184184

185185
impl StateSyncClient for FakeStateSync {
186+
type Message = StateSyncMessage;
187+
186188
fn available_states(&self) -> Vec<StateSyncArtifactId> {
187189
if self.disconnected.load(Ordering::SeqCst) {
188190
return vec![];
@@ -197,7 +199,7 @@ impl StateSyncClient for FakeStateSync {
197199
fn start_state_sync(
198200
&self,
199201
id: &StateSyncArtifactId,
200-
) -> Option<Box<dyn Chunkable<StateSyncMessage> + Send + Sync>> {
202+
) -> Option<Box<dyn Chunkable<StateSyncMessage> + Send>> {
201203
if !self.uses_global() && id.height > self.current_height() && !self.disconnected() {
202204
return Some(Box::new(FakeChunkable::new(
203205
self.local_state.clone(),
@@ -352,7 +354,7 @@ impl Chunkable<StateSyncMessage> for SharableMockChunkable {
352354

353355
#[derive(Clone, Default)]
354356
pub struct SharableMockStateSync {
355-
mock: Arc<Mutex<MockStateSync>>,
357+
mock: Arc<Mutex<MockStateSync<StateSyncMessage>>>,
356358
available_states_calls: Arc<AtomicUsize>,
357359
start_state_sync_calls: Arc<AtomicUsize>,
358360
should_cancel_calls: Arc<AtomicUsize>,
@@ -367,7 +369,7 @@ impl SharableMockStateSync {
367369
..Default::default()
368370
}
369371
}
370-
pub fn get_mut(&self) -> MutexGuard<'_, MockStateSync> {
372+
pub fn get_mut(&self) -> MutexGuard<'_, MockStateSync<StateSyncMessage>> {
371373
self.mock.lock().unwrap()
372374
}
373375
pub fn start_state_sync_calls(&self) -> usize {
@@ -383,14 +385,16 @@ impl SharableMockStateSync {
383385
}
384386

385387
impl StateSyncClient for SharableMockStateSync {
388+
type Message = StateSyncMessage;
389+
386390
fn available_states(&self) -> Vec<StateSyncArtifactId> {
387391
self.available_states_calls.fetch_add(1, Ordering::SeqCst);
388392
self.mock.lock().unwrap().available_states()
389393
}
390394
fn start_state_sync(
391395
&self,
392396
id: &StateSyncArtifactId,
393-
) -> Option<Box<dyn Chunkable<StateSyncMessage> + Send + Sync>> {
397+
) -> Option<Box<dyn Chunkable<StateSyncMessage> + Send>> {
394398
self.start_state_sync_calls.fetch_add(1, Ordering::SeqCst);
395399
self.mock.lock().unwrap().start_state_sync(id)
396400
}

rs/p2p/test_utils/src/mocks.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,28 @@ use ic_types::{
1212
artifact::StateSyncArtifactId,
1313
chunkable::{ArtifactErrorCode, Chunkable},
1414
chunkable::{Chunk, ChunkId},
15-
state_sync::StateSyncMessage,
1615
NodeId,
1716
};
1817
use mockall::mock;
1918

2019
mock! {
21-
pub StateSync {}
20+
pub StateSync<T: Send> {}
21+
22+
impl<T: Send + Sync> StateSyncClient for StateSync<T> {
23+
type Message = T;
2224

23-
impl StateSyncClient for StateSync {
2425
fn available_states(&self) -> Vec<StateSyncArtifactId>;
2526

2627
fn start_state_sync(
2728
&self,
2829
id: &StateSyncArtifactId,
29-
) -> Option<Box<dyn Chunkable<StateSyncMessage> + Send + Sync>>;
30+
) -> Option<Box<dyn Chunkable<T> + Send>>;
3031

3132
fn should_cancel(&self, id: &StateSyncArtifactId) -> bool;
3233

3334
fn chunk(&self, id: &StateSyncArtifactId, chunk_id: ChunkId) -> Option<Chunk>;
3435

35-
fn deliver_state_sync(&self, msg: StateSyncMessage);
36+
fn deliver_state_sync(&self, msg: T);
3637
}
3738
}
3839

rs/p2p/test_utils/src/turmoil.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use ic_logger::ReplicaLogger;
2727
use ic_metrics::MetricsRegistry;
2828
use ic_peer_manager::SubnetTopology;
2929
use ic_quic_transport::{QuicTransport, Transport};
30+
use ic_types::state_sync::StateSyncMessage;
3031
use ic_types::{artifact::UnvalidatedArtifactMutation, NodeId, RegistryVersion};
3132
use ic_types_test_utils::ids::SUBNET_1;
3233
use quinn::{
@@ -271,7 +272,7 @@ pub fn add_transport_to_sim<F>(
271272
conn_checker: Option<Router>,
272273
crypto: Option<Arc<dyn TlsConfig + Send + Sync>>,
273274
sev: Option<Arc<dyn ValidateAttestedStream<Box<dyn TlsStream>> + Send + Sync>>,
274-
state_sync_client: Option<Arc<dyn StateSyncClient>>,
275+
state_sync_client: Option<Arc<dyn StateSyncClient<Message = StateSyncMessage>>>,
275276
consensus_manager: Option<TestConsensus<U64Artifact>>,
276277
post_setup_future: F,
277278
) where

rs/replica/setup_ic_network/src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ use ic_types::{
6565
malicious_flags::MaliciousFlags,
6666
p2p::GossipAdvert,
6767
replica_config::ReplicaConfig,
68+
state_sync::StateSyncMessage,
6869
NodeId, SubnetId,
6970
};
7071
use std::{
@@ -147,7 +148,7 @@ pub fn setup_consensus_and_p2p(
147148
state_reader: Arc<dyn StateReader<State = ReplicatedState>>,
148149
consensus_pool: Arc<RwLock<ConsensusPoolImpl>>,
149150
catch_up_package: CatchUpPackage,
150-
state_sync_client: Arc<dyn StateSyncClient>,
151+
state_sync_client: Arc<dyn StateSyncClient<Message = StateSyncMessage>>,
151152
xnet_payload_builder: Arc<dyn XNetPayloadBuilder>,
152153
self_validating_payload_builder: Arc<dyn SelfValidatingPayloadBuilder>,
153154
query_stats_payload_builder: Box<dyn BatchPayloadBuilder>,

0 commit comments

Comments
 (0)