Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,12 @@ func (c *Client) bareDoUntilFound(ctx context.Context, req *http.Request, maxRed
return nil, nil, errInvalidLocation
}
newURL := c.BaseURL.ResolveReference(rerr.Location)
// Refuse to follow a permanent redirect to a different host:
// req.Clone preserves Authorization headers added by the auth
// transport, so a cross-host target would leak credentials.
if newURL.Host != c.BaseURL.Host {
return nil, response, fmt.Errorf("refusing to follow cross-host redirect from %q to %q", c.BaseURL.Host, newURL.Host)
}
newRequest := req.Clone(ctx)
newRequest.URL = newURL
return c.bareDoUntilFound(ctx, newRequest, maxRedirects-1)
Expand Down Expand Up @@ -1846,11 +1852,35 @@ func (c *Client) roundTripWithOptionalFollowRedirect(ctx context.Context, u stri
if maxRedirects > 0 && resp.StatusCode == http.StatusMovedPermanently {
_ = resp.Body.Close()
u = resp.Header.Get("Location")
if err := c.checkRedirectHost(u); err != nil {
return nil, err
}
resp, err = c.roundTripWithOptionalFollowRedirect(ctx, u, maxRedirects-1, opts...)
}
return resp, err
}

// checkRedirectHost returns an error if the redirect target is on a different
// host than the client's configured BaseURL. This prevents credentials attached
// by the auth transport from being sent to an attacker-controlled host when a
// compromised or malicious API response returns a cross-origin Location header.
// An empty Location is also rejected.
func (c *Client) checkRedirectHost(location string) error {
if location == "" {
return errInvalidLocation
}
target, err := url.Parse(location)
if err != nil {
return fmt.Errorf("invalid redirect location %q: %w", location, err)
}
// Resolve relative locations against BaseURL so relative paths are allowed.
target = c.BaseURL.ResolveReference(target)
if target.Host != c.BaseURL.Host {
return fmt.Errorf("refusing to follow cross-host redirect from %q to %q", c.BaseURL.Host, target.Host)
}
return nil
}

// Ptr is a helper routine that allocates a new T value
// to store v and returns a pointer to it.
func Ptr[T any](v T) *T {
Expand Down
74 changes: 74 additions & 0 deletions github/github_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2256,6 +2256,80 @@ func TestBareDoUntilFound_UnexpectedRedirection(t *testing.T) {
}
}

// TestBareDoUntilFound_RejectsCrossHostRedirect verifies that bareDoUntilFound
// refuses to follow a 301 redirect whose Location points to a different host,
// which would otherwise leak the Authorization header (added by the auth
// transport) to an attacker-controlled server.
func TestBareDoUntilFound_RejectsCrossHostRedirect(t *testing.T) {
t.Parallel()
client, mux, _ := setup(t)

mux.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Location", "https://evil.example.com/steal")
w.WriteHeader(http.StatusMovedPermanently)
})

req, _ := client.NewRequest("GET", ".", nil)
_, _, err := client.bareDoUntilFound(t.Context(), req, 1)
if err == nil {
t.Fatal("Expected cross-host redirect to be rejected, got nil error.")
}
if !strings.Contains(err.Error(), "cross-host redirect") {
t.Errorf("Expected cross-host redirect error, got: %v", err)
}
}

// TestRoundTripWithOptionalFollowRedirect_RejectsCrossHostRedirect verifies
// that roundTripWithOptionalFollowRedirect refuses to follow a 301 redirect to
// a different host, preventing Authorization-header leakage to attacker-
// controlled servers via a malicious or compromised API response.
func TestRoundTripWithOptionalFollowRedirect_RejectsCrossHostRedirect(t *testing.T) {
t.Parallel()
client, mux, _ := setup(t)

mux.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Location", "https://evil.example.com/steal")
w.WriteHeader(http.StatusMovedPermanently)
})

_, err := client.roundTripWithOptionalFollowRedirect(t.Context(), ".", 1)
if err == nil {
t.Fatal("Expected cross-host redirect to be rejected, got nil error.")
}
if !strings.Contains(err.Error(), "cross-host redirect") {
t.Errorf("Expected cross-host redirect error, got: %v", err)
}
}

// TestRoundTripWithOptionalFollowRedirect_AllowsSameHostRedirect ensures the
// cross-host check does not break legitimate same-host 301 follow behavior
// (the path that rate-limit redirection relies on).
func TestRoundTripWithOptionalFollowRedirect_AllowsSameHostRedirect(t *testing.T) {
t.Parallel()
client, mux, _ := setup(t)

var followed atomic.Bool
mux.HandleFunc("/archive", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Location", baseURLPath+"/archive-target")
w.WriteHeader(http.StatusMovedPermanently)
})
mux.HandleFunc("/archive-target", func(w http.ResponseWriter, _ *http.Request) {
followed.Store(true)
w.WriteHeader(http.StatusOK)
})

resp, err := client.roundTripWithOptionalFollowRedirect(t.Context(), "archive", 2)
if err != nil {
t.Fatalf("Unexpected error on same-host redirect: %v", err)
}
if resp != nil && resp.Body != nil {
resp.Body.Close()
}
if !followed.Load() {
t.Error("Expected same-host redirect to be followed.")
}
}

func TestSanitizeURL(t *testing.T) {
t.Parallel()
tests := []struct {
Expand Down
Loading