-
Notifications
You must be signed in to change notification settings - Fork 0
/
http.go
209 lines (190 loc) · 6.2 KB
/
http.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
package rdv
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"net/url"
"strings"
"time"
)
// Returns an rdv http/1.1 request with the provided options
func newRdvRequest(meta *Meta, addr string, header http.Header) (*http.Request, error) {
urlStr, err := url.JoinPath(addr, meta.Token)
if err != nil {
return nil, err
}
req, err := http.NewRequest(meta.Method, urlStr, nil)
if err != nil {
return nil, err
}
if header != nil {
req.Header = header
}
setUpgradeHeaders(req.Header, protocolName)
req.Header.Set(hSelfAddrs, formatAddrPorts(meta.SelfAddrs))
return req, nil
}
// Returns an rdv http/1.1 response with the provided options
func newRdvResponse(meta *Meta) *http.Response {
resp := newResponse(http.StatusSwitchingProtocols)
setUpgradeHeaders(resp.Header, protocolName)
resp.Header.Set(hPeerAddrs, formatAddrPorts(meta.PeerAddrs))
if meta.ObservedAddr != nil {
resp.Header.Set(hObservedAddr, meta.ObservedAddr.String())
}
return resp
}
// Parses an rdv http/1.1 request. Returns errUpgrade if upgrade is missing.
func parseRdvRequest(req *http.Request) (*Meta, error) {
// Check that upgrade is intended before protocol, to report a better error
if err := checkUpgradeHeaders(req.Header, protocolName); err != nil {
return nil, err
}
if strings.ToLower(req.Proto) != "http/1.1" {
return nil, fmt.Errorf("%w: bad http version for upgrade %s", errUpgrade, req.Proto)
}
token, _ := strings.CutPrefix(req.URL.Path, "/")
m, err := newMeta(req.Method, token)
if err != nil {
return nil, err
}
m.SelfAddrs, err = parseAddrPorts(req.Header.Get(hSelfAddrs))
if err != nil {
return nil, fmt.Errorf("invalid self addrs [%s]", req.Header.Get(hSelfAddrs))
}
if len(m.SelfAddrs) > maxAddrs-1 {
return nil, fmt.Errorf("too many self addrs [%s]", req.Header.Get(hSelfAddrs))
}
return m, nil
}
// Parses an rdv http/1.1 response, and modifies to the provided meta.
func parseRdvResponse(meta *Meta, resp *http.Response) (err error) {
if resp.StatusCode != http.StatusSwitchingProtocols {
return fmt.Errorf("unexpected http status %v", resp.Status)
}
if err = checkUpgradeHeaders(resp.Header, protocolName); err != nil {
return fmt.Errorf("%w: %v", ErrBadHandshake, err)
}
meta.PeerAddrs, err = parseAddrPorts(resp.Header.Get(hPeerAddrs))
if err != nil {
return fmt.Errorf("%w: invalid peer addrs %s", ErrBadHandshake, resp.Header.Get(hPeerAddrs))
}
if len(meta.PeerAddrs) > maxAddrs {
return fmt.Errorf("%w: too many peer addrs %s", ErrBadHandshake, resp.Header.Get(hPeerAddrs))
}
if resp.Header.Get(hObservedAddr) != "" {
observedAddr, err := netip.ParseAddrPort(resp.Header.Get(hObservedAddr))
if err != nil {
return fmt.Errorf("%w: invalid observed addr %s", ErrBadHandshake, resp.Header.Get(hObservedAddr))
}
meta.ObservedAddr = &observedAddr
}
return nil
}
// Upgrade an incoming request into a server-side rdv conn
func upgradeRdv(w http.ResponseWriter, req *http.Request) (*Conn, error) {
meta, err := parseRdvRequest(req)
if errors.Is(err, errUpgrade) {
http.Error(w, err.Error(), http.StatusUpgradeRequired)
return nil, err
} else if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return nil, err
}
nc, brw, err := http.NewResponseController(w).Hijack()
if err != nil {
// Already checked for http version while parsing, so this should be an internal err
http.Error(w, "", http.StatusInternalServerError)
return nil, err
}
req.Body = nil
nc.SetDeadline(time.Time{})
sw := newRelayConn(nc, brw.Reader, meta, req)
return sw, nil
}
// Writes a http/1.1 request and reads the response directly from the conn.
// The request's context is ignored.
func doHttp(nc net.Conn, br *bufio.Reader, req *http.Request) (*http.Response, error) {
err := req.Write(nc)
if err != nil {
return nil, err
}
return http.ReadResponse(br, nil)
}
// Checks that "Connection: upgrade" and "Upgrade: <protocol>" is set
func checkUpgradeHeaders(h http.Header, protocol string) error {
connection := strings.ToLower(h.Get("Connection"))
if connection != "upgrade" {
return fmt.Errorf("%w: requires connection upgrade", errUpgrade)
}
// Upgrade allows multiple comma-separated protos, but we don't, so we expect an exact match.
upgrade := strings.TrimSpace(strings.ToLower(h.Get("Upgrade")))
if upgrade == "" {
return fmt.Errorf("%w: missing upgrade header", errUpgrade)
}
if upgrade != protocol {
return fmt.Errorf("%w: bad upgrade %s", errUpgrade, upgrade)
}
return nil
}
// Set the "Connection: upgrade" and "Upgrade: <protocol>" headers
func setUpgradeHeaders(h http.Header, protocol string) {
h.Set("Connection", "upgrade")
h.Set("Upgrade", protocol)
}
// Slurp up a bit of the response body to aid in debugging prior to closing the response.
func slurp(resp *http.Response, size int) {
buf := make([]byte, size)
n, _ := io.ReadFull(resp.Body, buf)
resp.Body = io.NopCloser(bytes.NewReader(buf[:n]))
}
// Returns an http/1.1 response for an upgraded conn
func newResponse(status int) *http.Response {
return &http.Response{
ProtoMajor: 1,
ProtoMinor: 1,
StatusCode: status,
Header: make(http.Header),
}
}
// Write a response err and close the conn, with a short deadline
func writeResponseErr(nc net.Conn, statusCode int, reason string) error {
defer nc.Close()
resp := newResponse(statusCode)
resp.Body = io.NopCloser(strings.NewReader(reason))
// From HTTP std lib
resp.Header.Set("Content-Type", "text/plain; charset=utf-8")
resp.Header.Set("X-Content-Type-Options", "nosniff")
nc.SetDeadline(time.Now().Add(shortWriteTimeout))
return resp.Write(nc)
}
// Parse a comma-separated ip:port string
// TODO(https://github.com/golang/go/issues/41046): Structured http field parsing
func parseAddrPorts(addrStr string) (addrs []netip.AddrPort, err error) {
if addrStr == "" {
return nil, nil
}
parts := strings.Split(addrStr, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
addr, err := netip.ParseAddrPort(part)
if err != nil {
return nil, err
}
addrs = append(addrs, addr)
}
return
}
// Returns a comma-separated ip:port string
func formatAddrPorts(addrs []netip.AddrPort) string {
var parts []string
for _, addr := range addrs {
parts = append(parts, addr.String())
}
return strings.Join(parts, ", ")
}