From c3d68450cd543078852611a33162d06463402c26 Mon Sep 17 00:00:00 2001 From: Nathan Baulch Date: Sat, 19 Sep 2020 22:36:02 +1000 Subject: [PATCH] Simplify recently added context support --- mutex.go | 8 ++++---- mutex_test.go | 13 ++++++------- redis/goredis/goredis.go | 36 ++++++++++++++++-------------------- redis/goredis/v7/goredis.go | 29 ++++++++++++++++------------- redis/goredis/v8/goredis.go | 29 ++++++++++++++++------------- redis/redigo/redigo.go | 21 ++++++++++++--------- redis/redis.go | 10 +++++----- 7 files changed, 75 insertions(+), 71 deletions(-) diff --git a/mutex.go b/mutex.go index 13ee434..25f2923 100644 --- a/mutex.go +++ b/mutex.go @@ -122,7 +122,7 @@ func (m *Mutex) valid(ctx context.Context, pool redis.Pool) (bool, error) { return false, err } defer conn.Close() - reply, err := conn.Get(ctx, m.name) + reply, err := conn.Get(m.name) if err != nil { return false, err } @@ -144,7 +144,7 @@ func (m *Mutex) acquire(ctx context.Context, pool redis.Pool, value string) (boo return false, err } defer conn.Close() - reply, err := conn.SetNX(ctx, m.name, value, m.expiry) + reply, err := conn.SetNX(m.name, value, m.expiry) if err != nil { return false, err } @@ -165,7 +165,7 @@ func (m *Mutex) release(ctx context.Context, pool redis.Pool, value string) (boo return false, err } defer conn.Close() - status, err := conn.Eval(ctx, deleteScript, m.name, value) + status, err := conn.Eval(deleteScript, m.name, value) if err != nil { return false, err } @@ -186,7 +186,7 @@ func (m *Mutex) touch(ctx context.Context, pool redis.Pool, value string, expiry return false, err } defer conn.Close() - status, err := conn.Eval(ctx, touchScript, m.name, value, expiry) + status, err := conn.Eval(touchScript, m.name, value, expiry) if err != nil { return false, err } diff --git a/mutex_test.go b/mutex_test.go index 362d593..9228c2d 100644 --- a/mutex_test.go +++ b/mutex_test.go @@ -1,7 +1,6 @@ package redsync import ( - "context" "strconv" "testing" "time" @@ -127,11 +126,11 @@ func TestValid(t *testing.T) { func getPoolValues(pools []redis.Pool, name string) []string { values := make([]string, len(pools)) for i, pool := range pools { - conn, err := pool.Get(context.TODO()) + conn, err := pool.Get(nil) if err != nil { panic(err) } - value, err := conn.Get(context.TODO(), name) + value, err := conn.Get(name) if err != nil { panic(err) } @@ -144,11 +143,11 @@ func getPoolValues(pools []redis.Pool, name string) []string { func getPoolExpiries(pools []redis.Pool, name string) []int { expiries := make([]int, len(pools)) for i, pool := range pools { - conn, err := pool.Get(context.TODO()) + conn, err := pool.Get(nil) if err != nil { panic(err) } - expiry, err := conn.PTTL(context.TODO(), name) + expiry, err := conn.PTTL(name) if err != nil { panic(err) } @@ -165,11 +164,11 @@ func clogPools(pools []redis.Pool, mask int, mutex *Mutex) int { n++ continue } - conn, err := pool.Get(context.TODO()) + conn, err := pool.Get(nil) if err != nil { panic(err) } - _, err = conn.Set(context.TODO(), mutex.name, "foobar") + _, err = conn.Set(mutex.name, "foobar") if err != nil { panic(err) } diff --git a/redis/goredis/goredis.go b/redis/goredis/goredis.go index d00aad9..7135457 100644 --- a/redis/goredis/goredis.go +++ b/redis/goredis/goredis.go @@ -14,7 +14,11 @@ type pool struct { } func (p *pool) Get(ctx context.Context) (redsyncredis.Conn, error) { - return &conn{p.delegate}, nil + c := p.delegate + if ctx != nil { + c = c.WithContext(ctx) + } + return &conn{c}, nil } func NewPool(delegate *redis.Client) redsyncredis.Pool { @@ -25,27 +29,27 @@ type conn struct { delegate *redis.Client } -func (c *conn) Get(ctx context.Context, name string) (string, error) { - value, err := c.client(ctx).Get(name).Result() +func (c *conn) Get(name string) (string, error) { + value, err := c.delegate.Get(name).Result() return value, noErrNil(err) } -func (c *conn) Set(ctx context.Context, name string, value string) (bool, error) { - reply, err := c.client(ctx).Set(name, value, 0).Result() +func (c *conn) Set(name string, value string) (bool, error) { + reply, err := c.delegate.Set(name, value, 0).Result() return reply == "OK", noErrNil(err) } -func (c *conn) SetNX(ctx context.Context, name string, value string, expiry time.Duration) (bool, error) { - ok, err := c.client(ctx).SetNX(name, value, expiry).Result() +func (c *conn) SetNX(name string, value string, expiry time.Duration) (bool, error) { + ok, err := c.delegate.SetNX(name, value, expiry).Result() return ok, noErrNil(err) } -func (c *conn) PTTL(ctx context.Context, name string) (time.Duration, error) { - expiry, err := c.client(ctx).PTTL(name).Result() +func (c *conn) PTTL(name string) (time.Duration, error) { + expiry, err := c.delegate.PTTL(name).Result() return expiry, noErrNil(err) } -func (c *conn) Eval(ctx context.Context, script *redsyncredis.Script, keysAndArgs ...interface{}) (interface{}, error) { +func (c *conn) Eval(script *redsyncredis.Script, keysAndArgs ...interface{}) (interface{}, error) { keys := make([]string, script.KeyCount) args := keysAndArgs @@ -57,10 +61,9 @@ func (c *conn) Eval(ctx context.Context, script *redsyncredis.Script, keysAndArg args = keysAndArgs[script.KeyCount:] } - cli := c.client(ctx) - v, err := cli.EvalSha(script.Hash, keys, args...).Result() + v, err := c.delegate.EvalSha(script.Hash, keys, args...).Result() if err != nil && strings.HasPrefix(err.Error(), "NOSCRIPT ") { - v, err = cli.Eval(script.Src, keys, args...).Result() + v, err = c.delegate.Eval(script.Src, keys, args...).Result() } return v, noErrNil(err) } @@ -70,13 +73,6 @@ func (c *conn) Close() error { return nil } -func (c *conn) client(ctx context.Context) *redis.Client { - if ctx != nil { - return c.delegate.WithContext(ctx) - } - return c.delegate -} - func noErrNil(err error) error { if err == redis.Nil { return nil diff --git a/redis/goredis/v7/goredis.go b/redis/goredis/v7/goredis.go index 7048c53..8defdca 100644 --- a/redis/goredis/v7/goredis.go +++ b/redis/goredis/v7/goredis.go @@ -14,7 +14,11 @@ type pool struct { } func (p *pool) Get(ctx context.Context) (redsyncredis.Conn, error) { - return &conn{p.delegate}, nil + c := p.delegate + if ctx != nil { + c = c.WithContext(ctx) + } + return &conn{c}, nil } func NewPool(delegate *redis.Client) redsyncredis.Pool { @@ -25,25 +29,25 @@ type conn struct { delegate *redis.Client } -func (c *conn) Get(ctx context.Context, name string) (string, error) { - value, err := c.client(ctx).Get(name).Result() +func (c *conn) Get(name string) (string, error) { + value, err := c.delegate.Get(name).Result() return value, noErrNil(err) } -func (c *conn) Set(ctx context.Context, name string, value string) (bool, error) { - reply, err := c.client(ctx).Set(name, value, 0).Result() +func (c *conn) Set(name string, value string) (bool, error) { + reply, err := c.delegate.Set(name, value, 0).Result() return reply == "OK", err } -func (c *conn) SetNX(ctx context.Context, name string, value string, expiry time.Duration) (bool, error) { - return c.client(ctx).SetNX(name, value, expiry).Result() +func (c *conn) SetNX(name string, value string, expiry time.Duration) (bool, error) { + return c.delegate.SetNX(name, value, expiry).Result() } -func (c *conn) PTTL(ctx context.Context, name string) (time.Duration, error) { - return c.client(ctx).PTTL(name).Result() +func (c *conn) PTTL(name string) (time.Duration, error) { + return c.delegate.PTTL(name).Result() } -func (c *conn) Eval(ctx context.Context, script *redsyncredis.Script, keysAndArgs ...interface{}) (interface{}, error) { +func (c *conn) Eval(script *redsyncredis.Script, keysAndArgs ...interface{}) (interface{}, error) { keys := make([]string, script.KeyCount) args := keysAndArgs @@ -54,10 +58,9 @@ func (c *conn) Eval(ctx context.Context, script *redsyncredis.Script, keysAndArg args = keysAndArgs[script.KeyCount:] } - cli := c.client(ctx) - v, err := cli.EvalSha(script.Hash, keys, args...).Result() + v, err := c.delegate.EvalSha(script.Hash, keys, args...).Result() if err != nil && strings.HasPrefix(err.Error(), "NOSCRIPT ") { - v, err = cli.Eval(script.Src, keys, args...).Result() + v, err = c.delegate.Eval(script.Src, keys, args...).Result() } return v, noErrNil(err) } diff --git a/redis/goredis/v8/goredis.go b/redis/goredis/v8/goredis.go index b70d145..096c4e0 100644 --- a/redis/goredis/v8/goredis.go +++ b/redis/goredis/v8/goredis.go @@ -14,7 +14,10 @@ type pool struct { } func (p *pool) Get(ctx context.Context) (redsyncredis.Conn, error) { - return &conn{p.delegate}, nil + if ctx == nil { + ctx = p.delegate.Context() + } + return &conn{p.delegate, ctx}, nil } func NewPool(delegate *redis.Client) redsyncredis.Pool { @@ -23,27 +26,28 @@ func NewPool(delegate *redis.Client) redsyncredis.Pool { type conn struct { delegate *redis.Client + ctx context.Context } -func (c *conn) Get(ctx context.Context, name string) (string, error) { - value, err := c.delegate.Get(c._context(ctx), name).Result() +func (c *conn) Get(name string) (string, error) { + value, err := c.delegate.Get(c.ctx, name).Result() return value, noErrNil(err) } -func (c *conn) Set(ctx context.Context, name string, value string) (bool, error) { - reply, err := c.delegate.Set(c._context(ctx), name, value, 0).Result() +func (c *conn) Set(name string, value string) (bool, error) { + reply, err := c.delegate.Set(c.ctx, name, value, 0).Result() return reply == "OK", err } -func (c *conn) SetNX(ctx context.Context, name string, value string, expiry time.Duration) (bool, error) { - return c.delegate.SetNX(c._context(ctx), name, value, expiry).Result() +func (c *conn) SetNX(name string, value string, expiry time.Duration) (bool, error) { + return c.delegate.SetNX(c.ctx, name, value, expiry).Result() } -func (c *conn) PTTL(ctx context.Context, name string) (time.Duration, error) { - return c.delegate.PTTL(c._context(ctx), name).Result() +func (c *conn) PTTL(name string) (time.Duration, error) { + return c.delegate.PTTL(c.ctx, name).Result() } -func (c *conn) Eval(ctx context.Context, script *redsyncredis.Script, keysAndArgs ...interface{}) (interface{}, error) { +func (c *conn) Eval(script *redsyncredis.Script, keysAndArgs ...interface{}) (interface{}, error) { keys := make([]string, script.KeyCount) args := keysAndArgs @@ -54,10 +58,9 @@ func (c *conn) Eval(ctx context.Context, script *redsyncredis.Script, keysAndArg args = keysAndArgs[script.KeyCount:] } - ctx = c._context(ctx) - v, err := c.delegate.EvalSha(ctx, script.Hash, keys, args...).Result() + v, err := c.delegate.EvalSha(c.ctx, script.Hash, keys, args...).Result() if err != nil && strings.HasPrefix(err.Error(), "NOSCRIPT ") { - v, err = c.delegate.Eval(ctx, script.Src, keys, args...).Result() + v, err = c.delegate.Eval(c.ctx, script.Src, keys, args...).Result() } return v, noErrNil(err) } diff --git a/redis/redigo/redigo.go b/redis/redigo/redigo.go index 1ffa900..3020324 100644 --- a/redis/redigo/redigo.go +++ b/redis/redigo/redigo.go @@ -14,11 +14,14 @@ type pool struct { } func (p *pool) Get(ctx context.Context) (redsyncredis.Conn, error) { - c, err := p.delegate.GetContext(ctx) - if err != nil { - return nil, err + if ctx != nil { + c, err := p.delegate.GetContext(ctx) + if err != nil { + return nil, err + } + return &conn{c}, nil } - return &conn{c}, nil + return &conn{p.delegate.Get()}, nil } func NewPool(delegate *redis.Pool) redsyncredis.Pool { @@ -29,27 +32,27 @@ type conn struct { delegate redis.Conn } -func (c *conn) Get(_ context.Context, name string) (string, error) { +func (c *conn) Get(name string) (string, error) { value, err := redis.String(c.delegate.Do("GET", name)) return value, noErrNil(err) } -func (c *conn) Set(_ context.Context, name string, value string) (bool, error) { +func (c *conn) Set(name string, value string) (bool, error) { reply, err := redis.String(c.delegate.Do("SET", name, value)) return reply == "OK", noErrNil(err) } -func (c *conn) SetNX(_ context.Context, name string, value string, expiry time.Duration) (bool, error) { +func (c *conn) SetNX(name string, value string, expiry time.Duration) (bool, error) { reply, err := redis.String(c.delegate.Do("SET", name, value, "NX", "PX", int(expiry/time.Millisecond))) return reply == "OK", noErrNil(err) } -func (c *conn) PTTL(_ context.Context, name string) (time.Duration, error) { +func (c *conn) PTTL(name string) (time.Duration, error) { expiry, err := redis.Int64(c.delegate.Do("PTTL", name)) return time.Duration(expiry) * time.Millisecond, noErrNil(err) } -func (c *conn) Eval(_ context.Context, script *redsyncredis.Script, keysAndArgs ...interface{}) (interface{}, error) { +func (c *conn) Eval(script *redsyncredis.Script, keysAndArgs ...interface{}) (interface{}, error) { v, err := c.delegate.Do("EVALSHA", args(script, script.Hash, keysAndArgs)...) if e, ok := err.(redis.Error); ok && strings.HasPrefix(string(e), "NOSCRIPT ") { v, err = c.delegate.Do("EVAL", args(script, script.Src, keysAndArgs)...) diff --git a/redis/redis.go b/redis/redis.go index f41eb08..fa199f1 100644 --- a/redis/redis.go +++ b/redis/redis.go @@ -14,11 +14,11 @@ type Pool interface { } type Conn interface { - Get(ctx context.Context, name string) (string, error) - Set(ctx context.Context, name string, value string) (bool, error) - SetNX(ctx context.Context, name string, value string, expiry time.Duration) (bool, error) - Eval(ctx context.Context, script *Script, keysAndArgs ...interface{}) (interface{}, error) - PTTL(ctx context.Context, name string) (time.Duration, error) + Get(name string) (string, error) + Set(name string, value string) (bool, error) + SetNX(name string, value string, expiry time.Duration) (bool, error) + Eval(script *Script, keysAndArgs ...interface{}) (interface{}, error) + PTTL(name string) (time.Duration, error) Close() error }