Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Connect hook #1803

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions example_instrumentation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,20 @@ import (
"fmt"

"github.com/go-redis/redis/v8"
"github.com/go-redis/redis/v8/internal/pool"
)

type contextKey string

var key = contextKey("foo")

type redisHook struct{}
type redisFullHook struct {
redisHook
}

var _ redis.Hook = redisHook{}
var _ redis.Hook = redisFullHook{}

func (redisHook) BeforeProcess(ctx context.Context, cmd redis.Cmder) (context.Context, error) {
fmt.Printf("starting processing: <%s>\n", cmd)
Expand All @@ -31,6 +40,27 @@ func (redisHook) AfterProcessPipeline(ctx context.Context, cmds []redis.Cmder) e
return nil
}

func (redisFullHook) BeforeConnect(ctx context.Context) context.Context {
fmt.Printf("before connect")

if v := ctx.Value(key); v != nil {
fmt.Printf(" %v\n", v)
} else {
fmt.Printf("\n")
}

return ctx
}

func (redisFullHook) AfterConnect(ctx context.Context, event pool.ConnectEvent) {
fmt.Printf("after connect: %v", event.Err)
if v := ctx.Value(key); v != nil {
fmt.Printf(" %v\n", v)
} else {
fmt.Printf("\n")
}
}

func Example_instrumentation() {
rdb := redis.NewClient(&redis.Options{
Addr: ":6379",
Expand Down Expand Up @@ -78,3 +108,67 @@ func ExampleClient_Watch_instrumentation() {
// starting processing: <unwatch: >
// finished processing: <unwatch: OK>
}

func ExamplePool_instrumentation() {
rdb := redis.NewClient(&redis.Options{
Addr: ":6379",
})
rdb.AddHook(redisFullHook{})

rdb.Ping(ctx)
// Output:
// starting processing: <ping: >
// before connect
// after connect: <nil>
// finished processing: <ping: PONG>
}

func ExamplePool_instrumentation_connect_error() {
invalidAddr := "0.0.0.1:6379"
rdb := redis.NewClient(&redis.Options{
Addr: invalidAddr,
MaxRetries: -1,
})
rdb.AddHook(redisFullHook{})

rdb.Ping(ctx)
// Output:
// starting processing: <ping: >
// before connect
// after connect: dial tcp 0.0.0.1:6379: connect: no route to host
// finished processing: <ping: dial tcp 0.0.0.1:6379: connect: no route to host>
}

func ExamplePool_instrumentation_context_wiring() {
rdb := redis.NewClient(&redis.Options{
Addr: ":6379",
})
rdb.AddHook(redisFullHook{})

ctx := context.WithValue(context.Background(), key, "bar")
rdb.Ping(ctx)
// Output:
// starting processing: <ping: >
// before connect bar
// after connect: <nil> bar
// finished processing: <ping: PONG>
}

func ExamplePool_instrumentation_mulitple_hooks() {
rdb := redis.NewClient(&redis.Options{
Addr: ":6379",
})
rdb.AddHook(redisFullHook{})
rdb.AddHook(redisFullHook{})

rdb.Ping(ctx)
// Output:
// starting processing: <ping: >
// starting processing: <ping: >
// before connect
// before connect
// after connect: <nil>
// after connect: <nil>
// finished processing: <ping: PONG>
// finished processing: <ping: PONG>
}
46 changes: 42 additions & 4 deletions internal/pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,23 @@ var timers = sync.Pool{
},
}

type ConnectEvent struct {
Err error
}

type PoolHook interface {
BeforeConnect(ctx context.Context) context.Context
AfterConnect(ctx context.Context, event ConnectEvent)
}

type hooks struct {
hooks []PoolHook
}

func (hs *hooks) AddHook(hook PoolHook) {
hs.hooks = append(hs.hooks, hook)
}

// Stats contains pool state information and accumulated stats.
type Stats struct {
Hits uint32 // number of times free connection was found in the pool
Expand All @@ -39,6 +56,8 @@ type Stats struct {
}

type Pooler interface {
AddHook(hook PoolHook)

NewConn(context.Context) (*Conn, error)
CloseConn(*Conn) error

Expand Down Expand Up @@ -70,6 +89,8 @@ type lastDialErrorWrap struct {
}

type ConnPool struct {
hooks

opt *Options

dialErrorsNum uint32 // atomic
Expand Down Expand Up @@ -179,18 +200,35 @@ func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) {
return nil, p.getLastDialError()
}

netConn, err := p.opt.Dialer(ctx)
var netConn net.Conn
var cn* Conn
var hookIndex int
var err error

for ; hookIndex < len(p.hooks.hooks); hookIndex++ {
ctx = p.hooks.hooks[hookIndex].BeforeConnect(ctx)
}

netConn, err = p.opt.Dialer(ctx)
if err != nil {
p.setLastDialError(err)
if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) {
go p.tryDial()
}
} else {
internal.NewConnectionsCounter.Add(ctx, 1)
cn = NewConn(netConn)
cn.pooled = pooled
}
event := ConnectEvent{Err:err}
for hookIndex--; hookIndex >= 0; hookIndex-- {
p.hooks.hooks[hookIndex].AfterConnect(ctx, event);
}

if err != nil {
return nil, err
}

internal.NewConnectionsCounter.Add(ctx, 1)
cn := NewConn(netConn)
cn.pooled = pooled
return cn, nil
}

Expand Down
1 change: 1 addition & 0 deletions internal/pool/pool_single.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package pool
import "context"

type SingleConnPool struct {
hooks
pool Pooler
cn *Conn
stickyErr error
Expand Down
1 change: 1 addition & 0 deletions internal/pool/pool_sticky.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ func (e BadConnError) Unwrap() error {
//------------------------------------------------------------------------------

type StickyConnPool struct {
hooks
pool Pooler
shared int32 // atomic

Expand Down
16 changes: 12 additions & 4 deletions redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"sync/atomic"
"time"
"unsafe"

"github.com/go-redis/redis/v8/internal"
"github.com/go-redis/redis/v8/internal/pool"
Expand All @@ -29,6 +30,11 @@ type Hook interface {
AfterProcessPipeline(ctx context.Context, cmds []Cmder) error
}

type fullHook interface {
Hook
pool.PoolHook
}

type hooks struct {
hooks []Hook
}
Expand All @@ -45,6 +51,10 @@ func (hs hooks) clone() hooks {

func (hs *hooks) AddHook(hook Hook) {
hs.hooks = append(hs.hooks, hook)
if hook, ok := hook.(fullHook); ok {
client := *(*Client)(unsafe.Pointer(hs))
client.baseClient.connPool.AddHook(hook)
}
}

func (hs hooks) process(
Expand Down Expand Up @@ -137,8 +147,7 @@ func (hs hooks) withContext(ctx context.Context, fn func() error) error {
type baseClient struct {
opt *Options
connPool pool.Pooler

onClose func() error // hook called when client is closed
onClose func() error // hook called when client is closed
}

func newBaseClient(opt *Options, connPool pool.Pooler) *baseClient {
Expand Down Expand Up @@ -548,9 +557,9 @@ func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder)
// underlying connections. It's safe for concurrent use by multiple
// goroutines.
type Client struct {
hooks
*baseClient
cmdable
hooks
ctx context.Context
}

Expand All @@ -563,7 +572,6 @@ func NewClient(opt *Options) *Client {
ctx: context.Background(),
}
c.cmdable = c.Process

return &c
}

Expand Down