Skip to content

Commit

Permalink
Fix listing of participant modes in UI
Browse files Browse the repository at this point in the history
  • Loading branch information
rudream committed Apr 4, 2023
1 parent d075222 commit 9dba7ae
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 8 deletions.
2 changes: 1 addition & 1 deletion lib/auth/session_access.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ func (e *SessionAccessEvaluator) matchesJoin(allow *types.SessionJoinPolicy) boo

for _, allowRole := range allow.Roles {
// GlobToRegexp makes sure this is always a valid regexp.
expr := regexp.MustCompile(utils.GlobToRegexp(allowRole))
expr := regexp.MustCompile("^" + utils.GlobToRegexp(allowRole) + "$")

for _, policySet := range e.policySets {
if expr.MatchString(policySet.Name) {
Expand Down
30 changes: 30 additions & 0 deletions lib/auth/session_access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,35 @@ func failKindJoinTestCase(t *testing.T) joinTestCase {
}
}

// Tests to make sure that the regexp matching for roles only matches a full string
// match and not just any substring match.
// In this test case, we are making sure that having access to sessions hosted
// by someone with the role `test` doesn't also grant you access to sessions
// hosted by someone with the role `prod-test`.
func failJoinRoleNameInSubstringTestCase(t *testing.T) joinTestCase {
hostRole, err := types.NewRole("prod-test", types.RoleSpecV5{})
require.NoError(t, err)
participantRole, err := types.NewRole("participant", types.RoleSpecV5{})
require.NoError(t, err)

participantRole.SetSessionJoinPolicies([]*types.SessionJoinPolicy{{
Roles: []string{"test"},
Kinds: []string{string(types.SSHSessionKind), string(types.KubernetesSessionKind)},
Modes: []string{types.Wildcard},
}})

return joinTestCase{
name: "failRoleInSubstring",
host: hostRole,
sessionKinds: []types.SessionKind{types.SSHSessionKind, types.KubernetesSessionKind},
participant: SessionAccessContext{
Username: "participant",
Roles: []types.Role{participantRole},
},
expected: []bool{false, false},
}
}

func versionDefaultJoinTestCase(t *testing.T) joinTestCase {
hostRole, err := types.NewRole("host", types.RoleSpecV5{})
require.NoError(t, err)
Expand Down Expand Up @@ -486,6 +515,7 @@ func TestSessionAccessJoin(t *testing.T) {
successSameUserJoinTestCase(t),
failRoleJoinTestCase(t),
failKindJoinTestCase(t),
failJoinRoleNameInSubstringTestCase(t),
versionDefaultJoinTestCase(t),
}

Expand Down
8 changes: 1 addition & 7 deletions lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2801,12 +2801,6 @@ func (h *Handler) siteSessionsGet(w http.ResponseWriter, r *http.Request, p http
return nil, trace.Wrap(err)
}

var policySets []*types.SessionTrackerPolicySet
for _, role := range userRoles {
policySet := role.GetSessionPolicySet()
policySets = append(policySets, &policySet)
}

accessContext := auth.SessionAccessContext{
Username: sctx.GetUser(),
Roles: userRoles,
Expand All @@ -2817,7 +2811,7 @@ func (h *Handler) siteSessionsGet(w http.ResponseWriter, r *http.Request, p http
if tracker.GetState() != types.SessionState_SessionStateTerminated {
session := trackerToLegacySession(tracker, p.ByName("site"))
// Get the participant modes available to the user from their roles.
accessEvaluator := auth.NewSessionAccessEvaluator(policySets, types.SSHSessionKind, session.Owner)
accessEvaluator := auth.NewSessionAccessEvaluator(tracker.GetHostPolicySets(), types.SSHSessionKind, tracker.GetHostUser())
participantModes := accessEvaluator.CanJoin(accessContext)

sessions = append(sessions, siteSessionsGetResponseSession{Session: session, ParticipantModes: participantModes})
Expand Down
14 changes: 14 additions & 0 deletions lib/web/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1998,6 +1998,9 @@ func TestActiveSessions(t *testing.T) {
s := newWebSuite(t)
pack := s.authPack(t, "foo")

// Use enterprise license (required for moderated sessions).
modules.SetTestModules(t, &modules.TestModules{TestBuildType: modules.BuildEnterprise})

start := time.Now()
kinds := []types.SessionKind{
types.SSHSessionKind,
Expand Down Expand Up @@ -2025,6 +2028,17 @@ func TestActiveSessions(t *testing.T) {
Participants: []types.Participant{
{ID: "id", User: "user-1", LastActive: start},
},
HostPolicies: []*types.SessionTrackerPolicySet{
{
Name: "foo",
Version: "5",
RequireSessionJoin: []*types.SessionRequirePolicy{
{
Name: "foo",
},
},
},
},
})
require.NoError(t, err)
ids[tracker.GetSessionID()] = struct{}{}
Expand Down

0 comments on commit 9dba7ae

Please sign in to comment.