Skip to content

Commit

Permalink
Migrate to MS Graph for AAD operations
Browse files Browse the repository at this point in the history
  • Loading branch information
manicminer committed Jul 29, 2021
1 parent 34e122b commit ea1f822
Show file tree
Hide file tree
Showing 12 changed files with 832 additions and 459 deletions.
6 changes: 0 additions & 6 deletions backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,10 @@ func (b *azureSecretBackend) getClient(ctx context.Context, s logical.Storage) (
return nil, err
}

passwords := passwords{
policyGenerator: b.System(),
policyName: config.PasswordPolicy,
}

c := &client{
provider: p,
settings: b.settings,
expiration: time.Now().Add(clientLifetime),
passwords: passwords,
}
b.client = c

Expand Down
169 changes: 101 additions & 68 deletions backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,20 @@ import (
"time"

"github.com/Azure/azure-sdk-for-go/profiles/latest/compute/mgmt/compute"
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/Azure/azure-sdk-for-go/services/preview/authorization/mgmt/2018-01-01-preview/authorization"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/to"
log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/logical"
"github.com/manicminer/hamilton/msgraph"
)

const (
defaultLeaseTTLHr = 1 * time.Hour
maxLeaseTTLHr = 12 * time.Hour
defaultTestTTL = 300
defaultTestMaxTTL = 3600
passwordLength = 12
)

func getTestBackend(t *testing.T, initConfig bool) (*azureSecretBackend, logical.Storage) {
Expand Down Expand Up @@ -68,8 +68,9 @@ func getTestBackend(t *testing.T, initConfig bool) (*azureSecretBackend, logical
// mockProvider is a Provider that provides stubs and simple, deterministic responses.
type mockProvider struct {
subscriptionID string
applications map[string]bool
passwords map[string]bool
applications map[string]string
passwords map[string]string
servicePrincipals map[string]string
failNextCreateApplication bool
}

Expand All @@ -88,32 +89,34 @@ func (e *errMockProvider) CreateRoleAssignment(ctx context.Context, scope string
// key is found, unlike mockProvider which returns the same application object
// id each time. Existing tests depend on the mockProvider behavior, which is
// why errMockProvider has it's own version.
func (e *errMockProvider) GetApplication(ctx context.Context, applicationObjectID string) (graphrbac.Application, error) {
for s := range e.applications {
if s == applicationObjectID {
return graphrbac.Application{
AppID: to.StringPtr(s),
}, nil
}
func (e *errMockProvider) GetApplication(ctx context.Context, applicationObjectID string) (*msgraph.Application, error) {
appID, ok := e.applications[applicationObjectID]
if ok {
return &msgraph.Application{
ID: to.StringPtr(applicationObjectID),
AppId: to.StringPtr(appID),
}, nil
}
return graphrbac.Application{}, errors.New("not found")
return &msgraph.Application{}, errors.New("not found")
}

func newErrMockProvider() AzureProvider {
return &errMockProvider{
mockProvider: &mockProvider{
subscriptionID: generateUUID(),
applications: make(map[string]bool),
passwords: make(map[string]bool),
subscriptionID: generateUUID(),
applications: make(map[string]string),
passwords: make(map[string]string),
servicePrincipals: make(map[string]string),
},
}
}

func newMockProvider() AzureProvider {
return &mockProvider{
subscriptionID: generateUUID(),
applications: make(map[string]bool),
passwords: make(map[string]bool),
subscriptionID: generateUUID(),
applications: make(map[string]string),
passwords: make(map[string]string),
servicePrincipals: make(map[string]string),
}
}

Expand Down Expand Up @@ -168,63 +171,93 @@ func (m *mockProvider) GetRoleByID(ctx context.Context, roleID string) (result a
return d, nil
}

func (m *mockProvider) CreateServicePrincipal(ctx context.Context, parameters graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error) {
return graphrbac.ServicePrincipal{
ObjectID: to.StringPtr(generateUUID()),
func (m *mockProvider) CreateServicePrincipal(ctx context.Context, parameters msgraph.ServicePrincipal) (*msgraph.ServicePrincipal, error) {
id := generateUUID()
m.servicePrincipals[id] = *parameters.AppId
return &msgraph.ServicePrincipal{
ID: to.StringPtr(id),
AppId: parameters.AppId,
}, nil
}

func (m *mockProvider) GetServicePrincipal(ctx context.Context, objectID string) (*msgraph.ServicePrincipal, error) {
_, ok := m.servicePrincipals[objectID]
if !ok {
return nil, errors.New("not found")
}
return &msgraph.ServicePrincipal{
ID: to.StringPtr(objectID),
}, nil
}

func (m *mockProvider) CreateApplication(ctx context.Context, parameters graphrbac.ApplicationCreateParameters) (graphrbac.Application, error) {
func (m *mockProvider) DeleteServicePrincipal(ctx context.Context, objectID string) error {
delete(m.servicePrincipals, objectID)
return nil
}

func (m *mockProvider) CreateApplication(ctx context.Context, parameters msgraph.Application) (*msgraph.Application, error) {
if m.failNextCreateApplication {
m.failNextCreateApplication = false
return graphrbac.Application{}, errors.New("Mock: fail to create application")
return &msgraph.Application{}, errors.New("Mock: fail to create application")
}
appObjID := generateUUID()
m.applications[appObjID] = true
appID := generateUUID()
m.applications[appObjID] = appID

return graphrbac.Application{
AppID: to.StringPtr(generateUUID()),
ObjectID: &appObjID,
return &msgraph.Application{
ID: to.StringPtr(appObjID),
AppId: to.StringPtr(appID),
}, nil
}

func (m *mockProvider) GetApplication(ctx context.Context, applicationObjectID string) (graphrbac.Application, error) {
return graphrbac.Application{
AppID: to.StringPtr("00000000-0000-0000-0000-000000000000"),
func (m *mockProvider) GetApplication(ctx context.Context, applicationObjectID string) (*msgraph.Application, error) {
creds := make([]msgraph.PasswordCredential, 0)
for keyId, i := range m.passwords {
if i == applicationObjectID {
creds = append(creds, msgraph.PasswordCredential{
KeyId: to.StringPtr(keyId),
})
}
}

appID, ok := m.applications[applicationObjectID]
if !ok {
appID = generateUUID()
}
return &msgraph.Application{
ID: to.StringPtr(applicationObjectID),
AppId: to.StringPtr(appID),
PasswordCredentials: &creds,
}, nil
}

func (m *mockProvider) DeleteApplication(ctx context.Context, applicationObjectID string) (autorest.Response, error) {
func (m *mockProvider) DeleteApplication(ctx context.Context, applicationObjectID string) error {
delete(m.applications, applicationObjectID)
return autorest.Response{}, nil
return nil
}

func (m *mockProvider) UpdateApplicationPasswordCredentials(ctx context.Context, applicationObjectID string, parameters graphrbac.PasswordCredentialsUpdateParameters) (result autorest.Response, err error) {
m.passwords = make(map[string]bool)
for _, v := range *parameters.Value {
m.passwords[*v.KeyID] = true
}

return autorest.Response{}, nil
func (m *mockProvider) AddApplicationPassword(ctx context.Context, applicationObjectID string, credential msgraph.PasswordCredential) (newCredential *msgraph.PasswordCredential, err error) {
keyId := generateUUID()
m.passwords[keyId] = applicationObjectID
return &msgraph.PasswordCredential{
KeyId: to.StringPtr(keyId),
SecretText: to.StringPtr("p@ssw0rd!23$"),
}, nil
}

func (m *mockProvider) ListApplicationPasswordCredentials(ctx context.Context, applicationObjectID string) (result graphrbac.PasswordCredentialListResult, err error) {
var creds []graphrbac.PasswordCredential
for keyID := range m.passwords {
creds = append(creds, graphrbac.PasswordCredential{KeyID: &keyID})
}

return graphrbac.PasswordCredentialListResult{
Value: &creds,
}, nil
func (m *mockProvider) RemoveApplicationPassword(ctx context.Context, applicationObjectID string, keyId string) (err error) {
delete(m.passwords, keyId)
return nil
}

func (m *mockProvider) appExists(s string) bool {
return m.applications[s]
func (m *mockProvider) appExists(appObjID string) bool {
_, ok := m.applications[appObjID]
return ok
}

func (m *mockProvider) passwordExists(s string) bool {
return m.passwords[s]
func (m *mockProvider) passwordExists(keyID string) bool {
_, ok := m.passwords[keyID]
return ok
}

func (m *mockProvider) VMGet(ctx context.Context, resourceGroupName string, VMName string, expand compute.InstanceViewTypes) (result compute.VirtualMachine, err error) {
Expand All @@ -245,56 +278,56 @@ func (m *mockProvider) DeleteRoleAssignmentByID(ctx context.Context, roleID stri
return authorization.RoleAssignment{}, nil
}

// AddGroupMember adds a member to a AAD Group.
func (m *mockProvider) AddGroupMember(ctx context.Context, groupObjectID string, parameters graphrbac.GroupAddMemberParameters) (result autorest.Response, err error) {
return autorest.Response{}, nil
// AddGroupMembers adds members to an AAD Group.
func (m *mockProvider) AddGroupMember(ctx context.Context, group *msgraph.Group) (err error) {
return nil
}

// RemoveGroupMember removes a member from a AAD Group.
func (m *mockProvider) RemoveGroupMember(ctx context.Context, groupObjectID, memberObjectID string) (result autorest.Response, err error) {
return autorest.Response{}, nil
func (m *mockProvider) RemoveGroupMember(ctx context.Context, groupObjectID, memberObjectID string) (err error) {
return nil
}

// GetGroup gets group information from the directory.
func (m *mockProvider) GetGroup(ctx context.Context, objectID string) (result graphrbac.ADGroup, err error) {
g := graphrbac.ADGroup{
ObjectID: to.StringPtr(objectID),
func (m *mockProvider) GetGroup(ctx context.Context, objectID string) (result *msgraph.Group, err error) {
g := msgraph.Group{
ID: to.StringPtr(objectID),
}
s := strings.Split(objectID, "FAKE_GROUP-")
if len(s) > 1 {
g.DisplayName = to.StringPtr(s[1])
}

return g, nil
return &g, nil
}

// ListGroups gets list of groups for the current tenant.
func (m *mockProvider) ListGroups(ctx context.Context, filter string) (result []graphrbac.ADGroup, err error) {
func (m *mockProvider) ListGroups(ctx context.Context, filter string) (result *[]msgraph.Group, err error) {
reGroupName := regexp.MustCompile("displayName eq '(.*)'")

match := reGroupName.FindAllStringSubmatch(filter, -1)
if len(match) > 0 {
name := match[0][1]
if name == "multiple" {
return []graphrbac.ADGroup{
return &[]msgraph.Group{
{
ObjectID: to.StringPtr(fmt.Sprintf("00000000-1111-2222-3333-444444444444FAKE_GROUP-%s-1", name)),
ID: to.StringPtr(fmt.Sprintf("00000000-1111-2222-3333-444444444444FAKE_GROUP-%s-1", name)),
DisplayName: to.StringPtr(name),
},
{
ObjectID: to.StringPtr(fmt.Sprintf("00000000-1111-2222-3333-444444444444FAKE_GROUP-%s-2", name)),
ID: to.StringPtr(fmt.Sprintf("00000000-1111-2222-3333-444444444444FAKE_GROUP-%s-2", name)),
DisplayName: to.StringPtr(name),
},
}, nil
}

return []graphrbac.ADGroup{
return &[]msgraph.Group{
{
ObjectID: to.StringPtr(fmt.Sprintf("00000000-1111-2222-3333-444444444444FAKE_GROUP-%s", name)),
ID: to.StringPtr(fmt.Sprintf("00000000-1111-2222-3333-444444444444FAKE_GROUP-%s", name)),
DisplayName: to.StringPtr(name),
},
}, nil
}

return []graphrbac.ADGroup{}, nil
return &[]msgraph.Group{}, nil
}

0 comments on commit ea1f822

Please sign in to comment.