Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ import (
func TestListBundleIDs(t *testing.T) {
keyID, issuerID, privateKey, enterpriseAccount := getAPIKey(t)

client := appstoreconnect.NewClient(appstoreconnect.NewRetryableHTTPClient(), keyID, issuerID, []byte(privateKey), enterpriseAccount, appstoreconnect.NoOpAnalyticsTracker{})
tracker := appstoreconnect.NoOpAnalyticsTracker{}
client := appstoreconnect.NewClient(appstoreconnect.NewRetryableHTTPClient(tracker), keyID, issuerID, []byte(privateKey), enterpriseAccount, tracker)

response, err := client.Provisioning.ListBundleIDs(&appstoreconnect.ListBundleIDsOptions{})
require.NoError(t, err)
Expand Down
24 changes: 17 additions & 7 deletions autocodesign/devportalclient/appstoreconnect/appstoreconnect.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,23 @@ type Client struct {
}

// NewRetryableHTTPClient create a new http client with retry settings.
func NewRetryableHTTPClient() *http.Client {
func NewRetryableHTTPClient(tracker Tracker) *http.Client {
client := retry.NewHTTPClient()

trackingTransport := newTrackingRoundTripper(client.HTTPClient.Transport, tracker)
client.HTTPClient.Transport = trackingTransport

// RequestLogHook is called before each retry (attemptNum > 0 for retries, 0 for initial request).
// We mark retry attempts in the request context so RoundTrip can track which attempts are retries.
// We use pointer dereference (*req = *...) to modify the request in-place because RequestLogHook
// doesn't return a value - it modifies the request through side effects. This updates the request's
// context field, which will be present when RoundTrip is called immediately after.
client.RequestLogHook = func(_ retryablehttp.Logger, req *http.Request, attemptNum int) {
if attemptNum > 0 {
*req = *trackingTransport.markAsRetry(req)
}
}

client.CheckRetry = func(ctx context.Context, resp *http.Response, err error) (bool, error) {
if resp != nil && resp.StatusCode == http.StatusUnauthorized {
log.Debugf("Received HTTP 401 (Unauthorized), retrying request...")
Expand Down Expand Up @@ -115,6 +130,7 @@ func NewRetryableHTTPClient() *http.Client {

return shouldRetry, err
}

return client.StandardClient()
}

Expand Down Expand Up @@ -245,8 +261,6 @@ func (c *Client) Debugf(format string, v ...interface{}) {

// Do ...
func (c *Client) Do(req *http.Request, v interface{}) (*http.Response, error) {
startTime := time.Now()

c.Debugf("Request:")
if c.EnableDebugLogs {
if err := httputil.PrintRequest(req); err != nil {
Expand All @@ -255,7 +269,6 @@ func (c *Client) Do(req *http.Request, v interface{}) (*http.Response, error) {
}

resp, err := c.client.Do(req)
duration := time.Since(startTime)

c.Debugf("Response:")
if c.EnableDebugLogs {
Expand All @@ -276,13 +289,10 @@ func (c *Client) Do(req *http.Request, v interface{}) (*http.Response, error) {
}()

if err := checkResponse(resp); err != nil {
c.tracker.TrackAPIRequest(req.Method, req.URL.Host, req.URL.Path, resp.StatusCode, duration)
c.tracker.TrackAPIError(req.Method, req.URL.Host, req.URL.Path, resp.StatusCode, err.Error())
return resp, err
}

c.tracker.TrackAPIRequest(req.Method, req.URL.Host, req.URL.Path, resp.StatusCode, duration)

if v != nil {
decErr := json.NewDecoder(resp.Body).Decode(v)
if decErr == io.EOF {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ import (
)

func TestNewClient(t *testing.T) {
got := NewClient(NewRetryableHTTPClient(), "keyID", "issuerID", []byte{}, false, NoOpAnalyticsTracker{})
tracker := NoOpAnalyticsTracker{}
got := NewClient(NewRetryableHTTPClient(tracker), "keyID", "issuerID", []byte{}, false, tracker)

require.Equal(t, "appstoreconnect-v1", got.audience)

Expand All @@ -27,7 +28,8 @@ func TestNewClient(t *testing.T) {
}

func TestNewEnterpriseClient(t *testing.T) {
got := NewClient(NewRetryableHTTPClient(), "keyID", "issuerID", []byte{}, true, NoOpAnalyticsTracker{})
tracker := NoOpAnalyticsTracker{}
got := NewClient(NewRetryableHTTPClient(tracker), "keyID", "issuerID", []byte{}, true, tracker)

require.Equal(t, "apple-developer-enterprise-v1", got.audience)

Expand All @@ -48,6 +50,7 @@ type apiRequestRecord struct {
endpoint string
statusCode int
duration time.Duration
isRetry bool
}

type apiErrorRecord struct {
Expand All @@ -58,13 +61,14 @@ type apiErrorRecord struct {
errorMessage string
}

func (m *mockAnalyticsTracker) TrackAPIRequest(method, host, endpoint string, statusCode int, duration time.Duration) {
func (m *mockAnalyticsTracker) TrackAPIRequest(method, host, endpoint string, statusCode int, duration time.Duration, isRetry bool) {
m.apiRequests = append(m.apiRequests, apiRequestRecord{
method: method,
host: host,
endpoint: endpoint,
statusCode: statusCode,
duration: duration,
isRetry: isRetry,
})
}

Expand Down Expand Up @@ -93,6 +97,11 @@ func (m *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
return m.resp, m.err
}

func (m *mockHTTPClient) RoundTrip(req *http.Request) (*http.Response, error) {
m.called = true
return m.resp, m.err
}

func TestTracking(t *testing.T) {
t.Run("successful request", func(t *testing.T) {
mockTracker := &mockAnalyticsTracker{}
Expand All @@ -103,8 +112,11 @@ func TestTracking(t *testing.T) {
}))
defer server.Close()

httpClient := &http.Client{}
httpClient.Transport = newTrackingRoundTripper(httpClient.Transport, mockTracker)

client := &Client{
client: &http.Client{},
client: httpClient,
tracker: mockTracker,
}

Expand Down Expand Up @@ -132,16 +144,19 @@ func TestTracking(t *testing.T) {

t.Run("error response", func(t *testing.T) {
mockTracker := &mockAnalyticsTracker{}
mockHTTPClient := &mockHTTPClient{
mockTransport := &mockHTTPClient{
resp: &http.Response{
StatusCode: 400,
Body: io.NopCloser(strings.NewReader(`{"errors": [{"code": "PARAMETER_ERROR.INVALID", "title": "Invalid parameter"}]}`)),
Header: http.Header{},
},
}

httpClient := &http.Client{}
httpClient.Transport = newTrackingRoundTripper(mockTransport, mockTracker)

client := &Client{
client: mockHTTPClient,
client: httpClient,
tracker: mockTracker,
}

Expand All @@ -150,7 +165,7 @@ func TestTracking(t *testing.T) {
_, err = client.Do(req, nil)
require.Error(t, err, "Expected error due to 400 Bad Request response")

require.True(t, mockHTTPClient.called, "Expected HTTP client to be called")
require.True(t, mockTransport.called, "Expected HTTP client to be called")

require.Len(t, mockTracker.apiRequests, 1, "Expected 1 (failed) API requests tracked")
require.Len(t, mockTracker.apiErrors, 1, "Expected 1 API error tracked")
Expand All @@ -163,12 +178,15 @@ func TestTracking(t *testing.T) {
t.Run("network error", func(t *testing.T) {
mockTracker := &mockAnalyticsTracker{}

mockHTTPClient := &mockHTTPClient{
mockTransport := &mockHTTPClient{
err: errors.New("network connection failed"),
}

httpClient := &http.Client{}
httpClient.Transport = newTrackingRoundTripper(mockTransport, mockTracker)

client := &Client{
client: mockHTTPClient,
client: httpClient,
tracker: mockTracker,
}

Expand All @@ -177,12 +195,12 @@ func TestTracking(t *testing.T) {
_, err = client.Do(req, nil)
require.Error(t, err)

require.Len(t, mockTracker.apiRequests, 0, "Expected 0 API requests tracked")
require.Len(t, mockTracker.apiRequests, 1, "Expected 1 API request tracked (even though it failed)")
require.Len(t, mockTracker.apiErrors, 1, "Expected 1 API error tracked")

record := mockTracker.apiErrors[0]
require.Equal(t, "GET", record.method)
require.Equal(t, 0, record.statusCode)
require.Equal(t, "network connection failed", record.errorMessage)
require.Contains(t, record.errorMessage, "network connection failed")
})
}
59 changes: 59 additions & 0 deletions autocodesign/devportalclient/appstoreconnect/roundtripper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package appstoreconnect

import (
"context"
"net/http"
"time"
)

// trackingRoundTripper wraps an http.RoundTripper and tracks metrics for each HTTP attempt,
// including retries. It measures per-attempt duration without retry wait times included,
// allowing accurate tracking even when requests are retried due to rate limits or other errors.
type trackingRoundTripper struct {
wrapped http.RoundTripper
tracker Tracker
}

// isRetryContextKey is used to store whether a request is a retry attempt in the request context.
type isRetryContextKey struct{}

func newTrackingRoundTripper(wrapped http.RoundTripper, tracker Tracker) *trackingRoundTripper {
if wrapped == nil {
wrapped = http.DefaultTransport
}
return &trackingRoundTripper{
wrapped: wrapped,
tracker: tracker,
}
}

// markAsRetry stores a flag in the request context indicating this is a retry attempt.
// It returns a new request with the updated context. This approach avoids shared state
// between concurrent requests (e.g., multiple POSTs to the same endpoint with different bodies).
func (t *trackingRoundTripper) markAsRetry(req *http.Request) *http.Request {
ctx := context.WithValue(req.Context(), isRetryContextKey{}, true)
return req.WithContext(ctx)
}

// RoundTrip executes an HTTP request and tracks its duration and retry status.
// Each HTTP attempt (including retries) generates a separate metric event, allowing
// accurate alerting based on individual response times rather than aggregate times
// that include retry backoff delays.
func (t *trackingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
// Check if this request was marked as a retry by RequestLogHook
isRetry := req.Context().Value(isRetryContextKey{}) != nil

startTime := time.Now()
resp, err := t.wrapped.RoundTrip(req)
duration := time.Since(startTime)

statusCode := 0
if resp != nil {
statusCode = resp.StatusCode
}

// Track this attempt with its actual duration (no retry waits included)
t.tracker.TrackAPIRequest(req.Method, req.URL.Host, req.URL.Path, statusCode, duration, isRetry)

return resp, err
}
110 changes: 110 additions & 0 deletions autocodesign/devportalclient/appstoreconnect/roundtripper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package appstoreconnect

import (
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"

"github.com/stretchr/testify/require"
)

type attemptTracker struct {
mu sync.Mutex
attempts []attemptRecord
}

type attemptRecord struct {
method string
host string
endpoint string
statusCode int
duration time.Duration
isRetry bool
}

func (a *attemptTracker) TrackAPIRequest(method, host, endpoint string, statusCode int, duration time.Duration, isRetry bool) {
a.mu.Lock()
defer a.mu.Unlock()
a.attempts = append(a.attempts, attemptRecord{
method: method,
host: host,
endpoint: endpoint,
statusCode: statusCode,
duration: duration,
isRetry: isRetry,
})
}

func (a *attemptTracker) TrackAPIError(method, host, endpoint string, statusCode int, errorMessage string) {
}

func (a *attemptTracker) TrackAuthError(errorMessage string) {
}

func TestTrackingRoundTripper(t *testing.T) {
t.Run("tracks single successful request", func(t *testing.T) {
tracker := &attemptTracker{}

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()

transport := newTrackingRoundTripper(http.DefaultTransport, tracker)
client := &http.Client{Transport: transport}

req, err := http.NewRequest("GET", server.URL+"/test", nil)
require.NoError(t, err)

resp, err := client.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)

tracker.mu.Lock()
defer tracker.mu.Unlock()

require.Len(t, tracker.attempts, 1)
require.False(t, tracker.attempts[0].isRetry)
require.Equal(t, http.StatusOK, tracker.attempts[0].statusCode)
})

t.Run("tracks multiple attempts for same request", func(t *testing.T) {
tracker := &attemptTracker{}
attemptCount := 0

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attemptCount++
if attemptCount < 3 {
w.WriteHeader(http.StatusTooManyRequests)
} else {
w.WriteHeader(http.StatusOK)
}
}))
defer server.Close()

retryableClient := NewRetryableHTTPClient(tracker)

req, err := http.NewRequest("GET", server.URL+"/test", nil)
require.NoError(t, err)

resp, err := retryableClient.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)

tracker.mu.Lock()
defer tracker.mu.Unlock()

require.Len(t, tracker.attempts, 3, "Expected 3 attempts to be tracked")

require.False(t, tracker.attempts[0].isRetry)
require.Equal(t, http.StatusTooManyRequests, tracker.attempts[0].statusCode)

require.True(t, tracker.attempts[1].isRetry)
require.Equal(t, http.StatusTooManyRequests, tracker.attempts[1].statusCode)

require.True(t, tracker.attempts[2].isRetry)
require.Equal(t, http.StatusOK, tracker.attempts[2].statusCode)
})
}
Loading
Loading