Skip to content

Commit

Permalink
Add support for AAD auth in Azure Storage Queues binding (dapr#1842)
Browse files Browse the repository at this point in the history
* Add support for AAD auth in Azure Storage Queues binding

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>

* 馃Ч

Signed-off-by: ItalyPaleAle <43508+ItalyPaleAle@users.noreply.github.com>

Co-authored-by: Bernd Verst <4535280+berndverst@users.noreply.github.com>
Signed-off-by: Andrew Duss <andy.duss@storable.com>
  • Loading branch information
2 people authored and Andrew Duss committed Jul 6, 2022
1 parent fb653f2 commit 6629ff9
Show file tree
Hide file tree
Showing 21 changed files with 127 additions and 80 deletions.
49 changes: 47 additions & 2 deletions authentication/azure/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"time"

"github.com/Azure/azure-storage-blob-go/azblob"
"github.com/Azure/azure-storage-queue-go/azqueue"
"github.com/Azure/go-autorest/autorest/azure"

mdutils "github.com/dapr/components-contrib/metadata"
Expand All @@ -33,9 +34,9 @@ var (
StorageEndpointKeys = []string{"endpoint", "storageEndpoint", "storageAccountEndpoint", "queueEndpointUrl"}
)

// GetAzureStorageCredentials returns a azblob.Credential object that can be used to authenticate an Azure Blob Storage SDK pipeline.
// GetAzureStorageBlobCredentials returns a azblob.Credential object that can be used to authenticate an Azure Blob Storage SDK pipeline ("track 1").
// First it tries to authenticate using shared key credentials (using an account key) if present. It falls back to attempting to use Azure AD (via a service principal or MSI).
func GetAzureStorageCredentials(log logger.Logger, accountName string, metadata map[string]string) (azblob.Credential, *azure.Environment, error) {
func GetAzureStorageBlobCredentials(log logger.Logger, accountName string, metadata map[string]string) (azblob.Credential, *azure.Environment, error) {
settings, err := NewEnvironmentSettings("storage", metadata)
if err != nil {
return nil, nil, err
Expand Down Expand Up @@ -76,3 +77,47 @@ func GetAzureStorageCredentials(log logger.Logger, accountName string, metadata

return credential, settings.AzureEnvironment, nil
}

// GetAzureStorageQueueCredentials returns a azqueues.Credential object that can be used to authenticate an Azure Queue Storage SDK pipeline ("track 1").
// First it tries to authenticate using shared key credentials (using an account key) if present. It falls back to attempting to use Azure AD (via a service principal or MSI).
func GetAzureStorageQueueCredentials(log logger.Logger, accountName string, metadata map[string]string) (azqueue.Credential, *azure.Environment, error) {
settings, err := NewEnvironmentSettings("storage", metadata)
if err != nil {
return nil, nil, err
}

// Try using shared key credentials first
accountKey, ok := mdutils.GetMetadataProperty(metadata, StorageAccountKeyKeys...)
if ok && accountKey != "" {
credential, newSharedKeyErr := azqueue.NewSharedKeyCredential(accountName, accountKey)
if err != nil {
return nil, nil, fmt.Errorf("invalid credentials with error: %s", newSharedKeyErr.Error())
}

return credential, settings.AzureEnvironment, nil
}

// Fallback to using Azure AD
spt, err := settings.GetServicePrincipalToken()
if err != nil {
return nil, nil, err
}
var tokenRefresher azqueue.TokenRefresher = func(credential azqueue.TokenCredential) time.Duration {
log.Debug("Refreshing Azure Storage auth token")
err := spt.Refresh()
if err != nil {
panic(err)
}
token := spt.Token()
credential.SetToken(token.AccessToken)

// Make the token expire 2 minutes earlier to get some extra buffer
exp := token.Expires().Sub(time.Now().Add(2 * time.Minute))
log.Debug("Received new token, valid for", exp)

return exp
}
credential := azqueue.NewTokenCredential("", tokenRefresher)

return credential, settings.AzureEnvironment, nil
}
7 changes: 3 additions & 4 deletions bindings/azure/blobstorage/blobstorage.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ const (
// Specifies the maximum number of blobs to return, including all BlobPrefix elements. If the request does not
// specify maxresults the server will return up to 5,000 items.
// See: https://docs.microsoft.com/en-us/rest/api/storageservices/list-blobs#uri-parameters
maxResults = 5000
endpointKey = "endpoint"
maxResults = 5000
)

var ErrMissingBlobName = errors.New("blobName is a required attribute")
Expand Down Expand Up @@ -118,7 +117,7 @@ func (a *AzureBlobStorage) Init(metadata bindings.Metadata) error {
}
a.metadata = m

credential, env, err := azauth.GetAzureStorageCredentials(a.logger, m.AccountName, metadata.Properties)
credential, env, err := azauth.GetAzureStorageBlobCredentials(a.logger, m.AccountName, metadata.Properties)
if err != nil {
return fmt.Errorf("invalid credentials with error: %s", err.Error())
}
Expand All @@ -130,7 +129,7 @@ func (a *AzureBlobStorage) Init(metadata bindings.Metadata) error {
p := azblob.NewPipeline(credential, options)

var containerURL azblob.ContainerURL
customEndpoint, ok := metadata.Properties[endpointKey]
customEndpoint, ok := mdutils.GetMetadataProperty(metadata.Properties, azauth.StorageEndpointKeys...)
if ok && customEndpoint != "" {
URL, parseErr := url.Parse(fmt.Sprintf("%s/%s/%s", customEndpoint, m.AccountName, m.Container))
if parseErr != nil {
Expand Down
2 changes: 1 addition & 1 deletion bindings/azure/eventhubs/eventhubs.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ func (a *AzureEventHubs) Init(metadata bindings.Metadata) error {
if m.storageAccountKey != "" {
metadata.Properties["accountKey"] = m.storageAccountKey
}
a.storageCredential, a.azureEnvironment, err = azauth.GetAzureStorageCredentials(a.logger, m.storageAccountName, metadata.Properties)
a.storageCredential, a.azureEnvironment, err = azauth.GetAzureStorageBlobCredentials(a.logger, m.storageAccountName, metadata.Properties)
if err != nil {
return fmt.Errorf("invalid credentials with error: %w", err)
}
Expand Down
86 changes: 34 additions & 52 deletions bindings/azure/storagequeues/storagequeues.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,64 +44,59 @@ type consumer struct {

// QueueHelper enables injection for testnig.
type QueueHelper interface {
Init(endpoint string, accountName string, accountKey string, queueName string, decodeBase64 bool) error
Init(metadata bindings.Metadata) (*storageQueuesMetadata, error)
Write(ctx context.Context, data []byte, ttl *time.Duration) error
Read(ctx context.Context, consumer *consumer) error
}

// AzureQueueHelper concrete impl of queue helper.
type AzureQueueHelper struct {
credential *azqueue.SharedKeyCredential
queueURL azqueue.QueueURL
reqURI string
logger logger.Logger
decodeBase64 bool
}

func getEndpoint(endpoint, reqURI, accountName, queueName string) (*url.URL, error) {
if endpoint != "" {
u, err := url.Parse(endpoint)
if err != nil {
return nil, err
}

p, err := url.Parse(queueName)
if err != nil {
return nil, err
}

return u.ResolveReference(p), nil
}

return url.Parse(fmt.Sprintf(reqURI, accountName, queueName))
}

// Init sets up this helper.
func (d *AzureQueueHelper) Init(endpoint string, accountName string, accountKey string, queueName string, decodeBase64 bool) error {
credential, err := azqueue.NewSharedKeyCredential(accountName, accountKey)
func (d *AzureQueueHelper) Init(metadata bindings.Metadata) (*storageQueuesMetadata, error) {
m, err := parseMetadata(metadata)
if err != nil {
return err
return nil, err
}
d.credential = credential
d.decodeBase64 = decodeBase64
u, err := getEndpoint(endpoint, d.reqURI, accountName, queueName)

credential, env, err := azauth.GetAzureStorageQueueCredentials(d.logger, m.AccountName, metadata.Properties)
if err != nil {
return err
return nil, fmt.Errorf("invalid credentials with error: %s", err.Error())
}

userAgent := "dapr-" + logger.DaprVersion
pipelineOptions := azqueue.PipelineOptions{
Telemetry: azqueue.TelemetryOptions{
Value: userAgent,
},
}
d.queueURL = azqueue.NewQueueURL(*u, azqueue.NewPipeline(credential, pipelineOptions))
ctx := context.TODO()
p := azqueue.NewPipeline(credential, pipelineOptions)

d.decodeBase64 = m.DecodeBase64

if m.QueueEndpoint != "" {
URL, parseErr := url.Parse(fmt.Sprintf("%s/%s/%s", m.QueueEndpoint, m.AccountName, m.QueueName))
if parseErr != nil {
return nil, parseErr
}
d.queueURL = azqueue.NewQueueURL(*URL, p)
} else {
URL, _ := url.Parse(fmt.Sprintf("https://%s.queue.%s/%s", m.AccountName, env.StorageEndpointSuffix, m.QueueName))
d.queueURL = azqueue.NewQueueURL(*URL, p)
}

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
_, err = d.queueURL.Create(ctx, azqueue.Metadata{})
cancel()
if err != nil {
return err
return nil, err
}

return nil
return m, nil
}

func (d *AzureQueueHelper) Write(ctx context.Context, data []byte, ttl *time.Duration) error {
Expand Down Expand Up @@ -167,7 +162,6 @@ func (d *AzureQueueHelper) Read(ctx context.Context, consumer *consumer) error {
// NewAzureQueueHelper creates new helper.
func NewAzureQueueHelper(logger logger.Logger) QueueHelper {
return &AzureQueueHelper{
reqURI: "https://%s.queue.core.windows.net/%s",
logger: logger,
}
}
Expand All @@ -181,11 +175,10 @@ type AzureStorageQueues struct {
}

type storageQueuesMetadata struct {
AccountKey string `json:"storageAccessKey"`
QueueName string `json:"queue"`
QueueEndpoint string `json:"queueEndpointUrl"`
AccountName string `json:"storageAccount"`
DecodeBase64 bool `json:"decodeBase64"`
QueueName string
QueueEndpoint string
AccountName string
DecodeBase64 bool
ttl *time.Duration
}

Expand All @@ -195,27 +188,16 @@ func NewAzureStorageQueues(logger logger.Logger) *AzureStorageQueues {
}

// Init parses connection properties and creates a new Storage Queue client.
func (a *AzureStorageQueues) Init(metadata bindings.Metadata) error {
meta, err := a.parseMetadata(metadata)
if err != nil {
return err
}
a.metadata = meta

endpoint := ""
if a.metadata.QueueEndpoint != "" {
endpoint = a.metadata.QueueEndpoint
}

err = a.helper.Init(endpoint, a.metadata.AccountName, a.metadata.AccountKey, a.metadata.QueueName, a.metadata.DecodeBase64)
func (a *AzureStorageQueues) Init(metadata bindings.Metadata) (err error) {
a.metadata, err = a.helper.Init(metadata)
if err != nil {
return err
}

return nil
}

func (a *AzureStorageQueues) parseMetadata(metadata bindings.Metadata) (*storageQueuesMetadata, error) {
func parseMetadata(metadata bindings.Metadata) (*storageQueuesMetadata, error) {
var m storageQueuesMetadata
// AccountKey is parsed in azauth

Expand Down
28 changes: 9 additions & 19 deletions bindings/azure/storagequeues/storagequeues_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,8 @@ type MockHelper struct {
mock.Mock
}

func (m *MockHelper) Init(endpoint, accountName, accountKey, queueName string, decodeBase64 bool) error {
retvals := m.Called(endpoint, accountName, accountKey, queueName, decodeBase64)

return retvals.Error(0)
func (m *MockHelper) Init(metadata bindings.Metadata) (*storageQueuesMetadata, error) {
return parseMetadata(metadata)
}

func (m *MockHelper) Write(ctx context.Context, data []byte, ttl *time.Duration) error {
Expand All @@ -50,7 +48,6 @@ func (m *MockHelper) Read(ctx context.Context, consumer *consumer) error {

func TestWriteQueue(t *testing.T) {
mm := new(MockHelper)
mm.On("Init", "", mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("bool")).Return(nil)
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.MatchedBy(func(in *time.Duration) bool {
return in == nil
})).Return(nil)
Expand All @@ -65,14 +62,13 @@ func TestWriteQueue(t *testing.T) {

r := bindings.InvokeRequest{Data: []byte("This is my message")}

_, err = a.Invoke(context.TODO(), &r)
_, err = a.Invoke(context.Background(), &r)

assert.Nil(t, err)
}

func TestWriteWithTTLInQueue(t *testing.T) {
mm := new(MockHelper)
mm.On("Init", "", mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("bool")).Return(nil)
mm.On("Write", mock.AnythingOfTypeArgument("[]uint8"), mock.MatchedBy(func(in *time.Duration) bool {
return in != nil && *in == time.Second
})).Return(nil)
Expand All @@ -87,14 +83,13 @@ func TestWriteWithTTLInQueue(t *testing.T) {

r := bindings.InvokeRequest{Data: []byte("This is my message")}

_, err = a.Invoke(context.TODO(), &r)
_, err = a.Invoke(context.Background(), &r)

assert.Nil(t, err)
}

func TestWriteWithTTLInWrite(t *testing.T) {
mm := new(MockHelper)
mm.On("Init", "", mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("bool")).Return(nil)
mm.On("Write", mock.AnythingOfTypeArgument("[]uint8"), mock.MatchedBy(func(in *time.Duration) bool {
return in != nil && *in == time.Second
})).Return(nil)
Expand All @@ -112,7 +107,7 @@ func TestWriteWithTTLInWrite(t *testing.T) {
Metadata: map[string]string{metadata.TTLMetadataKey: "1"},
}

_, err = a.Invoke(context.TODO(), &r)
_, err = a.Invoke(context.Background(), &r)

assert.Nil(t, err)
}
Expand All @@ -137,7 +132,6 @@ func TestWriteWithTTLInWrite(t *testing.T) {

func TestReadQueue(t *testing.T) {
mm := new(MockHelper)
mm.On("Init", "", mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("bool")).Return(nil)
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil)
a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}

Expand All @@ -149,7 +143,7 @@ func TestReadQueue(t *testing.T) {

r := bindings.InvokeRequest{Data: []byte("This is my message")}

_, err = a.Invoke(context.TODO(), &r)
_, err = a.Invoke(context.Background(), &r)

assert.Nil(t, err)

Expand All @@ -171,7 +165,6 @@ func TestReadQueue(t *testing.T) {

func TestReadQueueDecode(t *testing.T) {
mm := new(MockHelper)
mm.On("Init", "", mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("bool")).Return(nil)
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil)

a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
Expand All @@ -184,7 +177,7 @@ func TestReadQueueDecode(t *testing.T) {

r := bindings.InvokeRequest{Data: []byte("VGhpcyBpcyBteSBtZXNzYWdl")}

_, err = a.Invoke(context.TODO(), &r)
_, err = a.Invoke(context.Background(), &r)

assert.Nil(t, err)

Expand Down Expand Up @@ -235,7 +228,6 @@ func TestReadQueueDecode(t *testing.T) {
*/
func TestReadQueueNoMessage(t *testing.T) {
mm := new(MockHelper)
mm.On("Init", "", mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), false).Return(nil)
mm.On("Write", mock.AnythingOfType("[]uint8"), mock.AnythingOfType("*time.Duration")).Return(nil)

a := AzureStorageQueues{helper: mm, logger: logger.NewLogger("test")}
Expand Down Expand Up @@ -310,8 +302,7 @@ func TestParseMetadata(t *testing.T) {
m := bindings.Metadata{}
m.Properties = tt.properties

a := NewAzureStorageQueues(logger.NewLogger("test"))
meta, err := a.parseMetadata(m)
meta, err := parseMetadata(m)

assert.Nil(t, err)
// assert.Equal(t, tt.expectedAccountKey, meta.AccountKey)
Expand Down Expand Up @@ -346,8 +337,7 @@ func TestParseMetadataWithInvalidTTL(t *testing.T) {
m := bindings.Metadata{}
m.Properties = tt.properties

a := NewAzureStorageQueues(logger.NewLogger("test"))
_, err := a.parseMetadata(m)
_, err := parseMetadata(m)
assert.NotNil(t, err)
})
}
Expand Down
2 changes: 1 addition & 1 deletion pubsub/azure/eventhubs/eventhubs.go
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ func (aeh *AzureEventHubs) Init(metadata pubsub.Metadata) error {
metadata.Properties["accountKey"] = m.StorageAccountKey
}
var storageCredsErr error
aeh.storageCredential, aeh.azureEnvironment, storageCredsErr = azauth.GetAzureStorageCredentials(aeh.logger, m.StorageAccountName, metadata.Properties)
aeh.storageCredential, aeh.azureEnvironment, storageCredsErr = azauth.GetAzureStorageBlobCredentials(aeh.logger, m.StorageAccountName, metadata.Properties)
if storageCredsErr != nil {
return fmt.Errorf("invalid storage credentials with error: %w", storageCredsErr)
}
Expand Down
2 changes: 1 addition & 1 deletion state/azure/blobstorage/blobstorage.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (r *StateStore) Init(metadata state.Metadata) error {
return err
}

credential, env, err := azauth.GetAzureStorageCredentials(r.logger, meta.accountName, metadata.Properties)
credential, env, err := azauth.GetAzureStorageBlobCredentials(r.logger, meta.accountName, metadata.Properties)
if err != nil {
return fmt.Errorf("invalid credentials with error: %s", err.Error())
}
Expand Down
Loading

0 comments on commit 6629ff9

Please sign in to comment.