Skip to content

Commit 09468fa

Browse files
authored
fix: retry Github 503 requests up to 3 times (#2650)
Fixes #2649
1 parent f918c5d commit 09468fa

File tree

2 files changed

+105
-0
lines changed

2 files changed

+105
-0
lines changed

internal/github/github.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,36 @@ import (
2424
"net/http"
2525
"net/url"
2626
"strings"
27+
"time"
2728

2829
"github.com/google/go-github/v69/github"
2930
)
3031

32+
const (
33+
maxRetries = 3
34+
retryDelay = 2 * time.Second
35+
)
36+
37+
type retryableTransport struct {
38+
transport http.RoundTripper
39+
}
40+
41+
// RoundTrip implements the http.RoundTripper interface and adds retry logic
42+
// for transient server errors.
43+
func (t *retryableTransport) RoundTrip(req *http.Request) (*http.Response, error) {
44+
var resp *http.Response
45+
var err error
46+
for i := 0; i < maxRetries; i++ {
47+
resp, err = t.transport.RoundTrip(req)
48+
if err == nil && resp.StatusCode != http.StatusServiceUnavailable {
49+
return resp, nil
50+
}
51+
slog.Warn("retrying due to error", "err", err, "status_code", resp.StatusCode)
52+
time.Sleep(retryDelay)
53+
}
54+
return resp, err
55+
}
56+
3157
// PullRequest is a type alias for the go-github type.
3258
type PullRequest = github.PullRequest
3359

@@ -57,6 +83,14 @@ func NewClient(accessToken string, repo *Repository) *Client {
5783
}
5884

5985
func newClientWithHTTP(accessToken string, repo *Repository, httpClient *http.Client) *Client {
86+
if httpClient == nil {
87+
httpClient = &http.Client{}
88+
}
89+
transport := httpClient.Transport
90+
if transport == nil {
91+
transport = http.DefaultTransport
92+
}
93+
httpClient.Transport = &retryableTransport{transport: transport}
6094
client := github.NewClient(httpClient)
6195
if repo != nil && repo.BaseURL != "" {
6296
baseURL, _ := url.Parse(repo.BaseURL)

internal/github/github_test.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,6 +1000,77 @@ func TestCreateTag(t *testing.T) {
10001000
}
10011001
}
10021002

1003+
func TestRetryableTransport(t *testing.T) {
1004+
t.Parallel()
1005+
for _, test := range []struct {
1006+
name string
1007+
handler func(w http.ResponseWriter, r *http.Request, requestCount int)
1008+
wantStatusCode int
1009+
wantErr bool
1010+
wantRequestCount int
1011+
}{
1012+
{
1013+
name: "Success after retries",
1014+
handler: func(w http.ResponseWriter, r *http.Request, requestCount int) {
1015+
if requestCount < 3 {
1016+
w.WriteHeader(http.StatusServiceUnavailable)
1017+
} else {
1018+
w.WriteHeader(http.StatusOK)
1019+
}
1020+
},
1021+
wantStatusCode: http.StatusOK,
1022+
wantErr: false,
1023+
wantRequestCount: 3,
1024+
},
1025+
{
1026+
name: "Failure after all retries",
1027+
handler: func(w http.ResponseWriter, r *http.Request, _ int) {
1028+
w.WriteHeader(http.StatusServiceUnavailable)
1029+
},
1030+
wantStatusCode: http.StatusServiceUnavailable,
1031+
wantErr: true,
1032+
wantRequestCount: 3,
1033+
},
1034+
} {
1035+
t.Run(test.name, func(t *testing.T) {
1036+
var requestCount int
1037+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1038+
requestCount++
1039+
test.handler(w, r, requestCount)
1040+
}))
1041+
defer server.Close()
1042+
1043+
repo := &Repository{Owner: "owner", Name: "repo"}
1044+
client := newClientWithHTTP("fake-token", repo, server.Client())
1045+
client.BaseURL, _ = url.Parse(server.URL + "/")
1046+
1047+
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, server.URL, nil)
1048+
if err != nil {
1049+
t.Fatalf("http.NewRequestWithContext() failed: %v", err)
1050+
}
1051+
resp, err := client.Do(t.Context(), req, nil)
1052+
1053+
if test.wantErr {
1054+
if err == nil {
1055+
t.Fatal("client.Do() err = nil, want error")
1056+
}
1057+
} else {
1058+
if err != nil {
1059+
t.Fatalf("client.Do() failed: %v", err)
1060+
}
1061+
defer resp.Body.Close()
1062+
}
1063+
1064+
if resp.StatusCode != test.wantStatusCode {
1065+
t.Errorf("client.Do() status = %d, want %d", resp.StatusCode, test.wantStatusCode)
1066+
}
1067+
if requestCount != test.wantRequestCount {
1068+
t.Errorf("requestCount = %d, want %d", requestCount, test.wantRequestCount)
1069+
}
1070+
})
1071+
}
1072+
}
1073+
10031074
func TestNewClient(t *testing.T) {
10041075

10051076
t.Parallel()

0 commit comments

Comments
 (0)