Skip to content

Commit

Permalink
feat(auth/oidc): Use BadRequest response for state mismatches, enhanc…
Browse files Browse the repository at this point in the history
…e logging with context (#403)

* feat(core/auth/oidc): Use BadRequest response for state mismatches, enhance logging with context

* lint: avoid dynamic errors
  • Loading branch information
carstendietrich committed May 2, 2024
1 parent 939d802 commit 85bcd6f
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 25 deletions.
30 changes: 18 additions & 12 deletions core/auth/oauth/oidc.go
Expand Up @@ -109,6 +109,11 @@ var (

return ok
}

errNoStateInRequest = errors.New("no state in request")
errStateMismatch = errors.New("state mismatch")
errNoIDTokenClaim = errors.New("claim id_token missing")
errGeneric = errors.New("OpenID Connect error")
)

func oidcFactory(cfg config.Map) (auth.RequestIdentifier, error) {
Expand Down Expand Up @@ -443,12 +448,12 @@ func (i *openIDIdentifier) Authenticate(ctx context.Context, request *web.Reques

authConfig, err := i.config(request)
if err != nil {
return i.responder.ServerError(err)
return i.responder.ServerErrorWithContext(ctx, err)
}

u, err := url.Parse(authConfig.AuthCodeURL(state, options...))
if err != nil {
return i.responder.ServerError(err)
return i.responder.ServerErrorWithContext(ctx, err)
}

return i.responder.URLRedirect(u)
Expand All @@ -464,25 +469,26 @@ func (i *openIDIdentifier) Callback(ctx context.Context, request *web.Request, r
}
}

return i.responder.ServerError(fmt.Errorf("OpenID Connect error: %q (%q)", errString, errDetails))
return i.responder.ServerErrorWithContext(ctx, fmt.Errorf("%w: %q (%q)", errGeneric, errString, errDetails))
}

queryState, err := request.Query1("state")
if err != nil {
return i.responder.ServerError(errors.New("no state in request"))
return i.responder.BadRequestWithContext(ctx, errNoStateInRequest)
}

if !i.validateSessionCode(request, queryState) {
return i.responder.ServerError(errors.New("state mismatch"))
return i.responder.BadRequestWithContext(ctx, errStateMismatch)
}

code, err := request.Query1("code")
if err != nil {
return i.responder.ServerError(err)
return i.responder.BadRequestWithContext(ctx, fmt.Errorf("%w: code", err))
}

oauthConfig, err := i.config(request)
if err != nil {
return i.responder.ServerError(err)
return i.responder.ServerErrorWithContext(ctx, err)
}

options := make([]oauth2.AuthCodeOption, 0)
Expand All @@ -495,13 +501,13 @@ func (i *openIDIdentifier) Callback(ctx context.Context, request *web.Request, r

oauth2Token, err := oauthConfig.Exchange(ctx, code, options...)
if err != nil {
return i.responder.ServerError(err)
return i.responder.ServerErrorWithContext(ctx, err)
}

// Extract the ID Token from OAuth2 token.
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
return i.responder.ServerError(errors.New("claim id_token missing"))
return i.responder.ServerErrorWithContext(ctx, errNoIDTokenClaim)
}

verifierConfig := &oidc.Config{ClientID: i.oauth2Config.ClientID}
Expand All @@ -513,7 +519,7 @@ func (i *openIDIdentifier) Callback(ctx context.Context, request *web.Request, r
// Parse and verify ID Token payload.
idToken, err := verifier.Verify(ctx, rawIDToken)
if err != nil {
return i.responder.ServerError(err)
return i.responder.ServerErrorWithContext(ctx, err)
}

var (
Expand All @@ -523,7 +529,7 @@ func (i *openIDIdentifier) Callback(ctx context.Context, request *web.Request, r
)

if err := idToken.Claims(&tempIDTokenClaims); err != nil {
return i.responder.ServerError(err)
return i.responder.ServerErrorWithContext(ctx, err)
}
for k, v := range i.oidcConfig.Claims.IDToken {
idTokenClaims[k] = tempIDTokenClaims[v]
Expand Down Expand Up @@ -561,7 +567,7 @@ func (i *openIDIdentifier) Callback(ctx context.Context, request *web.Request, r
identity, err := i.Identify(ctx, request)
if err != nil {
i.Logout(ctx, request)
return i.responder.ServerError(err)
return i.responder.ServerErrorWithContext(ctx, err)
}

i.eventRouter.Dispatch(ctx, &auth.WebLoginEvent{Broker: i.broker, Request: request, Identity: identity})
Expand Down
4 changes: 2 additions & 2 deletions core/auth/oauth/oidc_test.go
Expand Up @@ -83,12 +83,12 @@ func TestParallelStateRaceConditions(t *testing.T) {
request.URL.RawQuery = url.Values{"state": []string{state2}}.Encode()
resp = identifier.Callback(context.Background(), web.CreateRequest(request, session), nil)
errResp = resp.(*web.ServerErrorResponse)
assert.EqualError(t, errResp.Error, "query value not found")
assert.EqualError(t, errResp.Error, "query value not found: code")

request.URL.RawQuery = url.Values{"state": []string{state1}}.Encode()
resp = identifier.Callback(context.Background(), web.CreateRequest(request, session), nil)
errResp = resp.(*web.ServerErrorResponse)
assert.EqualError(t, errResp.Error, "query value not found")
assert.EqualError(t, errResp.Error, "query value not found: code")

request.URL.RawQuery = url.Values{"state": []string{state1}}.Encode()
resp = identifier.Callback(context.Background(), web.CreateRequest(request, session), nil)
Expand Down
13 changes: 7 additions & 6 deletions core/requestlogger/logger_test.go
Expand Up @@ -30,12 +30,13 @@ func TestLogger(t *testing.T) {
request.Request().Header.Set("Referer", "https://example.com/")

responder := new(web.Responder).Inject(&web.Router{}, flamingo.NullLogger{}, &struct {
Engine flamingo.TemplateEngine "inject:\",optional\""
Debug bool "inject:\"config:flamingo.debug.mode\""
TemplateForbidden string "inject:\"config:flamingo.template.err403\""
TemplateNotFound string "inject:\"config:flamingo.template.err404\""
TemplateUnavailable string "inject:\"config:flamingo.template.err503\""
TemplateErrorWithCode string "inject:\"config:flamingo.template.errWithCode\""
Engine flamingo.TemplateEngine `inject:",optional"`
Debug bool `inject:"config:flamingo.debug.mode"`
TemplateBadRequest string `inject:"config:flamingo.template.err400"`
TemplateForbidden string `inject:"config:flamingo.template.err403"`
TemplateNotFound string `inject:"config:flamingo.template.err404"`
TemplateUnavailable string `inject:"config:flamingo.template.err503"`
TemplateErrorWithCode string `inject:"config:flamingo.template.errWithCode"`
}{})

tests := []struct {
Expand Down
1 change: 1 addition & 0 deletions framework/module.go
Expand Up @@ -104,6 +104,7 @@ flamingo: {
path?: string
}
template: {
err400: string | *"error/400"
err403: string | *"error/403"
err404: string | *"error/404"
errWithCode: string | *"error/withCode"
Expand Down
40 changes: 35 additions & 5 deletions framework/web/result.go
Expand Up @@ -32,6 +32,7 @@ type (
debug bool

templateForbidden string
templateBadRequest string
templateNotFound string
templateUnavailable string
templateErrorWithCode string
Expand Down Expand Up @@ -127,6 +128,7 @@ const (
func (r *Responder) Inject(router *Router, logger flamingo.Logger, cfg *struct {
Engine flamingo.TemplateEngine `inject:",optional"`
Debug bool `inject:"config:flamingo.debug.mode"`
TemplateBadRequest string `inject:"config:flamingo.template.err400"`
TemplateForbidden string `inject:"config:flamingo.template.err403"`
TemplateNotFound string `inject:"config:flamingo.template.err404"`
TemplateUnavailable string `inject:"config:flamingo.template.err503"`
Expand All @@ -135,6 +137,7 @@ func (r *Responder) Inject(router *Router, logger flamingo.Logger, cfg *struct {
r.engine = cfg.Engine
r.router = router
r.templateForbidden = cfg.TemplateForbidden
r.templateBadRequest = cfg.TemplateBadRequest
r.templateNotFound = cfg.TemplateNotFound
r.templateUnavailable = cfg.TemplateUnavailable
r.templateErrorWithCode = cfg.TemplateErrorWithCode
Expand Down Expand Up @@ -429,36 +432,63 @@ func (r *Responder) ServerErrorWithCodeAndTemplate(err error, tpl string, status

// ServerError creates a 500 error response
func (r *Responder) ServerError(err error) *ServerErrorResponse {
return r.ServerErrorWithContext(context.Background(), err)
}

// ServerErrorWithContext creates a 500 error response and uses the provided context for enhanced logging
func (r *Responder) ServerErrorWithContext(ctx context.Context, err error) *ServerErrorResponse {
if errors.Is(err, context.Canceled) {
r.getLogger().Debug(fmt.Sprintf("%+v\n", err))
r.getLogger().WithContext(ctx).Debug(fmt.Sprintf("%+v\n", err))
} else {
r.getLogger().Error(fmt.Sprintf("%+v\n", err))
r.getLogger().WithContext(ctx).Error(fmt.Sprintf("%+v\n", err))
}

return r.ServerErrorWithCodeAndTemplate(err, r.templateErrorWithCode, http.StatusInternalServerError)
}

// Unavailable creates a 503 error response
func (r *Responder) Unavailable(err error) *ServerErrorResponse {
r.getLogger().Error(fmt.Sprintf("%+v\n", err))
return r.UnavailableWithContext(context.Background(), err)
}

// UnavailableWithContext creates a 503 error response and uses the provided context for enhanced logging
func (r *Responder) UnavailableWithContext(ctx context.Context, err error) *ServerErrorResponse {
r.getLogger().WithContext(ctx).Error(fmt.Sprintf("%+v\n", err))

return r.ServerErrorWithCodeAndTemplate(err, r.templateUnavailable, http.StatusServiceUnavailable)
}

// NotFound creates a 404 error response
func (r *Responder) NotFound(err error) *ServerErrorResponse {
r.getLogger().Warn(err)
return r.NotFoundWithContext(context.Background(), err)
}

// NotFoundWithContext creates a 404 error response and uses the provided context for enhanced logging
func (r *Responder) NotFoundWithContext(ctx context.Context, err error) *ServerErrorResponse {
r.getLogger().WithContext(ctx).Warn(err)

return r.ServerErrorWithCodeAndTemplate(err, r.templateNotFound, http.StatusNotFound)
}

// Forbidden creates a 403 error response
func (r *Responder) Forbidden(err error) *ServerErrorResponse {
r.getLogger().Warn(err)
return r.ForbiddenWithContext(context.Background(), err)
}

// ForbiddenWithContext creates a 403 error response and uses the provided context for enhanced logging
func (r *Responder) ForbiddenWithContext(ctx context.Context, err error) *ServerErrorResponse {
r.getLogger().WithContext(ctx).Warn(err)

return r.ServerErrorWithCodeAndTemplate(err, r.templateForbidden, http.StatusForbidden)
}

// BadRequestWithContext creates a 400 error response and uses the provided context for enhanced logging
func (r *Responder) BadRequestWithContext(ctx context.Context, err error) *ServerErrorResponse {
r.getLogger().WithContext(ctx).Info(err)

return r.ServerErrorWithCodeAndTemplate(err, r.templateForbidden, http.StatusBadRequest)
}

// SetNoCache helper
func (r *ServerErrorResponse) SetNoCache() *ServerErrorResponse {
r.Response.SetNoCache()
Expand Down

0 comments on commit 85bcd6f

Please sign in to comment.