/
tlsdial.go
124 lines (111 loc) · 3.09 KB
/
tlsdial.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
// Copyright (c) 2023 RethinkDNS and its authors.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package dialers
import (
"crypto/tls"
"errors"
"net"
"net/netip"
"strconv"
"time"
"github.com/celzero/firestack/intra/log"
)
type tlsConnectFunc func(*tls.Dialer, string, string, netip.Addr, int) (net.Conn, error)
func tlsConnect(d *tls.Dialer, proto, sni string, ip netip.Addr, port int) (net.Conn, error) {
if d == nil {
log.E("tlsdial: tlsConnect: nil dialer")
return nil, errNoDialer
} else if !ipok(ip) {
log.E("tlsdial: tlsConnect: invalid ip", ip)
return nil, errNoIps
}
switch proto {
case "tcp", "tcp4", "tcp6":
fallthrough
case "udp", "udp4", "udp6":
fallthrough
default:
if d.Config == nil {
d.Config = &tls.Config{
ServerName: sni,
}
} else if len(d.Config.ServerName) <= 0 {
d.Config.ServerName = sni
}
return d.Dial(proto, addr(ip, port))
}
}
func tlsdial(d *tls.Dialer, network, addr string, connect tlsConnectFunc) (net.Conn, error) {
start := time.Now()
log.D("tlsdial: dialing %s", addr)
domain, portstr, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
// cannot dial into a wildcard address
// while, listen is unsupported
if len(domain) == 0 {
return nil, net.InvalidAddrError(addr)
}
port, err := strconv.Atoi(portstr)
if err != nil {
return nil, err
}
var errs error
ips := ipm.Get(domain)
confirmed := ips.Confirmed()
if ipok(confirmed) {
log.V("tlsdial: confirmed ip %s for %s", confirmed, addr)
if conn, cerr := connect(d, network, domain, confirmed, port); cerr == nil {
log.V("tlsdial: found working ip %s for %s", confirmed, addr)
return conn, nil
} else {
errs = errors.Join(errs, cerr)
ips.Disconfirm(confirmed)
log.D("tlsdial: confirmed ip %s for %s failed with err %v", confirmed, addr, cerr)
}
}
ipset := ips.Addrs()
allips := filter(ipset, confirmed)
if len(allips) <= 0 {
var ok bool
if ips, ok = renew(domain, ips); ok {
ipset = ips.Addrs()
allips = filter(ipset, confirmed)
}
log.D("tlsdial: renew ips for %s; ok? %t", addr, ok)
}
log.D("tlsdial: trying all ips %d for %s", len(allips), addr)
for _, ip := range allips {
end := time.Since(start)
if end > dialRetryTimeout {
log.D("pdial: timeout %s for %s", end, addr)
break
}
if ipok(ip) {
log.V("tlsdial: trying ip %s for %s", ip, addr)
if conn, err := connect(d, network, domain, ip, port); err == nil {
ips.Confirm(ip)
log.I("tlsdial: found working ip %s for %s", ip, addr)
return conn, nil
} else {
errs = errors.Join(errs, err)
log.W("tlsdial: ip %s for %s failed with err %v", ip, addr, err)
}
} else {
log.D("tlsdial: ip %s for %s is not ok", ip, addr)
}
}
dur := time.Since(start)
log.D("tlsdial: duration: %s; failed %s", dur, addr)
if len(ipset) <= 0 {
errs = errNoIps
}
return nil, errs
}
func TlsDial(d *tls.Dialer, network, addr string) (net.Conn, error) {
return tlsdial(d, network, addr, tlsConnect)
}