From f04c7785b095f29c247003e185ae065c3663b35a Mon Sep 17 00:00:00 2001 From: Ben McNicholl Date: Thu, 9 Apr 2026 15:55:19 +1000 Subject: [PATCH 1/2] Adjust http client to support rate limits Changed the way that the CLI implements rate limits and the back off. We've removed the dependency on `roko`, which shouldn't be required with the availability of the `RateLimit-` headers that we can get from client calls. --- cmd/api/api.go | 19 +- cmd/preflight/preflight.go | 13 +- go.mod | 1 - go.sum | 4 - internal/http/client.go | 92 +-------- internal/http/client_test.go | 334 -------------------------------- internal/http/ratelimit.go | 123 ++++++++++++ internal/http/ratelimit_test.go | 298 ++++++++++++++++++++++++++++ pkg/cmd/factory/factory.go | 24 ++- 9 files changed, 462 insertions(+), 446 deletions(-) create mode 100644 internal/http/ratelimit.go create mode 100644 internal/http/ratelimit_test.go diff --git a/cmd/api/api.go b/cmd/api/api.go index f56b4f5e..b93e06ca 100644 --- a/cmd/api/api.go +++ b/cmd/api/api.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" "os" "strings" "time" @@ -114,17 +115,19 @@ func (c *ApiCmd) Run(kongCtx *kong.Context, globals cli.GlobalFlags) error { fullEndpoint := buildFullEndpoint(c.Endpoint, f.Config.OrganizationSlug(), c.Analytics) - // Create an HTTP client with appropriate configuration + // Create an HTTP client with rate-limit retry via the shared transport. + rl := httpClient.NewRateLimitTransport(nil) + rl.MaxRetryDelay = 60 * time.Second + rl.OnRateLimit = func(attempt int, delay time.Duration) { + if c.Verbose { + fmt.Fprintf(os.Stderr, "WARNING: Rate limit exceeded, retrying in %v @ %q (attempt %d)\n", delay, time.Now().Add(delay).Format(time.RFC3339), attempt) + } + } + client := httpClient.NewClient( f.Config.APIToken(), httpClient.WithBaseURL(f.RestAPIClient.BaseURL.String()), - httpClient.WithMaxRetries(3), - httpClient.WithMaxRetryDelay(60*time.Second), - httpClient.WithOnRetry(func(attempt int, delay time.Duration) { - if c.Verbose { - fmt.Fprintf(os.Stderr, "WARNING: Rate limit exceeded, retrying in %v @ %q (attempt %d)\n", delay, time.Now().Add(delay).Format(time.RFC3339), attempt) - } - }), + httpClient.WithHTTPClient(&http.Client{Transport: rl}), ) // Process custom headers diff --git a/cmd/preflight/preflight.go b/cmd/preflight/preflight.go index 219d1ef7..6241aec7 100644 --- a/cmd/preflight/preflight.go +++ b/cmd/preflight/preflight.go @@ -18,6 +18,7 @@ import ( "github.com/buildkite/cli/v3/internal/build/watch" "github.com/buildkite/cli/v3/internal/cli" bkErrors "github.com/buildkite/cli/v3/internal/errors" + bkhttp "github.com/buildkite/cli/v3/internal/http" "github.com/buildkite/cli/v3/internal/pipeline/resolver" "github.com/buildkite/cli/v3/internal/preflight" "github.com/buildkite/cli/v3/pkg/cmd/factory" @@ -43,7 +44,8 @@ func (c *PreflightCmd) Help() string { } func (c *PreflightCmd) Run(kongCtx *kong.Context, globals cli.GlobalFlags) error { - f, err := newFactory(factory.WithDebug(globals.EnableDebug())) + rlTransport := bkhttp.NewRateLimitTransport(http.DefaultTransport) + f, err := newFactory(factory.WithDebug(globals.EnableDebug()), factory.WithTransport(rlTransport)) if err != nil { return bkErrors.NewInternalError(err, "failed to initialize CLI", "This is likely a bug", "Report to Buildkite") } @@ -90,6 +92,15 @@ func (c *PreflightCmd) Run(kongCtx *kong.Context, globals cli.GlobalFlags) error renderer := newRenderer(os.Stdout, c.JSON, c.Text, stop) + rlTransport.OnRateLimit = func(attempt int, delay time.Duration) { + _ = renderer.Render(Event{ + Type: EventOperation, + Time: time.Now(), + PreflightID: preflightID.String(), + Title: fmt.Sprintf("Rate limited by API, waiting %s before retrying (attempt %d/%d)...", delay.Truncate(time.Second), attempt+1, rlTransport.MaxRetries), + }) + } + _ = renderer.Render(Event{Type: EventOperation, Time: time.Now(), PreflightID: preflightID.String(), Title: "Pushing snapshot of working tree..."}) var opts []preflight.SnapshotOption diff --git a/go.mod b/go.mod index c4901648..af674c71 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,6 @@ require ( github.com/alecthomas/kong v1.15.0 github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be github.com/buildkite/go-buildkite/v4 v4.18.1-0.20260408232706-47eafe1749f2 - github.com/buildkite/roko v1.4.0 github.com/buildkite/termoji v0.0.0-20260330080310-c0aa4ebee0d1 github.com/charmbracelet/bubbles v1.0.0 github.com/charmbracelet/bubbletea v1.3.10 diff --git a/go.sum b/go.sum index b3a17124..0141b30a 100644 --- a/go.sum +++ b/go.sum @@ -35,8 +35,6 @@ github.com/bradleyjkemp/cupaloy/v2 v2.6.0 h1:knToPYa2xtfg42U3I6punFEjaGFKWQRXJwj github.com/bradleyjkemp/cupaloy/v2 v2.6.0/go.mod h1:bm7JXdkRd4BHJk9HpwqAI8BoAY1lps46Enkdqw6aRX0= github.com/buildkite/go-buildkite/v4 v4.18.1-0.20260408232706-47eafe1749f2 h1:NXO5ZhFPFGk2nD2H9m1+QqbGJqRs2LWlx4U+K0sSNp0= github.com/buildkite/go-buildkite/v4 v4.18.1-0.20260408232706-47eafe1749f2/go.mod h1:8+7GiWBKwEPAWoZnRU/kpNCt46j1iVH8kFMMbD4YDfc= -github.com/buildkite/roko v1.4.0 h1:DxixoCdpNqxu4/1lXrXbfsKbJSd7r1qoxtef/TT2J80= -github.com/buildkite/roko v1.4.0/go.mod h1:0vbODqUFEcVf4v2xVXRfZZRsqJVsCCHTG/TBRByGK4E= github.com/buildkite/termoji v0.0.0-20260330080310-c0aa4ebee0d1 h1:aaEl0QZURcwC+KOfFTzSp66xknw5eTmFZ1NgB87s2xk= github.com/buildkite/termoji v0.0.0-20260330080310-c0aa4ebee0d1/go.mod h1:ZTEvQlMN3+qzjROvjRb1p0X+xDQxxKpkMFhMSnaTrpw= github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4= @@ -225,5 +223,3 @@ gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q= -gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA= diff --git a/internal/http/client.go b/internal/http/client.go index a5150f0e..eef919c3 100644 --- a/internal/http/client.go +++ b/internal/http/client.go @@ -8,11 +8,7 @@ import ( "io" "net/http" "net/url" - "strconv" "strings" - "time" - - "github.com/buildkite/roko" ) // ErrorResponse represents an error response from the API @@ -44,10 +40,6 @@ type Client struct { token string userAgent string client *http.Client - - maxRetries int - maxRetryDelay time.Duration - onRetry OnRetryFunc } // ClientOption is a function that modifies a Client @@ -74,30 +66,6 @@ func WithHTTPClient(client *http.Client) ClientOption { } } -// WithMaxRetries sets the maximum number of retries for rate-limited requests. -func WithMaxRetries(n int) ClientOption { - return func(c *Client) { - c.maxRetries = n - } -} - -// WithMaxRetryDelay sets the maximum delay between retries -func WithMaxRetryDelay(d time.Duration) ClientOption { - return func(c *Client) { - c.maxRetryDelay = d - } -} - -// WithOnRetry sets a callback that is invoked before each retry sleep. -func WithOnRetry(f OnRetryFunc) ClientOption { - return func(c *Client) { - c.onRetry = f - } -} - -// OnRetryFunc is called before each retry sleep with the attempt number and delay duration. -type OnRetryFunc func(attempt int, delay time.Duration) - // NewClient creates a new HTTP client with the given token and options func NewClient(token string, opts ...ClientOption) *Client { c := &Client{ @@ -164,23 +132,6 @@ func (e *ErrorResponse) IsTooManyRequests() bool { return e.StatusCode == http.StatusTooManyRequests } -// RetryAfter returns the duration to wait before retrying, based on the RateLimit-Reset header. -// Returns 0 if the header is missing or invalid. -func (e *ErrorResponse) RetryAfter() time.Duration { - if e.Headers == nil { - return 0 - } - resetStr := e.Headers.Get("RateLimit-Reset") - if resetStr == "" { - return 0 - } - seconds, err := strconv.Atoi(resetStr) - if err != nil || seconds < 0 { - return 0 - } - return time.Duration(seconds) * time.Second -} - // Do performs an HTTP request with the given method, endpoint, and body. func (c *Client) Do(ctx context.Context, method, endpoint string, body interface{}, v interface{}) error { // Ensure endpoint starts with "/" @@ -215,21 +166,7 @@ func (c *Client) Do(ctx context.Context, method, endpoint string, body interface } } - r := roko.NewRetrier( - roko.WithMaxAttempts(c.maxRetries+1), - roko.WithStrategy(roko.Constant(0)), - ) - - respBody, err := roko.DoFunc(ctx, r, func(r *roko.Retrier) ([]byte, error) { - resp, err := c.send(ctx, method, reqURL, bodyBytes) - if err != nil { - if !c.handleRetry(r, err) { - r.Break() - } - return nil, err - } - return resp, nil - }) + respBody, err := c.send(ctx, method, reqURL, bodyBytes) if err != nil { return err } @@ -284,30 +221,3 @@ func (c *Client) send(ctx context.Context, method, reqURL string, body []byte) ( return respBody, nil } - -// handleRetry checks if an error is retryable and configures the retrier accordingly. -// Returns true if the request should be retried, false otherwise. -func (c *Client) handleRetry(r *roko.Retrier, err error) bool { - errResp, ok := err.(*ErrorResponse) - if !ok || !errResp.IsTooManyRequests() { - return false - } - - attempt := r.AttemptCount() - delay := errResp.RetryAfter() - if attempt > 0 { - // Got rate-limited again means contention - back off exponentially - delay *= time.Duration(1 << attempt) - } - - if c.maxRetryDelay > 0 { - delay = min(delay, c.maxRetryDelay) - } - - if c.onRetry != nil { - c.onRetry(attempt, delay) - } - - r.SetNextInterval(delay) - return true -} diff --git a/internal/http/client_test.go b/internal/http/client_test.go index 519f7867..b5187de2 100644 --- a/internal/http/client_test.go +++ b/internal/http/client_test.go @@ -3,13 +3,10 @@ package http import ( "context" "encoding/json" - "errors" "io" "net/http" "net/http/httptest" - "sync/atomic" "testing" - "time" ) type testResponse struct { @@ -307,335 +304,4 @@ func TestErrorResponse(t *testing.T) { t.Error("expected IsTooManyRequests to return false for 500") } }) - - t.Run("RetryAfter parses RateLimit-Reset header", func(t *testing.T) { - t.Parallel() - - headers := http.Header{} - headers.Set("RateLimit-Reset", "30") - err := &ErrorResponse{Headers: headers} - - if got := err.RetryAfter(); got != 30*time.Second { - t.Errorf("expected 30s, got %v", got) - } - }) - - t.Run("RetryAfter returns zero for missing header", func(t *testing.T) { - t.Parallel() - - err := &ErrorResponse{Headers: http.Header{}} - if got := err.RetryAfter(); got != 0 { - t.Errorf("expected 0, got %v", got) - } - - err = &ErrorResponse{Headers: nil} - if got := err.RetryAfter(); got != 0 { - t.Errorf("expected 0 for nil headers, got %v", got) - } - }) - - t.Run("RetryAfter returns zero for invalid header value", func(t *testing.T) { - t.Parallel() - - headers := http.Header{} - headers.Set("RateLimit-Reset", "not-a-number") - err := &ErrorResponse{Headers: headers} - - if got := err.RetryAfter(); got != 0 { - t.Errorf("expected 0 for invalid value, got %v", got) - } - }) -} - -func TestClientRetry(t *testing.T) { - t.Parallel() - - t.Run("retries on 429 with RateLimit-Reset header", func(t *testing.T) { - t.Parallel() - - var requestCount atomic.Int32 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - count := requestCount.Add(1) - if count == 1 { - w.Header().Set("RateLimit-Reset", "0") - w.WriteHeader(http.StatusTooManyRequests) - return - } - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(testResponse{Message: "success"}) - })) - defer server.Close() - - client := NewClient("test-token", - WithBaseURL(server.URL), - WithMaxRetries(3), - WithMaxRetryDelay(100*time.Millisecond), - ) - - var resp testResponse - err := client.Get(context.Background(), "/test", &resp) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if got := requestCount.Load(); got != 2 { - t.Errorf("expected 2 requests, got %d", got) - } - if resp.Message != "success" { - t.Errorf("expected success message, got %q", resp.Message) - } - }) - - t.Run("respects max retries limit", func(t *testing.T) { - t.Parallel() - - var requestCount atomic.Int32 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - requestCount.Add(1) - w.Header().Set("RateLimit-Reset", "0") - w.WriteHeader(http.StatusTooManyRequests) - })) - defer server.Close() - - client := NewClient("test-token", - WithBaseURL(server.URL), - WithMaxRetries(2), - WithMaxRetryDelay(1*time.Millisecond), - ) - - var resp testResponse - err := client.Get(context.Background(), "/test", &resp) - if err == nil { - t.Fatal("expected error, got nil") - } - - errResp, ok := err.(*ErrorResponse) - if !ok { - t.Fatalf("expected ErrorResponse, got %T", err) - } - if !errResp.IsTooManyRequests() { - t.Errorf("expected 429 error, got %d", errResp.StatusCode) - } - - if got := requestCount.Load(); got != 3 { - t.Errorf("expected 3 requests, got %d", got) - } - }) - - t.Run("does not retry non-429 errors", func(t *testing.T) { - t.Parallel() - - var requestCount atomic.Int32 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - requestCount.Add(1) - w.WriteHeader(http.StatusInternalServerError) - })) - defer server.Close() - - client := NewClient("test-token", - WithBaseURL(server.URL), - WithMaxRetries(3), - ) - - var resp testResponse - err := client.Get(context.Background(), "/test", &resp) - if err == nil { - t.Fatal("expected error, got nil") - } - - if got := requestCount.Load(); got != 1 { - t.Errorf("expected 1 request (no retries for 500), got %d", got) - } - }) - - t.Run("respects context cancellation during retry", func(t *testing.T) { - t.Parallel() - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("RateLimit-Reset", "60") - w.WriteHeader(http.StatusTooManyRequests) - })) - defer server.Close() - - client := NewClient("test-token", - WithBaseURL(server.URL), - WithMaxRetries(3), - WithMaxRetryDelay(60*time.Second), - ) - - ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) - defer cancel() - - start := time.Now() - var resp testResponse - err := client.Get(ctx, "/test", &resp) - - elapsed := time.Since(start) - if elapsed > 1*time.Second { - t.Errorf("expected quick cancellation, took %v", elapsed) - } - - if !errors.Is(err, context.DeadlineExceeded) { - t.Errorf("expected DeadlineExceeded, got %v", err) - } - }) - - t.Run("honors RateLimit-Reset header when no max retry delay is set", func(t *testing.T) { - t.Parallel() - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("RateLimit-Reset", "1") - w.WriteHeader(http.StatusTooManyRequests) - })) - defer server.Close() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - var observedDelay time.Duration - - client := NewClient("test-token", - WithBaseURL(server.URL), - WithMaxRetries(1), - WithOnRetry(func(attempt int, delay time.Duration) { - // Stop the request as soon as we see the computed delay so the test - // doesn't actually sleep for the full backoff. - observedDelay = delay - cancel() - }), - ) - - var resp testResponse - err := client.Get(ctx, "/test", &resp) - if !errors.Is(err, context.Canceled) { - t.Fatalf("expected context cancellation, got %v", err) - } - if observedDelay != time.Second { - t.Fatalf("expected retry delay from RateLimit-Reset to be 1s, got %v", observedDelay) - } - }) - - t.Run("caps retry delay at maxRetryDelay", func(t *testing.T) { - t.Parallel() - - var requestCount atomic.Int32 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - count := requestCount.Add(1) - if count == 1 { - w.Header().Set("RateLimit-Reset", "3600") - w.WriteHeader(http.StatusTooManyRequests) - return - } - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(testResponse{Message: "success"}) - })) - defer server.Close() - - client := NewClient("test-token", - WithBaseURL(server.URL), - WithMaxRetries(1), - WithMaxRetryDelay(10*time.Millisecond), - ) - - start := time.Now() - var resp testResponse - err := client.Get(context.Background(), "/test", &resp) - elapsed := time.Since(start) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if elapsed > 1*time.Second { - t.Errorf("expected delay to be capped, but took %v", elapsed) - } - }) - - t.Run("retries preserve request body", func(t *testing.T) { - t.Parallel() - - var requestCount atomic.Int32 - var lastBody []byte - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - count := requestCount.Add(1) - body, _ := io.ReadAll(r.Body) - lastBody = body - if count == 1 { - w.Header().Set("RateLimit-Reset", "0") - w.WriteHeader(http.StatusTooManyRequests) - return - } - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(testResponse{Message: "success"}) - })) - defer server.Close() - - client := NewClient("test-token", - WithBaseURL(server.URL), - WithMaxRetries(1), - WithMaxRetryDelay(1*time.Millisecond), - ) - - requestBody := map[string]string{"key": "value"} - var resp testResponse - err := client.Post(context.Background(), "/test", requestBody, &resp) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - var parsed map[string]string - if err := json.Unmarshal(lastBody, &parsed); err != nil { - t.Fatalf("failed to parse body: %v", err) - } - if parsed["key"] != "value" { - t.Errorf("expected body to be preserved on retry, got %v", parsed) - } - }) - - t.Run("invokes OnRetry callback before sleeping", func(t *testing.T) { - t.Parallel() - - var requestCount atomic.Int32 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - count := requestCount.Add(1) - if count <= 2 { - w.Header().Set("RateLimit-Reset", "1") - w.WriteHeader(http.StatusTooManyRequests) - return - } - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(testResponse{Message: "success"}) - })) - defer server.Close() - - type callback struct { - attempt int - delay time.Duration - } - var callbacks []callback - - client := NewClient("test-token", - WithBaseURL(server.URL), - WithMaxRetries(3), - WithMaxRetryDelay(10*time.Millisecond), - WithOnRetry(func(attempt int, delay time.Duration) { - callbacks = append(callbacks, callback{attempt, delay}) - }), - ) - - var resp testResponse - err := client.Get(context.Background(), "/test", &resp) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - if len(callbacks) != 2 { - t.Fatalf("expected 2 callbacks, got %d", len(callbacks)) - } - if callbacks[0].attempt != 0 { - t.Errorf("first callback attempt: expected 0, got %d", callbacks[0].attempt) - } - if callbacks[1].attempt != 1 { - t.Errorf("second callback attempt: expected 1, got %d", callbacks[1].attempt) - } - }) } diff --git a/internal/http/ratelimit.go b/internal/http/ratelimit.go new file mode 100644 index 00000000..6a589118 --- /dev/null +++ b/internal/http/ratelimit.go @@ -0,0 +1,123 @@ +package http + +import ( + "io" + "net/http" + "strconv" + "time" +) + +const ( + // DefaultMaxRateLimitRetries is the default number of times to retry a + // rate-limited request. + DefaultMaxRateLimitRetries = 3 + + // defaultFallbackDelay is used when the server returns 429 but the + // RateLimit-Reset header is missing or unparseable. + defaultFallbackDelay = 10 * time.Second +) + +// OnRateLimitFunc is called before sleeping for a rate-limit backoff. +// attempt is zero-indexed; delay is how long the transport will sleep. +type OnRateLimitFunc func(attempt int, delay time.Duration) + +// RateLimitTransport wraps an http.RoundTripper and automatically retries +// requests that receive an HTTP 429 response, sleeping for the duration +// indicated by the RateLimit-Reset header. +type RateLimitTransport struct { + // Transport is the underlying RoundTripper. If nil, http.DefaultTransport + // is used. + Transport http.RoundTripper + + // MaxRetries is the maximum number of retry attempts on 429. Zero means + // no retries; negative values are treated as zero. + MaxRetries int + + // MaxRetryDelay caps the sleep duration for any single retry. Zero means + // no cap is applied. + MaxRetryDelay time.Duration + + // OnRateLimit is an optional callback invoked before each backoff sleep. + OnRateLimit OnRateLimitFunc +} + +// NewRateLimitTransport returns a RateLimitTransport wrapping the given +// transport with sensible defaults. +func NewRateLimitTransport(transport http.RoundTripper) *RateLimitTransport { + if transport == nil { + transport = http.DefaultTransport + } + return &RateLimitTransport{ + Transport: transport, + MaxRetries: DefaultMaxRateLimitRetries, + } +} + +// RoundTrip implements http.RoundTripper. On a 429 response it reads the +// RateLimit-Reset header (seconds until the rate-limit window resets) and +// sleeps for that duration before retrying, up to MaxRetries times. +func (t *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, error) { + transport := t.Transport + if transport == nil { + transport = http.DefaultTransport + } + + for attempt := 0; ; attempt++ { + // Reset the request body for retries. + if attempt > 0 && req.GetBody != nil { + body, err := req.GetBody() + if err != nil { + return nil, err + } + req.Body = body + } + + resp, err := transport.RoundTrip(req) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusTooManyRequests || attempt >= t.MaxRetries { + return resp, nil + } + + delay, ok := parseRateLimitReset(resp) + if !ok { + delay = defaultFallbackDelay + } + if t.MaxRetryDelay > 0 && delay > t.MaxRetryDelay { + delay = t.MaxRetryDelay + } + + // Drain and close the 429 response body before retrying. + _, _ = io.Copy(io.Discard, resp.Body) + resp.Body.Close() + + if t.OnRateLimit != nil { + t.OnRateLimit(attempt, delay) + } + + // Sleep for the backoff duration, but honour context cancellation. + timer := time.NewTimer(delay) + select { + case <-req.Context().Done(): + timer.Stop() + return nil, req.Context().Err() + case <-timer.C: + } + } +} + +// parseRateLimitReset reads the RateLimit-Reset header and returns the +// duration to wait plus a boolean indicating whether the value was valid. +func parseRateLimitReset(resp *http.Response) (time.Duration, bool) { + s := resp.Header.Get("RateLimit-Reset") + if s == "" { + return 0, false + } + seconds, err := strconv.Atoi(s) + if err != nil || seconds < 0 { + return 0, false + } + return time.Duration(seconds) * time.Second, true +} diff --git a/internal/http/ratelimit_test.go b/internal/http/ratelimit_test.go new file mode 100644 index 00000000..f4816ddd --- /dev/null +++ b/internal/http/ratelimit_test.go @@ -0,0 +1,298 @@ +package http + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" +) + +func TestRateLimitTransport(t *testing.T) { + t.Run("passes through non-429 responses", func(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) + })) + defer s.Close() + + rt := NewRateLimitTransport(http.DefaultTransport) + req, _ := http.NewRequest("GET", s.URL, nil) + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + }) + + t.Run("retries on 429 and succeeds", func(t *testing.T) { + var attempts atomic.Int32 + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := attempts.Add(1) + if n <= 2 { + w.Header().Set("RateLimit-Reset", "1") + w.WriteHeader(http.StatusTooManyRequests) + w.Write([]byte("rate limited")) + return + } + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) + })) + defer s.Close() + + var callbackCalls int + rt := NewRateLimitTransport(http.DefaultTransport) + rt.MaxRetries = 3 + rt.OnRateLimit = func(attempt int, delay time.Duration) { + callbackCalls++ + } + + req, _ := http.NewRequest("GET", s.URL, nil) + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected 200 after retries, got %d", resp.StatusCode) + } + if got := attempts.Load(); got != 3 { + t.Errorf("expected 3 total attempts, got %d", got) + } + if callbackCalls != 2 { + t.Errorf("expected 2 callback calls, got %d", callbackCalls) + } + }) + + t.Run("returns 429 after exhausting retries", func(t *testing.T) { + var attempts atomic.Int32 + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts.Add(1) + w.Header().Set("RateLimit-Reset", "1") + w.WriteHeader(http.StatusTooManyRequests) + })) + defer s.Close() + + rt := NewRateLimitTransport(http.DefaultTransport) + rt.MaxRetries = 2 + + req, _ := http.NewRequest("GET", s.URL, nil) + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusTooManyRequests { + t.Errorf("expected 429 after exhausting retries, got %d", resp.StatusCode) + } + // 1 initial + 2 retries = 3 total + if got := attempts.Load(); got != 3 { + t.Errorf("expected 3 total attempts, got %d", got) + } + }) + + t.Run("respects context cancellation during backoff", func(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("RateLimit-Reset", "60") + w.WriteHeader(http.StatusTooManyRequests) + })) + defer s.Close() + + rt := NewRateLimitTransport(http.DefaultTransport) + rt.MaxRetries = 3 + + ctx, cancel := context.WithCancel(context.Background()) + // Cancel shortly after the first 429 is received. + rt.OnRateLimit = func(attempt int, delay time.Duration) { + go func() { + time.Sleep(10 * time.Millisecond) + cancel() + }() + } + + req, _ := http.NewRequestWithContext(ctx, "GET", s.URL, nil) + _, err := rt.RoundTrip(req) + if err == nil { + t.Fatal("expected error from cancelled context, got nil") + } + if !strings.Contains(err.Error(), "context canceled") { + t.Errorf("expected context canceled error, got: %v", err) + } + }) + + t.Run("uses fallback delay when header missing", func(t *testing.T) { + var attempts atomic.Int32 + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := attempts.Add(1) + if n == 1 { + // No RateLimit-Reset header + w.WriteHeader(http.StatusTooManyRequests) + return + } + w.WriteHeader(http.StatusOK) + })) + defer s.Close() + + var gotDelay time.Duration + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rt := NewRateLimitTransport(http.DefaultTransport) + rt.MaxRetries = 1 + // Cancel quickly to avoid waiting the full fallback delay. + rt.OnRateLimit = func(attempt int, delay time.Duration) { + gotDelay = delay + cancel() + } + + req, _ := http.NewRequestWithContext(ctx, "GET", s.URL, nil) + rt.RoundTrip(req) + + if gotDelay != defaultFallbackDelay { + t.Errorf("expected fallback delay %v, got %v", defaultFallbackDelay, gotDelay) + } + }) + + t.Run("uses zero delay when header is zero", func(t *testing.T) { + var attempts atomic.Int32 + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := attempts.Add(1) + if n == 1 { + w.Header().Set("RateLimit-Reset", "0") + w.WriteHeader(http.StatusTooManyRequests) + return + } + w.WriteHeader(http.StatusOK) + })) + defer s.Close() + + var gotDelay time.Duration + rt := NewRateLimitTransport(http.DefaultTransport) + rt.MaxRetries = 1 + rt.OnRateLimit = func(attempt int, delay time.Duration) { + gotDelay = delay + } + + req, _ := http.NewRequest("GET", s.URL, nil) + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + if gotDelay != 0 { + t.Errorf("expected zero delay, got %v", gotDelay) + } + }) + + t.Run("caps delay at MaxRetryDelay", func(t *testing.T) { + var attempts atomic.Int32 + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := attempts.Add(1) + if n == 1 { + w.Header().Set("RateLimit-Reset", "3600") + w.WriteHeader(http.StatusTooManyRequests) + return + } + w.WriteHeader(http.StatusOK) + })) + defer s.Close() + + rt := NewRateLimitTransport(http.DefaultTransport) + rt.MaxRetries = 1 + rt.MaxRetryDelay = 10 * time.Millisecond + + req, _ := http.NewRequest("GET", s.URL, nil) + start := time.Now() + resp, err := rt.RoundTrip(req) + elapsed := time.Since(start) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected 200, got %d", resp.StatusCode) + } + if elapsed > 1*time.Second { + t.Errorf("expected delay to be capped, but took %v", elapsed) + } + }) + + t.Run("replays request body on retry", func(t *testing.T) { + var attempts atomic.Int32 + var bodies []string + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b, _ := io.ReadAll(r.Body) + bodies = append(bodies, string(b)) + n := attempts.Add(1) + if n == 1 { + w.Header().Set("RateLimit-Reset", "1") + w.WriteHeader(http.StatusTooManyRequests) + return + } + w.WriteHeader(http.StatusOK) + })) + defer s.Close() + + rt := NewRateLimitTransport(http.DefaultTransport) + rt.MaxRetries = 1 + + body := `{"key":"value"}` + req, _ := http.NewRequest("POST", s.URL, strings.NewReader(body)) + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + if len(bodies) != 2 { + t.Fatalf("expected 2 requests, got %d", len(bodies)) + } + for i, got := range bodies { + if got != body { + t.Errorf("attempt %d: body = %q, want %q", i, got, body) + } + } + }) +} + +func TestParseRateLimitReset(t *testing.T) { + tests := []struct { + name string + header string + expected time.Duration + ok bool + }{ + {"valid seconds", "30", 30 * time.Second, true}, + {"one second", "1", 1 * time.Second, true}, + {"empty", "", 0, false}, + {"negative", "-1", 0, false}, + {"zero means retry now", "0", 0, true}, + {"non-numeric", "abc", 0, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resp := &http.Response{Header: http.Header{}} + if tt.header != "" { + resp.Header.Set("RateLimit-Reset", tt.header) + } + got, ok := parseRateLimitReset(resp) + if got != tt.expected { + t.Errorf("parseRateLimitReset(%q) = %v, want %v", tt.header, got, tt.expected) + } + if ok != tt.ok { + t.Errorf("parseRateLimitReset(%q) ok = %v, want %v", tt.header, ok, tt.ok) + } + }) + } +} diff --git a/pkg/cmd/factory/factory.go b/pkg/cmd/factory/factory.go index a1666198..56219fb8 100644 --- a/pkg/cmd/factory/factory.go +++ b/pkg/cmd/factory/factory.go @@ -37,6 +37,7 @@ type FactoryOpt func(*factoryConfig) type factoryConfig struct { debug bool orgOverride string + transport http.RoundTripper } // WithDebug enables debug output for REST API calls @@ -55,6 +56,14 @@ func WithOrgOverride(org string) FactoryOpt { } } +// WithTransport sets a custom http.RoundTripper for the REST API client. +// It is composed with the debug transport when debug mode is enabled. +func WithTransport(t http.RoundTripper) FactoryOpt { + return func(c *factoryConfig) { + c.transport = t + } +} + // debugTransport wraps an http.RoundTripper and logs requests/responses with sensitive headers redacted type debugTransport struct { transport http.RoundTripper @@ -163,15 +172,16 @@ func New(opts ...FactoryOpt) (*Factory, error) { buildkite.WithUserAgent(userAgent), } - // Use our own debug transport with redacted headers instead of go-buildkite's built-in debug + // Use custom transport if provided (caller is responsible for wrapping + // http.DefaultTransport if needed), then wrap with debug transport when enabled. + transport := http.RoundTripper(http.DefaultTransport) + if cfg.transport != nil { + transport = cfg.transport + } if cfg.debug { - httpClient := &http.Client{ - Transport: &debugTransport{ - transport: http.DefaultTransport, - }, - } - clientOpts = append(clientOpts, buildkite.WithHTTPClient(httpClient)) + transport = &debugTransport{transport: transport} } + clientOpts = append(clientOpts, buildkite.WithHTTPClient(&http.Client{Transport: transport})) buildkiteClient, err := buildkite.NewOpts(clientOpts...) if err != nil { From 901b59c018df77a39ab7e8a93ad329053835dbc3 Mon Sep 17 00:00:00 2001 From: Ben McNicholl Date: Fri, 10 Apr 2026 10:55:13 +1000 Subject: [PATCH 2/2] wrap api limit output in a conditional for debug flag use --- cmd/preflight/preflight.go | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/cmd/preflight/preflight.go b/cmd/preflight/preflight.go index 6241aec7..bb5f22e9 100644 --- a/cmd/preflight/preflight.go +++ b/cmd/preflight/preflight.go @@ -93,12 +93,14 @@ func (c *PreflightCmd) Run(kongCtx *kong.Context, globals cli.GlobalFlags) error renderer := newRenderer(os.Stdout, c.JSON, c.Text, stop) rlTransport.OnRateLimit = func(attempt int, delay time.Duration) { - _ = renderer.Render(Event{ - Type: EventOperation, - Time: time.Now(), - PreflightID: preflightID.String(), - Title: fmt.Sprintf("Rate limited by API, waiting %s before retrying (attempt %d/%d)...", delay.Truncate(time.Second), attempt+1, rlTransport.MaxRetries), - }) + if globals.EnableDebug() { + _ = renderer.Render(Event{ + Type: EventOperation, + Time: time.Now(), + PreflightID: preflightID.String(), + Title: fmt.Sprintf("Rate limited by API, waiting %s before retrying (attempt %d/%d)...", delay.Truncate(time.Second), attempt+1, rlTransport.MaxRetries), + }) + } } _ = renderer.Render(Event{Type: EventOperation, Time: time.Now(), PreflightID: preflightID.String(), Title: "Pushing snapshot of working tree..."})