diff --git a/openraft/src/engine/handler/following_handler/mod.rs b/openraft/src/engine/handler/following_handler/mod.rs index aef56dbfb..b8b6d0540 100644 --- a/openraft/src/engine/handler/following_handler/mod.rs +++ b/openraft/src/engine/handler/following_handler/mod.rs @@ -31,6 +31,7 @@ use crate::StoredMembership; #[cfg(test)] mod commit_entries_test; #[cfg(test)] mod do_append_entries_test; #[cfg(test)] mod install_snapshot_test; +#[cfg(test)] mod receive_snapshot_chunk_test; #[cfg(test)] mod truncate_logs_test; #[cfg(test)] mod update_committed_membership_test; @@ -248,10 +249,9 @@ where C: RaftTypeConfig &mut self, req: InstallSnapshotRequest, ) -> Result<(), InstallSnapshotError> { - // TODO: add unit test tracing::info!(req = display(req.summary()), "{}", func_name!()); - let snapshot_meta = &req.meta; + let snapshot_id = &req.meta.snapshot_id; let curr_id = self.state.snapshot_streaming.as_ref().map(|s| &s.snapshot_id); @@ -260,11 +260,11 @@ where C: RaftTypeConfig if req.offset > 0 { let mismatch = SnapshotMismatch { expect: SnapshotSegmentId { - id: snapshot_meta.snapshot_id.clone(), + id: snapshot_id.clone(), offset: 0, }, got: SnapshotSegmentId { - id: snapshot_meta.snapshot_id.clone(), + id: snapshot_id.clone(), offset: req.offset, }, }; @@ -276,7 +276,7 @@ where C: RaftTypeConfig if req.offset == 0 { self.state.snapshot_streaming = Some(StreamingState { offset: 0, - snapshot_id: req.meta.snapshot_id.clone(), + snapshot_id: snapshot_id.clone(), }); } } diff --git a/openraft/src/engine/handler/following_handler/receive_snapshot_chunk_test.rs b/openraft/src/engine/handler/following_handler/receive_snapshot_chunk_test.rs new file mode 100644 index 000000000..a6736f3b5 --- /dev/null +++ b/openraft/src/engine/handler/following_handler/receive_snapshot_chunk_test.rs @@ -0,0 +1,164 @@ +use maplit::btreeset; +use pretty_assertions::assert_eq; + +use crate::core::sm; +use crate::engine::testing::UTConfig; +use crate::engine::Command; +use crate::engine::Engine; +use crate::error::InstallSnapshotError; +use crate::error::SnapshotMismatch; +use crate::raft::InstallSnapshotRequest; +use crate::raft_state::StreamingState; +use crate::testing::log_id1; +use crate::Membership; +use crate::SnapshotMeta; +use crate::SnapshotSegmentId; +use crate::StoredMembership; +use crate::Vote; + +fn m1234() -> Membership { + Membership::::new(vec![btreeset! {1,2,3,4}], None) +} + +fn eng() -> Engine { + let mut eng = Engine::default(); + eng.state.enable_validate = false; // Disable validation for incomplete state + + eng.state.vote.update(*eng.timer.now(), Vote::new_committed(2, 1)); + eng.state.server_state = eng.calc_server_state(); + + eng +} + +fn make_meta() -> SnapshotMeta { + SnapshotMeta { + last_log_id: Some(log_id1(2, 2)), + last_membership: StoredMembership::new(Some(log_id1(1, 1)), m1234()), + snapshot_id: "1-2-3-4".to_string(), + } +} + +fn make_req(offset: u64) -> InstallSnapshotRequest { + InstallSnapshotRequest { + vote: Vote::new_committed(2, 1), + meta: make_meta(), + offset, + data: vec![], + done: false, + } +} + +#[test] +fn test_receive_snapshot_chunk_new_chunk() -> anyhow::Result<()> { + let mut eng = eng(); + assert!(eng.state.snapshot_streaming.is_none()); + + eng.following_handler().receive_snapshot_chunk(make_req(0))?; + + assert_eq!( + Some(StreamingState { + offset: 0, + snapshot_id: "1-2-3-4".to_string(), + }), + eng.state.snapshot_streaming + ); + assert_eq!( + vec![Command::from(sm::Command::receive(make_req(0)).with_seq(1))], + eng.output.take_commands() + ); + + Ok(()) +} + +#[test] +fn test_receive_snapshot_chunk_continue_receive_chunk() -> anyhow::Result<()> { + let mut eng = eng(); + + eng.state.snapshot_streaming = Some(StreamingState { + offset: 0, + snapshot_id: "1-2-3-4".to_string(), + }); + + eng.following_handler().receive_snapshot_chunk(make_req(2))?; + + assert_eq!( + Some(StreamingState { + offset: 2, + snapshot_id: "1-2-3-4".to_string(), + }), + eng.state.snapshot_streaming + ); + assert_eq!( + vec![Command::from(sm::Command::receive(make_req(2)).with_seq(1))], + eng.output.take_commands() + ); + + Ok(()) +} + +#[test] +fn test_receive_snapshot_chunk_diff_id_offset_0() -> anyhow::Result<()> { + // When receiving a chunk with different snapshot id and offset 0, starts a new snapshot streaming. + let mut eng = eng(); + + eng.state.snapshot_streaming = Some(StreamingState { + offset: 2, + snapshot_id: "1-2-3-100".to_string(), + }); + + eng.following_handler().receive_snapshot_chunk(make_req(0))?; + + assert_eq!( + Some(StreamingState { + offset: 0, + snapshot_id: "1-2-3-4".to_string(), + }), + eng.state.snapshot_streaming + ); + assert_eq!( + vec![Command::from(sm::Command::receive(make_req(0)).with_seq(1))], + eng.output.take_commands() + ); + + Ok(()) +} + +#[test] +fn test_receive_snapshot_chunk_diff_id_offset_gt_0() -> anyhow::Result<()> { + // When receiving a chunk with different snapshot id and offset that is greater than 0, return an + // error. + let mut eng = eng(); + + eng.state.snapshot_streaming = Some(StreamingState { + offset: 2, + snapshot_id: "1-2-3-100".to_string(), + }); + + let res = eng.following_handler().receive_snapshot_chunk(make_req(3)); + + assert_eq!( + Err(InstallSnapshotError::from(SnapshotMismatch { + expect: SnapshotSegmentId { + id: "1-2-3-4".to_string(), + offset: 0 + }, + got: SnapshotSegmentId { + id: "1-2-3-4".to_string(), + offset: 3 + }, + })), + res + ); + + assert_eq!( + Some(StreamingState { + offset: 2, + snapshot_id: "1-2-3-100".to_string(), + }), + eng.state.snapshot_streaming, + "streaming state not changed" + ); + assert_eq!(true, eng.output.take_commands().is_empty()); + + Ok(()) +} diff --git a/openraft/src/lib.rs b/openraft/src/lib.rs index 29fb21d48..47c305a5c 100644 --- a/openraft/src/lib.rs +++ b/openraft/src/lib.rs @@ -69,7 +69,7 @@ pub(crate) mod validate; mod display_ext; mod internal_server_state; mod leader; -mod raft_state; +pub(crate) mod raft_state; mod runtime; mod try_as_ref; diff --git a/openraft/src/raft_state/mod.rs b/openraft/src/raft_state/mod.rs index 017da97e6..3a4da552b 100644 --- a/openraft/src/raft_state/mod.rs +++ b/openraft/src/raft_state/mod.rs @@ -23,7 +23,7 @@ mod accepted; pub(crate) mod io_state; mod log_state_reader; mod membership_state; -mod snapshot_streaming; +pub(crate) mod snapshot_streaming; mod vote_state_reader; pub(crate) use io_state::IOState;