diff --git a/contrib/raftexample/kvstore.go b/contrib/raftexample/kvstore.go index 7f69ef8ca74..7273eb49ee1 100644 --- a/contrib/raftexample/kvstore.go +++ b/contrib/raftexample/kvstore.go @@ -21,6 +21,7 @@ import ( "log" "sync" + "go.etcd.io/etcd/raft/v3/raftpb" "go.etcd.io/etcd/server/v3/etcdserver/api/snap" ) @@ -39,8 +40,16 @@ type kv struct { func newKVStore(snapshotter *snap.Snapshotter, proposeC chan<- string, commitC <-chan *string, errorC <-chan error) *kvstore { s := &kvstore{proposeC: proposeC, kvStore: make(map[string]string), snapshotter: snapshotter} - // replay log into key-value map - s.readCommits(commitC, errorC) + snapshot, err := s.loadSnapshot() + if err != nil { + log.Panic(err) + } + if snapshot != nil { + log.Printf("loading snapshot at term %d and index %d", snapshot.Metadata.Term, snapshot.Metadata.Index) + if err := s.recoverFromSnapshot(snapshot.Data); err != nil { + log.Panic(err) + } + } // read commits from raft into kvStore map until error go s.readCommits(commitC, errorC) return s @@ -64,18 +73,16 @@ func (s *kvstore) Propose(k string, v string) { func (s *kvstore) readCommits(commitC <-chan *string, errorC <-chan error) { for data := range commitC { if data == nil { - // done replaying log; new data incoming - // OR signaled to load snapshot - snapshot, err := s.snapshotter.Load() - if err == snap.ErrNoSnapshot { - return - } + // signaled to load snapshot + snapshot, err := s.loadSnapshot() if err != nil { log.Panic(err) } - log.Printf("loading snapshot at term %d and index %d", snapshot.Metadata.Term, snapshot.Metadata.Index) - if err := s.recoverFromSnapshot(snapshot.Data); err != nil { - log.Panic(err) + if snapshot != nil { + log.Printf("loading snapshot at term %d and index %d", snapshot.Metadata.Term, snapshot.Metadata.Index) + if err := s.recoverFromSnapshot(snapshot.Data); err != nil { + log.Panic(err) + } } continue } @@ -100,6 +107,17 @@ func (s *kvstore) getSnapshot() ([]byte, error) { return json.Marshal(s.kvStore) } +func (s *kvstore) loadSnapshot() (*raftpb.Snapshot, error) { + snapshot, err := s.snapshotter.Load() + if err == snap.ErrNoSnapshot { + return nil, nil + } + if err != nil { + return nil, err + } + return snapshot, nil +} + func (s *kvstore) recoverFromSnapshot(snapshot []byte) error { var store map[string]string if err := json.Unmarshal(snapshot, &store); err != nil { diff --git a/contrib/raftexample/raft.go b/contrib/raftexample/raft.go index 399252fbd4a..711979c8fd7 100644 --- a/contrib/raftexample/raft.go +++ b/contrib/raftexample/raft.go @@ -50,7 +50,6 @@ type raftNode struct { waldir string // path to WAL directory snapdir string // path to snapshot directory getSnapshot func() ([]byte, error) - lastIndex uint64 // index of log at start confState raftpb.ConfState snapshotIndex uint64 @@ -175,15 +174,6 @@ func (rc *raftNode) publishEntries(ents []raftpb.Entry) bool { // after commit, update appliedIndex rc.appliedIndex = ents[i].Index - - // special nil commit to signal replay has finished - if ents[i].Index == rc.lastIndex { - select { - case rc.commitC <- nil: - case <-rc.stopc: - return false - } - } } return true } @@ -240,12 +230,7 @@ func (rc *raftNode) replayWAL() *wal.WAL { // append to storage so raft starts at the right place in log rc.raftStorage.Append(ents) - // send nil once lastIndex is published so client knows commit channel is current - if len(ents) > 0 { - rc.lastIndex = ents[len(ents)-1].Index - } else { - rc.commitC <- nil - } + return w } @@ -264,11 +249,13 @@ func (rc *raftNode) startRaft() { } } rc.snapshotter = snap.New(zap.NewExample(), rc.snapdir) - rc.snapshotterReady <- rc.snapshotter oldwal := wal.Exist(rc.waldir) rc.wal = rc.replayWAL() + // signal replay has finished + rc.snapshotterReady <- rc.snapshotter + rpeers := make([]raft.Peer, len(rc.peers)) for i := range rpeers { rpeers[i] = raft.Peer{ID: uint64(i + 1)} @@ -353,6 +340,15 @@ func (rc *raftNode) maybeTriggerSnapshot() { return } + // wait until all committed entries are applied + // commitC is synchronous channel, so consumption of the message signals + // full application of previous messages + select { + case rc.commitC <- nil: + case <-rc.stopc: + return + } + log.Printf("start snapshot [applied index: %d | last snapshot index: %d]", rc.appliedIndex, rc.snapshotIndex) data, err := rc.getSnapshot() if err != nil { diff --git a/contrib/raftexample/raftexample_test.go b/contrib/raftexample/raftexample_test.go index 6c0b629c685..e2702f67e4e 100644 --- a/contrib/raftexample/raftexample_test.go +++ b/contrib/raftexample/raftexample_test.go @@ -61,17 +61,6 @@ func newCluster(n int) *cluster { return clus } -// sinkReplay reads all commits in each node's local log. -func (clus *cluster) sinkReplay() { - for i := range clus.peers { - for s := range clus.commitC[i] { - if s == nil { - break - } - } - } -} - // Close closes all cluster nodes and returns an error if any failed. func (clus *cluster) Close() (err error) { for i := range clus.peers { @@ -102,8 +91,6 @@ func TestProposeOnCommit(t *testing.T) { clus := newCluster(3) defer clus.closeNoErrors(t) - clus.sinkReplay() - donec := make(chan struct{}) for i := range clus.peers { // feedback for "n" committed entries, then update donec @@ -149,8 +136,6 @@ func TestCloseProposerInflight(t *testing.T) { clus := newCluster(1) defer clus.closeNoErrors(t) - clus.sinkReplay() - // some inflight ops go func() { clus.proposeC[0] <- "foo"