Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Azure SDK "track 2": authentication and secretstore/azure/keyvault #1290

Merged
merged 14 commits into from
Nov 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 107 additions & 6 deletions authentication/azure/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ import (
"crypto/x509"
"errors"
"fmt"
"io/ioutil"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest/azure"
Expand Down Expand Up @@ -69,10 +72,58 @@ func (s EnvironmentSettings) GetAzureEnvironment() (*azure.Environment, error) {
return &env, err
}

// GetTokenCredential returns an azcore.TokenCredential retrieved from, in order:
// 1. Client credentials
// 2. Client certificate
// 3. MSI
// This is used by the newer ("track 2") Azure SDKs.
func (s EnvironmentSettings) GetTokenCredential() (azcore.TokenCredential, error) {
// Create a chain
var creds []azcore.TokenCredential
errMsg := ""

// 1. Client credentials
if c, e := s.GetClientCredentials(); e == nil {
cred, err := c.GetTokenCredential()
if err == nil {
creds = append(creds, cred)
} else {
errMsg += err.Error() + "\n"
}
}

// 2. Client certificate
if c, e := s.GetClientCert(); e == nil {
cred, err := c.GetTokenCredential()
if err == nil {
creds = append(creds, cred)
} else {
errMsg += err.Error() + "\n"
}
}

// 3. MSI
{
c := s.GetMSI()
cred, err := c.GetTokenCredential()
if err == nil {
creds = append(creds, cred)
} else {
errMsg += err.Error() + "\n"
}
}

if len(creds) == 0 {
return nil, fmt.Errorf("no suitable token provider for Azure AD; errors: %v", errMsg)
}
return azidentity.NewChainedTokenCredential(creds, nil)
}

// GetAuthorizer creates an Authorizer retrieved from, in order:
// 1. Client credentials
// 2. Client certificate
// 3. MSI.
// 3. MSI
// This is used by the older Azure SDKs.
func (s EnvironmentSettings) GetAuthorizer() (autorest.Authorizer, error) {
spt, err := s.GetServicePrincipalToken()
if err != nil {
Expand All @@ -85,7 +136,8 @@ func (s EnvironmentSettings) GetAuthorizer() (autorest.Authorizer, error) {
// GetServicePrincipalToken returns a Service Principal Token retrieved from, in order:
// 1. Client credentials
// 2. Client certificate
// 3. MSI.
// 3. MSI
// This is used by the older Azure SDKs.
func (s EnvironmentSettings) GetServicePrincipalToken() (*adal.ServicePrincipalToken, error) {
// 1. Client credentials
if c, e := s.GetClientCredentials(); e == nil {
Expand Down Expand Up @@ -182,6 +234,13 @@ func (c CredentialsConfig) ServicePrincipalToken() (*adal.ServicePrincipalToken,
return adal.NewServicePrincipalToken(*oauthConfig, c.ClientID, c.ClientSecret, c.Resource)
}

// GetTokenCredential returns the azcore.TokenCredential object from the credentials.
func (c CredentialsConfig) GetTokenCredential() (token azcore.TokenCredential, err error) {
return azidentity.NewClientSecretCredential(c.TenantID, c.ClientID, c.ClientSecret, &azidentity.ClientSecretCredentialOptions{
AuthorityHost: azidentity.AuthorityHost(c.AADEndpoint),
})
}

// CertConfig provides the options to get a bearer authorizer from a client certificate.
type CertConfig struct {
*auth.ClientCertificateConfig
Expand Down Expand Up @@ -231,6 +290,39 @@ func (c CertConfig) ServicePrincipalTokenByCertBytes() (*adal.ServicePrincipalTo
return adal.NewServicePrincipalTokenFromCertificate(*oauthConfig, c.ClientID, certificate, rsaPrivateKey, c.Resource)
}

// GetTokenCredential returns the azcore.TokenCredential object from client certificate.
func (c CertConfig) GetTokenCredential() (token azcore.TokenCredential, err error) {
ccc := c.ClientCertificateConfig

// Certificate data - may be empty here
data := c.CertificateData

// If we have a certificate path, load it
if c.ClientCertificateConfig.CertificatePath != "" {
var errB error
data, errB = ioutil.ReadFile(ccc.CertificatePath)
if errB != nil {
return nil, fmt.Errorf("failed to read the certificate file (%s): %v", ccc.CertificatePath, errB)
}
}
if len(data) == 0 {
return nil, fmt.Errorf("certificate is not given")
}

// Decode the PKCS#12 certificate
cert, key, err := c.decodePkcs12(data, c.CertificatePassword)
if err != nil || cert == nil {
return nil, fmt.Errorf("failed to decode pkcs12 certificate while creating spt: %v", err)
}

// Create the azcore.TokenCredential object
certs := []*x509.Certificate{cert}
opts := &azidentity.ClientCertificateCredentialOptions{
AuthorityHost: azidentity.AuthorityHost(c.AADEndpoint),
}
return azidentity.NewClientCertificateCredential(c.TenantID, c.ClientID, certs, key, opts)
}

func (c CertConfig) decodePkcs12(pkcs []byte, password string) (*x509.Certificate, *rsa.PrivateKey, error) {
privateKey, certificate, err := pkcs12.Decode(pkcs, password)
if err != nil {
Expand Down Expand Up @@ -259,20 +351,20 @@ func NewMSIConfig(resource string) MSIConfig {
}

// ServicePrincipalToken gets the ServicePrincipalToken object from MSI.
func (mc MSIConfig) ServicePrincipalToken() (*adal.ServicePrincipalToken, error) {
func (c MSIConfig) ServicePrincipalToken() (*adal.ServicePrincipalToken, error) {
msiEndpoint, err := adal.GetMSIEndpoint()
if err != nil {
return nil, err
}

var spToken *adal.ServicePrincipalToken
if mc.ClientID == "" {
spToken, err = adal.NewServicePrincipalTokenFromMSI(msiEndpoint, mc.Resource)
if c.ClientID == "" {
spToken, err = adal.NewServicePrincipalTokenFromMSI(msiEndpoint, c.Resource)
if err != nil {
return nil, fmt.Errorf("failed to get oauth token from MSI: %v", err)
}
} else {
spToken, err = adal.NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, mc.Resource, mc.ClientID)
spToken, err = adal.NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, c.Resource, c.ClientID)
if err != nil {
return nil, fmt.Errorf("failed to get oauth token from MSI for user assigned identity: %v", err)
}
Expand All @@ -281,6 +373,15 @@ func (mc MSIConfig) ServicePrincipalToken() (*adal.ServicePrincipalToken, error)
return spToken, nil
}

// GetTokenCredential returns the azcore.TokenCredential object from MSI.
func (c MSIConfig) GetTokenCredential() (token azcore.TokenCredential, err error) {
opts := &azidentity.ManagedIdentityCredentialOptions{}
if c.ClientID != "" {
opts.ID = azidentity.ClientID(c.ClientID)
}
return azidentity.NewManagedIdentityCredential(opts)
}

// GetAzureEnvironment returns the Azure environment for a given name, supporting aliases too.
func (s EnvironmentSettings) GetEnvironment(key string) (string, bool) {
var (
Expand Down
6 changes: 6 additions & 0 deletions authentication/azure/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ func TestGetMSI(t *testing.T) {
}

func TestFallbackToMSI(t *testing.T) {
os.Setenv("MSI_ENDPOINT", "test")
berndverst marked this conversation as resolved.
Show resolved Hide resolved
defer os.Unsetenv("MSI_ENDPOINT")
settings, err := NewEnvironmentSettings(
"keyvault",
map[string]string{
Expand All @@ -162,6 +164,8 @@ func TestFallbackToMSI(t *testing.T) {
}

func TestAuthorizorWithMSI(t *testing.T) {
os.Setenv("MSI_ENDPOINT", "test")
defer os.Unsetenv("MSI_ENDPOINT")
settings, err := NewEnvironmentSettings(
"keyvault",
map[string]string{
Expand All @@ -180,6 +184,8 @@ func TestAuthorizorWithMSI(t *testing.T) {
}

func TestAuthorizorWithMSIAndUserAssignedID(t *testing.T) {
os.Setenv("MSI_ENDPOINT", "test")
defer os.Unsetenv("MSI_ENDPOINT")
settings, err := NewEnvironmentSettings(
"keyvault",
map[string]string{
Expand Down
2 changes: 1 addition & 1 deletion bindings/azure/eventgrid/eventgrid.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ func (a *AzureEventGrid) createSubscription() error {
return err
}

res := result.Future.Response()
res := result.FutureAPI.Response()

if res.StatusCode != fasthttp.StatusCreated {
bodyBytes, err := ioutil.ReadAll(res.Body)
Expand Down
12 changes: 7 additions & 5 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,17 @@ require (
cloud.google.com/go/storage v1.10.0
github.com/Azure/azure-amqp-common-go/v3 v3.1.0
github.com/Azure/azure-event-hubs-go/v3 v3.3.10
github.com/Azure/azure-sdk-for-go v48.2.0+incompatible
github.com/Azure/azure-sdk-for-go v57.2.0+incompatible
github.com/Azure/azure-sdk-for-go/sdk/azcore v0.20.0
github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.12.0
github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets v0.3.0
github.com/Azure/azure-service-bus-go v0.10.10
github.com/Azure/azure-storage-blob-go v0.10.0
github.com/Azure/azure-storage-queue-go v0.0.0-20191125232315-636801874cdd
github.com/Azure/go-amqp v0.13.1
github.com/Azure/go-autorest/autorest v0.11.12
github.com/Azure/go-autorest/autorest/adal v0.9.5
github.com/Azure/go-autorest/autorest/azure/auth v0.4.2
github.com/Azure/go-autorest/autorest v0.11.21
github.com/Azure/go-autorest/autorest/adal v0.9.16
github.com/Azure/go-autorest/autorest/azure/auth v0.5.8
github.com/DATA-DOG/go-sqlmock v1.5.0
github.com/Shopify/sarama v1.23.1
github.com/StackExchange/wmi v0.0.0-20210224194228-fe8f1750fd46 // indirect
Expand Down Expand Up @@ -49,7 +52,6 @@ require (
github.com/dghubble/go-twitter v0.0.0-20190719072343-39e5462e111f
github.com/dghubble/oauth1 v0.6.0
github.com/didip/tollbooth v4.0.2+incompatible
github.com/dnaeon/go-vcr v1.1.0 // indirect
github.com/eapache/go-resiliency v1.2.0 // indirect
github.com/eclipse/paho.mqtt.golang v1.3.5
github.com/fasthttp-contrib/sessions v0.0.0-20160905201309-74f6ac73d5d5
Expand Down
Loading