From cc8af8cd67d78118e0ea48dc5d1de3adf183e45a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E7=82=8E=E6=B3=BC?= Date: Sun, 8 Jan 2023 15:25:43 +0800 Subject: [PATCH] Fix: last_purged_log_id is not loaded correctly - Fix: `last_purged_log_id` should be `None`, but not `LogId{index=0, ..}` when raft startup with a store with log at index 0. This is fixed by adding another field `next_purge` to distinguish `last_purged_log_id` value `None` and `LogId{index=0, ..}`, because `RaftState.log_ids` stores `LogId` but not `Option`. - Add a wrapper `Valid` of `RaftState` to check if the state is valid every time accessing it. This check is done only when `debug_assertions` is turned on. --- openraft/src/engine/calc_purge_upto_test.rs | 7 +- openraft/src/engine/elect_test.rs | 4 +- openraft/src/engine/engine_impl.rs | 23 +- .../engine/follower_commit_entries_test.rs | 2 + .../engine/follower_do_append_entries_test.rs | 2 + .../engine/handle_append_entries_req_test.rs | 2 + openraft/src/engine/handle_vote_req_test.rs | 2 + openraft/src/engine/handle_vote_resp_test.rs | 5 +- .../src/engine/handler/snapshot_handler.rs | 1 + openraft/src/engine/initialize_test.rs | 4 + openraft/src/engine/install_snapshot_test.rs | 2 + .../engine/internal_handle_vote_req_test.rs | 2 + .../src/engine/leader_append_entries_test.rs | 2 + openraft/src/engine/purge_log_test.rs | 3 + openraft/src/engine/truncate_logs_test.rs | 2 + openraft/src/engine/update_progress_test.rs | 2 + openraft/src/lib.rs | 1 + openraft/src/raft.rs | 2 +- openraft/src/raft_state.rs | 35 +++ openraft/src/raft_state_test.rs | 2 + openraft/src/storage/helper.rs | 4 + openraft/src/storage/mod.rs | 5 + openraft/src/valid/bench/mod.rs | 1 + openraft/src/valid/bench/valid_deref.rs | 32 +++ openraft/src/valid/mod.rs | 8 + openraft/src/valid/valid_impl.rs | 234 ++++++++++++++++++ 26 files changed, 379 insertions(+), 10 deletions(-) create mode 100644 openraft/src/valid/bench/mod.rs create mode 100644 openraft/src/valid/bench/valid_deref.rs create mode 100644 openraft/src/valid/mod.rs create mode 100644 openraft/src/valid/valid_impl.rs diff --git a/openraft/src/engine/calc_purge_upto_test.rs b/openraft/src/engine/calc_purge_upto_test.rs index 2a6f0f120..17258b779 100644 --- a/openraft/src/engine/calc_purge_upto_test.rs +++ b/openraft/src/engine/calc_purge_upto_test.rs @@ -12,6 +12,8 @@ fn log_id(term: u64, index: u64) -> LogId { fn eng() -> Engine { let mut eng = Engine::default(); + eng.state.enable_validate = false; // Disable validation for incomplete state + eng.state.log_ids = LogIdList::new(vec![ // log_id(0, 0), @@ -32,7 +34,7 @@ fn test_calc_purge_upto() -> anyhow::Result<()> { (None, None, 1, None), // (None, Some(log_id(1, 1)), 0, Some(log_id(1, 1))), - (None, Some(log_id(1, 1)), 1, None), + (None, Some(log_id(1, 1)), 1, Some(log_id(0, 0))), (None, Some(log_id(1, 1)), 2, None), // (Some(log_id(0, 0)), Some(log_id(1, 1)), 0, Some(log_id(1, 1))), @@ -43,7 +45,7 @@ fn test_calc_purge_upto() -> anyhow::Result<()> { (None, Some(log_id(3, 4)), 1, Some(log_id(3, 3))), (None, Some(log_id(3, 4)), 2, Some(log_id(1, 2))), (None, Some(log_id(3, 4)), 3, Some(log_id(1, 1))), - (None, Some(log_id(3, 4)), 4, None), + (None, Some(log_id(3, 4)), 4, Some(log_id(0, 0))), (None, Some(log_id(3, 4)), 5, None), // (Some(log_id(1, 2)), Some(log_id(3, 4)), 0, Some(log_id(3, 4))), @@ -61,6 +63,7 @@ fn test_calc_purge_upto() -> anyhow::Result<()> { if let Some(last_purged) = last_purged { eng.state.log_ids.purge(&last_purged); + eng.state.next_purge = last_purged.index + 1; } eng.state.snapshot_meta.last_log_id = snapshot_last_log_id; let got = eng.calc_purge_upto(); diff --git a/openraft/src/engine/elect_test.rs b/openraft/src/engine/elect_test.rs index b96e8f9e2..8917dcc92 100644 --- a/openraft/src/engine/elect_test.rs +++ b/openraft/src/engine/elect_test.rs @@ -31,7 +31,9 @@ fn m12() -> Membership { } fn eng() -> Engine { - Engine::default() + let mut eng = Engine::default(); + eng.state.enable_validate = false; // Disable validation for incomplete state + eng } #[test] diff --git a/openraft/src/engine/engine_impl.rs b/openraft/src/engine/engine_impl.rs index 0d550bd5b..e54f8eb30 100644 --- a/openraft/src/engine/engine_impl.rs +++ b/openraft/src/engine/engine_impl.rs @@ -24,6 +24,7 @@ use crate::raft_state::LogStateReader; use crate::raft_state::RaftState; use crate::raft_types::RaftLogId; use crate::summary::MessageSummary; +use crate::valid::Valid; use crate::LogId; use crate::LogIdOptionExt; use crate::Membership; @@ -106,7 +107,7 @@ where pub(crate) config: EngineConfig, /// The state of this raft node. - pub(crate) state: RaftState, + pub(crate) state: Valid>, /// The internal server state used by Engine. pub(crate) internal_server_state: InternalServerState, @@ -120,10 +121,10 @@ where N: Node, NID: NodeId, { - pub(crate) fn new(init_state: &RaftState, config: EngineConfig) -> Self { + pub(crate) fn new(init_state: RaftState, config: EngineConfig) -> Self { Self { config, - state: init_state.clone(), + state: Valid::new(init_state), internal_server_state: InternalServerState::default(), output: EngineOutput::default(), } @@ -657,7 +658,7 @@ where return; } - st.log_ids.purge(&upto); + st.purge_log(&upto); self.output.push_command(Command::PurgeLog { upto }); } @@ -859,6 +860,11 @@ where tracing::info!("install_snapshot: meta:{:?}", meta); + // TODO: temp solution: committed is updated after snapshot_last_log_id. + // committed should be updated first or together with snapshot_last_log_id(i.e., extract `state` first). + let old_validate = self.state.enable_validate; + self.state.enable_validate = false; + let snap_last_log_id = meta.last_log_id; if snap_last_log_id <= self.state.committed { @@ -868,6 +874,8 @@ where self.state.committed.summary() ); self.output.push_command(Command::CancelSnapshot { snapshot_meta: meta }); + // TODO: temp solution: committed is updated after snapshot_last_log_id. + self.state.enable_validate = old_validate; return; } @@ -877,6 +885,8 @@ where let mut snap_handler = self.snapshot_handler(); let updated = snap_handler.update_snapshot(meta.clone()); if !updated { + // TODO: temp solution: committed is updated after snapshot_last_log_id. + self.state.enable_validate = old_validate; return; } @@ -924,7 +934,10 @@ where // In the second case, if local-last-log-id is smaller than snapshot-last-log-id, // and this node crashes after installing snapshot and before purging logs, // the log will be purged the next start up, in [`RaftState::get_initial_state`]. - self.purge_log(snap_last_log_id) + self.purge_log(snap_last_log_id); + + // TODO: temp solution: committed is updated after snapshot_last_log_id. + self.state.enable_validate = old_validate; } #[tracing::instrument(level = "debug", skip_all)] diff --git a/openraft/src/engine/follower_commit_entries_test.rs b/openraft/src/engine/follower_commit_entries_test.rs index 4ed0df432..7e4091331 100644 --- a/openraft/src/engine/follower_commit_entries_test.rs +++ b/openraft/src/engine/follower_commit_entries_test.rs @@ -41,6 +41,8 @@ fn m23() -> Membership { fn eng() -> Engine { let mut eng = Engine::default(); + eng.state.enable_validate = false; // Disable validation for incomplete state + eng.state.committed = Some(log_id(1, 1)); eng.state.membership_state.committed = Arc::new(EffectiveMembership::new(Some(log_id(1, 1)), m01())); eng.state.membership_state.effective = Arc::new(EffectiveMembership::new(Some(log_id(2, 3)), m23())); diff --git a/openraft/src/engine/follower_do_append_entries_test.rs b/openraft/src/engine/follower_do_append_entries_test.rs index 5e2db508c..62d8715cb 100644 --- a/openraft/src/engine/follower_do_append_entries_test.rs +++ b/openraft/src/engine/follower_do_append_entries_test.rs @@ -51,6 +51,8 @@ fn m45() -> Membership { fn eng() -> Engine { let mut eng = Engine::default(); + eng.state.enable_validate = false; // Disable validation for incomplete state + eng.config.id = 2; eng.state.log_ids.append(log_id(1, 1)); eng.state.log_ids.append(log_id(2, 3)); diff --git a/openraft/src/engine/handle_append_entries_req_test.rs b/openraft/src/engine/handle_append_entries_req_test.rs index 1cd094186..ead5609f7 100644 --- a/openraft/src/engine/handle_append_entries_req_test.rs +++ b/openraft/src/engine/handle_append_entries_req_test.rs @@ -50,6 +50,8 @@ fn m34() -> Membership { fn eng() -> Engine { let mut eng = Engine::default(); + eng.state.enable_validate = false; // Disable validation for incomplete state + eng.config.id = 2; eng.state.vote = Vote::new(2, 1); eng.state.log_ids.append(log_id(1, 1)); diff --git a/openraft/src/engine/handle_vote_req_test.rs b/openraft/src/engine/handle_vote_req_test.rs index c37932971..ce0fec56f 100644 --- a/openraft/src/engine/handle_vote_req_test.rs +++ b/openraft/src/engine/handle_vote_req_test.rs @@ -28,6 +28,8 @@ fn m01() -> Membership { fn eng() -> Engine { let mut eng = Engine::::default(); + eng.state.enable_validate = false; // Disable validation for incomplete state + eng.state.vote = Vote::new(2, 1); eng.state.server_state = ServerState::Candidate; eng.state.membership_state.effective = Arc::new(EffectiveMembership::new(Some(log_id(1, 1)), m01())); diff --git a/openraft/src/engine/handle_vote_resp_test.rs b/openraft/src/engine/handle_vote_resp_test.rs index 09d545114..f27338eb8 100644 --- a/openraft/src/engine/handle_vote_resp_test.rs +++ b/openraft/src/engine/handle_vote_resp_test.rs @@ -32,7 +32,10 @@ fn m1234() -> Membership { } fn eng() -> Engine { - Engine::::default() + let mut eng = Engine::::default(); + eng.state.enable_validate = false; // Disable validation for incomplete state + + eng } #[test] diff --git a/openraft/src/engine/handler/snapshot_handler.rs b/openraft/src/engine/handler/snapshot_handler.rs index e2b8ef14a..21fd5a51b 100644 --- a/openraft/src/engine/handler/snapshot_handler.rs +++ b/openraft/src/engine/handler/snapshot_handler.rs @@ -74,6 +74,7 @@ mod tests { fn eng() -> Engine { let mut eng = Engine:: { ..Default::default() }; + eng.state.enable_validate = false; // Disable validation for incomplete state eng.state.snapshot_meta = SnapshotMeta { last_log_id: Some(log_id(2, 2)), diff --git a/openraft/src/engine/initialize_test.rs b/openraft/src/engine/initialize_test.rs index 6c1f93c99..3aade6361 100644 --- a/openraft/src/engine/initialize_test.rs +++ b/openraft/src/engine/initialize_test.rs @@ -24,6 +24,8 @@ use crate::Vote; fn test_initialize_single_node() -> anyhow::Result<()> { let eng = || { let mut eng = Engine::::default(); + eng.state.enable_validate = false; // Disable validation for incomplete state + eng.state.server_state = eng.calc_server_state(); eng }; @@ -130,6 +132,8 @@ fn test_initialize_single_node() -> anyhow::Result<()> { fn test_initialize() -> anyhow::Result<()> { let eng = || { let mut eng = Engine::::default(); + eng.state.enable_validate = false; // Disable validation for incomplete state + eng.state.server_state = eng.calc_server_state(); eng }; diff --git a/openraft/src/engine/install_snapshot_test.rs b/openraft/src/engine/install_snapshot_test.rs index 2f61d438a..c4fe26467 100644 --- a/openraft/src/engine/install_snapshot_test.rs +++ b/openraft/src/engine/install_snapshot_test.rs @@ -30,6 +30,7 @@ fn m1234() -> Membership { fn eng() -> Engine { let mut eng = Engine:: { ..Default::default() }; + eng.state.enable_validate = false; // Disable validation for incomplete state eng.state.committed = Some(log_id(4, 5)); eng.state.log_ids = LogIdList::new(vec![ @@ -199,6 +200,7 @@ fn test_install_snapshot_conflict() -> anyhow::Result<()> { // And there should be no conflicting logs left. let mut eng = { let mut eng = Engine:: { ..Default::default() }; + eng.state.enable_validate = false; // Disable validation for incomplete state eng.state.committed = Some(log_id(2, 3)); eng.state.log_ids = LogIdList::new(vec![ diff --git a/openraft/src/engine/internal_handle_vote_req_test.rs b/openraft/src/engine/internal_handle_vote_req_test.rs index 2104f997e..25aea9d46 100644 --- a/openraft/src/engine/internal_handle_vote_req_test.rs +++ b/openraft/src/engine/internal_handle_vote_req_test.rs @@ -27,6 +27,8 @@ fn m01() -> Membership { fn eng() -> Engine { let mut eng = Engine::::default(); + eng.state.enable_validate = false; // Disable validation for incomplete state + eng.state.vote = Vote::new(2, 1); eng.state.server_state = ServerState::Candidate; eng.state.membership_state.effective = Arc::new(EffectiveMembership::new(Some(log_id(1, 1)), m01())); diff --git a/openraft/src/engine/leader_append_entries_test.rs b/openraft/src/engine/leader_append_entries_test.rs index d4e29e36f..bc9c961cf 100644 --- a/openraft/src/engine/leader_append_entries_test.rs +++ b/openraft/src/engine/leader_append_entries_test.rs @@ -64,6 +64,8 @@ fn m34() -> Membership { fn eng() -> Engine { let mut eng = Engine::default(); + eng.state.enable_validate = false; // Disable validation for incomplete state + eng.config.id = 1; eng.state.committed = Some(log_id(0, 0)); eng.state.vote = Vote::new_committed(3, 1); diff --git a/openraft/src/engine/purge_log_test.rs b/openraft/src/engine/purge_log_test.rs index 3fd5fff91..fba50ebe1 100644 --- a/openraft/src/engine/purge_log_test.rs +++ b/openraft/src/engine/purge_log_test.rs @@ -14,7 +14,10 @@ fn log_id(term: u64, index: u64) -> LogId { fn eng() -> Engine { let mut eng = Engine::::default(); + eng.state.enable_validate = false; // Disable validation for incomplete state + eng.state.log_ids = LogIdList::new(vec![log_id(2, 2), log_id(4, 4), log_id(4, 6)]); + eng.state.next_purge = 3; eng } diff --git a/openraft/src/engine/truncate_logs_test.rs b/openraft/src/engine/truncate_logs_test.rs index fd6f588c9..59cf27e08 100644 --- a/openraft/src/engine/truncate_logs_test.rs +++ b/openraft/src/engine/truncate_logs_test.rs @@ -35,6 +35,8 @@ fn m23() -> Membership { fn eng() -> Engine { let mut eng = Engine::default(); + eng.state.enable_validate = false; // Disable validation for incomplete state + eng.config.id = 2; eng.state.log_ids = LogIdList::new(vec![ log_id(2, 2), // diff --git a/openraft/src/engine/update_progress_test.rs b/openraft/src/engine/update_progress_test.rs index 20d8ef225..a58b1eb0e 100644 --- a/openraft/src/engine/update_progress_test.rs +++ b/openraft/src/engine/update_progress_test.rs @@ -27,6 +27,8 @@ fn m123() -> Membership { fn eng() -> Engine { let mut eng = Engine::default(); + eng.state.enable_validate = false; // Disable validation for incomplete state + eng.config.id = 2; eng.state.vote = Vote::new_committed(2, 1); eng.state.membership_state.committed = Arc::new(EffectiveMembership::new(Some(log_id(1, 1)), m01())); diff --git a/openraft/src/lib.rs b/openraft/src/lib.rs index 6dc734131..6fa7f6c2f 100644 --- a/openraft/src/lib.rs +++ b/openraft/src/lib.rs @@ -45,6 +45,7 @@ mod runtime; pub mod storage; pub mod testing; pub mod timer; +pub(crate) mod valid; pub mod versioned; #[cfg(test)] mod raft_state_test; diff --git a/openraft/src/raft.rs b/openraft/src/raft.rs index d88cd9cea..702d8813e 100644 --- a/openraft/src/raft.rs +++ b/openraft/src/raft.rs @@ -228,7 +228,7 @@ impl, S: RaftStorage> Raft>, + pub(crate) next_purge: u64, + /// All log ids this node has. pub log_ids: LogIdList, @@ -94,10 +101,33 @@ where } fn last_purged_log_id(&self) -> Option<&LogId> { + if self.next_purge == 0 { + return None; + } self.log_ids.first() } } +impl Validate for RaftState +where + NID: NodeId, + N: Node, +{ + fn validate(&self) -> Result<(), Box> { + if self.next_purge == 0 { + less_equal!(self.log_ids.first().index(), Some(0)); + } else { + equal!(self.next_purge, self.log_ids.first().next_index()); + } + + less_equal!(self.last_purged_log_id(), self.snapshot_last_log_id()); + less_equal!(self.snapshot_last_log_id(), self.committed()); + less_equal!(self.committed(), self.last_log_id()); + + Ok(()) + } +} + impl RaftState where NID: NodeId, @@ -156,4 +186,9 @@ where None } } + + pub(crate) fn purge_log(&mut self, upto: &LogId) { + self.next_purge = upto.index + 1; + self.log_ids.purge(upto); + } } diff --git a/openraft/src/raft_state_test.rs b/openraft/src/raft_state_test.rs index 8159eec23..0f2bd20d9 100644 --- a/openraft/src/raft_state_test.rs +++ b/openraft/src/raft_state_test.rs @@ -100,12 +100,14 @@ fn test_raft_state_last_purged_log_id() -> anyhow::Result<()> { let rs = RaftState:: { log_ids: LogIdList::new(vec![log_id(1, 2)]), + next_purge: 3, ..Default::default() }; assert_eq!(Some(log_id(1, 2)), rs.last_purged_log_id().copied()); let rs = RaftState:: { log_ids: LogIdList::new(vec![log_id(1, 2), log_id(3, 4)]), + next_purge: 3, ..Default::default() }; assert_eq!(Some(log_id(1, 2)), rs.last_purged_log_id().copied()); diff --git a/openraft/src/storage/helper.rs b/openraft/src/storage/helper.rs index 07a60ae59..37b823a0c 100644 --- a/openraft/src/storage/helper.rs +++ b/openraft/src/storage/helper.rs @@ -53,7 +53,10 @@ where last_purged_log_id = last_applied; } + println!("purged: {:?}", last_purged_log_id); + println!("last: {:?}", last_log_id); let log_ids = LogIdList::load_log_ids(last_purged_log_id, last_log_id, self).await?; + println!("log_ids: {:?}", log_ids); let snapshot_meta = self.sto.get_current_snapshot().await?.map(|x| x.meta).unwrap_or_default(); @@ -62,6 +65,7 @@ where // The initial value for `vote` is the minimal possible value. // See: [Conditions for initialization](https://datafuselabs.github.io/openraft/cluster-formation.html#conditions-for-initialization) vote: vote.unwrap_or_default(), + next_purge: last_purged_log_id.next_index(), log_ids, membership_state: mem_state, snapshot_meta, diff --git a/openraft/src/storage/mod.rs b/openraft/src/storage/mod.rs index 2bfc2105c..977e2dbe5 100644 --- a/openraft/src/storage/mod.rs +++ b/openraft/src/storage/mod.rs @@ -69,6 +69,11 @@ where snapshot_id: self.snapshot_id.clone(), } } + + /// Returns a ref to the id of the last log that is included in this snasphot. + pub fn last_log_id(&self) -> Option<&LogId> { + self.last_log_id.as_ref() + } } /// The data associated with the current snapshot. diff --git a/openraft/src/valid/bench/mod.rs b/openraft/src/valid/bench/mod.rs new file mode 100644 index 000000000..a0e6272f6 --- /dev/null +++ b/openraft/src/valid/bench/mod.rs @@ -0,0 +1 @@ +mod valid_deref; diff --git a/openraft/src/valid/bench/valid_deref.rs b/openraft/src/valid/bench/valid_deref.rs new file mode 100644 index 000000000..e2cd6d8ce --- /dev/null +++ b/openraft/src/valid/bench/valid_deref.rs @@ -0,0 +1,32 @@ +extern crate test; +use std::error::Error; + +use maplit::btreeset; +use test::black_box; +use test::Bencher; + +use crate::less_equal; +use crate::quorum::AsJoint; +use crate::quorum::QuorumSet; +use crate::valid::Valid; +use crate::valid::Validate; + +struct Foo { + a: u64, +} + +impl Validate for Foo { + fn validate(&self) -> Result<(), Box> { + less_equal!(self.a, 10); + Ok(()) + } +} + +#[bench] +fn valid_deref(b: &mut Bencher) { + let f = Valid::new(Foo { a: 5 }); + + b.iter(|| { + let _x = black_box(f.a); + }) +} diff --git a/openraft/src/valid/mod.rs b/openraft/src/valid/mod.rs new file mode 100644 index 000000000..fc51c85b6 --- /dev/null +++ b/openraft/src/valid/mod.rs @@ -0,0 +1,8 @@ +#[cfg(feature = "bench")] +#[cfg(test)] +mod bench; + +mod valid_impl; + +pub(crate) use valid_impl::Valid; +pub(crate) use valid_impl::Validate; diff --git a/openraft/src/valid/valid_impl.rs b/openraft/src/valid/valid_impl.rs new file mode 100644 index 000000000..66e4b98ab --- /dev/null +++ b/openraft/src/valid/valid_impl.rs @@ -0,0 +1,234 @@ +use std::error::Error; +use std::fmt::Debug; +use std::fmt::Formatter; +use std::ops::Deref; +use std::ops::DerefMut; + +#[macro_export] +macro_rules! less_equal { + ($a: expr, $b: expr) => {{ + let a = $a; + let b = $b; + if (a <= b) { + // Ok + } else { + Err(::anyerror::AnyError::error(format!( + "expect: {}({:?}) {} {}({:?})", + stringify!($a), + a, + "<=", + stringify!($b), + b, + )))?; + } + }}; +} + +#[macro_export] +macro_rules! equal { + ($a: expr, $b: expr) => {{ + let a = $a; + let b = $b; + if (a == b) { + // Ok + } else { + Err(::anyerror::AnyError::error(format!( + "expect: {}({:?}) {} {}({:?})", + stringify!($a), + a, + "==", + stringify!($b), + b, + )))?; + } + }}; +} + +/// A type that validates its internal state. +/// +/// An example of defining field `a` whose value must not exceed `10`. +/// ```ignore +/// # use std::error::Error; +/// # use openraft::less_equal; +/// struct Foo { a: u64 } +/// impl Validate for Foo { +/// fn validate(&self) -> Result<(), Box> { +/// less_equal!(self.a, 10); +/// Ok(()) +/// } +/// } +/// ``` +pub(crate) trait Validate { + fn validate(&self) -> Result<(), Box>; +} + +/// A wrapper of T that validate the state of T every time accessing it. +/// +/// - It validates the state before accessing it, i.e., if when a invalid state is written to it, it won't panic until +/// next time accessing it. +/// - The validation is turned on only when `debug_assertions` is enabled. +/// +/// An example of defining field `a` whose value must not exceed `10`. +/// ```ignore +/// # use std::error::Error; +/// # use openraft::less_equal; +/// struct Foo { a: u64 } +/// impl Validate for Foo { +/// fn validate(&self) -> Result<(), Box> { +/// less_equal!(self.a, 10); +/// Ok(()) +/// } +/// } +/// +/// let f = Valid::new(Foo { a: 20 }); +/// let _x = f.a; // panic: panicked at 'invalid state: expect: self.a(20) <= 10(10) ... +/// ``` +pub(crate) struct Valid +where T: Validate +{ + pub(crate) enable_validate: bool, + inner: T, +} + +impl Valid { + pub(crate) fn new(inner: T) -> Self { + Self { + enable_validate: true, + inner, + } + } +} + +impl Deref for Valid +where T: Validate +{ + type Target = T; + + fn deref(&self) -> &Self::Target { + #[cfg(debug_assertions)] + if self.enable_validate { + if let Err(e) = self.inner.validate() { + panic!("invalid state: {}", e); + } + } + + &self.inner + } +} + +impl DerefMut for Valid +where T: Validate +{ + fn deref_mut(&mut self) -> &mut Self::Target { + #[cfg(debug_assertions)] + if self.enable_validate { + if let Err(e) = self.inner.validate() { + panic!("invalid state: {}", e); + } + } + + &mut self.inner + } +} + +impl PartialEq for Valid +where T: Validate +{ + fn eq(&self, other: &Self) -> bool { + self.inner.eq(&other.inner) + } +} + +impl Eq for Valid where T: Validate {} + +impl Debug for Valid +where T: Validate +{ + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.inner.fmt(f) + } +} + +impl Clone for Valid +where T: Validate +{ + fn clone(&self) -> Self { + Self { + enable_validate: self.enable_validate, + inner: self.inner.clone(), + } + } +} + +impl Default for Valid +where T: Validate +{ + fn default() -> Self { + Self { + enable_validate: true, + inner: T::default(), + } + } +} + +#[cfg(test)] +mod tests { + use std::error::Error; + + use crate::valid::Valid; + use crate::valid::Validate; + + struct Foo { + a: u64, + } + + impl Validate for Foo { + fn validate(&self) -> Result<(), Box> { + less_equal!(self.a, 10); + Ok(()) + } + } + + #[test] + fn test_validate() { + // panic when reading an invalid state + let res = std::panic::catch_unwind(|| { + let f = Valid::new(Foo { a: 20 }); + let _x = f.a; + }); + tracing::info!("res: {:?}", res); + assert!(res.is_err()); + + // Disable validation + let res = std::panic::catch_unwind(|| { + let mut f = Valid::new(Foo { a: 20 }); + f.enable_validate = false; + let _x = f.a; + }); + tracing::info!("res: {:?}", res); + assert!(res.is_ok()); + + // valid state + let res = std::panic::catch_unwind(|| { + let f = Valid::new(Foo { a: 10 }); + let _x = f.a; + }); + assert!(res.is_ok()); + + // no panic when just becoming invalid + let res = std::panic::catch_unwind(|| { + let mut f = Valid::new(Foo { a: 10 }); + f.a += 3; + }); + assert!(res.is_ok()); + + // panic on next write access + let res = std::panic::catch_unwind(|| { + let mut f = Valid::new(Foo { a: 10 }); + f.a += 3; + f.a += 1; + }); + tracing::info!("res: {:?}", res); + assert!(res.is_err()); + } +}