diff --git a/contrib/raftexample/kvstore.go b/contrib/raftexample/kvstore.go index 7273eb49ee1..ba49d00ee55 100644 --- a/contrib/raftexample/kvstore.go +++ b/contrib/raftexample/kvstore.go @@ -38,7 +38,7 @@ type kv struct { Val string } -func newKVStore(snapshotter *snap.Snapshotter, proposeC chan<- string, commitC <-chan *string, errorC <-chan error) *kvstore { +func newKVStore(snapshotter *snap.Snapshotter, proposeC chan<- string, commitC <-chan *commit, errorC <-chan error) *kvstore { s := &kvstore{proposeC: proposeC, kvStore: make(map[string]string), snapshotter: snapshotter} snapshot, err := s.loadSnapshot() if err != nil { @@ -70,9 +70,9 @@ func (s *kvstore) Propose(k string, v string) { s.proposeC <- buf.String() } -func (s *kvstore) readCommits(commitC <-chan *string, errorC <-chan error) { - for data := range commitC { - if data == nil { +func (s *kvstore) readCommits(commitC <-chan *commit, errorC <-chan error) { + for commit := range commitC { + if commit == nil { // signaled to load snapshot snapshot, err := s.loadSnapshot() if err != nil { @@ -87,14 +87,17 @@ func (s *kvstore) readCommits(commitC <-chan *string, errorC <-chan error) { continue } - var dataKv kv - dec := gob.NewDecoder(bytes.NewBufferString(*data)) - if err := dec.Decode(&dataKv); err != nil { - log.Fatalf("raftexample: could not decode message (%v)", err) + for _, data := range commit.data { + var dataKv kv + dec := gob.NewDecoder(bytes.NewBufferString(data)) + if err := dec.Decode(&dataKv); err != nil { + log.Fatalf("raftexample: could not decode message (%v)", err) + } + s.mu.Lock() + s.kvStore[dataKv.Key] = dataKv.Val + s.mu.Unlock() } - s.mu.Lock() - s.kvStore[dataKv.Key] = dataKv.Val - s.mu.Unlock() + close(commit.applyDoneC) } if err, ok := <-errorC; ok { log.Fatal(err) diff --git a/contrib/raftexample/raft.go b/contrib/raftexample/raft.go index 98a039b7b26..73b0d1e39a4 100644 --- a/contrib/raftexample/raft.go +++ b/contrib/raftexample/raft.go @@ -37,11 +37,16 @@ import ( "go.uber.org/zap" ) +type commit struct { + data []string + applyDoneC chan<- struct{} +} + // A key-value stream backed by raft type raftNode struct { proposeC <-chan string // proposed messages (k,v) confChangeC <-chan raftpb.ConfChange // proposed cluster config changes - commitC chan<- *string // entries committed to log (k,v) + commitC chan<- *commit // entries committed to log (k,v) errorC chan<- error // errors from raft session id int // client ID for raft session @@ -80,9 +85,9 @@ var defaultSnapshotCount uint64 = 10000 // commit channel, followed by a nil message (to indicate the channel is // current), then new log entries. To shutdown, close proposeC and read errorC. func newRaftNode(id int, peers []string, join bool, getSnapshot func() ([]byte, error), proposeC <-chan string, - confChangeC <-chan raftpb.ConfChange) (<-chan *string, <-chan error, <-chan *snap.Snapshotter) { + confChangeC <-chan raftpb.ConfChange) (<-chan *commit, <-chan error, <-chan *snap.Snapshotter) { - commitC := make(chan *string) + commitC := make(chan *commit) errorC := make(chan error) rc := &raftNode{ @@ -143,7 +148,12 @@ func (rc *raftNode) entriesToApply(ents []raftpb.Entry) (nents []raftpb.Entry) { // publishEntries writes committed log entries to commit channel and returns // whether all entries could be published. -func (rc *raftNode) publishEntries(ents []raftpb.Entry) bool { +func (rc *raftNode) publishEntries(ents []raftpb.Entry) (<-chan struct{}, bool) { + if len(ents) == 0 { + return nil, true + } + + data := make([]string, 0, len(ents)) for i := range ents { switch ents[i].Type { case raftpb.EntryNormal: @@ -152,12 +162,7 @@ func (rc *raftNode) publishEntries(ents []raftpb.Entry) bool { break } s := string(ents[i].Data) - select { - case rc.commitC <- &s: - case <-rc.stopc: - return false - } - + data = append(data, s) case raftpb.EntryConfChange: var cc raftpb.ConfChange cc.Unmarshal(ents[i].Data) @@ -170,16 +175,28 @@ func (rc *raftNode) publishEntries(ents []raftpb.Entry) bool { case raftpb.ConfChangeRemoveNode: if cc.NodeID == uint64(rc.id) { log.Println("I've been removed from the cluster! Shutting down.") - return false + return nil, false } rc.transport.RemovePeer(types.ID(cc.NodeID)) } } + } - // after commit, update appliedIndex - rc.appliedIndex = ents[i].Index + var applyDoneC chan struct{} + + if len(data) > 0 { + applyDoneC := make(chan struct{}, 1) + select { + case rc.commitC <- &commit{data, applyDoneC}: + case <-rc.stopc: + return nil, false + } } - return true + + // after commit, update appliedIndex + rc.appliedIndex = ents[len(ents)-1].Index + + return applyDoneC, true } func (rc *raftNode) loadSnapshot() *raftpb.Snapshot { @@ -346,18 +363,14 @@ func (rc *raftNode) publishSnapshot(snapshotToSave raftpb.Snapshot) { var snapshotCatchUpEntriesN uint64 = 10000 -func (rc *raftNode) maybeTriggerSnapshot() { +func (rc *raftNode) maybeTriggerSnapshot(applyDoneC <-chan struct{}) { if rc.appliedIndex-rc.snapshotIndex <= rc.snapCount { return } // wait until all committed entries are applied - // commitC is synchronous channel, so consumption of the message signals - // full application of previous messages - select { - case rc.commitC <- nil: - case <-rc.stopc: - return + if applyDoneC != nil { + <-applyDoneC } log.Printf("start snapshot [applied index: %d | last snapshot index: %d]", rc.appliedIndex, rc.snapshotIndex) @@ -443,11 +456,12 @@ func (rc *raftNode) serveChannels() { } rc.raftStorage.Append(rd.Entries) rc.transport.Send(rd.Messages) - if ok := rc.publishEntries(rc.entriesToApply(rd.CommittedEntries)); !ok { + applyDoneC, ok := rc.publishEntries(rc.entriesToApply(rd.CommittedEntries)) + if !ok { rc.stop() return } - rc.maybeTriggerSnapshot() + rc.maybeTriggerSnapshot(applyDoneC) rc.node.Advance() case err := <-rc.transport.ErrorC: diff --git a/contrib/raftexample/raftexample_test.go b/contrib/raftexample/raftexample_test.go index e2702f67e4e..7643f77fd94 100644 --- a/contrib/raftexample/raftexample_test.go +++ b/contrib/raftexample/raftexample_test.go @@ -29,7 +29,7 @@ import ( type cluster struct { peers []string - commitC []<-chan *string + commitC []<-chan *commit errorC []<-chan error proposeC []chan string confChangeC []chan raftpb.ConfChange @@ -44,7 +44,7 @@ func newCluster(n int) *cluster { clus := &cluster{ peers: peers, - commitC: make([]<-chan *string, len(peers)), + commitC: make([]<-chan *commit, len(peers)), errorC: make([]<-chan error, len(peers)), proposeC: make([]chan string, len(peers)), confChangeC: make([]chan raftpb.ConfChange, len(peers)), @@ -94,14 +94,14 @@ func TestProposeOnCommit(t *testing.T) { donec := make(chan struct{}) for i := range clus.peers { // feedback for "n" committed entries, then update donec - go func(pC chan<- string, cC <-chan *string, eC <-chan error) { + go func(pC chan<- string, cC <-chan *commit, eC <-chan error) { for n := 0; n < 100; n++ { - s, ok := <-cC + c, ok := <-cC if !ok { pC = nil } select { - case pC <- *s: + case pC <- c.data[0]: continue case err := <-eC: t.Errorf("eC message (%v)", err) @@ -143,7 +143,7 @@ func TestCloseProposerInflight(t *testing.T) { }() // wait for one message - if c, ok := <-clus.commitC[0]; *c != "foo" || !ok { + if c, ok := <-clus.commitC[0]; !ok || c.data[0] != "foo" { t.Fatalf("Commit failed") } }