Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remotes/docker/authorizer.go: refresh OAuth tokens when they expire #8735

Merged
merged 1 commit into from Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 10 additions & 10 deletions core/remotes/docker/auth/fetch.go
Expand Up @@ -86,11 +86,11 @@ type TokenOptions struct {

// OAuthTokenResponse is response from fetching token with a OAuth POST request
type OAuthTokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
IssuedAt time.Time `json:"issued_at"`
Scope string `json:"scope"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresInSeconds int `json:"expires_in"`
IssuedAt time.Time `json:"issued_at"`
Scope string `json:"scope"`
}

// FetchTokenWithOAuth fetches a token using a POST request
Expand Down Expand Up @@ -152,11 +152,11 @@ func FetchTokenWithOAuth(ctx context.Context, client *http.Client, headers http.

// FetchTokenResponse is response from fetching token with GET request
type FetchTokenResponse struct {
Token string `json:"token"`
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
IssuedAt time.Time `json:"issued_at"`
RefreshToken string `json:"refresh_token"`
Token string `json:"token"`
AccessToken string `json:"access_token"`
ExpiresInSeconds int `json:"expires_in"`
IssuedAt time.Time `json:"issued_at"`
RefreshToken string `json:"refresh_token"`
}

// FetchToken fetches a token using a GET request
Expand Down
27 changes: 22 additions & 5 deletions core/remotes/docker/authorizer.go
Expand Up @@ -24,6 +24,7 @@ import (
"net/http"
"strings"
"sync"
"time"

"github.com/containerd/containerd/v2/core/remotes/docker/auth"
remoteerrors "github.com/containerd/containerd/v2/core/remotes/errors"
Expand Down Expand Up @@ -205,9 +206,10 @@ func (a *dockerAuthorizer) AddResponses(ctx context.Context, responses []*http.R
// authResult is used to control limit rate.
type authResult struct {
sync.WaitGroup
token string
refreshToken string
err error
token string
refreshToken string
expirationTime *time.Time
err error
}

// authHandler is used to handle auth request per registry server.
Expand Down Expand Up @@ -270,8 +272,12 @@ func (ah *authHandler) doBearerAuth(ctx context.Context) (token, refreshToken st
// Docs: https://docs.docker.com/registry/spec/auth/scope
scoped := strings.Join(to.Scopes, " ")

// Keep track of the expiration time of cached bearer tokens so they can be
// refreshed when they expire without a server roundtrip.
var expirationTime *time.Time

ah.Lock()
if r, exist := ah.scopedTokens[scoped]; exist {
if r, exist := ah.scopedTokens[scoped]; exist && (r.expirationTime == nil || r.expirationTime.After(time.Now())) {
ah.Unlock()
r.Wait()
return r.token, r.refreshToken, r.err
Expand All @@ -285,7 +291,7 @@ func (ah *authHandler) doBearerAuth(ctx context.Context) (token, refreshToken st

defer func() {
token = fmt.Sprintf("Bearer %s", token)
r.token, r.refreshToken, r.err = token, refreshToken, err
r.token, r.refreshToken, r.err, r.expirationTime = token, refreshToken, err, expirationTime
r.Done()
}()

Expand All @@ -311,6 +317,7 @@ func (ah *authHandler) doBearerAuth(ctx context.Context) (token, refreshToken st
if err != nil {
return "", "", err
}
expirationTime = getExpirationTime(resp.ExpiresInSeconds)
return resp.Token, resp.RefreshToken, nil
}
log.G(ctx).WithFields(log.Fields{
Expand All @@ -320,16 +327,26 @@ func (ah *authHandler) doBearerAuth(ctx context.Context) (token, refreshToken st
}
return "", "", err
}
expirationTime = getExpirationTime(resp.ExpiresInSeconds)
return resp.AccessToken, resp.RefreshToken, nil
}
// do request anonymously
resp, err := auth.FetchToken(ctx, ah.client, ah.header, to)
if err != nil {
return "", "", fmt.Errorf("failed to fetch anonymous token: %w", err)
}
expirationTime = getExpirationTime(resp.ExpiresInSeconds)
return resp.Token, resp.RefreshToken, nil
}

func getExpirationTime(expiresInSeconds int) *time.Time {
if expiresInSeconds <= 0 {
return nil
}
expirationTime := time.Now().Add(time.Duration(expiresInSeconds) * time.Second)
return &expirationTime
}

func invalidAuthorization(ctx context.Context, c auth.Challenge, responses []*http.Response) (retry bool, _ error) {
errStr := c.Parameters["error"]
if errStr == "" {
Expand Down