Skip to content

Commit

Permalink
rewrite to use Ticker type
Browse files Browse the repository at this point in the history
  • Loading branch information
linkdata committed Apr 30, 2024
1 parent 022d17a commit 11cb3f8
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 55 deletions.
105 changes: 69 additions & 36 deletions ticker.go
Original file line number Diff line number Diff line change
@@ -1,38 +1,80 @@
package rate

import (
"context"
"runtime"
"sync/atomic"
"time"
)

// NewTicker returns a channel that sends a `struct{}{}`
// at most `*maxrate` times per second.
type Ticker struct {
C <-chan struct{}
ch chan struct{}
closing atomic.Bool

Check failure on line 12 in ticker.go

View workflow job for this annotation

GitHub Actions / build (1.16)

undefined: "sync/atomic".Bool
stopped atomic.Bool

Check failure on line 13 in ticker.go

View workflow job for this annotation

GitHub Actions / build (1.16)

undefined: "sync/atomic".Bool
}

// Close stops the Ticker and frees resources.
//
// It is safe to call multiple times or concurrently.
func (ticker *Ticker) Close() {
if ticker.closing.CompareAndSwap(false, true) {
defer close(ticker.ch)
for !ticker.stopped.Load() {
select {
case <-ticker.ch:
default:
}
runtime.Gosched()
}
}
}

// AddTick adds a single tick to the Ticker, retrying with the
// given interval until it succeeds or the Ticker is closed.
func (ticker *Ticker) AddTick(d time.Duration) {
for !ticker.stopped.Load() && !ticker.closing.Load() {
select {
case ticker.ch <- struct{}{}:
return
default:
}
time.Sleep(d)
}
}

func (ticker *Ticker) run(parent <-chan struct{}, maxrate *int32, counter *uint64) {
defer func() {
ticker.stopped.Store(true)
ticker.Close()
}()
var rl Limiter
for !ticker.closing.Load() {
if parent != nil {
if _, ok := <-parent; !ok {
break
}
}
ticker.ch <- struct{}{}
if counter != nil {
atomic.AddUint64(counter, 1)
}
rl.Wait(maxrate)
}
}

// NewTicker returns a Ticker that sends a `struct{}{}`
// at most `*maxrate` times per second on it's C channel.
//
// If counter is not nil, it is incremented every time a
// send is successful.
//
// A nil `maxrate` or a `*maxrate` of zero or less sends
// as quickly as possible.
//
// The channel is closed when the context is done.
func NewTicker(ctx context.Context, maxrate *int32, counter *uint64) chan struct{} {
func NewTicker(maxrate *int32, counter *uint64) (ticker *Ticker) {
ch := make(chan struct{})
go func() {
defer close(ch)
var rl Limiter
for {
select {
case <-ctx.Done():
return
case ch <- struct{}{}:
}
if counter != nil {
atomic.AddUint64(counter, 1)
}
rl.Wait(maxrate)
}
}()
return ch
ticker = &Ticker{C: ch, ch: ch}
go ticker.run(nil, maxrate, counter)
return
}

// NewSubTicker returns a channel that reads from another struct{}{}
Expand All @@ -44,19 +86,10 @@ func NewTicker(ctx context.Context, maxrate *int32, counter *uint64) chan struct
//
// Use this to make "background" tickers that are less prioritized.
//
// The channel is closed when the parent channel is closed.
func NewSubTicker(parent <-chan struct{}, maxrate *int32, counter *uint64) chan struct{} {
// The Ticker is closed when the parent channel closes.
func NewSubTicker(parent <-chan struct{}, maxrate *int32, counter *uint64) (ticker *Ticker) {
ch := make(chan struct{})
go func() {
defer close(ch)
var rl Limiter
for range parent {
ch <- struct{}{}
if counter != nil {
atomic.AddUint64(counter, 1)
}
rl.Wait(maxrate)
}
}()
return ch
ticker = &Ticker{C: ch, ch: ch}
go ticker.run(parent, maxrate, counter)
return
}
85 changes: 66 additions & 19 deletions ticker_test.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
package rate

import (
"context"
"sync"
"sync/atomic"
"testing"
"time"
)

func TestTickerRespectsContext(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
ch := NewTicker(ctx, nil, nil)
func TestTickerClosing(t *testing.T) {
ticker := NewTicker(nil, nil)
ticker.Close()
select {
case _, ok := <-ch:
case _, ok := <-ticker.C:
if ok {
t.Error("got a tick")
}
Expand All @@ -23,19 +22,14 @@ func TestTickerRespectsContext(t *testing.T) {
func TestNewTicker(t *testing.T) {
const n = 100
var counter uint64
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ch := NewTicker(ctx, nil, &counter)
now := time.Now()
ticker := NewTicker(nil, &counter)
for i := 0; i < n; i++ {
_, ok := <-ch
_, ok := <-ticker.C
if !ok {
t.Error("ticker channel closed early")
}
}
if d := time.Since(now); d > variance {
t.Errorf("%v > %v", d, variance)
}
for i := 0; i < 10; i++ {
if atomic.LoadUint64(&counter) == n {
break
Expand All @@ -46,18 +40,19 @@ func TestNewTicker(t *testing.T) {
if x := atomic.LoadUint64(&counter); x != n {
t.Errorf("%v != %v", x, n)
}
if d := time.Since(now); d > variance {
t.Errorf("%v > %v", d, variance)
}
}

func TestNewSubTicker(t *testing.T) {
const n = 100
var counter uint64
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ch1 := NewTicker(ctx, nil, nil)
ch2 := NewSubTicker(ch1, nil, &counter)
now := time.Now()
t1 := NewTicker(nil, nil)
t2 := NewSubTicker(t1.C, nil, &counter)
for i := 0; i < n; i++ {
_, ok := <-ch2
_, ok := <-t2.C
if !ok {
t.Error("ticker channel closed early")
}
Expand All @@ -71,8 +66,60 @@ func TestNewSubTicker(t *testing.T) {
}
time.Sleep(time.Millisecond)
}
time.Sleep(time.Millisecond)
if x := atomic.LoadUint64(&counter); x != n {
t.Errorf("%v != %v", x, n)
}
t1.Close()
// there can be at most one extra tick to read after t1.Close
if _, ok := <-t2.C; ok {
if _, ok := <-t2.C; ok {
t.Error("t2 should have been closed")
}
}
if d := time.Since(now); d > variance {
t.Errorf("%v > %v", d, variance)
}
}

func TestAddingTick(t *testing.T) {
var counter uint64

now := time.Now()
maxrate := int32(time.Second / variance * 2)
ticker := NewTicker(&maxrate, &counter)

select {
case <-ticker.C:
case <-time.NewTimer(variance).C:
t.Error("timed out waiting for tick")
}

select {
case <-ticker.C:
t.Error("got an unexpected tick")
default:
}

var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
select {
case <-ticker.C:
case <-time.NewTimer(variance).C:
t.Error("timed out waiting for tick")
}
}()
ticker.AddTick(time.Nanosecond)
wg.Wait()
if d := time.Since(now); d > variance {
t.Errorf("%v > %v", d, variance)
}
ticker.Close()
if counter != 1 {
t.Error("counter should be one, not", counter)
}
if d := time.Since(now); d > variance {
t.Errorf("%v > %v", d, variance)
}
}

0 comments on commit 11cb3f8

Please sign in to comment.