diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 13aa37335a9fa..c3d276182ef96 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -3860,6 +3860,49 @@ func (a *Server) CreateAccessRequest(ctx context.Context, req types.AccessReques return trace.Wrap(err) } + // Look for user groups and associated applications to the request. + requestedResourceIDs := req.GetRequestedResourceIDs() + var additionalResources []types.ResourceID + + var userGroups []types.ResourceID + existingApps := map[string]struct{}{} + for _, resource := range requestedResourceIDs { + switch resource.Kind { + case types.KindApp: + existingApps[resource.Name] = struct{}{} + case types.KindUserGroup: + userGroups = append(userGroups, resource) + } + } + + for _, resource := range userGroups { + if resource.Kind != types.KindUserGroup { + continue + } + + userGroup, err := a.GetUserGroup(ctx, resource.Name) + if err != nil { + return trace.Wrap(err) + } + + for _, app := range userGroup.GetApplications() { + // Only add to the request if we haven't already added it. + if _, ok := existingApps[app]; !ok { + additionalResources = append(additionalResources, types.ResourceID{ + ClusterName: resource.ClusterName, + Kind: types.KindApp, + Name: app, + }) + existingApps[app] = struct{}{} + } + } + } + + if len(additionalResources) > 0 { + requestedResourceIDs = append(requestedResourceIDs, additionalResources...) + req.SetRequestedResourceIDs(requestedResourceIDs) + } + if req.GetDryRun() { // Made it this far with no errors, return before creating the request // if this is a dry run. diff --git a/lib/auth/auth_with_roles_test.go b/lib/auth/auth_with_roles_test.go index baf9ecfac6e31..012c9b2b0a132 100644 --- a/lib/auth/auth_with_roles_test.go +++ b/lib/auth/auth_with_roles_test.go @@ -5519,3 +5519,190 @@ func TestSafeToSkipInventoryCheck(t *testing.T) { safeToSkipInventoryCheck(*semver.New(tc.authVersion), *semver.New(tc.minRequiredVersion))) } } + +func TestCreateAccessRequest(t *testing.T) { + t.Parallel() + ctx := context.Background() + + srv := newTestTLSServer(t) + clock := srv.Clock() + alice, bob, admin := createSessionTestUsers(t, srv.Auth()) + + searchRole, err := types.NewRole("requestRole", types.RoleSpecV6{ + Allow: types.RoleConditions{ + Request: &types.AccessRequestConditions{ + Roles: []string{"requestRole"}, + SearchAsRoles: []string{"requestRole"}, + }, + }, + }) + require.NoError(t, err) + + requestRole, err := types.NewRole("requestRole", types.RoleSpecV6{}) + require.NoError(t, err) + + srv.Auth().CreateRole(ctx, searchRole) + srv.Auth().CreateRole(ctx, requestRole) + + user, err := srv.Auth().GetUser(alice, true) + require.NoError(t, err) + + user.AddRole(searchRole.GetName()) + require.NoError(t, srv.Auth().UpsertUser(user)) + + userGroup1, err := types.NewUserGroup(types.Metadata{ + Name: "user-group1", + }, types.UserGroupSpecV1{ + Applications: []string{"app1", "app2", "app3"}, + }) + require.NoError(t, err) + require.NoError(t, srv.Auth().CreateUserGroup(ctx, userGroup1)) + + userGroup2, err := types.NewUserGroup(types.Metadata{ + Name: "user-group2", + }, types.UserGroupSpecV1{}) + require.NoError(t, err) + require.NoError(t, srv.Auth().CreateUserGroup(ctx, userGroup2)) + + userGroup3, err := types.NewUserGroup(types.Metadata{ + Name: "user-group3", + }, types.UserGroupSpecV1{ + Applications: []string{"app1", "app4", "app5"}, + }) + require.NoError(t, err) + require.NoError(t, srv.Auth().CreateUserGroup(ctx, userGroup3)) + + tests := []struct { + name string + user string + accessRequest types.AccessRequest + errAssertionFunc require.ErrorAssertionFunc + expected types.AccessRequest + }{ + { + name: "user creates own pending access request", + user: alice, + accessRequest: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour), + []string{requestRole.GetName()}, []types.ResourceID{ + mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()), + }), + errAssertionFunc: require.NoError, + expected: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour), + []string{requestRole.GetName()}, []types.ResourceID{ + mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()), + }), + }, + { + name: "admin creates a request for alice", + user: admin, + accessRequest: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour), + []string{requestRole.GetName()}, []types.ResourceID{ + mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()), + }), + errAssertionFunc: require.NoError, + expected: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour), + []string{requestRole.GetName()}, []types.ResourceID{ + mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()), + }), + }, + { + name: "bob fails to create a request for alice", + user: bob, + accessRequest: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour), + []string{requestRole.GetName()}, []types.ResourceID{ + mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()), + }), + errAssertionFunc: require.Error, + }, + { + name: "user creates own pending access request with user group needing app expansion", + user: alice, + accessRequest: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour), + []string{requestRole.GetName()}, []types.ResourceID{ + mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()), + mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup1.GetName()), + mustResourceID(srv.ClusterName(), types.KindApp, "app1"), + mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup2.GetName()), + mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup3.GetName()), + }), + errAssertionFunc: require.NoError, + expected: mustAccessRequest(t, alice, types.RequestState_PENDING, clock.Now(), clock.Now().Add(time.Hour), + []string{requestRole.GetName()}, []types.ResourceID{ + mustResourceID(srv.ClusterName(), types.KindRole, requestRole.GetName()), + mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup1.GetName()), + mustResourceID(srv.ClusterName(), types.KindApp, "app1"), + mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup2.GetName()), + mustResourceID(srv.ClusterName(), types.KindUserGroup, userGroup3.GetName()), + mustResourceID(srv.ClusterName(), types.KindApp, "app2"), + mustResourceID(srv.ClusterName(), types.KindApp, "app3"), + mustResourceID(srv.ClusterName(), types.KindApp, "app4"), + mustResourceID(srv.ClusterName(), types.KindApp, "app5"), + }), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Make sure there are no access requests before we do anything. We'll clear out + // each time to save on the complexity of setting up the auth server and dependent + // users and roles. + ctx := context.Background() + require.NoError(t, srv.Auth().DeleteAllAccessRequests(ctx)) + + client, err := srv.NewClient(TestUser(test.user)) + require.NoError(t, err) + + test.errAssertionFunc(t, client.CreateAccessRequest(ctx, test.accessRequest)) + + accessRequests, err := srv.Auth().GetAccessRequests(ctx, types.AccessRequestFilter{ + ID: test.accessRequest.GetName(), + }) + require.NoError(t, err) + + if test.expected == nil { + require.Empty(t, accessRequests) + return + } + + require.Len(t, accessRequests, 1) + + // We have to ignore the name here, as it's auto-generated by the underlying access request + // logic. + require.Empty(t, cmp.Diff(test.expected, accessRequests[0], + cmpopts.IgnoreFields(types.Metadata{}, "Name", "ID"), + cmpopts.IgnoreFields(types.AccessRequestSpecV3{}), + )) + }) + } +} + +func mustAccessRequest(t *testing.T, user string, state types.RequestState, created, expires time.Time, roles []string, resourceIDs []types.ResourceID) types.AccessRequest { + t.Helper() + + accessRequest, err := types.NewAccessRequest(uuid.NewString(), user, roles...) + require.NoError(t, err) + + accessRequest.SetRequestedResourceIDs(resourceIDs) + accessRequest.SetState(state) + accessRequest.SetCreationTime(created) + accessRequest.SetExpiry(expires) + accessRequest.SetAccessExpiry(expires) + accessRequest.SetThresholds([]types.AccessReviewThreshold{{Name: "default", Approve: 1, Deny: 1}}) + accessRequest.SetRoleThresholdMapping(map[string]types.ThresholdIndexSets{ + "requestRole": { + Sets: []types.ThresholdIndexSet{ + {Indexes: []uint32{0}}, + }, + }, + }) + + return accessRequest +} + +func mustResourceID(clusterName, kind, name string) types.ResourceID { + return types.ResourceID{ + ClusterName: clusterName, + Kind: kind, + Name: name, + } +}