Skip to content

Commit

Permalink
cache authenticator retrieved when login to a provider
Browse files Browse the repository at this point in the history
Signed-off-by: Soule BA <bah.soule@gmail.com>
  • Loading branch information
souleb committed Jun 15, 2024
1 parent 61276f4 commit e6b2983
Show file tree
Hide file tree
Showing 13 changed files with 680 additions and 78 deletions.
63 changes: 49 additions & 14 deletions oci/auth/aws/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"regexp"
"strings"
"sync"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
Expand Down Expand Up @@ -78,11 +79,7 @@ func (c *Client) WithConfig(cfg *aws.Config) {
// be the case if it's running in EKS, and may need additional setup
// otherwise (visit https://aws.github.io/aws-sdk-go-v2/docs/configuring-sdk/
// as a starting point).
func (c *Client) getLoginAuth(ctx context.Context, awsEcrRegion string) (authn.AuthConfig, error) {
// No caching of tokens is attempted; the quota for getting an
// auth token is high enough that getting a token every time you
// scan an image is viable for O(500) images per region. See
// https://docs.aws.amazon.com/general/latest/gr/ecr.html.
func (c *Client) getLoginAuth(ctx context.Context, awsEcrRegion string) (authn.AuthConfig, *time.Time, error) {
var authConfig authn.AuthConfig
var cfg aws.Config

Expand All @@ -94,7 +91,7 @@ func (c *Client) getLoginAuth(ctx context.Context, awsEcrRegion string) (authn.A
cfg, err = config.LoadDefaultConfig(ctx, config.WithRegion(awsEcrRegion))
if err != nil {
c.mu.Unlock()
return authConfig, fmt.Errorf("failed to load default configuration: %w", err)
return authConfig, nil, fmt.Errorf("failed to load default configuration: %w", err)
}
c.config = &cfg
}
Expand All @@ -105,31 +102,52 @@ func (c *Client) getLoginAuth(ctx context.Context, awsEcrRegion string) (authn.A
// pass nil input.
ecrToken, err := ecrService.GetAuthorizationToken(ctx, nil)
if err != nil {
return authConfig, err
return authConfig, nil, err
}

// Validate the authorization data.
if len(ecrToken.AuthorizationData) == 0 {
return authConfig, errors.New("no authorization data")
return authConfig, nil, errors.New("no authorization data")
}
if ecrToken.AuthorizationData[0].AuthorizationToken == nil {
return authConfig, fmt.Errorf("no authorization token")
return authConfig, nil, fmt.Errorf("no authorization token")
}
token, err := base64.StdEncoding.DecodeString(*ecrToken.AuthorizationData[0].AuthorizationToken)
if err != nil {
return authConfig, err
return authConfig, nil, err
}

tokenSplit := strings.Split(string(token), ":")
// Validate the tokens.
if len(tokenSplit) != 2 {
return authConfig, fmt.Errorf("invalid authorization token, expected the token to have two parts separated by ':', got %d parts", len(tokenSplit))
return authConfig, nil, fmt.Errorf("invalid authorization token, expected the token to have two parts separated by ':', got %d parts", len(tokenSplit))
}
authConfig = authn.AuthConfig{
Username: tokenSplit[0],
Password: tokenSplit[1],
}
return authConfig, nil
return authConfig, ecrToken.AuthorizationData[0].ExpiresAt, nil
}

// LoginWithExpiry attempts to get the authentication material for ECR.
// It returns the authentication material and the expiry time of the token.
func (c *Client) LoginWithExpiry(ctx context.Context, autoLogin bool, image string) (authn.Authenticator, *time.Time, error) {
if autoLogin {
log.FromContext(ctx).Info("logging in to AWS ECR for " + image)
_, awsEcrRegion, ok := ParseRegistry(image)
if !ok {
return nil, nil, errors.New("failed to parse AWS ECR image, invalid ECR image")
}

authConfig, expiresAt, err := c.getLoginAuth(ctx, awsEcrRegion)
if err != nil {
return nil, nil, err
}

auth := authn.FromConfig(authConfig)
return auth, expiresAt, nil
}
return nil, nil, fmt.Errorf("ECR authentication failed: %w", oci.ErrUnconfiguredProvider)
}

// Login attempts to get the authentication material for ECR.
Expand All @@ -141,7 +159,7 @@ func (c *Client) Login(ctx context.Context, autoLogin bool, image string) (authn
return nil, errors.New("failed to parse AWS ECR image, invalid ECR image")
}

authConfig, err := c.getLoginAuth(ctx, awsEcrRegion)
authConfig, _, err := c.getLoginAuth(ctx, awsEcrRegion)
if err != nil {
return nil, err
}
Expand All @@ -152,14 +170,31 @@ func (c *Client) Login(ctx context.Context, autoLogin bool, image string) (authn
return nil, fmt.Errorf("ECR authentication failed: %w", oci.ErrUnconfiguredProvider)
}

// OIDCLoginWithExpiry attempts to get the authentication material for ECR.
// It returns the authentication material and the expiry time of the token.
func (c *Client) OIDCLoginWithExpiry(ctx context.Context, registryURL string) (authn.Authenticator, *time.Time, error) {
_, awsEcrRegion, ok := ParseRegistry(registryURL)
if !ok {
return nil, nil, errors.New("failed to parse AWS ECR image, invalid ECR image")
}

authConfig, expiresAt, err := c.getLoginAuth(ctx, awsEcrRegion)
if err != nil {
return nil, nil, err
}

auth := authn.FromConfig(authConfig)
return auth, expiresAt, nil
}

// OIDCLogin attempts to get the authentication material for ECR.
func (c *Client) OIDCLogin(ctx context.Context, registryURL string) (authn.Authenticator, error) {
_, awsEcrRegion, ok := ParseRegistry(registryURL)
if !ok {
return nil, errors.New("failed to parse AWS ECR image, invalid ECR image")
}

authConfig, err := c.getLoginAuth(ctx, awsEcrRegion)
authConfig, _, err := c.getLoginAuth(ctx, awsEcrRegion)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion oci/auth/aws/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func TestGetLoginAuth(t *testing.T) {
cfg.Credentials = credentials.NewStaticCredentialsProvider("x", "y", "z")
ec.WithConfig(cfg)

a, err := ec.getLoginAuth(context.TODO(), "us-east-1")
a, _, err := ec.getLoginAuth(context.TODO(), "us-east-1")
g.Expect(err != nil).To(Equal(tt.wantErr))
if tt.statusCode == http.StatusOK {
g.Expect(a).To(Equal(tt.wantAuthConfig))
Expand Down
59 changes: 52 additions & 7 deletions oci/auth/azure/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"fmt"
"strings"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
_ "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
Expand All @@ -33,6 +34,11 @@ import (
"github.com/fluxcd/pkg/oci"
)

// Default cache expiration time in seconds for ACR refresh token.
// TODO @souleb: This is copied from https://github.com/Azure/msi-acrpull/blob/0ca921a7740e561c7204d9c3b3b55c4e0b9bd7b9/pkg/authorizer/token_retriever.go#L21C2-L21C39
// as it not provided by the Azure SDK. See with the Azure SDK team to see if there is a better way to get this value.
const defaultCacheExpirationInSeconds = 600

// Client is an Azure ACR client which can log into the registry and return
// authorization information.
type Client struct {
Expand Down Expand Up @@ -60,7 +66,7 @@ func (c *Client) WithScheme(scheme string) *Client {
// getLoginAuth returns authentication for ACR. The details needed for authentication
// are gotten from environment variable so there is no need to mount a host path.
// The endpoint is the registry server and will be queried for OAuth authorization token.
func (c *Client) getLoginAuth(ctx context.Context, registryURL string) (authn.AuthConfig, error) {
func (c *Client) getLoginAuth(ctx context.Context, registryURL string) (authn.AuthConfig, *time.Time, error) {
var authConfig authn.AuthConfig

// Use default credentials if no token credential is provided.
Expand All @@ -69,7 +75,7 @@ func (c *Client) getLoginAuth(ctx context.Context, registryURL string) (authn.Au
if c.credential == nil {
cred, err := azidentity.NewDefaultAzureCredential(nil)
if err != nil {
return authConfig, err
return authConfig, nil, err
}
c.credential = cred
}
Expand All @@ -80,22 +86,24 @@ func (c *Client) getLoginAuth(ctx context.Context, registryURL string) (authn.Au
Scopes: []string{configurationEnvironment.Services[cloud.ResourceManager].Endpoint + "/" + ".default"},
})
if err != nil {
return authConfig, err
return authConfig, nil, err
}

// Obtain ACR access token using exchanger.
ex := newExchanger(registryURL)
accessToken, err := ex.ExchangeACRAccessToken(string(armToken.Token))
if err != nil {
return authConfig, fmt.Errorf("error exchanging token: %w", err)
return authConfig, nil, fmt.Errorf("error exchanging token: %w", err)
}

expiresAt := time.Now().Add(defaultCacheExpirationInSeconds * time.Second)

return authn.AuthConfig{
// This is the acr username used by Azure
// See documentation: https://docs.microsoft.com/en-us/azure/container-registry/container-registry-authentication?tabs=azure-cli#az-acr-login-with---expose-token
Username: "00000000-0000-0000-0000-000000000000",
Password: accessToken,
}, nil
}, &expiresAt, nil
}

// getCloudConfiguration returns the cloud configuration based on the registry URL.
Expand All @@ -122,6 +130,27 @@ func ValidHost(host string) bool {
return false
}

// LoginWithExpiry attempts to get the authentication material for ACR.
// It returns the authentication material and the expiry time of the token.
// The caller can ensure that the passed image is a valid ACR image using ValidHost().
func (c *Client) LoginWithExpiry(ctx context.Context, autoLogin bool, image string, ref name.Reference) (authn.Authenticator, *time.Time, error) {
if autoLogin {
log.FromContext(ctx).Info("logging in to Azure ACR for " + image)
// get registry host from image
strArr := strings.SplitN(image, "/", 2)
endpoint := fmt.Sprintf("%s://%s", c.scheme, strArr[0])
authConfig, expiresAt, err := c.getLoginAuth(ctx, endpoint)
if err != nil {
log.FromContext(ctx).Info("error logging into ACR " + err.Error())
return nil, nil, err
}

auth := authn.FromConfig(authConfig)
return auth, expiresAt, nil
}
return nil, nil, fmt.Errorf("ACR authentication failed: %w", oci.ErrUnconfiguredProvider)
}

// Login attempts to get the authentication material for ACR. The caller can
// ensure that the passed image is a valid ACR image using ValidHost().
func (c *Client) Login(ctx context.Context, autoLogin bool, image string, ref name.Reference) (authn.Authenticator, error) {
Expand All @@ -130,7 +159,7 @@ func (c *Client) Login(ctx context.Context, autoLogin bool, image string, ref na
// get registry host from image
strArr := strings.SplitN(image, "/", 2)
endpoint := fmt.Sprintf("%s://%s", c.scheme, strArr[0])
authConfig, err := c.getLoginAuth(ctx, endpoint)
authConfig, _, err := c.getLoginAuth(ctx, endpoint)
if err != nil {
log.FromContext(ctx).Info("error logging into ACR " + err.Error())
return nil, err
Expand All @@ -142,12 +171,28 @@ func (c *Client) Login(ctx context.Context, autoLogin bool, image string, ref na
return nil, fmt.Errorf("ACR authentication failed: %w", oci.ErrUnconfiguredProvider)
}

// OIDCLoginWithExpiry attempts to get an Authenticator for the provided ACR registry URL endpoint.
// It returns the Authenticator and the expiry time of the token.
//
// If you want to construct an Authenticator based on an image reference,
// you may want to use Login instead.
func (c *Client) OIDCLoginWithExpiry(ctx context.Context, registryUrl string) (authn.Authenticator, *time.Time, error) {
authConfig, expiresAt, err := c.getLoginAuth(ctx, registryUrl)
if err != nil {
log.FromContext(ctx).Info("error logging into ACR " + err.Error())
return nil, nil, err
}

auth := authn.FromConfig(authConfig)
return auth, expiresAt, nil
}

// OIDCLogin attempts to get an Authenticator for the provided ACR registry URL endpoint.
//
// If you want to construct an Authenticator based on an image reference,
// you may want to use Login instead.
func (c *Client) OIDCLogin(ctx context.Context, registryUrl string) (authn.Authenticator, error) {
authConfig, err := c.getLoginAuth(ctx, registryUrl)
authConfig, _, err := c.getLoginAuth(ctx, registryUrl)
if err != nil {
log.FromContext(ctx).Info("error logging into ACR " + err.Error())
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion oci/auth/azure/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func TestGetAzureLoginAuth(t *testing.T) {
WithTokenCredential(tt.tokenCredential).
WithScheme("http")

auth, err := c.getLoginAuth(context.TODO(), srv.URL)
auth, _, err := c.getLoginAuth(context.TODO(), srv.URL)
g.Expect(err != nil).To(Equal(tt.wantErr))
if tt.statusCode == http.StatusOK {
g.Expect(auth).To(Equal(tt.wantAuthConfig))
Expand Down
52 changes: 44 additions & 8 deletions oci/auth/gcp/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"io"
"net/http"
"strings"
"time"

"github.com/google/go-containerregistry/pkg/authn"
"github.com/google/go-containerregistry/pkg/name"
Expand Down Expand Up @@ -66,47 +67,69 @@ func (c *Client) WithTokenURL(url string) *Client {
// on GCP. This assumes that the pod has right to pull the image which would be
// the case if it is hosted on GCP. It works with both service account and
// workload identity enabled clusters.
func (c *Client) getLoginAuth(ctx context.Context) (authn.AuthConfig, error) {
func (c *Client) getLoginAuth(ctx context.Context) (authn.AuthConfig, *time.Time, error) {
var authConfig authn.AuthConfig

request, err := http.NewRequestWithContext(ctx, http.MethodGet, c.tokenURL, nil)
if err != nil {
return authConfig, err
return authConfig, nil, err
}

request.Header.Add("Metadata-Flavor", "Google")

client := &http.Client{}
response, err := client.Do(request)
if err != nil {
return authConfig, err
return authConfig, nil, err
}
defer response.Body.Close()
defer io.Copy(io.Discard, response.Body)

if response.StatusCode != http.StatusOK {
return authConfig, fmt.Errorf("unexpected status from metadata service: %s", response.Status)
return authConfig, nil, fmt.Errorf("unexpected status from metadata service: %s", response.Status)
}

var accessToken gceToken
decoder := json.NewDecoder(response.Body)
if err := decoder.Decode(&accessToken); err != nil {
return authConfig, err
return authConfig, nil, err
}

authConfig = authn.AuthConfig{
Username: "oauth2accesstoken",
Password: accessToken.AccessToken,
}
return authConfig, nil

// add expires_in seconds to the current time to get the expiry time
expiresAt := time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)

return authConfig, &expiresAt, nil
}

// Login attempts to get the authentication material for GCR.
// It returns the authentication material and the expiry time of the token.
// The caller can ensure that the passed image is a valid GCR image using ValidHost().
func (c *Client) LoginWithExpiry(ctx context.Context, autoLogin bool, image string, ref name.Reference) (authn.Authenticator, *time.Time, error) {
if autoLogin {
log.FromContext(ctx).Info("logging in to GCP GCR for " + image)
authConfig, expiresAt, err := c.getLoginAuth(ctx)
if err != nil {
log.FromContext(ctx).Info("error logging into GCP " + err.Error())
return nil, nil, err
}

auth := authn.FromConfig(authConfig)
return auth, expiresAt, nil
}
return nil, nil, fmt.Errorf("GCR authentication failed: %w", oci.ErrUnconfiguredProvider)
}

// Login attempts to get the authentication material for GCR. The caller can
// ensure that the passed image is a valid GCR image using ValidHost().
func (c *Client) Login(ctx context.Context, autoLogin bool, image string, ref name.Reference) (authn.Authenticator, error) {
if autoLogin {
log.FromContext(ctx).Info("logging in to GCP GCR for " + image)
authConfig, err := c.getLoginAuth(ctx)
authConfig, _, err := c.getLoginAuth(ctx)
if err != nil {
log.FromContext(ctx).Info("error logging into GCP " + err.Error())
return nil, err
Expand All @@ -118,9 +141,22 @@ func (c *Client) Login(ctx context.Context, autoLogin bool, image string, ref na
return nil, fmt.Errorf("GCR authentication failed: %w", oci.ErrUnconfiguredProvider)
}

// OIDCLoginWithExpiry attempts to get the authentication material for GCR from the token url set in the client.
// It returns the authentication material and the expiry time of the token.
func (c *Client) OIDCLoginWithExpiry(ctx context.Context) (authn.Authenticator, *time.Time, error) {
authConfig, expiresAt, err := c.getLoginAuth(ctx)
if err != nil {
log.FromContext(ctx).Info("error logging into GCP " + err.Error())
return nil, nil, err
}

auth := authn.FromConfig(authConfig)
return auth, expiresAt, nil
}

// OIDCLogin attempts to get the authentication material for GCR from the token url set in the client.
func (c *Client) OIDCLogin(ctx context.Context) (authn.Authenticator, error) {
authConfig, err := c.getLoginAuth(ctx)
authConfig, _, err := c.getLoginAuth(ctx)
if err != nil {
log.FromContext(ctx).Info("error logging into GCP " + err.Error())
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion oci/auth/gcp/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func TestGetLoginAuth(t *testing.T) {
})

gc := NewClient().WithTokenURL(srv.URL)
a, err := gc.getLoginAuth(context.TODO())
a, _, err := gc.getLoginAuth(context.TODO())
g.Expect(err != nil).To(Equal(tt.wantErr))
if tt.statusCode == http.StatusOK {
g.Expect(a).To(Equal(tt.wantAuthConfig))
Expand Down
Loading

0 comments on commit e6b2983

Please sign in to comment.