Permalink
Browse files

Merge branch 'master' of https://github.com/goraft/raft

  • Loading branch information...
2 parents ef3280c + bfee414 commit cff0a00545434990a9922585edd420d26cbd128b @benbjohnson benbjohnson committed Feb 11, 2014
Showing with 59 additions and 17 deletions.
  1. +1 −0 event.go
  2. +18 −0 event_dispatcher.go
  3. +19 −0 event_dispatcher_test.go
  4. +7 −7 log.go
  5. +2 −2 log_entry.go
  6. +4 −4 log_test.go
  7. +5 −1 server.go
  8. +3 −3 test.go
View
1 event.go
@@ -4,6 +4,7 @@ const (
StateChangeEventType = "stateChange"
LeaderChangeEventType = "leaderChange"
TermChangeEventType = "termChange"
+ CommitEventType = "commit"
AddPeerEventType = "addPeer"
RemovePeerEventType = "removePeer"
View
18 event_dispatcher.go
@@ -1,6 +1,7 @@
package raft
import (
+ "reflect"
"sync"
)
@@ -33,6 +34,23 @@ func (d *eventDispatcher) AddEventListener(typ string, listener EventListener) {
d.listeners[typ] = append(d.listeners[typ], listener)
}
+// RemoveEventListener removes a listener function for a given event type.
+func (d *eventDispatcher) RemoveEventListener(typ string, listener EventListener) {
+ d.Lock()
+ defer d.Unlock()
+
+ // Grab a reference to the function pointer once.
+ ptr := reflect.ValueOf(listener).Pointer()
+
+ // Find listener by pointer and remove it.
+ listeners := d.listeners[typ]
+ for i, l := range listeners {
+ if reflect.ValueOf(l).Pointer() == ptr {
+ d.listeners[typ] = append(listeners[:i], listeners[i+1:]...)
+ }
+ }
+}
+
// DispatchEvent dispatches an event.
func (d *eventDispatcher) DispatchEvent(e Event) {
d.RLock()
View
19 event_dispatcher_test.go
@@ -23,6 +23,25 @@ func TestDispatchEvent(t *testing.T) {
assert.Equal(t, 11, count)
}
+// Ensure that we can add and remove a listener.
+func TestRemoveEventListener(t *testing.T) {
+ var count int
+ f0 := func(e Event) {
+ count += 1
+ }
+ f1 := func(e Event) {
+ count += 10
+ }
+
+ dispatcher := newEventDispatcher(nil)
+ dispatcher.AddEventListener("foo", f0)
+ dispatcher.AddEventListener("foo", f1)
+ dispatcher.DispatchEvent(&event{typ: "foo"})
+ dispatcher.RemoveEventListener("foo", f0)
+ dispatcher.DispatchEvent(&event{typ: "foo"})
+ assert.Equal(t, 21, count)
+}
+
// Ensure that event is properly passed to listener.
func TestEventListener(t *testing.T) {
dispatcher := newEventDispatcher("X")
View
14 log.go
@@ -19,7 +19,7 @@ import (
// A log is a collection of log entries that are persisted to durable storage.
type Log struct {
- ApplyFunc func(Command) (interface{}, error)
+ ApplyFunc func(*LogEntry, Command) (interface{}, error)
file *os.File
path string
entries []*LogEntry
@@ -160,7 +160,7 @@ func (l *Log) open(path string) error {
entry, _ := newLogEntry(l, nil, 0, 0, nil)
entry.Position, _ = l.file.Seek(0, os.SEEK_CUR)
- n, err := entry.decode(l.file)
+ n, err := entry.Decode(l.file)
if err != nil {
if err == io.EOF {
debugln("open.log.append: finish ")
@@ -179,7 +179,7 @@ func (l *Log) open(path string) error {
if err != nil {
continue
}
- l.ApplyFunc(command)
+ l.ApplyFunc(entry, command)
}
debugln("open.log.append log index ", entry.Index())
}
@@ -368,7 +368,7 @@ func (l *Log) setCommitIndex(index uint64) error {
}
// Apply the changes to the state machine and store the error code.
- returnValue, err := l.ApplyFunc(command)
+ returnValue, err := l.ApplyFunc(entry, command)
debugf("setCommitIndex.set.result index: %v, entries index: %v", i, entryIndex)
if entry.event != nil {
@@ -517,7 +517,7 @@ func (l *Log) appendEntry(entry *LogEntry) error {
entry.Position = position
// Write to storage.
- if _, err := entry.encode(l.file); err != nil {
+ if _, err := entry.Encode(l.file); err != nil {
return err
}
@@ -544,7 +544,7 @@ func (l *Log) writeEntry(entry *LogEntry, w io.Writer) (int64, error) {
}
// Write to storage.
- size, err := entry.encode(w)
+ size, err := entry.Encode(w)
if err != nil {
return -1, err
}
@@ -589,7 +589,7 @@ func (l *Log) compact(index uint64, term uint64) error {
position, _ := l.file.Seek(0, os.SEEK_CUR)
entry.Position = position
- if _, err = entry.encode(file); err != nil {
+ if _, err = entry.Encode(file); err != nil {
file.Close()
os.Remove(new_file_path)
return err
View
4 log_entry.go
@@ -67,7 +67,7 @@ func (e *LogEntry) Command() []byte {
// Encodes the log entry to a buffer. Returns the number of bytes
// written and any error that may have occurred.
-func (e *LogEntry) encode(w io.Writer) (int, error) {
+func (e *LogEntry) Encode(w io.Writer) (int, error) {
b, err := proto.Marshal(e.pb)
if err != nil {
return -1, err
@@ -82,7 +82,7 @@ func (e *LogEntry) encode(w io.Writer) (int, error) {
// Decodes the log entry from a buffer. Returns the number of bytes read and
// any error that occurs.
-func (e *LogEntry) decode(r io.Reader) (int, error) {
+func (e *LogEntry) Decode(r io.Reader) (int, error) {
var length int
_, err := fmt.Fscanf(r, "%8x\n", &length)
View
8 log_test.go
@@ -21,7 +21,7 @@ import (
func TestLogNewLog(t *testing.T) {
path := getLogPath()
log := newLog()
- log.ApplyFunc = func(c Command) (interface{}, error) {
+ log.ApplyFunc = func(e *LogEntry, c Command) (interface{}, error) {
return nil, nil
}
if err := log.open(path); err != nil {
@@ -119,13 +119,13 @@ func TestLogRecovery(t *testing.T) {
e1, _ := newLogEntry(tmpLog, nil, 2, 1, &testCommand2{X: 100})
f, _ := ioutil.TempFile("", "raft-log-")
- e0.encode(f)
- e1.encode(f)
+ e0.Encode(f)
+ e1.Encode(f)
f.WriteString("CORRUPT!")
f.Close()
log := newLog()
- log.ApplyFunc = func(c Command) (interface{}, error) {
+ log.ApplyFunc = func(e *LogEntry, c Command) (interface{}, error) {
return nil, nil
}
if err := log.open(f.Name()); err != nil {
View
6 server.go
@@ -180,7 +180,11 @@ func NewServer(name string, path string, transporter Transporter, stateMachine S
s.eventDispatcher = newEventDispatcher(s)
// Setup apply function.
- s.log.ApplyFunc = func(c Command) (interface{}, error) {
+ s.log.ApplyFunc = func(e *LogEntry, c Command) (interface{}, error) {
+ // Dispatch commit event.
+ s.DispatchEvent(newEvent(CommitEventType, e, nil))
+
+ // Apply command to the state machine.
switch c := c.(type) {
case CommandApply:
return c.Apply(&context{
View
6 test.go
@@ -42,7 +42,7 @@ func setupLog(entries []*LogEntry) (*Log, string) {
f, _ := ioutil.TempFile("", "raft-log-")
for _, entry := range entries {
- entry.encode(f)
+ entry.Encode(f)
}
err := f.Close()
@@ -51,7 +51,7 @@ func setupLog(entries []*LogEntry) (*Log, string) {
}
log := newLog()
- log.ApplyFunc = func(c Command) (interface{}, error) {
+ log.ApplyFunc = func(e *LogEntry, c Command) (interface{}, error) {
return nil, nil
}
if err := log.open(f.Name()); err != nil {
@@ -95,7 +95,7 @@ func newTestServerWithLog(name string, transporter Transporter, entries []*LogEn
}
for _, entry := range entries {
- entry.encode(f)
+ entry.Encode(f)
}
f.Close()
return server

0 comments on commit cff0a00

Please sign in to comment.