Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the request context throughout the code, and add retries & timeouts to more Azure auth provider calls #358

Merged
merged 7 commits into from
Mar 16, 2023
65 changes: 43 additions & 22 deletions auth/providers/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/coreos/go-oidc"
"github.com/hashicorp/go-retryablehttp"
"github.com/pkg/errors"
"golang.org/x/oauth2"
authv1 "k8s.io/api/authentication/v1"
"k8s.io/klog/v2"
)
Expand Down Expand Up @@ -70,7 +71,6 @@ type Authenticator struct {
graphClient *graph.UserInfo
verifier *oidc.IDTokenVerifier
popTokenVerifier *PoPTokenVerifier
ctx context.Context
}

type authInfo struct {
Expand All @@ -79,19 +79,19 @@ type authInfo struct {
Issuer string
}

func New(opts Options) (auth.Interface, error) {
func New(ctx context.Context, opts Options) (auth.Interface, error) {
c := &Authenticator{
Options: opts,
ctx: context.Background(),
}
authInfoVal, err := getAuthInfo(c.Environment, c.TenantID, getMetadata)
authInfoVal, err := getAuthInfo(ctx, c.Environment, c.TenantID, getMetadata)
if err != nil {
return nil, err
}

klog.V(3).Infof("Using issuer url: %v", authInfoVal.Issuer)

provider, err := oidc.NewProvider(c.ctx, authInfoVal.Issuer)
ctx = withRetryableHttpClient(ctx)
provider, err := oidc.NewProvider(ctx, authInfoVal.Issuer)
if err != nil {
return nil, errors.Wrap(err, "failed to create provider for azure")
}
Expand All @@ -115,22 +115,16 @@ func New(opts Options) (auth.Interface, error) {
return c, nil
}

type metadataJSON struct {
Issuer string `json:"issuer"`
MsgraphHost string `json:"msgraph_host"`
}

// https://docs.microsoft.com/en-us/azure/active-directory/develop/howto-convert-app-to-be-multi-tenant
func getMetadata(aadEndpoint, tenantID string) (*metadataJSON, error) {
metadataURL := aadEndpoint + tenantID + "/.well-known/openid-configuration"

// makeRetryableHttpClient creates an HTTP client which attempts the request
// 3 times and has a 3 second timeout per attempt.
func makeRetryableHttpClient() retryablehttp.Client {
// Copy the default HTTP client so we can set a timeout.
// (It uses the same transport since the pointer gets copied)
httpClient := *httpclient.DefaultHTTPClient
httpClient.Timeout = 3 * time.Second

// Attempt the request up to 3 times
retryClient := retryablehttp.Client{
return retryablehttp.Client{
HTTPClient: &httpClient,
RetryWaitMin: 500 * time.Millisecond,
RetryWaitMax: 2 * time.Second,
Expand All @@ -139,8 +133,32 @@ func getMetadata(aadEndpoint, tenantID string) (*metadataJSON, error) {
Backoff: retryablehttp.DefaultBackoff,
AzureMarker marked this conversation as resolved.
Show resolved Hide resolved
Logger: log.Default(),
}
}

response, err := retryClient.Get(metadataURL)
// withRetryableHttpClient sets the oauth2.HTTPClient key of the context to an
// *http.Client made from makeRetryableHttpClient.
// Some of the libraries we use will take the client out of the context via
// oauth2.HTTPClient and use it, so this way we can add retries to external code.
func withRetryableHttpClient(ctx context.Context) context.Context {
retryClient := makeRetryableHttpClient()
return context.WithValue(ctx, oauth2.HTTPClient, retryClient.StandardClient())
AzureMarker marked this conversation as resolved.
Show resolved Hide resolved
}

type metadataJSON struct {
Issuer string `json:"issuer"`
MsgraphHost string `json:"msgraph_host"`
}

// https://docs.microsoft.com/en-us/azure/active-directory/develop/howto-convert-app-to-be-multi-tenant
func getMetadata(ctx context.Context, aadEndpoint, tenantID string) (*metadataJSON, error) {
metadataURL := aadEndpoint + tenantID + "/.well-known/openid-configuration"
retryClient := makeRetryableHttpClient()

request, err := retryablehttp.NewRequest("GET", metadataURL, nil)
if err != nil {
return nil, err
}
response, err := retryClient.Do(request.WithContext(ctx))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -168,7 +186,7 @@ func (s Authenticator) UID() string {
return OrgType
}

func (s Authenticator) Check(token string) (*authv1.UserInfo, error) {
func (s Authenticator) Check(ctx context.Context, token string) (*authv1.UserInfo, error) {
var err error

if s.EnablePOP {
Expand All @@ -178,7 +196,8 @@ func (s Authenticator) Check(token string) (*authv1.UserInfo, error) {
}
}

idToken, err := s.verifier.Verify(s.ctx, token)
ctx = withRetryableHttpClient(ctx)
idToken, err := s.verifier.Verify(ctx, token)
if err != nil {
return nil, errors.Wrap(err, "failed to verify token for azure")
}
Expand All @@ -204,10 +223,10 @@ func (s Authenticator) Check(token string) (*authv1.UserInfo, error) {
}
}
if !s.Options.SkipGroupMembershipResolution {
if err := s.graphClient.RefreshToken(token); err != nil {
if err := s.graphClient.RefreshToken(ctx, token); err != nil {
return nil, err
}
resp.Groups, err = s.graphClient.GetGroups(resp.Username)
resp.Groups, err = s.graphClient.GetGroups(ctx, resp.Username)
if err != nil {
return nil, errors.Wrap(err, "failed to get groups")
}
Expand Down Expand Up @@ -343,7 +362,9 @@ func (c claims) string(key string) (string, error) {
return s, nil
}

func getAuthInfo(environment, tenantID string, getMetadata func(string, string) (*metadataJSON, error)) (*authInfo, error) {
type getMetadataFunc = func(context.Context, string, string) (*metadataJSON, error)

func getAuthInfo(ctx context.Context, environment, tenantID string, getMetadata getMetadataFunc) (*authInfo, error) {
var err error
env := azure.PublicCloud
if environment != "" {
Expand All @@ -353,7 +374,7 @@ func getAuthInfo(environment, tenantID string, getMetadata func(string, string)
}
}

metadata, err := getMetadata(env.ActiveDirectoryEndpoint, tenantID)
metadata, err := getMetadata(ctx, env.ActiveDirectoryEndpoint, tenantID)
if err != nil {
return nil, errors.Wrap(err, "failed to get metadata for azure")
}
Expand Down
34 changes: 20 additions & 14 deletions auth/providers/azure/azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,10 @@ func clientSetup(clientID, clientSecret, tenantID, serverUrl string, useGroupUID
AKSTokenURL: "",
VerifyClientID: verifyClientID,
},
ctx: context.Background(),
}

p, err := oidc.NewProvider(c.ctx, serverUrl)
ctx := context.Background()
p, err := oidc.NewProvider(ctx, serverUrl)
if err != nil {
return nil, fmt.Errorf("failed to create provider for azure. Reason: %v", err)
}
Expand Down Expand Up @@ -285,6 +285,7 @@ func getServerAndClient(t *testing.T, signKey *signingKey, loginResp string, gro
}

func TestCheckAzureAuthenticationSuccess(t *testing.T) {
ctx := context.Background()
signKey, err := newRSAKey()
if err != nil {
t.Fatalf("Error when creating signing key. reason : %v", err)
Expand Down Expand Up @@ -319,7 +320,7 @@ func TestCheckAzureAuthenticationSuccess(t *testing.T) {
t.Fatalf("Error when signing token. reason: %v", err)
}

resp, err := client.Check(token)
resp, err := client.Check(ctx, token)
assert.Nil(t, err)
assertUserInfo(t, resp, test.groupSize, client.UseGroupUID)
})
Expand All @@ -336,14 +337,15 @@ func TestCheckAzureAuthenticationSuccess(t *testing.T) {
t.Fatalf("Error when signing token. reason: %v", err)
}

resp, err := client.Check(token)
resp, err := client.Check(ctx, token)
assert.Nil(t, err)
assertUserInfo(t, resp, test.groupSize, client.UseGroupUID)
})
}
}

func TestCheckAzureAuthenticationWithOverageCheckOption(t *testing.T) {
ctx := context.Background()
signKey, err := newRSAKey()
if err != nil {
t.Fatalf("Error when creating signing key. reason : %v", err)
Expand Down Expand Up @@ -373,7 +375,7 @@ func TestCheckAzureAuthenticationWithOverageCheckOption(t *testing.T) {
t.Fatalf("Error when signing token. reason: %v", err)
}

resp, err := client.Check(token)
resp, err := client.Check(ctx, token)
assert.Nil(t, err)
assertUserInfo(t, resp, test.groupSize, client.UseGroupUID)
})
Expand All @@ -382,6 +384,7 @@ func TestCheckAzureAuthenticationWithOverageCheckOption(t *testing.T) {
}

func TestCheckAzureAuthenticationFailed(t *testing.T) {
ctx := context.Background()
signKey, err := newRSAKey()
if err != nil {
t.Fatalf("Error when creating signing key. reason : %v", err)
Expand Down Expand Up @@ -430,7 +433,7 @@ func TestCheckAzureAuthenticationFailed(t *testing.T) {
token = test.token
}

resp, err := client.Check(token)
resp, err := client.Check(ctx, token)
assert.NotNil(t, err)
assert.Nil(t, resp)
})
Expand Down Expand Up @@ -495,32 +498,35 @@ func TestString(t *testing.T) {
}

func TestGetAuthInfo(t *testing.T) {
authInfo, err := getAuthInfo("AzurePublicCloud", "testTenant", localGetMetadata)
ctx := context.Background()
authInfo, err := getAuthInfo(ctx, "AzurePublicCloud", "testTenant", localGetMetadata)
assert.NoError(t, err)
assert.Contains(t, authInfo.AADEndpoint, "login.microsoftonline.com")

authInfo, err = getAuthInfo("AzureChinaCloud", "testTenant", localGetMetadata)
authInfo, err = getAuthInfo(ctx, "AzureChinaCloud", "testTenant", localGetMetadata)
assert.NoError(t, err)
assert.Contains(t, authInfo.AADEndpoint, "login.chinacloudapi.cn")
}

func localGetMetadata(string, string) (*metadataJSON, error) {
func localGetMetadata(context.Context, string, string) (*metadataJSON, error) {
return &metadataJSON{
Issuer: "testIssuer",
MsgraphHost: "testHost",
}, nil
}

func TestGetMetadata(t *testing.T) {
ctx := context.Background()

t.Run("sends request to AAD server and parses response", func(t *testing.T) {
testServer := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(200)
_, _ = writer.Write([]byte(`{"issuer":"testIssuer","msgraph_host":"testHost"}`))
}))
defer testServer.Close()
expectedMetadata, _ := localGetMetadata("", "")
expectedMetadata, _ := localGetMetadata(ctx, "", "")

metadata, err := getMetadata(testServer.URL+"/", "testTenant")
metadata, err := getMetadata(ctx, testServer.URL+"/", "testTenant")
assert.NoError(t, err)
assert.Equal(t, expectedMetadata, metadata)
})
Expand All @@ -547,9 +553,9 @@ func TestGetMetadata(t *testing.T) {
}
}))
defer testServer.Close()
expectedMetadata, _ := localGetMetadata("", "")
expectedMetadata, _ := localGetMetadata(ctx, "", "")

metadata, err := getMetadata(testServer.URL+"/", "testTenant")
metadata, err := getMetadata(ctx, testServer.URL+"/", "testTenant")
assert.NoError(t, err)
assert.Equal(t, expectedMetadata, metadata)
})
Expand All @@ -560,7 +566,7 @@ func TestGetMetadata(t *testing.T) {
testServer.CloseClientConnections()
}))

metadata, err := getMetadata(testServer.URL+"/", "testTenant")
metadata, err := getMetadata(ctx, testServer.URL+"/", "testTenant")
assert.Error(t, err)
assert.Nil(t, metadata)
})
Expand Down
5 changes: 3 additions & 2 deletions auth/providers/azure/graph/aks_tokenprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package graph

import (
"bytes"
"context"
"io"
"net/http"

Expand Down Expand Up @@ -45,7 +46,7 @@ func NewAKSTokenProvider(tokenURL, tenantID string) TokenProvider {

func (u *aksTokenProvider) Name() string { return u.name }

func (u *aksTokenProvider) Acquire(token string) (AuthResponse, error) {
func (u *aksTokenProvider) Acquire(ctx context.Context, token string) (AuthResponse, error) {
authResp := AuthResponse{}
tokenReq := struct {
TenantID string `json:"tenantID,omitempty"`
Expand All @@ -65,7 +66,7 @@ func (u *aksTokenProvider) Acquire(token string) (AuthResponse, error) {
}
req.Header.Set("Content-Type", "application/json")

resp, err := u.client.Do(req)
resp, err := u.client.Do(req.WithContext(ctx))
if err != nil {
return authResp, errors.Wrap(err, "failed to send request")
}
Expand Down
7 changes: 5 additions & 2 deletions auth/providers/azure/graph/aks_tokenprovider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package graph

import (
"context"
"fmt"
"net/http"
"testing"
Expand Down Expand Up @@ -59,8 +60,9 @@ func TestAKSTokenProvider(t *testing.T) {

defer stopTestServer(t, s)

ctx := context.Background()
r := NewAKSTokenProvider(s.URL, tenantID)
resp, err := r.Acquire(inputAccessToken)
resp, err := r.Acquire(ctx, inputAccessToken)
if err != nil {
t.Fatalf("refresh should not return error: %s", err)
}
Expand Down Expand Up @@ -102,8 +104,9 @@ func TestAKSTokenProvider(t *testing.T) {

defer stopTestServer(t, s)

ctx := context.Background()
r := NewAKSTokenProvider(s.URL, tenantID)
resp, err := r.Acquire(inputAccessToken)
resp, err := r.Acquire(ctx, inputAccessToken)
if err == nil {
t.Error("refresh should return error")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package graph

import (
"context"
"io"
"net/http"
"net/url"
Expand Down Expand Up @@ -53,7 +54,7 @@ func NewClientCredentialTokenProvider(clientID, clientSecret, loginURL, scope st

func (u *clientCredentialTokenProvider) Name() string { return u.name }

func (u *clientCredentialTokenProvider) Acquire(token string) (AuthResponse, error) {
func (u *clientCredentialTokenProvider) Acquire(ctx context.Context, token string) (AuthResponse, error) {
authResp := AuthResponse{}
form := url.Values{}
form.Set("client_id", u.clientID)
Expand All @@ -71,7 +72,7 @@ func (u *clientCredentialTokenProvider) Acquire(token string) (AuthResponse, err
klog.V(10).Infoln(cmd)
}

resp, err := u.client.Do(req)
resp, err := u.client.Do(req.WithContext(ctx))
if err != nil {
return authResp, errors.Wrap(err, "fail to send request")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package graph

import (
"context"
"fmt"
"net/http"
"testing"
Expand Down Expand Up @@ -60,8 +61,9 @@ func TestClientCredentialTokenProvider(t *testing.T) {

defer stopTestServer(t, s)

ctx := context.Background()
r := NewClientCredentialTokenProvider(clientID, clientSecret, s.URL, scope)
resp, err := r.Acquire(inputAccessToken)
resp, err := r.Acquire(ctx, inputAccessToken)
if err != nil {
t.Fatalf("refresh should not return error: %s", err)
}
Expand Down Expand Up @@ -101,8 +103,9 @@ func TestClientCredentialTokenProvider(t *testing.T) {

defer stopTestServer(t, s)

ctx := context.Background()
r := NewClientCredentialTokenProvider(clientID, clientSecret, s.URL, scope)
resp, err := r.Acquire(inputAccessToken)
resp, err := r.Acquire(ctx, inputAccessToken)
if err == nil {
t.Error("refresh should return error")
}
Expand Down
Loading