diff --git a/internal/http_client.go b/internal/http_client.go index 47031cef..900645e3 100644 --- a/internal/http_client.go +++ b/internal/http_client.go @@ -41,9 +41,18 @@ import ( type HTTPClient struct { Client *http.Client RetryConfig *RetryConfig - ErrParser ErrorParser + ErrParser ErrorParser // Deprecated. Use CreateErrFn instead. + CreateErrFn CreateErrFn + SuccessFn SuccessFn + Opts []HTTPOption } +// SuccessFn is a function that checks if a Response indicates success. +type SuccessFn func(r *Response) bool + +// CreateErrFn is a function that creates an error from a given Response. +type CreateErrFn func(r *Response) error + // NewHTTPClient creates a new HTTPClient using the provided client options and the default // RetryConfig. // @@ -58,6 +67,7 @@ func NewHTTPClient(ctx context.Context, opts ...option.ClientOption) (*HTTPClien if err != nil { return nil, "", err } + twoMinutes := time.Duration(2) * time.Minute client := &HTTPClient{ Client: hc, @@ -74,9 +84,32 @@ func NewHTTPClient(ctx context.Context, opts ...option.ClientOption) (*HTTPClien return client, endpoint, nil } +// Request contains all the parameters required to construct an outgoing HTTP request. +type Request struct { + Method string + URL string + Body HTTPEntity + Opts []HTTPOption + SuccessFn SuccessFn + CreateErrFn CreateErrFn +} + +// Response contains information extracted from an HTTP response. +type Response struct { + Status int + Header http.Header + Body []byte + errParser ErrorParser +} + // Do executes the given Request, and returns a Response. // // If a RetryConfig is specified on the client, Do attempts to retry failing requests. +// +// If SuccessFn is set on the client or on the request, the response is validated against that +// function. If this validation fails, returns an error. These errors are created using the +// CreateErrFn on the client or on the request. If neither is set, CreatePlatformError is +// used as the default error function. func (c *HTTPClient) Do(ctx context.Context, req *Request) (*Response, error) { var result *attemptResult var err error @@ -93,42 +126,103 @@ func (c *HTTPClient) Do(ctx context.Context, req *Request) (*Response, error) { return nil, err } } - return result.handleResponse() + return c.handleResult(req, result) +} + +// DoAndUnmarshal behaves similar to Do, but additionally unmarshals the response payload into +// the given pointer. +// +// Unmarshal takes place only if the response does not represent an error (as determined by +// the Do function) and v is not nil. If the unmarshal fails, an error is returned even if the +// original response indicated success. +func (c *HTTPClient) DoAndUnmarshal(ctx context.Context, req *Request, v interface{}) (*Response, error) { + resp, err := c.Do(ctx, req) + if err != nil { + return nil, err + } + + if v != nil { + if err := json.Unmarshal(resp.Body, v); err != nil { + return nil, fmt.Errorf("error while parsing response: %v", err) + } + } + + return resp, nil } func (c *HTTPClient) attempt(ctx context.Context, req *Request, retries int) (*attemptResult, error) { - hr, err := req.buildHTTPRequest() + hr, err := req.buildHTTPRequest(c.Opts) if err != nil { return nil, err } resp, err := c.Client.Do(hr.WithContext(ctx)) - result := &attemptResult{ - Resp: resp, - Err: err, - ErrParser: c.ErrParser, + result := &attemptResult{} + if err != nil { + result.Err = err + } else { + // Read the response body here forcing any I/O errors to occur so that retry logic will + // cover them as well. + ir, err := newResponse(resp, c.ErrParser) + result.Resp = ir + result.Err = err } // If a RetryConfig is available, always consult it to determine if the request should be retried // or not. Even if there was a network error, we may not want to retry the request based on the // RetryConfig that is in effect. if c.RetryConfig != nil { - delay, retry := c.RetryConfig.retryDelay(retries, resp, err) + delay, retry := c.RetryConfig.retryDelay(retries, resp, result.Err) result.RetryAfter = delay result.Retry = retry - if retry && resp != nil { - defer resp.Body.Close() - } } return result, nil } +func (c *HTTPClient) handleResult(req *Request, result *attemptResult) (*Response, error) { + if result.Err != nil { + return nil, fmt.Errorf("error while making http call: %v", result.Err) + } + + if !c.success(req, result.Resp) { + return nil, c.newError(req, result.Resp) + } + + return result.Resp, nil +} + +func (c *HTTPClient) success(req *Request, resp *Response) bool { + var successFn SuccessFn + if req.SuccessFn != nil { + successFn = req.SuccessFn + } else if c.SuccessFn != nil { + successFn = c.SuccessFn + } + + if successFn != nil { + return successFn(resp) + } + + // TODO: Default to HasSuccessStatusCode + return true +} + +func (c *HTTPClient) newError(req *Request, resp *Response) error { + createErr := CreatePlatformError + if req.CreateErrFn != nil { + createErr = req.CreateErrFn + } else if c.CreateErrFn != nil { + createErr = c.CreateErrFn + } + + return createErr(resp) +} + type attemptResult struct { - Resp *http.Response + Resp *Response Err error Retry bool RetryAfter time.Duration - ErrParser ErrorParser } func (r *attemptResult) waitForRetry(ctx context.Context) error { @@ -141,23 +235,7 @@ func (r *attemptResult) waitForRetry(ctx context.Context) error { return ctx.Err() } -func (r *attemptResult) handleResponse() (*Response, error) { - if r.Err != nil { - return nil, r.Err - } - return newResponse(r.Resp, r.ErrParser) -} - -// Request contains all the parameters required to construct an outgoing HTTP request. -type Request struct { - Method string - URL string - Body HTTPEntity - Opts []HTTPOption -} - -func (r *Request) buildHTTPRequest() (*http.Request, error) { - var opts []HTTPOption +func (r *Request) buildHTTPRequest(opts []HTTPOption) (*http.Request, error) { var data io.Reader if r.Body != nil { b, err := r.Body.Bytes() @@ -203,14 +281,6 @@ func (e *jsonEntity) Mime() string { return "application/json" } -// Response contains information extracted from an HTTP response. -type Response struct { - Status int - Header http.Header - Body []byte - errParser ErrorParser -} - func newResponse(resp *http.Response, errParser ErrorParser) (*Response, error) { defer resp.Body.Close() b, err := ioutil.ReadAll(resp.Body) @@ -229,6 +299,8 @@ func newResponse(resp *http.Response, errParser ErrorParser) (*Response, error) // // Returns an error if the status code does not match. If an ErrorParser is specified, uses that to // construct the returned error message. Otherwise includes the full response body in the error. +// +// Deprecated. Directly verify the Status field on the Response instead. func (r *Response) CheckStatus(want int) error { if r.Status == want { return nil @@ -249,6 +321,8 @@ func (r *Response) CheckStatus(want int) error { // // Unmarshal uses https://golang.org/pkg/encoding/json/#Unmarshal internally, and hence v has the // same requirements as the json package. +// +// Deprecated. Use DoAndUnmarshal function instead. func (r *Response) Unmarshal(want int, v interface{}) error { if err := r.CheckStatus(want); err != nil { return err @@ -257,6 +331,8 @@ func (r *Response) Unmarshal(want int, v interface{}) error { } // ErrorParser is a function that is used to construct custom error messages. +// +// Deprecated. Use SuccessFn and CreateErrFn instead. type ErrorParser func([]byte) string // HTTPOption is an additional parameter that can be specified to customize an outgoing request. @@ -290,6 +366,38 @@ func WithQueryParams(qp map[string]string) HTTPOption { } } +// HasSuccessStatus returns true if the response status code is in the 2xx range. +func HasSuccessStatus(r *Response) bool { + return r.Status >= http.StatusOK && r.Status < http.StatusNotModified +} + +// CreatePlatformError parses the response payload as a GCP error response +// and create an error from the details extracted. +// +// If the response failes to parse, or otherwise doesn't provide any useful details +// CreatePlatformError creates an error with some sensible defaults. +func CreatePlatformError(resp *Response) error { + var gcpError struct { + Error struct { + Status string `json:"status"` + Message string `json:"message"` + } `json:"error"` + } + json.Unmarshal(resp.Body, &gcpError) // ignore any json parse errors at this level + code := gcpError.Error.Status + if code == "" { + code = "UNKNOWN" + } + + message := gcpError.Error.Message + if message == "" { + message = fmt.Sprintf( + "unexpected http response with status: %d; body: %s", resp.Status, string(resp.Body)) + } + + return Error(code, message) +} + // RetryConfig specifies how the HTTPClient should retry failing HTTP requests. // // A request is never retried more than MaxRetries times. If CheckForRetry is nil, all network diff --git a/internal/http_client_test.go b/internal/http_client_test.go index fa3b5e20..e5db3a13 100644 --- a/internal/http_client_test.go +++ b/internal/http_client_test.go @@ -17,16 +17,20 @@ import ( "context" "encoding/json" "errors" + "fmt" "io/ioutil" "net/http" "net/http/httptest" "reflect" + "strings" "testing" "time" "google.golang.org/api/option" ) +const defaultMaxRetries = 4 + var ( testRetryConfig = RetryConfig{ MaxRetries: 4, @@ -184,6 +188,218 @@ func TestHTTPClient(t *testing.T) { } } +func TestDefaultOpts(t *testing.T) { + var header string + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + header = r.Header.Get("Test-Header") + w.Write([]byte("{}")) + }) + server := httptest.NewServer(handler) + defer server.Close() + + client := &HTTPClient{ + Client: http.DefaultClient, + Opts: []HTTPOption{ + WithHeader("Test-Header", "test-value"), + }, + } + req := &Request{ + Method: http.MethodGet, + URL: fmt.Sprintf("%s%s", server.URL, wantURL), + } + + resp, err := client.Do(context.Background(), req) + if err != nil { + t.Fatal(err) + } + + if resp.Status != http.StatusOK { + t.Errorf("Status = %d; want = %d", resp.Status, http.StatusOK) + } + if header != "test-value" { + t.Errorf("Test-Header = %q; want = %q", header, "test-value") + } +} + +func TestSuccessFn(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("{}")) + }) + server := httptest.NewServer(handler) + defer server.Close() + + client := &HTTPClient{ + Client: http.DefaultClient, + SuccessFn: func(r *Response) bool { + return false + }, + } + get := &Request{ + Method: http.MethodGet, + URL: server.URL, + } + want := "unexpected http response with status: 200; body: {}" + + resp, err := client.Do(context.Background(), get) + if resp != nil || err == nil || err.Error() != want { + t.Fatalf("Do() = (%v, %v); want = (nil, %q)", resp, err, want) + } + + if !HasErrorCode(err, "UNKNOWN") { + t.Errorf("ErrorCode = %q; want = %q", err.(*FirebaseError).Code, "UNKNOWN") + } +} + +func TestSuccessFnOnRequest(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("{}")) + }) + server := httptest.NewServer(handler) + defer server.Close() + + client := &HTTPClient{ + Client: http.DefaultClient, + SuccessFn: HasSuccessStatus, + } + get := &Request{ + Method: http.MethodGet, + URL: server.URL, + SuccessFn: func(r *Response) bool { + return false + }, + } + want := "unexpected http response with status: 200; body: {}" + + resp, err := client.Do(context.Background(), get) + if resp != nil || err == nil || err.Error() != want { + t.Fatalf("Do() = (%v, %v); want = (nil, %q)", resp, err, want) + } + + if !HasErrorCode(err, "UNKNOWN") { + t.Errorf("ErrorCode = %q; want = %q", err.(*FirebaseError).Code, "UNKNOWN") + } +} + +func TestPlatformError(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := `{ + "error": { + "status": "NOT_FOUND", + "message": "Requested entity not found" + } + }` + + w.WriteHeader(http.StatusNotFound) + w.Write([]byte(resp)) + }) + server := httptest.NewServer(handler) + defer server.Close() + + client := &HTTPClient{ + Client: http.DefaultClient, + SuccessFn: HasSuccessStatus, + } + get := &Request{ + Method: http.MethodGet, + URL: server.URL, + } + want := "Requested entity not found" + + resp, err := client.Do(context.Background(), get) + if resp != nil || err == nil || err.Error() != want { + t.Fatalf("Do() = (%v, %v); want = (nil, %q)", resp, err, want) + } + + if !HasErrorCode(err, "NOT_FOUND") { + t.Errorf("ErrorCode = %q; want = %q", err.(*FirebaseError).Code, "NOT_FOUND") + } +} + +func TestPlatformErrorWithoutDetails(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte("{}")) + }) + server := httptest.NewServer(handler) + defer server.Close() + + client := &HTTPClient{ + Client: http.DefaultClient, + SuccessFn: HasSuccessStatus, + } + get := &Request{ + Method: http.MethodGet, + URL: server.URL, + } + want := "unexpected http response with status: 404; body: {}" + + resp, err := client.Do(context.Background(), get) + if resp != nil || err == nil || err.Error() != want { + t.Fatalf("Do() = (%v, %v); want = (nil, %q)", resp, err, want) + } + + if !HasErrorCode(err, "UNKNOWN") { + t.Errorf("ErrorCode = %q; want = %q", err.(*FirebaseError).Code, "UNKNOWN") + } +} + +func TestCreateErrFn(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte("{}")) + }) + server := httptest.NewServer(handler) + defer server.Close() + + client := &HTTPClient{ + Client: http.DefaultClient, + CreateErrFn: func(r *Response) error { + return fmt.Errorf("custom error with status: %d", r.Status) + }, + SuccessFn: HasSuccessStatus, + } + get := &Request{ + Method: http.MethodGet, + URL: server.URL, + } + want := "custom error with status: 404" + + resp, err := client.Do(context.Background(), get) + if resp != nil || err == nil || err.Error() != want { + t.Fatalf("Do() = (%v, %v); want = (nil, %q)", resp, err, want) + } +} + +func TestCreateErrFnOnRequest(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte("{}")) + }) + server := httptest.NewServer(handler) + defer server.Close() + + client := &HTTPClient{ + Client: http.DefaultClient, + CreateErrFn: func(r *Response) error { + return fmt.Errorf("custom error with status: %d", r.Status) + }, + SuccessFn: HasSuccessStatus, + } + get := &Request{ + Method: http.MethodGet, + URL: server.URL, + CreateErrFn: func(r *Response) error { + return fmt.Errorf("custom error from req with status: %d", r.Status) + }, + } + want := "custom error from req with status: 404" + + resp, err := client.Do(context.Background(), get) + if resp != nil || err == nil || err.Error() != want { + t.Fatalf("Do() = (%v, %v); want = (nil, %q)", resp, err, want) + } +} + func TestContext(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") @@ -381,6 +597,21 @@ func TestNoRetryOnRequestBuildError(t *testing.T) { } } +func TestNoRetryOnInvalidMethod(t *testing.T) { + client := &HTTPClient{ + Client: http.DefaultClient, + RetryConfig: &testRetryConfig, + } + + req := &Request{ + Method: "Invalid/Method", + URL: "https://firebase.google.com", + } + if _, err := client.Do(context.Background(), req); err == nil { + t.Errorf("Do() = nil; want = error") + } +} + func TestNoRetryOnHTTPSuccessCodes(t *testing.T) { for i := http.StatusOK; i < http.StatusBadRequest; i++ { resp := &http.Response{ @@ -596,7 +827,6 @@ func TestNewHTTPClientRetryOnNetworkErrors(t *testing.T) { t.Errorf("Do() = (%v, %v); want = (nil, error)", resp, err) } - const defaultMaxRetries = 4 wantRequests := 1 + defaultMaxRetries if tansport.RequestAttempts != wantRequests { t.Errorf("Total requests = %d; want = %d", tansport.RequestAttempts, wantRequests) @@ -620,7 +850,6 @@ func TestNewHTTPClientRetryOnHTTPErrors(t *testing.T) { t.Fatal(err) } client.RetryConfig.ExpBackoffFactor = 0 - const defaultMaxRetries = 4 for _, status = range []int{http.StatusInternalServerError, http.StatusServiceUnavailable} { requests = 0 req := &Request{Method: http.MethodGet, URL: server.URL} @@ -666,6 +895,35 @@ func TestNewHttpClientNoRetryOnNotFound(t *testing.T) { } } +func TestNewHttpClientRetryOnResponseReadError(t *testing.T) { + requests := 0 + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requests++ + // Lie about the content-length forcing a read error on the client + w.Header().Set("Content-Length", "1") + }) + server := httptest.NewServer(handler) + defer server.Close() + + client, _, err := NewHTTPClient(context.Background(), tokenSourceOpt) + if err != nil { + t.Fatal(err) + } + client.RetryConfig.ExpBackoffFactor = 0 + wantPrefix := "error while making http call: " + + req := &Request{Method: http.MethodGet, URL: server.URL} + resp, err := client.Do(context.Background(), req) + if resp != nil || err == nil || !strings.HasPrefix(err.Error(), wantPrefix) { + t.Errorf("Do() = (%v, %v); want = (nil, %q)", resp, err, wantPrefix) + } + + wantRequests := 1 + defaultMaxRetries + if requests != wantRequests { + t.Errorf("Total requests = %d; want = %d", requests, wantRequests) + } +} + type faultyEntity struct { RequestAttempts int } diff --git a/internal/json_http_client_test.go b/internal/json_http_client_test.go new file mode 100644 index 00000000..53c9da72 --- /dev/null +++ b/internal/json_http_client_test.go @@ -0,0 +1,202 @@ +// Copyright 2019 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +const wantURL = "/test" + +func TestDoAndUnmarshalGet(t *testing.T) { + var req *http.Request + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + req = r + resp := `{ + "name": "test" + }` + w.Write([]byte(resp)) + }) + server := httptest.NewServer(handler) + defer server.Close() + + client := &HTTPClient{ + Client: http.DefaultClient, + } + get := &Request{ + Method: http.MethodGet, + URL: fmt.Sprintf("%s%s", server.URL, wantURL), + } + var data responseBody + + resp, err := client.DoAndUnmarshal(context.Background(), get, &data) + if err != nil { + t.Fatal(err) + } + + if resp.Status != http.StatusOK { + t.Errorf("Status = %d; want = %d", resp.Status, http.StatusOK) + } + if data.Name != "test" { + t.Errorf("Data = %v; want = {Name: %q}", data, "test") + } + if req.Method != http.MethodGet { + t.Errorf("Method = %q; want = %q", req.Method, http.MethodGet) + } + if req.URL.Path != wantURL { + t.Errorf("URL = %q; want = %q", req.URL.Path, wantURL) + } +} + +func TestDoAndUnmarshalPost(t *testing.T) { + var req *http.Request + var b []byte + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + req = r + b, _ = ioutil.ReadAll(r.Body) + resp := `{ + "name": "test" + }` + w.Write([]byte(resp)) + }) + server := httptest.NewServer(handler) + defer server.Close() + + client := &HTTPClient{ + Client: http.DefaultClient, + } + post := &Request{ + Method: http.MethodPost, + URL: fmt.Sprintf("%s%s", server.URL, wantURL), + Body: NewJSONEntity(map[string]string{"input": "test-input"}), + } + var data responseBody + + resp, err := client.DoAndUnmarshal(context.Background(), post, &data) + if err != nil { + t.Fatal(err) + } + + if resp.Status != http.StatusOK { + t.Errorf("Status = %d; want = %d", resp.Status, http.StatusOK) + } + if data.Name != "test" { + t.Errorf("Data = %v; want = {Name: %q}", data, "test") + } + if req.Method != http.MethodPost { + t.Errorf("Method = %q; want = %q", req.Method, http.MethodGet) + } + if req.URL.Path != wantURL { + t.Errorf("URL = %q; want = %q", req.URL.Path, wantURL) + } + + var parsed struct { + Input string `json:"input"` + } + if err := json.Unmarshal(b, &parsed); err != nil { + t.Fatal(err) + } + if parsed.Input != "test-input" { + t.Errorf("Request Body = %v; want = {Input: %q}", parsed, "test-input") + } +} + +func TestDoAndUnmarshalNotJSON(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("not json")) + }) + server := httptest.NewServer(handler) + defer server.Close() + + client := &HTTPClient{ + Client: http.DefaultClient, + } + get := &Request{ + Method: http.MethodGet, + URL: server.URL, + } + var data interface{} + wantPrefix := "error while parsing response: " + + resp, err := client.DoAndUnmarshal(context.Background(), get, &data) + if resp != nil || err == nil || !strings.HasPrefix(err.Error(), wantPrefix) { + t.Errorf("DoAndUnmarshal() = (%v, %v); want = (nil, %q)", resp, err, wantPrefix) + } + + if data != nil { + t.Errorf("Data = %v; want = nil", data) + } +} + +func TestDoAndUnmarshalNilPointer(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("not json")) + }) + server := httptest.NewServer(handler) + defer server.Close() + + client := &HTTPClient{ + Client: http.DefaultClient, + } + get := &Request{ + Method: http.MethodGet, + URL: server.URL, + } + + resp, err := client.DoAndUnmarshal(context.Background(), get, nil) + if err != nil { + t.Fatalf("DoAndUnmarshal() = %v; want = nil", err) + } + + if resp.Status != http.StatusOK { + t.Errorf("Status = %d; want = %d", resp.Status, http.StatusOK) + } +} + +func TestDoAndUnmarshalTransportError(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + server := httptest.NewServer(handler) + server.Close() + + client := &HTTPClient{ + Client: http.DefaultClient, + } + get := &Request{ + Method: http.MethodGet, + URL: server.URL, + } + var data interface{} + wantPrefix := "error while making http call: " + + resp, err := client.DoAndUnmarshal(context.Background(), get, &data) + if resp != nil || err == nil || !strings.HasPrefix(err.Error(), wantPrefix) { + t.Errorf("DoAndUnmarshal() = (%v, %v); want = (nil, %q)", resp, err, wantPrefix) + } + + if data != nil { + t.Errorf("Data = %v; want = nil", data) + } +} + +type responseBody struct { + Name string `json:"name"` +}