/
tlssniproxy.go
130 lines (107 loc) · 3.46 KB
/
tlssniproxy.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
package testingx
import (
"context"
"errors"
"fmt"
"io"
"net"
"sync"
"github.com/ooni/netem"
"github.com/ooni/probe-engine/pkg/logx"
"github.com/ooni/probe-engine/pkg/model"
"github.com/ooni/probe-engine/pkg/runtimex"
)
// TLSSNIProxyNetx is how [TLSSNIProxy] views [*netxlite.Netx].
type TLSSNIProxyNetx interface {
ListenTCP(network string, addr *net.TCPAddr) (net.Listener, error)
NewDialerWithResolver(dl model.DebugLogger, r model.Resolver, w ...model.DialerWrapper) model.Dialer
NewStdlibResolver(logger model.DebugLogger) model.Resolver
}
// TLSSNIProxy is a proxy using the SNI to figure out where to connect to.
type TLSSNIProxy struct {
// closeOnce provides "once" semantics for Close.
closeOnce sync.Once
// listener is the TCP listener we're using.
listener net.Listener
// logger is the logger we should use.
logger model.Logger
// netx is the underlying network.
netx TLSSNIProxyNetx
// wg is the wait group for the background listener
wg *sync.WaitGroup
}
// MustNewTLSSNIProxyEx creates a new [*TLSSNIProxy].
func MustNewTLSSNIProxyEx(
logger model.Logger, netx TLSSNIProxyNetx, tcpAddr *net.TCPAddr) *TLSSNIProxy {
listener := runtimex.Try1(netx.ListenTCP("tcp", tcpAddr))
proxy := &TLSSNIProxy{
closeOnce: sync.Once{},
listener: listener,
logger: &logx.PrefixLogger{
Prefix: fmt.Sprintf("%-16s", "TLSPROXY"),
Logger: logger,
},
netx: netx,
wg: &sync.WaitGroup{},
}
proxy.wg.Add(1)
go proxy.mainloop()
return proxy
}
// Close implements io.Closer
func (tp *TLSSNIProxy) Close() (err error) {
tp.closeOnce.Do(func() {
err = tp.listener.Close()
tp.wg.Wait()
})
return
}
// Endpoint returns the listening endpoint or nil after Close has been called.
func (tp *TLSSNIProxy) Endpoint() string {
return tp.listener.Addr().String()
}
func (tp *TLSSNIProxy) mainloop() {
// make sure panics don't crash the process
defer runtimex.CatchLogAndIgnorePanic(tp.logger, "TLSSNIProxy.mainloop")
defer tp.wg.Done()
for {
conn, err := tp.listener.Accept()
if errors.Is(err, net.ErrClosed) {
return
}
// use panics to reduce the testing surface, which is ~okay given
// that this code is meant to support testing
runtimex.PanicOnError(err, "tp.listener.Accept() failed")
// we're creating a goroutine per connection, which is ~okay because
// this code is designed for helping with testing
go tp.handle(conn)
}
}
func (tp *TLSSNIProxy) handle(clientConn net.Conn) {
// make sure panics don't crash the process
defer runtimex.CatchLogAndIgnorePanic(tp.logger, "TLSSNIProxy.handle")
// make sure we close the client connection
defer clientConn.Close()
// read initial records
buffer := make([]byte, 1<<17)
count := runtimex.Try1(clientConn.Read(buffer))
rawRecords := buffer[:count]
// inspecty the raw records to find the SNI
sni := runtimex.Try1(netem.ExtractTLSServerName(rawRecords))
// connect to the remote host
tcpDialer := tp.netx.NewDialerWithResolver(tp.logger, tp.netx.NewStdlibResolver(tp.logger))
serverConn := runtimex.Try1(tcpDialer.DialContext(context.Background(), "tcp", net.JoinHostPort(sni, "443")))
defer serverConn.Close()
// forward the initial records to the server
_ = runtimex.Try1(serverConn.Write(rawRecords))
// route traffic between the conns
wg := &sync.WaitGroup{}
wg.Add(2)
go tp.forward(wg, clientConn, serverConn)
go tp.forward(wg, serverConn, clientConn)
wg.Wait()
}
func (tp *TLSSNIProxy) forward(wg *sync.WaitGroup, left, right net.Conn) {
defer wg.Done()
io.Copy(right, left)
}