/
httpproxy.go
103 lines (86 loc) · 2.31 KB
/
httpproxy.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
package httpproxy
import (
"context"
"fmt"
"io"
"log"
"net"
"net/http"
"net/http/httputil"
)
type Proxy httputil.ReverseProxy
func (p *Proxy) IsProxyRequest(req *http.Request) bool {
return req.URL.IsAbs() || req.Method == http.MethodConnect
}
func (p *Proxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodConnect {
if p.Director == nil {
p.Director = func(req *http.Request) {}
}
(*httputil.ReverseProxy)(p).ServeHTTP(rw, req)
return
} else if p.Director != nil {
p.Director(req)
}
hj, ok := rw.(http.Hijacker)
if !ok {
p.getErrorHandler()(rw, req, fmt.Errorf("can't %s using non-Hijacker ResponseWriter type %T", req.Method, rw))
return
}
// TODO: check port
target, err := p.dial(req.Context(), "tcp", req.URL.Host)
if err != nil {
p.getErrorHandler()(rw, req, fmt.Errorf("dial failed on %s: %v", req.Method, err))
return
}
defer target.Close()
source, _, err := hj.Hijack()
if err != nil {
p.getErrorHandler()(rw, req, fmt.Errorf("hijack failed on %s: %v", req.Method, err))
return
}
defer source.Close()
_, err = source.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n"))
if err != nil {
p.getErrorHandler()(rw, req, fmt.Errorf("write failed on %s: %v", req.Method, err))
return
}
done := make(chan struct{})
go copy(source, target, done)
go copy(target, source, done)
select {
case <-req.Context().Done():
case done <- struct{}{}:
}
close(done)
}
func copy(to io.Writer, from io.Reader, done <-chan struct{}) {
io.Copy(to, from)
<-done
}
func (p *Proxy) logf(format string, args ...interface{}) {
if p.ErrorLog != nil {
p.ErrorLog.Printf(format, args...)
} else {
log.Printf(format, args...)
}
}
func (p *Proxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
p.logf("http: proxy error: %v", err)
rw.WriteHeader(http.StatusBadGateway)
}
func (p *Proxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) {
if p.ErrorHandler != nil {
return p.ErrorHandler
}
return p.defaultErrorHandler
}
var zeroDialer net.Dialer
func (p *Proxy) dial(ctx context.Context, network, addr string) (net.Conn, error) {
if t, ok := p.Transport.(*http.Transport); ok {
if t.DialContext != nil {
return t.DialContext(ctx, network, addr)
}
}
return zeroDialer.DialContext(ctx, network, addr)
}