diff --git a/db.go b/db.go index f5910b988..31ab7b664 100644 --- a/db.go +++ b/db.go @@ -1850,67 +1850,120 @@ type KVList = pb.KVList // This function blocks until the given context is done or an error occurs. // The given function will be called with a new KVList containing the modified keys and the // corresponding values. +// Due to the blocking nature of this function, it's impossible to know when subscription actually +// takes place, other than waiting for your first cb notification. If you need to wait for +// confirmation of subscription, but can't wait until your first cb notification, consider +// SubscribeAsync instead. func (db *DB) Subscribe(ctx context.Context, cb func(kv *KVList) error, matches []pb.Match) error { if cb == nil { return ErrNilCallback } + events, unsubscribe, err := db.SubscribeAsync(matches) + if err != nil { + return err + } + defer unsubscribe() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case event, ok := <-events: + if !ok { + return nil + } + if err := cb(event); err != nil { + return err + } + } + } +} + +// An UnsubscribeFunc tells SubscribeAsync that no further events will be consumed from the returned +// channel. Any pending event deliveries, either queued or in-progress, will be dropped. +// An UnsubscribeFunc can be called by multiple goroutines simultaneously. +// After the first call, subsequent calls to an UnsubscribeFunc do nothing. +type UnsubscribeFunc func() + +// SubscribeAsync can be used to watch key changes for the given key prefixes and the ignore string. +// At least one prefix should be passed, or an error will be returned. +// You can use an empty prefix to monitor all changes to the DB. +// Ignore string is the byte ranges for which prefix matching will be ignored. +// For example: ignore = "2-3", and prefix = "abc" will match for keys "abxxc", "abdfc" etc. +// The returned channel can be listened on for KVList events containing the modified keys and the +// corresponding values. The returned UnsubscribeFunc should be called as soon as you are done +// listening for events. Failure to do so will block other subscribers from receiving event +// deliveries of their own. +func (db *DB) SubscribeAsync(matches []pb.Match) (<-chan *KVList, UnsubscribeFunc, error) { c := z.NewCloser(1) s, err := db.pub.newSubscriber(c, matches) if err != nil { - return y.Wrapf(err, "while creating a new subscriber") + return nil, nil, y.Wrapf(err, "while creating a new subscriber") } - slurp := func(batch *pb.KVList) error { + + outChan := make(chan *KVList) + slurp := func(batch *pb.KVList) *pb.KVList { for { select { case kvs := <-s.sendCh: batch.Kv = append(batch.Kv, kvs.Kv...) default: - if len(batch.GetKv()) > 0 { - return cb(batch) - } - return nil + return batch } } } - drain := func() { + unsubscribeChan := make(chan bool) + unsubscribed := int32(0) + unsubscribe := func() { + if atomic.CompareAndSwapInt32(&unsubscribed, 0, 1) { + close(unsubscribeChan) + } + } + + go func() { + defer close(outChan) + for { select { - case <-s.sendCh: - default: + case <-c.HasBeenClosed(): + // No need to delete here. Closer will be called only while + // closing DB. Subscriber will be deleted by cleanSubscribers. + // Drain if any pending updates. + if batch := slurp(new(pb.KVList)); len(batch.Kv) > 0 { + select { + // Send event. + case outChan <- batch: + // Unless consumer unsubscribes in the middle. + case <-unsubscribeChan: + } + } + c.Done() return - } - } - } - for { - select { - case <-c.HasBeenClosed(): - // No need to delete here. Closer will be called only while - // closing DB. Subscriber will be deleted by cleanSubscribers. - err := slurp(new(pb.KVList)) - // Drain if any pending updates. - c.Done() - return err - case <-ctx.Done(): - c.Done() - s.active.Store(0) - drain() - db.pub.deleteSubscriber(s.id) - // Delete the subscriber to avoid further updates. - return ctx.Err() - case batch := <-s.sendCh: - err := slurp(batch) - if err != nil { + case <-unsubscribeChan: c.Done() + // Deactivate subscriber. s.active.Store(0) - drain() - // Delete the subscriber if there is an error by the callback. + // Drain and discard if any pending updates. + slurp(new(pb.KVList)) + // Delete the subscriber to avoid further updates. db.pub.deleteSubscriber(s.id) - return err + return + case batch := <-s.sendCh: + if batch = slurp(batch); len(batch.Kv) > 0 { + select { + // Send event. + case outChan <- batch: + // Unless consumer unsubscribes in the middle. + case <-unsubscribeChan: + } + } } } - } + }() + + return outChan, unsubscribe, nil } func (db *DB) syncDir(dir string) error { diff --git a/publisher_test.go b/publisher_test.go index 1113ba2e8..c9a4d9ce3 100644 --- a/publisher_test.go +++ b/publisher_test.go @@ -85,6 +85,48 @@ func TestPublisherDeadlock(t *testing.T) { }) } +func TestPublisherAsyncDeadlock(t *testing.T) { + runBadgerTest(t, nil, func(t *testing.T, db *DB) { + match := pb.Match{Prefix: []byte("ke"), IgnoreBytes: ""} + events, unsubscribe, err := db.SubscribeAsync([]pb.Match{match}) + require.NoError(t, err) + + firstUpdate := sync.WaitGroup{} + firstUpdate.Add(1) + + go func() { + <-events + firstUpdate.Done() + time.Sleep(time.Second * 20) + unsubscribe() + }() + + err = db.Update(func(txn *Txn) error { + e := NewEntry([]byte(fmt.Sprintf("key%d", 0)), []byte(fmt.Sprintf("value%d", 0))) + return txn.SetEntry(e) + }) + require.NoError(t, err) + + firstUpdate.Wait() + req := int64(0) + for i := 1; i < 1110; i++ { + time.Sleep(time.Millisecond * 10) + err := db.Update(func(txn *Txn) error { + e := NewEntry([]byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i))) + return txn.SetEntry(e) + }) + require.NoError(t, err) + atomic.AddInt64(&req, 1) + } + for atomic.LoadInt64(&req) != 1109 { + time.Sleep(time.Second) + } + for range events { + // wait for events to be closed + } + }) +} + func TestPublisherOrdering(t *testing.T) { runBadgerTest(t, nil, func(t *testing.T, db *DB) { order := []string{} @@ -124,6 +166,30 @@ func TestPublisherOrdering(t *testing.T) { }) } +func TestPublisherAsyncOrdering(t *testing.T) { + runBadgerTest(t, nil, func(t *testing.T, db *DB) { + match := pb.Match{Prefix: []byte("ke"), IgnoreBytes: ""} + events, unsubscribe, err := db.SubscribeAsync([]pb.Match{match}) + require.NoError(t, err) + defer unsubscribe() + + for i := 0; i < 5; i++ { + err := db.Update(func(txn *Txn) error { + e := NewEntry([]byte(fmt.Sprintf("key%d", i)), []byte(fmt.Sprintf("value%d", i))) + return txn.SetEntry(e) + }) + require.NoError(t, err) + } + + for i := 0; i < 5; { + for _, kv := range (<-events).Kv { + require.Equal(t, fmt.Sprintf("value%d", i), string(kv.Value)) + i++ + } + } + }) +} + func TestMultiplePrefix(t *testing.T) { runBadgerTest(t, nil, func(t *testing.T, db *DB) { var wg sync.WaitGroup @@ -163,3 +229,33 @@ func TestMultiplePrefix(t *testing.T) { wg.Wait() }) } + +func TestMultiplePrefixAsync(t *testing.T) { + runBadgerTest(t, nil, func(t *testing.T, db *DB) { + match1 := pb.Match{Prefix: []byte("ke"), IgnoreBytes: ""} + match2 := pb.Match{Prefix: []byte("hel"), IgnoreBytes: ""} + events, unsubscribe, err := db.SubscribeAsync([]pb.Match{match1, match2}) + require.NoError(t, err) + defer unsubscribe() + + err = db.Update(func(txn *Txn) error { + return txn.SetEntry(NewEntry([]byte("key"), []byte("value"))) + }) + require.NoError(t, err) + err = db.Update(func(txn *Txn) error { + return txn.SetEntry(NewEntry([]byte("hello"), []byte("badger"))) + }) + require.NoError(t, err) + + for i := 0; i < 2; { + for _, kv := range (<-events).Kv { + if string(kv.Key) == "key" { + require.Equal(t, string(kv.Value), "value") + } else { + require.Equal(t, string(kv.Value), "badger") + } + i++ + } + } + }) +}