From 76267e5a2b5c102b04eda15607ccebc703eeba30 Mon Sep 17 00:00:00 2001 From: Marco Primi Date: Wed, 4 Oct 2023 16:07:46 -0700 Subject: [PATCH] Add RAFT tests --- server/raft_helpers_test.go | 279 +++++++++++++++++++++++++ server/raft_test.go | 403 ++++++++++++++++++++++++++++++++++++ 2 files changed, 682 insertions(+) diff --git a/server/raft_helpers_test.go b/server/raft_helpers_test.go index e7a7ef881c..bff2c202ab 100644 --- a/server/raft_helpers_test.go +++ b/server/raft_helpers_test.go @@ -17,8 +17,12 @@ package server import ( + "encoding" "encoding/binary" + "encoding/json" "fmt" + "hash" + "hash/crc32" "math/rand" "sync" "testing" @@ -274,3 +278,278 @@ func (rg smGroup) waitOnTotal(t *testing.T, expected int64) { func newStateAdder(s *Server, cfg *RaftConfig, n RaftNode) stateMachine { return &stateAdder{s: s, n: n, cfg: cfg} } + +var RaftChainOptions = struct { + verbose bool + maxBlockSize int +}{ + false, + 25, +} + +// Simple implementation of a replicated state machine on top of RAFT. +// Hash each value delivered on top of existing hash +// All replicas should go through the same sequence of block hashes +type raftChainStateMachine struct { + sync.Mutex + s *Server + n RaftNode + cfg *RaftConfig + leader bool + proposalSequence uint64 + rng *rand.Rand + hash hash.Hash + blocksApplied uint64 + blocksAppliedSinceSnapshot uint64 +} + +// Block is just a random array of bytes, but contains a little extra metadata to track its source +type ChainBlock struct { + Proposer string + ProposerSequence uint64 + Data []byte +} + +func (sm *raftChainStateMachine) logDebug(format string, args ...any) { + if RaftChainOptions.verbose { + fmt.Printf("["+sm.s.Name()+" ("+sm.n.ID()+")] "+format+"\n", args...) + } +} + +func (sm *raftChainStateMachine) server() *Server { + return sm.s +} + +func (sm *raftChainStateMachine) node() RaftNode { + return sm.n +} + +func (sm *raftChainStateMachine) propose(data []byte) { + sm.Lock() + defer sm.Unlock() + err := sm.n.ForwardProposal(data) + if err != nil { + sm.logDebug("block proposal error: %s", err) + } +} + +func (sm *raftChainStateMachine) applyEntry(ce *CommittedEntry) { + sm.Lock() + defer sm.Unlock() + if ce == nil { + // Nothing to apply + return + } + sm.logDebug("Apply entries #%d (%d entries)", ce.Index, len(ce.Entries)) + for _, entry := range ce.Entries { + if entry.Type == EntryNormal { + sm.applyBlock(entry.Data) + } else if entry.Type == EntrySnapshot { + sm.loadSnapshot(entry.Data) + } else { + panic(fmt.Sprintf("[%s] unknown entry type: %s", sm.s.Name(), entry.Type)) + } + } + sm.n.Applied(ce.Index) +} + +func (sm *raftChainStateMachine) leaderChange(isLeader bool) { + if sm.leader && !isLeader { + sm.logDebug("Leader change: no longer leader") + } else if sm.leader && isLeader { + sm.logDebug("Elected leader while already leader") + } else if !sm.leader && isLeader { + sm.logDebug("Leader change: i am leader") + } else { + sm.logDebug("Leader change") + } + sm.leader = isLeader +} + +func (sm *raftChainStateMachine) stop() { + sm.Lock() + defer sm.Unlock() + sm.n.Stop() + + // Clear state, on restart it will be recovered from snapshot or peers + sm.blocksApplied = 0 + sm.hash.Reset() + sm.logDebug("Stopped") +} + +func (sm *raftChainStateMachine) restart() { + sm.Lock() + defer sm.Unlock() + + sm.logDebug("Restarting") + + if sm.n.State() != Closed { + return + } + + // The filestore is stopped as well, so need to extract the parts to recreate it. + rn := sm.n.(*raft) + fs := rn.wal.(*fileStore) + + var err error + sm.cfg.Log, err = newFileStore(fs.fcfg, fs.cfg.StreamConfig) + if err != nil { + panic(err) + } + sm.n, err = sm.s.startRaftNode(globalAccountName, sm.cfg, pprofLabels{}) + if err != nil { + panic(err) + } + // Finally restart the driver. + go smLoop(sm) +} + +func (sm *raftChainStateMachine) proposeBlock() { + // Track how many blocks this replica proposed + sm.proposalSequence += 1 + // Create a block + block := ChainBlock{ + Proposer: sm.s.Name(), + ProposerSequence: sm.proposalSequence, + Data: make([]byte, sm.rng.Intn(20)+1), + } + // Data is random bytes + sm.rng.Read(block.Data) + // Serialize as JSON + blockData, err := json.Marshal(block) + if err != nil { + panic(fmt.Sprintf("serialization error: %s", err)) + } + sm.logDebug( + "Proposing block <%s, %d, [%dB]>", + block.Proposer, + block.ProposerSequence, + len(block.Data), + ) + + // Propose (may silently fail if this replica is not leader, or other reasons) + sm.propose(blockData) +} + +func (sm *raftChainStateMachine) applyBlock(data []byte) { + // Deserialize block received in JSON format + var block ChainBlock + err := json.Unmarshal(data, &block) + if err != nil { + panic(fmt.Sprintf("deserialization error: %s", err)) + } + sm.logDebug("Applying block <%s, %d>", block.Proposer, block.ProposerSequence) + + // Hash the data on top of the existing running hash + n, err := sm.hash.Write(block.Data) + if n != len(block.Data) { + panic(fmt.Sprintf("unexpected checksum written %d data block size: %d", n, len(block.Data))) + } else if err != nil { + panic(fmt.Sprintf("checksum error: %s", err)) + } + + // Track block number + sm.blocksApplied += 1 + sm.blocksAppliedSinceSnapshot += 1 + + sm.logDebug("Hash after %d blocks: %X ", sm.blocksApplied, sm.hash.Sum(nil)) +} + +func (sm *raftChainStateMachine) getCurrentHash() (uint64, string) { + sm.Lock() + defer sm.Unlock() + + // Return the number of blocks applied and the current running hash + return sm.blocksApplied, fmt.Sprintf("%X", sm.hash.Sum(nil)) +} + +type chainHashSnapshot struct { + SourceNode string + HashData []byte + BlocksCount uint64 +} + +func (sm *raftChainStateMachine) snapshot() { + sm.Lock() + defer sm.Unlock() + + if sm.blocksAppliedSinceSnapshot == 0 { + sm.logDebug("Skip snapshot, no new entries") + return + } + + sm.logDebug("Snapshot (with %d blocks applied)", sm.blocksApplied) + + // Serialize the internal state of the hash block + serializedHash, err := sm.hash.(encoding.BinaryMarshaler).MarshalBinary() + if err != nil { + panic(fmt.Sprintf("failed to marshal hash: %s", err)) + } + + // Create snapshot + snapshot := chainHashSnapshot{ + SourceNode: fmt.Sprintf("%s (%s)", sm.s.Name(), sm.n.ID()), + HashData: serializedHash, + BlocksCount: sm.blocksApplied, + } + + // Serialize snapshot as JSON + snapshotData, err := json.Marshal(snapshot) + if err != nil { + panic(fmt.Sprintf("failed to marshal snapshot: %s", err)) + } + + // Install it as byte array + err = sm.n.InstallSnapshot(snapshotData) + if err != nil { + panic(fmt.Sprintf("failed to snapshot: %s", err)) + } + + // Reset counter since last snapshot + sm.blocksAppliedSinceSnapshot = 0 +} + +func (sm *raftChainStateMachine) loadSnapshot(data []byte) { + // Deserialize snapshot from JSON + var snapshot chainHashSnapshot + err := json.Unmarshal(data, &snapshot) + if err != nil { + panic(fmt.Sprintf("failed to unmarshal snapshot: %s", err)) + } + + sm.logDebug( + "Applying snapshot (created by %s) taken after %d blocks", + snapshot.SourceNode, + snapshot.BlocksCount, + ) + + // Load internal hash block state + err = sm.hash.(encoding.BinaryUnmarshaler).UnmarshalBinary(snapshot.HashData) + if err != nil { + panic(fmt.Sprintf("failed to unmarshal hash data: %s", err)) + } + + // Load block counter + sm.blocksApplied = snapshot.BlocksCount + sm.blocksAppliedSinceSnapshot = 0 + + sm.logDebug("Hash after snapshot with %d blocks: %X ", sm.blocksApplied, sm.hash.Sum(nil)) +} + +// Factory function to create RaftChainStateMachine on top of the given server/node +func newRaftChainStateMachine(s *Server, cfg *RaftConfig, n RaftNode) stateMachine { + // Create RNG seed based on server name and node id + var seed int64 + for _, c := range []byte(s.Name()) { + seed += int64(c) + } + for _, c := range []byte(n.ID()) { + seed += int64(c) + } + rng := rand.New(rand.NewSource(seed)) + + // Initialize empty hash block + hashBlock := crc32.NewIEEE() + + return &raftChainStateMachine{s: s, n: n, cfg: cfg, rng: rng, hash: hashBlock} +} diff --git a/server/raft_test.go b/server/raft_test.go index bddce36fb6..f89abef137 100644 --- a/server/raft_test.go +++ b/server/raft_test.go @@ -14,8 +14,10 @@ package server import ( + "fmt" "math" "math/rand" + "strings" "testing" "time" ) @@ -138,3 +140,404 @@ func TestNRGAppendEntryDecode(t *testing.T) { } } } + +func TestRaftChainOneBlockInLockstep(t *testing.T) { + const iterations = 50 + const timeout = 15 * time.Second + //RaftChainOptions.verbose = true + + c := createJetStreamClusterExplicit(t, "R3S", 3) + defer c.shutdown() + + rg := c.createRaftGroup("TEST", 3, newRaftChainStateMachine) + rg.waitOnLeader() + + for iteration := uint64(1); iteration <= iterations; iteration++ { + rg.leader().(*raftChainStateMachine).proposeBlock() + + // Wait on participants to converge + var previousNodeName string + var previousNodeHash string + + for _, sm := range rg { + stateMachine := sm.(*raftChainStateMachine) + nodeName := fmt.Sprintf( + "%s/%s", + stateMachine.server().Name(), + stateMachine.node().ID(), + ) + checkFor(t, timeout, 500*time.Millisecond, func() error { + blocksCount, currentHash := stateMachine.getCurrentHash() + if blocksCount != iteration { + return fmt.Errorf( + "node %s applied %d blocks out of %d expected", + nodeName, + blocksCount, + iteration, + ) + } + // Make sure hash is not empty + if currentHash == "" { + return fmt.Errorf( + "node %s has empty hash after applying %d blocks", + nodeName, + blocksCount, + ) + } + // Check against previous node hash, unless this is the first node to be checked + if previousNodeHash != "" && previousNodeHash != currentHash { + return fmt.Errorf( + "hash mismatch after %d blocks: %s hash: %s != %s hash: %s", + iteration, + nodeName, + currentHash, + previousNodeName, + previousNodeHash, + ) + } + // Set node name and hash for next node to compare against + previousNodeName, previousNodeHash = nodeName, currentHash + // All is well + return nil + }) + } + t.Logf( + "Verified chain hash %s for %d/%d nodes after %d/%d iterations", + previousNodeHash, + len(rg), + len(rg), + iteration, + iterations, + ) + } +} + +func TestRaftChainStopAndCatchUp(t *testing.T) { + const iterations = 50 + const blocksPerIteration = 3 + const timeout = 15 * time.Second + //RaftChainOptions.verbose = true + + c := createJetStreamClusterExplicit(t, "R3S", 3) + defer c.shutdown() + + rg := c.createRaftGroup("TEST", 3, newRaftChainStateMachine) + rg.waitOnLeader() + + for iteration := uint64(1); iteration <= iterations; iteration++ { + + // Stop a (non-leader) node + stoppedNode := rg.nonLeader() + stoppedNode.stop() + + t.Logf( + "Iteration %d/%d: stopping node: %s/%s", + iteration, + iterations, + stoppedNode.server().Name(), + stoppedNode.node().ID(), + ) + + // Propose some new blocks + for i := 0; i < blocksPerIteration; i++ { + rg.leader().(*raftChainStateMachine).proposeBlock() + } + + // Restart the stopped node + stoppedNode.restart() + + // Wait on participants to converge + var previousNodeName string + var previousNodeHash string + expectedBlocks := iteration * blocksPerIteration + for _, sm := range rg { + stateMachine := sm.(*raftChainStateMachine) + nodeName := fmt.Sprintf( + "%s/%s", + stateMachine.server().Name(), + stateMachine.node().ID(), + ) + checkFor(t, timeout, 500*time.Millisecond, func() error { + blocksCount, currentHash := stateMachine.getCurrentHash() + if blocksCount != expectedBlocks { + return fmt.Errorf( + "node %s applied %d blocks out of %d expected", + nodeName, + blocksCount, + expectedBlocks, + ) + } + // Make sure hash is not empty + if currentHash == "" { + return fmt.Errorf( + "node %s has empty hash after applying %d blocks", + nodeName, + blocksCount, + ) + } + // Check against previous node hash, unless this is the first node to be checked + if previousNodeHash != "" && previousNodeHash != currentHash { + return fmt.Errorf( + "hash mismatch after %d blocks: %s hash: %s != %s hash: %s", + expectedBlocks, + nodeName, + currentHash, + previousNodeName, + previousNodeHash, + ) + } + // Set node name and hash for next node to compare against + previousNodeName, previousNodeHash = nodeName, currentHash + // All is well + return nil + }) + } + t.Logf( + "Verified chain hash %s for %d/%d nodes after %d blocks, %d/%d iterations", + previousNodeHash, + len(rg), + len(rg), + expectedBlocks, + iteration, + iterations, + ) + } +} + +func FuzzRaftChain(f *testing.F) { + const ( + groupName = "FUZZ_TEST_RAFT_CHAIN" + numPeers = 3 + checkConvergenceTimeout = 30 * time.Second + ) + + //RaftChainOptions.verbose = true + + // Cases to run when executed as unit test: + f.Add(100, int64(123456)) + f.Add(1000, int64(123456)) + + // Run in Fuzz mode to repeat maximizing coverage and looking for failing cases + // notice that this test execution is not perfectly deterministic! + // The same seed may not fail on retry. + f.Fuzz( + func(t *testing.T, iterations int, rngSeed int64) { + rng := rand.New(rand.NewSource(rngSeed)) + + c := createJetStreamClusterExplicit(t, "R3S", numPeers) + defer c.shutdown() + + rg := c.createRaftGroup(groupName, numPeers, newRaftChainStateMachine) + rg.waitOnLeader() + + // Manually track active and stopped nodes + activeNodes := make([]stateMachine, 0, numPeers) + stoppedNodes := make([]stateMachine, 0, numPeers) + + // Initially all are active + activeNodes = append(activeNodes, rg...) + + // Available operations + type RaftFuzzTestOperation string + + const ( + StopOne RaftFuzzTestOperation = "Stop one active node" + StopAll = "Stop all active nodes" + RestartOne = "Restart one stopped node" + RestartAll = "Restart all stopped nodes" + Snapshot = "Snapshot one active node" + Propose = "Propose a value via one active node" + ProposeLeader = "Propose a value via leader" + Pause = "Let things run undisturbed for a while" + Check = "Wait for nodes to converge" + ) + + // Weighted distribution of operations, one is randomly chosen from this vector in each iteration + opsWeighted := []RaftFuzzTestOperation{ + StopOne, + StopAll, + RestartOne, + RestartOne, + RestartAll, + RestartAll, + RestartAll, + Snapshot, + Snapshot, + Propose, + Propose, + Propose, + Propose, + Propose, + Propose, + ProposeLeader, + ProposeLeader, + ProposeLeader, + ProposeLeader, + ProposeLeader, + ProposeLeader, + Pause, + Pause, + Pause, + Pause, + Pause, + Pause, + Check, + Check, + Check, + Check, + } + + pickRandomNode := func(nodes []stateMachine) ([]stateMachine, stateMachine) { + if len(nodes) == 0 { + // Input list is empty + return nodes, nil + } + // Pick random node + i := rng.Intn(len(nodes)) + node := nodes[i] + // Move last element in its place + nodes[i] = nodes[len(nodes)-1] + // Return slice excluding last element + return nodes[:len(nodes)-1], node + } + + chainStatusString := func() string { + b := strings.Builder{} + for _, sm := range rg { + csm := sm.(*raftChainStateMachine) + blocksCount, blockHash := csm.getCurrentHash() + b.WriteString( + fmt.Sprintf( + " [%s (%s): %d blocks, hash=%s],", + csm.server().Name(), + csm.node().ID(), + blocksCount, + blockHash, + ), + ) + } + return b.String() + } + + // Track the highest number of blocks applied by any of the replicas + highestBlocksCount := uint64(0) + + for iteration := 1; iteration <= iterations; iteration++ { + nextOperation := opsWeighted[rng.Intn(len(opsWeighted))] + t.Logf("State: %s", chainStatusString()) + t.Logf("Iteration %d/%d: %s", iteration, iterations, nextOperation) + + switch nextOperation { + + case StopOne: + // Stop an active node (if any are left active) + var n stateMachine + activeNodes, n = pickRandomNode(activeNodes) + if n != nil { + n.stop() + stoppedNodes = append(stoppedNodes, n) + } + + case StopAll: + // Stop any node which is active + for _, node := range activeNodes { + node.stop() + } + stoppedNodes = append(stoppedNodes, activeNodes...) + activeNodes = make([]stateMachine, 0, numPeers) + + case RestartOne: + // Restart a stopped node (if any are stopped) + var n stateMachine + stoppedNodes, n = pickRandomNode(stoppedNodes) + if n != nil { + n.restart() + activeNodes = append(activeNodes, n) + } + + case RestartAll: + // Restart any node which is stopped + for _, node := range stoppedNodes { + node.restart() + } + activeNodes = append(activeNodes, stoppedNodes...) + stoppedNodes = make([]stateMachine, 0, numPeers) + + case Snapshot: + // Make an active node take a snapshot (if any nodes are active) + if len(activeNodes) > 0 { + n := activeNodes[rng.Intn(len(activeNodes))] + n.(*raftChainStateMachine).snapshot() + } + + case Propose: + // Make an active node propose the next block (if any nodes are active) + if len(activeNodes) > 0 { + n := activeNodes[rng.Intn(len(activeNodes))] + n.(*raftChainStateMachine).proposeBlock() + } + + case ProposeLeader: + // Make the leader propose the next block (if a leader is active) + leader := rg.leader() + if leader != nil { + leader.(*raftChainStateMachine).proposeBlock() + } + + case Pause: + // Just sit for a while and let things happen + time.Sleep(time.Duration(rng.Intn(250)) * time.Millisecond) + + case Check: + // Restart any stopped node + for _, node := range stoppedNodes { + node.restart() + } + activeNodes = append(activeNodes, stoppedNodes...) + stoppedNodes = make([]stateMachine, 0, numPeers) + + // Ensure all nodes (eventually) converge + checkFor( + t, + checkConvergenceTimeout, + 1*time.Second, + func() error { + referenceBlocksCount, referenceHash := rg[0].(*raftChainStateMachine).getCurrentHash() + for _, n := range rg { + sm := n.(*raftChainStateMachine) + blocksCount, blockHash := sm.getCurrentHash() + // Track the highest block seen + if blocksCount > highestBlocksCount { + t.Logf( + "New highest blocks count: %d (%s (%s))", + blocksCount, + sm.s.Name(), + sm.n.ID(), + ) + highestBlocksCount = blocksCount + } + // Each replica must match the reference node (given enough time) + if blocksCount != referenceBlocksCount || blockHash != referenceHash { + return fmt.Errorf( + "nodes not converged: %s", + chainStatusString(), + ) + } + } + // Replicas are in sync, but missing some blocks that was previously seen + if referenceBlocksCount < highestBlocksCount { + return fmt.Errorf( + "nodes converged below highest known block count: %d: %s", + highestBlocksCount, + chainStatusString(), + ) + } + // All nodes reached the same state, check passed + return nil + }, + ) + } + } + }, + ) +}