From 023cfdf78b4dea338b882a02519314e01804674b Mon Sep 17 00:00:00 2001 From: Julien Duchesne Date: Tue, 6 Jun 2023 09:42:17 -0400 Subject: [PATCH] Allow setting status codes to retry (#152) For https://github.com/grafana/terraform-provider-grafana/issues/893 We currently retry 429 and 5xx. This is retained as a default in this PR but there's a new `RetryStatusCodes` that, if set, overrides these statuses That way, the user that opened the issue above will be able to also retry 407 on their Azure instance **Had to do some fixes in dashboard tests, the httptest servers were not properly used, not closed and leaked out of the tests** --- client.go | 39 +++++++++++++- client_test.go | 50 +++++++++++++++++- dashboard_test.go | 111 ++++++++++++++++++++++------------------ mock.go => mock_test.go | 10 ++-- 4 files changed, 153 insertions(+), 57 deletions(-) rename mock.go => mock_test.go (90%) diff --git a/client.go b/client.go index e3f5aed3..5052e463 100644 --- a/client.go +++ b/client.go @@ -42,6 +42,8 @@ type Config struct { NumRetries int // RetryTimeout says how long to wait before retrying a request RetryTimeout time.Duration + // RetryStatusCodes contains the list of status codes to retry, use "x" as a wildcard for a single digit (default: [429, 5xx]) + RetryStatusCodes []string } // New creates a new Grafana client. @@ -80,6 +82,10 @@ func (c *Client) request(method, requestPath string, query url.Values, body []by err error bodyContents []byte ) + retryStatusCodes := c.config.RetryStatusCodes + if len(retryStatusCodes) == 0 { + retryStatusCodes = []string{"429", "5xx"} + } // retry logic for n := 0; n <= c.config.NumRetries; n++ { @@ -115,8 +121,11 @@ func (c *Client) request(method, requestPath string, query url.Values, body []by continue } - // Exit the loop if we have something final to return. This is anything < 500, if it's not a 429. - if resp.StatusCode < http.StatusInternalServerError && resp.StatusCode != http.StatusTooManyRequests { + shouldRetry, err := matchRetryCode(resp.StatusCode, retryStatusCodes) + if err != nil { + return err + } + if !shouldRetry { break } } @@ -179,3 +188,29 @@ func (c *Client) newRequest(method, requestPath string, query url.Values, body i req.Header.Add("Content-Type", "application/json") return req, err } + +// matchRetryCode checks if the status code matches any of the configured retry status codes. +func matchRetryCode(gottenCode int, retryCodes []string) (bool, error) { + gottenCodeStr := strconv.Itoa(gottenCode) + for _, retryCode := range retryCodes { + if len(retryCode) != 3 { + return false, fmt.Errorf("invalid retry status code: %s", retryCode) + } + matched := true + for i := range retryCode { + c := retryCode[i] + if c == 'x' { + continue + } + if gottenCodeStr[i] != c { + matched = false + break + } + } + if matched { + return true, nil + } + } + + return false, nil +} diff --git a/client_test.go b/client_test.go index 47d2dee3..5c5c3a86 100644 --- a/client_test.go +++ b/client_test.go @@ -8,6 +8,8 @@ import ( "net/http" "net/http/httptest" "net/url" + "strconv" + "strings" "testing" "time" ) @@ -209,8 +211,11 @@ func TestClient_requestWithRetries(t *testing.T) { case 2: http.Error(w, `{"error":"calm down"}`, http.StatusTooManyRequests) - default: + case 3: w.Write([]byte(`{"foo":"bar"}`)) //nolint:errcheck + + default: + t.Errorf("unexpected retry %d", try) } })) defer ts.Close() @@ -255,6 +260,49 @@ func TestClient_requestWithRetries(t *testing.T) { t.Logf("request successful after %d retries", try) } +func TestClient_CustomRetryStatusCode(t *testing.T) { + body := []byte(`lorem ipsum dolor sit amet`) + var try int + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + + try++ + + switch try { + case 1, 2: + http.Error(w, `{"error":"weird error"}`, http.StatusUpgradeRequired) + default: + http.Error(w, `{"error":"failed"}`, http.StatusInternalServerError) + } + })) + defer ts.Close() + + httpClient := &http.Client{ + Transport: &customRoundTripper{}, + } + + c, err := New(ts.URL, Config{ + NumRetries: 5, + Client: httpClient, + RetryTimeout: 50 * time.Millisecond, + RetryStatusCodes: []string{strconv.Itoa(http.StatusUpgradeRequired)}, + }) + if err != nil { + t.Fatalf("unexpected error creating client: %v", err) + } + + var got interface{} + err = c.request(http.MethodPost, "/", nil, body, &got) + expectedErr := "status: 500, body: {\"error\":\"failed\"}" // The 500 is not retried because it's not in RetryStatusCodes + if strings.TrimSpace(err.Error()) != expectedErr { + t.Fatalf("expected err: %s, got err: %v", expectedErr, err) + } + + if try != 3 { + t.Fatalf("unexpected number of tries: %d", try) + } +} + type customRoundTripper struct { try int } diff --git a/dashboard_test.go b/dashboard_test.go index aa38c17d..caf46a9c 100644 --- a/dashboard_test.go +++ b/dashboard_test.go @@ -1,6 +1,7 @@ package gapi import ( + "fmt" "strings" "testing" @@ -78,67 +79,79 @@ func TestDashboardCreateAndUpdate(t *testing.T) { } func TestDashboardGet(t *testing.T) { - client := gapiTestTools(t, 200, getDashboardResponse) + t.Run("By slug", func(t *testing.T) { + client := gapiTestTools(t, 200, getDashboardResponse) - resp, err := client.Dashboard("test") - if err != nil { - t.Error(err) - } - uid, ok := resp.Model["uid"] - if !ok || uid != "cIBgcSjkk" { - t.Errorf("Invalid uid - %s, Expected %s", uid, "cIBgcSjkk") - } - - client = gapiTestTools(t, 200, getDashboardResponse) + resp, err := client.Dashboard("test") + if err != nil { + t.Error(err) + } + uid, ok := resp.Model["uid"] + if !ok || uid != "cIBgcSjkk" { + t.Errorf("Invalid uid - %s, Expected %s", uid, "cIBgcSjkk") + } + }) - resp, err = client.DashboardByUID("cIBgcSjkk") - if err != nil { - t.Fatal(err) - } - uid, ok = resp.Model["uid"] - if !ok || uid != "cIBgcSjkk" { - t.Fatalf("Invalid UID - %s, Expected %s", uid, "cIBgcSjkk") - } + t.Run("By UID", func(t *testing.T) { + client := gapiTestTools(t, 200, getDashboardResponse) - for _, code := range []int{401, 403, 404} { - client = gapiTestTools(t, code, "error") - _, err = client.Dashboard("test") - if err == nil { - t.Errorf("%d not detected", code) + resp, err := client.DashboardByUID("cIBgcSjkk") + if err != nil { + t.Fatal(err) } - - _, err = client.DashboardByUID("cIBgcSjkk") - if err == nil { - t.Errorf("%d not detected", code) + uid, ok := resp.Model["uid"] + if !ok || uid != "cIBgcSjkk" { + t.Fatalf("Invalid UID - %s, Expected %s", uid, "cIBgcSjkk") } + }) + + for _, code := range []int{401, 403, 404} { + t.Run(fmt.Sprintf("Dashboard error: %d", code), func(t *testing.T) { + client := gapiTestToolsFromCalls(t, []mockServerCall{{code, "error"}, {code, "error"}}) + _, err := client.Dashboard("test") + if err == nil { + t.Errorf("%d not detected", code) + } + + _, err = client.DashboardByUID("cIBgcSjkk") + if err == nil { + t.Errorf("%d not detected", code) + } + }) } } func TestDashboardDelete(t *testing.T) { - client := gapiTestTools(t, 200, "") - err := client.DeleteDashboard("test") - if err != nil { - t.Error(err) - } - - client = gapiTestTools(t, 200, "") - err = client.DeleteDashboardByUID("cIBgcSjkk") - if err != nil { - t.Fatal(err) - } - - for _, code := range []int{401, 403, 404, 412} { - client = gapiTestTools(t, code, "error") - - err = client.DeleteDashboard("test") - if err == nil { - t.Errorf("%d not detected", code) + t.Run("By slug", func(t *testing.T) { + client := gapiTestTools(t, 200, "") + err := client.DeleteDashboard("test") + if err != nil { + t.Error(err) } + }) - err = client.DeleteDashboardByUID("cIBgcSjkk") - if err == nil { - t.Errorf("%d not detected", code) + t.Run("By UID", func(t *testing.T) { + client := gapiTestTools(t, 200, "") + err := client.DeleteDashboardByUID("cIBgcSjkk") + if err != nil { + t.Fatal(err) } + }) + + for _, code := range []int{401, 403, 404, 412} { + t.Run(fmt.Sprintf("Dashboard error: %d", code), func(t *testing.T) { + client := gapiTestToolsFromCalls(t, []mockServerCall{{code, "error"}, {code, "error"}}) + + err := client.DeleteDashboard("test") + if err == nil { + t.Errorf("%d not detected", code) + } + + err = client.DeleteDashboardByUID("cIBgcSjkk") + if err == nil { + t.Errorf("%d not detected", code) + } + }) } } diff --git a/mock.go b/mock_test.go similarity index 90% rename from mock.go rename to mock_test.go index ef460546..90053870 100644 --- a/mock.go +++ b/mock_test.go @@ -19,11 +19,8 @@ type mockServer struct { server *httptest.Server } -func (m *mockServer) Close() { - m.server.Close() -} - func gapiTestTools(t *testing.T, code int, body string) *Client { + t.Helper() return gapiTestToolsFromCalls(t, []mockServerCall{{code, body}}) } @@ -35,6 +32,9 @@ func gapiTestToolsFromCalls(t *testing.T, calls []mockServerCall) *Client { } mock.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if len(mock.upcomingCalls) == 0 { + t.Fatalf("unexpected call to %s %s", r.Method, r.URL) + } call := mock.upcomingCalls[0] if len(calls) > 1 { mock.upcomingCalls = mock.upcomingCalls[1:] @@ -61,7 +61,7 @@ func gapiTestToolsFromCalls(t *testing.T, calls []mockServerCall) *Client { } t.Cleanup(func() { - mock.Close() + mock.server.Close() }) return client