/
watcher.go
314 lines (241 loc) · 6.53 KB
/
watcher.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
// SPDX-FileCopyrightText: 2023 Steffen Vogel <post@steffenvogel.de>
// SPDX-License-Identifier: Apache-2.0
// Package watcher keeps track and monitors for new, removed and modified WireGuard interfaces and peers.
package daemon
import (
"errors"
"fmt"
"strings"
"sync"
"time"
"go.uber.org/zap"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"cunicu.li/cunicu/pkg/crypto"
"cunicu.li/cunicu/pkg/log"
slicesx "cunicu.li/cunicu/pkg/types/slices"
)
var errNotSupported = errors.New("not supported on this platform")
const (
InterfaceAdded InterfaceEventOp = iota
InterfaceDeleted
)
type InterfaceFilterFunc func(string) bool
type (
InterfaceEventOp int
InterfaceEvent struct {
Op InterfaceEventOp
Name string
}
)
func (ls InterfaceEventOp) String() string {
switch ls {
case InterfaceAdded:
return "added"
case InterfaceDeleted:
return "deleted"
default:
return ""
}
}
func (e InterfaceEvent) String() string {
return fmt.Sprintf("%s %s", e.Name, e.Op)
}
// Watcher monitors both userspace and kernel for changes to WireGuard interfaces
type Watcher struct {
interfaces InterfaceList
devices []*wgtypes.Device
mu sync.RWMutex
onInterface []InterfaceHandler
client *wgctrl.Client
events chan InterfaceEvent
errors chan error
stop chan any
stopped chan any
manualTrigger chan any
// Settings
filter InterfaceFilterFunc
interval time.Duration
logger *log.Logger
}
func NewWatcher(client *wgctrl.Client, interval time.Duration, filter InterfaceFilterFunc) (*Watcher, error) {
return &Watcher{
interfaces: InterfaceList{},
devices: []*wgtypes.Device{},
onInterface: []InterfaceHandler{},
client: client,
filter: filter,
interval: interval,
events: make(chan InterfaceEvent, 16),
errors: make(chan error, 16),
manualTrigger: make(chan any, 16),
stop: make(chan any),
stopped: make(chan any),
logger: log.Global.Named("watcher"),
}, nil
}
func (w *Watcher) Close() error {
close(w.stop)
<-w.stopped
return nil
}
func (w *Watcher) Watch() {
if err := w.watchUserInterfaces(); err != nil {
w.logger.Fatal("Failed to watch userspace interfaces", zap.Error(err))
}
w.logger.Debug("Started watching for changes of WireGuard userspace interfaces")
if err := w.watchKernelInterfaces(); err != nil && !errors.Is(err, errNotSupported) {
w.logger.Fatal("Failed to watch kernel interfaces", zap.Error(err))
}
w.logger.Debug("Started watching for changes of WireGuard kernel interfaces")
// TODO: Watch for kernel routing tables, assigned addresses, MTUs ...
ticker := &time.Ticker{}
if w.interval > 0 {
ticker = time.NewTicker(w.interval)
defer ticker.Stop()
}
out:
for {
select {
case <-w.manualTrigger:
w.logger.DebugV(10, "Start interface synchronization")
if err := w.syncInterfaces(); err != nil {
w.logger.Error("Synchronization failed", zap.Error(err))
}
// We still a need periodic sync we can not (yet) monitor WireGuard interfaces
// for changes via a netlink socket (patch is pending)
case <-ticker.C:
w.logger.DebugV(10, "Start periodic interface synchronization")
if err := w.syncInterfaces(); err != nil {
w.logger.Error("Synchronization failed", zap.Error(err))
}
case event := <-w.events:
w.logger.DebugV(10, "Received interface event", zap.Any("event", event))
if err := w.syncInterfaces(); err != nil {
w.logger.Error("Synchronization failed", zap.Error(err))
}
case err := <-w.errors:
w.logger.Error("Failed to watch for interface changes", zap.Error(err))
case <-w.stop:
break out
}
}
close(w.stopped)
}
func (w *Watcher) Sync() error {
w.manualTrigger <- nil
return nil
}
func (w *Watcher) syncInterfaces() error {
var err error
var newDevs []*wgtypes.Device
oldDevs := w.devices
w.mu.Lock()
if newDevs, err = w.client.Devices(); err != nil {
w.mu.Unlock()
return fmt.Errorf("failed to list WireGuard interfaces: %w", err)
}
// Ignore devices which do not match the filter
newDevs = slicesx.Filter(newDevs, func(d *wgtypes.Device) bool {
return w.filter == nil || w.filter(d.Name)
})
added, removed, kept := slicesx.DiffFunc(oldDevs, newDevs, func(a, b *wgtypes.Device) int {
return strings.Compare(a.Name, b.Name)
})
w.mu.Unlock()
for _, wgd := range removed {
i, ok := w.interfaces[wgd.Name]
if !ok {
w.logger.Warn("Failed to find matching interface", zap.Any("intf", wgd.Name))
continue
}
w.logger.Info("Interface removed", zap.String("intf", wgd.Name))
for _, h := range w.onInterface {
h.OnInterfaceRemoved(i)
}
delete(w.interfaces, wgd.Name)
}
for _, wgd := range added {
w.logger.Info("Interface added", zap.String("intf", wgd.Name))
i, err := NewInterface(wgd, w.client)
if err != nil {
w.logger.Fatal("Failed to create new interface",
zap.Error(err),
zap.String("intf", wgd.Name),
)
}
for _, h := range w.onInterface {
h.OnInterfaceAdded(i)
}
// We purposefully prune the peer list here to force full initial sync of all peers
wgdCopy := *wgd
wgd.Peers = nil
i.syncInterface(&wgdCopy)
w.interfaces[wgd.Name] = i
}
for _, wgd := range kept {
i, ok := w.interfaces[wgd.Name]
if !ok {
w.logger.Warn("Failed to find matching interface", zap.Any("intf", wgd.Name))
continue
}
i.syncInterface(wgd)
}
w.devices = newDevs
return nil
}
func (w *Watcher) Peer(intf string, pk *crypto.Key) *Peer {
i := w.InterfaceByName(intf)
if i == nil {
return nil
}
if p, ok := i.Peers[*pk]; ok {
return p
}
return nil
}
func (w *Watcher) PeerByPublicKey(pk *crypto.Key) *Peer {
w.mu.RLock()
defer w.mu.RUnlock()
for _, i := range w.interfaces {
if p, ok := i.Peers[*pk]; ok {
return p
}
}
return nil
}
func (w *Watcher) InterfaceByName(name string) *Interface {
w.mu.RLock()
defer w.mu.RUnlock()
return w.interfaces.ByName(name)
}
func (w *Watcher) InterfaceByPublicKey(pk crypto.Key) *Interface {
w.mu.RLock()
defer w.mu.RUnlock()
return w.interfaces.ByPublicKey(pk)
}
func (w *Watcher) InterfaceByIndex(idx int) *Interface {
w.mu.RLock()
defer w.mu.RUnlock()
return w.interfaces.ByIndex(idx)
}
func (w *Watcher) ForEachInterface(cb func(i *Interface) error) error {
w.mu.RLock()
defer w.mu.RUnlock()
for _, i := range w.interfaces {
if err := cb(i); err != nil {
return err
}
}
return nil
}
func (w *Watcher) ForEachPeer(cb func(p *Peer) error) error {
return w.ForEachInterface(func(i *Interface) error {
for _, p := range i.Peers {
if err := cb(p); err != nil {
return err
}
}
return nil
})
}