-
Notifications
You must be signed in to change notification settings - Fork 14
/
rdial.go
266 lines (237 loc) · 7.99 KB
/
rdial.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
// Copyright (c) 2023 RethinkDNS and its authors.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package dialers
import (
"crypto/tls"
"errors"
"net"
"net/netip"
"strconv"
"time"
"github.com/celzero/firestack/intra/log"
"github.com/celzero/firestack/intra/protect"
"github.com/celzero/firestack/intra/settings"
)
type connectFunc func(*protect.RDial, string, netip.Addr, int) (net.Conn, error)
const dialRetryTimeout = 1 * time.Minute
func maybeFilter(ips []netip.Addr, alwaysExclude netip.Addr) []netip.Addr {
filtered := make([]netip.Addr, 0, len(ips))
unfiltered := make([]netip.Addr, 0, len(ips))
for _, ip := range ips {
if ip.Compare(alwaysExclude) == 0 || !ip.IsValid() {
continue
} else if ip.Is4() && ipProto == settings.IP6 {
unfiltered = append(unfiltered, ip)
} else if ip.Is6() && ipProto == settings.IP4 {
unfiltered = append(unfiltered, ip)
} else {
filtered = append(filtered, ip)
}
}
if len(filtered) <= 0 {
// if all ips are filtered out, fail open and return unfiltered
return unfiltered
}
if len(unfiltered) > 0 {
// sample one unfiltered ip in an ironic case that it works
// but the filtered out ones don't. this can happen in scenarios
// where tunnel's ipProto is IP4 but the underlying network is IP6:
// that is, IP6 is filtered out even though it might have worked.
filtered = append(filtered, unfiltered[0])
}
return filtered
}
// ipConnect dials into ip:port using the provided dialer and returns a net.Conn
// net.Conn is guaranteed to be either net.UDPConn or net.TCPConn
func ipConnect(d *protect.RDial, proto string, ip netip.Addr, port int) (net.Conn, error) {
if d == nil {
log.E("rdial: ipConnect: nil dialer")
return nil, errNoDialer
} else if !ipok(ip) {
log.E("rdial: ipConnect: invalid ip", ip)
return nil, errNoIps
}
switch proto {
case "tcp", "tcp4", "tcp6":
return d.DialTCP(proto, nil, tcpaddr(ip, port))
case "udp", "udp4", "udp6":
return d.DialUDP(proto, nil, udpaddr(ip, port))
default:
return d.Dial(proto, addr(ip, port))
}
}
// ipConnect2 dials into ip:port using the provided dialer and returns a net.Conn
// net.Conn may not be any among net.UDPConn or net.TCPConn or core.UDPConn or core.TCPConn
func ipConnect2(d *protect.RDial, proto string, ip netip.Addr, port int) (net.Conn, error) {
if d == nil {
log.E("rdial: ipConnect2: nil dialer")
return nil, errNoDialer
} else if !ipok(ip) {
log.E("rdial: ipConnect2: invalid ip", ip)
return nil, errNoIps
}
return d.Dial(proto, addr(ip, port))
}
func doSplit(port int) bool {
// HTTPS or DoT
return port == 443 || port == 853
}
func splitIpConnect(d *protect.RDial, proto string, ip netip.Addr, port int) (net.Conn, error) {
if d == nil {
log.E("rdial: splitIpConnect: nil dialer")
return nil, errNoDialer
} else if !ipok(ip) {
log.E("rdial: splitIpConnect: invalid ip", ip)
return nil, errNoIps
}
switch proto {
case "tcp", "tcp4", "tcp6":
if doSplit(port) { // split tls client-hello for https requests
return DialWithSplitRetry(d, tcpaddr(ip, port))
}
return d.DialTCP(proto, nil, tcpaddr(ip, port))
case "udp", "udp4", "udp6":
return d.DialUDP(proto, nil, udpaddr(ip, port))
default:
return d.Dial(proto, addr(ip, port))
}
}
func splitIpConnect2(d *protect.RDial, proto string, ip netip.Addr, port int) (net.Conn, error) {
if d == nil {
log.E("rdial: splitIpConnect2: nil dialer")
return nil, errNoDialer
} else if !ipok(ip) {
log.E("rdial: splitIpConnect2: invalid ip", ip)
return nil, errNoIps
}
switch proto {
case "tcp", "tcp4", "tcp6":
if doSplit(port) { // split tls client-hello for https requests
return DialWithSplit(d, tcpaddr(ip, port))
}
return d.DialTCP(proto, nil, tcpaddr(ip, port))
case "udp", "udp4", "udp6":
return d.DialUDP(proto, nil, udpaddr(ip, port))
default:
return d.Dial(proto, addr(ip, port))
}
}
func commondial(d *protect.RDial, network, addr string, connect connectFunc) (net.Conn, error) {
start := time.Now()
log.D("rdial: commondial: dialing (host:port) %s", addr)
domain, portstr, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
// cannot dial into a wildcard address
// while, listen is unsupported
if len(domain) == 0 {
return nil, net.InvalidAddrError(addr)
}
port, err := strconv.Atoi(portstr)
if err != nil {
return nil, err
}
var conn net.Conn
var errs error
ips := ipm.Get(domain)
confirmed := ips.Confirmed() // may be zeroaddr
if ipok(confirmed) {
log.V("rdial: commondial: dialing confirmed ip %s for %s", confirmed, addr)
if conn, err = connect(d, network, confirmed, port); err == nil {
log.V("rdial: commondial: ip %s works for %s", confirmed, addr)
return conn, nil
}
errs = errors.Join(errs, err)
ips.Disconfirm(confirmed)
log.D("rdial: commondial: confirmed ip %s for %s failed with err %v", confirmed, addr, err)
}
ipset := ips.Addrs()
allips := maybeFilter(ipset, confirmed)
if len(allips) <= 0 {
var ok bool
if ips, ok = renew(domain, ips); ok {
ipset = ips.Addrs()
allips = maybeFilter(ipset, confirmed)
}
log.D("rdial: renew ips for %s; ok? %t", addr, ok)
}
log.D("rdial: commondial: trying all ips %d for %s", len(allips), addr)
for _, ip := range allips {
end := time.Since(start)
if end > dialRetryTimeout {
log.D("rdial: commondial: timeout %s for %s", end, addr)
break
}
if ipok(ip) {
if conn, err = connect(d, network, ip, port); err == nil {
log.V("rdial: commondial: dialing ip %s for %s", ip, addr)
confirm(ips, ip)
log.I("rdial: commondial: ip %s works for %s", ip, addr)
return conn, nil
}
errs = errors.Join(errs, err)
log.W("rdial: commondial: ip %s for %s failed with err %v", ip, addr, err)
} else {
log.W("rdial: commondial: ip %s not ok for %s", ip, addr)
}
}
dur := time.Since(start)
log.D("rdial: commondial: duration: %s; failed %s", dur, addr)
if len(ipset) <= 0 {
errs = errNoIps
}
return nil, errs
}
// ListenPacket listens on for UDP connections on the local address using d.
// Returned net.Conn is guaranteed to be a *net.UDPConn.
func ListenPacket(d *protect.RDial, network, local string) (net.PacketConn, error) {
if d == nil {
log.E("rdial: ListenPacket: nil dialer")
return nil, errNoListener
}
return d.AnnounceUDP(network, local)
}
// Listen listens on for TCP connections on the local address using d.
func Listen(d *protect.RDial, network, local string) (net.Listener, error) {
if d == nil {
log.E("rdial: Listen: nil dialer")
return nil, errNoListener
}
return d.AcceptTCP(network, local)
}
// Dial dials into addr using the provided dialer and returns a net.Conn,
// which is guaranteed to be either net.UDPConn or net.TCPConn
func Dial(d *protect.RDial, network, addr string) (net.Conn, error) {
return commondial(d, network, addr, ipConnect)
}
// Dial2 dials into addr using the provided dialer and returns a net.Conn
func Dial2(d *protect.RDial, network, addr string) (net.Conn, error) {
return commondial(d, network, addr, ipConnect2)
}
// SplitDial dials into addr splitting ClientHello if the first connection
// is unsuccessful. Using the provided dialer it returns a net.Conn,
// which may not be net.UDPConn or net.TCPConn
func SplitDial(d *protect.RDial, network, addr string) (net.Conn, error) {
return commondial(d, network, addr, splitIpConnect)
}
// SplitDial2 is like SplitDial except it splits ClientHello in all TLS connections.
func SplitDial2(d *protect.RDial, network, addr string) (net.Conn, error) {
return commondial(d, network, addr, splitIpConnect2)
}
// SplitDialWithTls dials into addr using the provided dialer and returns a tls.Conn
func SplitDialWithTls(d *protect.RDial, cfg *tls.Config, addr string) (net.Conn, error) {
c, err := commondial(d, "tcp", addr, splitIpConnect)
if err != nil {
return c, err
}
tlsconn := tls.Client(c, cfg)
err = tlsconn.Handshake()
return tlsconn, err
}
func ipok(ip netip.Addr) bool {
return ip.IsValid() && !ip.IsUnspecified()
}