From c03f7d9a45b6ad90d52b8bf69a7d69e414155a5e Mon Sep 17 00:00:00 2001 From: Michael Wilson Date: Wed, 5 Jul 2023 15:21:14 -0400 Subject: [PATCH] [v13] User groups in access requests will expand list of applications. (#28603) * User groups in access requests will expand list of applications. When requesting user groups in an access request, the applications associated with the access request will now be expanded in the access request, requesting both access to the group itself along with the associated applications. * Don't duplicate apps if they're already being requested. * Add clarifying comments to test, test equivalence against more fields. --- lib/auth/auth.go | 43 +++++++ lib/auth/auth_with_roles_test.go | 187 +++++++++++++++++++++++++++++++ 2 files changed, 230 insertions(+) diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 13aa37335a9fa..c3d276182ef96 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -3860,6 +3860,49 @@ func (a *Server) CreateAccessRequest(ctx context.Context, req types.AccessReques return trace.Wrap(err) } + // Look for user groups and associated applications to the request. + requestedResourceIDs := req.GetRequestedResourceIDs() + var additionalResources []types.ResourceID + + var userGroups []types.ResourceID + existingApps := map[string]struct{}{} + for _, resource := range requestedResourceIDs { + switch resource.Kind { + case types.KindApp: + existingApps[resource.Name] = struct{}{} + case types.KindUserGroup: + userGroups = append(userGroups, resource) + } + } + + for _, resource := range userGroups { + if resource.Kind != types.KindUserGroup { + continue + } + + userGroup, err := a.GetUserGroup(ctx, resource.Name) + if err != nil { + return trace.Wrap(err) + } + + for _, app := range userGroup.GetApplications() { + // Only add to the request if we haven't already added it. + if _, ok := existingApps[app]; !ok { + additionalResources = append(additionalResources, types.ResourceID{ + ClusterName: resource.ClusterName, + Kind: types.KindApp, + Name: app, + }) + existingApps[app] = struct{}{} + } + } + } + + if len(additionalResources) > 0 { + requestedResourceIDs = append(requestedResourceIDs, additionalResources...) + req.SetRequestedResourceIDs(requestedResourceIDs) + } + if req.GetDryRun() { // Made it this far with no errors, return before creating the request // if this is a dry run. diff --git a/lib/auth/auth_with_roles_test.go b/lib/auth/auth_with_roles_test.go index baf9ecfac6e31..012c9b2b0a132 100644 --- a/lib/auth/auth_with_roles_test.go +++ b/lib/auth/auth_with_roles_test.go @@ -5519,3 +5519,190 @@ func TestSafeToSkipInventoryCheck(t *testing.T) { safeToSkipInventoryCheck(*semver.New(tc.authVersion), *semver.New(tc.minRequiredVersion))) } } + +func TestCreateAccessRequest(t *testing.T) { + t.Parallel() + ctx := context.Background() + + srv := newTestTLSServer(t) + clock := srv.Clock() + alice, bob, admin := createSessionTestUsers(t, srv.Auth()) + + searchRole, err := types.NewRole("requestRole", types.RoleSpecV6{ + Allow: types.RoleConditions{ + Request: &types.AccessRequestConditions{ + Roles: []string{"requestRole"}, + SearchAsRoles: []string{"requestRole"}, + }, + }, + }) + require.NoError(t, err) + + requestRole, err := types.NewRole("requestRole", types.RoleSpecV6{}) + require.NoError(t, err) + + srv.Auth().CreateRole(ctx, searchRole) + srv.Auth().CreateRole(ctx, requestRole) + + user, err := srv.Auth().GetUser(alice, true) + require.NoError(t, err) + + user.AddRole(searchRole.GetName()) + require.NoError(t, srv.Auth().UpsertUser(user)) + + userGroup1, err := types.NewUserGroup(types.Metadata{ + Name: "user-group1", + }, types.UserGroupSpecV1{ + Applications: []string{"app1", "app2", "app3"}, + }) + require.NoError(t, err) + require.NoError(t, srv.Auth().CreateUserGroup(ctx, userGroup1)) + + userGroup2, err := types.NewUserGroup(types.Metadata{ + Name: "user-group2", + }, types.UserGroupSpecV1{}) + require.NoError(t, err) + require.NoError(t, srv.Auth().CreateUserGroup(ctx, userGroup2)) + + userGroup3, err := types.NewUserGroup(types.Metadata{ + Name: "user-group3", + }, types.UserGroupSpecV1{ + Applications: []string{"app1", "app4", "app5"}, + }) + require.NoError(t, err) + require.NoError(t, srv.Auth().CreateUserGroup(ctx, userGroup3)) + + tests := []struct { + name string + user string + accessRequest types.AccessRequest + errAssertionFunc require.ErrorAssertionFunc + expected types.AccessRequest + }{ + { + name: "user creates own pending access request", + user: alice, + accessRequest: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour), + []string{requestRole.GetName()}, []types.ResourceID{ + mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()), + }), + errAssertionFunc: require.NoError, + expected: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour), + []string{requestRole.GetName()}, []types.ResourceID{ + mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()), + }), + }, + { + name: "admin creates a request for alice", + user: admin, + accessRequest: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour), + []string{requestRole.GetName()}, []types.ResourceID{ + mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()), + }), + errAssertionFunc: require.NoError, + expected: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour), + []string{requestRole.GetName()}, []types.ResourceID{ + mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()), + }), + }, + { + name: "bob fails to create a request for alice", + user: bob, + accessRequest: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour), + []string{requestRole.GetName()}, []types.ResourceID{ + mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()), + }), + errAssertionFunc: require.Error, + }, + { + name: "user creates own pending access request with user group needing app expansion", + user: alice, + accessRequest: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour), + []string{requestRole.GetName()}, []types.ResourceID{ + mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()), + mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup1.GetName()), + mustResourceID(srv.ClusterName(), types.KindApp, "app1"), + mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup2.GetName()), + mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup3.GetName()), + }), + errAssertionFunc: require.NoError, + expected: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour), + []string{requestRole.GetName()}, []types.ResourceID{ + mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()), + mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup1.GetName()), + mustResourceID(srv.ClusterName(), types.KindApp, "app1"), + mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup2.GetName()), + mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup3.GetName()), + mustResourceID(srv.ClusterName(), types.KindApp, "app2"), + mustResourceID(srv.ClusterName(), types.KindApp, "app3"), + mustResourceID(srv.ClusterName(), types.KindApp, "app4"), + mustResourceID(srv.ClusterName(), types.KindApp, "app5"), + }), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Make sure there are no access requests before we do anything. We'll clear out + // each time to save on the complexity of setting up the auth server and dependent + // users and roles. + ctx := context.Background() + require.NoError(t, srv.Auth().DeleteAllAccessRequests(ctx)) + + client, err := srv.NewClient(TestUser(test.user)) + require.NoError(t, err) + + test.errAssertionFunc(t, client.CreateAccessRequest(ctx, test.accessRequest)) + + accessRequests, err := srv.Auth().GetAccessRequests(ctx, types.AccessRequestFilter{ + ID: test.accessRequest.GetName(), + }) + require.NoError(t, err) + + if test.expected == nil { + require.Empty(t, accessRequests) + return + } + + require.Len(t, accessRequests, 1) + + // We have to ignore the name here, as it's auto-generated by the underlying access request + // logic. + require.Empty(t, cmp.Diff(test.expected, accessRequests[0], + cmpopts.IgnoreFields(types.Metadata{}, "Name", "ID"), + cmpopts.IgnoreFields(types.AccessRequestSpecV3{}), + )) + }) + } +} + +func mustAccessRequest(t *testing.T, user string, state types.RequestState, created, expires time.Time, roles []string, resourceIDs []types.ResourceID) types.AccessRequest { + t.Helper() + + accessRequest, err := types.NewAccessRequest(uuid.NewString(), user, roles...) + require.NoError(t, err) + + accessRequest.SetRequestedResourceIDs(resourceIDs) + accessRequest.SetState(state) + accessRequest.SetCreationTime(created) + accessRequest.SetExpiry(expires) + accessRequest.SetAccessExpiry(expires) + accessRequest.SetThresholds([]types.AccessReviewThreshold{{Name: "default", Approve: 1, Deny: 1}}) + accessRequest.SetRoleThresholdMapping(map[string]types.ThresholdIndexSets{ + "requestRole": { + Sets: []types.ThresholdIndexSet{ + {Indexes: []uint32{0}}, + }, + }, + }) + + return accessRequest +} + +func mustResourceID(clusterName, kind, name string) types.ResourceID { + return types.ResourceID{ + ClusterName: clusterName, + Kind: kind, + Name: name, + } +}