-
Notifications
You must be signed in to change notification settings - Fork 3
/
client.go
254 lines (224 loc) · 7.88 KB
/
client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
package twirp
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"strconv"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
)
// HTTPClient is the interface used by generated clients to send HTTP requests.
// It is fulfilled by *(net/http).Client, which is sufficient for most users.
// Users can provide their own implementation for special retry policies.
//
// HTTPClient implementations should not follow redirects. Redirects are
// automatically disabled if *(net/http).Client is passed to client
// constructors. See the withoutRedirects function in this file for more
// details.
type HTTPClient interface {
Do(req *http.Request) (*http.Response, error)
}
// DoProtobufRequest is common code to make a request to the remote twirp service.
func DoProtobufRequest(ctx context.Context, client HTTPClient, url string, in, out proto.Message) (err error) {
reqBodyBytes, err := proto.Marshal(in)
if err != nil {
return clientError("failed to marshal proto request", err)
}
reqBody := bytes.NewBuffer(reqBodyBytes)
if err = ctx.Err(); err != nil {
return clientError("aborted because context was done", err)
}
req, err := newRequest(ctx, url, reqBody, "application/protobuf")
if err != nil {
return clientError("could not build request", err)
}
resp, err := client.Do(req)
if err != nil {
return clientError("failed to do request", err)
}
defer func() {
cerr := resp.Body.Close()
if err == nil && cerr != nil {
err = clientError("failed to close response body", cerr)
}
}()
if err = ctx.Err(); err != nil {
return clientError("aborted because context was done", err)
}
if resp.StatusCode != 200 {
return errorFromResponse(resp)
}
respBodyBytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
return clientError("failed to read response body", err)
}
if err = ctx.Err(); err != nil {
return clientError("aborted because context was done", err)
}
if err = proto.Unmarshal(respBodyBytes, out); err != nil {
return clientError("failed to unmarshal proto response", err)
}
return nil
}
// DoJSONRequest is common code to make a request to the remote twirp service.
func DoJSONRequest(ctx context.Context, client HTTPClient, url string, in, out proto.Message) (err error) {
marshaler := protojson.MarshalOptions{UseProtoNames: true}
var buf []byte
if buf, err = marshaler.Marshal(in); err != nil {
return clientError("failed to marshal json request", err)
}
if err = ctx.Err(); err != nil {
return clientError("aborted because context was done", err)
}
reqBody := bytes.NewReader(buf)
req, err := newRequest(ctx, url, reqBody, "application/json")
if err != nil {
return clientError("could not build request", err)
}
resp, err := client.Do(req)
if err != nil {
return clientError("failed to do request", err)
}
defer func() {
cerr := resp.Body.Close()
if err == nil && cerr != nil {
err = clientError("failed to close response body", cerr)
}
}()
if err = ctx.Err(); err != nil {
return clientError("aborted because context was done", err)
}
if resp.StatusCode != 200 {
return errorFromResponse(resp)
}
unmarshaler := protojson.UnmarshalOptions{}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return clientError("failed to read response body", err)
}
if err = unmarshaler.Unmarshal(body, out); err != nil {
return clientError("failed to unmarshal json response", err)
}
if err = ctx.Err(); err != nil {
return clientError("aborted because context was done", err)
}
return nil
}
// newRequest makes an http.Request from a client, adding common headers.
func newRequest(ctx context.Context, url string, reqBody io.Reader, contentType string) (*http.Request, error) {
req, err := http.NewRequest("POST", url, reqBody)
if err != nil {
return nil, err
}
req = req.WithContext(ctx)
if customHeader := getCustomHTTPReqHeaders(ctx); customHeader != nil {
req.Header = customHeader
}
req.Header.Set("Accept", contentType)
req.Header.Set("Content-Type", contentType)
req.Header.Set("Twirp-Version", "v5.5.0")
return req, nil
}
// getCustomHTTPReqHeaders retrieves a copy of any headers that are set in
// a context through the WithHTTPRequestHeaders function.
// If there are no headers set, or if they have the wrong type, nil is returned.
func getCustomHTTPReqHeaders(ctx context.Context) http.Header {
header, ok := HTTPRequestHeaders(ctx)
if !ok || header == nil {
return nil
}
copied := make(http.Header)
for k, vv := range header {
if vv == nil {
copied[k] = nil
continue
}
copied[k] = make([]string, len(vv))
copy(copied[k], vv)
}
return copied
}
// clientError adds consistency to errors generated in the client
func clientError(desc string, err error) Error {
return InternalErrorWith(wrapErr(err, desc))
}
// wrappedError implements the github.com/pkg/errors.Causer interface, allowing errors to be
// examined for their root cause.
type wrappedError struct {
msg string
cause error
}
func wrapErr(err error, msg string) error { return &wrappedError{msg: msg, cause: err} }
func (e *wrappedError) Cause() error { return e.cause }
func (e *wrappedError) Error() string { return e.msg + ": " + e.cause.Error() }
// errorFromResponse builds a Error from a non-200 HTTP response.
// If the response has a valid serialized Twirp error, then it's returned.
// If not, the response status code is used to generate a similar twirp
// error. See twirpErrorFromIntermediary for more info on intermediary errors.
func errorFromResponse(resp *http.Response) Error {
statusCode := resp.StatusCode
statusText := http.StatusText(statusCode)
if isHTTPRedirect(statusCode) {
// Unexpected redirect: it must be an error from an intermediary.
// Twirp clients don't follow redirects automatically, Twirp only handles
// POST requests, redirects should only happen on GET and HEAD requests.
location := resp.Header.Get("Location")
msg := fmt.Sprintf("unexpected HTTP status code %d %q received, Location=%q", statusCode, statusText, location)
return twirpErrorFromIntermediary(statusCode, msg, location)
}
respBodyBytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
return clientError("failed to read server error response body", err)
}
var tj twerr
if err := json.Unmarshal(respBodyBytes, &tj); err != nil {
// Invalid JSON response; it must be an error from an intermediary.
msg := fmt.Sprintf("Error from intermediary with HTTP status code %d %q", statusCode, statusText)
return twirpErrorFromIntermediary(statusCode, msg, string(respBodyBytes))
}
if !IsValidErrorCode(tj.Code()) {
msg := "invalid type returned from server error response: " + string(tj.Code())
return InternalError(msg)
}
return &tj
}
// twirpErrorFromIntermediary maps HTTP errors from non-twirp sources to twirp errors.
// The mapping is similar to gRPC: https://github.com/grpc/grpc/blob/master/doc/http-grpc-status-mapping.md.
// Returned twirp Errors have some additional metadata for inspection.
func twirpErrorFromIntermediary(status int, msg string, bodyOrLocation string) Error {
var code ErrorCode
if isHTTPRedirect(status) { // 3xx
code = Internal
} else {
switch status {
case 400: // Bad Request
code = Internal
case 401: // Unauthorized
code = Unauthenticated
case 403: // Forbidden
code = PermissionDenied
case 404: // Not Found
code = BadRoute
case 429, 502, 503, 504: // Too Many Requests, Bad Gateway, Service Unavailable, Gateway Timeout
code = Unavailable
default: // All other codes
code = Unknown
}
}
twerr := NewError(code, msg)
twerr = twerr.WithMeta("http_error_from_intermediary", "true") // to easily know if this error was from intermediary
twerr = twerr.WithMeta("status_code", strconv.Itoa(status))
if isHTTPRedirect(status) {
twerr = twerr.WithMeta("location", bodyOrLocation)
} else {
twerr = twerr.WithMeta("body", bodyOrLocation)
}
return twerr
}
func isHTTPRedirect(status int) bool {
return status >= 300 && status <= 399
}