/
otredis.go
77 lines (66 loc) · 2 KB
/
otredis.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
package otredis
import (
"context"
"fmt"
"github.com/go-redis/redis/v7"
)
type UniversalClient interface {
redis.UniversalClient
// withContext is to inject context and to add hook.
// it is an internal method.
withContext(ctx context.Context) UniversalClient
}
type redisClient struct {
*redis.Client
}
// WithContext is to inject context and to add hook.
func (rc *redisClient) withContext(ctx context.Context) UniversalClient {
opts := rc.Client.Options()
rc.Client = rc.Client.WithContext(ctx)
rc.AddHook(hook{addrs: []string{opts.Addr}, database: opts.DB})
return rc
}
type redisClusterClient struct {
*redis.ClusterClient
}
// WithContext is to inject context and to add hook.
func (rc *redisClusterClient) withContext(ctx context.Context) UniversalClient {
rc.ClusterClient = rc.ClusterClient.WithContext(ctx)
rc.AddHook(hook{addrs: rc.ClusterClient.Options().Addrs, database: 0})
return rc
}
type redisRing struct {
*redis.Ring
}
// WithContext is to inject context and to add hook.
func (rc *redisRing) withContext(ctx context.Context) UniversalClient {
opts := rc.Ring.Options()
rc.Ring = rc.Ring.WithContext(ctx)
addrs := make([]string, len(opts.Addrs))
i := 0
for _, v := range opts.Addrs {
addrs[i] = v
i += 1
}
rc.AddHook(hook{addrs: addrs, database: opts.DB})
return rc
}
// WrapClient is to wrap context and to add hooks for opentracing.
func WrapClient(ctx context.Context, client redis.UniversalClient) (UniversalClient, error) {
if ctx == nil || client == nil {
return nil, fmt.Errorf("[err] WrapClient invalid params")
}
var wrapClient UniversalClient
switch client.(type) {
case *redis.Client:
wrapClient = &redisClient{Client: client.(*redis.Client)}
case *redis.ClusterClient:
wrapClient = &redisClusterClient{ClusterClient: client.(*redis.ClusterClient)}
case *redis.Ring:
wrapClient = &redisRing{Ring: client.(*redis.Ring)}
default:
return nil, fmt.Errorf("[err] WrapClient not support client")
}
wrapClient = wrapClient.withContext(ctx)
return wrapClient, nil
}