/
dialer.go
110 lines (96 loc) · 3.11 KB
/
dialer.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
package chained
import (
"bufio"
"fmt"
"net"
"net/http"
"strings"
"github.com/getlantern/bandwidth"
"github.com/getlantern/errors"
)
// Config is a configuration for a Dialer.
type Config struct {
// DialServer: function that dials the upstream server proxy
DialServer func() (net.Conn, error)
// OnRequest: optional function that gets called on every CONNECT request to
// the server and is allowed to modify the http.Request before it passes to
// the server.
OnRequest func(req *http.Request)
// Label: a optional label for debugging.
Label string
}
// dialer provides an implementation of net.Dial that proxies traffic via an
// upstream server proxy. Its Dial function uses DialServer to dial the server
// proxy and then issues a CONNECT request to instruct the server to connect to
// the destination at the specified network and addr.
type dialer struct {
Config
}
// NewDialer returns an implementation of net.Dial() based on the given Config.
func NewDialer(cfg Config) func(network, addr string) (net.Conn, error) {
d := &dialer{Config: cfg}
return d.Dial
}
// Dial is a net.Dial-compatible function.
func (d *dialer) Dial(network, addr string) (net.Conn, error) {
conn, err := d.DialServer()
if err != nil {
return nil, errors.New("Unable to dial server %v: %s", d.Label, err)
}
// Look for our special hacked "connect" transport used to signal
// that we should send a CONNECT request and tunnel all traffic through
// that.
if network == "connect" {
log.Tracef("Sending CONNECT REQUEST")
if err := d.sendCONNECT("tcp", addr, conn); err != nil {
// We discard this error, since we are only interested in sendCONNECT
_ = conn.Close()
return nil, err
}
}
return conn, nil
}
func (d *dialer) sendCONNECT(network, addr string, conn net.Conn) error {
if !strings.Contains(network, "tcp") {
return fmt.Errorf("%s connections are not supported, only tcp is supported", network)
}
req, err := buildCONNECTRequest(addr, d.OnRequest)
if err != nil {
return fmt.Errorf("Unable to construct CONNECT request: %s", err)
}
err = req.Write(conn)
if err != nil {
return fmt.Errorf("Unable to write CONNECT request: %s", err)
}
r := bufio.NewReader(conn)
err = checkCONNECTResponse(r, req)
return err
}
func buildCONNECTRequest(addr string, onRequest func(req *http.Request)) (*http.Request, error) {
req, err := http.NewRequest(httpConnectMethod, addr, nil)
if err != nil {
return nil, err
}
req.Host = addr
if onRequest != nil {
onRequest(req)
}
return req, nil
}
func checkCONNECTResponse(r *bufio.Reader, req *http.Request) error {
resp, err := http.ReadResponse(r, req)
if err != nil {
return fmt.Errorf("Error reading CONNECT response: %s", err)
}
if !sameStatusCodeClass(http.StatusOK, resp.StatusCode) {
return fmt.Errorf("Bad status code on CONNECT response: %d", resp.StatusCode)
}
bandwidth.Track(resp)
return nil
}
func sameStatusCodeClass(statusCode1 int, statusCode2 int) bool {
// HTTP response status code "classes" come in ranges of 100.
const classRange = 100
// These are all integers, so division truncates.
return statusCode1/classRange == statusCode2/classRange
}