diff --git a/server/server.go b/server/server.go index 978105a..7c25e90 100644 --- a/server/server.go +++ b/server/server.go @@ -125,7 +125,7 @@ func Start(ctx context.Context, listener net.Listener, logger *zap.Logger, cfg * app := http.HandlerFunc(proxy.handler) if cfg.AuthVerify { if cfg.AuthVerifyRequireLogin { - http.Handle(cfg.AuthVerifyPath, middleware.RequireAccount(http.HandlerFunc(noContentHandler))) + http.Handle(cfg.AuthVerifyPath, authVerifyWithLogin(logger, proxy, middleware)) } else { http.Handle(cfg.AuthVerifyPath, authVerify(middleware)) } @@ -191,9 +191,21 @@ func setupHttpClient(idpCaFile string) (*http.Client, error) { return client, nil } -// HTTP handler that replies to each request with a “204 no content”. -func noContentHandler(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusNoContent) +func authVerifyWithLogin(logger *zap.Logger, proxy *Proxy, middleware *samlsp.Middleware) http.Handler { + return middleware.RequireAccount(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + session := samlsp.SessionFromContext(r.Context()) + + sessionClaims, ok := session.(samlsp.JWTSessionClaims) + if !ok { + logger.Error("session is not expected type") + w.WriteHeader(http.StatusInternalServerError) + return + } + + proxy.addHeaders(sessionClaims, w.Header()) // pass over SAML attrs as headers + + w.WriteHeader(http.StatusNoContent) + })) } func authVerify(middleware *samlsp.Middleware) http.Handler {