/
protocol.go
256 lines (221 loc) · 6.64 KB
/
protocol.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
package connection
import (
"fmt"
"hash/fnv"
"sync"
"time"
"github.com/rs/zerolog"
"github.com/khulnasoft/netscale/edgediscovery"
)
const (
AvailableProtocolFlagMessage = "Available protocols: 'auto' - automatically chooses the best protocol over time (the default; and also the recommended one); 'quic' - based on QUIC, relying on UDP egress to Khulnasoft edge; 'http2' - using Go's HTTP2 library, relying on TCP egress to Khulnasoft edge"
// edgeH2muxTLSServerName is the server name to establish h2mux connection with edge
edgeH2muxTLSServerName = "cftunnel.com"
// edgeH2TLSServerName is the server name to establish http2 connection with edge
edgeH2TLSServerName = "h2.cftunnel.com"
// edgeQUICServerName is the server name to establish quic connection with edge.
edgeQUICServerName = "quic.cftunnel.com"
AutoSelectFlag = "auto"
// SRV and TXT record resolution TTL
ResolveTTL = time.Hour
)
var (
// ProtocolList represents a list of supported protocols for communication with the edge
// in order of precedence for remote percentage fetcher.
ProtocolList = []Protocol{QUIC, HTTP2}
)
type Protocol int64
const (
// HTTP2 using golang HTTP2 library for edge connections.
HTTP2 Protocol = iota
// QUIC using quic-go for edge connections.
QUIC
)
// Fallback returns the fallback protocol and whether the protocol has a fallback
func (p Protocol) fallback() (Protocol, bool) {
switch p {
case HTTP2:
return 0, false
case QUIC:
return HTTP2, true
default:
return 0, false
}
}
func (p Protocol) String() string {
switch p {
case HTTP2:
return "http2"
case QUIC:
return "quic"
default:
return fmt.Sprintf("unknown protocol")
}
}
func (p Protocol) TLSSettings() *TLSSettings {
switch p {
case HTTP2:
return &TLSSettings{
ServerName: edgeH2TLSServerName,
}
case QUIC:
return &TLSSettings{
ServerName: edgeQUICServerName,
NextProtos: []string{"argotunnel"},
}
default:
return nil
}
}
type TLSSettings struct {
ServerName string
NextProtos []string
}
type ProtocolSelector interface {
Current() Protocol
Fallback() (Protocol, bool)
}
// staticProtocolSelector will not provide a different protocol for Fallback
type staticProtocolSelector struct {
current Protocol
}
func (s *staticProtocolSelector) Current() Protocol {
return s.current
}
func (s *staticProtocolSelector) Fallback() (Protocol, bool) {
return s.current, false
}
// remoteProtocolSelector will fetch a list of remote protocols to provide for edge discovery
type remoteProtocolSelector struct {
lock sync.RWMutex
current Protocol
// protocolPool is desired protocols in the order of priority they should be picked in.
protocolPool []Protocol
switchThreshold int32
fetchFunc edgediscovery.PercentageFetcher
refreshAfter time.Time
ttl time.Duration
log *zerolog.Logger
}
func newRemoteProtocolSelector(
current Protocol,
protocolPool []Protocol,
switchThreshold int32,
fetchFunc edgediscovery.PercentageFetcher,
ttl time.Duration,
log *zerolog.Logger,
) *remoteProtocolSelector {
return &remoteProtocolSelector{
current: current,
protocolPool: protocolPool,
switchThreshold: switchThreshold,
fetchFunc: fetchFunc,
refreshAfter: time.Now().Add(ttl),
ttl: ttl,
log: log,
}
}
func (s *remoteProtocolSelector) Current() Protocol {
s.lock.Lock()
defer s.lock.Unlock()
if time.Now().Before(s.refreshAfter) {
return s.current
}
protocol, err := getProtocol(s.protocolPool, s.fetchFunc, s.switchThreshold)
if err != nil {
s.log.Err(err).Msg("Failed to refresh protocol")
return s.current
}
s.current = protocol
s.refreshAfter = time.Now().Add(s.ttl)
return s.current
}
func (s *remoteProtocolSelector) Fallback() (Protocol, bool) {
s.lock.RLock()
defer s.lock.RUnlock()
return s.current.fallback()
}
func getProtocol(protocolPool []Protocol, fetchFunc edgediscovery.PercentageFetcher, switchThreshold int32) (Protocol, error) {
protocolPercentages, err := fetchFunc()
if err != nil {
return 0, err
}
for _, protocol := range protocolPool {
protocolPercentage := protocolPercentages.GetPercentage(protocol.String())
if protocolPercentage > switchThreshold {
return protocol, nil
}
}
// Default to first index in protocolPool list
return protocolPool[0], nil
}
// defaultProtocolSelector will allow for a protocol to have a fallback
type defaultProtocolSelector struct {
lock sync.RWMutex
current Protocol
}
func newDefaultProtocolSelector(
current Protocol,
) *defaultProtocolSelector {
return &defaultProtocolSelector{
current: current,
}
}
func (s *defaultProtocolSelector) Current() Protocol {
s.lock.Lock()
defer s.lock.Unlock()
return s.current
}
func (s *defaultProtocolSelector) Fallback() (Protocol, bool) {
s.lock.RLock()
defer s.lock.RUnlock()
return s.current.fallback()
}
func NewProtocolSelector(
protocolFlag string,
accountTag string,
tunnelTokenProvided bool,
needPQ bool,
protocolFetcher edgediscovery.PercentageFetcher,
resolveTTL time.Duration,
log *zerolog.Logger,
) (ProtocolSelector, error) {
// With --post-quantum, we force quic
if needPQ {
return &staticProtocolSelector{
current: QUIC,
}, nil
}
threshold := switchThreshold(accountTag)
fetchedProtocol, err := getProtocol(ProtocolList, protocolFetcher, threshold)
log.Debug().Msgf("Fetched protocol: %s", fetchedProtocol)
if err != nil {
log.Warn().Msg("Unable to lookup protocol percentage.")
// Falling through here since 'auto' is handled in the switch and failing
// to do the protocol lookup isn't a failure since it can be triggered again
// after the TTL.
}
// If the user picks a protocol, then we stick to it no matter what.
switch protocolFlag {
case "h2mux":
// Any users still requesting h2mux will be upgraded to http2 instead
log.Warn().Msg("h2mux is no longer a supported protocol: upgrading edge connection to http2. Please remove '--protocol h2mux' from runtime arguments to remove this warning.")
return &staticProtocolSelector{current: HTTP2}, nil
case QUIC.String():
return &staticProtocolSelector{current: QUIC}, nil
case HTTP2.String():
return &staticProtocolSelector{current: HTTP2}, nil
case AutoSelectFlag:
// When a --token is provided, we want to start with QUIC but have fallback to HTTP2
if tunnelTokenProvided {
return newDefaultProtocolSelector(QUIC), nil
}
return newRemoteProtocolSelector(fetchedProtocol, ProtocolList, threshold, protocolFetcher, resolveTTL, log), nil
}
return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage)
}
func switchThreshold(accountTag string) int32 {
h := fnv.New32a()
_, _ = h.Write([]byte(accountTag))
return int32(h.Sum32() % 100)
}