/
device_kernel.go
134 lines (105 loc) · 2.79 KB
/
device_kernel.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
// SPDX-FileCopyrightText: 2023 Steffen Vogel <post@steffenvogel.de>
// SPDX-License-Identifier: Apache-2.0
package device
import (
"errors"
"fmt"
"net"
"go.uber.org/zap"
wgconn "golang.zx2c4.com/wireguard/conn"
wgdevice "golang.zx2c4.com/wireguard/device"
"cunicu.li/cunicu/pkg/link"
"cunicu.li/cunicu/pkg/log"
"cunicu.li/cunicu/pkg/wg"
)
var errNotWireGuardLink = errors.New("link is not a WireGuard link")
type KernelDevice struct {
link.Link
ListenPort int
bind *wg.Bind
logger *log.Logger
}
func NewKernelDevice(name string) (*KernelDevice, error) {
logger := log.Global.Named("dev").With(
zap.String("dev", name),
zap.String("type", "kernel"),
)
lnk, err := link.CreateWireGuardLink(name)
if err != nil {
return nil, fmt.Errorf("failed to create WireGuard link: %w", err)
}
return &KernelDevice{
Link: lnk,
bind: wg.NewBind(logger),
logger: logger,
}, nil
}
func FindKernelDevice(name string) (*KernelDevice, error) {
logger := log.Global.Named("dev").With(
zap.String("dev", name),
zap.String("type", "kernel"),
)
lnk, err := link.FindLink(name)
if err != nil {
return nil, fmt.Errorf("failed to find WireGuard link: %w", err)
}
// TODO: Is this portable?
if lnk.Type() != link.TypeWireGuard {
return nil, fmt.Errorf("%w: %s", errNotWireGuardLink, lnk.Name())
}
return &KernelDevice{
Link: lnk,
bind: wg.NewBind(logger),
logger: logger,
}, nil
}
func (d *KernelDevice) Bind() *wg.Bind {
return d.bind
}
func (d *KernelDevice) BindUpdate() error {
if d.ListenPort == 0 {
d.logger.Debug("Skip bind update as we no listen port yet")
return nil
}
if err := d.bind.Close(); err != nil {
return fmt.Errorf("failed to close bind: %w", err)
}
rcvFns, _, err := d.bind.Open(0)
if err != nil {
return fmt.Errorf("failed to open bind: %w", err)
}
for _, rcvFn := range rcvFns {
go d.doReceive(rcvFn)
}
return nil
}
func (d *KernelDevice) doReceive(rcvFn wgconn.ReceiveFunc) {
d.logger.Debug("Receive worker started")
batchSize := 1
packets := make([][]byte, batchSize)
sizes := make([]int, batchSize)
eps := make([]wgconn.Endpoint, batchSize)
packets[0] = make([]byte, wgdevice.MaxMessageSize)
for {
n, err := rcvFn(packets, sizes, eps)
if err != nil {
if errors.Is(err, net.ErrClosed) {
break
}
d.logger.Error("Failed to receive from bind", zap.Error(err))
continue
} else if n == 0 || sizes[0] == 0 {
continue
}
ep := eps[0].(*wg.BindEndpoint) //nolint:forcetypeassert
kc, ok := ep.Conn.(wg.BindKernelConn)
if !ok {
d.logger.Error("No kernel connection found", zap.String("ep", ep.DstToString()))
continue
}
if _, err := kc.WriteKernel(packets[0][:sizes[0]]); err != nil {
d.logger.Error("Failed to write to kernel", zap.Error(err))
}
}
d.logger.Debug("Receive worker stopped")
}