Skip to content

Commit

Permalink
Merge pull request #11889 from mrkm4ntr/example-recover-from-snap
Browse files Browse the repository at this point in the history
raftexample: Fix recovery from snapshot
  • Loading branch information
ptabor committed Feb 10, 2021
2 parents 44c889a + be2167e commit e8ba375
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 43 deletions.
40 changes: 29 additions & 11 deletions contrib/raftexample/kvstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"log"
"sync"

"go.etcd.io/etcd/raft/v3/raftpb"
"go.etcd.io/etcd/server/v3/etcdserver/api/snap"
)

Expand All @@ -39,8 +40,16 @@ type kv struct {

func newKVStore(snapshotter *snap.Snapshotter, proposeC chan<- string, commitC <-chan *string, errorC <-chan error) *kvstore {
s := &kvstore{proposeC: proposeC, kvStore: make(map[string]string), snapshotter: snapshotter}
// replay log into key-value map
s.readCommits(commitC, errorC)
snapshot, err := s.loadSnapshot()
if err != nil {
log.Panic(err)
}
if snapshot != nil {
log.Printf("loading snapshot at term %d and index %d", snapshot.Metadata.Term, snapshot.Metadata.Index)
if err := s.recoverFromSnapshot(snapshot.Data); err != nil {
log.Panic(err)
}
}
// read commits from raft into kvStore map until error
go s.readCommits(commitC, errorC)
return s
Expand All @@ -64,18 +73,16 @@ func (s *kvstore) Propose(k string, v string) {
func (s *kvstore) readCommits(commitC <-chan *string, errorC <-chan error) {
for data := range commitC {
if data == nil {
// done replaying log; new data incoming
// OR signaled to load snapshot
snapshot, err := s.snapshotter.Load()
if err == snap.ErrNoSnapshot {
return
}
// signaled to load snapshot
snapshot, err := s.loadSnapshot()
if err != nil {
log.Panic(err)
}
log.Printf("loading snapshot at term %d and index %d", snapshot.Metadata.Term, snapshot.Metadata.Index)
if err := s.recoverFromSnapshot(snapshot.Data); err != nil {
log.Panic(err)
if snapshot != nil {
log.Printf("loading snapshot at term %d and index %d", snapshot.Metadata.Term, snapshot.Metadata.Index)
if err := s.recoverFromSnapshot(snapshot.Data); err != nil {
log.Panic(err)
}
}
continue
}
Expand All @@ -100,6 +107,17 @@ func (s *kvstore) getSnapshot() ([]byte, error) {
return json.Marshal(s.kvStore)
}

func (s *kvstore) loadSnapshot() (*raftpb.Snapshot, error) {
snapshot, err := s.snapshotter.Load()
if err == snap.ErrNoSnapshot {
return nil, nil
}
if err != nil {
return nil, err
}
return snapshot, nil
}

func (s *kvstore) recoverFromSnapshot(snapshot []byte) error {
var store map[string]string
if err := json.Unmarshal(snapshot, &store); err != nil {
Expand Down
30 changes: 13 additions & 17 deletions contrib/raftexample/raft.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ type raftNode struct {
waldir string // path to WAL directory
snapdir string // path to snapshot directory
getSnapshot func() ([]byte, error)
lastIndex uint64 // index of log at start

confState raftpb.ConfState
snapshotIndex uint64
Expand Down Expand Up @@ -175,15 +174,6 @@ func (rc *raftNode) publishEntries(ents []raftpb.Entry) bool {

// after commit, update appliedIndex
rc.appliedIndex = ents[i].Index

// special nil commit to signal replay has finished
if ents[i].Index == rc.lastIndex {
select {
case rc.commitC <- nil:
case <-rc.stopc:
return false
}
}
}
return true
}
Expand Down Expand Up @@ -240,12 +230,7 @@ func (rc *raftNode) replayWAL() *wal.WAL {

// append to storage so raft starts at the right place in log
rc.raftStorage.Append(ents)
// send nil once lastIndex is published so client knows commit channel is current
if len(ents) > 0 {
rc.lastIndex = ents[len(ents)-1].Index
} else {
rc.commitC <- nil
}

return w
}

Expand All @@ -264,11 +249,13 @@ func (rc *raftNode) startRaft() {
}
}
rc.snapshotter = snap.New(zap.NewExample(), rc.snapdir)
rc.snapshotterReady <- rc.snapshotter

oldwal := wal.Exist(rc.waldir)
rc.wal = rc.replayWAL()

// signal replay has finished
rc.snapshotterReady <- rc.snapshotter

rpeers := make([]raft.Peer, len(rc.peers))
for i := range rpeers {
rpeers[i] = raft.Peer{ID: uint64(i + 1)}
Expand Down Expand Up @@ -353,6 +340,15 @@ func (rc *raftNode) maybeTriggerSnapshot() {
return
}

// wait until all committed entries are applied
// commitC is synchronous channel, so consumption of the message signals
// full application of previous messages
select {
case rc.commitC <- nil:
case <-rc.stopc:
return
}

log.Printf("start snapshot [applied index: %d | last snapshot index: %d]", rc.appliedIndex, rc.snapshotIndex)
data, err := rc.getSnapshot()
if err != nil {
Expand Down
15 changes: 0 additions & 15 deletions contrib/raftexample/raftexample_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,6 @@ func newCluster(n int) *cluster {
return clus
}

// sinkReplay reads all commits in each node's local log.
func (clus *cluster) sinkReplay() {
for i := range clus.peers {
for s := range clus.commitC[i] {
if s == nil {
break
}
}
}
}

// Close closes all cluster nodes and returns an error if any failed.
func (clus *cluster) Close() (err error) {
for i := range clus.peers {
Expand Down Expand Up @@ -102,8 +91,6 @@ func TestProposeOnCommit(t *testing.T) {
clus := newCluster(3)
defer clus.closeNoErrors(t)

clus.sinkReplay()

donec := make(chan struct{})
for i := range clus.peers {
// feedback for "n" committed entries, then update donec
Expand Down Expand Up @@ -149,8 +136,6 @@ func TestCloseProposerInflight(t *testing.T) {
clus := newCluster(1)
defer clus.closeNoErrors(t)

clus.sinkReplay()

// some inflight ops
go func() {
clus.proposeC[0] <- "foo"
Expand Down

0 comments on commit e8ba375

Please sign in to comment.