Skip to content

Commit

Permalink
Add tests + prevent resource leak
Browse files Browse the repository at this point in the history
  • Loading branch information
atburke committed Mar 27, 2024
1 parent c67bfcf commit b89d5cd
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 2 deletions.
40 changes: 38 additions & 2 deletions lib/services/access_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ func CalculateAccessCapabilities(ctx context.Context, clock clockwork.Clock, clt
return nil, trace.Wrap(err)
}

if len(req.ResourceIDs) != 0 {
if len(req.ResourceIDs) != 0 && !req.FilterRequestableRolesByResource {
caps.ApplicableRolesForResources, err = v.applicableSearchAsRoles(ctx, req.ResourceIDs, req.Login)
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -1414,6 +1414,17 @@ func (m *RequestValidator) truncateTTL(ctx context.Context, identity tlsca.Ident
return ttl, nil
}

// getResourceViewingRoles gets the subset of the user's roles that could be used
// to view resources (i.e. base roles + search as roles).
func (m *RequestValidator) getResourceViewingRoles(ctx context.Context) ([]string, error) {
// No need to filter by resource IDs as that will be done later.
searchAsRoles, err := m.applicableSearchAsRoles(ctx, nil, "")
if err != nil {
return nil, trace.Wrap(err)
}
return apiutils.Deduplicate(slices.Concat(searchAsRoles, m.userState.GetRoles())), nil
}

// GetRequestableRoles gets the list of all existent roles which the user is
// able to request. This operation is expensive since it loads all existent
// roles in order to determine the role list. Prefer calling CanRequestRole
Expand All @@ -1431,6 +1442,31 @@ func (m *RequestValidator) GetRequestableRoles(ctx context.Context, resourceIDs
return nil, trace.Wrap(err)
}

cluster, err := m.getter.GetClusterName()
if err != nil {
return nil, trace.Wrap(err)
}
roles, err := m.getResourceViewingRoles(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
accessChecker, err := NewAccessChecker(&AccessInfo{
Roles: roles,
Traits: m.userState.GetTraits(),
Username: m.userState.GetName(),
}, cluster.GetClusterName(), m.getter)
if err != nil {
return nil, trace.Wrap(err)
}

// Filter out resources the user requested but doesn't have access to.
filteredResources := make([]types.ResourceWithLabels, 0, len(resources))
for _, resource := range resources {
if err := accessChecker.CheckAccess(resource, AccessState{MFAVerified: true}); err == nil {
filteredResources = append(filteredResources, resource)
}
}

var expanded []string
for _, role := range allRoles {
n := role.GetName()
Expand All @@ -1439,7 +1475,7 @@ func (m *RequestValidator) GetRequestableRoles(ctx context.Context, resourceIDs
}

roleAllowsAccess := true
for _, resource := range resources {
for _, resource := range filteredResources {
access, err := m.roleAllowsResource(ctx, role, resource, loginHint)
if err != nil {
return nil, trace.Wrap(err)
Expand Down
150 changes: 150 additions & 0 deletions lib/services/access_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package services

import (
"context"
"fmt"
"strconv"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -1608,6 +1610,154 @@ func TestPruneRequestRoles(t *testing.T) {
}
}

func TestGetRequestableRoles(t *testing.T) {
t.Parallel()
ctx := context.Background()

clusterName := "my-cluster"

g := &mockGetter{
roles: make(map[string]types.Role),
userStates: make(map[string]*userloginstate.UserLoginState),
nodes: make(map[string]types.Server),
clusterName: clusterName,
}

for i := 0; i < 10; i++ {
node, err := types.NewServerWithLabels(
fmt.Sprintf("node-%d", i),
types.KindNode,
types.ServerSpecV2{},
map[string]string{"index": strconv.Itoa(i)})
require.NoError(t, err)
g.nodes[node.GetName()] = node
}

getResourceID := func(i int) types.ResourceID {
return types.ResourceID{
ClusterName: clusterName,
Kind: types.KindNode,
Name: fmt.Sprintf("node-%d", i),
}
}

roleDesc := map[string]types.RoleSpecV6{
"partial-access": {
Allow: types.RoleConditions{
NodeLabels: types.Labels{
"index": {"0", "1", "2", "3", "4"},
},
Logins: []string{"{{internal.logins}}"},
},
},
"full-access": {
Allow: types.RoleConditions{
NodeLabels: types.Labels{
"index": {"*"},
},
Logins: []string{"{{internal.logins}}"},
},
},
"full-search": {
Allow: types.RoleConditions{
Request: &types.AccessRequestConditions{
Roles: []string{"partial-access", "full-access"},
SearchAsRoles: []string{"full-access"},
},
},
},
"partial-search": {
Allow: types.RoleConditions{
Request: &types.AccessRequestConditions{
Roles: []string{"partial-access", "full-access"},
SearchAsRoles: []string{"partial-access"},
},
},
},
"partial-roles": {
Allow: types.RoleConditions{
Request: &types.AccessRequestConditions{
Roles: []string{"partial-access"},
SearchAsRoles: []string{"full-access"},
},
},
},
}

for name, spec := range roleDesc {
role, err := types.NewRole(name, spec)
require.NoError(t, err)
g.roles[name] = role
}

user := g.user(t)

tests := []struct {
name string
userRole string
requestedResources []types.ResourceID
disableFilter bool
expectedRoles []string
}{
{
name: "no resources to filter by",
userRole: "full-search",
expectedRoles: []string{"partial-access", "full-access"},
},
{
name: "filtering disabled",
userRole: "full-search",
requestedResources: []types.ResourceID{getResourceID(9)},
disableFilter: true,
expectedRoles: []string{"partial-access", "full-access"},
},
{
name: "filter by resources",
userRole: "full-search",
requestedResources: []types.ResourceID{getResourceID(9)},
expectedRoles: []string{"full-access"},
},
{
name: "resource in another cluster",
userRole: "full-search",
requestedResources: []types.ResourceID{
getResourceID(9),
{
ClusterName: "some-other-cluster",
Kind: types.KindNode,
Name: "node-9",
},
},
expectedRoles: []string{"partial-access", "full-access"},
},
{
name: "resource user shouldn't know about",
userRole: "partial-search",
requestedResources: []types.ResourceID{getResourceID(9)},
expectedRoles: []string{"partial-access", "full-access"},
},
{
name: "can view resource but not assume role",
userRole: "partial-roles",
requestedResources: []types.ResourceID{getResourceID(9)},
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
g.userStates[user].Spec.Roles = []string{tc.userRole}
accessCaps, err := CalculateAccessCapabilities(ctx, clockwork.NewFakeClock(), g, types.AccessCapabilitiesRequest{
User: user,
RequestableRoles: true,
ResourceIDs: tc.requestedResources,
FilterRequestableRolesByResource: !tc.disableFilter,
})
require.NoError(t, err)
require.ElementsMatch(t, tc.expectedRoles, accessCaps.RequestableRoles)
})
}
}

// TestCalculatePendingRequesTTL verifies that the TTL for the Access Request is capped to the
// request's access expiry or capped to the default const requestTTL, whichever is smaller.
func TestCalculatePendingRequesTTL(t *testing.T) {
Expand Down

0 comments on commit b89d5cd

Please sign in to comment.