Skip to content

Commit

Permalink
[v13] User groups in access requests will expand list of applications. (
Browse files Browse the repository at this point in the history
#28603)

* User groups in access requests will expand list of applications.

When requesting user groups in an access request, the applications associated
with the access request will now be expanded in the access request, requesting
both access to the group itself along with the associated applications.

* Don't duplicate apps if they're already being requested.

* Add clarifying comments to test, test equivalence against more fields.
  • Loading branch information
mdwn committed Jul 5, 2023
1 parent 1412860 commit c03f7d9
Show file tree
Hide file tree
Showing 2 changed files with 230 additions and 0 deletions.
43 changes: 43 additions & 0 deletions lib/auth/auth.go
Expand Up @@ -3860,6 +3860,49 @@ func (a *Server) CreateAccessRequest(ctx context.Context, req types.AccessReques
return trace.Wrap(err)
}

// Look for user groups and associated applications to the request.
requestedResourceIDs := req.GetRequestedResourceIDs()
var additionalResources []types.ResourceID

var userGroups []types.ResourceID
existingApps := map[string]struct{}{}
for _, resource := range requestedResourceIDs {
switch resource.Kind {
case types.KindApp:
existingApps[resource.Name] = struct{}{}
case types.KindUserGroup:
userGroups = append(userGroups, resource)
}
}

for _, resource := range userGroups {
if resource.Kind != types.KindUserGroup {
continue
}

userGroup, err := a.GetUserGroup(ctx, resource.Name)
if err != nil {
return trace.Wrap(err)
}

for _, app := range userGroup.GetApplications() {
// Only add to the request if we haven't already added it.
if _, ok := existingApps[app]; !ok {
additionalResources = append(additionalResources, types.ResourceID{
ClusterName: resource.ClusterName,
Kind: types.KindApp,
Name: app,
})
existingApps[app] = struct{}{}
}
}
}

if len(additionalResources) > 0 {
requestedResourceIDs = append(requestedResourceIDs, additionalResources...)
req.SetRequestedResourceIDs(requestedResourceIDs)
}

if req.GetDryRun() {
// Made it this far with no errors, return before creating the request
// if this is a dry run.
Expand Down
187 changes: 187 additions & 0 deletions lib/auth/auth_with_roles_test.go
Expand Up @@ -5519,3 +5519,190 @@ func TestSafeToSkipInventoryCheck(t *testing.T) {
safeToSkipInventoryCheck(*semver.New(tc.authVersion), *semver.New(tc.minRequiredVersion)))
}
}

func TestCreateAccessRequest(t *testing.T) {
t.Parallel()
ctx := context.Background()

srv := newTestTLSServer(t)
clock := srv.Clock()
alice, bob, admin := createSessionTestUsers(t, srv.Auth())

searchRole, err := types.NewRole("requestRole", types.RoleSpecV6{
Allow: types.RoleConditions{
Request: &types.AccessRequestConditions{
Roles: []string{"requestRole"},
SearchAsRoles: []string{"requestRole"},
},
},
})
require.NoError(t, err)

requestRole, err := types.NewRole("requestRole", types.RoleSpecV6{})
require.NoError(t, err)

srv.Auth().CreateRole(ctx, searchRole)
srv.Auth().CreateRole(ctx, requestRole)

user, err := srv.Auth().GetUser(alice, true)
require.NoError(t, err)

user.AddRole(searchRole.GetName())
require.NoError(t, srv.Auth().UpsertUser(user))

userGroup1, err := types.NewUserGroup(types.Metadata{
Name: "user-group1",
}, types.UserGroupSpecV1{
Applications: []string{"app1", "app2", "app3"},
})
require.NoError(t, err)
require.NoError(t, srv.Auth().CreateUserGroup(ctx, userGroup1))

userGroup2, err := types.NewUserGroup(types.Metadata{
Name: "user-group2",
}, types.UserGroupSpecV1{})
require.NoError(t, err)
require.NoError(t, srv.Auth().CreateUserGroup(ctx, userGroup2))

userGroup3, err := types.NewUserGroup(types.Metadata{
Name: "user-group3",
}, types.UserGroupSpecV1{
Applications: []string{"app1", "app4", "app5"},
})
require.NoError(t, err)
require.NoError(t, srv.Auth().CreateUserGroup(ctx, userGroup3))

tests := []struct {
name string
user string
accessRequest types.AccessRequest
errAssertionFunc require.ErrorAssertionFunc
expected types.AccessRequest
}{
{
name: "user creates own pending access request",
user: alice,
accessRequest: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour),
[]string{requestRole.GetName()}, []types.ResourceID{
mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()),
}),
errAssertionFunc: require.NoError,
expected: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour),
[]string{requestRole.GetName()}, []types.ResourceID{
mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()),
}),
},
{
name: "admin creates a request for alice",
user: admin,
accessRequest: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour),
[]string{requestRole.GetName()}, []types.ResourceID{
mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()),
}),
errAssertionFunc: require.NoError,
expected: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour),
[]string{requestRole.GetName()}, []types.ResourceID{
mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()),
}),
},
{
name: "bob fails to create a request for alice",
user: bob,
accessRequest: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour),
[]string{requestRole.GetName()}, []types.ResourceID{
mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()),
}),
errAssertionFunc: require.Error,
},
{
name: "user creates own pending access request with user group needing app expansion",
user: alice,
accessRequest: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour),
[]string{requestRole.GetName()}, []types.ResourceID{
mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()),
mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup1.GetName()),
mustResourceID(srv.ClusterName(), types.KindApp, "app1"),
mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup2.GetName()),
mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup3.GetName()),
}),
errAssertionFunc: require.NoError,
expected: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour),
[]string{requestRole.GetName()}, []types.ResourceID{
mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()),
mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup1.GetName()),
mustResourceID(srv.ClusterName(), types.KindApp, "app1"),
mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup2.GetName()),
mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup3.GetName()),
mustResourceID(srv.ClusterName(), types.KindApp, "app2"),
mustResourceID(srv.ClusterName(), types.KindApp, "app3"),
mustResourceID(srv.ClusterName(), types.KindApp, "app4"),
mustResourceID(srv.ClusterName(), types.KindApp, "app5"),
}),
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Make sure there are no access requests before we do anything. We'll clear out
// each time to save on the complexity of setting up the auth server and dependent
// users and roles.
ctx := context.Background()
require.NoError(t, srv.Auth().DeleteAllAccessRequests(ctx))

client, err := srv.NewClient(TestUser(test.user))
require.NoError(t, err)

test.errAssertionFunc(t, client.CreateAccessRequest(ctx, test.accessRequest))

accessRequests, err := srv.Auth().GetAccessRequests(ctx, types.AccessRequestFilter{
ID: test.accessRequest.GetName(),
})
require.NoError(t, err)

if test.expected == nil {
require.Empty(t, accessRequests)
return
}

require.Len(t, accessRequests, 1)

// We have to ignore the name here, as it's auto-generated by the underlying access request
// logic.
require.Empty(t, cmp.Diff(test.expected, accessRequests[0],
cmpopts.IgnoreFields(types.Metadata{}, "Name", "ID"),
cmpopts.IgnoreFields(types.AccessRequestSpecV3{}),
))
})
}
}

func mustAccessRequest(t *testing.T, user string, state types.RequestState, created, expires time.Time, roles []string, resourceIDs []types.ResourceID) types.AccessRequest {
t.Helper()

accessRequest, err := types.NewAccessRequest(uuid.NewString(), user, roles...)
require.NoError(t, err)

accessRequest.SetRequestedResourceIDs(resourceIDs)
accessRequest.SetState(state)
accessRequest.SetCreationTime(created)
accessRequest.SetExpiry(expires)
accessRequest.SetAccessExpiry(expires)
accessRequest.SetThresholds([]types.AccessReviewThreshold{{Name: "default", Approve: 1, Deny: 1}})
accessRequest.SetRoleThresholdMapping(map[string]types.ThresholdIndexSets{
"requestRole": {
Sets: []types.ThresholdIndexSet{
{Indexes: []uint32{0}},
},
},
})

return accessRequest
}

func mustResourceID(clusterName, kind, name string) types.ResourceID {
return types.ResourceID{
ClusterName: clusterName,
Kind: kind,
Name: name,
}
}

0 comments on commit c03f7d9

Please sign in to comment.