Skip to content
Permalink
Browse files

Include current user in request context

  • Loading branch information...
Jim McBeath
Jim McBeath committed Feb 25, 2018
1 parent 582cfea commit 545b636f536d950ee63facb38a4886e757b369ab
Showing with 102 additions and 15 deletions.
  1. +27 −4 auth/authapi.go
  2. +38 −3 auth/authapi_test.go
  3. +13 −3 auth/token.go
  4. +4 −1 auth/token_test.go
  5. +20 −4 users/users.go
@@ -1,12 +1,15 @@
package auth

import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"strconv"
"time"

"github.com/jimmc/mimsrv/users"
)

const (
@@ -17,6 +20,11 @@ type LoginStatus struct {
LoggedIn bool
}

type authKey int
const (
ctxUserKey = iota + 1
)

func (h *Handler) initApiHandler() {
mux := http.NewServeMux()
mux.HandleFunc(h.apiPrefix("login"), h.login)
@@ -30,14 +38,29 @@ func (h *Handler) RequireAuth(httpHandler http.Handler) http.Handler {
token := cookieValue(r, tokenCookieName)
idstr := clientIdString(r)
if isValidToken(token, idstr) {
httpHandler.ServeHTTP(w, r)
user := userFromToken(token)
mimRequest := requestWithContextUser(r, user)
httpHandler.ServeHTTP(w, mimRequest)
} else {
// No token, or token is not valid
http.Error(w, "Invalid token", http.StatusUnauthorized)
}
})
}

func requestWithContextUser(r *http.Request, user *users.User) *http.Request {
mimContext := context.WithValue(r.Context(), ctxUserKey, user)
return r.WithContext(mimContext)
}

func CurrentUser(r *http.Request) *users.User {
v := r.Context().Value(ctxUserKey)
if v == nil {
return nil
}
return v.(*users.User)
}

func (h *Handler) apiPrefix(s string) string {
return fmt.Sprintf("%s%s/", h.config.Prefix, s)
}
@@ -56,7 +79,7 @@ func (h *Handler) login(w http.ResponseWriter, r *http.Request) {
if user != nil && h.nonceIsValidNow(userid, nonce, seconds) {
// OK to log in; generate a bearer token and put in a cookie
idstr := clientIdString(r)
http.SetCookie(w, tokenCookie(userid, idstr))
http.SetCookie(w, tokenCookie(user, idstr))
} else {
http.Error(w, "Invalid userid or nonce", http.StatusUnauthorized)
return
@@ -66,8 +89,8 @@ func (h *Handler) login(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"status": "ok"}`))
}

func tokenCookie(userid, idstr string) *http.Cookie {
token := newToken(userid, idstr)
func tokenCookie(user *users.User, idstr string) *http.Cookie {
token := newToken(user, idstr)
return &http.Cookie{
Name: tokenCookieName,
Path: "/",
@@ -4,6 +4,9 @@ import (
"net/http"
"net/http/httptest"
"testing"

"github.com/jimmc/mimsrv/permissions"
"github.com/jimmc/mimsrv/users"
)

func TestRequireAuth(t *testing.T) {
@@ -18,9 +21,9 @@ func TestRequireAuth(t *testing.T) {
t.Fatalf("error create auth list request: %v", err)
}

handlerResult := ""
var reqUser *users.User
baseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerResult = "called"
reqUser = CurrentUser(r)
})
wrappedHandler := h.RequireAuth(baseHandler)

@@ -31,11 +34,43 @@ func TestRequireAuth(t *testing.T) {
}

rr = httptest.NewRecorder()
user := users.NewUser("user1", "cw1", nil)
idstr := clientIdString(req)
cookie := tokenCookie("user1", idstr)
cookie := tokenCookie(user, idstr)
req.AddCookie(cookie)
reqUser = nil
wrappedHandler.ServeHTTP(rr, req)
if got, want := rr.Code, http.StatusOK; got != want {
t.Errorf("request with auth: got status %d, want %d", got, want)
}
if reqUser == nil {
t.Errorf("authenicated request should carry a current user")
}
if got, want := reqUser.Id(), user.Id(); got != want {
t.Errorf("authenticated userid: got %s, want %s", got, want)
}
if got, want := reqUser.HasPermission(permissions.CanEdit), false; got != want {
t.Errorf("permission for CanEdit: got %v, want %v", got, want)
}

req, err = http.NewRequest("GET", "/api/list/d1", nil)
if err != nil {
t.Fatalf("error create auth list request: %v", err)
}
rr = httptest.NewRecorder()
user = users.NewUser("user1", "cw1", permissions.FromString("edit"))
idstr = clientIdString(req)
cookie = tokenCookie(user, idstr)
req.AddCookie(cookie)
reqUser = nil
wrappedHandler.ServeHTTP(rr, req)
if got, want := rr.Code, http.StatusOK; got != want {
t.Errorf("request with auth: got status %d, want %d", got, want)
}
if reqUser == nil {
t.Errorf("authenicated request should carry a current user")
}
if got, want := reqUser.HasPermission(permissions.CanEdit), true; got != want {
t.Errorf("permission for CanEdit: got %v, want %v", got, want)
}
}
@@ -4,6 +4,8 @@ import (
"fmt"
"math/rand"
"time"

"github.com/jimmc/mimsrv/users"
)

const (
@@ -16,7 +18,7 @@ var (

type Token struct {
Key string
userid string
user *users.User
idstr string
expiry time.Time
}
@@ -25,9 +27,9 @@ func initTokens() {
tokens = make(map[string]*Token)
}

func newToken(userid, idstr string) *Token {
func newToken(user *users.User, idstr string) *Token {
token := &Token{
userid: userid,
user: user,
idstr: idstr,
expiry: timeNow().Add(tokenExpirationDuration),
}
@@ -50,3 +52,11 @@ func isValidToken(tokenKey, idstr string) bool {
}
return true
}

func userFromToken(tokenKey string) *users.User {
token := tokens[tokenKey]
if token == nil {
return nil
}
return token.user
}
@@ -3,14 +3,17 @@ package auth
import (
"testing"
"time"

"github.com/jimmc/mimsrv/users"
)

func TestIsValid(t *testing.T) {
initTokens()
if isValidToken("user1", "id1") {
t.Fatal("token was deemed valid before any tokens added")
}
token := newToken("user1", "id1")
user1 := users.NewUser("user1", "cw1", nil)
token := newToken(user1, "id1")
if !isValidToken(token.Key, "id1") {
t.Fatalf("Token %s should be valid", token.Key)
}
@@ -125,6 +125,14 @@ func (m *Users) addRecord(userid, cryptword string, perms string) {
m.records = append(m.records, record)
}

func NewUser(userid, cryptword string, perms *permissions.Permissions) *User {
return &User{
userid: userid,
cryptword: cryptword,
perms: perms,
}
}

func (m *Users) User(userid string) *User {
return m.users[userid]
}
@@ -161,10 +169,7 @@ func (m *Users) HasPermission(userid string, perm permissions.Permission) bool {
if user == nil {
return false
}
if user.perms == nil {
return false
}
return user.perms.HasPermission(perm)
return user.HasPermission(perm)
}

func (u *User) Cryptword() string {
@@ -174,3 +179,14 @@ func (u *User) Cryptword() string {
func (u *User) SetCryptword(cryptword string) {
u.cryptword = cryptword
}

func (u *User) Id() string {
return u.userid
}

func (u *User) HasPermission(perm permissions.Permission) bool {
if u.perms == nil {
return false
}
return u.perms.HasPermission(perm)
}

0 comments on commit 545b636

Please sign in to comment.
You can’t perform that action at this time.