Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
atburke committed Apr 2, 2024
1 parent 31d9284 commit 76246a4
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions lib/services/access_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package services

import (
"context"
"fmt"
"slices"
"sort"
"strings"
Expand Down Expand Up @@ -224,6 +225,7 @@ func canFilterRequestableRolesByResource(a RequestValidatorGetter, req types.Acc
// CalculateAccessCapabilities aggregates the requested capabilities using the supplied getter
// to load relevant resources.
func CalculateAccessCapabilities(ctx context.Context, clock clockwork.Clock, clt RequestValidatorGetter, identity tlsca.Identity, req types.AccessCapabilitiesRequest) (*types.AccessCapabilities, error) {
fmt.Printf("CalculateAccessCapabilities %+v\n", req)
canFilter, err := canFilterRequestableRolesByResource(clt, req)
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -1416,13 +1418,14 @@ func (m *RequestValidator) truncateTTL(ctx context.Context, identity tlsca.Ident

// getResourceViewingRoles gets the subset of the user's roles that could be used
// to view resources (i.e. base roles + search as roles).
func (m *RequestValidator) getResourceViewingRoles(ctx context.Context) ([]string, error) {
// No need to filter by resource IDs as that will be done later.
searchAsRoles, err := m.applicableSearchAsRoles(ctx, nil, "")
if err != nil {
return nil, trace.Wrap(err)
func (m *RequestValidator) getResourceViewingRoles() []string {
roles := slices.Clone(m.userState.GetRoles())
for _, role := range m.Roles.AllowSearch {
if m.CanSearchAsRole(role) {
roles = append(roles, role)
}
}
return apiutils.Deduplicate(slices.Concat(searchAsRoles, m.userState.GetRoles())), nil
return apiutils.Deduplicate(roles)
}

// GetRequestableRoles gets the list of all existent roles which the user is
Expand All @@ -1446,12 +1449,8 @@ func (m *RequestValidator) GetRequestableRoles(ctx context.Context, identity tls
if err != nil {
return nil, trace.Wrap(err)
}
roles, err := m.getResourceViewingRoles(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
accessChecker, err := NewAccessChecker(&AccessInfo{
Roles: roles,
Roles: m.getResourceViewingRoles(),
Traits: m.userState.GetTraits(),
Username: m.userState.GetName(),
AllowedResourceIDs: identity.AllowedResourceIDs,
Expand Down

0 comments on commit 76246a4

Please sign in to comment.