Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for AAD auth in Azure Storage Queues binding #1842

Merged
merged 3 commits into from
Jul 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
49 changes: 47 additions & 2 deletions authentication/azure/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"time"

"github.com/Azure/azure-storage-blob-go/azblob"
"github.com/Azure/azure-storage-queue-go/azqueue"
"github.com/Azure/go-autorest/autorest/azure"

mdutils "github.com/dapr/components-contrib/metadata"
Expand All @@ -33,9 +34,9 @@ var (
StorageEndpointKeys = []string{"endpoint", "storageEndpoint", "storageAccountEndpoint", "queueEndpointUrl"}
)

// GetAzureStorageCredentials returns a azblob.Credential object that can be used to authenticate an Azure Blob Storage SDK pipeline.
// GetAzureStorageBlobCredentials returns a azblob.Credential object that can be used to authenticate an Azure Blob Storage SDK pipeline ("track 1").
// First it tries to authenticate using shared key credentials (using an account key) if present. It falls back to attempting to use Azure AD (via a service principal or MSI).
func GetAzureStorageCredentials(log logger.Logger, accountName string, metadata map[string]string) (azblob.Credential, *azure.Environment, error) {
func GetAzureStorageBlobCredentials(log logger.Logger, accountName string, metadata map[string]string) (azblob.Credential, *azure.Environment, error) {
settings, err := NewEnvironmentSettings("storage", metadata)
if err != nil {
return nil, nil, err
Expand Down Expand Up @@ -76,3 +77,47 @@ func GetAzureStorageCredentials(log logger.Logger, accountName string, metadata

return credential, settings.AzureEnvironment, nil
}

// GetAzureStorageQueueCredentials returns a azqueues.Credential object that can be used to authenticate an Azure Queue Storage SDK pipeline ("track 1").
// First it tries to authenticate using shared key credentials (using an account key) if present. It falls back to attempting to use Azure AD (via a service principal or MSI).
func GetAzureStorageQueueCredentials(log logger.Logger, accountName string, metadata map[string]string) (azqueue.Credential, *azure.Environment, error) {
settings, err := NewEnvironmentSettings("storage", metadata)
if err != nil {
return nil, nil, err
}

// Try using shared key credentials first
accountKey, ok := mdutils.GetMetadataProperty(metadata, StorageAccountKeyKeys...)
if ok && accountKey != "" {
credential, newSharedKeyErr := azqueue.NewSharedKeyCredential(accountName, accountKey)
if err != nil {
return nil, nil, fmt.Errorf("invalid credentials with error: %s", newSharedKeyErr.Error())
}

return credential, settings.AzureEnvironment, nil
}

// Fallback to using Azure AD
spt, err := settings.GetServicePrincipalToken()
if err != nil {
return nil, nil, err
}
var tokenRefresher azqueue.TokenRefresher = func(credential azqueue.TokenCredential) time.Duration {
log.Debug("Refreshing Azure Storage auth token")
err := spt.Refresh()
if err != nil {
panic(err)
}
token := spt.Token()
credential.SetToken(token.AccessToken)

// Make the token expire 2 minutes earlier to get some extra buffer
exp := token.Expires().Sub(time.Now().Add(2 * time.Minute))
log.Debug("Received new token, valid for", exp)

return exp
}
credential := azqueue.NewTokenCredential("", tokenRefresher)

return credential, settings.AzureEnvironment, nil
}
7 changes: 3 additions & 4 deletions bindings/azure/blobstorage/blobstorage.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ const (
// Specifies the maximum number of blobs to return, including all BlobPrefix elements. If the request does not
// specify maxresults the server will return up to 5,000 items.
// See: https://docs.microsoft.com/en-us/rest/api/storageservices/list-blobs#uri-parameters
maxResults = 5000
endpointKey = "endpoint"
maxResults = 5000
)

var ErrMissingBlobName = errors.New("blobName is a required attribute")
Expand Down Expand Up @@ -118,7 +117,7 @@ func (a *AzureBlobStorage) Init(metadata bindings.Metadata) error {
}
a.metadata = m

credential, env, err := azauth.GetAzureStorageCredentials(a.logger, m.AccountName, metadata.Properties)
credential, env, err := azauth.GetAzureStorageBlobCredentials(a.logger, m.AccountName, metadata.Properties)
if err != nil {
return fmt.Errorf("invalid credentials with error: %s", err.Error())
}
Expand All @@ -130,7 +129,7 @@ func (a *AzureBlobStorage) Init(metadata bindings.Metadata) error {
p := azblob.NewPipeline(credential, options)

var containerURL azblob.ContainerURL
customEndpoint, ok := metadata.Properties[endpointKey]
customEndpoint, ok := mdutils.GetMetadataProperty(metadata.Properties, azauth.StorageEndpointKeys...)
if ok && customEndpoint != "" {
URL, parseErr := url.Parse(fmt.Sprintf("%s/%s/%s", customEndpoint, m.AccountName, m.Container))
if parseErr != nil {
Expand Down
2 changes: 1 addition & 1 deletion bindings/azure/eventhubs/eventhubs.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ func (a *AzureEventHubs) Init(metadata bindings.Metadata) error {
if m.storageAccountKey != "" {
metadata.Properties["accountKey"] = m.storageAccountKey
}
a.storageCredential, a.azureEnvironment, err = azauth.GetAzureStorageCredentials(a.logger, m.storageAccountName, metadata.Properties)
a.storageCredential, a.azureEnvironment, err = azauth.GetAzureStorageBlobCredentials(a.logger, m.storageAccountName, metadata.Properties)
if err != nil {
return fmt.Errorf("invalid credentials with error: %w", err)
}
Expand Down
86 changes: 34 additions & 52 deletions bindings/azure/storagequeues/storagequeues.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,64 +44,59 @@ type consumer struct {

// QueueHelper enables injection for testnig.
type QueueHelper interface {
Init(endpoint string, accountName string, accountKey string, queueName string, decodeBase64 bool) error
Init(metadata bindings.Metadata) (*storageQueuesMetadata, error)
Write(ctx context.Context, data []byte, ttl *time.Duration) error
Read(ctx context.Context, consumer *consumer) error
}

// AzureQueueHelper concrete impl of queue helper.
type AzureQueueHelper struct {
credential *azqueue.SharedKeyCredential
queueURL azqueue.QueueURL
reqURI string
logger logger.Logger
decodeBase64 bool
}

func getEndpoint(endpoint, reqURI, accountName, queueName string) (*url.URL, error) {
if endpoint != "" {
u, err := url.Parse(endpoint)
if err != nil {
return nil, err
}

p, err := url.Parse(queueName)
if err != nil {
return nil, err
}

return u.ResolveReference(p), nil
}

return url.Parse(fmt.Sprintf(reqURI, accountName, queueName))
}

// Init sets up this helper.
func (d *AzureQueueHelper) Init(endpoint string, accountName string, accountKey string, queueName string, decodeBase64 bool) error {
credential, err := azqueue.NewSharedKeyCredential(accountName, accountKey)
func (d *AzureQueueHelper) Init(metadata bindings.Metadata) (*storageQueuesMetadata, error) {
m, err := parseMetadata(metadata)
if err != nil {
return err
return nil, err
}
d.credential = credential
d.decodeBase64 = decodeBase64
u, err := getEndpoint(endpoint, d.reqURI, accountName, queueName)

credential, env, err := azauth.GetAzureStorageQueueCredentials(d.logger, m.AccountName, metadata.Properties)
if err != nil {
return err
return nil, fmt.Errorf("invalid credentials with error: %s", err.Error())
}

userAgent := "dapr-" + logger.DaprVersion
pipelineOptions := azqueue.PipelineOptions{
Telemetry: azqueue.TelemetryOptions{
Value: userAgent,
},
}
d.queueURL = azqueue.NewQueueURL(*u, azqueue.NewPipeline(credential, pipelineOptions))
ctx := context.TODO()
p := azqueue.NewPipeline(credential, pipelineOptions)

d.decodeBase64 = m.DecodeBase64

if m.QueueEndpoint != "" {
URL, parseErr := url.Parse(fmt.Sprintf("%s/%s/%s", m.QueueEndpoint, m.AccountName, m.QueueName))
if parseErr != nil {
return nil, parseErr
}
d.queueURL = azqueue.NewQueueURL(*URL, p)
} else {
URL, _ := url.Parse(fmt.Sprintf("https://%s.queue.%s/%s", m.AccountName, env.StorageEndpointSuffix, m.QueueName))
d.queueURL = azqueue.NewQueueURL(*URL, p)
}

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
_, err = d.queueURL.Create(ctx, azqueue.Metadata{})
cancel()
if err != nil {
return err
return nil, err
}

return nil
return m, nil
}

func (d *AzureQueueHelper) Write(ctx context.Context, data []byte, ttl *time.Duration) error {
Expand Down Expand Up @@ -167,7 +162,6 @@ func (d *AzureQueueHelper) Read(ctx context.Context, consumer *consumer) error {
// NewAzureQueueHelper creates new helper.
func NewAzureQueueHelper(logger logger.Logger) QueueHelper {
return &AzureQueueHelper{
reqURI: "https://%s.queue.core.windows.net/%s",
logger: logger,
}
}
Expand All @@ -181,11 +175,10 @@ type AzureStorageQueues struct {
}

type storageQueuesMetadata struct {
AccountKey string `json:"storageAccessKey"`
QueueName string `json:"queue"`
QueueEndpoint string `json:"queueEndpointUrl"`
AccountName string `json:"storageAccount"`
DecodeBase64 bool `json:"decodeBase64"`
QueueName string
QueueEndpoint string
AccountName string
DecodeBase64 bool
ttl *time.Duration
}

Expand All @@ -195,27 +188,16 @@ func NewAzureStorageQueues(logger logger.Logger) *AzureStorageQueues {
}

// Init parses connection properties and creates a new Storage Queue client.
func (a *AzureStorageQueues) Init(metadata bindings.Metadata) error {
meta, err := a.parseMetadata(metadata)
if err != nil {
return err
}
a.metadata = meta

endpoint := ""
if a.metadata.QueueEndpoint != "" {
endpoint = a.metadata.QueueEndpoint
}

err = a.helper.Init(endpoint, a.metadata.AccountName, a.metadata.AccountKey, a.metadata.QueueName, a.metadata.DecodeBase64)
func (a *AzureStorageQueues) Init(metadata bindings.Metadata) (err error) {
a.metadata, err = a.helper.Init(metadata)
if err != nil {
return err
}

return nil
}

func (a *AzureStorageQueues) parseMetadata(metadata bindings.Metadata) (*storageQueuesMetadata, error) {
func parseMetadata(metadata bindings.Metadata) (*storageQueuesMetadata, error) {
var m storageQueuesMetadata
// AccountKey is parsed in azauth

Expand Down
28 changes: 9 additions & 19 deletions bindings/azure/storagequeues/storagequeues_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,8 @@ type MockHelper struct {
mock.Mock
}

func (m *MockHelper) Init(endpoint, accountName, accountKey, queueName string, decodeBase64 bool) error {
retvals := m.Called(endpoint, accountName, accountKey, queueName, decodeBase64)

return retvals.Error(0)
func (m *MockHelper) Init(metadata bindings.Metadata) (*storageQueuesMetadata, error) {
return parseMetadata(metadata)
}

func (m *MockHelper) Write(ctx context.Context, data []byte, ttl *time.Duration) error {
Expand All @@ -50,7 +48,6 @@ func (m *MockHelper) Read(ctx context.Context, consumer *consumer) error {

func TestWriteQueue(t *testing.T) {
mm := new(MockHelper)
mm.On("Init", "", mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("bool")).Return(nil)
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.MatchedBy(func(in *time.Duration) bool {
return in == nil
})).Return(nil)
Expand All @@ -65,14 +62,13 @@ func TestWriteQueue(t *testing.T) {

r := bindings.InvokeRequest{Data: []byte("This is my message")}

_, err = a.Invoke(context.TODO(), &r)
_, err = a.Invoke(context.Background(), &r)

assert.Nil(t, err)
}

func TestWriteWithTTLInQueue(t *testing.T) {
mm := new(MockHelper)
mm.On("Init", "", mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("bool")).Return(nil)
mm.On("Write", mock.AnythingOfTypeArgument("[]uint8"), mock.MatchedBy(func(in *time.Duration) bool {
return in != nil && *in == time.Second
})).Return(nil)
Expand All @@ -87,14 +83,13 @@ func TestWriteWithTTLInQueue(t *testing.T) {

r := bindings.InvokeRequest{Data: []byte("This is my message")}

_, err = a.Invoke(context.TODO(), &r)
_, err = a.Invoke(context.Background(), &r)

assert.Nil(t, err)
}

func TestWriteWithTTLInWrite(t *testing.T) {
mm := new(MockHelper)
mm.On("Init", "", mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("bool")).Return(nil)
mm.On("Write", mock.AnythingOfTypeArgument("[]uint8"), mock.MatchedBy(func(in *time.Duration) bool {
return in != nil && *in == time.Second
})).Return(nil)
Expand All @@ -112,7 +107,7 @@ func TestWriteWithTTLInWrite(t *testing.T) {
Metadata: map[string]string{metadata.TTLMetadataKey: "1"},
}

_, err = a.Invoke(context.TODO(), &r)
_, err = a.Invoke(context.Background(), &r)

assert.Nil(t, err)
}
Expand All @@ -137,7 +132,6 @@ func TestWriteWithTTLInWrite(t *testing.T) {

func TestReadQueue(t *testing.T) {
mm := new(MockHelper)
mm.On("Init", "", mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("bool")).Return(nil)
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil)
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}

Expand All @@ -149,7 +143,7 @@ func TestReadQueue(t *testing.T) {

r := bindings.InvokeRequest{Data: []byte("This is my message")}

_, err = a.Invoke(context.TODO(), &r)
_, err = a.Invoke(context.Background(), &r)

assert.Nil(t, err)

Expand All @@ -171,7 +165,6 @@ func TestReadQueue(t *testing.T) {

func TestReadQueueDecode(t *testing.T) {
mm := new(MockHelper)
mm.On("Init", "", mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("bool")).Return(nil)
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil)

a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
Expand All @@ -184,7 +177,7 @@ func TestReadQueueDecode(t *testing.T) {

r := bindings.InvokeRequest{Data: []byte("VGhpcyBpcyBteSBtZXNzYWdl")}

_, err = a.Invoke(context.TODO(), &r)
_, err = a.Invoke(context.Background(), &r)

assert.Nil(t, err)

Expand Down Expand Up @@ -235,7 +228,6 @@ func TestReadQueueDecode(t *testing.T) {
*/
func TestReadQueueNoMessage(t *testing.T) {
mm := new(MockHelper)
mm.On("Init", "", mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), false).Return(nil)
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil)

a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
Expand Down Expand Up @@ -310,8 +302,7 @@ func TestParseMetadata(t *testing.T) {
m := bindings.Metadata{}
m.Properties = tt.properties

a := NewAzureStorageQueues(logger.NewLogger("test"))
meta, err := a.parseMetadata(m)
meta, err := parseMetadata(m)

assert.Nil(t, err)
// assert.Equal(t, tt.expectedAccountKey, meta.AccountKey)
Expand Down Expand Up @@ -346,8 +337,7 @@ func TestParseMetadataWithInvalidTTL(t *testing.T) {
m := bindings.Metadata{}
m.Properties = tt.properties

a := NewAzureStorageQueues(logger.NewLogger("test"))
_, err := a.parseMetadata(m)
_, err := parseMetadata(m)
assert.NotNil(t, err)
})
}
Expand Down
2 changes: 1 addition & 1 deletion pubsub/azure/eventhubs/eventhubs.go
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ func (aeh *AzureEventHubs) Init(metadata pubsub.Metadata) error {
metadata.Properties["accountKey"] = m.StorageAccountKey
}
var storageCredsErr error
aeh.storageCredential, aeh.azureEnvironment, storageCredsErr = azauth.GetAzureStorageCredentials(aeh.logger, m.StorageAccountName, metadata.Properties)
aeh.storageCredential, aeh.azureEnvironment, storageCredsErr = azauth.GetAzureStorageBlobCredentials(aeh.logger, m.StorageAccountName, metadata.Properties)
if storageCredsErr != nil {
return fmt.Errorf("invalid storage credentials with error: %w", storageCredsErr)
}
Expand Down
2 changes: 1 addition & 1 deletion state/azure/blobstorage/blobstorage.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (r *StateStore) Init(metadata state.Metadata) error {
return err
}

credential, env, err := azauth.GetAzureStorageCredentials(r.logger, meta.accountName, metadata.Properties)
credential, env, err := azauth.GetAzureStorageBlobCredentials(r.logger, meta.accountName, metadata.Properties)
if err != nil {
return fmt.Errorf("invalid credentials with error: %s", err.Error())
}
Expand Down