Skip to content

Commit

Permalink
add a flag in azure auth module to omit spn: prefix in audience claim
Browse files Browse the repository at this point in the history
  • Loading branch information
weinong committed Feb 5, 2020
1 parent 9f6f608 commit c08db5f
Show file tree
Hide file tree
Showing 4 changed files with 277 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ go_test(
name = "go_default_test",
srcs = ["azure_test.go"],
embed = [":go_default_library"],
deps = ["//vendor/github.com/Azure/go-autorest/autorest/adal:go_default_library"],
deps = [
"//vendor/github.com/Azure/go-autorest/autorest/adal:go_default_library",
"//vendor/github.com/Azure/go-autorest/autorest/azure:go_default_library",
],
)

go_library(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,13 @@ This plugin provides an integration with Azure Active Directory device flow. If
* Replace `APISERVER_APPLICATION_ID` with the application ID of your `apiserver` application ID
* Be sure to also (create and) select a context that uses above user

6. The access token is acquired when first `kubectl` command is executed
6. (Optionally) the AAD token has `aud` claim with `spn:` prefix. To omit that, add following auth configuration:

```
--auth-provider-arg=config-mode="1"
```

7. The access token is acquired when first `kubectl` command is executed

```
kubectl get pods
Expand Down
108 changes: 81 additions & 27 deletions staging/src/k8s.io/client-go/plugin/pkg/client/auth/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"
"net/http"
"os"
"strconv"
"sync"

"github.com/Azure/go-autorest/autorest"
Expand All @@ -33,6 +34,8 @@ import (
restclient "k8s.io/client-go/rest"
)

type configMode int

const (
azureTokenKey = "azureTokenKey"
tokenType = "Bearer"
Expand All @@ -46,6 +49,10 @@ const (
cfgExpiresOn = "expires-on"
cfgEnvironment = "environment"
cfgApiserverID = "apiserver-id"
cfgConfigMode = "config-mode"

configModeDefault configMode = 0
configModeOmitSPNPrefix configMode = 1
)

func init() {
Expand Down Expand Up @@ -78,17 +85,37 @@ func (c *azureTokenCache) setToken(tokenKey string, token *azureToken) {
}

func newAzureAuthProvider(_ string, cfg map[string]string, persister restclient.AuthProviderConfigPersister) (restclient.AuthProvider, error) {
var ts tokenSource

environment, err := azure.EnvironmentFromName(cfg[cfgEnvironment])
var (
ts tokenSource
environment azure.Environment
err error
mode configMode
)

environment, err = azure.EnvironmentFromName(cfg[cfgEnvironment])
if err != nil {
environment = azure.PublicCloud
}
ts, err = newAzureTokenSourceDeviceCode(environment, cfg[cfgClientID], cfg[cfgTenantID], cfg[cfgApiserverID])

mode = configModeDefault
if cfg[cfgConfigMode] != "" {
configModeInt, err := strconv.Atoi(cfg[cfgConfigMode])
if err != nil {
return nil, fmt.Errorf("failed to parse %s, error: %s", cfgConfigMode, err)
}
mode = configMode(configModeInt)
switch mode {
case configModeOmitSPNPrefix:
case configModeDefault:
default:
return nil, fmt.Errorf("%s:%s is not a valid mode", cfgConfigMode, cfg[cfgConfigMode])
}
}
ts, err = newAzureTokenSourceDeviceCode(environment, cfg[cfgClientID], cfg[cfgTenantID], cfg[cfgApiserverID], mode)
if err != nil {
return nil, fmt.Errorf("creating a new azure token source for device code authentication: %v", err)
}
cacheSource := newAzureTokenSource(ts, cache, cfg, persister)
cacheSource := newAzureTokenSource(ts, cache, cfg, mode, persister)

return &azureAuthProvider{
tokenSource: cacheSource,
Expand Down Expand Up @@ -156,19 +183,21 @@ type tokenSource interface {
}

type azureTokenSource struct {
source tokenSource
cache *azureTokenCache
lock sync.Mutex
cfg map[string]string
persister restclient.AuthProviderConfigPersister
source tokenSource
cache *azureTokenCache
lock sync.Mutex
configMode configMode
cfg map[string]string
persister restclient.AuthProviderConfigPersister
}

func newAzureTokenSource(source tokenSource, cache *azureTokenCache, cfg map[string]string, persister restclient.AuthProviderConfigPersister) tokenSource {
func newAzureTokenSource(source tokenSource, cache *azureTokenCache, cfg map[string]string, configMode configMode, persister restclient.AuthProviderConfigPersister) tokenSource {
return &azureTokenSource{
source: source,
cache: cache,
cfg: cfg,
persister: persister,
source: source,
cache: cache,
cfg: cfg,
persister: persister,
configMode: configMode,
}
}

Expand Down Expand Up @@ -232,9 +261,9 @@ func (ts *azureTokenSource) retrieveTokenFromCfg() (*azureToken, error) {
if tenantID == "" {
return nil, fmt.Errorf("no tenant ID in cfg: %s", cfgTenantID)
}
apiserverID := ts.cfg[cfgApiserverID]
if apiserverID == "" {
return nil, fmt.Errorf("no apiserver ID in cfg: %s", apiserverID)
resourceID := ts.cfg[cfgApiserverID]
if resourceID == "" {
return nil, fmt.Errorf("no apiserver ID in cfg: %s", cfgApiserverID)
}
expiresIn := ts.cfg[cfgExpiresIn]
if expiresIn == "" {
Expand All @@ -244,6 +273,9 @@ func (ts *azureTokenSource) retrieveTokenFromCfg() (*azureToken, error) {
if expiresOn == "" {
return nil, fmt.Errorf("no expiresOn in cfg: %s", cfgExpiresOn)
}
if ts.configMode == configModeDefault {
resourceID = fmt.Sprintf("spn:%s", resourceID)
}

return &azureToken{
token: adal.Token{
Expand All @@ -252,13 +284,13 @@ func (ts *azureTokenSource) retrieveTokenFromCfg() (*azureToken, error) {
ExpiresIn: json.Number(expiresIn),
ExpiresOn: json.Number(expiresOn),
NotBefore: json.Number(expiresOn),
Resource: fmt.Sprintf("spn:%s", apiserverID),
Resource: resourceID,
Type: tokenType,
},
environment: environment,
clientID: clientID,
tenantID: tenantID,
apiserverID: apiserverID,
apiserverID: resourceID,
}, nil
}

Expand All @@ -272,6 +304,7 @@ func (ts *azureTokenSource) storeTokenInCfg(token *azureToken) error {
newCfg[cfgApiserverID] = token.apiserverID
newCfg[cfgExpiresIn] = string(token.token.ExpiresIn)
newCfg[cfgExpiresOn] = string(token.token.ExpiresOn)
newCfg[cfgConfigMode] = strconv.Itoa(int(ts.configMode))

err := ts.persister.Persist(newCfg)
if err != nil {
Expand All @@ -287,9 +320,17 @@ func (ts *azureTokenSource) refreshToken(token *azureToken) (*azureToken, error)
return nil, err
}

oauthConfig, err := adal.NewOAuthConfig(env.ActiveDirectoryEndpoint, token.tenantID)
if err != nil {
return nil, fmt.Errorf("building the OAuth configuration for token refresh: %v", err)
var oauthConfig *adal.OAuthConfig
if ts.configMode == configModeOmitSPNPrefix {
oauthConfig, err = adal.NewOAuthConfigWithAPIVersion(env.ActiveDirectoryEndpoint, token.tenantID, nil)
if err != nil {
return nil, fmt.Errorf("building the OAuth configuration without api-version for token refresh: %v", err)
}
} else {
oauthConfig, err = adal.NewOAuthConfig(env.ActiveDirectoryEndpoint, token.tenantID)
if err != nil {
return nil, fmt.Errorf("building the OAuth configuration for token refresh: %v", err)
}
}

callback := func(t adal.Token) error {
Expand Down Expand Up @@ -323,9 +364,10 @@ type azureTokenSourceDeviceCode struct {
clientID string
tenantID string
apiserverID string
configMode configMode
}

func newAzureTokenSourceDeviceCode(environment azure.Environment, clientID string, tenantID string, apiserverID string) (tokenSource, error) {
func newAzureTokenSourceDeviceCode(environment azure.Environment, clientID string, tenantID string, apiserverID string, configMode configMode) (tokenSource, error) {
if clientID == "" {
return nil, errors.New("client-id is empty")
}
Expand All @@ -340,13 +382,25 @@ func newAzureTokenSourceDeviceCode(environment azure.Environment, clientID strin
clientID: clientID,
tenantID: tenantID,
apiserverID: apiserverID,
configMode: configMode,
}, nil
}

func (ts *azureTokenSourceDeviceCode) Token() (*azureToken, error) {
oauthConfig, err := adal.NewOAuthConfig(ts.environment.ActiveDirectoryEndpoint, ts.tenantID)
if err != nil {
return nil, fmt.Errorf("building the OAuth configuration for device code authentication: %v", err)
var (
oauthConfig *adal.OAuthConfig
err error
)
if ts.configMode == configModeOmitSPNPrefix {
oauthConfig, err = adal.NewOAuthConfigWithAPIVersion(ts.environment.ActiveDirectoryEndpoint, ts.tenantID, nil)
if err != nil {
return nil, fmt.Errorf("building the OAuth configuration without api-version for device code authentication: %v", err)
}
} else {
oauthConfig, err = adal.NewOAuthConfig(ts.environment.ActiveDirectoryEndpoint, ts.tenantID)
if err != nil {
return nil, fmt.Errorf("building the OAuth configuration for device code authentication: %v", err)
}
}
client := &autorest.Client{}
deviceCode, err := adal.InitiateDeviceAuth(client, *oauthConfig, ts.clientID, ts.apiserverID)
Expand Down
Loading

0 comments on commit c08db5f

Please sign in to comment.