Skip to content

Commit

Permalink
use functional options
Browse files Browse the repository at this point in the history
  • Loading branch information
da440dil committed Aug 1, 2019
1 parent faa43cd commit 31431c1
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 83 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@ func main() {
client := redis.NewClient(&redis.Options{})
defer client.Close()

ctr := counter.NewCounter(
client,
counter.Params{TTL: time.Millisecond * 100, Limit: 2},
)
c, err := counter.NewCounter(client, 2, time.Millisecond*100)
if err != nil {
panic(err)
}
key := "key"
var wg sync.WaitGroup
count := func() {
wg.Add(1)
go func() {
v, err := ctr.Count(key)
v, err := c.Count(key)
if err == nil {
fmt.Printf("Counter has counted the key, remainder %v\n", v)
} else {
Expand Down
99 changes: 65 additions & 34 deletions counter.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,54 +17,79 @@ type Gateway interface {
Incr(key string, ttl int) (int, int, error)
}

// Params defines parameters for creating new Counter.
type Params struct {
TTL time.Duration // TTL of a key. Must be greater than or equal to 1 millisecond.
Limit int // Maximum key value. Must be greater than 0.
Prefix string // Prefix of a key. Optional.
// ErrInvalidTTL is the error returned when NewCounter receives invalid value of TTL.
var ErrInvalidTTL = errors.New("TTL must be greater than or equal to 1 millisecond")

// ErrInvalidLimit is the error returned when NewCounter receives invalid value of limit.
var ErrInvalidLimit = errors.New("Limit must be greater than zero")

// ErrInvaldKey is the error returned when key length is greater than 512 MB.
var ErrInvaldKey = errors.New("Key length must be less than or equal to 512 MB")

// Func is function returned by functions for setting options.
type Func func(c *Counter) error

// WithPrefix sets prefix of a key.
func WithPrefix(v string) Func {
return func(c *Counter) error {
if !isValidKey(v) {
return ErrInvaldKey
}
c.prefix = v
return nil
}
}

var errInvalidTTL = errors.New("TTL must be greater than or equal to 1 millisecond")
var errInvalidLimit = errors.New("Limit must be greater than zero")

func (p Params) validate() {
if p.TTL < time.Millisecond {
panic(errInvalidTTL)
}
if p.Limit < 1 {
panic(errInvalidLimit)
}
// Counter implements distributed rate limiting.
type Counter struct {
gateway Gateway
ttl int
limit int
prefix string
}

// NewCounterWithGateway creates new Counter using custom Gateway.
func NewCounterWithGateway(gateway Gateway, params Params) *Counter {
params.validate()
return &Counter{
// Limit is maximum key value, must be greater than 0.
// TTL is TTL of a key, must be greater than or equal to 1 millisecond.
// Options are functional options.
func NewCounterWithGateway(gateway Gateway, limit int, ttl time.Duration, options ...Func) (*Counter, error) {
if limit < 1 {
return nil, ErrInvalidLimit
}
if ttl < time.Millisecond {
return nil, ErrInvalidTTL
}
c := &Counter{
gateway: gateway,
ttl: durationToMilliseconds(params.TTL),
limit: params.Limit,
prefix: params.Prefix,
ttl: durationToMilliseconds(ttl),
limit: limit,
}
for _, fn := range options {
err := fn(c)
if err != nil {
return nil, err
}
}
return c, nil
}

// NewCounter creates new Counter using Redis Gateway.
func NewCounter(client *redis.Client, params Params) *Counter {
return NewCounterWithGateway(gw.NewGateway(client), params)
}

// Counter implements distributed rate limiting.
type Counter struct {
gateway Gateway
ttl int
limit int
prefix string
// Limit is maximum key value, must be greater than 0.
// TTL is TTL of a key, must be greater than or equal to 1 millisecond.
// Options are functional options.
func NewCounter(client *redis.Client, limit int, ttl time.Duration, options ...Func) (*Counter, error) {
return NewCounterWithGateway(gw.NewGateway(client), limit, ttl, options...)
}

// Count increments key value.
// Returns limit remainder.
// Returns TTLError if limit exceeded.
func (c *Counter) Count(key string) (int, error) {
value, ttl, err := c.gateway.Incr(c.prefix+key, c.ttl)
key = c.prefix + key
if !isValidKey(key) {
return -1, ErrInvaldKey
}
value, ttl, err := c.gateway.Incr(key, c.ttl)
if err != nil {
return -1, err
}
Expand All @@ -89,7 +114,7 @@ type TTLError interface {
TTL() time.Duration // Returns TTL of a key.
}

var errTooManyRequests = errors.New("Too Many Requests")
const ttlErrorMsg = "Too Many Requests"

type ttlError struct {
ttl time.Duration
Expand All @@ -100,9 +125,15 @@ func newTTLError(ttl int) *ttlError {
}

func (e *ttlError) Error() string {
return errTooManyRequests.Error()
return ttlErrorMsg
}

func (e *ttlError) TTL() time.Duration {
return e.ttl
}

const maxKeyLen = 512000000

func isValidKey(key string) bool {
return len([]byte(key)) <= maxKeyLen
}
121 changes: 82 additions & 39 deletions counter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"errors"
"testing"
"time"
"unsafe"

"github.com/go-redis/redis"
"github.com/stretchr/testify/assert"
Expand All @@ -30,23 +31,91 @@ func TestNewCounter(t *testing.T) {
client := redis.NewClient(&redis.Options{Addr: Addr, DB: DB})
defer client.Close()

ctr := NewCounter(client, Params{TTL: TTL, Limit: Limit})
assert.IsType(t, &Counter{}, ctr)
t.Run("ErrInvalidLimit", func(t *testing.T) {
_, err := NewCounter(client, 0, time.Microsecond)
assert.Error(t, err)
assert.Equal(t, ErrInvalidLimit, err)
})

t.Run("ErrInvalidTTL", func(t *testing.T) {
_, err := NewCounter(client, Limit, time.Microsecond)
assert.Error(t, err)
assert.Equal(t, ErrInvalidTTL, err)
})

t.Run("success", func(t *testing.T) {
c, err := NewCounter(client, Limit, TTL)
assert.NoError(t, err)
assert.IsType(t, &Counter{}, c)
})
}

func TestCounter(t *testing.T) {
params := Params{TTL: TTL, Limit: Limit}
func TestNewCounterWithGateway(t *testing.T) {
gw := &gwMock{}

t.Run("ErrInvalidLimit", func(t *testing.T) {
_, err := NewCounterWithGateway(gw, 0, time.Microsecond)
assert.Error(t, err)
assert.Equal(t, ErrInvalidLimit, err)
})

t.Run("ErrInvalidTTL", func(t *testing.T) {
_, err := NewCounterWithGateway(gw, Limit, time.Microsecond)
assert.Error(t, err)
assert.Equal(t, ErrInvalidTTL, err)
})

t.Run("success", func(t *testing.T) {
c, err := NewCounterWithGateway(gw, Limit, TTL)
assert.NoError(t, err)
assert.IsType(t, &Counter{}, c)
})
}

func TestOptions(t *testing.T) {
gw := &gwMock{}

t.Run("ErrInvaldKey", func(t *testing.T) {
p := make([]byte, 512000001)
s := *(*string)(unsafe.Pointer(&p))
_, err := NewCounterWithGateway(gw, Limit, TTL, WithPrefix(s))
assert.Error(t, err)
assert.Equal(t, ErrInvaldKey, err)
})

t.Run("success", func(t *testing.T) {
c, err := NewCounterWithGateway(gw, Limit, TTL, WithPrefix(""))
assert.NoError(t, err)
assert.IsType(t, &Counter{}, c)
})
}

func TestCounter(t *testing.T) {
ttl := durationToMilliseconds(TTL)

t.Run("ErrInvaldKey", func(t *testing.T) {
gw := &gwMock{}

c, err := NewCounterWithGateway(gw, Limit, TTL)
assert.NoError(t, err)

p := make([]byte, 512000001)
s := *(*string)(unsafe.Pointer(&p))
v, err := c.Count(s)
assert.Equal(t, -1, v)
assert.Error(t, err)
assert.Equal(t, ErrInvaldKey, err)
})

t.Run("error", func(t *testing.T) {
e := errors.New("any")
gw := &gwMock{}
gw.On("Incr", Key, ttl).Return(-1, 42, e)

ctr := NewCounterWithGateway(gw, params)
c, err := NewCounterWithGateway(gw, Limit, TTL)
assert.NoError(t, err)

v, err := ctr.Count(Key)
v, err := c.Count(Key)
assert.Equal(t, -1, v)
assert.Error(t, err)
assert.Equal(t, e, err)
Expand All @@ -58,9 +127,10 @@ func TestCounter(t *testing.T) {
gw := &gwMock{}
gw.On("Incr", Key, ttl).Return(Limit+1, et, nil)

ctr := NewCounterWithGateway(gw, params)
c, err := NewCounterWithGateway(gw, Limit, TTL)
assert.NoError(t, err)

v, err := ctr.Count(Key)
v, err := c.Count(Key)
assert.Equal(t, -1, v)
assert.Error(t, err)
assert.Exactly(t, newTTLError(et), err)
Expand All @@ -71,46 +141,19 @@ func TestCounter(t *testing.T) {
gw := &gwMock{}
gw.On("Incr", Key, ttl).Return(Limit, 42, nil)

ctr := NewCounterWithGateway(gw, params)
c, err := NewCounterWithGateway(gw, Limit, TTL)
assert.NoError(t, err)

v, err := ctr.Count(Key)
v, err := c.Count(Key)
assert.Equal(t, 0, v)
assert.NoError(t, err)
gw.AssertExpectations(t)
})
}

func TestParams(t *testing.T) {
t.Run("invalid ttl", func(t *testing.T) {
defer func() {
r := recover()
assert.NotNil(t, r)
err, ok := r.(error)
assert.True(t, ok)
assert.Error(t, err)
assert.Equal(t, errInvalidTTL, err)
}()

Params{TTL: time.Microsecond}.validate()
})

t.Run("invalid limit", func(t *testing.T) {
defer func() {
r := recover()
assert.NotNil(t, r)
err, ok := r.(error)
assert.True(t, ok)
assert.Error(t, err)
assert.Equal(t, errInvalidLimit, err)
}()

Params{TTL: time.Millisecond}.validate()
})
}

func TestTTLError(t *testing.T) {
et := 42
err := newTTLError(et)
assert.EqualError(t, err, errTooManyRequests.Error())
assert.Equal(t, ttlErrorMsg, err.Error())
assert.Equal(t, millisecondsToDuration(et), err.TTL())
}
10 changes: 5 additions & 5 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@ func Example() {
client := redis.NewClient(&redis.Options{})
defer client.Close()

ctr := counter.NewCounter(
client,
counter.Params{TTL: time.Millisecond * 100, Limit: 2},
)
c, err := counter.NewCounter(client, 2, time.Millisecond*100)
if err != nil {
panic(err)
}
key := "key"
var wg sync.WaitGroup
count := func() {
wg.Add(1)
go func() {
v, err := ctr.Count(key)
v, err := c.Count(key)
if err == nil {
fmt.Printf("Counter has counted the key, remainder %v\n", v)
} else {
Expand Down

0 comments on commit 31431c1

Please sign in to comment.