/
control.go
278 lines (241 loc) · 7.44 KB
/
control.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
package vpn
//
// OpenVPN control channel
//
import (
"bytes"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"math"
"net"
"sync"
)
var (
errBadReset = errors.New("bad reset packet")
errExpiredKey = errors.New("max packet id reached")
)
var (
serverPushReply = []byte("PUSH_REPLY")
serverBadAuth = []byte("AUTH_FAILED")
)
// session keeps mutable state related to an OpenVPN session.
type session struct {
RemoteSessionID sessionID
LocalSessionID sessionID
keys []*dataChannelKey
keyID int
localPacketID packetID
lastACK packetID
ackQueue chan *packet
mu sync.Mutex
Log Logger
}
// newSession returns a session ready to be used.
func newSession() (*session, error) {
key0 := &dataChannelKey{}
ackQueue := make(chan *packet, 100)
session := &session{
keys: []*dataChannelKey{key0},
ackQueue: ackQueue,
}
randomBytes, err := randomFn(8)
if err != nil {
return session, err
}
// in go 1.17, one could do:
// localSession := (*sessionID)(lsid)
var localSession sessionID
copy(localSession[:], randomBytes[:8])
session.LocalSessionID = localSession
localKey, err := newKeySource()
if err != nil {
return session, err
}
k, err := session.ActiveKey()
if err != nil {
return session, err
}
k.local = localKey
return session, nil
}
// ActiveKey returns the dataChannelKey that is actively being used.
func (s *session) ActiveKey() (*dataChannelKey, error) {
if len(s.keys) < s.keyID {
return nil, fmt.Errorf("%w: %s", errDataChannelKey, "no such key id")
}
dck := s.keys[s.keyID]
return dck, nil
}
// localPacketID returns an unique Packet ID. It increments the counter.
// In the future, this call could detect (or warn us) when we're approaching
// the key end of life.
func (s *session) LocalPacketID() (packetID, error) {
s.mu.Lock()
defer s.mu.Unlock()
pid := s.localPacketID
if pid == math.MaxUint32 {
// we reached the max packetID, increment will overflow
return 0, errExpiredKey
}
s.localPacketID++
return pid, nil
}
// UpdateLastACK will update the internal variable for the last acknowledged
// packet to the passed packetID, only if packetID is greater than the lastACK.
func (s *session) UpdateLastACK(newPacketID packetID) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.lastACK == math.MaxUint32 {
return errExpiredKey
}
if s.lastACK != 0 && newPacketID <= s.lastACK {
logger.Warnf("tried to write ack %d; last was %d", newPacketID, s.lastACK)
}
s.lastACK = newPacketID
return nil
}
// isNextPacket returns true if the packetID is the next integer
// from the last acknowledged packet.
func (s *session) isNextPacket(p *packet) bool {
s.mu.Lock()
defer s.mu.Unlock()
if p == nil {
return false
}
return p.id-s.lastACK == 1
}
// control implements the controlHandler interface.
// Like for true pirates, there is no state in control.
type control struct{}
// SendHardReset sends a control packet with the HardResetClientv2 header,
// over the passed net.Conn.
func (c *control) SendHardReset(conn net.Conn, s *session) error {
_, err := sendControlPacket(conn, s, pControlHardResetClientV2, 0, []byte(""))
return err
}
// ParseHardReset extracts the sessionID from a hard-reset server response, and
// an error if the operation was not successful.
func (c *control) ParseHardReset(b []byte) (sessionID, error) {
p, err := newServerHardReset(b)
if err != nil {
return sessionID{}, err
}
return parseServerHardResetPacket(p)
}
// PushRequest returns a byte array with the PUSH_REQUEST command.
func (c *control) PushRequest() []byte {
var out bytes.Buffer
out.Write([]byte("PUSH_REQUEST"))
out.WriteByte(0x00)
return out.Bytes()
}
// ReadReadPushResponse reads a byte array returned from the server,
// as the response to a Push Request, and returns a string containing the
// tunnel IP.
// For now, this is a single string containing _only_ the tunnel ip,
// but we might want to pass a pointer to the tunnel struct in the
// future.
func (*control) ReadPushResponse(b []byte) string {
return parsePushedOptions(b)
}
// ControlMessage returns a byte array containing a message over the control
// channel.
// This is not a P_CONTROL, but a message over the TLS encrypted channel.
func (c *control) ControlMessage(s *session, opt *Options) ([]byte, error) {
key, err := s.ActiveKey()
if err != nil {
return []byte{}, err
}
return encodeClientControlMessageAsBytes(key.local, opt)
}
// ReadControlMessage reads a control message with authentication result data.
// it returns the remote key, remote options and an error if we cannot parse
// the data.
func (c *control) ReadControlMessage(b []byte) (*keySource, string, error) {
cm := newServerControlMessageFromBytes(b)
return parseServerControlMessage(cm)
}
// SendACK builds an ACK control packet for the given packetID, and writes it
// over the passed connection. It returns an error if the operation cannot be
// completed successfully.
func (c *control) SendACK(conn net.Conn, s *session, pid packetID) error {
return sendACKFn(conn, s, pid)
}
// sendACK is used by controlHandler.SendACK() and by TLSConn.Read()
func sendACK(conn net.Conn, s *session, pid packetID) error {
panicIfFalse(len(s.RemoteSessionID) != 0, "tried to ack with null remote")
p := newACKPacket(pid, s)
payload := p.Bytes()
payload = maybeAddSizeFrame(conn, payload)
_, err := conn.Write(payload)
if err != nil {
return err
}
logger.Debug(fmt.Sprintln("write ack:", pid))
logger.Debug(fmt.Sprintln(hex.Dump(payload)))
return s.UpdateLastACK(pid)
}
var sendACKFn = sendACK
var _ controlHandler = &control{} // Ensure that we implement controlHandler
// sendControlPacket crafts a control packet with the given opcode and payload,
// and writes it to the passed net.Conn.
func sendControlPacket(conn net.Conn, s *session, opcode int, ack int, payload []byte) (n int, err error) {
if s == nil {
return 0, fmt.Errorf("%w:%s", errBadInput, "nil session")
}
p := newPacketFromPayload(uint8(opcode), 0, payload)
p.localSessionID = s.LocalSessionID
p.id, err = s.LocalPacketID()
if err != nil {
return 0, err
}
out := p.Bytes()
out = maybeAddSizeFrame(conn, out)
logger.Debug(fmt.Sprintf("control write: (%d bytes)\n", len(out)))
logger.Debug(fmt.Sprintln(hex.Dump(out)))
return conn.Write(out)
}
// isControlMessage returns a boolean indicating whether the header of a
// payload indicates a control message.
func isControlMessage(b []byte) bool {
if len(b) < 4 {
return false
}
return bytes.Equal(b[:4], controlMessageHeader)
}
// maybeAddSizeFrame prepends a two-byte header containing the size of the
// payload if the network type for the passed net.Conn is not UDP (assumed to
// be TCP).
func maybeAddSizeFrame(conn net.Conn, payload []byte) []byte {
switch conn.LocalAddr().Network() {
case "udp", "udp4", "udp6":
// nothing to do for UDP
return payload
case "tcp", "tcp4", "tcp6":
length := make([]byte, 2)
binary.BigEndian.PutUint16(length, uint16(len(payload)))
return append(length, payload...)
default:
return []byte{}
}
}
// isBadAuthReply returns true if the passed payload is a "bad auth" server
// response; false otherwise.
func isBadAuthReply(b []byte) bool {
l := len(serverBadAuth)
if len(b) < l {
return false
}
return bytes.Equal(b[:l], serverBadAuth)
}
// isPushReply returns true if the passed payload is a "push reply" server
// response; false otherwise.
func isPushReply(b []byte) bool {
l := len(serverPushReply)
if len(b) < l {
return false
}
return bytes.Equal(b[:l], serverPushReply)
}