Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v12] Prevent races creating web api session context #23733

Merged
merged 1 commit into from Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/web/apiserver.go
Expand Up @@ -1871,7 +1871,7 @@ func (h *Handler) renewWebSession(w http.ResponseWriter, r *http.Request, params
if err != nil {
return nil, trace.Wrap(err)
}
newContext, err := h.auth.newSessionContextFromSession(newSession)
newContext, err := h.auth.newSessionContextFromSession(r.Context(), newSession)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -3456,7 +3456,7 @@ func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, ch
logger.WithError(err).Warn("Failed to decode cookie.")
return nil, trace.AccessDenied("failed to decode cookie")
}
ctx, err := h.auth.validateSession(r.Context(), decodedCookie.User, decodedCookie.SID)
ctx, err := h.auth.getOrCreateSession(r.Context(), decodedCookie.User, decodedCookie.SID)
if err != nil {
logger.WithError(err).Warn("Invalid session.")
ClearSession(w)
Expand Down
75 changes: 75 additions & 0 deletions lib/web/apiserver_test.go
Expand Up @@ -8308,3 +8308,78 @@ func (s *fakeKubeService) ListKubernetesResources(ctx context.Context, req *kube
TotalCount: 2,
}, nil
}

// TestSimultaneousAuthenticateRequest ensures that multiple authenticated
// requests do not race to create a SessionContext. This would happen when
// Proxies were deployed behind a round-robin load balancer. Only the Proxy
// that handled the login will have initially created a SessionContext for
// the particular user+session. All subsequent requests to the other Proxies
// in the load balancer pool attempt to create a SessionContext in
// [Handler.AuthenticateRequest] if one didn't already exist. If the web UI
// makes enough requests fast enough it can result in the Proxy trying to
// create multiple SessionContext for a user+session. Since only one SessionContext
// is stored in the sessionCache all previous SessionContext and their underlying
// auth client get closed, which results in an ugly and unfriendly
// `grpc: the client connection is closing` error banner on the web UI.
func TestSimultaneousAuthenticateRequest(t *testing.T) {
ctx := context.Background()
env := newWebPack(t, 1)

proxy := env.proxies[0]

// Authenticate to get a session token and cookies.
pack := proxy.authPack(t, "test-user@example.com", nil)

// Reset the sessions so that all future requests will race to create
// a new SessionContext for the user + session pair to simulate multiple
// proxies behind a load balancer.
proxy.handler.handler.auth.sessions = map[string]*SessionContext{}

// Create a request with the auth header and cookies for the session.
endpoint := pack.clt.Endpoint("webapi", "sites")
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
require.NoError(t, err)

req.Header.Set("Authorization", "Bearer "+pack.session.Token)
for _, cookie := range pack.cookies {
req.AddCookie(cookie)
}

// Spawn several requests in parallel and attempt to use the auth client.
type res struct {
domain string
err error
}
const requests = 10
respC := make(chan res, requests)
for i := 0; i < requests; i++ {
go func() {
sctx, err := proxy.handler.handler.AuthenticateRequest(httptest.NewRecorder(), req.Clone(ctx), false)
if err != nil {
respC <- res{err: err}
return
}

clt, err := sctx.GetClient()
if err != nil {
respC <- res{err: err}
return
}

domain, err := clt.GetDomainName(ctx)
respC <- res{domain: domain, err: err}
}()
}

// Assert that all requests were successful and each one was able to
// get the domain name without its auth client being closed.
for i := 0; i < requests; i++ {
select {
case res := <-respC:
require.NoError(t, res.err)
require.Equal(t, "localhost", res.domain)
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for responses")
}
}
}
55 changes: 34 additions & 21 deletions lib/web/sessions.go
Expand Up @@ -653,6 +653,9 @@ type sessionCache struct {
// sessions maps user/sessionID to an active web session value between renewals.
// This is the client-facing session handle
sessions map[string]*SessionContext
// sessionGroup ensures only a single SessionContext will exist for a
// user+session.
sessionGroup singleflight.Group

// session cache maintains a list of resources per-user as long
// as the user session is active even though individual session values
Expand Down Expand Up @@ -832,17 +835,32 @@ func (s *sessionCache) ValidateTrustedCluster(ctx context.Context, validateReque
return s.proxyClient.ValidateTrustedCluster(ctx, validateRequest)
}

// validateSession validates the session given with user and session ID.
// Returns a new or existing session context.
func (s *sessionCache) validateSession(ctx context.Context, user, sessionID string) (*SessionContext, error) {
sessionCtx, err := s.getContext(user, sessionID)
if err == nil {
return sessionCtx, nil
}
if !trace.IsNotFound(err) {
// getOrCreateSession gets the SessionContext for the user and session ID. If one does
// not exist, then a new one is created.
func (s *sessionCache) getOrCreateSession(ctx context.Context, user, sessionID string) (*SessionContext, error) {
key := sessionKey(user, sessionID)

// Use sessionGroup to prevent multiple requests from racing to create a SessionContext.
i, err, _ := s.sessionGroup.Do(key, func() (any, error) {
sessionCtx, ok := s.getContext(key)
if ok {
return sessionCtx, nil
}

return s.newSessionContext(ctx, user, sessionID)
})

if err != nil {
return nil, trace.Wrap(err)
}
return s.newSessionContext(ctx, user, sessionID)

sctx, ok := i.(*SessionContext)
if !ok {
return nil, trace.BadParameter("expected SessionContext, got %T", i)
}

return sctx, nil

}

func (s *sessionCache) invalidateSession(ctx context.Context, sctx *SessionContext) error {
Expand All @@ -869,15 +887,11 @@ func (s *sessionCache) invalidateSession(ctx context.Context, sctx *SessionConte
return nil
}

func (s *sessionCache) getContext(user, sessionID string) (*SessionContext, error) {
func (s *sessionCache) getContext(key string) (*SessionContext, bool) {
s.mu.Lock()
defer s.mu.Unlock()
ctx, ok := s.sessions[user+sessionID]
if ok {
return ctx, nil
}
return nil, trace.NotFound("no context for user %v and session %v",
user, sessionID)
ctx, ok := s.sessions[key]
return ctx, ok
}

func (s *sessionCache) insertContext(user string, sctx *SessionContext) (exists bool) {
Expand Down Expand Up @@ -959,11 +973,11 @@ func (s *sessionCache) newSessionContext(ctx context.Context, user, sessionID st
// This will fail if the session has expired and was removed
return nil, trace.Wrap(err)
}
return s.newSessionContextFromSession(session)
return s.newSessionContextFromSession(ctx, session)
}

func (s *sessionCache) newSessionContextFromSession(session types.WebSession) (*SessionContext, error) {
tlsConfig, err := s.tlsConfig(session.GetTLSCert(), session.GetPriv())
func (s *sessionCache) newSessionContextFromSession(ctx context.Context, session types.WebSession) (*SessionContext, error) {
tlsConfig, err := s.tlsConfig(ctx, session.GetTLSCert(), session.GetPriv())
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -1002,8 +1016,7 @@ func (s *sessionCache) newSessionContextFromSession(session types.WebSession) (*
return sctx, nil
}

func (s *sessionCache) tlsConfig(cert, privKey []byte) (*tls.Config, error) {
ctx := context.TODO()
func (s *sessionCache) tlsConfig(ctx context.Context, cert, privKey []byte) (*tls.Config, error) {
ca, err := s.proxyClient.GetCertAuthority(ctx, types.CertAuthID{
Type: types.HostCA,
DomainName: s.clusterName,
Expand Down