Skip to content

Commit

Permalink
Session improvements (#510)
Browse files Browse the repository at this point in the history
  • Loading branch information
0xCA committed Jan 6, 2024
1 parent 46b0934 commit fa33d3f
Show file tree
Hide file tree
Showing 9 changed files with 274 additions and 36 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ docker-compose up
| `BIND_ADDRESS` | The addresses that can access to the web interface and the port, use unix:///abspath/to/file.socket for unix domain socket. | 0.0.0.0:80 |
| `SESSION_SECRET` | The secret key used to encrypt the session cookies. Set this to a random value | N/A |
| `SESSION_SECRET_FILE` | Optional filepath for the secret key used to encrypt the session cookies. Leave `SESSION_SECRET` blank to take effect | N/A |
| `SESSION_MAX_DURATION` | Max time in days a remembered session is refreshed and valid. Non-refreshed session is valid for 7 days max, regardless of this setting. | 90 |
| `SUBNET_RANGES` | The list of address subdivision ranges. Format: `SR Name:10.0.1.0/24; SR2:10.0.2.0/24,10.0.3.0/24` Each CIDR must be inside one of the server interfaces. | N/A |
| `WGUI_USERNAME` | The username for the login page. Used for db initialization only | `admin` |
| `WGUI_PASSWORD` | The password for the user on the login page. Will be hashed automatically. Used for db initialization only | `admin` |
Expand Down
23 changes: 16 additions & 7 deletions handler/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,32 +93,41 @@ func Login(db store.IStore) echo.HandlerFunc {
}

if userCorrect && passwordCorrect {
// TODO: refresh the token
ageMax := 0
expiration := time.Now().Add(24 * time.Hour)
if rememberMe {
ageMax = 86400
expiration.Add(144 * time.Hour)
ageMax = 86400 * 7
}

cookiePath := util.GetCookiePath()

sess, _ := session.Get("session", c)
sess.Options = &sessions.Options{
Path: util.BasePath,
Path: cookiePath,
MaxAge: ageMax,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}

// set session_token
tokenUID := xid.New().String()
now := time.Now().UTC().Unix()
sess.Values["username"] = dbuser.Username
sess.Values["user_hash"] = util.GetDBUserCRC32(dbuser)
sess.Values["admin"] = dbuser.Admin
sess.Values["session_token"] = tokenUID
sess.Values["max_age"] = ageMax
sess.Values["created_at"] = now
sess.Values["updated_at"] = now
sess.Save(c.Request(), c.Response())

// set session_token in cookie
cookie := new(http.Cookie)
cookie.Name = "session_token"
cookie.Path = cookiePath
cookie.Value = tokenUID
cookie.Expires = expiration
cookie.MaxAge = ageMax
cookie.HttpOnly = true
cookie.SameSite = http.SameSiteLaxMode
c.SetCookie(cookie)

return c.JSON(http.StatusOK, jsonHTTPResponse{true, "Logged in successfully"})
Expand Down Expand Up @@ -256,7 +265,7 @@ func UpdateUser(db store.IStore) echo.HandlerFunc {
log.Infof("Updated user information successfully")

if previousUsername == currentUser(c) {
setUser(c, user.Username, user.Admin)
setUser(c, user.Username, user.Admin, util.GetDBUserCRC32(user))
}

return c.JSON(http.StatusOK, jsonHTTPResponse{true, "Updated user information successfully"})
Expand Down
168 changes: 167 additions & 1 deletion handler/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package handler
import (
"fmt"
"net/http"
"time"

"github.com/gorilla/sessions"
"github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/v4"
"github.com/ngoduykhanh/wireguard-ui/util"
Expand All @@ -23,6 +25,15 @@ func ValidSession(next echo.HandlerFunc) echo.HandlerFunc {
}
}

// RefreshSession must only be used after ValidSession middleware
// RefreshSession checks if the session is eligible for the refresh, but doesn't check if it's fully valid
func RefreshSession(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
doRefreshSession(c)
return next(c)
}
}

func NeedsAdmin(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if !isAdmin(c) {
Expand All @@ -41,9 +52,146 @@ func isValidSession(c echo.Context) bool {
if err != nil || sess.Values["session_token"] != cookie.Value {
return false
}

// Check time bounds
createdAt := getCreatedAt(sess)
updatedAt := getUpdatedAt(sess)
maxAge := getMaxAge(sess)
// Temporary session is considered valid within 24h if browser is not closed before
// This value is not saved and is used as virtual expiration
if maxAge == 0 {
maxAge = 86400
}
expiration := updatedAt + int64(maxAge)
now := time.Now().UTC().Unix()
if updatedAt > now || expiration < now || createdAt+util.SessionMaxDuration < now {
return false
}

// Check if user still exists and unchanged
username := fmt.Sprintf("%s", sess.Values["username"])
userHash := getUserHash(sess)
if uHash, ok := util.DBUsersToCRC32[username]; !ok || userHash != uHash {
return false
}

return true
}

// Refreshes a "remember me" session when the user visits web pages (not API)
// Session must be valid before calling this function
// Refresh is performed at most once per 24h
func doRefreshSession(c echo.Context) {
if util.DisableLogin {
return
}

sess, _ := session.Get("session", c)
maxAge := getMaxAge(sess)
if maxAge <= 0 {
return
}

oldCookie, err := c.Cookie("session_token")
if err != nil || sess.Values["session_token"] != oldCookie.Value {
return
}

// Refresh no sooner than 24h
createdAt := getCreatedAt(sess)
updatedAt := getUpdatedAt(sess)
expiration := updatedAt + int64(getMaxAge(sess))
now := time.Now().UTC().Unix()
if updatedAt > now || expiration < now || now-updatedAt < 86_400 || createdAt+util.SessionMaxDuration < now {
return
}

cookiePath := util.GetCookiePath()

sess.Values["updated_at"] = now
sess.Options = &sessions.Options{
Path: cookiePath,
MaxAge: maxAge,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
sess.Save(c.Request(), c.Response())

cookie := new(http.Cookie)
cookie.Name = "session_token"
cookie.Path = cookiePath
cookie.Value = oldCookie.Value
cookie.MaxAge = maxAge
cookie.HttpOnly = true
cookie.SameSite = http.SameSiteLaxMode
c.SetCookie(cookie)
}

// Get time in seconds this session is valid without updating
func getMaxAge(sess *sessions.Session) int {
if util.DisableLogin {
return 0
}

maxAge := sess.Values["max_age"]

switch typedMaxAge := maxAge.(type) {
case int:
return typedMaxAge
default:
return 0
}
}

// Get a timestamp in seconds of the time the session was created
func getCreatedAt(sess *sessions.Session) int64 {
if util.DisableLogin {
return 0
}

createdAt := sess.Values["created_at"]

switch typedCreatedAt := createdAt.(type) {
case int64:
return typedCreatedAt
default:
return 0
}
}

// Get a timestamp in seconds of the last session update
func getUpdatedAt(sess *sessions.Session) int64 {
if util.DisableLogin {
return 0
}

lastUpdate := sess.Values["updated_at"]

switch typedLastUpdate := lastUpdate.(type) {
case int64:
return typedLastUpdate
default:
return 0
}
}

// Get CRC32 of a user at the moment of log in
// Any changes to user will result in logout of other (not updated) sessions
func getUserHash(sess *sessions.Session) uint32 {
if util.DisableLogin {
return 0
}

userHash := sess.Values["user_hash"]

switch typedUserHash := userHash.(type) {
case uint32:
return typedUserHash
default:
return 0
}
}

// currentUser to get username of logged in user
func currentUser(c echo.Context) string {
if util.DisableLogin {
Expand All @@ -66,9 +214,10 @@ func isAdmin(c echo.Context) bool {
return admin == "true"
}

func setUser(c echo.Context, username string, admin bool) {
func setUser(c echo.Context, username string, admin bool, userCRC32 uint32) {
sess, _ := session.Get("session", c)
sess.Values["username"] = username
sess.Values["user_hash"] = userCRC32
sess.Values["admin"] = admin
sess.Save(c.Request(), c.Response())
}
Expand All @@ -77,7 +226,24 @@ func setUser(c echo.Context, username string, admin bool) {
func clearSession(c echo.Context) {
sess, _ := session.Get("session", c)
sess.Values["username"] = ""
sess.Values["user_hash"] = 0
sess.Values["admin"] = false
sess.Values["session_token"] = ""
sess.Values["max_age"] = -1
sess.Options.MaxAge = -1
sess.Save(c.Request(), c.Response())

cookiePath := util.GetCookiePath()

cookie, err := c.Cookie("session_token")
if err != nil {
cookie = new(http.Cookie)
}

cookie.Name = "session_token"
cookie.Path = cookiePath
cookie.MaxAge = -1
cookie.HttpOnly = true
cookie.SameSite = http.SameSiteLaxMode
c.SetCookie(cookie)
}
20 changes: 12 additions & 8 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"crypto/sha512"
"embed"
"flag"
"fmt"
Expand Down Expand Up @@ -48,6 +49,7 @@ var (
flagTelegramAllowConfRequest = false
flagTelegramFloodWait = 60
flagSessionSecret = util.RandomString(32)
flagSessionMaxDuration = 90
flagWgConfTemplate string
flagBasePath string
flagSubnetRanges string
Expand Down Expand Up @@ -91,6 +93,7 @@ func init() {
flag.StringVar(&flagWgConfTemplate, "wg-conf-template", util.LookupEnvOrString("WG_CONF_TEMPLATE", flagWgConfTemplate), "Path to custom wg.conf template.")
flag.StringVar(&flagBasePath, "base-path", util.LookupEnvOrString("BASE_PATH", flagBasePath), "The base path of the URL")
flag.StringVar(&flagSubnetRanges, "subnet-ranges", util.LookupEnvOrString("SUBNET_RANGES", flagSubnetRanges), "IP ranges to choose from when assigning an IP for a client.")
flag.IntVar(&flagSessionMaxDuration, "session-max-duration", util.LookupEnvOrInt("SESSION_MAX_DURATION", flagSessionMaxDuration), "Max time in days a remembered session is refreshed and valid.")

var (
smtpPasswordLookup = util.LookupEnvOrString("SMTP_PASSWORD", flagSmtpPassword)
Expand Down Expand Up @@ -135,7 +138,8 @@ func init() {
util.SendgridApiKey = flagSendgridApiKey
util.EmailFrom = flagEmailFrom
util.EmailFromName = flagEmailFromName
util.SessionSecret = []byte(flagSessionSecret)
util.SessionSecret = sha512.Sum512([]byte(flagSessionSecret))
util.SessionMaxDuration = int64(flagSessionMaxDuration) * 86_400 // Store in seconds
util.WgConfTemplate = flagWgConfTemplate
util.BasePath = util.ParseBasePath(flagBasePath)
util.SubnetRanges = util.ParseSubnetRanges(flagSubnetRanges)
Expand Down Expand Up @@ -204,7 +208,7 @@ func main() {
// register routes
app := router.New(tmplDir, extraData, util.SessionSecret)

app.GET(util.BasePath, handler.WireGuardClients(db), handler.ValidSession)
app.GET(util.BasePath, handler.WireGuardClients(db), handler.ValidSession, handler.RefreshSession)

// Important: Make sure that all non-GET routes check the request content type using handler.ContentTypeJson to
// mitigate CSRF attacks. This is effective, because browsers don't allow setting the Content-Type header on
Expand All @@ -214,8 +218,8 @@ func main() {
app.GET(util.BasePath+"/login", handler.LoginPage())
app.POST(util.BasePath+"/login", handler.Login(db), handler.ContentTypeJson)
app.GET(util.BasePath+"/logout", handler.Logout(), handler.ValidSession)
app.GET(util.BasePath+"/profile", handler.LoadProfile(), handler.ValidSession)
app.GET(util.BasePath+"/users-settings", handler.UsersSettings(), handler.ValidSession, handler.NeedsAdmin)
app.GET(util.BasePath+"/profile", handler.LoadProfile(), handler.ValidSession, handler.RefreshSession)
app.GET(util.BasePath+"/users-settings", handler.UsersSettings(), handler.ValidSession, handler.RefreshSession, handler.NeedsAdmin)
app.POST(util.BasePath+"/update-user", handler.UpdateUser(db), handler.ValidSession, handler.ContentTypeJson)
app.POST(util.BasePath+"/create-user", handler.CreateUser(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin)
app.POST(util.BasePath+"/remove-user", handler.RemoveUser(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin)
Expand All @@ -241,19 +245,19 @@ func main() {
app.POST(util.BasePath+"/client/set-status", handler.SetClientStatus(db), handler.ValidSession, handler.ContentTypeJson)
app.POST(util.BasePath+"/remove-client", handler.RemoveClient(db), handler.ValidSession, handler.ContentTypeJson)
app.GET(util.BasePath+"/download", handler.DownloadClient(db), handler.ValidSession)
app.GET(util.BasePath+"/wg-server", handler.WireGuardServer(db), handler.ValidSession, handler.NeedsAdmin)
app.GET(util.BasePath+"/wg-server", handler.WireGuardServer(db), handler.ValidSession, handler.RefreshSession, handler.NeedsAdmin)
app.POST(util.BasePath+"/wg-server/interfaces", handler.WireGuardServerInterfaces(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin)
app.POST(util.BasePath+"/wg-server/keypair", handler.WireGuardServerKeyPair(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin)
app.GET(util.BasePath+"/global-settings", handler.GlobalSettings(db), handler.ValidSession, handler.NeedsAdmin)
app.GET(util.BasePath+"/global-settings", handler.GlobalSettings(db), handler.ValidSession, handler.RefreshSession, handler.NeedsAdmin)
app.POST(util.BasePath+"/global-settings", handler.GlobalSettingSubmit(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin)
app.GET(util.BasePath+"/status", handler.Status(db), handler.ValidSession)
app.GET(util.BasePath+"/status", handler.Status(db), handler.ValidSession, handler.RefreshSession)
app.GET(util.BasePath+"/api/clients", handler.GetClients(db), handler.ValidSession)
app.GET(util.BasePath+"/api/client/:id", handler.GetClient(db), handler.ValidSession)
app.GET(util.BasePath+"/api/machine-ips", handler.MachineIPAddresses(), handler.ValidSession)
app.GET(util.BasePath+"/api/subnet-ranges", handler.GetOrderedSubnetRanges(), handler.ValidSession)
app.GET(util.BasePath+"/api/suggest-client-ips", handler.SuggestIPAllocation(db), handler.ValidSession)
app.POST(util.BasePath+"/api/apply-wg-config", handler.ApplyServerConfig(db, tmplDir), handler.ValidSession, handler.ContentTypeJson)
app.GET(util.BasePath+"/wake_on_lan_hosts", handler.GetWakeOnLanHosts(db), handler.ValidSession)
app.GET(util.BasePath+"/wake_on_lan_hosts", handler.GetWakeOnLanHosts(db), handler.ValidSession, handler.RefreshSession)
app.POST(util.BasePath+"/wake_on_lan_host", handler.SaveWakeOnLanHost(db), handler.ValidSession, handler.ContentTypeJson)
app.DELETE(util.BasePath+"/wake_on_lan_host/:mac_address", handler.DeleteWakeOnHost(db), handler.ValidSession, handler.ContentTypeJson)
app.PUT(util.BasePath+"/wake_on_lan_host/:mac_address", handler.WakeOnHost(db), handler.ValidSession, handler.ContentTypeJson)
Expand Down
12 changes: 10 additions & 2 deletions router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,17 @@ func (t *TemplateRegistry) Render(w io.Writer, name string, data interface{}, c
}

// New function
func New(tmplDir fs.FS, extraData map[string]interface{}, secret []byte) *echo.Echo {
func New(tmplDir fs.FS, extraData map[string]interface{}, secret [64]byte) *echo.Echo {
e := echo.New()
e.Use(session.Middleware(sessions.NewCookieStore(secret)))

cookiePath := util.GetCookiePath()

cookieStore := sessions.NewCookieStore(secret[:32], secret[32:])
cookieStore.Options.Path = cookiePath
cookieStore.Options.HttpOnly = true
cookieStore.MaxAge(86400 * 7)

e.Use(session.Middleware(cookieStore))

// read html template file to string
tmplBaseString, err := util.StringFromEmbedFile(tmplDir, "base.html")
Expand Down

0 comments on commit fa33d3f

Please sign in to comment.