Skip to content

Commit

Permalink
fix: reduce request transaction lifetime
Browse files Browse the repository at this point in the history
PostgreSQL row and table locks are held until the end of the
transaction.

Previously we were using a single transaction for the entire lifetime of
the request. One of the first things we did with that transaction was
update the access key and user record associated with the request.

That update would block any other requests from that user until the
first request finished. For short lived requests that was fine, but
longer requests (anything making HTTP requests to an IDP, or blocking
and waiting for updates) would prevent the user from making concurrent
requests.

This commit fixes the problem by committing the "middleware" transaction
first, then starting a new transaction for the request handler.
The middleware transaction should be short lived, because does only a
few database operations and returns. It doesn't block or make external
requests.
  • Loading branch information
dnephin committed Sep 23, 2022
1 parent 8761e63 commit b697a95
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 22 deletions.
6 changes: 1 addition & 5 deletions internal/server/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -610,11 +610,7 @@ func (s Server) loadConfig(config Config) error {
if err != nil {
return err
}
defer func() {
if err := tx.Rollback(); err != nil {
logging.L.Error().Err(err).Msg("failed to rollback database transaction")
}
}()
defer logError(tx.Rollback, "failed to rollback loadConfig transaction")
tx = tx.WithOrgID(org.ID)

if config.DefaultOrganizationDomain != org.Domain {
Expand Down
39 changes: 35 additions & 4 deletions internal/server/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,13 @@ func handleInfraDestinationHeader(c *gin.Context) error {
// gin.Context.
// See validateRequestOrganization for a related function used for unauthenticated
// routes.
func authenticateRequest(c *gin.Context, tx *data.Transaction, srv *Server) error {
func authenticateRequest(c *gin.Context, srv *Server) error {
tx, err := srv.db.Begin(c.Request.Context())
if err != nil {
return err
}
defer logError(tx.Rollback, "failed to rollback middleware transaction")

authned, err := requireAccessKey(c, tx, srv)
if err != nil {
return err
Expand All @@ -86,6 +92,7 @@ func authenticateRequest(c *gin.Context, tx *data.Transaction, srv *Server) erro
return internal.ErrBadRequest
}

// TODO: move to caller
rCtx := access.RequestContext{
Request: c.Request,
DBTxn: tx.WithOrgID(authned.Organization.ID),
Expand All @@ -96,6 +103,10 @@ func authenticateRequest(c *gin.Context, tx *data.Transaction, srv *Server) erro
if err := handleInfraDestinationHeader(c); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return err
}
rCtx.DBTxn = nil
return nil
}

Expand Down Expand Up @@ -131,7 +142,13 @@ func validateOrgMatchesRequest(req *http.Request, tx data.GormTxn, accessKeyOrg
//
// validateRequestOrganization is also responsible for adding RequestContext to the
// gin.Context.
func validateRequestOrganization(c *gin.Context, tx *data.Transaction, srv *Server) error {
func validateRequestOrganization(c *gin.Context, srv *Server) error {
tx, err := srv.db.Begin(c.Request.Context())
if err != nil {
return err
}
defer logError(tx.Rollback, "failed to rollback middleware transaction")

// ignore errors, access key is not required
authned, _ := requireAccessKey(c, tx, srv)

Expand All @@ -152,12 +169,15 @@ func validateRequestOrganization(c *gin.Context, tx *data.Transaction, srv *Serv
}
if org != nil {
authned.Organization = org
tx = tx.WithOrgID(authned.Organization.ID)
}

if err := tx.Commit(); err != nil {
return err
}

// TODO: move to caller
rCtx := access.RequestContext{
Request: c.Request,
DBTxn: tx,
Authenticated: authned,
}
c.Set(access.RequestContextKey, rCtx)
Expand Down Expand Up @@ -296,3 +316,14 @@ func reqBearerToken(c *gin.Context, opts Options) (string, error) {

return bearer, nil
}

// logError calls fn and writes a log line at the warning level if the error is
// not nil. The log level is a warning because the error is not handled, which
// generally indicates the problem is not a critical error.
// logError accepts a function instead of an error so that it can be used with
// defer.
func logError(fn func() error, msg string) {
if err := fn(); err != nil {
logging.L.Warn().Err(err).Msg(msg)
}
}
31 changes: 18 additions & 13 deletions internal/server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"github.com/infrahq/infra/api"
"github.com/infrahq/infra/internal"
"github.com/infrahq/infra/internal/access"
"github.com/infrahq/infra/internal/logging"
"github.com/infrahq/infra/internal/validate"
"github.com/infrahq/infra/metrics"
Expand Down Expand Up @@ -196,27 +197,19 @@ func wrapRoute[Req, Res any](a *API, routeID routeIdentifier, route route[Req, R
}
}

tx, err := a.server.db.Begin(c.Request.Context())
if err != nil {
return err
}
defer func() {
if err := tx.Rollback(); err != nil {
logging.L.Error().Err(err).Msg("failed to rollback database transaction")
}
}()

var err error
if route.noAuthentication {
err = validateRequestOrganization(c, tx, a.server)
err = validateRequestOrganization(c, a.server)
} else {
err = authenticateRequest(c, tx, a.server)
err = authenticateRequest(c, a.server)
}
if err != nil {
return err
}

rCtx := getRequestContext(c)
if !route.noOrgRequired {
if org := getRequestContext(c).Authenticated.Organization; org == nil {
if org := rCtx.Authenticated.Organization; org == nil {
return internal.ErrBadRequest
}
}
Expand All @@ -226,6 +219,18 @@ func wrapRoute[Req, Res any](a *API, routeID routeIdentifier, route route[Req, R
return err
}

tx, err := a.server.db.Begin(c.Request.Context())
if err != nil {
return err
}
defer logError(tx.Rollback, "failed to rollback request handler transaction")

if org := rCtx.Authenticated.Organization; org != nil {
tx = tx.WithOrgID(org.ID)
}
rCtx.DBTxn = tx
c.Set(access.RequestContextKey, rCtx)

resp, err := route.handler(c, req)
if err != nil {
return err
Expand Down

0 comments on commit b697a95

Please sign in to comment.