Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to MS Graph for AAD operations #63

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 0 additions & 6 deletions backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,10 @@ func (b *azureSecretBackend) getClient(ctx context.Context, s logical.Storage) (
return nil, err
}

passwords := passwords{
policyGenerator: b.System(),
policyName: config.PasswordPolicy,
}

c := &client{
provider: p,
settings: b.settings,
expiration: time.Now().Add(clientLifetime),
passwords: passwords,
}
b.client = c

Expand Down
169 changes: 101 additions & 68 deletions backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,20 @@ import (
"time"

"github.com/Azure/azure-sdk-for-go/profiles/latest/compute/mgmt/compute"
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/Azure/azure-sdk-for-go/services/preview/authorization/mgmt/2018-01-01-preview/authorization"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/to"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/logical"
"github.com/manicminer/hamilton/msgraph"
)

const (
defaultLeaseTTLHr = 1 * time.Hour
maxLeaseTTLHr = 12 * time.Hour
defaultTestTTL = 300
defaultTestMaxTTL = 3600
passwordLength = 12
)

func getTestBackend(t *testing.T, initConfig bool) (*azureSecretBackend, logical.Storage) {
Expand Down Expand Up @@ -68,8 +68,9 @@ func getTestBackend(t *testing.T, initConfig bool) (*azureSecretBackend, logical
// mockProvider is a Provider that provides stubs and simple, deterministic responses.
type mockProvider struct {
subscriptionID string
applications map[string]bool
passwords map[string]bool
applications map[string]string
passwords map[string]string
servicePrincipals map[string]string
failNextCreateApplication bool
}

Expand All @@ -88,32 +89,34 @@ func (e *errMockProvider) CreateRoleAssignment(ctx context.Context, scope string
// key is found, unlike mockProvider which returns the same application object
// id each time. Existing tests depend on the mockProvider behavior, which is
// why errMockProvider has it's own version.
func (e *errMockProvider) GetApplication(ctx context.Context, applicationObjectID string) (graphrbac.Application, error) {
for s := range e.applications {
if s == applicationObjectID {
return graphrbac.Application{
AppID: to.StringPtr(s),
}, nil
}
func (e *errMockProvider) GetApplication(ctx context.Context, applicationObjectID string) (*msgraph.Application, error) {
appID, ok := e.applications[applicationObjectID]
if ok {
return &msgraph.Application{
ID: to.StringPtr(applicationObjectID),
AppId: to.StringPtr(appID),
}, nil
}
return graphrbac.Application{}, errors.New("not found")
return &msgraph.Application{}, errors.New("not found")
}

func newErrMockProvider() AzureProvider {
return &errMockProvider{
mockProvider: &mockProvider{
subscriptionID: generateUUID(),
applications: make(map[string]bool),
passwords: make(map[string]bool),
subscriptionID: generateUUID(),
applications: make(map[string]string),
passwords: make(map[string]string),
servicePrincipals: make(map[string]string),
},
}
}

func newMockProvider() AzureProvider {
return &mockProvider{
subscriptionID: generateUUID(),
applications: make(map[string]bool),
passwords: make(map[string]bool),
subscriptionID: generateUUID(),
applications: make(map[string]string),
passwords: make(map[string]string),
servicePrincipals: make(map[string]string),
}
}

Expand Down Expand Up @@ -168,63 +171,93 @@ func (m *mockProvider) GetRoleByID(ctx context.Context, roleID string) (result a
return d, nil
}

func (m *mockProvider) CreateServicePrincipal(ctx context.Context, parameters graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error) {
return graphrbac.ServicePrincipal{
ObjectID: to.StringPtr(generateUUID()),
func (m *mockProvider) CreateServicePrincipal(ctx context.Context, parameters msgraph.ServicePrincipal) (*msgraph.ServicePrincipal, error) {
id := generateUUID()
m.servicePrincipals[id] = *parameters.AppId
return &msgraph.ServicePrincipal{
ID: to.StringPtr(id),
AppId: parameters.AppId,
}, nil
}

func (m *mockProvider) GetServicePrincipal(ctx context.Context, objectID string) (*msgraph.ServicePrincipal, error) {
_, ok := m.servicePrincipals[objectID]
if !ok {
return nil, errors.New("not found")
}
return &msgraph.ServicePrincipal{
ID: to.StringPtr(objectID),
}, nil
}

func (m *mockProvider) CreateApplication(ctx context.Context, parameters graphrbac.ApplicationCreateParameters) (graphrbac.Application, error) {
func (m *mockProvider) DeleteServicePrincipal(ctx context.Context, objectID string) error {
delete(m.servicePrincipals, objectID)
return nil
}

func (m *mockProvider) CreateApplication(ctx context.Context, parameters msgraph.Application) (*msgraph.Application, error) {
if m.failNextCreateApplication {
m.failNextCreateApplication = false
return graphrbac.Application{}, errors.New("Mock: fail to create application")
return &msgraph.Application{}, errors.New("Mock: fail to create application")
}
appObjID := generateUUID()
m.applications[appObjID] = true
appID := generateUUID()
m.applications[appObjID] = appID

return graphrbac.Application{
AppID: to.StringPtr(generateUUID()),
ObjectID: &appObjID,
return &msgraph.Application{
ID: to.StringPtr(appObjID),
AppId: to.StringPtr(appID),
}, nil
}

func (m *mockProvider) GetApplication(ctx context.Context, applicationObjectID string) (graphrbac.Application, error) {
return graphrbac.Application{
AppID: to.StringPtr("00000000-0000-0000-0000-000000000000"),
func (m *mockProvider) GetApplication(ctx context.Context, applicationObjectID string) (*msgraph.Application, error) {
creds := make([]msgraph.PasswordCredential, 0)
for keyId, i := range m.passwords {
if i == applicationObjectID {
creds = append(creds, msgraph.PasswordCredential{
KeyId: to.StringPtr(keyId),
})
}
}

appID, ok := m.applications[applicationObjectID]
if !ok {
appID = generateUUID()
}
return &msgraph.Application{
ID: to.StringPtr(applicationObjectID),
AppId: to.StringPtr(appID),
PasswordCredentials: &creds,
}, nil
}

func (m *mockProvider) DeleteApplication(ctx context.Context, applicationObjectID string) (autorest.Response, error) {
func (m *mockProvider) DeleteApplication(ctx context.Context, applicationObjectID string) error {
delete(m.applications, applicationObjectID)
return autorest.Response{}, nil
return nil
}

func (m *mockProvider) UpdateApplicationPasswordCredentials(ctx context.Context, applicationObjectID string, parameters graphrbac.PasswordCredentialsUpdateParameters) (result autorest.Response, err error) {
m.passwords = make(map[string]bool)
for _, v := range *parameters.Value {
m.passwords[*v.KeyID] = true
}

return autorest.Response{}, nil
func (m *mockProvider) AddApplicationPassword(ctx context.Context, applicationObjectID string, credential msgraph.PasswordCredential) (newCredential *msgraph.PasswordCredential, err error) {
keyId := generateUUID()
m.passwords[keyId] = applicationObjectID
return &msgraph.PasswordCredential{
KeyId: to.StringPtr(keyId),
SecretText: to.StringPtr("p@ssw0rd!23$"),
}, nil
}

func (m *mockProvider) ListApplicationPasswordCredentials(ctx context.Context, applicationObjectID string) (result graphrbac.PasswordCredentialListResult, err error) {
var creds []graphrbac.PasswordCredential
for keyID := range m.passwords {
creds = append(creds, graphrbac.PasswordCredential{KeyID: &keyID})
}

return graphrbac.PasswordCredentialListResult{
Value: &creds,
}, nil
func (m *mockProvider) RemoveApplicationPassword(ctx context.Context, applicationObjectID string, keyId string) (err error) {
delete(m.passwords, keyId)
return nil
}

func (m *mockProvider) appExists(s string) bool {
return m.applications[s]
func (m *mockProvider) appExists(appObjID string) bool {
_, ok := m.applications[appObjID]
return ok
}

func (m *mockProvider) passwordExists(s string) bool {
return m.passwords[s]
func (m *mockProvider) passwordExists(keyID string) bool {
_, ok := m.passwords[keyID]
return ok
}

func (m *mockProvider) VMGet(ctx context.Context, resourceGroupName string, VMName string, expand compute.InstanceViewTypes) (result compute.VirtualMachine, err error) {
Expand All @@ -245,56 +278,56 @@ func (m *mockProvider) DeleteRoleAssignmentByID(ctx context.Context, roleID stri
return authorization.RoleAssignment{}, nil
}

// AddGroupMember adds a member to a AAD Group.
func (m *mockProvider) AddGroupMember(ctx context.Context, groupObjectID string, parameters graphrbac.GroupAddMemberParameters) (result autorest.Response, err error) {
return autorest.Response{}, nil
// AddGroupMembers adds members to an AAD Group.
func (m *mockProvider) AddGroupMember(ctx context.Context, group *msgraph.Group) (err error) {
return nil
}

// RemoveGroupMember removes a member from a AAD Group.
func (m *mockProvider) RemoveGroupMember(ctx context.Context, groupObjectID, memberObjectID string) (result autorest.Response, err error) {
return autorest.Response{}, nil
func (m *mockProvider) RemoveGroupMember(ctx context.Context, groupObjectID, memberObjectID string) (err error) {
return nil
}

// GetGroup gets group information from the directory.
func (m *mockProvider) GetGroup(ctx context.Context, objectID string) (result graphrbac.ADGroup, err error) {
g := graphrbac.ADGroup{
ObjectID: to.StringPtr(objectID),
func (m *mockProvider) GetGroup(ctx context.Context, objectID string) (result *msgraph.Group, err error) {
g := msgraph.Group{
ID: to.StringPtr(objectID),
}
s := strings.Split(objectID, "FAKE_GROUP-")
if len(s) > 1 {
g.DisplayName = to.StringPtr(s[1])
}

return g, nil
return &g, nil
}

// ListGroups gets list of groups for the current tenant.
func (m *mockProvider) ListGroups(ctx context.Context, filter string) (result []graphrbac.ADGroup, err error) {
func (m *mockProvider) ListGroups(ctx context.Context, filter string) (result *[]msgraph.Group, err error) {
reGroupName := regexp.MustCompile("displayName eq '(.*)'")

match := reGroupName.FindAllStringSubmatch(filter, -1)
if len(match) > 0 {
name := match[0][1]
if name == "multiple" {
return []graphrbac.ADGroup{
return &[]msgraph.Group{
{
ObjectID: to.StringPtr(fmt.Sprintf("00000000-1111-2222-3333-444444444444FAKE_GROUP-%s-1", name)),
ID: to.StringPtr(fmt.Sprintf("00000000-1111-2222-3333-444444444444FAKE_GROUP-%s-1", name)),
DisplayName: to.StringPtr(name),
},
{
ObjectID: to.StringPtr(fmt.Sprintf("00000000-1111-2222-3333-444444444444FAKE_GROUP-%s-2", name)),
ID: to.StringPtr(fmt.Sprintf("00000000-1111-2222-3333-444444444444FAKE_GROUP-%s-2", name)),
DisplayName: to.StringPtr(name),
},
}, nil
}

return []graphrbac.ADGroup{
return &[]msgraph.Group{
{
ObjectID: to.StringPtr(fmt.Sprintf("00000000-1111-2222-3333-444444444444FAKE_GROUP-%s", name)),
ID: to.StringPtr(fmt.Sprintf("00000000-1111-2222-3333-444444444444FAKE_GROUP-%s", name)),
DisplayName: to.StringPtr(name),
},
}, nil
}

return []graphrbac.ADGroup{}, nil
return &[]msgraph.Group{}, nil
}