Skip to content

Commit

Permalink
Merge pull request #3378 from cvvz/support-workload-identity
Browse files Browse the repository at this point in the history
feat: support workload identity
  • Loading branch information
k8s-ci-robot committed Apr 23, 2023
2 parents 8f56e37 + c62d043 commit 8d912fa
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 0 deletions.
12 changes: 12 additions & 0 deletions pkg/provider/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,18 @@ func ParseConfig(configReader io.Reader) (*Config, error) {
// The resource group name may be in different cases from different Azure APIs, hence it is converted to lower here.
// See more context at https://github.com/kubernetes/kubernetes/issues/71994.
config.ResourceGroup = strings.ToLower(config.ResourceGroup)

// these environment variables are injected by workload identity webhook
if tenantID := os.Getenv("AZURE_TENANT_ID"); tenantID != "" {
config.TenantID = tenantID
}
if clientID := os.Getenv("AZURE_CLIENT_ID"); clientID != "" {
config.AADClientID = clientID
}
if federatedTokenFile := os.Getenv("AZURE_FEDERATED_TOKEN_FILE"); federatedTokenFile != "" {
config.AADFederatedTokenFile = federatedTokenFile
config.UseFederatedWorkloadIdentityExtension = true
}
return &config, nil
}

Expand Down
15 changes: 15 additions & 0 deletions pkg/provider/azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"
"math"
"net/http"
"os"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -2214,6 +2215,10 @@ func TestNewCloudFromJSON(t *testing.T) {
"vmType": "vmss",
"disableAvailabilitySetNodes": true
}`
os.Setenv("AZURE_FEDERATED_TOKEN_FILE", "--aad-federated-token-file--")
defer func() {
os.Unsetenv("AZURE_FEDERATED_TOKEN_FILE")
}()
validateConfig(t, config)
}

Expand Down Expand Up @@ -2277,6 +2282,10 @@ plsCacheTTLInSeconds: 100
vmType: vmss
disableAvailabilitySetNodes: true
`
os.Setenv("AZURE_FEDERATED_TOKEN_FILE", "--aad-federated-token-file--")
defer func() {
os.Unsetenv("AZURE_FEDERATED_TOKEN_FILE")
}()
validateConfig(t, config)
}

Expand Down Expand Up @@ -2391,6 +2400,12 @@ func validateConfig(t *testing.T, config string) { //nolint
if !azureCloud.DisableAvailabilitySetNodes {
t.Errorf("got incorrect value for disableAvailabilitySetNodes")
}
if azureCloud.AADFederatedTokenFile != "--aad-federated-token-file--" {
t.Errorf("got incorrect value for AADFederatedTokenFile")
}
if !azureCloud.UseFederatedWorkloadIdentityExtension {
t.Errorf("got incorrect value for UseFederatedWorkloadIdentityExtension")
}
}

func getCloudFromConfig(t *testing.T, config string) *Cloud {
Expand Down
26 changes: 26 additions & 0 deletions pkg/provider/config/azure_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ type AzureAuthConfig struct {
NetworkResourceTenantID string `json:"networkResourceTenantID,omitempty" yaml:"networkResourceTenantID,omitempty"`
// The ID of the Azure Subscription that the network resources are deployed in
NetworkResourceSubscriptionID string `json:"networkResourceSubscriptionID,omitempty" yaml:"networkResourceSubscriptionID,omitempty"`
// The AAD federated token file
AADFederatedTokenFile string `json:"aadFederatedTokenFile,omitempty" yaml:"aadFederatedTokenFile,omitempty"`
// Use workload identity federation for the virtual machine to access Azure ARM APIs
UseFederatedWorkloadIdentityExtension bool `json:"useFederatedWorkloadIdentityExtension,omitempty" yaml:"useFederatedWorkloadIdentityExtension,omitempty"`
}

// GetServicePrincipalToken creates a new service principal token based on the configuration.
Expand All @@ -100,6 +104,28 @@ func GetServicePrincipalToken(config *AzureAuthConfig, env *azure.Environment, r
resource = env.ServiceManagementEndpoint
}

if config.UseFederatedWorkloadIdentityExtension {
klog.V(2).Infoln("azure: using workload identity extension to retrieve access token")
oauthConfig, err := adal.NewOAuthConfigWithAPIVersion(env.ActiveDirectoryEndpoint, config.TenantID, nil)
if err != nil {
return nil, fmt.Errorf("failed to create the OAuth config: %w", err)
}

jwtCallback := func() (string, error) {
jwt, err := os.ReadFile(config.AADFederatedTokenFile)
if err != nil {
return "", fmt.Errorf("failed to read a file with a federated token: %w", err)
}
return string(jwt), nil
}

token, err := adal.NewServicePrincipalTokenFromFederatedTokenCallback(*oauthConfig, config.AADClientID, jwtCallback, env.ResourceManagerEndpoint)
if err != nil {
return nil, fmt.Errorf("failed to create a workload identity token: %w", err)
}
return token, nil
}

if config.UseManagedIdentityExtension {
klog.V(2).Infoln("azure: using managed identity extension to retrieve access token")
msiEndpoint, err := adal.GetMSIVMEndpoint()
Expand Down
33 changes: 33 additions & 0 deletions pkg/provider/config/azure_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package config

import (
"fmt"
"net/http"
"net/http/httptest"
"os"
Expand Down Expand Up @@ -205,6 +206,38 @@ func TestGetServicePrincipalTokenFromMSI(t *testing.T) {

}

func TestGetServicePrincipalTokenFromWorkloadIdentity(t *testing.T) {
config := &AzureAuthConfig{
TenantID: "TenantID",
AADClientID: "AADClientID",
AADFederatedTokenFile: "/tmp/federated-token",
UseFederatedWorkloadIdentityExtension: true,
}
env := &azure.PublicCloud

token, err := GetServicePrincipalToken(config, env, "")
assert.NoError(t, err)
marshalToken, _ := token.MarshalJSON()

oauthConfig, err := adal.NewOAuthConfigWithAPIVersion(env.ActiveDirectoryEndpoint, config.TenantID, nil)
assert.NoError(t, err)

jwtCallback := func() (string, error) {
jwt, err := os.ReadFile(config.AADFederatedTokenFile)
if err != nil {
return "", fmt.Errorf("failed to read a file with a federated token: %w", err)
}
return string(jwt), nil
}

spt, err := adal.NewServicePrincipalTokenFromFederatedTokenCallback(*oauthConfig, config.AADClientID, jwtCallback, env.ResourceManagerEndpoint)
assert.NoError(t, err)

marshalSpt, _ := spt.MarshalJSON()

assert.Equal(t, marshalToken, marshalSpt)
}

func TestGetServicePrincipalToken(t *testing.T) {
config := &AzureAuthConfig{
TenantID: "TenantID",
Expand Down

0 comments on commit 8d912fa

Please sign in to comment.