Skip to content

Commit

Permalink
Use the request context throughout the code, and add retries & timeou…
Browse files Browse the repository at this point in the history
…ts to more Azure auth provider calls (#358)

* Use the request context to cancel processing if request is canceled

Signed-off-by: Mark Drobnak <markdrobnak@microsoft.com>

* Use a retrying HTTP client wherever possible in azure auth provider

Also includes a 3 second timeout per request.

Signed-off-by: Mark Drobnak <markdrobnak@microsoft.com>

* Reformat

Signed-off-by: Mark Drobnak <markdrobnak@microsoft.com>

* Clarify a comment

Signed-off-by: Mark Drobnak <markdrobnak@microsoft.com>

---------

Signed-off-by: Mark Drobnak <markdrobnak@microsoft.com>
  • Loading branch information
AzureMarker committed Mar 16, 2023
1 parent 10c14ce commit c67ed7e
Show file tree
Hide file tree
Showing 30 changed files with 218 additions and 140 deletions.
65 changes: 43 additions & 22 deletions auth/providers/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/coreos/go-oidc"
"github.com/hashicorp/go-retryablehttp"
"github.com/pkg/errors"
"golang.org/x/oauth2"
authv1 "k8s.io/api/authentication/v1"
"k8s.io/klog/v2"
)
Expand Down Expand Up @@ -70,7 +71,6 @@ type Authenticator struct {
graphClient *graph.UserInfo
verifier *oidc.IDTokenVerifier
popTokenVerifier *PoPTokenVerifier
ctx context.Context
}

type authInfo struct {
Expand All @@ -79,19 +79,19 @@ type authInfo struct {
Issuer string
}

func New(opts Options) (auth.Interface, error) {
func New(ctx context.Context, opts Options) (auth.Interface, error) {
c := &Authenticator{
Options: opts,
ctx: context.Background(),
}
authInfoVal, err := getAuthInfo(c.Environment, c.TenantID, getMetadata)
authInfoVal, err := getAuthInfo(ctx, c.Environment, c.TenantID, getMetadata)
if err != nil {
return nil, err
}

klog.V(3).Infof("Using issuer url: %v", authInfoVal.Issuer)

provider, err := oidc.NewProvider(c.ctx, authInfoVal.Issuer)
ctx = withRetryableHttpClient(ctx)
provider, err := oidc.NewProvider(ctx, authInfoVal.Issuer)
if err != nil {
return nil, errors.Wrap(err, "failed to create provider for azure")
}
Expand All @@ -115,22 +115,16 @@ func New(opts Options) (auth.Interface, error) {
return c, nil
}

type metadataJSON struct {
Issuer string `json:"issuer"`
MsgraphHost string `json:"msgraph_host"`
}

// https://docs.microsoft.com/en-us/azure/active-directory/develop/howto-convert-app-to-be-multi-tenant
func getMetadata(aadEndpoint, tenantID string) (*metadataJSON, error) {
metadataURL := aadEndpoint + tenantID + "/.well-known/openid-configuration"

// makeRetryableHttpClient creates an HTTP client which attempts the request
// 3 times and has a 3 second timeout per attempt.
func makeRetryableHttpClient() retryablehttp.Client {
// Copy the default HTTP client so we can set a timeout.
// (It uses the same transport since the pointer gets copied)
httpClient := *httpclient.DefaultHTTPClient
httpClient.Timeout = 3 * time.Second

// Attempt the request up to 3 times
retryClient := retryablehttp.Client{
return retryablehttp.Client{
HTTPClient: &httpClient,
RetryWaitMin: 500 * time.Millisecond,
RetryWaitMax: 2 * time.Second,
Expand All @@ -139,8 +133,32 @@ func getMetadata(aadEndpoint, tenantID string) (*metadataJSON, error) {
Backoff: retryablehttp.DefaultBackoff,
Logger: log.Default(),
}
}

response, err := retryClient.Get(metadataURL)
// withRetryableHttpClient sets the oauth2.HTTPClient key of the context to an
// *http.Client made from makeRetryableHttpClient.
// Some of the libraries we use will take the client out of the context via
// oauth2.HTTPClient and use it, so this way we can add retries to external code.
func withRetryableHttpClient(ctx context.Context) context.Context {
retryClient := makeRetryableHttpClient()
return context.WithValue(ctx, oauth2.HTTPClient, retryClient.StandardClient())
}

type metadataJSON struct {
Issuer string `json:"issuer"`
MsgraphHost string `json:"msgraph_host"`
}

// https://docs.microsoft.com/en-us/azure/active-directory/develop/howto-convert-app-to-be-multi-tenant
func getMetadata(ctx context.Context, aadEndpoint, tenantID string) (*metadataJSON, error) {
metadataURL := aadEndpoint + tenantID + "/.well-known/openid-configuration"
retryClient := makeRetryableHttpClient()

request, err := retryablehttp.NewRequest("GET", metadataURL, nil)
if err != nil {
return nil, err
}
response, err := retryClient.Do(request.WithContext(ctx))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -168,7 +186,7 @@ func (s Authenticator) UID() string {
return OrgType
}

func (s Authenticator) Check(token string) (*authv1.UserInfo, error) {
func (s Authenticator) Check(ctx context.Context, token string) (*authv1.UserInfo, error) {
var err error

if s.EnablePOP {
Expand All @@ -178,7 +196,8 @@ func (s Authenticator) Check(token string) (*authv1.UserInfo, error) {
}
}

idToken, err := s.verifier.Verify(s.ctx, token)
ctx = withRetryableHttpClient(ctx)
idToken, err := s.verifier.Verify(ctx, token)
if err != nil {
return nil, errors.Wrap(err, "failed to verify token for azure")
}
Expand All @@ -204,10 +223,10 @@ func (s Authenticator) Check(token string) (*authv1.UserInfo, error) {
}
}
if !s.Options.SkipGroupMembershipResolution {
if err := s.graphClient.RefreshToken(token); err != nil {
if err := s.graphClient.RefreshToken(ctx, token); err != nil {
return nil, err
}
resp.Groups, err = s.graphClient.GetGroups(resp.Username)
resp.Groups, err = s.graphClient.GetGroups(ctx, resp.Username)
if err != nil {
return nil, errors.Wrap(err, "failed to get groups")
}
Expand Down Expand Up @@ -343,7 +362,9 @@ func (c claims) string(key string) (string, error) {
return s, nil
}

func getAuthInfo(environment, tenantID string, getMetadata func(string, string) (*metadataJSON, error)) (*authInfo, error) {
type getMetadataFunc = func(context.Context, string, string) (*metadataJSON, error)

func getAuthInfo(ctx context.Context, environment, tenantID string, getMetadata getMetadataFunc) (*authInfo, error) {
var err error
env := azure.PublicCloud
if environment != "" {
Expand All @@ -353,7 +374,7 @@ func getAuthInfo(environment, tenantID string, getMetadata func(string, string)
}
}

metadata, err := getMetadata(env.ActiveDirectoryEndpoint, tenantID)
metadata, err := getMetadata(ctx, env.ActiveDirectoryEndpoint, tenantID)
if err != nil {
return nil, errors.Wrap(err, "failed to get metadata for azure")
}
Expand Down
34 changes: 20 additions & 14 deletions auth/providers/azure/azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,10 @@ func clientSetup(clientID, clientSecret, tenantID, serverUrl string, useGroupUID
AKSTokenURL: "",
VerifyClientID: verifyClientID,
},
ctx: context.Background(),
}

p, err := oidc.NewProvider(c.ctx, serverUrl)
ctx := context.Background()
p, err := oidc.NewProvider(ctx, serverUrl)
if err != nil {
return nil, fmt.Errorf("failed to create provider for azure. Reason: %v", err)
}
Expand Down Expand Up @@ -285,6 +285,7 @@ func getServerAndClient(t *testing.T, signKey *signingKey, loginResp string, gro
}

func TestCheckAzureAuthenticationSuccess(t *testing.T) {
ctx := context.Background()
signKey, err := newRSAKey()
if err != nil {
t.Fatalf("Error when creating signing key. reason : %v", err)
Expand Down Expand Up @@ -319,7 +320,7 @@ func TestCheckAzureAuthenticationSuccess(t *testing.T) {
t.Fatalf("Error when signing token. reason: %v", err)
}

resp, err := client.Check(token)
resp, err := client.Check(ctx, token)
assert.Nil(t, err)
assertUserInfo(t, resp, test.groupSize, client.UseGroupUID)
})
Expand All @@ -336,14 +337,15 @@ func TestCheckAzureAuthenticationSuccess(t *testing.T) {
t.Fatalf("Error when signing token. reason: %v", err)
}

resp, err := client.Check(token)
resp, err := client.Check(ctx, token)
assert.Nil(t, err)
assertUserInfo(t, resp, test.groupSize, client.UseGroupUID)
})
}
}

func TestCheckAzureAuthenticationWithOverageCheckOption(t *testing.T) {
ctx := context.Background()
signKey, err := newRSAKey()
if err != nil {
t.Fatalf("Error when creating signing key. reason : %v", err)
Expand Down Expand Up @@ -373,7 +375,7 @@ func TestCheckAzureAuthenticationWithOverageCheckOption(t *testing.T) {
t.Fatalf("Error when signing token. reason: %v", err)
}

resp, err := client.Check(token)
resp, err := client.Check(ctx, token)
assert.Nil(t, err)
assertUserInfo(t, resp, test.groupSize, client.UseGroupUID)
})
Expand All @@ -382,6 +384,7 @@ func TestCheckAzureAuthenticationWithOverageCheckOption(t *testing.T) {
}

func TestCheckAzureAuthenticationFailed(t *testing.T) {
ctx := context.Background()
signKey, err := newRSAKey()
if err != nil {
t.Fatalf("Error when creating signing key. reason : %v", err)
Expand Down Expand Up @@ -430,7 +433,7 @@ func TestCheckAzureAuthenticationFailed(t *testing.T) {
token = test.token
}

resp, err := client.Check(token)
resp, err := client.Check(ctx, token)
assert.NotNil(t, err)
assert.Nil(t, resp)
})
Expand Down Expand Up @@ -495,32 +498,35 @@ func TestString(t *testing.T) {
}

func TestGetAuthInfo(t *testing.T) {
authInfo, err := getAuthInfo("AzurePublicCloud", "testTenant", localGetMetadata)
ctx := context.Background()
authInfo, err := getAuthInfo(ctx, "AzurePublicCloud", "testTenant", localGetMetadata)
assert.NoError(t, err)
assert.Contains(t, authInfo.AADEndpoint, "login.microsoftonline.com")

authInfo, err = getAuthInfo("AzureChinaCloud", "testTenant", localGetMetadata)
authInfo, err = getAuthInfo(ctx, "AzureChinaCloud", "testTenant", localGetMetadata)
assert.NoError(t, err)
assert.Contains(t, authInfo.AADEndpoint, "login.chinacloudapi.cn")
}

func localGetMetadata(string, string) (*metadataJSON, error) {
func localGetMetadata(context.Context, string, string) (*metadataJSON, error) {
return &metadataJSON{
Issuer: "testIssuer",
MsgraphHost: "testHost",
}, nil
}

func TestGetMetadata(t *testing.T) {
ctx := context.Background()

t.Run("sends request to AAD server and parses response", func(t *testing.T) {
testServer := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(200)
_, _ = writer.Write([]byte(`{"issuer":"testIssuer","msgraph_host":"testHost"}`))
}))
defer testServer.Close()
expectedMetadata, _ := localGetMetadata("", "")
expectedMetadata, _ := localGetMetadata(ctx, "", "")

metadata, err := getMetadata(testServer.URL+"/", "testTenant")
metadata, err := getMetadata(ctx, testServer.URL+"/", "testTenant")
assert.NoError(t, err)
assert.Equal(t, expectedMetadata, metadata)
})
Expand All @@ -547,9 +553,9 @@ func TestGetMetadata(t *testing.T) {
}
}))
defer testServer.Close()
expectedMetadata, _ := localGetMetadata("", "")
expectedMetadata, _ := localGetMetadata(ctx, "", "")

metadata, err := getMetadata(testServer.URL+"/", "testTenant")
metadata, err := getMetadata(ctx, testServer.URL+"/", "testTenant")
assert.NoError(t, err)
assert.Equal(t, expectedMetadata, metadata)
})
Expand All @@ -560,7 +566,7 @@ func TestGetMetadata(t *testing.T) {
testServer.CloseClientConnections()
}))

metadata, err := getMetadata(testServer.URL+"/", "testTenant")
metadata, err := getMetadata(ctx, testServer.URL+"/", "testTenant")
assert.Error(t, err)
assert.Nil(t, metadata)
})
Expand Down
5 changes: 3 additions & 2 deletions auth/providers/azure/graph/aks_tokenprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package graph

import (
"bytes"
"context"
"io"
"net/http"

Expand Down Expand Up @@ -45,7 +46,7 @@ func NewAKSTokenProvider(tokenURL, tenantID string) TokenProvider {

func (u *aksTokenProvider) Name() string { return u.name }

func (u *aksTokenProvider) Acquire(token string) (AuthResponse, error) {
func (u *aksTokenProvider) Acquire(ctx context.Context, token string) (AuthResponse, error) {
authResp := AuthResponse{}
tokenReq := struct {
TenantID string `json:"tenantID,omitempty"`
Expand All @@ -65,7 +66,7 @@ func (u *aksTokenProvider) Acquire(token string) (AuthResponse, error) {
}
req.Header.Set("Content-Type", "application/json")

resp, err := u.client.Do(req)
resp, err := u.client.Do(req.WithContext(ctx))
if err != nil {
return authResp, errors.Wrap(err, "failed to send request")
}
Expand Down
7 changes: 5 additions & 2 deletions auth/providers/azure/graph/aks_tokenprovider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package graph

import (
"context"
"fmt"
"net/http"
"testing"
Expand Down Expand Up @@ -59,8 +60,9 @@ func TestAKSTokenProvider(t *testing.T) {

defer stopTestServer(t, s)

ctx := context.Background()
r := NewAKSTokenProvider(s.URL, tenantID)
resp, err := r.Acquire(inputAccessToken)
resp, err := r.Acquire(ctx, inputAccessToken)
if err != nil {
t.Fatalf("refresh should not return error: %s", err)
}
Expand Down Expand Up @@ -102,8 +104,9 @@ func TestAKSTokenProvider(t *testing.T) {

defer stopTestServer(t, s)

ctx := context.Background()
r := NewAKSTokenProvider(s.URL, tenantID)
resp, err := r.Acquire(inputAccessToken)
resp, err := r.Acquire(ctx, inputAccessToken)
if err == nil {
t.Error("refresh should return error")
}
Expand Down
5 changes: 3 additions & 2 deletions auth/providers/azure/graph/clientcredential_tokenprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package graph

import (
"context"
"io"
"net/http"
"net/url"
Expand Down Expand Up @@ -53,7 +54,7 @@ func NewClientCredentialTokenProvider(clientID, clientSecret, loginURL, scope st

func (u *clientCredentialTokenProvider) Name() string { return u.name }

func (u *clientCredentialTokenProvider) Acquire(token string) (AuthResponse, error) {
func (u *clientCredentialTokenProvider) Acquire(ctx context.Context, token string) (AuthResponse, error) {
authResp := AuthResponse{}
form := url.Values{}
form.Set("client_id", u.clientID)
Expand All @@ -71,7 +72,7 @@ func (u *clientCredentialTokenProvider) Acquire(token string) (AuthResponse, err
klog.V(10).Infoln(cmd)
}

resp, err := u.client.Do(req)
resp, err := u.client.Do(req.WithContext(ctx))
if err != nil {
return authResp, errors.Wrap(err, "fail to send request")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package graph

import (
"context"
"fmt"
"net/http"
"testing"
Expand Down Expand Up @@ -60,8 +61,9 @@ func TestClientCredentialTokenProvider(t *testing.T) {

defer stopTestServer(t, s)

ctx := context.Background()
r := NewClientCredentialTokenProvider(clientID, clientSecret, s.URL, scope)
resp, err := r.Acquire(inputAccessToken)
resp, err := r.Acquire(ctx, inputAccessToken)
if err != nil {
t.Fatalf("refresh should not return error: %s", err)
}
Expand Down Expand Up @@ -101,8 +103,9 @@ func TestClientCredentialTokenProvider(t *testing.T) {

defer stopTestServer(t, s)

ctx := context.Background()
r := NewClientCredentialTokenProvider(clientID, clientSecret, s.URL, scope)
resp, err := r.Acquire(inputAccessToken)
resp, err := r.Acquire(ctx, inputAccessToken)
if err == nil {
t.Error("refresh should return error")
}
Expand Down
Loading

0 comments on commit c67ed7e

Please sign in to comment.