diff --git a/ticker.go b/ticker.go index 4110629..059d991 100644 --- a/ticker.go +++ b/ticker.go @@ -1,38 +1,80 @@ package rate import ( - "context" + "runtime" "sync/atomic" + "time" ) -// NewTicker returns a channel that sends a `struct{}{}` -// at most `*maxrate` times per second. +type Ticker struct { + C <-chan struct{} + ch chan struct{} + closing atomic.Bool + stopped atomic.Bool +} + +// Close stops the Ticker and frees resources. +// +// It is safe to call multiple times or concurrently. +func (ticker *Ticker) Close() { + if ticker.closing.CompareAndSwap(false, true) { + defer close(ticker.ch) + for !ticker.stopped.Load() { + select { + case <-ticker.ch: + default: + } + runtime.Gosched() + } + } +} + +// AddTick adds a single tick to the Ticker, retrying with the +// given interval until it succeeds or the Ticker is closed. +func (ticker *Ticker) AddTick(d time.Duration) { + for !ticker.stopped.Load() && !ticker.closing.Load() { + select { + case ticker.ch <- struct{}{}: + return + default: + } + time.Sleep(d) + } +} + +func (ticker *Ticker) run(parent <-chan struct{}, maxrate *int32, counter *uint64) { + defer func() { + ticker.stopped.Store(true) + ticker.Close() + }() + var rl Limiter + for !ticker.closing.Load() { + if parent != nil { + if _, ok := <-parent; !ok { + break + } + } + ticker.ch <- struct{}{} + if counter != nil { + atomic.AddUint64(counter, 1) + } + rl.Wait(maxrate) + } +} + +// NewTicker returns a Ticker that sends a `struct{}{}` +// at most `*maxrate` times per second on it's C channel. // // If counter is not nil, it is incremented every time a // send is successful. // // A nil `maxrate` or a `*maxrate` of zero or less sends // as quickly as possible. -// -// The channel is closed when the context is done. -func NewTicker(ctx context.Context, maxrate *int32, counter *uint64) chan struct{} { +func NewTicker(maxrate *int32, counter *uint64) (ticker *Ticker) { ch := make(chan struct{}) - go func() { - defer close(ch) - var rl Limiter - for { - select { - case <-ctx.Done(): - return - case ch <- struct{}{}: - } - if counter != nil { - atomic.AddUint64(counter, 1) - } - rl.Wait(maxrate) - } - }() - return ch + ticker = &Ticker{C: ch, ch: ch} + go ticker.run(nil, maxrate, counter) + return } // NewSubTicker returns a channel that reads from another struct{}{} @@ -44,19 +86,10 @@ func NewTicker(ctx context.Context, maxrate *int32, counter *uint64) chan struct // // Use this to make "background" tickers that are less prioritized. // -// The channel is closed when the parent channel is closed. -func NewSubTicker(parent <-chan struct{}, maxrate *int32, counter *uint64) chan struct{} { +// The Ticker is closed when the parent channel closes. +func NewSubTicker(parent <-chan struct{}, maxrate *int32, counter *uint64) (ticker *Ticker) { ch := make(chan struct{}) - go func() { - defer close(ch) - var rl Limiter - for range parent { - ch <- struct{}{} - if counter != nil { - atomic.AddUint64(counter, 1) - } - rl.Wait(maxrate) - } - }() - return ch + ticker = &Ticker{C: ch, ch: ch} + go ticker.run(parent, maxrate, counter) + return } diff --git a/ticker_test.go b/ticker_test.go index 263b1ae..0364b8c 100644 --- a/ticker_test.go +++ b/ticker_test.go @@ -1,18 +1,17 @@ package rate import ( - "context" + "sync" "sync/atomic" "testing" "time" ) -func TestTickerRespectsContext(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - ch := NewTicker(ctx, nil, nil) +func TestTickerClosing(t *testing.T) { + ticker := NewTicker(nil, nil) + ticker.Close() select { - case _, ok := <-ch: + case _, ok := <-ticker.C: if ok { t.Error("got a tick") } @@ -23,19 +22,14 @@ func TestTickerRespectsContext(t *testing.T) { func TestNewTicker(t *testing.T) { const n = 100 var counter uint64 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - ch := NewTicker(ctx, nil, &counter) now := time.Now() + ticker := NewTicker(nil, &counter) for i := 0; i < n; i++ { - _, ok := <-ch + _, ok := <-ticker.C if !ok { t.Error("ticker channel closed early") } } - if d := time.Since(now); d > variance { - t.Errorf("%v > %v", d, variance) - } for i := 0; i < 10; i++ { if atomic.LoadUint64(&counter) == n { break @@ -46,18 +40,19 @@ func TestNewTicker(t *testing.T) { if x := atomic.LoadUint64(&counter); x != n { t.Errorf("%v != %v", x, n) } + if d := time.Since(now); d > variance { + t.Errorf("%v > %v", d, variance) + } } func TestNewSubTicker(t *testing.T) { const n = 100 var counter uint64 - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - ch1 := NewTicker(ctx, nil, nil) - ch2 := NewSubTicker(ch1, nil, &counter) now := time.Now() + t1 := NewTicker(nil, nil) + t2 := NewSubTicker(t1.C, nil, &counter) for i := 0; i < n; i++ { - _, ok := <-ch2 + _, ok := <-t2.C if !ok { t.Error("ticker channel closed early") } @@ -71,8 +66,60 @@ func TestNewSubTicker(t *testing.T) { } time.Sleep(time.Millisecond) } - time.Sleep(time.Millisecond) if x := atomic.LoadUint64(&counter); x != n { t.Errorf("%v != %v", x, n) } + t1.Close() + // there can be at most one extra tick to read after t1.Close + if _, ok := <-t2.C; ok { + if _, ok := <-t2.C; ok { + t.Error("t2 should have been closed") + } + } + if d := time.Since(now); d > variance { + t.Errorf("%v > %v", d, variance) + } +} + +func TestAddingTick(t *testing.T) { + var counter uint64 + + now := time.Now() + maxrate := int32(time.Second / variance * 2) + ticker := NewTicker(&maxrate, &counter) + + select { + case <-ticker.C: + case <-time.NewTimer(variance).C: + t.Error("timed out waiting for tick") + } + + select { + case <-ticker.C: + t.Error("got an unexpected tick") + default: + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + select { + case <-ticker.C: + case <-time.NewTimer(variance).C: + t.Error("timed out waiting for tick") + } + }() + ticker.AddTick(time.Nanosecond) + wg.Wait() + if d := time.Since(now); d > variance { + t.Errorf("%v > %v", d, variance) + } + ticker.Close() + if counter != 1 { + t.Error("counter should be one, not", counter) + } + if d := time.Since(now); d > variance { + t.Errorf("%v > %v", d, variance) + } }