diff --git a/pubsub/batcher/batcher.go b/pubsub/batcher/batcher.go index 917cef822a..ad1667e0b9 100644 --- a/pubsub/batcher/batcher.go +++ b/pubsub/batcher/batcher.go @@ -200,26 +200,40 @@ func (b *Batcher) AddNoWait(item interface{}) <-chan error { if b.nHandlers < b.opts.MaxHandlers { // If we can start a handler, do so with the item just added and any others that are pending. batch := b.nextBatch() - if batch != nil { - b.wg.Add(1) - go func() { - b.callHandler(batch) - b.wg.Done() - }() - b.nHandlers++ - } + b.handleBatch(batch) } // If we can't start a handler, then one of the currently running handlers will // take our item. return c } +func (b *Batcher) handleBatch(batch []waiter) { + if batch == nil || len(batch) == 0 { + return + } + + b.wg.Add(1) + go func() { + b.callHandler(batch) + b.wg.Done() + }() + b.nHandlers++ +} + // nextBatch returns the batch to process, and updates b.pending. // It returns nil if there's no batch ready for processing. // b.mu must be held. func (b *Batcher) nextBatch() []waiter { if len(b.pending) < b.opts.MinBatchSize { - return nil + // We handle minimum batch sizes depending on specific + // situations. + // XXX: If we allow max batch lifetimes, handle that here. + if b.shutdown == false { + // If we're not shutting down, respect minimums. If we're + // shutting down, though, we ignore minimums to flush the + // entire batch. + return nil + } } if b.opts.MaxBatchByteSize == 0 && (b.opts.MaxBatchSize == 0 || len(b.pending) <= b.opts.MaxBatchSize) { @@ -283,5 +297,13 @@ func (b *Batcher) Shutdown() { b.mu.Lock() b.shutdown = true b.mu.Unlock() + + // On shutdown, ensure that we attempt to flush any pending items + // if there's a minimum batch size. + if b.nHandlers < b.opts.MaxHandlers { + batch := b.nextBatch() + b.handleBatch(batch) + } + b.wg.Wait() } diff --git a/pubsub/batcher/batcher_test.go b/pubsub/batcher/batcher_test.go index e7c5dd96c1..9b0ee3c055 100644 --- a/pubsub/batcher/batcher_test.go +++ b/pubsub/batcher/batcher_test.go @@ -171,6 +171,34 @@ func TestMinBatchSize(t *testing.T) { } } +// TestMinBatchSizeFlushesOnShutdown ensures that Shutdown() flushes batches, even if +// the pending count is less than the minimum batch size. +func TestMinBatchSizeFlushesOnShutdown(t *testing.T) { + var got [][]int + + batchSize := 3 + + b := batcher.New(reflect.TypeOf(int(0)), &batcher.Options{MinBatchSize: batchSize}, func(items interface{}) error { + got = append(got, items.([]int)) + return nil + }) + for i := 0; i < (batchSize - 1); i++ { + b.AddNoWait(i) + } + + // Ensure that we've received nothing + if len(got) > 0 { + t.Errorf("got batch unexpectedly: %+v", got) + } + + b.Shutdown() + + want := [][]int{{0, 1}} + if !cmp.Equal(got, want) { + t.Errorf("got %+v, want %+v on shutdown", got, want) + } +} + func TestSaturation(t *testing.T) { // Verify that under high load the maximum number of handlers are running. ctx := context.Background()