Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Audience support. #16

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 13 additions & 1 deletion api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/http"
"net/textproto"
"regexp"

"github.com/dgrijalva/jwt-go"
Expand All @@ -14,7 +15,10 @@ import (
"github.com/rs/cors"
)

const defaultVersion = "unknown version"
const (
audHeaderName = "X-JWT-AUD"
defaultVersion = "unknown version"
)

var bearerRegexp = regexp.MustCompile(`^(?:B|b)earer (\S+$)`)

Expand Down Expand Up @@ -54,6 +58,14 @@ func (a *API) requireAuthentication(ctx context.Context, w http.ResponseWriter,
return context.WithValue(ctx, "jwt", token)
}

func (a *API) requestAud(r *http.Request) string {
p := textproto.MIMEHeader(r.Header)
if h, exist := p[textproto.CanonicalMIMEHeaderKey(audHeaderName)]; exist && len(h) > 0 {
return h[0]
}
return a.config.JWT.Aud
}

// ListenAndServe starts the REST API
func (a *API) ListenAndServe(hostAndPort string) error {
return http.ListenAndServe(hostAndPort, a.handler)
Expand Down
7 changes: 6 additions & 1 deletion api/logout.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@ import (
func (a *API) Logout(ctx context.Context, w http.ResponseWriter, r *http.Request) {
token := getToken(ctx)

a.db.Logout(token.Claims["id"])
id, ok := token.Claims["id"].(string)
if !ok {
BadRequestError(w, "Could not read User ID claim")
return
}

a.db.Logout(id)
w.WriteHeader(204)
}
3 changes: 2 additions & 1 deletion api/recover.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ func (a *API) Recover(ctx context.Context, w http.ResponseWriter, r *http.Reques
return
}

user, err := a.db.FindUserByEmail(params.Email)
aud := a.requestAud(r)
user, err := a.db.FindUserByEmailAndAudience(params.Email, aud)
if err != nil {
if models.IsNotFoundError(err) {
NotFoundError(w, err.Error())
Expand Down
10 changes: 6 additions & 4 deletions api/signup.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,16 @@ func (a *API) Signup(ctx context.Context, w http.ResponseWriter, r *http.Request
return
}

user, err := a.db.FindUserByEmail(params.Email)
aud := a.requestAud(r)

user, err := a.db.FindUserByEmailAndAudience(params.Email, aud)
if err != nil {
if !models.IsNotFoundError(err) {
InternalServerError(w, err.Error())
return
}

user, err = a.signupNewUser(params)
user, err = a.signupNewUser(params, aud)
if err != nil {
InternalServerError(w, err.Error())
return
Expand All @@ -61,8 +63,8 @@ func (a *API) Signup(ctx context.Context, w http.ResponseWriter, r *http.Request
sendJSON(w, 200, user)
}

func (a *API) signupNewUser(params *SignupParams) (*models.User, error) {
user, err := models.NewUser(params.Email, params.Password, params.Data)
func (a *API) signupNewUser(params *SignupParams, aud string) (*models.User, error) {
user, err := models.NewUser(params.Email, params.Password, aud, params.Data)
if err != nil {
return nil, err
}
Expand Down
7 changes: 5 additions & 2 deletions api/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri
username := r.FormValue("username")
password := r.FormValue("password")

user, err := a.db.FindUserByEmail(username)
aud := a.requestAud(r)
user, err := a.db.FindUserByEmailAndAudience(username, aud)
if err != nil {
if models.IsNotFoundError(err) {
sendJSON(w, 400, &OAuthError{Error: "invalid_grant", Description: "No user found with this email"})
Expand Down Expand Up @@ -71,7 +72,8 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h
return
}

user, token, err := a.db.FindUserWithRefreshToken(tokenStr)
aud := a.requestAud(r)
user, token, err := a.db.FindUserWithRefreshToken(tokenStr, aud)
if err != nil {
if models.IsNotFoundError(err) {
sendJSON(w, 400, &OAuthError{Error: "invalid_grant", Description: "Invalid Refresh Token"})
Expand Down Expand Up @@ -111,6 +113,7 @@ func (a *API) generateAccessToken(user *models.User) (string, error) {

token.Claims["id"] = user.ID
token.Claims["email"] = user.Email
token.Claims["aud"] = user.Aud
token.Claims["exp"] = time.Now().Add(time.Second * time.Duration(a.config.JWT.Exp)).Unix()
token.Claims["app_metadata"] = user.AppMetaData
token.Claims["user_metadata"] = user.UserMetaData
Expand Down
14 changes: 13 additions & 1 deletion api/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,18 @@ func (a *API) UserGet(ctx context.Context, w http.ResponseWriter, r *http.Reques
return
}

tokenAud, ok := token.Claims["aud"].(string)
if !ok {
BadRequestError(w, "Could not read User Aud claim")
return
}

aud := a.requestAud(r)
if aud != tokenAud {
BadRequestError(w, "Token audience doesn't match request audience")
return
}

user, err := a.db.FindUserByID(id)
if err != nil {
if models.IsNotFoundError(err) {
Expand Down Expand Up @@ -72,7 +84,7 @@ func (a *API) UserUpdate(ctx context.Context, w http.ResponseWriter, r *http.Req

var sendChangeEmailVerification bool
if params.Email != "" {
exists, err := a.db.IsDuplicatedEmail(params.Email, user.ID)
exists, err := a.db.IsDuplicatedEmail(params.Email, user.Aud, user.ID)
if err != nil {
InternalServerError(w, err.Error())
return
Expand Down
1 change: 1 addition & 0 deletions conf/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type DBConfiguration struct {
type JWTConfiguration struct {
Secret string `json:"secret"`
Exp int `json:"exp"`
Aud string `json:"aud"`
AdminGroupName string `json:"admin_group_name"`
AdminGroupDisabled bool `json:"admin_group_disabled"`
}
Expand Down
3 changes: 2 additions & 1 deletion config.example.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
{
"jwt": {
"secret": "CHANGE-THIS! VERY IMPORTANT!",
"exp": 3600
"exp": 3600,
"aud": "api.netlify.com"
},
"db": {
"driver": "sqlite3",
Expand Down
4 changes: 3 additions & 1 deletion models/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
type User struct {
ID string `json:"id" bson:"_id,omitempty"`

Aud string `json:"aud" bson:"aud"`
Email string `json:"email" bson:"email"`
EncryptedPassword string `json:"-" bson:"encrypted_password"`
ConfirmedAt time.Time `json:"confirmed_at" bson:"confirmed_at"`
Expand All @@ -38,9 +39,10 @@ type User struct {
}

// NewUser initializes a new user from an email, password and user data.
func NewUser(email, password string, userData map[string]interface{}) (*User, error) {
func NewUser(email, password, aud string, userData map[string]interface{}) (*User, error) {
user := &User{
ID: uuid.NewRandom().String(),
Aud: aud,
Email: email,
UserMetaData: userData,
}
Expand Down
50 changes: 28 additions & 22 deletions storage/mongo/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,7 @@ func (conn *Connection) CreateUser(user *models.User) error {
return errors.Wrap(err, "Error creating user")
}

if !conn.config.JWT.AdminGroupDisabled {
v, err := c.Find(bson.M{"_id": bson.M{"$ne": user.ID}}).Count()
if err != nil {
return errors.Wrap(err, "Error making user an admin")
}

if v == 0 {
user.SetRole(conn.config.JWT.AdminGroupName)
if err := c.Update(bson.M{"_id": user.ID}, bson.M{"$set": user}); err != nil {
return errors.Wrap(err, "Error making user an admin")
}
}
}

return nil
return conn.makeUserAdmin(c, user)
}

func (conn *Connection) findUser(query bson.M) (*models.User, error) {
Expand All @@ -76,8 +62,8 @@ func (conn *Connection) FindUserByConfirmationToken(token string) (*models.User,
return conn.findUser(bson.M{"confirmation_token": token})
}

func (conn *Connection) FindUserByEmail(email string) (*models.User, error) {
return conn.findUser(bson.M{"email": email})
func (conn *Connection) FindUserByEmailAndAudience(email, aud string) (*models.User, error) {
return conn.findUser(bson.M{"email": email, "aud": aud})
}

func (conn *Connection) FindUserByID(id string) (*models.User, error) {
Expand All @@ -88,7 +74,7 @@ func (conn *Connection) FindUserByRecoveryToken(token string) (*models.User, err
return conn.findUser(bson.M{"recovery_token": token})
}

func (conn *Connection) FindUserWithRefreshToken(token string) (*models.User, *models.RefreshToken, error) {
func (conn *Connection) FindUserWithRefreshToken(token, aud string) (*models.User, *models.RefreshToken, error) {
refreshToken := &models.RefreshToken{}
rc := conn.db.C(refreshToken.TableName())

Expand All @@ -100,7 +86,7 @@ func (conn *Connection) FindUserWithRefreshToken(token string) (*models.User, *m
}
}

user, err := conn.findUser(bson.M{"_id": refreshToken.UserID})
user, err := conn.findUser(bson.M{"_id": refreshToken.UserID, "aud": aud})
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -164,8 +150,8 @@ func (conn *Connection) GrantRefreshTokenSwap(user *models.User, token *models.R
return newToken, nil
}

func (conn *Connection) IsDuplicatedEmail(email, id string) (bool, error) {
_, err := conn.findUser(bson.M{"email": email, "_id": bson.M{"$ne": id}})
func (conn *Connection) IsDuplicatedEmail(email, aud, id string) (bool, error) {
_, err := conn.findUser(bson.M{"email": email, "aud": aud, "_id": bson.M{"$ne": id}})
if err != nil {
if models.IsNotFoundError(err) {
return false, nil
Expand All @@ -177,12 +163,32 @@ func (conn *Connection) IsDuplicatedEmail(email, id string) (bool, error) {
return true, nil
}

func (conn *Connection) Logout(id interface{}) {
func (conn *Connection) Logout(id string) {
t := &models.RefreshToken{}
c := conn.db.C(t.TableName())
c.RemoveAll(bson.M{"user_id": id})
}

func (conn *Connection) makeUserAdmin(c *mgo.Collection, user *models.User) error {
if conn.config.JWT.AdminGroupDisabled {
return nil
}

v, err := c.Find(bson.M{"_id": bson.M{"$ne": user.ID}}).Count()
if err != nil {
return errors.Wrap(err, "Error making user an admin")
}

if v == 0 {
user.SetRole(conn.config.JWT.AdminGroupName)
if err := c.Update(bson.M{"_id": user.ID}, bson.M{"$set": user}); err != nil {
return errors.Wrap(err, "Error making user an admin")
}
}

return nil
}

func (conn *Connection) RevokeToken(token *models.RefreshToken) error {
token.Revoked = true

Expand Down
54 changes: 35 additions & 19 deletions storage/sql/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,32 @@ type Connection struct {
config *conf.Configuration
}

func (conn *Connection) Close() error {
return conn.db.Close()
}

func (conn *Connection) Automigrate() error {
conn.db = conn.db.AutoMigrate(&UserObj{}, &models.RefreshToken{})
return conn.db.Error
}

func (conn *Connection) Close() error {
return conn.db.Close()
}

func (conn *Connection) CreateUser(user *models.User) error {
obj := &UserObj{
User: user,
FirstRoleName: conn.config.JWT.AdminGroupName,
AutoAsignRoles: !conn.config.JWT.AdminGroupDisabled,
tx := conn.db.Begin()
if _, err := conn.createUserWithTransaction(tx, user); err != nil {
return err
}
tx.Commit()
return nil
}

if result := conn.db.Create(obj); result.Error != nil {
return errors.Wrap(result.Error, "Error creating user")
func (conn *Connection) createUserWithTransaction(tx *gorm.DB, user *models.User) (*UserObj, error) {
obj := conn.newUserObj(user)
if result := tx.Create(obj); result.Error != nil {
tx.Rollback()
return nil, errors.Wrap(result.Error, "Error creating user")
}
return nil

return obj, nil
}

func (conn *Connection) findUser(query string, args ...interface{}) (*models.User, error) {
Expand All @@ -70,8 +76,8 @@ func (conn *Connection) FindUserByConfirmationToken(token string) (*models.User,
return conn.findUser("confirmation_token = ?", token)
}

func (conn *Connection) FindUserByEmail(email string) (*models.User, error) {
return conn.findUser("email = ?", email)
func (conn *Connection) FindUserByEmailAndAudience(email, aud string) (*models.User, error) {
return conn.findUser("email = ? and aud = ?", email, aud)
}

func (conn *Connection) FindUserByID(id string) (*models.User, error) {
Expand All @@ -82,7 +88,7 @@ func (conn *Connection) FindUserByRecoveryToken(token string) (*models.User, err
return conn.findUser("recovery_token = ?", token)
}

func (conn *Connection) FindUserWithRefreshToken(token string) (*models.User, *models.RefreshToken, error) {
func (conn *Connection) FindUserWithRefreshToken(token, aud string) (*models.User, *models.RefreshToken, error) {
refreshToken := &models.RefreshToken{}
if result := conn.db.First(refreshToken, "token = ?", token); result.Error != nil {
if result.RecordNotFound() {
Expand All @@ -92,7 +98,7 @@ func (conn *Connection) FindUserWithRefreshToken(token string) (*models.User, *m
}
}

user, err := conn.findUser("id = ?", refreshToken.UserID)
user, err := conn.findUser("id = ? and aud = ?", refreshToken.UserID, aud)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -134,8 +140,8 @@ func (conn *Connection) GrantRefreshTokenSwap(user *models.User, token *models.R
return newToken, nil
}

func (conn *Connection) IsDuplicatedEmail(email, id string) (bool, error) {
_, err := conn.findUser("id != ? and email = ?", id, email)
func (conn *Connection) IsDuplicatedEmail(email, aud, id string) (bool, error) {
_, err := conn.findUser("id != ? and email = ? and aud = ?", id, email, aud)
if err != nil {
if models.IsNotFoundError(err) {
return false, nil
Expand All @@ -146,7 +152,7 @@ func (conn *Connection) IsDuplicatedEmail(email, id string) (bool, error) {
return true, nil
}

func (conn *Connection) Logout(id interface{}) {
func (conn *Connection) Logout(id string) {
conn.db.Where("user_id = ?", id).Delete(&models.RefreshToken{})
}

Expand Down Expand Up @@ -179,10 +185,20 @@ func (conn *Connection) RollbackRefreshTokenSwap(newToken, oldToken *models.Refr
}

func (conn *Connection) UpdateUser(user *models.User) error {
tx := conn.db.Begin()
if err := conn.updateUserWithTransaction(tx, user); err != nil {
return err
}
tx.Commit()
return nil
}

func (conn *Connection) updateUserWithTransaction(tx *gorm.DB, user *models.User) error {
obj := &UserObj{
User: user,
}
if result := conn.db.Save(obj); result.Error != nil {
if result := tx.Save(obj); result.Error != nil {
tx.Rollback()
return errors.Wrap(result.Error, "Error updating user record")
}
return nil
Expand Down