Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance using session trackers in large clusters #12584

Merged
merged 11 commits into from
May 23, 2022
110 changes: 67 additions & 43 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,13 +273,71 @@ func (a *ServerWithRoles) CreateSessionTracker(ctx context.Context, tracker type

}

func (a *ServerWithRoles) filterSessionTracker(ctx context.Context, joinerRoles []types.Role, tracker types.SessionTracker) bool {
evaluator := NewSessionAccessEvaluator(tracker.GetHostPolicySets(), tracker.GetSessionKind())
modes := evaluator.CanJoin(SessionAccessContext{Roles: joinerRoles})

if len(modes) == 0 {
return false
}

// Apply RFD 45 RBAC rules to the session if it's SSH.
// This is a bit of a hack. It converts to the old legacy format
// which we don't have all data for, luckily the fields we don't have aren't made available
// to the RBAC filter anyway.
if tracker.GetKind() == types.KindSSHSession {
ruleCtx := &services.Context{User: a.context.User}
ruleCtx.SSHSession = &session.Session{
ID: session.ID(tracker.GetSessionID()),
Namespace: apidefaults.Namespace,
Login: tracker.GetLogin(),
Created: tracker.GetCreated(),
LastActive: a.authServer.GetClock().Now(),
ServerID: tracker.GetAddress(),
ServerAddr: tracker.GetAddress(),
ServerHostname: tracker.GetHostname(),
ClusterName: tracker.GetClusterName(),
}

for _, participant := range tracker.GetParticipants() {
// We only need to fill in User here since other fields get discarded anyway.
ruleCtx.SSHSession.Parties = append(ruleCtx.SSHSession.Parties, session.Party{
User: participant.User,
})
}

// Skip past it if there's a deny rule in place blocking access.
if err := a.context.Checker.CheckAccessToRule(ruleCtx, apidefaults.Namespace, types.KindSSHSession, types.VerbList, true /* silent */); err != nil {
return false
}
}

return true
}

// GetSessionTracker returns the current state of a session tracker for an active session.
func (a *ServerWithRoles) GetSessionTracker(ctx context.Context, sessionID string) (types.SessionTracker, error) {
if err := a.serverAction(); err != nil {
tracker, err := a.authServer.GetSessionTracker(ctx, sessionID)
xacrimon marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return nil, trace.Wrap(err)
}

xacrimon marked this conversation as resolved.
Show resolved Hide resolved
if err := a.serverAction(); err == nil {
return tracker, nil
}

user := a.context.User
joinerRoles, err := services.FetchRoles(user.GetRoles(), a.authServer, user.GetTraits())
if err != nil {
return nil, trace.Wrap(err)
}

return a.authServer.GetSessionTracker(ctx, sessionID)
ok := a.filterSessionTracker(ctx, joinerRoles, tracker)
if !ok {
return nil, trace.NotFound("session %v not found", sessionID)
}

return tracker, nil
}

// GetActiveSessionTrackers returns a list of active session trackers.
Expand All @@ -290,52 +348,18 @@ func (a *ServerWithRoles) GetActiveSessionTrackers(ctx context.Context) ([]types
}

xacrimon marked this conversation as resolved.
Show resolved Hide resolved
var filteredSessions []types.SessionTracker
joinerRoles, err := a.authServer.GetRoles(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
xacrimon marked this conversation as resolved.
Show resolved Hide resolved

for _, sess := range sessions {
evaluator := NewSessionAccessEvaluator(sess.GetHostPolicySets(), sess.GetSessionKind())
joinerRoles, err := a.authServer.GetRoles(ctx)
if err != nil {
return nil, trace.Wrap(err)
}

modes, err := evaluator.CanJoin(SessionAccessContext{Roles: joinerRoles})
if err == nil || len(modes) > 0 {
// Apply RFD 45 RBAC rules to the session if it's SSH.
// This is a bit of a hack. It converts to the old legacy format
// which we don't have all data for, luckily the fields we don't have aren't made available
// to the RBAC filter anyway.
if sess.GetKind() == types.KindSSHSession {
ruleCtx := &services.Context{User: a.context.User}
ruleCtx.SSHSession = &session.Session{
ID: session.ID(sess.GetSessionID()),
Namespace: apidefaults.Namespace,
Login: sess.GetLogin(),
Created: sess.GetCreated(),
LastActive: a.authServer.GetClock().Now(),
ServerID: sess.GetAddress(),
ServerAddr: sess.GetAddress(),
ServerHostname: sess.GetHostname(),
ClusterName: sess.GetClusterName(),
}

for _, participant := range sess.GetParticipants() {
// We only need to fill in User here since other fields get discarded anyway.
ruleCtx.SSHSession.Parties = append(ruleCtx.SSHSession.Parties, session.Party{
User: participant.User,
})
}

// Skip past it if there's a deny rule in place blocking access.
if err := a.context.Checker.CheckAccessToRule(ruleCtx, apidefaults.Namespace, types.KindSSHSession, types.VerbList, true /* silent */); err != nil {
continue
}
}

ok := a.filterSessionTracker(ctx, joinerRoles, sess)
if ok {
filteredSessions = append(filteredSessions, sess)
} else {
log.Warnf("Session %v is not allowed to join: %v", sess.GetSessionID(), err)
}
}

return filteredSessions, nil
}

Expand Down
6 changes: 3 additions & 3 deletions lib/auth/session_access.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,10 @@ func HasV5Role(roles []types.Role) bool {

// CanJoin returns the modes a user has access to join a session with.
// If the list is empty, the user doesn't have access to join the session at all.
func (e *SessionAccessEvaluator) CanJoin(user SessionAccessContext) ([]types.SessionParticipantMode, error) {
func (e *SessionAccessEvaluator) CanJoin(user SessionAccessContext) []types.SessionParticipantMode {
// If we don't support session access controls, return the default mode set that was supported prior to Moderated Sessions.
if !HasV5Role(user.Roles) {
return preAccessControlsModes(e.kind), nil
return preAccessControlsModes(e.kind)
}

var modes []types.SessionParticipantMode
Expand All @@ -205,7 +205,7 @@ func (e *SessionAccessEvaluator) CanJoin(user SessionAccessContext) ([]types.Ses
}
}

return modes, nil
return modes
}

func SliceContainsMode(s []types.SessionParticipantMode, e types.SessionParticipantMode) bool {
Expand Down
3 changes: 1 addition & 2 deletions lib/auth/session_access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,7 @@ func TestSessionAccessJoin(t *testing.T) {
t.Run(testCase.name, func(t *testing.T) {
policy := testCase.host.GetSessionPolicySet()
evaluator := NewSessionAccessEvaluator([]*types.SessionTrackerPolicySet{&policy}, testCase.sessionKind)
result, err := evaluator.CanJoin(testCase.participant)
require.NoError(t, err)
result := evaluator.CanJoin(testCase.participant)
require.Equal(t, testCase.expected, len(result) > 0)
})
}
Expand Down
23 changes: 2 additions & 21 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1565,18 +1565,11 @@ func (tc *TeleportClient) Join(ctx context.Context, mode types.SessionParticipan
return trace.Wrap(err)
}

var session types.SessionTracker
sessions, err := site.GetActiveSessionTrackers(ctx)
if err != nil {
session, err := site.GetSessionTracker(ctx, string(sessionID))
if err != nil && !trace.IsNotFound(err) {
return trace.Wrap(err)
}

for _, sessionIter := range sessions {
if sessionIter.GetSessionID() == string(sessionID) {
session = sessionIter
}
}

if session == nil {
return trace.NotFound(notFoundErrorMessage)
}
Expand Down Expand Up @@ -3595,15 +3588,3 @@ func findActiveDatabases(key *Key) ([]tlsca.RouteToDatabase, error) {
}
return databases, nil
}

// GetActiveSessions fetches a list of all active sessions tracked by the SessionTracker resource
// that the user has access to.
func (tc *TeleportClient) GetActiveSessions(ctx context.Context) ([]types.SessionTracker, error) {
proxy, err := tc.ConnectToProxy(ctx)
if err != nil {
return nil, trace.Wrap(err)
}

defer proxy.Close()
return proxy.GetActiveSessions(ctx)
}
15 changes: 0 additions & 15 deletions lib/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,21 +74,6 @@ type NodeClient struct {
OnMFA func()
}

// GetActiveSessions returns a list of active session trackers.
func (proxy *ProxyClient) GetActiveSessions(ctx context.Context) ([]types.SessionTracker, error) {
auth, err := proxy.ConnectToCurrentCluster(ctx, false)
if err != nil {
return nil, trace.Wrap(err)
}
defer auth.Close()
sessions, err := auth.GetActiveSessionTrackers(ctx)
if err != nil {
return nil, trace.Wrap(err)
}

return sessions, nil
}

// GetSites returns list of the "sites" (AKA teleport clusters) connected to the proxy
// Each site is returned as an instance of its auth server
//
Expand Down
6 changes: 1 addition & 5 deletions lib/kube/proxy/sess.go
Original file line number Diff line number Diff line change
Expand Up @@ -738,11 +738,7 @@ func (s *session) join(p *party) error {
Roles: roles,
}

modes, err := s.accessEvaluator.CanJoin(accessContext)
if err != nil {
return trace.Wrap(err)
}

modes := s.accessEvaluator.CanJoin(accessContext)
if !auth.SliceContainsMode(modes, p.Mode) {
return trace.AccessDenied("insufficient permissions to join session")
}
Expand Down
1 change: 0 additions & 1 deletion lib/services/local/sessiontracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ func (s *sessionTracker) GetActiveSessionTrackers(ctx context.Context) ([]types.
sessions = append(sessions, session)
case !after && item.Expires.IsZero():
// Clear item if expiry is not set on the backend.
// We currently don't set the expiry here but we will when #11551 is merged.
noExpiry = append(noExpiry, item)
default:
// If we take this branch, the expiry is set and the backend is responsible for cleaning up the item.
Expand Down
6 changes: 1 addition & 5 deletions lib/srv/sess.go
Original file line number Diff line number Diff line change
Expand Up @@ -1509,11 +1509,7 @@ func (s *session) join(ch ssh.Channel, ctx *ServerContext, mode types.SessionPar
Roles: roles,
}

modes, err := s.access.CanJoin(accessContext)
if err != nil {
return nil, trace.Wrap(err)
}

modes := s.access.CanJoin(accessContext)
if !auth.SliceContainsMode(modes, mode) {
return nil, trace.AccessDenied("insufficient permissions to join session %v", s.id)
}
Expand Down
23 changes: 16 additions & 7 deletions tool/tsh/kube.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,18 +105,17 @@ func newKubeJoinCommand(parent *kingpin.CmdClause) *kubeJoinCommand {
}

func (c *kubeJoinCommand) getSessionMeta(ctx context.Context, tc *client.TeleportClient) (types.SessionTracker, error) {
sessions, err := tc.GetActiveSessions(ctx)
proxy, err := tc.ConnectToProxy(ctx)
if err != nil {
return nil, trace.Wrap(err)
}

for _, session := range sessions {
if session.GetSessionID() == c.session {
return session, nil
}
site, err := proxy.ConnectToCurrentCluster(ctx, false)
if err != nil {
return nil, trace.Wrap(err)
}

return nil, trace.NotFound("session %q not found", c.session)
return site.GetSessionTracker(ctx, c.session)
}

func (c *kubeJoinCommand) run(cf *CLIConf) error {
Expand Down Expand Up @@ -489,7 +488,17 @@ func (c *kubeSessionsCommand) run(cf *CLIConf) error {
return trace.Wrap(err)
}

sessions, err := tc.GetActiveSessions(cf.Context)
proxy, err := tc.ConnectToProxy(cf.Context)
if err != nil {
return trace.Wrap(err)
}

site, err := proxy.ConnectToCurrentCluster(cf.Context, true)
if err != nil {
return trace.Wrap(err)
}

sessions, err := site.GetActiveSessionTrackers(cf.Context)
if err != nil {
return trace.Wrap(err)
}
Expand Down