From bc4083985bfaaa90bdd36aba36202c6b00c7fcd6 Mon Sep 17 00:00:00 2001 From: rosstimothy <39066650+rosstimothy@users.noreply.github.com> Date: Tue, 28 Mar 2023 19:10:56 -0400 Subject: [PATCH] Prevent races creating web api session context (#23691) (#23733) Closes #23533 and #20963. There was a race to create the `web.SessionContext` for a session when multiple Proxies are behind a load balancer. Only the Proxy that processes the login will have a `web.SessionContext` created for the session. Any subsequent requests to the other Proxies in the pool would create one if the request was authenticated. However, multiple requests within a short succession could cause a single Proxy to create multiple `web.SessionContext` for a single session. When that happens the most recently created `web.SessionContext` gets saved and the previous `web.SessonContext` gets closed. Closing causes the `auth.Client` to be closed, which causes any active requests for that client to return with a `grpc: client connection is closing` error. This manifests in a single request from the web UI to fail and depending on the request, for a banner to be displayed with the error. Refreshing the page or navigating to another page would resolve the problem because the most recent `web.SessionContext` would be used with the still open `auth.Client`. To prevent `web.Handler.AuthenticateRequest` from racing to create the `web.SessionContext` a `singleflight.Group` was added to the `web.sessionCache`. When multiple requests come in for the same session they now will all use the first `web.SessionContext` to be created instead of each creating their own. --- lib/web/apiserver.go | 4 +-- lib/web/apiserver_test.go | 75 +++++++++++++++++++++++++++++++++++++++ lib/web/sessions.go | 55 +++++++++++++++++----------- 3 files changed, 111 insertions(+), 23 deletions(-) diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 510b70404a555..710c9d94943c3 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -1871,7 +1871,7 @@ func (h *Handler) renewWebSession(w http.ResponseWriter, r *http.Request, params if err != nil { return nil, trace.Wrap(err) } - newContext, err := h.auth.newSessionContextFromSession(newSession) + newContext, err := h.auth.newSessionContextFromSession(r.Context(), newSession) if err != nil { return nil, trace.Wrap(err) } @@ -3456,7 +3456,7 @@ func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, ch logger.WithError(err).Warn("Failed to decode cookie.") return nil, trace.AccessDenied("failed to decode cookie") } - ctx, err := h.auth.validateSession(r.Context(), decodedCookie.User, decodedCookie.SID) + ctx, err := h.auth.getOrCreateSession(r.Context(), decodedCookie.User, decodedCookie.SID) if err != nil { logger.WithError(err).Warn("Invalid session.") ClearSession(w) diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 4fa9f86768036..329293df9c8c2 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -8308,3 +8308,78 @@ func (s *fakeKubeService) ListKubernetesResources(ctx context.Context, req *kube TotalCount: 2, }, nil } + +// TestSimultaneousAuthenticateRequest ensures that multiple authenticated +// requests do not race to create a SessionContext. This would happen when +// Proxies were deployed behind a round-robin load balancer. Only the Proxy +// that handled the login will have initially created a SessionContext for +// the particular user+session. All subsequent requests to the other Proxies +// in the load balancer pool attempt to create a SessionContext in +// [Handler.AuthenticateRequest] if one didn't already exist. If the web UI +// makes enough requests fast enough it can result in the Proxy trying to +// create multiple SessionContext for a user+session. Since only one SessionContext +// is stored in the sessionCache all previous SessionContext and their underlying +// auth client get closed, which results in an ugly and unfriendly +// `grpc: the client connection is closing` error banner on the web UI. +func TestSimultaneousAuthenticateRequest(t *testing.T) { + ctx := context.Background() + env := newWebPack(t, 1) + + proxy := env.proxies[0] + + // Authenticate to get a session token and cookies. + pack := proxy.authPack(t, "test-user@example.com", nil) + + // Reset the sessions so that all future requests will race to create + // a new SessionContext for the user + session pair to simulate multiple + // proxies behind a load balancer. + proxy.handler.handler.auth.sessions = map[string]*SessionContext{} + + // Create a request with the auth header and cookies for the session. + endpoint := pack.clt.Endpoint("webapi", "sites") + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + require.NoError(t, err) + + req.Header.Set("Authorization", "Bearer "+pack.session.Token) + for _, cookie := range pack.cookies { + req.AddCookie(cookie) + } + + // Spawn several requests in parallel and attempt to use the auth client. + type res struct { + domain string + err error + } + const requests = 10 + respC := make(chan res, requests) + for i := 0; i < requests; i++ { + go func() { + sctx, err := proxy.handler.handler.AuthenticateRequest(httptest.NewRecorder(), req.Clone(ctx), false) + if err != nil { + respC <- res{err: err} + return + } + + clt, err := sctx.GetClient() + if err != nil { + respC <- res{err: err} + return + } + + domain, err := clt.GetDomainName(ctx) + respC <- res{domain: domain, err: err} + }() + } + + // Assert that all requests were successful and each one was able to + // get the domain name without its auth client being closed. + for i := 0; i < requests; i++ { + select { + case res := <-respC: + require.NoError(t, res.err) + require.Equal(t, "localhost", res.domain) + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for responses") + } + } +} diff --git a/lib/web/sessions.go b/lib/web/sessions.go index d4801699638d8..f7db29b16b431 100644 --- a/lib/web/sessions.go +++ b/lib/web/sessions.go @@ -653,6 +653,9 @@ type sessionCache struct { // sessions maps user/sessionID to an active web session value between renewals. // This is the client-facing session handle sessions map[string]*SessionContext + // sessionGroup ensures only a single SessionContext will exist for a + // user+session. + sessionGroup singleflight.Group // session cache maintains a list of resources per-user as long // as the user session is active even though individual session values @@ -832,17 +835,32 @@ func (s *sessionCache) ValidateTrustedCluster(ctx context.Context, validateReque return s.proxyClient.ValidateTrustedCluster(ctx, validateRequest) } -// validateSession validates the session given with user and session ID. -// Returns a new or existing session context. -func (s *sessionCache) validateSession(ctx context.Context, user, sessionID string) (*SessionContext, error) { - sessionCtx, err := s.getContext(user, sessionID) - if err == nil { - return sessionCtx, nil - } - if !trace.IsNotFound(err) { +// getOrCreateSession gets the SessionContext for the user and session ID. If one does +// not exist, then a new one is created. +func (s *sessionCache) getOrCreateSession(ctx context.Context, user, sessionID string) (*SessionContext, error) { + key := sessionKey(user, sessionID) + + // Use sessionGroup to prevent multiple requests from racing to create a SessionContext. + i, err, _ := s.sessionGroup.Do(key, func() (any, error) { + sessionCtx, ok := s.getContext(key) + if ok { + return sessionCtx, nil + } + + return s.newSessionContext(ctx, user, sessionID) + }) + + if err != nil { return nil, trace.Wrap(err) } - return s.newSessionContext(ctx, user, sessionID) + + sctx, ok := i.(*SessionContext) + if !ok { + return nil, trace.BadParameter("expected SessionContext, got %T", i) + } + + return sctx, nil + } func (s *sessionCache) invalidateSession(ctx context.Context, sctx *SessionContext) error { @@ -869,15 +887,11 @@ func (s *sessionCache) invalidateSession(ctx context.Context, sctx *SessionConte return nil } -func (s *sessionCache) getContext(user, sessionID string) (*SessionContext, error) { +func (s *sessionCache) getContext(key string) (*SessionContext, bool) { s.mu.Lock() defer s.mu.Unlock() - ctx, ok := s.sessions[user+sessionID] - if ok { - return ctx, nil - } - return nil, trace.NotFound("no context for user %v and session %v", - user, sessionID) + ctx, ok := s.sessions[key] + return ctx, ok } func (s *sessionCache) insertContext(user string, sctx *SessionContext) (exists bool) { @@ -959,11 +973,11 @@ func (s *sessionCache) newSessionContext(ctx context.Context, user, sessionID st // This will fail if the session has expired and was removed return nil, trace.Wrap(err) } - return s.newSessionContextFromSession(session) + return s.newSessionContextFromSession(ctx, session) } -func (s *sessionCache) newSessionContextFromSession(session types.WebSession) (*SessionContext, error) { - tlsConfig, err := s.tlsConfig(session.GetTLSCert(), session.GetPriv()) +func (s *sessionCache) newSessionContextFromSession(ctx context.Context, session types.WebSession) (*SessionContext, error) { + tlsConfig, err := s.tlsConfig(ctx, session.GetTLSCert(), session.GetPriv()) if err != nil { return nil, trace.Wrap(err) } @@ -1002,8 +1016,7 @@ func (s *sessionCache) newSessionContextFromSession(session types.WebSession) (* return sctx, nil } -func (s *sessionCache) tlsConfig(cert, privKey []byte) (*tls.Config, error) { - ctx := context.TODO() +func (s *sessionCache) tlsConfig(ctx context.Context, cert, privKey []byte) (*tls.Config, error) { ca, err := s.proxyClient.GetCertAuthority(ctx, types.CertAuthID{ Type: types.HostCA, DomainName: s.clusterName,