forked from redis/go-redis
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tx.go
185 lines (160 loc) · 4.13 KB
/
tx.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
package redis
import (
"fmt"
"gopkg.in/redis.v5/internal"
"gopkg.in/redis.v5/internal/pool"
"gopkg.in/redis.v5/internal/proto"
)
// Redis transaction failed.
const TxFailedErr = internal.RedisError("redis: transaction failed")
// Tx implements Redis transactions as described in
// http://redis.io/topics/transactions. It's NOT safe for concurrent use
// by multiple goroutines, because Exec resets list of watched keys.
// If you don't need WATCH it is better to use Pipeline.
type Tx struct {
cmdable
statefulCmdable
baseClient
closed bool
}
var _ Cmdable = (*Tx)(nil)
func (c *Client) newTx() *Tx {
tx := Tx{
baseClient: baseClient{
opt: c.opt,
connPool: pool.NewStickyConnPool(c.connPool.(*pool.ConnPool), true),
},
}
tx.cmdable.process = tx.Process
tx.statefulCmdable.process = tx.Process
return &tx
}
func (c *Client) Watch(fn func(*Tx) error, keys ...string) error {
tx := c.newTx()
if len(keys) > 0 {
if err := tx.Watch(keys...).Err(); err != nil {
_ = tx.close()
return err
}
}
firstErr := fn(tx)
if err := tx.close(); err != nil && firstErr == nil {
firstErr = err
}
return firstErr
}
// close closes the transaction, releasing any open resources.
func (c *Tx) close() error {
if c.closed {
return nil
}
c.closed = true
if err := c.Unwatch().Err(); err != nil {
internal.Logf("Unwatch failed: %s", err)
}
return c.baseClient.Close()
}
// Watch marks the keys to be watched for conditional execution
// of a transaction.
func (c *Tx) Watch(keys ...string) *StatusCmd {
args := make([]interface{}, 1+len(keys))
args[0] = "WATCH"
for i, key := range keys {
args[1+i] = key
}
cmd := NewStatusCmd(args...)
c.Process(cmd)
return cmd
}
// Unwatch flushes all the previously watched keys for a transaction.
func (c *Tx) Unwatch(keys ...string) *StatusCmd {
args := make([]interface{}, 1+len(keys))
args[0] = "UNWATCH"
for i, key := range keys {
args[1+i] = key
}
cmd := NewStatusCmd(args...)
c.Process(cmd)
return cmd
}
func (c *Tx) Pipeline() *Pipeline {
pipe := Pipeline{
exec: c.exec,
}
pipe.cmdable.process = pipe.Process
pipe.statefulCmdable.process = pipe.Process
return &pipe
}
// Pipelined executes commands queued in the fn in a transaction
// and restores the connection state to normal.
//
// When using WATCH, EXEC will execute commands only if the watched keys
// were not modified, allowing for a check-and-set mechanism.
//
// Exec always returns list of commands. If transaction fails
// TxFailedErr is returned. Otherwise Exec returns error of the first
// failed command or nil.
func (c *Tx) Pipelined(fn func(*Pipeline) error) ([]Cmder, error) {
return c.Pipeline().pipelined(fn)
}
func (c *Tx) exec(cmds []Cmder) error {
if c.closed {
return pool.ErrClosed
}
cn, _, err := c.conn()
if err != nil {
setCmdsErr(cmds, err)
return err
}
multiExec := make([]Cmder, 0, len(cmds)+2)
multiExec = append(multiExec, NewStatusCmd("MULTI"))
multiExec = append(multiExec, cmds...)
multiExec = append(multiExec, NewSliceCmd("EXEC"))
err = c.execCmds(cn, multiExec)
c.putConn(cn, err, false)
return err
}
func (c *Tx) execCmds(cn *pool.Conn, cmds []Cmder) error {
cn.SetWriteTimeout(c.opt.WriteTimeout)
err := writeCmd(cn, cmds...)
if err != nil {
setCmdsErr(cmds[1:len(cmds)-1], err)
return err
}
// Set read timeout for all commands.
cn.SetReadTimeout(c.opt.ReadTimeout)
// Omit last command (EXEC).
cmdsLen := len(cmds) - 1
// Parse queued replies.
statusCmd := cmds[0]
for i := 0; i < cmdsLen; i++ {
if err := statusCmd.readReply(cn); err != nil {
setCmdsErr(cmds[1:len(cmds)-1], err)
return err
}
}
// Parse number of replies.
line, err := cn.Rd.ReadLine()
if err != nil {
if err == Nil {
err = TxFailedErr
}
setCmdsErr(cmds[1:len(cmds)-1], err)
return err
}
if line[0] != proto.ArrayReply {
err := fmt.Errorf("redis: expected '*', but got line %q", line)
setCmdsErr(cmds[1:len(cmds)-1], err)
return err
}
var firstErr error
// Parse replies.
// Loop starts from 1 to omit MULTI cmd.
for i := 1; i < cmdsLen; i++ {
cmd := cmds[i]
if err := cmd.readReply(cn); err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}