Skip to content

Commit

Permalink
Merge pull request #3299 from infrahq/dnephin/api-separate-transactio…
Browse files Browse the repository at this point in the history
…n-for-middleware

fix: reduce request transaction lifetime
  • Loading branch information
dnephin committed Sep 27, 2022
2 parents 6f70a4e + ee92dec commit 5406d44
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 22 deletions.
6 changes: 1 addition & 5 deletions internal/server/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -610,11 +610,7 @@ func (s Server) loadConfig(config Config) error {
if err != nil {
return err
}
defer func() {
if err := tx.Rollback(); err != nil {
logging.L.Error().Err(err).Msg("failed to rollback database transaction")
}
}()
defer logError(tx.Rollback, "failed to rollback loadConfig transaction")
tx = tx.WithOrgID(org.ID)

if config.DefaultOrganizationDomain != org.Domain {
Expand Down
39 changes: 35 additions & 4 deletions internal/server/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,13 @@ func handleInfraDestinationHeader(c *gin.Context) error {
// gin.Context.
// See validateRequestOrganization for a related function used for unauthenticated
// routes.
func authenticateRequest(c *gin.Context, tx *data.Transaction, srv *Server) error {
func authenticateRequest(c *gin.Context, srv *Server) error {
tx, err := srv.db.Begin(c.Request.Context())
if err != nil {
return err
}
defer logError(tx.Rollback, "failed to rollback middleware transaction")

authned, err := requireAccessKey(c, tx, srv)
if err != nil {
return err
Expand All @@ -86,6 +92,7 @@ func authenticateRequest(c *gin.Context, tx *data.Transaction, srv *Server) erro
return internal.ErrBadRequest
}

// TODO: move to caller
rCtx := access.RequestContext{
Request: c.Request,
DBTxn: tx.WithOrgID(authned.Organization.ID),
Expand All @@ -96,6 +103,10 @@ func authenticateRequest(c *gin.Context, tx *data.Transaction, srv *Server) erro
if err := handleInfraDestinationHeader(c); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return err
}
rCtx.DBTxn = nil
return nil
}

Expand Down Expand Up @@ -131,7 +142,13 @@ func validateOrgMatchesRequest(req *http.Request, tx data.GormTxn, accessKeyOrg
//
// validateRequestOrganization is also responsible for adding RequestContext to the
// gin.Context.
func validateRequestOrganization(c *gin.Context, tx *data.Transaction, srv *Server) error {
func validateRequestOrganization(c *gin.Context, srv *Server) error {
tx, err := srv.db.Begin(c.Request.Context())
if err != nil {
return err
}
defer logError(tx.Rollback, "failed to rollback middleware transaction")

// ignore errors, access key is not required
authned, _ := requireAccessKey(c, tx, srv)

Expand All @@ -152,12 +169,15 @@ func validateRequestOrganization(c *gin.Context, tx *data.Transaction, srv *Serv
}
if org != nil {
authned.Organization = org
tx = tx.WithOrgID(authned.Organization.ID)
}

if err := tx.Commit(); err != nil {
return err
}

// TODO: move to caller
rCtx := access.RequestContext{
Request: c.Request,
DBTxn: tx,
Authenticated: authned,
}
c.Set(access.RequestContextKey, rCtx)
Expand Down Expand Up @@ -296,3 +316,14 @@ func reqBearerToken(c *gin.Context, opts Options) (string, error) {

return bearer, nil
}

// logError calls fn and writes a log line at the warning level if the error is
// not nil. The log level is a warning because the error is not handled, which
// generally indicates the problem is not a critical error.
// logError accepts a function instead of an error so that it can be used with
// defer.
func logError(fn func() error, msg string) {
if err := fn(); err != nil {
logging.L.Warn().Err(err).Msg(msg)
}
}
31 changes: 18 additions & 13 deletions internal/server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"github.com/infrahq/infra/api"
"github.com/infrahq/infra/internal"
"github.com/infrahq/infra/internal/access"
"github.com/infrahq/infra/internal/logging"
"github.com/infrahq/infra/internal/server/redis"
"github.com/infrahq/infra/internal/validate"
Expand Down Expand Up @@ -197,26 +198,18 @@ func wrapRoute[Req, Res any](a *API, routeID routeIdentifier, route route[Req, R
}
}

tx, err := a.server.db.Begin(c.Request.Context())
if err != nil {
return err
}
defer func() {
if err := tx.Rollback(); err != nil {
logging.L.Error().Err(err).Msg("failed to rollback database transaction")
}
}()

var err error
if route.noAuthentication {
err = validateRequestOrganization(c, tx, a.server)
err = validateRequestOrganization(c, a.server)
} else {
err = authenticateRequest(c, tx, a.server)
err = authenticateRequest(c, a.server)
}
if err != nil {
return err
}

org := getRequestContext(c).Authenticated.Organization
rCtx := getRequestContext(c)
org := rCtx.Authenticated.Organization
if !route.noOrgRequired {
if org == nil {
return internal.ErrBadRequest
Expand All @@ -235,6 +228,18 @@ func wrapRoute[Req, Res any](a *API, routeID routeIdentifier, route route[Req, R
return err
}

tx, err := a.server.db.Begin(c.Request.Context())
if err != nil {
return err
}
defer logError(tx.Rollback, "failed to rollback request handler transaction")

if org := rCtx.Authenticated.Organization; org != nil {
tx = tx.WithOrgID(org.ID)
}
rCtx.DBTxn = tx
c.Set(access.RequestContextKey, rCtx)

resp, err := route.handler(c, req)
if err != nil {
return err
Expand Down

0 comments on commit 5406d44

Please sign in to comment.