forked from kelseyhightower/confd
/
client.go
270 lines (244 loc) · 6.54 KB
/
client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
package etcdv3
import (
"crypto/tls"
"crypto/x509"
"io/ioutil"
"strings"
"time"
"golang.org/x/net/context"
"github.com/coreos/etcd/clientv3"
"github.com/kelseyhightower/confd/log"
"sync"
)
// A watch only tells the latest revision
type Watch struct {
// Last seen revision
revision int64
// A channel to wait, will be closed after revision changes
cond chan struct{}
// Use RWMutex to protect cond variable
rwl sync.RWMutex
}
// Wait until revision is greater than lastRevision
func (w *Watch) WaitNext(ctx context.Context, lastRevision int64, notify chan<-int64) {
for {
w.rwl.RLock()
if w.revision > lastRevision {
w.rwl.RUnlock()
break
}
cond := w.cond
w.rwl.RUnlock()
select{
case <-cond:
case <-ctx.Done():
return
}
}
// We accept larger revision, so do not need to use RLock
select{
case notify<-w.revision:
case <-ctx.Done():
}
}
// Update revision
func (w *Watch) update(newRevision int64){
w.rwl.Lock()
defer w.rwl.Unlock()
w.revision = newRevision
close(w.cond)
w.cond = make(chan struct{})
}
func createWatch(client *clientv3.Client, prefix string) (*Watch, error) {
w := &Watch{0, make(chan struct{}), sync.RWMutex{}}
go func() {
rch := client.Watch(context.Background(), prefix, clientv3.WithPrefix(),
clientv3.WithCreatedNotify())
log.Debug("Watch created on %s", prefix)
for {
for wresp := range rch {
if wresp.CompactRevision > w.revision {
// respect CompactRevision
w.update(wresp.CompactRevision)
log.Debug("Watch to '%s' updated to %d by CompactRevision", prefix, wresp.CompactRevision)
} else if wresp.Header.GetRevision() > w.revision {
// Watch created or updated
w.update(wresp.Header.GetRevision())
log.Debug("Watch to '%s' updated to %d by header revision", prefix, wresp.Header.GetRevision())
}
if err := wresp.Err(); err != nil {
log.Error("Watch error: %s", err.Error())
}
}
log.Warning("Watch to '%s' stopped at revision %d", prefix, w.revision)
// Disconnected or cancelled
// Wait for a moment to avoid reconnecting
// too quickly
time.Sleep(time.Duration(1) * time.Second)
// Start from next revision so we are not missing anything
if w.revision > 0 {
rch = client.Watch(context.Background(), prefix, clientv3.WithPrefix(),
clientv3.WithRev(w.revision + 1))
} else {
// Start from the latest revision
rch = client.Watch(context.Background(), prefix, clientv3.WithPrefix(),
clientv3.WithCreatedNotify())
}
}
}()
return w, nil
}
// Client is a wrapper around the etcd client
type Client struct {
client *clientv3.Client
watches map[string]*Watch
// Protect watch
wm sync.Mutex
}
// NewEtcdClient returns an *etcdv3.Client with a connection to named machines.
func NewEtcdClient(machines []string, cert, key, caCert string, basicAuth bool, username string, password string) (*Client, error) {
cfg := clientv3.Config{
Endpoints: machines,
DialTimeout: 5 * time.Second,
DialKeepAliveTime: 10 * time.Second,
DialKeepAliveTimeout: 3 * time.Second,
}
if basicAuth {
cfg.Username = username
cfg.Password = password
}
tlsEnabled := false
tlsConfig := &tls.Config{
InsecureSkipVerify: false,
}
if caCert != "" {
certBytes, err := ioutil.ReadFile(caCert)
if err != nil {
return &Client{}, err
}
caCertPool := x509.NewCertPool()
ok := caCertPool.AppendCertsFromPEM(certBytes)
if ok {
tlsConfig.RootCAs = caCertPool
}
tlsEnabled = true
}
if cert != "" && key != "" {
tlsCert, err := tls.LoadX509KeyPair(cert, key)
if err != nil {
return &Client{}, err
}
tlsConfig.Certificates = []tls.Certificate{tlsCert}
tlsEnabled = true
}
if tlsEnabled {
cfg.TLS = tlsConfig
}
client, err := clientv3.New(cfg)
if err != nil {
return &Client{}, err
}
return &Client{client, make(map[string]*Watch), sync.Mutex{}}, nil
}
// GetValues queries etcd for keys prefixed by prefix.
func (c *Client) GetValues(keys []string) (map[string]string, error) {
// Use all operations on the same revision
var first_rev int64 = 0
vars := make(map[string]string)
// Default ETCDv3 TXN limitation. Since it is configurable from v3.3,
// maybe an option should be added (also set max-txn=0 can disable Txn?)
maxTxnOps := 128
getOps := make([]string, 0, maxTxnOps)
doTxn := func (ops []string) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(3) * time.Second)
defer cancel()
txnOps := make([]clientv3.Op, 0, maxTxnOps)
for _, k := range ops {
txnOps = append(txnOps, clientv3.OpGet(k,
clientv3.WithPrefix(),
clientv3.WithSort(clientv3.SortByKey, clientv3.SortDescend),
clientv3.WithRev(first_rev)))
}
result, err := c.client.Txn(ctx).Then(txnOps...).Commit()
if err != nil {
return err
}
for i, r := range result.Responses {
originKey := ops[i]
// append a '/' if not already exists
originKeyFixed := originKey
if !strings.HasSuffix(originKeyFixed, "/") {
originKeyFixed = originKey + "/"
}
for _, ev := range r.GetResponseRange().Kvs {
k := string(ev.Key)
if k == originKey || strings.HasPrefix(k, originKeyFixed) {
vars[string(ev.Key)] = string(ev.Value)
}
}
}
if first_rev == 0 {
// Save the revison of the first request
first_rev = result.Header.GetRevision()
}
return nil
}
for _, key := range keys {
getOps = append(getOps, key)
if len(getOps) >= maxTxnOps {
if err := doTxn(getOps); err != nil {
return vars, err
}
getOps = getOps[:0]
}
}
if len(getOps) > 0 {
if err := doTxn(getOps); err != nil {
return vars, err
}
}
return vars, nil
}
func (c *Client) WatchPrefix(prefix string, keys []string, waitIndex uint64, stopChan chan bool) (uint64, error) {
var err error
// Create watch for each key
watches := make(map[string]*Watch)
c.wm.Lock()
for _, k := range keys {
watch, ok := c.watches[k]
if !ok {
watch, err = createWatch(c.client, k)
if err != nil {
c.wm.Unlock()
return 0, err
}
c.watches[k] = watch
}
watches[k] = watch
}
c.wm.Unlock()
ctx, cancel := context.WithCancel(context.Background())
cancelRoutine := make(chan struct{})
defer cancel()
defer close(cancelRoutine)
go func() {
select {
case <-stopChan:
cancel()
case <-cancelRoutine:
return
}
}()
notify := make(chan int64)
// Wait for all watches
for _, v := range watches {
go v.WaitNext(ctx, int64(waitIndex), notify)
}
select{
case nextRevision := <- notify:
return uint64(nextRevision), err
case <-ctx.Done():
return 0, ctx.Err()
}
return 0, err
}