Skip to content

Commit

Permalink
Prevent races creating web api session context (#23691) (#23733)
Browse files Browse the repository at this point in the history
Closes #23533 and #20963.

There was a race to create the `web.SessionContext` for a session
when multiple Proxies are behind a load balancer. Only the Proxy
that processes the login will have a `web.SessionContext` created
for the session. Any subsequent requests to the other Proxies in
the pool would create one if the request was authenticated. However,
multiple requests within a short succession could cause a
single Proxy to create multiple `web.SessionContext` for a single
session. When that happens the most recently created `web.SessionContext`
gets saved and the previous `web.SessonContext` gets closed. Closing
causes the `auth.Client` to be closed, which causes any active requests
for that client to return with a `grpc: client connection is closing`
error. This manifests in a single request from the web UI to fail
and depending on the request, for a banner to be displayed with the
error. Refreshing the page or navigating to another page would
resolve the problem because the most recent `web.SessionContext`
would be used with the still open `auth.Client`.

To prevent `web.Handler.AuthenticateRequest` from racing to create
the `web.SessionContext` a `singleflight.Group` was added to the
`web.sessionCache`. When multiple requests come in for the same
session they now will all use the first `web.SessionContext` to
be created instead of each creating their own.
  • Loading branch information
rosstimothy committed Mar 28, 2023
1 parent b93bc13 commit bc40839
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 23 deletions.
4 changes: 2 additions & 2 deletions lib/web/apiserver.go
Expand Up @@ -1871,7 +1871,7 @@ func (h *Handler) renewWebSession(w http.ResponseWriter, r *http.Request, params
if err != nil {
return nil, trace.Wrap(err)
}
newContext, err := h.auth.newSessionContextFromSession(newSession)
newContext, err := h.auth.newSessionContextFromSession(r.Context(), newSession)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -3456,7 +3456,7 @@ func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, ch
logger.WithError(err).Warn("Failed to decode cookie.")
return nil, trace.AccessDenied("failed to decode cookie")
}
ctx, err := h.auth.validateSession(r.Context(), decodedCookie.User, decodedCookie.SID)
ctx, err := h.auth.getOrCreateSession(r.Context(), decodedCookie.User, decodedCookie.SID)
if err != nil {
logger.WithError(err).Warn("Invalid session.")
ClearSession(w)
Expand Down
75 changes: 75 additions & 0 deletions lib/web/apiserver_test.go
Expand Up @@ -8308,3 +8308,78 @@ func (s *fakeKubeService) ListKubernetesResources(ctx context.Context, req *kube
TotalCount: 2,
}, nil
}

// TestSimultaneousAuthenticateRequest ensures that multiple authenticated
// requests do not race to create a SessionContext. This would happen when
// Proxies were deployed behind a round-robin load balancer. Only the Proxy
// that handled the login will have initially created a SessionContext for
// the particular user+session. All subsequent requests to the other Proxies
// in the load balancer pool attempt to create a SessionContext in
// [Handler.AuthenticateRequest] if one didn't already exist. If the web UI
// makes enough requests fast enough it can result in the Proxy trying to
// create multiple SessionContext for a user+session. Since only one SessionContext
// is stored in the sessionCache all previous SessionContext and their underlying
// auth client get closed, which results in an ugly and unfriendly
// `grpc: the client connection is closing` error banner on the web UI.
func TestSimultaneousAuthenticateRequest(t *testing.T) {
ctx := context.Background()
env := newWebPack(t, 1)

proxy := env.proxies[0]

// Authenticate to get a session token and cookies.
pack := proxy.authPack(t, "test-user@example.com", nil)

// Reset the sessions so that all future requests will race to create
// a new SessionContext for the user + session pair to simulate multiple
// proxies behind a load balancer.
proxy.handler.handler.auth.sessions = map[string]*SessionContext{}

// Create a request with the auth header and cookies for the session.
endpoint := pack.clt.Endpoint("webapi", "sites")
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
require.NoError(t, err)

req.Header.Set("Authorization", "Bearer "+pack.session.Token)
for _, cookie := range pack.cookies {
req.AddCookie(cookie)
}

// Spawn several requests in parallel and attempt to use the auth client.
type res struct {
domain string
err error
}
const requests = 10
respC := make(chan res, requests)
for i := 0; i < requests; i++ {
go func() {
sctx, err := proxy.handler.handler.AuthenticateRequest(httptest.NewRecorder(), req.Clone(ctx), false)
if err != nil {
respC <- res{err: err}
return
}

clt, err := sctx.GetClient()
if err != nil {
respC <- res{err: err}
return
}

domain, err := clt.GetDomainName(ctx)
respC <- res{domain: domain, err: err}
}()
}

// Assert that all requests were successful and each one was able to
// get the domain name without its auth client being closed.
for i := 0; i < requests; i++ {
select {
case res := <-respC:
require.NoError(t, res.err)
require.Equal(t, "localhost", res.domain)
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for responses")
}
}
}
55 changes: 34 additions & 21 deletions lib/web/sessions.go
Expand Up @@ -653,6 +653,9 @@ type sessionCache struct {
// sessions maps user/sessionID to an active web session value between renewals.
// This is the client-facing session handle
sessions map[string]*SessionContext
// sessionGroup ensures only a single SessionContext will exist for a
// user+session.
sessionGroup singleflight.Group

// session cache maintains a list of resources per-user as long
// as the user session is active even though individual session values
Expand Down Expand Up @@ -832,17 +835,32 @@ func (s *sessionCache) ValidateTrustedCluster(ctx context.Context, validateReque
return s.proxyClient.ValidateTrustedCluster(ctx, validateRequest)
}

// validateSession validates the session given with user and session ID.
// Returns a new or existing session context.
func (s *sessionCache) validateSession(ctx context.Context, user, sessionID string) (*SessionContext, error) {
sessionCtx, err := s.getContext(user, sessionID)
if err == nil {
return sessionCtx, nil
}
if !trace.IsNotFound(err) {
// getOrCreateSession gets the SessionContext for the user and session ID. If one does
// not exist, then a new one is created.
func (s *sessionCache) getOrCreateSession(ctx context.Context, user, sessionID string) (*SessionContext, error) {
key := sessionKey(user, sessionID)

// Use sessionGroup to prevent multiple requests from racing to create a SessionContext.
i, err, _ := s.sessionGroup.Do(key, func() (any, error) {
sessionCtx, ok := s.getContext(key)
if ok {
return sessionCtx, nil
}

return s.newSessionContext(ctx, user, sessionID)
})

if err != nil {
return nil, trace.Wrap(err)
}
return s.newSessionContext(ctx, user, sessionID)

sctx, ok := i.(*SessionContext)
if !ok {
return nil, trace.BadParameter("expected SessionContext, got %T", i)
}

return sctx, nil

}

func (s *sessionCache) invalidateSession(ctx context.Context, sctx *SessionContext) error {
Expand All @@ -869,15 +887,11 @@ func (s *sessionCache) invalidateSession(ctx context.Context, sctx *SessionConte
return nil
}

func (s *sessionCache) getContext(user, sessionID string) (*SessionContext, error) {
func (s *sessionCache) getContext(key string) (*SessionContext, bool) {
s.mu.Lock()
defer s.mu.Unlock()
ctx, ok := s.sessions[user+sessionID]
if ok {
return ctx, nil
}
return nil, trace.NotFound("no context for user %v and session %v",
user, sessionID)
ctx, ok := s.sessions[key]
return ctx, ok
}

func (s *sessionCache) insertContext(user string, sctx *SessionContext) (exists bool) {
Expand Down Expand Up @@ -959,11 +973,11 @@ func (s *sessionCache) newSessionContext(ctx context.Context, user, sessionID st
// This will fail if the session has expired and was removed
return nil, trace.Wrap(err)
}
return s.newSessionContextFromSession(session)
return s.newSessionContextFromSession(ctx, session)
}

func (s *sessionCache) newSessionContextFromSession(session types.WebSession) (*SessionContext, error) {
tlsConfig, err := s.tlsConfig(session.GetTLSCert(), session.GetPriv())
func (s *sessionCache) newSessionContextFromSession(ctx context.Context, session types.WebSession) (*SessionContext, error) {
tlsConfig, err := s.tlsConfig(ctx, session.GetTLSCert(), session.GetPriv())
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -1002,8 +1016,7 @@ func (s *sessionCache) newSessionContextFromSession(session types.WebSession) (*
return sctx, nil
}

func (s *sessionCache) tlsConfig(cert, privKey []byte) (*tls.Config, error) {
ctx := context.TODO()
func (s *sessionCache) tlsConfig(ctx context.Context, cert, privKey []byte) (*tls.Config, error) {
ca, err := s.proxyClient.GetCertAuthority(ctx, types.CertAuthID{
Type: types.HostCA,
DomainName: s.clusterName,
Expand Down

0 comments on commit bc40839

Please sign in to comment.