From bce9d210ef806bc30273a48a7bb8fa7a4f95046a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?William=20Mouch=C3=A8re?= Date: Tue, 31 Dec 2024 15:59:06 +0100 Subject: [PATCH 1/2] Add opt-in rate limit support on endpoints returning 302s Refactor `Client` to add a new `http.Client` dedicated to calls that expect a 302 success response. Reuse the existing `BareDo` rate limit logic with that new client. Add a feature flag to opt in to using this instead of the usual `roundTripWithOptionalFollowRedirect` logic. --- github/actions_artifacts.go | 28 ++ github/actions_artifacts_test.go | 248 +++++++++++++----- github/actions_workflow_jobs.go | 28 ++ github/actions_workflow_jobs_test.go | 199 +++++++++----- github/actions_workflow_runs.go | 56 ++++ github/actions_workflow_runs_test.go | 370 ++++++++++++++++++--------- github/github.go | 166 ++++++++++-- github/github_test.go | 32 +++ github/repos_contents.go | 29 +++ github/repos_contents_test.go | 185 +++++++++----- 10 files changed, 1005 insertions(+), 336 deletions(-) diff --git a/github/actions_artifacts.go b/github/actions_artifacts.go index e05a9a84024..fa3829899ab 100644 --- a/github/actions_artifacts.go +++ b/github/actions_artifacts.go @@ -142,6 +142,14 @@ func (s *ActionsService) GetArtifact(ctx context.Context, owner, repo string, ar func (s *ActionsService) DownloadArtifact(ctx context.Context, owner, repo string, artifactID int64, maxRedirects int) (*url.URL, *Response, error) { u := fmt.Sprintf("repos/%v/%v/actions/artifacts/%v/zip", owner, repo, artifactID) + if s.client.RateLimitRedirectionalEndpoints { + return s.downloadArtifactWithRateLimit(ctx, u, maxRedirects) + } + + return s.downloadArtifactWithoutRateLimit(ctx, u, maxRedirects) +} + +func (s *ActionsService) downloadArtifactWithoutRateLimit(ctx context.Context, u string, maxRedirects int) (*url.URL, *Response, error) { resp, err := s.client.roundTripWithOptionalFollowRedirect(ctx, u, maxRedirects) if err != nil { return nil, nil, err @@ -160,6 +168,26 @@ func (s *ActionsService) DownloadArtifact(ctx context.Context, owner, repo strin return parsedURL, newResponse(resp), nil } +func (s *ActionsService) downloadArtifactWithRateLimit(ctx context.Context, u string, maxRedirects int) (*url.URL, *Response, error) { + req, err := s.client.NewRequest("GET", u, nil) + if err != nil { + return nil, nil, err + } + + url, resp, err := s.client.bareDoUntilFound(ctx, req, maxRedirects) + if err != nil { + return nil, resp, err + } + defer resp.Body.Close() + + // If we received a valid Location in a 302 response + if url != nil { + return url, resp, nil + } + + return nil, resp, fmt.Errorf("unexpected status code: %s", resp.Status) +} + // DeleteArtifact deletes a workflow run artifact. // // GitHub API docs: https://docs.github.com/rest/actions/artifacts#delete-an-artifact diff --git a/github/actions_artifacts_test.go b/github/actions_artifacts_test.go index 5ac5a8f2bd1..ec1fb2a7216 100644 --- a/github/actions_artifacts_test.go +++ b/github/actions_artifacts_test.go @@ -272,102 +272,210 @@ func TestActionsService_GetArtifact_notFound(t *testing.T) { func TestActionsService_DownloadArtifact(t *testing.T) { t.Parallel() - client, mux, _ := setup(t) - - mux.HandleFunc("/repos/o/r/actions/artifacts/1/zip", func(w http.ResponseWriter, r *http.Request) { - testMethod(t, r, "GET") - http.Redirect(w, r, "https://github.com/artifact", http.StatusFound) - }) - ctx := context.Background() - url, resp, err := client.Actions.DownloadArtifact(ctx, "o", "r", 1, 1) - if err != nil { - t.Errorf("Actions.DownloadArtifact returned error: %v", err) - } - if resp.StatusCode != http.StatusFound { - t.Errorf("Actions.DownloadArtifact returned status: %d, want %d", resp.StatusCode, http.StatusFound) + tcs := []struct { + name string + respectRateLimits bool + }{ + { + name: "withoutRateLimits", + respectRateLimits: false, + }, + { + name: "withRateLimits", + respectRateLimits: true, + }, } - want := "https://github.com/artifact" - if url.String() != want { - t.Errorf("Actions.DownloadArtifact returned %+v, want %+v", url.String(), want) - } + for _, tc := range tcs { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + client, mux, _ := setup(t) + client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + + mux.HandleFunc("/repos/o/r/actions/artifacts/1/zip", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + http.Redirect(w, r, "https://github.com/artifact", http.StatusFound) + }) + + ctx := context.Background() + url, resp, err := client.Actions.DownloadArtifact(ctx, "o", "r", 1, 1) + if err != nil { + t.Errorf("Actions.DownloadArtifact returned error: %v", err) + } + if resp.StatusCode != http.StatusFound { + t.Errorf("Actions.DownloadArtifact returned status: %d, want %d", resp.StatusCode, http.StatusFound) + } - const methodName = "DownloadArtifact" - testBadOptions(t, methodName, func() (err error) { - _, _, err = client.Actions.DownloadArtifact(ctx, "\n", "\n", -1, 1) - return err - }) + want := "https://github.com/artifact" + if url.String() != want { + t.Errorf("Actions.DownloadArtifact returned %+v, want %+v", url.String(), want) + } - // Add custom round tripper - client.client.Transport = roundTripperFunc(func(r *http.Request) (*http.Response, error) { - return nil, errors.New("failed to download artifact") - }) - testBadOptions(t, methodName, func() (err error) { - _, _, err = client.Actions.DownloadArtifact(ctx, "o", "r", 1, 1) - return err - }) + const methodName = "DownloadArtifact" + testBadOptions(t, methodName, func() (err error) { + _, _, err = client.Actions.DownloadArtifact(ctx, "\n", "\n", -1, 1) + return err + }) + + // Add custom round tripper + client.client.Transport = roundTripperFunc(func(r *http.Request) (*http.Response, error) { + return nil, errors.New("failed to download artifact") + }) + // propagate custom round tripper to client without CheckRedirect + client.initialize() + testBadOptions(t, methodName, func() (err error) { + _, _, err = client.Actions.DownloadArtifact(ctx, "o", "r", 1, 1) + return err + }) + }) + } } func TestActionsService_DownloadArtifact_invalidOwner(t *testing.T) { t.Parallel() - client, _, _ := setup(t) + tcs := []struct { + name string + respectRateLimits bool + }{ + { + name: "withoutRateLimits", + respectRateLimits: false, + }, + { + name: "withRateLimits", + respectRateLimits: true, + }, + } - ctx := context.Background() - _, _, err := client.Actions.DownloadArtifact(ctx, "%", "r", 1, 1) - testURLParseError(t, err) + for _, tc := range tcs { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + client, _, _ := setup(t) + client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + + ctx := context.Background() + _, _, err := client.Actions.DownloadArtifact(ctx, "%", "r", 1, 1) + testURLParseError(t, err) + }) + } } func TestActionsService_DownloadArtifact_invalidRepo(t *testing.T) { t.Parallel() - client, _, _ := setup(t) + tcs := []struct { + name string + respectRateLimits bool + }{ + { + name: "withoutRateLimits", + respectRateLimits: false, + }, + { + name: "withRateLimits", + respectRateLimits: true, + }, + } - ctx := context.Background() - _, _, err := client.Actions.DownloadArtifact(ctx, "o", "%", 1, 1) - testURLParseError(t, err) + for _, tc := range tcs { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + client, _, _ := setup(t) + client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + + ctx := context.Background() + _, _, err := client.Actions.DownloadArtifact(ctx, "o", "%", 1, 1) + testURLParseError(t, err) + }) + } } func TestActionsService_DownloadArtifact_StatusMovedPermanently_dontFollowRedirects(t *testing.T) { t.Parallel() - client, mux, _ := setup(t) - - mux.HandleFunc("/repos/o/r/actions/artifacts/1/zip", func(w http.ResponseWriter, r *http.Request) { - testMethod(t, r, "GET") - http.Redirect(w, r, "https://github.com/artifact", http.StatusMovedPermanently) - }) + tcs := []struct { + name string + respectRateLimits bool + }{ + { + name: "withoutRateLimits", + respectRateLimits: false, + }, + { + name: "withRateLimits", + respectRateLimits: true, + }, + } - ctx := context.Background() - _, resp, _ := client.Actions.DownloadArtifact(ctx, "o", "r", 1, 0) - if resp.StatusCode != http.StatusMovedPermanently { - t.Errorf("Actions.DownloadArtifact return status %d, want %d", resp.StatusCode, http.StatusMovedPermanently) + for _, tc := range tcs { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + client, mux, _ := setup(t) + client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + + mux.HandleFunc("/repos/o/r/actions/artifacts/1/zip", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + http.Redirect(w, r, "https://github.com/artifact", http.StatusMovedPermanently) + }) + + ctx := context.Background() + _, resp, _ := client.Actions.DownloadArtifact(ctx, "o", "r", 1, 0) + if resp.StatusCode != http.StatusMovedPermanently { + t.Errorf("Actions.DownloadArtifact return status %d, want %d", resp.StatusCode, http.StatusMovedPermanently) + } + }) } } func TestActionsService_DownloadArtifact_StatusMovedPermanently_followRedirects(t *testing.T) { t.Parallel() - client, mux, serverURL := setup(t) - - mux.HandleFunc("/repos/o/r/actions/artifacts/1/zip", func(w http.ResponseWriter, r *http.Request) { - testMethod(t, r, "GET") - redirectURL, _ := url.Parse(serverURL + baseURLPath + "/redirect") - http.Redirect(w, r, redirectURL.String(), http.StatusMovedPermanently) - }) - mux.HandleFunc("/redirect", func(w http.ResponseWriter, r *http.Request) { - testMethod(t, r, "GET") - http.Redirect(w, r, "http://github.com/artifact", http.StatusFound) - }) - - ctx := context.Background() - url, resp, err := client.Actions.DownloadArtifact(ctx, "o", "r", 1, 1) - if err != nil { - t.Errorf("Actions.DownloadArtifact return error: %v", err) - } - if resp.StatusCode != http.StatusFound { - t.Errorf("Actions.DownloadArtifact return status %d, want %d", resp.StatusCode, http.StatusFound) + tcs := []struct { + name string + respectRateLimits bool + }{ + { + name: "withoutRateLimits", + respectRateLimits: false, + }, + { + name: "withRateLimits", + respectRateLimits: true, + }, } - want := "http://github.com/artifact" - if url.String() != want { - t.Errorf("Actions.DownloadArtifact returned %+v, want %+v", url.String(), want) + + for _, tc := range tcs { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + client, mux, serverURL := setup(t) + client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + + mux.HandleFunc("/repos/o/r/actions/artifacts/1/zip", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + redirectURL, _ := url.Parse(serverURL + baseURLPath + "/redirect") + http.Redirect(w, r, redirectURL.String(), http.StatusMovedPermanently) + }) + mux.HandleFunc("/redirect", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + http.Redirect(w, r, "http://github.com/artifact", http.StatusFound) + }) + + ctx := context.Background() + url, resp, err := client.Actions.DownloadArtifact(ctx, "o", "r", 1, 1) + if err != nil { + t.Errorf("Actions.DownloadArtifact return error: %v", err) + } + if resp.StatusCode != http.StatusFound { + t.Errorf("Actions.DownloadArtifact return status %d, want %d", resp.StatusCode, http.StatusFound) + } + want := "http://github.com/artifact" + if url.String() != want { + t.Errorf("Actions.DownloadArtifact returned %+v, want %+v", url.String(), want) + } + }) } } diff --git a/github/actions_workflow_jobs.go b/github/actions_workflow_jobs.go index 84bbe5aa46d..0acbaa4916c 100644 --- a/github/actions_workflow_jobs.go +++ b/github/actions_workflow_jobs.go @@ -150,6 +150,14 @@ func (s *ActionsService) GetWorkflowJobByID(ctx context.Context, owner, repo str func (s *ActionsService) GetWorkflowJobLogs(ctx context.Context, owner, repo string, jobID int64, maxRedirects int) (*url.URL, *Response, error) { u := fmt.Sprintf("repos/%v/%v/actions/jobs/%v/logs", owner, repo, jobID) + if s.client.RateLimitRedirectionalEndpoints { + return s.getWorkflowJobLogsWithRateLimit(ctx, u, maxRedirects) + } + + return s.getWorkflowJobLogsWithoutRateLimit(ctx, u, maxRedirects) +} + +func (s *ActionsService) getWorkflowJobLogsWithoutRateLimit(ctx context.Context, u string, maxRedirects int) (*url.URL, *Response, error) { resp, err := s.client.roundTripWithOptionalFollowRedirect(ctx, u, maxRedirects) if err != nil { return nil, nil, err @@ -163,3 +171,23 @@ func (s *ActionsService) GetWorkflowJobLogs(ctx context.Context, owner, repo str parsedURL, err := url.Parse(resp.Header.Get("Location")) return parsedURL, newResponse(resp), err } + +func (s *ActionsService) getWorkflowJobLogsWithRateLimit(ctx context.Context, u string, maxRedirects int) (*url.URL, *Response, error) { + req, err := s.client.NewRequest("GET", u, nil) + if err != nil { + return nil, nil, err + } + + url, resp, err := s.client.bareDoUntilFound(ctx, req, maxRedirects) + if err != nil { + return nil, resp, err + } + defer resp.Body.Close() + + // If we received a valid Location in a 302 response + if url != nil { + return url, resp, nil + } + + return nil, resp, fmt.Errorf("unexpected status code: %s", resp.Status) +} diff --git a/github/actions_workflow_jobs_test.go b/github/actions_workflow_jobs_test.go index 3ef5295afde..527901eadf6 100644 --- a/github/actions_workflow_jobs_test.go +++ b/github/actions_workflow_jobs_test.go @@ -184,87 +184,152 @@ func TestActionsService_GetWorkflowJobByID(t *testing.T) { func TestActionsService_GetWorkflowJobLogs(t *testing.T) { t.Parallel() - client, mux, _ := setup(t) - - mux.HandleFunc("/repos/o/r/actions/jobs/399444496/logs", func(w http.ResponseWriter, r *http.Request) { - testMethod(t, r, "GET") - http.Redirect(w, r, "http://github.com/a", http.StatusFound) - }) - - ctx := context.Background() - url, resp, err := client.Actions.GetWorkflowJobLogs(ctx, "o", "r", 399444496, 1) - if err != nil { - t.Errorf("Actions.GetWorkflowJobLogs returned error: %v", err) - } - if resp.StatusCode != http.StatusFound { - t.Errorf("Actions.GetWorkflowJobLogs returned status: %d, want %d", resp.StatusCode, http.StatusFound) - } - want := "http://github.com/a" - if url.String() != want { - t.Errorf("Actions.GetWorkflowJobLogs returned %+v, want %+v", url.String(), want) + tcs := []struct { + name string + respectRateLimits bool + }{ + { + name: "withoutRateLimits", + respectRateLimits: false, + }, + { + name: "withRateLimits", + respectRateLimits: true, + }, } - const methodName = "GetWorkflowJobLogs" - testBadOptions(t, methodName, func() (err error) { - _, _, err = client.Actions.GetWorkflowJobLogs(ctx, "\n", "\n", 399444496, 1) - return err - }) - - // Add custom round tripper - client.client.Transport = roundTripperFunc(func(r *http.Request) (*http.Response, error) { - return nil, errors.New("failed to get workflow logs") - }) - testBadOptions(t, methodName, func() (err error) { - _, _, err = client.Actions.GetWorkflowJobLogs(ctx, "o", "r", 399444496, 1) - return err - }) + for _, tc := range tcs { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + client, mux, _ := setup(t) + client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + + mux.HandleFunc("/repos/o/r/actions/jobs/399444496/logs", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + http.Redirect(w, r, "http://github.com/a", http.StatusFound) + }) + + ctx := context.Background() + url, resp, err := client.Actions.GetWorkflowJobLogs(ctx, "o", "r", 399444496, 1) + if err != nil { + t.Errorf("Actions.GetWorkflowJobLogs returned error: %v", err) + } + if resp.StatusCode != http.StatusFound { + t.Errorf("Actions.GetWorkflowJobLogs returned status: %d, want %d", resp.StatusCode, http.StatusFound) + } + want := "http://github.com/a" + if url.String() != want { + t.Errorf("Actions.GetWorkflowJobLogs returned %+v, want %+v", url.String(), want) + } + + const methodName = "GetWorkflowJobLogs" + testBadOptions(t, methodName, func() (err error) { + _, _, err = client.Actions.GetWorkflowJobLogs(ctx, "\n", "\n", 399444496, 1) + return err + }) + + // Add custom round tripper + client.client.Transport = roundTripperFunc(func(r *http.Request) (*http.Response, error) { + return nil, errors.New("failed to get workflow logs") + }) + // propagate custom round tripper to client without CheckRedirect + client.initialize() + testBadOptions(t, methodName, func() (err error) { + _, _, err = client.Actions.GetWorkflowJobLogs(ctx, "o", "r", 399444496, 1) + return err + }) + }) + } } func TestActionsService_GetWorkflowJobLogs_StatusMovedPermanently_dontFollowRedirects(t *testing.T) { t.Parallel() - client, mux, _ := setup(t) - - mux.HandleFunc("/repos/o/r/actions/jobs/399444496/logs", func(w http.ResponseWriter, r *http.Request) { - testMethod(t, r, "GET") - http.Redirect(w, r, "http://github.com/a", http.StatusMovedPermanently) - }) + tcs := []struct { + name string + respectRateLimits bool + }{ + { + name: "withoutRateLimits", + respectRateLimits: false, + }, + { + name: "withRateLimits", + respectRateLimits: true, + }, + } - ctx := context.Background() - _, resp, _ := client.Actions.GetWorkflowJobLogs(ctx, "o", "r", 399444496, 0) - if resp.StatusCode != http.StatusMovedPermanently { - t.Errorf("Actions.GetWorkflowJobLogs returned status: %d, want %d", resp.StatusCode, http.StatusMovedPermanently) + for _, tc := range tcs { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + client, mux, _ := setup(t) + client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + + mux.HandleFunc("/repos/o/r/actions/jobs/399444496/logs", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + http.Redirect(w, r, "http://github.com/a", http.StatusMovedPermanently) + }) + + ctx := context.Background() + _, resp, _ := client.Actions.GetWorkflowJobLogs(ctx, "o", "r", 399444496, 0) + if resp.StatusCode != http.StatusMovedPermanently { + t.Errorf("Actions.GetWorkflowJobLogs returned status: %d, want %d", resp.StatusCode, http.StatusMovedPermanently) + } + }) } } func TestActionsService_GetWorkflowJobLogs_StatusMovedPermanently_followRedirects(t *testing.T) { t.Parallel() - client, mux, serverURL := setup(t) - - // Mock a redirect link, which leads to an archive link - mux.HandleFunc("/repos/o/r/actions/jobs/399444496/logs", func(w http.ResponseWriter, r *http.Request) { - testMethod(t, r, "GET") - redirectURL, _ := url.Parse(serverURL + baseURLPath + "/redirect") - http.Redirect(w, r, redirectURL.String(), http.StatusMovedPermanently) - }) - - mux.HandleFunc("/redirect", func(w http.ResponseWriter, r *http.Request) { - testMethod(t, r, "GET") - http.Redirect(w, r, "http://github.com/a", http.StatusFound) - }) - - ctx := context.Background() - url, resp, err := client.Actions.GetWorkflowJobLogs(ctx, "o", "r", 399444496, 1) - if err != nil { - t.Errorf("Actions.GetWorkflowJobLogs returned error: %v", err) - } - - if resp.StatusCode != http.StatusFound { - t.Errorf("Actions.GetWorkflowJobLogs returned status: %d, want %d", resp.StatusCode, http.StatusFound) + tcs := []struct { + name string + respectRateLimits bool + }{ + { + name: "withoutRateLimits", + respectRateLimits: false, + }, + { + name: "withRateLimits", + respectRateLimits: true, + }, } - want := "http://github.com/a" - if url.String() != want { - t.Errorf("Actions.GetWorkflowJobLogs returned %+v, want %+v", url.String(), want) + for _, tc := range tcs { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + client, mux, serverURL := setup(t) + client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + + // Mock a redirect link, which leads to an archive link + mux.HandleFunc("/repos/o/r/actions/jobs/399444496/logs", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + redirectURL, _ := url.Parse(serverURL + baseURLPath + "/redirect") + http.Redirect(w, r, redirectURL.String(), http.StatusMovedPermanently) + }) + + mux.HandleFunc("/redirect", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + http.Redirect(w, r, "http://github.com/a", http.StatusFound) + }) + + ctx := context.Background() + url, resp, err := client.Actions.GetWorkflowJobLogs(ctx, "o", "r", 399444496, 1) + if err != nil { + t.Errorf("Actions.GetWorkflowJobLogs returned error: %v", err) + } + + if resp.StatusCode != http.StatusFound { + t.Errorf("Actions.GetWorkflowJobLogs returned status: %d, want %d", resp.StatusCode, http.StatusFound) + } + + want := "http://github.com/a" + if url.String() != want { + t.Errorf("Actions.GetWorkflowJobLogs returned %+v, want %+v", url.String(), want) + } + }) } } diff --git a/github/actions_workflow_runs.go b/github/actions_workflow_runs.go index 122ea1d0e2b..6c1ea92b56b 100644 --- a/github/actions_workflow_runs.go +++ b/github/actions_workflow_runs.go @@ -259,6 +259,14 @@ func (s *ActionsService) GetWorkflowRunAttempt(ctx context.Context, owner, repo func (s *ActionsService) GetWorkflowRunAttemptLogs(ctx context.Context, owner, repo string, runID int64, attemptNumber int, maxRedirects int) (*url.URL, *Response, error) { u := fmt.Sprintf("repos/%v/%v/actions/runs/%v/attempts/%v/logs", owner, repo, runID, attemptNumber) + if s.client.RateLimitRedirectionalEndpoints { + return s.getWorkflowRunAttemptLogsWithRateLimit(ctx, u, maxRedirects) + } + + return s.getWorkflowRunAttemptLogsWithoutRateLimit(ctx, u, maxRedirects) +} + +func (s *ActionsService) getWorkflowRunAttemptLogsWithoutRateLimit(ctx context.Context, u string, maxRedirects int) (*url.URL, *Response, error) { resp, err := s.client.roundTripWithOptionalFollowRedirect(ctx, u, maxRedirects) if err != nil { return nil, nil, err @@ -273,6 +281,26 @@ func (s *ActionsService) GetWorkflowRunAttemptLogs(ctx context.Context, owner, r return parsedURL, newResponse(resp), err } +func (s *ActionsService) getWorkflowRunAttemptLogsWithRateLimit(ctx context.Context, u string, maxRedirects int) (*url.URL, *Response, error) { + req, err := s.client.NewRequest("GET", u, nil) + if err != nil { + return nil, nil, err + } + + url, resp, err := s.client.bareDoUntilFound(ctx, req, maxRedirects) + if err != nil { + return nil, resp, err + } + defer resp.Body.Close() + + // If we received a valid Location in a 302 response + if url != nil { + return url, resp, nil + } + + return nil, resp, fmt.Errorf("unexpected status code: %s", resp.Status) +} + // RerunWorkflowByID re-runs a workflow by ID. // // GitHub API docs: https://docs.github.com/rest/actions/workflow-runs#re-run-a-workflow @@ -345,6 +373,14 @@ func (s *ActionsService) CancelWorkflowRunByID(ctx context.Context, owner, repo func (s *ActionsService) GetWorkflowRunLogs(ctx context.Context, owner, repo string, runID int64, maxRedirects int) (*url.URL, *Response, error) { u := fmt.Sprintf("repos/%v/%v/actions/runs/%v/logs", owner, repo, runID) + if s.client.RateLimitRedirectionalEndpoints { + return s.getWorkflowRunLogsWithRateLimit(ctx, u, maxRedirects) + } + + return s.getWorkflowRunLogsWithoutRateLimit(ctx, u, maxRedirects) +} + +func (s *ActionsService) getWorkflowRunLogsWithoutRateLimit(ctx context.Context, u string, maxRedirects int) (*url.URL, *Response, error) { resp, err := s.client.roundTripWithOptionalFollowRedirect(ctx, u, maxRedirects) if err != nil { return nil, nil, err @@ -359,6 +395,26 @@ func (s *ActionsService) GetWorkflowRunLogs(ctx context.Context, owner, repo str return parsedURL, newResponse(resp), err } +func (s *ActionsService) getWorkflowRunLogsWithRateLimit(ctx context.Context, u string, maxRedirects int) (*url.URL, *Response, error) { + req, err := s.client.NewRequest("GET", u, nil) + if err != nil { + return nil, nil, err + } + + url, resp, err := s.client.bareDoUntilFound(ctx, req, maxRedirects) + if err != nil { + return nil, resp, err + } + defer resp.Body.Close() + + // If we received a valid Location in a 302 response + if url != nil { + return url, resp, nil + } + + return nil, resp, fmt.Errorf("unexpected status code: %s", resp.Status) +} + // DeleteWorkflowRun deletes a workflow run by ID. // // GitHub API docs: https://docs.github.com/rest/actions/workflow-runs#delete-a-workflow-run diff --git a/github/actions_workflow_runs_test.go b/github/actions_workflow_runs_test.go index 4c91c5349ce..1f985429980 100644 --- a/github/actions_workflow_runs_test.go +++ b/github/actions_workflow_runs_test.go @@ -190,85 +190,148 @@ func TestActionsService_GetWorkflowRunAttempt(t *testing.T) { func TestActionsService_GetWorkflowRunAttemptLogs(t *testing.T) { t.Parallel() - client, mux, _ := setup(t) + tcs := []struct { + name string + respectRateLimits bool + }{ + { + name: "withoutRateLimits", + respectRateLimits: false, + }, + { + name: "withRateLimits", + respectRateLimits: true, + }, + } - mux.HandleFunc("/repos/o/r/actions/runs/399444496/attempts/2/logs", func(w http.ResponseWriter, r *http.Request) { - testMethod(t, r, "GET") - http.Redirect(w, r, "http://github.com/a", http.StatusFound) - }) + for _, tc := range tcs { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + client, mux, _ := setup(t) + client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + + mux.HandleFunc("/repos/o/r/actions/runs/399444496/attempts/2/logs", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + http.Redirect(w, r, "http://github.com/a", http.StatusFound) + }) + + ctx := context.Background() + url, resp, err := client.Actions.GetWorkflowRunAttemptLogs(ctx, "o", "r", 399444496, 2, 1) + if err != nil { + t.Errorf("Actions.GetWorkflowRunAttemptLogs returned error: %v", err) + } + if resp.StatusCode != http.StatusFound { + t.Errorf("Actions.GetWorkflowRunAttemptLogs returned status: %d, want %d", resp.StatusCode, http.StatusFound) + } + want := "http://github.com/a" + if url.String() != want { + t.Errorf("Actions.GetWorkflowRunAttemptLogs returned %+v, want %+v", url.String(), want) + } - ctx := context.Background() - url, resp, err := client.Actions.GetWorkflowRunAttemptLogs(ctx, "o", "r", 399444496, 2, 1) - if err != nil { - t.Errorf("Actions.GetWorkflowRunAttemptLogs returned error: %v", err) - } - if resp.StatusCode != http.StatusFound { - t.Errorf("Actions.GetWorkflowRunAttemptLogs returned status: %d, want %d", resp.StatusCode, http.StatusFound) + const methodName = "GetWorkflowRunAttemptLogs" + testBadOptions(t, methodName, func() (err error) { + _, _, err = client.Actions.GetWorkflowRunAttemptLogs(ctx, "\n", "\n", 399444496, 2, 1) + return err + }) + }) } - want := "http://github.com/a" - if url.String() != want { - t.Errorf("Actions.GetWorkflowRunAttemptLogs returned %+v, want %+v", url.String(), want) - } - - const methodName = "GetWorkflowRunAttemptLogs" - testBadOptions(t, methodName, func() (err error) { - _, _, err = client.Actions.GetWorkflowRunAttemptLogs(ctx, "\n", "\n", 399444496, 2, 1) - return err - }) } func TestActionsService_GetWorkflowRunAttemptLogs_StatusMovedPermanently_dontFollowRedirects(t *testing.T) { t.Parallel() - client, mux, _ := setup(t) - - mux.HandleFunc("/repos/o/r/actions/runs/399444496/attempts/2/logs", func(w http.ResponseWriter, r *http.Request) { - testMethod(t, r, "GET") - http.Redirect(w, r, "http://github.com/a", http.StatusMovedPermanently) - }) + tcs := []struct { + name string + respectRateLimits bool + }{ + { + name: "withoutRateLimits", + respectRateLimits: false, + }, + { + name: "withRateLimits", + respectRateLimits: true, + }, + } - ctx := context.Background() - _, resp, _ := client.Actions.GetWorkflowRunAttemptLogs(ctx, "o", "r", 399444496, 2, 0) - if resp.StatusCode != http.StatusMovedPermanently { - t.Errorf("Actions.GetWorkflowRunAttemptLogs returned status: %d, want %d", resp.StatusCode, http.StatusMovedPermanently) + for _, tc := range tcs { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + client, mux, _ := setup(t) + client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + + mux.HandleFunc("/repos/o/r/actions/runs/399444496/attempts/2/logs", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + http.Redirect(w, r, "http://github.com/a", http.StatusMovedPermanently) + }) + + ctx := context.Background() + _, resp, _ := client.Actions.GetWorkflowRunAttemptLogs(ctx, "o", "r", 399444496, 2, 0) + if resp.StatusCode != http.StatusMovedPermanently { + t.Errorf("Actions.GetWorkflowRunAttemptLogs returned status: %d, want %d", resp.StatusCode, http.StatusMovedPermanently) + } + }) } } func TestActionsService_GetWorkflowRunAttemptLogs_StatusMovedPermanently_followRedirects(t *testing.T) { t.Parallel() - client, mux, serverURL := setup(t) - - // Mock a redirect link, which leads to an archive link - mux.HandleFunc("/repos/o/r/actions/runs/399444496/attempts/2/logs", func(w http.ResponseWriter, r *http.Request) { - testMethod(t, r, "GET") - redirectURL, _ := url.Parse(serverURL + baseURLPath + "/redirect") - http.Redirect(w, r, redirectURL.String(), http.StatusMovedPermanently) - }) + tcs := []struct { + name string + respectRateLimits bool + }{ + { + name: "withoutRateLimits", + respectRateLimits: false, + }, + { + name: "withRateLimits", + respectRateLimits: true, + }, + } - mux.HandleFunc("/redirect", func(w http.ResponseWriter, r *http.Request) { - testMethod(t, r, "GET") - http.Redirect(w, r, "http://github.com/a", http.StatusFound) - }) + for _, tc := range tcs { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + client, mux, serverURL := setup(t) + client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + + // Mock a redirect link, which leads to an archive link + mux.HandleFunc("/repos/o/r/actions/runs/399444496/attempts/2/logs", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + redirectURL, _ := url.Parse(serverURL + baseURLPath + "/redirect") + http.Redirect(w, r, redirectURL.String(), http.StatusMovedPermanently) + }) + + mux.HandleFunc("/redirect", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + http.Redirect(w, r, "http://github.com/a", http.StatusFound) + }) + + ctx := context.Background() + url, resp, err := client.Actions.GetWorkflowRunAttemptLogs(ctx, "o", "r", 399444496, 2, 1) + if err != nil { + t.Errorf("Actions.GetWorkflowRunAttemptLogs returned error: %v", err) + } - ctx := context.Background() - url, resp, err := client.Actions.GetWorkflowRunAttemptLogs(ctx, "o", "r", 399444496, 2, 1) - if err != nil { - t.Errorf("Actions.GetWorkflowRunAttemptLogs returned error: %v", err) - } + if resp.StatusCode != http.StatusFound { + t.Errorf("Actions.GetWorkflowRunAttemptLogs returned status: %d, want %d", resp.StatusCode, http.StatusFound) + } - if resp.StatusCode != http.StatusFound { - t.Errorf("Actions.GetWorkflowRunAttemptLogs returned status: %d, want %d", resp.StatusCode, http.StatusFound) - } + want := "http://github.com/a" + if url.String() != want { + t.Errorf("Actions.GetWorkflowRunAttemptLogs returned %+v, want %+v", url.String(), want) + } - want := "http://github.com/a" - if url.String() != want { - t.Errorf("Actions.GetWorkflowRunAttemptLogs returned %+v, want %+v", url.String(), want) + const methodName = "GetWorkflowRunAttemptLogs" + testBadOptions(t, methodName, func() (err error) { + _, _, err = client.Actions.GetWorkflowRunAttemptLogs(ctx, "\n", "\n", 399444496, 2, 1) + return err + }) + }) } - - const methodName = "GetWorkflowRunAttemptLogs" - testBadOptions(t, methodName, func() (err error) { - _, _, err = client.Actions.GetWorkflowRunAttemptLogs(ctx, "\n", "\n", 399444496, 2, 1) - return err - }) } func TestActionsService_RerunWorkflowRunByID(t *testing.T) { @@ -389,85 +452,148 @@ func TestActionsService_CancelWorkflowRunByID(t *testing.T) { func TestActionsService_GetWorkflowRunLogs(t *testing.T) { t.Parallel() - client, mux, _ := setup(t) + tcs := []struct { + name string + respectRateLimits bool + }{ + { + name: "withoutRateLimits", + respectRateLimits: false, + }, + { + name: "withRateLimits", + respectRateLimits: true, + }, + } - mux.HandleFunc("/repos/o/r/actions/runs/399444496/logs", func(w http.ResponseWriter, r *http.Request) { - testMethod(t, r, "GET") - http.Redirect(w, r, "http://github.com/a", http.StatusFound) - }) + for _, tc := range tcs { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + client, mux, _ := setup(t) + client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + + mux.HandleFunc("/repos/o/r/actions/runs/399444496/logs", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + http.Redirect(w, r, "http://github.com/a", http.StatusFound) + }) + + ctx := context.Background() + url, resp, err := client.Actions.GetWorkflowRunLogs(ctx, "o", "r", 399444496, 1) + if err != nil { + t.Errorf("Actions.GetWorkflowRunLogs returned error: %v", err) + } + if resp.StatusCode != http.StatusFound { + t.Errorf("Actions.GetWorkflowRunLogs returned status: %d, want %d", resp.StatusCode, http.StatusFound) + } + want := "http://github.com/a" + if url.String() != want { + t.Errorf("Actions.GetWorkflowRunLogs returned %+v, want %+v", url.String(), want) + } - ctx := context.Background() - url, resp, err := client.Actions.GetWorkflowRunLogs(ctx, "o", "r", 399444496, 1) - if err != nil { - t.Errorf("Actions.GetWorkflowRunLogs returned error: %v", err) - } - if resp.StatusCode != http.StatusFound { - t.Errorf("Actions.GetWorkflowRunLogs returned status: %d, want %d", resp.StatusCode, http.StatusFound) + const methodName = "GetWorkflowRunLogs" + testBadOptions(t, methodName, func() (err error) { + _, _, err = client.Actions.GetWorkflowRunLogs(ctx, "\n", "\n", 399444496, 1) + return err + }) + }) } - want := "http://github.com/a" - if url.String() != want { - t.Errorf("Actions.GetWorkflowRunLogs returned %+v, want %+v", url.String(), want) - } - - const methodName = "GetWorkflowRunLogs" - testBadOptions(t, methodName, func() (err error) { - _, _, err = client.Actions.GetWorkflowRunLogs(ctx, "\n", "\n", 399444496, 1) - return err - }) } func TestActionsService_GetWorkflowRunLogs_StatusMovedPermanently_dontFollowRedirects(t *testing.T) { t.Parallel() - client, mux, _ := setup(t) - - mux.HandleFunc("/repos/o/r/actions/runs/399444496/logs", func(w http.ResponseWriter, r *http.Request) { - testMethod(t, r, "GET") - http.Redirect(w, r, "http://github.com/a", http.StatusMovedPermanently) - }) + tcs := []struct { + name string + respectRateLimits bool + }{ + { + name: "withoutRateLimits", + respectRateLimits: false, + }, + { + name: "withRateLimits", + respectRateLimits: true, + }, + } - ctx := context.Background() - _, resp, _ := client.Actions.GetWorkflowRunLogs(ctx, "o", "r", 399444496, 0) - if resp.StatusCode != http.StatusMovedPermanently { - t.Errorf("Actions.GetWorkflowJobLogs returned status: %d, want %d", resp.StatusCode, http.StatusMovedPermanently) + for _, tc := range tcs { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + client, mux, _ := setup(t) + client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + + mux.HandleFunc("/repos/o/r/actions/runs/399444496/logs", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + http.Redirect(w, r, "http://github.com/a", http.StatusMovedPermanently) + }) + + ctx := context.Background() + _, resp, _ := client.Actions.GetWorkflowRunLogs(ctx, "o", "r", 399444496, 0) + if resp.StatusCode != http.StatusMovedPermanently { + t.Errorf("Actions.GetWorkflowJobLogs returned status: %d, want %d", resp.StatusCode, http.StatusMovedPermanently) + } + }) } } func TestActionsService_GetWorkflowRunLogs_StatusMovedPermanently_followRedirects(t *testing.T) { t.Parallel() - client, mux, serverURL := setup(t) - - // Mock a redirect link, which leads to an archive link - mux.HandleFunc("/repos/o/r/actions/runs/399444496/logs", func(w http.ResponseWriter, r *http.Request) { - testMethod(t, r, "GET") - redirectURL, _ := url.Parse(serverURL + baseURLPath + "/redirect") - http.Redirect(w, r, redirectURL.String(), http.StatusMovedPermanently) - }) + tcs := []struct { + name string + respectRateLimits bool + }{ + { + name: "withoutRateLimits", + respectRateLimits: false, + }, + { + name: "withRateLimits", + respectRateLimits: true, + }, + } - mux.HandleFunc("/redirect", func(w http.ResponseWriter, r *http.Request) { - testMethod(t, r, "GET") - http.Redirect(w, r, "http://github.com/a", http.StatusFound) - }) + for _, tc := range tcs { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + client, mux, serverURL := setup(t) + client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + + // Mock a redirect link, which leads to an archive link + mux.HandleFunc("/repos/o/r/actions/runs/399444496/logs", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + redirectURL, _ := url.Parse(serverURL + baseURLPath + "/redirect") + http.Redirect(w, r, redirectURL.String(), http.StatusMovedPermanently) + }) + + mux.HandleFunc("/redirect", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + http.Redirect(w, r, "http://github.com/a", http.StatusFound) + }) + + ctx := context.Background() + url, resp, err := client.Actions.GetWorkflowRunLogs(ctx, "o", "r", 399444496, 1) + if err != nil { + t.Errorf("Actions.GetWorkflowJobLogs returned error: %v", err) + } - ctx := context.Background() - url, resp, err := client.Actions.GetWorkflowRunLogs(ctx, "o", "r", 399444496, 1) - if err != nil { - t.Errorf("Actions.GetWorkflowJobLogs returned error: %v", err) - } + if resp.StatusCode != http.StatusFound { + t.Errorf("Actions.GetWorkflowJobLogs returned status: %d, want %d", resp.StatusCode, http.StatusFound) + } - if resp.StatusCode != http.StatusFound { - t.Errorf("Actions.GetWorkflowJobLogs returned status: %d, want %d", resp.StatusCode, http.StatusFound) - } + want := "http://github.com/a" + if url.String() != want { + t.Errorf("Actions.GetWorkflowJobLogs returned %+v, want %+v", url.String(), want) + } - want := "http://github.com/a" - if url.String() != want { - t.Errorf("Actions.GetWorkflowJobLogs returned %+v, want %+v", url.String(), want) + const methodName = "GetWorkflowRunLogs" + testBadOptions(t, methodName, func() (err error) { + _, _, err = client.Actions.GetWorkflowRunLogs(ctx, "\n", "\n", 399444496, 1) + return err + }) + }) } - - const methodName = "GetWorkflowRunLogs" - testBadOptions(t, methodName, func() (err error) { - _, _, err = client.Actions.GetWorkflowRunLogs(ctx, "\n", "\n", 399444496, 1) - return err - }) } func TestActionService_ListRepositoryWorkflowRuns(t *testing.T) { diff --git a/github/github.go b/github/github.go index 4e5a33c67e4..fe74c75f7a7 100644 --- a/github/github.go +++ b/github/github.go @@ -155,8 +155,9 @@ var errNonNilContext = errors.New("context must be non-nil") // A Client manages communication with the GitHub API. type Client struct { - clientMu sync.Mutex // clientMu protects the client during calls that modify the CheckRedirect func. - client *http.Client // HTTP client used to communicate with the API. + clientMu sync.Mutex // clientMu protects the client during calls that modify the CheckRedirect func. + client *http.Client // HTTP client used to communicate with the API. + clientIgnoreRedirects *http.Client // HTTP client used to communicate with the API on endpoints where we don't want to follow redirects. // Base URL for API requests. Defaults to the public GitHub API, but can be // set to a domain endpoint to use with GitHub Enterprise. BaseURL should @@ -173,6 +174,9 @@ type Client struct { rateLimits [Categories]Rate // Rate limits for the client as determined by the most recent API calls. secondaryRateLimitReset time.Time // Secondary rate limit reset for the client as determined by the most recent API calls. + // Whether to respect rate limit headers on endpoints that return 302 redirections to artifacts + RateLimitRedirectionalEndpoints bool + common service // Reuse a single struct instead of allocating one for each service on the heap. // Services used for talking to different parts of the GitHub API. @@ -394,6 +398,14 @@ func (c *Client) initialize() { if c.client == nil { c.client = &http.Client{} } + // Copy the main http client into the IgnoreRedirects one, overriding the `CheckRedirect` func + c.clientIgnoreRedirects = &http.Client{} + c.clientIgnoreRedirects.Transport = c.client.Transport + c.clientIgnoreRedirects.Timeout = c.client.Timeout + c.clientIgnoreRedirects.Jar = c.client.Jar + c.clientIgnoreRedirects.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } if c.BaseURL == nil { c.BaseURL, _ = url.Parse(defaultBaseURL) } @@ -448,11 +460,12 @@ func (c *Client) copy() *Client { c.clientMu.Lock() // can't use *c here because that would copy mutexes by value. clone := Client{ - client: &http.Client{}, - UserAgent: c.UserAgent, - BaseURL: c.BaseURL, - UploadURL: c.UploadURL, - secondaryRateLimitReset: c.secondaryRateLimitReset, + client: &http.Client{}, + UserAgent: c.UserAgent, + BaseURL: c.BaseURL, + UploadURL: c.UploadURL, + RateLimitRedirectionalEndpoints: c.RateLimitRedirectionalEndpoints, + secondaryRateLimitReset: c.secondaryRateLimitReset, } c.clientMu.Unlock() if c.client != nil { @@ -805,15 +818,15 @@ const ( SleepUntilPrimaryRateLimitResetWhenRateLimited ) -// BareDo sends an API request and lets you handle the api response. If an error -// or API Error occurs, the error will contain more information. Otherwise you -// are supposed to read and close the response's Body. If rate limit is exceeded -// and reset time is in the future, BareDo returns *RateLimitError immediately -// without making a network API call. +// bareDo sends an API request using `caller` http.Client passed in the parameters +// and lets you handle the api response. If an error or API Error occurs, the error +// will contain more information. Otherwise you are supposed to read and close the +// response's Body. If rate limit is exceeded and reset time is in the future, +// bareDo returns *RateLimitError immediately without making a network API call. // // The provided ctx must be non-nil, if it is nil an error is returned. If it is // canceled or times out, ctx.Err() will be returned. -func (c *Client) BareDo(ctx context.Context, req *http.Request) (*Response, error) { +func (c *Client) bareDo(ctx context.Context, caller *http.Client, req *http.Request) (*Response, error) { if ctx == nil { return nil, errNonNilContext } @@ -838,7 +851,7 @@ func (c *Client) BareDo(ctx context.Context, req *http.Request) (*Response, erro } } - resp, err := c.client.Do(req) + resp, err := caller.Do(req) var response *Response if resp != nil { response = newResponse(resp) @@ -897,7 +910,7 @@ func (c *Client) BareDo(ctx context.Context, req *http.Request) (*Response, erro return response, err } // retry the request once when the rate limit has reset - return c.BareDo(context.WithValue(req.Context(), SleepUntilPrimaryRateLimitResetWhenRateLimited, nil), req) + return c.bareDo(context.WithValue(req.Context(), SleepUntilPrimaryRateLimitResetWhenRateLimited, nil), caller, req) } // Update the secondary rate limit if we hit it. @@ -911,6 +924,73 @@ func (c *Client) BareDo(ctx context.Context, req *http.Request) (*Response, erro return response, err } +// BareDo sends an API request and lets you handle the api response. If an error +// or API Error occurs, the error will contain more information. Otherwise you +// are supposed to read and close the response's Body. If rate limit is exceeded +// and reset time is in the future, BareDo returns *RateLimitError immediately +// without making a network API call. +// +// The provided ctx must be non-nil, if it is nil an error is returned. If it is +// canceled or times out, ctx.Err() will be returned. +func (c *Client) BareDo(ctx context.Context, req *http.Request) (*Response, error) { + return c.bareDo(ctx, c.client, req) +} + +// bareDoIgnoreRedirects has the exact same behavior as BareDo but stops at the first +// redirection code returned by the API. If a redirection is returned by the api, bareDoIgnoreRedirects +// returns a *RedirectionError. +// +// The provided ctx must be non-nil, if it is nil an error is returned. If it is +// canceled or times out, ctx.Err() will be returned. +func (c *Client) bareDoIgnoreRedirects(ctx context.Context, req *http.Request) (*Response, error) { + return c.bareDo(ctx, c.clientIgnoreRedirects, req) +} + +var errInvalidLocation = errors.New("invalid or empty Location header in redirection response") + +// bareDoUntilFound has the exact same behavior as BareDo but only follows 301s, up to maxRedirects times. If it receives +// a 302, it will parse the Location header into a *url.URL and return that. +// This is useful for endpoints that return a 302 in successful cases but still might return 301s for +// permanent redirections. +// +// The provided ctx must be non-nil, if it is nil an error is returned. If it is +// canceled or times out, ctx.Err() will be returned. +func (c *Client) bareDoUntilFound(ctx context.Context, req *http.Request, maxRedirects int) (*url.URL, *Response, error) { + response, err := c.bareDoIgnoreRedirects(ctx, req) + + if err != nil { + rerr, ok := err.(*RedirectionError) + if ok { + // If we receive a 302, transform potential relative locations into absolute and return it. + if rerr.StatusCode == http.StatusFound { + if rerr.Location == nil { + return nil, nil, errInvalidLocation + } + newURL := c.BaseURL.ResolveReference(rerr.Location) + return newURL, response, nil + } + // If permanent redirect response is returned, follow it + if maxRedirects > 0 && rerr.StatusCode == http.StatusMovedPermanently { + if rerr.Location == nil { + return nil, nil, errInvalidLocation + } + newURL := c.BaseURL.ResolveReference(rerr.Location) + newRequest := req.Clone(ctx) + newRequest.URL = newURL + return c.bareDoUntilFound(ctx, newRequest, maxRedirects-1) + } + // If we reached the maximum amount of redirections, return an error + if maxRedirects <= 0 && rerr.StatusCode == http.StatusMovedPermanently { + return nil, response, fmt.Errorf("reached the maximum amount of redirections: %w", err) + } + return nil, response, fmt.Errorf("unexepected redirection response: %w", err) + } + } + + // If we don't receive a redirection, forward the response and potential error + return nil, response, err +} + // Do sends an API request and returns the API response. The API response is // JSON decoded and stored in the value pointed to by v, or returned as an // error if an API error has occurred. If v implements the io.Writer interface, @@ -1196,6 +1276,40 @@ func (r *AbuseRateLimitError) Is(target error) bool { compareHTTPResponse(r.Response, v.Response) } +// RedirectionError represents a response that returned a redirect status code: +// +// 301 (Moved Permanently) +// 302 (Found) +// 303 (See Other) +// 307 (Temporary Redirect) +// 308 (Permanent Redirect) +// +// If there was a valid Location header included, it will be parsed to a URL. You should use +// `BaseURL.ResolveReference()` to enrich it with the correct hostname where needed. +type RedirectionError struct { + Response *http.Response // HTTP response that caused this error + StatusCode int + Location *url.URL // location header of the redirection if present +} + +func (r *RedirectionError) Error() string { + return fmt.Sprintf("%v %v: %d location %v", + r.Response.Request.Method, sanitizeURL(r.Response.Request.URL), + r.StatusCode, sanitizeURL(r.Location)) +} + +// Is returns whether the provided error equals this error. +func (r *RedirectionError) Is(target error) bool { + v, ok := target.(*RedirectionError) + if !ok { + return false + } + + return r.StatusCode == v.StatusCode && + (r.Location == v.Location || // either both locations are nil or exactly the same pointer + r.Location != nil && v.Location != nil && r.Location.String() == v.Location.String()) // or they are both not nil and marshalled identically +} + // sanitizeURL redacts the client_secret parameter from the URL which may be // exposed to the user. func sanitizeURL(uri *url.URL) *url.URL { @@ -1260,7 +1374,8 @@ func (e *Error) UnmarshalJSON(data []byte) error { // // The error type will be *RateLimitError for rate limit exceeded errors, // *AcceptedError for 202 Accepted status codes, -// and *TwoFactorAuthError for two-factor authentication errors. +// *TwoFactorAuthError for two-factor authentication errors, +// and *RedirectionError for redirect status codes (only happens when ignoring redirections). func CheckResponse(r *http.Response) error { if r.StatusCode == http.StatusAccepted { return &AcceptedError{} @@ -1302,6 +1417,25 @@ func CheckResponse(r *http.Response) error { abuseRateLimitError.RetryAfter = retryAfter } return abuseRateLimitError + // Check that the status code is a redirection and return a sentinel error that can be used to handle special cases + // where 302 is considered a successful result. + // This should never happen with the default `CheckRedirect`, because it would return a `url.Error` that should be handled upstream. + case r.StatusCode == http.StatusMovedPermanently || + r.StatusCode == http.StatusFound || + r.StatusCode == http.StatusSeeOther || + r.StatusCode == http.StatusTemporaryRedirect || + r.StatusCode == http.StatusPermanentRedirect: + + locationStr := r.Header.Get("Location") + var location *url.URL + if locationStr != "" { + location, _ = url.Parse(locationStr) + } + return &RedirectionError{ + Response: errorResponse.Response, + StatusCode: r.StatusCode, + Location: location, + } default: return errorResponse } diff --git a/github/github_test.go b/github/github_test.go index 43c9bde4f82..e77984d5ed6 100644 --- a/github/github_test.go +++ b/github/github_test.go @@ -1925,6 +1925,38 @@ func TestCheckResponse_AbuseRateLimit(t *testing.T) { } } +func TestCheckResponse_RedirectionError(t *testing.T) { + t.Parallel() + urlStr := "/foo/bar" + + res := &http.Response{ + Request: &http.Request{}, + StatusCode: http.StatusFound, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(``)), + } + res.Header.Set("Location", urlStr) + err := CheckResponse(res).(*RedirectionError) + + if err == nil { + t.Errorf("Expected error response.") + } + + wantedURL, parseErr := url.Parse(urlStr) + if parseErr != nil { + t.Errorf("Error parsing fixture url: %v", parseErr) + } + + want := &RedirectionError{ + Response: res, + StatusCode: http.StatusFound, + Location: wantedURL, + } + if !errors.Is(err, want) { + t.Errorf("Error = %#v, want %#v", err, want) + } +} + func TestCompareHttpResponse(t *testing.T) { t.Parallel() testcases := map[string]struct { diff --git a/github/repos_contents.go b/github/repos_contents.go index 3a0c266b5ee..735ea32d93f 100644 --- a/github/repos_contents.go +++ b/github/repos_contents.go @@ -348,6 +348,15 @@ func (s *RepositoriesService) GetArchiveLink(ctx context.Context, owner, repo st if opts != nil && opts.Ref != "" { u += fmt.Sprintf("/%s", opts.Ref) } + + if s.client.RateLimitRedirectionalEndpoints { + return s.getArchiveLinkWithRateLimit(ctx, u, maxRedirects) + } + + return s.getArchiveLinkWithoutRateLimit(ctx, u, maxRedirects) +} + +func (s *RepositoriesService) getArchiveLinkWithoutRateLimit(ctx context.Context, u string, maxRedirects int) (*url.URL, *Response, error) { resp, err := s.client.roundTripWithOptionalFollowRedirect(ctx, u, maxRedirects) if err != nil { return nil, nil, err @@ -365,3 +374,23 @@ func (s *RepositoriesService) GetArchiveLink(ctx context.Context, owner, repo st return parsedURL, newResponse(resp), nil } + +func (s *RepositoriesService) getArchiveLinkWithRateLimit(ctx context.Context, u string, maxRedirects int) (*url.URL, *Response, error) { + req, err := s.client.NewRequest("GET", u, nil) + if err != nil { + return nil, nil, err + } + + url, resp, err := s.client.bareDoUntilFound(ctx, req, maxRedirects) + if err != nil { + return nil, resp, err + } + defer resp.Body.Close() + + // If we received a valid Location in a 302 response + if url != nil { + return url, resp, nil + } + + return nil, resp, fmt.Errorf("unexpected status code: %s", resp.Status) +} diff --git a/github/repos_contents_test.go b/github/repos_contents_test.go index e75a00e6d5c..84f660eca15 100644 --- a/github/repos_contents_test.go +++ b/github/repos_contents_test.go @@ -707,81 +707,144 @@ func TestRepositoriesService_DeleteFile(t *testing.T) { func TestRepositoriesService_GetArchiveLink(t *testing.T) { t.Parallel() - client, mux, _ := setup(t) - - mux.HandleFunc("/repos/o/r/tarball/yo", func(w http.ResponseWriter, r *http.Request) { - testMethod(t, r, "GET") - http.Redirect(w, r, "http://github.com/a", http.StatusFound) - }) - ctx := context.Background() - url, resp, err := client.Repositories.GetArchiveLink(ctx, "o", "r", Tarball, &RepositoryContentGetOptions{Ref: "yo"}, 1) - if err != nil { - t.Errorf("Repositories.GetArchiveLink returned error: %v", err) - } - if resp.StatusCode != http.StatusFound { - t.Errorf("Repositories.GetArchiveLink returned status: %d, want %d", resp.StatusCode, http.StatusFound) - } - want := "http://github.com/a" - if url.String() != want { - t.Errorf("Repositories.GetArchiveLink returned %+v, want %+v", url.String(), want) + tcs := []struct { + name string + respectRateLimits bool + }{ + { + name: "withoutRateLimits", + respectRateLimits: false, + }, + { + name: "withRateLimits", + respectRateLimits: true, + }, } - const methodName = "GetArchiveLink" - testBadOptions(t, methodName, func() (err error) { - _, _, err = client.Repositories.GetArchiveLink(ctx, "\n", "\n", Tarball, &RepositoryContentGetOptions{}, 1) - return err - }) + for _, tc := range tcs { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + client, mux, _ := setup(t) + client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + + mux.HandleFunc("/repos/o/r/tarball/yo", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + http.Redirect(w, r, "http://github.com/a", http.StatusFound) + }) + ctx := context.Background() + url, resp, err := client.Repositories.GetArchiveLink(ctx, "o", "r", Tarball, &RepositoryContentGetOptions{Ref: "yo"}, 1) + if err != nil { + t.Errorf("Repositories.GetArchiveLink returned error: %v", err) + } + if resp.StatusCode != http.StatusFound { + t.Errorf("Repositories.GetArchiveLink returned status: %d, want %d", resp.StatusCode, http.StatusFound) + } + want := "http://github.com/a" + if url.String() != want { + t.Errorf("Repositories.GetArchiveLink returned %+v, want %+v", url.String(), want) + } - // Add custom round tripper - client.client.Transport = roundTripperFunc(func(r *http.Request) (*http.Response, error) { - return nil, errors.New("failed to get archive link") - }) - testBadOptions(t, methodName, func() (err error) { - _, _, err = client.Repositories.GetArchiveLink(ctx, "o", "r", Tarball, &RepositoryContentGetOptions{}, 1) - return err - }) + const methodName = "GetArchiveLink" + testBadOptions(t, methodName, func() (err error) { + _, _, err = client.Repositories.GetArchiveLink(ctx, "\n", "\n", Tarball, &RepositoryContentGetOptions{}, 1) + return err + }) + + // Add custom round tripper + client.client.Transport = roundTripperFunc(func(r *http.Request) (*http.Response, error) { + return nil, errors.New("failed to get archive link") + }) + testBadOptions(t, methodName, func() (err error) { + _, _, err = client.Repositories.GetArchiveLink(ctx, "o", "r", Tarball, &RepositoryContentGetOptions{}, 1) + return err + }) + }) + } } func TestRepositoriesService_GetArchiveLink_StatusMovedPermanently_dontFollowRedirects(t *testing.T) { t.Parallel() - client, mux, _ := setup(t) + tcs := []struct { + name string + respectRateLimits bool + }{ + { + name: "withoutRateLimits", + respectRateLimits: false, + }, + { + name: "withRateLimits", + respectRateLimits: true, + }, + } - mux.HandleFunc("/repos/o/r/tarball", func(w http.ResponseWriter, r *http.Request) { - testMethod(t, r, "GET") - http.Redirect(w, r, "http://github.com/a", http.StatusMovedPermanently) - }) - ctx := context.Background() - _, resp, _ := client.Repositories.GetArchiveLink(ctx, "o", "r", Tarball, &RepositoryContentGetOptions{}, 0) - if resp.StatusCode != http.StatusMovedPermanently { - t.Errorf("Repositories.GetArchiveLink returned status: %d, want %d", resp.StatusCode, http.StatusMovedPermanently) + for _, tc := range tcs { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + client, mux, _ := setup(t) + client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + + mux.HandleFunc("/repos/o/r/tarball", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + http.Redirect(w, r, "http://github.com/a", http.StatusMovedPermanently) + }) + ctx := context.Background() + _, resp, _ := client.Repositories.GetArchiveLink(ctx, "o", "r", Tarball, &RepositoryContentGetOptions{}, 0) + if resp.StatusCode != http.StatusMovedPermanently { + t.Errorf("Repositories.GetArchiveLink returned status: %d, want %d", resp.StatusCode, http.StatusMovedPermanently) + } + }) } } func TestRepositoriesService_GetArchiveLink_StatusMovedPermanently_followRedirects(t *testing.T) { t.Parallel() - client, mux, serverURL := setup(t) - - // Mock a redirect link, which leads to an archive link - mux.HandleFunc("/repos/o/r/tarball", func(w http.ResponseWriter, r *http.Request) { - testMethod(t, r, "GET") - redirectURL, _ := url.Parse(serverURL + baseURLPath + "/redirect") - http.Redirect(w, r, redirectURL.String(), http.StatusMovedPermanently) - }) - mux.HandleFunc("/redirect", func(w http.ResponseWriter, r *http.Request) { - testMethod(t, r, "GET") - http.Redirect(w, r, "http://github.com/a", http.StatusFound) - }) - ctx := context.Background() - url, resp, err := client.Repositories.GetArchiveLink(ctx, "o", "r", Tarball, &RepositoryContentGetOptions{}, 1) - if err != nil { - t.Errorf("Repositories.GetArchiveLink returned error: %v", err) - } - if resp.StatusCode != http.StatusFound { - t.Errorf("Repositories.GetArchiveLink returned status: %d, want %d", resp.StatusCode, http.StatusFound) + tcs := []struct { + name string + respectRateLimits bool + }{ + { + name: "withoutRateLimits", + respectRateLimits: false, + }, + { + name: "withRateLimits", + respectRateLimits: true, + }, } - want := "http://github.com/a" - if url.String() != want { - t.Errorf("Repositories.GetArchiveLink returned %+v, want %+v", url.String(), want) + + for _, tc := range tcs { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + client, mux, serverURL := setup(t) + client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + + // Mock a redirect link, which leads to an archive link + mux.HandleFunc("/repos/o/r/tarball", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + redirectURL, _ := url.Parse(serverURL + baseURLPath + "/redirect") + http.Redirect(w, r, redirectURL.String(), http.StatusMovedPermanently) + }) + mux.HandleFunc("/redirect", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + http.Redirect(w, r, "http://github.com/a", http.StatusFound) + }) + ctx := context.Background() + url, resp, err := client.Repositories.GetArchiveLink(ctx, "o", "r", Tarball, &RepositoryContentGetOptions{}, 1) + if err != nil { + t.Errorf("Repositories.GetArchiveLink returned error: %v", err) + } + if resp.StatusCode != http.StatusFound { + t.Errorf("Repositories.GetArchiveLink returned status: %d, want %d", resp.StatusCode, http.StatusFound) + } + want := "http://github.com/a" + if url.String() != want { + t.Errorf("Repositories.GetArchiveLink returned %+v, want %+v", url.String(), want) + } + }) } } From 51820660927042868ad5cb45abd51cd7c2cd1879 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?William=20Mouch=C3=A8re?= Date: Tue, 31 Dec 2024 18:16:24 +0100 Subject: [PATCH 2/2] PR Review: make code more idiomatic and improve coverage Fix the ...WithRateLimit funcs so that the happy path flows to the end of the method. Add unit tests for bareDoUntilFound and ...WithRateLimit funcs to validate behavior when receiving unexpected http status codes. --- github/actions_artifacts.go | 10 +-- github/actions_artifacts_test.go | 52 +++++++++++++ github/actions_workflow_jobs.go | 10 +-- github/actions_workflow_jobs_test.go | 54 ++++++++++++++ github/actions_workflow_runs.go | 18 ++--- github/actions_workflow_runs_test.go | 107 +++++++++++++++++++++++++++ github/github_test.go | 42 +++++++++++ github/repos_contents.go | 10 +-- 8 files changed, 279 insertions(+), 24 deletions(-) diff --git a/github/actions_artifacts.go b/github/actions_artifacts.go index fa3829899ab..2b560fa05de 100644 --- a/github/actions_artifacts.go +++ b/github/actions_artifacts.go @@ -157,7 +157,7 @@ func (s *ActionsService) downloadArtifactWithoutRateLimit(ctx context.Context, u defer resp.Body.Close() if resp.StatusCode != http.StatusFound { - return nil, newResponse(resp), fmt.Errorf("unexpected status code: %s", resp.Status) + return nil, newResponse(resp), fmt.Errorf("unexpected status code: %v", resp.Status) } parsedURL, err := url.Parse(resp.Header.Get("Location")) @@ -180,12 +180,12 @@ func (s *ActionsService) downloadArtifactWithRateLimit(ctx context.Context, u st } defer resp.Body.Close() - // If we received a valid Location in a 302 response - if url != nil { - return url, resp, nil + // If we didn't receive a valid Location in a 302 response + if url == nil { + return nil, resp, fmt.Errorf("unexpected status code: %v", resp.Status) } - return nil, resp, fmt.Errorf("unexpected status code: %s", resp.Status) + return url, resp, nil } // DeleteArtifact deletes a workflow run artifact. diff --git a/github/actions_artifacts_test.go b/github/actions_artifacts_test.go index ec1fb2a7216..76a82355e1d 100644 --- a/github/actions_artifacts_test.go +++ b/github/actions_artifacts_test.go @@ -11,6 +11,7 @@ import ( "fmt" "net/http" "net/url" + "strings" "testing" "github.com/google/go-cmp/cmp" @@ -479,6 +480,57 @@ func TestActionsService_DownloadArtifact_StatusMovedPermanently_followRedirects( } } +func TestActionsService_DownloadArtifact_unexpectedCode(t *testing.T) { + t.Parallel() + tcs := []struct { + name string + respectRateLimits bool + }{ + { + name: "withoutRateLimits", + respectRateLimits: false, + }, + { + name: "withRateLimits", + respectRateLimits: true, + }, + } + + for _, tc := range tcs { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + client, mux, serverURL := setup(t) + client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + + mux.HandleFunc("/repos/o/r/actions/artifacts/1/zip", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + redirectURL, _ := url.Parse(serverURL + baseURLPath + "/redirect") + http.Redirect(w, r, redirectURL.String(), http.StatusMovedPermanently) + }) + mux.HandleFunc("/redirect", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + w.WriteHeader(http.StatusNoContent) + }) + + ctx := context.Background() + url, resp, err := client.Actions.DownloadArtifact(ctx, "o", "r", 1, 1) + if err == nil { + t.Fatalf("Actions.DownloadArtifact should return error on unexpected code") + } + if !strings.Contains(err.Error(), "unexpected status code") { + t.Error("Actions.DownloadArtifact should return unexpected status code") + } + if got, want := resp.Response.StatusCode, http.StatusNoContent; got != want { + t.Errorf("Actions.DownloadArtifact return status %d, want %d", got, want) + } + if url != nil { + t.Errorf("Actions.DownloadArtifact return %+v, want nil", url) + } + }) + } +} + func TestActionsService_DeleteArtifact(t *testing.T) { t.Parallel() client, mux, _ := setup(t) diff --git a/github/actions_workflow_jobs.go b/github/actions_workflow_jobs.go index 0acbaa4916c..10067c8b260 100644 --- a/github/actions_workflow_jobs.go +++ b/github/actions_workflow_jobs.go @@ -165,7 +165,7 @@ func (s *ActionsService) getWorkflowJobLogsWithoutRateLimit(ctx context.Context, defer resp.Body.Close() if resp.StatusCode != http.StatusFound { - return nil, newResponse(resp), fmt.Errorf("unexpected status code: %s", resp.Status) + return nil, newResponse(resp), fmt.Errorf("unexpected status code: %v", resp.Status) } parsedURL, err := url.Parse(resp.Header.Get("Location")) @@ -184,10 +184,10 @@ func (s *ActionsService) getWorkflowJobLogsWithRateLimit(ctx context.Context, u } defer resp.Body.Close() - // If we received a valid Location in a 302 response - if url != nil { - return url, resp, nil + // If we didn't receive a valid Location in a 302 response + if url == nil { + return nil, resp, fmt.Errorf("unexpected status code: %v", resp.Status) } - return nil, resp, fmt.Errorf("unexpected status code: %s", resp.Status) + return url, resp, nil } diff --git a/github/actions_workflow_jobs_test.go b/github/actions_workflow_jobs_test.go index 527901eadf6..f414df0fbfb 100644 --- a/github/actions_workflow_jobs_test.go +++ b/github/actions_workflow_jobs_test.go @@ -11,6 +11,7 @@ import ( "fmt" "net/http" "net/url" + "strings" "testing" "time" @@ -333,6 +334,59 @@ func TestActionsService_GetWorkflowJobLogs_StatusMovedPermanently_followRedirect } } +func TestActionsService_GetWorkflowJobLogs_unexpectedCode(t *testing.T) { + t.Parallel() + tcs := []struct { + name string + respectRateLimits bool + }{ + { + name: "withoutRateLimits", + respectRateLimits: false, + }, + { + name: "withRateLimits", + respectRateLimits: true, + }, + } + + for _, tc := range tcs { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + client, mux, serverURL := setup(t) + client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + + // Mock a redirect link, which leads to an archive link + mux.HandleFunc("/repos/o/r/actions/jobs/399444496/logs", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + redirectURL, _ := url.Parse(serverURL + baseURLPath + "/redirect") + http.Redirect(w, r, redirectURL.String(), http.StatusMovedPermanently) + }) + + mux.HandleFunc("/redirect", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + w.WriteHeader(http.StatusNoContent) + }) + + ctx := context.Background() + url, resp, err := client.Actions.GetWorkflowJobLogs(ctx, "o", "r", 399444496, 1) + if err == nil { + t.Fatalf("Actions.GetWorkflowJobLogs should return error on unexpected code") + } + if !strings.Contains(err.Error(), "unexpected status code") { + t.Error("Actions.GetWorkflowJobLogs should return unexpected status code") + } + if got, want := resp.Response.StatusCode, http.StatusNoContent; got != want { + t.Errorf("Actions.GetWorkflowJobLogs return status %d, want %d", got, want) + } + if url != nil { + t.Errorf("Actions.GetWorkflowJobLogs return %+v, want nil", url) + } + }) + } +} + func TestTaskStep_Marshal(t *testing.T) { t.Parallel() testJSONMarshal(t, &TaskStep{}, "{}") diff --git a/github/actions_workflow_runs.go b/github/actions_workflow_runs.go index 6c1ea92b56b..dddc56d2327 100644 --- a/github/actions_workflow_runs.go +++ b/github/actions_workflow_runs.go @@ -274,7 +274,7 @@ func (s *ActionsService) getWorkflowRunAttemptLogsWithoutRateLimit(ctx context.C defer resp.Body.Close() if resp.StatusCode != http.StatusFound { - return nil, newResponse(resp), fmt.Errorf("unexpected status code: %s", resp.Status) + return nil, newResponse(resp), fmt.Errorf("unexpected status code: %v", resp.Status) } parsedURL, err := url.Parse(resp.Header.Get("Location")) @@ -293,12 +293,12 @@ func (s *ActionsService) getWorkflowRunAttemptLogsWithRateLimit(ctx context.Cont } defer resp.Body.Close() - // If we received a valid Location in a 302 response - if url != nil { - return url, resp, nil + // If we didn't receive a valid Location in a 302 response + if url == nil { + return nil, resp, fmt.Errorf("unexpected status code: %v", resp.Status) } - return nil, resp, fmt.Errorf("unexpected status code: %s", resp.Status) + return url, resp, nil } // RerunWorkflowByID re-runs a workflow by ID. @@ -407,12 +407,12 @@ func (s *ActionsService) getWorkflowRunLogsWithRateLimit(ctx context.Context, u } defer resp.Body.Close() - // If we received a valid Location in a 302 response - if url != nil { - return url, resp, nil + // If we didn't receive a valid Location in a 302 response + if url == nil { + return nil, resp, fmt.Errorf("unexpected status code: %v", resp.Status) } - return nil, resp, fmt.Errorf("unexpected status code: %s", resp.Status) + return url, resp, nil } // DeleteWorkflowRun deletes a workflow run by ID. diff --git a/github/actions_workflow_runs_test.go b/github/actions_workflow_runs_test.go index 1f985429980..be400c27e8d 100644 --- a/github/actions_workflow_runs_test.go +++ b/github/actions_workflow_runs_test.go @@ -11,6 +11,7 @@ import ( "fmt" "net/http" "net/url" + "strings" "testing" "time" @@ -334,6 +335,59 @@ func TestActionsService_GetWorkflowRunAttemptLogs_StatusMovedPermanently_followR } } +func TestActionsService_GetWorkflowRunAttemptLogs_unexpectedCode(t *testing.T) { + t.Parallel() + tcs := []struct { + name string + respectRateLimits bool + }{ + { + name: "withoutRateLimits", + respectRateLimits: false, + }, + { + name: "withRateLimits", + respectRateLimits: true, + }, + } + + for _, tc := range tcs { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + client, mux, serverURL := setup(t) + client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + + // Mock a redirect link, which leads to an archive link + mux.HandleFunc("/repos/o/r/actions/runs/399444496/attempts/2/logs", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + redirectURL, _ := url.Parse(serverURL + baseURLPath + "/redirect") + http.Redirect(w, r, redirectURL.String(), http.StatusMovedPermanently) + }) + + mux.HandleFunc("/redirect", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + w.WriteHeader(http.StatusNoContent) + }) + + ctx := context.Background() + url, resp, err := client.Actions.GetWorkflowRunAttemptLogs(ctx, "o", "r", 399444496, 2, 1) + if err == nil { + t.Fatalf("Actions.GetWorkflowRunAttemptLogs should return error on unexpected code") + } + if !strings.Contains(err.Error(), "unexpected status code") { + t.Error("Actions.GetWorkflowRunAttemptLogs should return unexpected status code") + } + if got, want := resp.Response.StatusCode, http.StatusNoContent; got != want { + t.Errorf("Actions.GetWorkflowRunAttemptLogs return status %d, want %d", got, want) + } + if url != nil { + t.Errorf("Actions.GetWorkflowRunAttemptLogs return %+v, want nil", url) + } + }) + } +} + func TestActionsService_RerunWorkflowRunByID(t *testing.T) { t.Parallel() client, mux, _ := setup(t) @@ -596,6 +650,59 @@ func TestActionsService_GetWorkflowRunLogs_StatusMovedPermanently_followRedirect } } +func TestActionsService_GetWorkflowRunLogs_unexpectedCode(t *testing.T) { + t.Parallel() + tcs := []struct { + name string + respectRateLimits bool + }{ + { + name: "withoutRateLimits", + respectRateLimits: false, + }, + { + name: "withRateLimits", + respectRateLimits: true, + }, + } + + for _, tc := range tcs { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + client, mux, serverURL := setup(t) + client.RateLimitRedirectionalEndpoints = tc.respectRateLimits + + // Mock a redirect link, which leads to an archive link + mux.HandleFunc("/repos/o/r/actions/runs/399444496/logs", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + redirectURL, _ := url.Parse(serverURL + baseURLPath + "/redirect") + http.Redirect(w, r, redirectURL.String(), http.StatusMovedPermanently) + }) + + mux.HandleFunc("/redirect", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "GET") + w.WriteHeader(http.StatusNoContent) + }) + + ctx := context.Background() + url, resp, err := client.Actions.GetWorkflowRunLogs(ctx, "o", "r", 399444496, 1) + if err == nil { + t.Fatalf("Actions.GetWorkflowRunLogs should return error on unexpected code") + } + if !strings.Contains(err.Error(), "unexpected status code") { + t.Error("Actions.GetWorkflowRunLogs should return unexpected status code") + } + if got, want := resp.Response.StatusCode, http.StatusNoContent; got != want { + t.Errorf("Actions.GetWorkflowRunLogs return status %d, want %d", got, want) + } + if url != nil { + t.Errorf("Actions.GetWorkflowRunLogs return %+v, want nil", url) + } + }) + } +} + func TestActionService_ListRepositoryWorkflowRuns(t *testing.T) { t.Parallel() client, mux, _ := setup(t) diff --git a/github/github_test.go b/github/github_test.go index e77984d5ed6..16257b384c2 100644 --- a/github/github_test.go +++ b/github/github_test.go @@ -1824,6 +1824,48 @@ func TestDo_noContent(t *testing.T) { } } +func TestBareDoUntilFound_redirectLoop(t *testing.T) { + t.Parallel() + client, mux, _ := setup(t) + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, baseURLPath, http.StatusMovedPermanently) + }) + + req, _ := client.NewRequest("GET", ".", nil) + ctx := context.Background() + _, _, err := client.bareDoUntilFound(ctx, req, 1) + + if err == nil { + t.Error("Expected error to be returned.") + } + var rerr *RedirectionError + if !errors.As(err, &rerr) { + t.Errorf("Expected a Redirection error; got %#v.", err) + } +} + +func TestBareDoUntilFound_UnexpectedRedirection(t *testing.T) { + t.Parallel() + client, mux, _ := setup(t) + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, baseURLPath, http.StatusSeeOther) + }) + + req, _ := client.NewRequest("GET", ".", nil) + ctx := context.Background() + _, _, err := client.bareDoUntilFound(ctx, req, 1) + + if err == nil { + t.Error("Expected error to be returned.") + } + var rerr *RedirectionError + if !errors.As(err, &rerr) { + t.Errorf("Expected a Redirection error; got %#v.", err) + } +} + func TestSanitizeURL(t *testing.T) { t.Parallel() tests := []struct { diff --git a/github/repos_contents.go b/github/repos_contents.go index 735ea32d93f..383988bf296 100644 --- a/github/repos_contents.go +++ b/github/repos_contents.go @@ -364,7 +364,7 @@ func (s *RepositoriesService) getArchiveLinkWithoutRateLimit(ctx context.Context defer resp.Body.Close() if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusFound { - return nil, newResponse(resp), fmt.Errorf("unexpected status code: %s", resp.Status) + return nil, newResponse(resp), fmt.Errorf("unexpected status code: %v", resp.Status) } parsedURL, err := url.Parse(resp.Header.Get("Location")) @@ -387,10 +387,10 @@ func (s *RepositoriesService) getArchiveLinkWithRateLimit(ctx context.Context, u } defer resp.Body.Close() - // If we received a valid Location in a 302 response - if url != nil { - return url, resp, nil + // If we didn't receive a valid Location in a 302 response + if url == nil { + return nil, resp, fmt.Errorf("unexpected status code: %v", resp.Status) } - return nil, resp, fmt.Errorf("unexpected status code: %s", resp.Status) + return url, resp, nil }