From ea771086ecef3f4d34e2842784845d9b6abf0eef Mon Sep 17 00:00:00 2001 From: Benjamin Buetikofer Date: Mon, 13 Nov 2023 21:18:24 +0100 Subject: [PATCH] auth improvements --- cmd/api/main.go | 2 +- cmd/api/middleware.go | 75 +++++++++++++++++-------------- cmd/api/url.go | 2 +- ui/html/pages/dashboard.tmpl.html | 3 +- 4 files changed, 45 insertions(+), 37 deletions(-) diff --git a/cmd/api/main.go b/cmd/api/main.go index e09bcb3..8100f23 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -93,7 +93,7 @@ func main() { sessionManager := scs.New() sessionManager.Store = postgresstore.New(dbd) - sessionManager.Lifetime = 7 * 24 * time.Hour + sessionManager.Lifetime = 14 * 24 * time.Hour app := &application{ sessionManager: sessionManager, diff --git a/cmd/api/middleware.go b/cmd/api/middleware.go index 03ab84a..dd1eae2 100644 --- a/cmd/api/middleware.go +++ b/cmd/api/middleware.go @@ -12,48 +12,57 @@ import ( func (app *application) authenticate(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - if app.isAuthenticated(c) { - return next(c) + if c.Request().Header.Get(echo.HeaderContentType) == echo.MIMEApplicationJSON { + return app.jsonAuthenticate(c, next) } - authorizationHeader := c.Request().Header.Get("Authorization") - if authorizationHeader == "" { - return c.JSON(http.StatusUnauthorized, "Unauthorized") + if !app.isAuthenticated(c) { + return c.Render(http.StatusUnauthorized, "login.tmpl.html", app.newTemplateData(c)) } + c.Request().Header.Set("Cache-Control", "no-store") + return next(c) - headerParts := strings.Split(authorizationHeader, " ") - if len(headerParts) != 2 || headerParts[0] != "Bearer" { - return c.JSON(http.StatusBadRequest, "Bad Request") - } + } +} - token := headerParts[1] - claims, err := jwt.HMACCheck([]byte(token), []byte(app.config.signingKey)) - if err != nil { - return c.JSON(http.StatusBadRequest, "Invalid Token") - } - if !claims.Valid(time.Now()) { - return c.JSON(http.StatusBadRequest, "Invalid Token") - } - if claims.Issuer != "shrink.ch" { - return c.JSON(http.StatusBadRequest, "Invalid Token") - } - if !claims.AcceptAudience("shrink.ch") { - return c.JSON(http.StatusBadRequest, "Invalid Token") - } +func (app *application) jsonAuthenticate(c echo.Context, next echo.HandlerFunc) error { + authorizationHeader := c.Request().Header.Get("Authorization") + if authorizationHeader == "" { + return c.JSON(http.StatusUnauthorized, "Unauthorized") + } - userID, err := uuid.Parse(claims.Subject) - if err != nil { - return c.JSON(http.StatusBadRequest, "Invalid Token") - } + headerParts := strings.Split(authorizationHeader, " ") + if len(headerParts) != 2 || headerParts[0] != "Bearer" { + return c.JSON(http.StatusBadRequest, "Bad Request") + } - user, err := app.models.Users.GetByID(userID) - if err != nil { - return c.JSON(http.StatusBadRequest, "Invalid Token") - } + token := headerParts[1] + claims, err := jwt.HMACCheck([]byte(token), []byte(app.config.signingKey)) + if err != nil { + return c.JSON(http.StatusBadRequest, "Invalid Token") + } + if !claims.Valid(time.Now()) { + return c.JSON(http.StatusBadRequest, "Invalid Token") + } + if claims.Issuer != "shrink.ch" { + return c.JSON(http.StatusBadRequest, "Invalid Token") + } + if !claims.AcceptAudience("shrink.ch") { + return c.JSON(http.StatusBadRequest, "Invalid Token") + } - c.Set("user", user) + userID, err := uuid.Parse(claims.Subject) + if err != nil { + return c.JSON(http.StatusBadRequest, "Invalid Token") + } - return next(c) + user, err := app.models.Users.GetByID(userID) + if err != nil { + return c.JSON(http.StatusBadRequest, "Invalid Token") } + + c.Set("user", user) + + return next(c) } func (app *application) requireRole(role string) echo.MiddlewareFunc { diff --git a/cmd/api/url.go b/cmd/api/url.go index d3633ec..d95a8b1 100644 --- a/cmd/api/url.go +++ b/cmd/api/url.go @@ -103,7 +103,7 @@ func (app *application) deleteUrlHandler(c echo.Context) error { urlUUID, err := uuid.Parse(c.Param("id")) if err != nil { app.sessionManager.Put(c.Request().Context(), "flash_error", "Bad Request?!") - return c.Render(http.StatusBadRequest, "dashboard.tmpl.html", app.newTemplateData(c)) + return app.dashboardHandler(c) } err = app.models.Urls.Delete(urlUUID) diff --git a/ui/html/pages/dashboard.tmpl.html b/ui/html/pages/dashboard.tmpl.html index 81d1108..84bdbed 100644 --- a/ui/html/pages/dashboard.tmpl.html +++ b/ui/html/pages/dashboard.tmpl.html @@ -20,8 +20,7 @@

Your URLs

- {{ .Original - }} + {{ .Original }} {{ .ShortUrl }}