/
stub.go
198 lines (167 loc) · 4.17 KB
/
stub.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
package radix
import (
"bufio"
"bytes"
"context"
"errors"
"net"
"sync"
"time"
"github.com/mediocregopher/radix/v4/resp"
"github.com/mediocregopher/radix/v4/resp/resp2"
)
type bufferAddr struct {
network, addr string
}
func (sa bufferAddr) Network() string {
return sa.network
}
func (sa bufferAddr) String() string {
return sa.addr
}
// in the end this is really just a complicated stub of net.Conn
type buffer struct {
net.Conn // always nil
remoteAddr bufferAddr
bufL *sync.Cond
buf *bytes.Buffer
bufbr *bufio.Reader
closed bool
}
func newBuffer(remoteNetwork, remoteAddr string) *buffer {
buf := new(bytes.Buffer)
return &buffer{
remoteAddr: bufferAddr{network: remoteNetwork, addr: remoteAddr},
bufL: sync.NewCond(new(sync.Mutex)),
buf: buf,
bufbr: bufio.NewReader(buf),
}
}
func (b *buffer) Encode(m resp.Marshaler) error {
b.bufL.L.Lock()
var err error
if b.closed {
err = b.err("write", errClosed)
} else {
err = m.MarshalRESP(b.buf)
}
b.bufL.L.Unlock()
if err != nil {
return err
}
b.bufL.Broadcast()
return nil
}
func (b *buffer) Decode(ctx context.Context, u resp.Unmarshaler) error {
b.bufL.L.Lock()
defer b.bufL.L.Unlock()
wakeupTicker := time.NewTicker(250 * time.Millisecond)
defer wakeupTicker.Stop()
for b.buf.Len() == 0 && b.bufbr.Buffered() == 0 {
if b.closed {
return b.err("read", errClosed)
}
select {
case <-ctx.Done():
return ctx.Err()
default:
}
// we have to periodically wakeup to check if the context is done
go func() {
<-wakeupTicker.C
b.bufL.Broadcast()
}()
b.bufL.Wait()
}
return u.UnmarshalRESP(b.bufbr)
}
func (b *buffer) Close() error {
b.bufL.L.Lock()
defer b.bufL.L.Unlock()
if b.closed {
return b.err("close", errClosed)
}
b.closed = true
b.bufL.Broadcast()
return nil
}
func (b *buffer) RemoteAddr() net.Addr {
return b.remoteAddr
}
func (b *buffer) err(op string, err error) error {
return &net.OpError{
Op: op,
Net: "tcp",
Source: nil,
Addr: b.remoteAddr,
Err: err,
}
}
var errClosed = errors.New("use of closed network connection")
////////////////////////////////////////////////////////////////////////////////
type stub struct {
*buffer
fn func([]string) interface{}
}
// Stub returns a (fake) Conn which pretends it is a Conn to a real redis
// instance, but is instead using the given callback to service requests. It is
// primarily useful for writing tests.
//
// When EncodeDecode is called the value to be marshaled is converted into a
// []string and passed to the callback. The return from the callback is then
// marshaled into an internal buffer. The value to be decoded is unmarshaled
// into using the internal buffer. If the internal buffer is empty at
// this step then the call will block.
//
// remoteNetwork and remoteAddr can be empty, but if given will be used as the
// return from the RemoteAddr method.
//
func Stub(remoteNetwork, remoteAddr string, fn func([]string) interface{}) Conn {
return &stub{
buffer: newBuffer(remoteNetwork, remoteAddr),
fn: fn,
}
}
func (s *stub) Do(ctx context.Context, a Action) error {
return a.Perform(ctx, s)
}
func (s *stub) EncodeDecode(ctx context.Context, m resp.Marshaler, u resp.Unmarshaler) error {
if m != nil {
buf := new(bytes.Buffer)
if err := m.MarshalRESP(buf); err != nil {
return err
}
br := bufio.NewReader(buf)
for {
var ss []string
if buf.Len() == 0 && br.Buffered() == 0 {
break
} else if err := (resp2.Any{I: &ss}).UnmarshalRESP(br); err != nil {
return err
}
// get return from callback. Results implementing resp.Marshaler are
// assumed to be wanting to be written in all cases, otherwise if
// the result is an error it is assumed to want to be returned
// directly.
ret := s.fn(ss)
if m, ok := ret.(resp.Marshaler); ok {
if err := s.buffer.Encode(m); err != nil {
return err
}
} else if err, _ := ret.(error); err != nil {
return err
} else if err = s.buffer.Encode(resp2.Any{I: ret}); err != nil {
return err
}
}
}
if u != nil {
if err := s.buffer.Decode(ctx, u); err != nil {
return err
}
}
return nil
}
func (s *stub) NetConn() net.Conn {
return s.buffer
}