Skip to content

Commit

Permalink
feat: user token can be used to authenticate a request
Browse files Browse the repository at this point in the history
Signed-off-by: Kush Sharma <thekushsharma@gmail.com>
  • Loading branch information
kushsharma committed Jun 7, 2023
1 parent 774288e commit e84f056
Show file tree
Hide file tree
Showing 28 changed files with 1,846 additions and 1,691 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ GOVERSION := $(shell go version | cut -d ' ' -f 3 | cut -d '.' -f 2)

.PHONY: build check fmt lint test test-race vet test-cover-html help install proto ui
.DEFAULT_GOAL := build
PROTON_COMMIT := "e383abda68a4543eaf09fa578ce8862465fdf3aa"
PROTON_COMMIT := "75601b9e0c409299789b6a373f22d9362dfaea69"

ui:
@echo " > generating ui build"
Expand Down
17 changes: 15 additions & 2 deletions cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ import (
"syscall"
"time"

"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/odpf/shield/core/authenticate/token"

"github.com/odpf/shield/pkg/server"

"github.com/odpf/shield/pkg/server/consts"
Expand Down Expand Up @@ -181,6 +184,16 @@ func buildAPIDependencies(
dbc *db.Client,
sdb *spicedb.SpiceDB,
) (api.Deps, error) {
var tokenKeySet jwk.Set
if len(cfg.App.Authentication.Token.RSAPath) > 0 {
if ks, err := jwk.ReadFile(cfg.App.Authentication.Token.RSAPath); err != nil {
return api.Deps{}, fmt.Errorf("failed to parse rsa key: %w", err)
} else {
tokenKeySet = ks
}
}
tokenService := token.NewService(tokenKeySet, cfg.App.Authentication.Token.Issuer)

sessionService := session.NewService(logger, postgres.NewSessionRepository(logger, dbc), consts.SessionValidity)

namespaceRepository := postgres.NewNamespaceRepository(dbc)
Expand All @@ -202,7 +215,7 @@ func buildAPIDependencies(
roleService := role.NewService(roleRepository, relationService, permissionService)

userRepository := postgres.NewUserRepository(dbc)
userService := user.NewService(userRepository, sessionService, relationService)
userService := user.NewService(userRepository, sessionService, relationService, tokenService)

groupRepository := postgres.NewGroupRepository(dbc)
groupService := group.NewService(groupRepository, relationService, userService)
Expand Down Expand Up @@ -246,7 +259,7 @@ func buildAPIDependencies(
)
}
registrationService := authenticate.NewRegistrationService(logger, cfg.App.Authentication,
postgres.NewFlowRepository(logger, dbc), userService, mailDialer)
postgres.NewFlowRepository(logger, dbc), userService, mailDialer, tokenService)

invitationService := invitation.NewService(mailDialer, postgres.NewInvitationRepository(logger, dbc),
organizationService, groupService, userService, relationService)
Expand Down
71 changes: 19 additions & 52 deletions core/authenticate/registration_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,34 @@ import (
"strings"
"time"

"github.com/odpf/shield/core/authenticate/token"

"github.com/odpf/shield/pkg/utils"

"github.com/odpf/shield/pkg/mailer"

"github.com/odpf/salt/log"

"github.com/google/uuid"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/odpf/shield/core/authenticate/strategy"
"github.com/odpf/shield/core/organization"
"github.com/odpf/shield/core/user"
"github.com/odpf/shield/pkg/str"
"github.com/robfig/cron/v3"
)

const (
// TODO(kushsharma): should we expose this in config?
tokenValidity = time.Hour * 24 * 14
defaultFlowExp = time.Minute * 10
maxOTPAttempt = 5
otpAttemptKey = "attempt"
)

var (
refreshTime = "0 0 * * *" // Once a day at midnight
ErrMissingRSADisableToken = errors.New("rsa key missing in config, generate and pass file path")
ErrStrategyNotApplicable = errors.New("strategy not applicable")
ErrUnsupportedMethod = errors.New("unsupported authentication method")
ErrInvalidMailOTP = errors.New("invalid mail otp")
ErrFlowInvalid = errors.New("invalid flow or expired")
refreshTime = "0 0 * * *" // Once a day at midnight
ErrStrategyNotApplicable = errors.New("strategy not applicable")
ErrUnsupportedMethod = errors.New("unsupported authentication method")
ErrInvalidMailOTP = errors.New("invalid mail otp")
ErrFlowInvalid = errors.New("invalid flow or expired")
)

type UserService interface {
Expand Down Expand Up @@ -78,17 +74,18 @@ type RegistrationFinishResponse struct {
}

type RegistrationService struct {
log log.Logger
cron *cron.Cron
flowRepo FlowRepository
userService UserService
config Config
mailDialer mailer.Dialer
Now func() time.Time
log log.Logger
cron *cron.Cron
flowRepo FlowRepository
userService UserService
config Config
mailDialer mailer.Dialer
Now func() time.Time
tokenService token.Service
}

func NewRegistrationService(logger log.Logger, config Config, flowRepo FlowRepository,
userService UserService, mailDialer mailer.Dialer) *RegistrationService {
userService UserService, mailDialer mailer.Dialer, tokenService token.Service) *RegistrationService {
r := &RegistrationService{
log: logger,
cron: cron.New(),
Expand All @@ -99,6 +96,7 @@ func NewRegistrationService(logger log.Logger, config Config, flowRepo FlowRepos
Now: func() time.Time {
return time.Now().UTC()
},
tokenService: tokenService,
}
return r
}
Expand Down Expand Up @@ -314,39 +312,8 @@ func (r RegistrationService) applyOIDC(ctx context.Context, request Registration
}, nil
}

func (r RegistrationService) Token(user user.User, orgs []organization.Organization) ([]byte, error) {
if len(r.config.Token.RSAPath) == 0 {
return nil, ErrMissingRSADisableToken
}
keySet, err := jwk.ReadFile(r.config.Token.RSAPath)
if err != nil {
return nil, fmt.Errorf("failed to parse rsa key: %w", err)
}
// use first key to sign token
rsaKey, ok := keySet.Key(0)
if !ok {
return nil, errors.New("missing rsa key to generate token")
}

var orgNames []string
for _, o := range orgs {
orgNames = append(orgNames, o.Name)
}

tok, err := jwt.NewBuilder().
Issuer(r.config.Token.Issuer).
IssuedAt(time.Now().UTC()).
NotBefore(time.Now().UTC()).
Expiration(time.Now().UTC().Add(tokenValidity)).
JwtID(uuid.New().String()).
Subject(user.ID).
Claim("org", strings.Join(orgNames, ",")).
Build()
if err != nil {
return nil, err
}

return jwt.Sign(tok, jwt.WithKey(jwa.RS256, rsaKey))
func (r RegistrationService) Token(ctx context.Context, user user.User, metadata map[string]string) ([]byte, error) {
return r.tokenService.Build(ctx, user.ID, metadata)
}

func (r RegistrationService) InitFlows(ctx context.Context) error {
Expand Down
96 changes: 96 additions & 0 deletions core/authenticate/token/service.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package token

import (
"context"
"errors"
"fmt"
"time"

"github.com/google/uuid"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/odpf/shield/pkg/server/consts"
"google.golang.org/grpc/metadata"
)

var (
ErrMissingRSADisableToken = errors.New("rsa key missing in config, generate and pass file path")
ErrInvalidToken = errors.New("failed to verify a valid token")
ErrNoToken = errors.New("no token")
)

const (
// TODO(kushsharma): should we expose this in config?
tokenValidity = time.Hour * 24 * 7
)

type Service struct {
keySet jwk.Set
issuer string
}

func NewService(keySet jwk.Set, issuer string) Service {
return Service{
keySet: keySet,
issuer: issuer,
}
}

func (s Service) Build(ctx context.Context, userID string, metadata map[string]string) ([]byte, error) {
if s.keySet == nil {
return nil, ErrMissingRSADisableToken
}

// use first key to sign token
rsaKey, ok := s.keySet.Key(0)
if !ok {
return nil, errors.New("missing rsa key to generate token")
}

body := jwt.NewBuilder().
Issuer(s.issuer).
IssuedAt(time.Now().UTC()).
NotBefore(time.Now().UTC()).
Expiration(time.Now().UTC().Add(tokenValidity)).
JwtID(uuid.New().String()).
Subject(userID)
for claimKey, claimVal := range metadata {
body = body.Claim(claimKey, claimVal)
}

tok, err := body.Build()
if err != nil {
return nil, err
}
return jwt.Sign(tok, jwt.WithKey(jwa.RS256, rsaKey))
}

func (s Service) Parse(ctx context.Context, userToken []byte) (string, map[string]any, error) {
if s.keySet == nil {
return "", nil, ErrMissingRSADisableToken
}
// verify token with jwks
verifiedToken, err := jwt.Parse(userToken, jwt.WithKeySet(s.keySet))
if err != nil {
return "", nil, fmt.Errorf("%s: %w", ErrInvalidToken.Error(), err)
}
tokenClaims, err := verifiedToken.AsMap(ctx)
if err != nil {
return "", nil, fmt.Errorf("%s: %w", ErrInvalidToken.Error(), err)
}
return verifiedToken.Subject(), tokenClaims, nil
}

func (s Service) ParseFromContext(ctx context.Context) (string, map[string]any, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return "", nil, ErrNoToken
}

tokenHeaders := md.Get(consts.UserTokenGatewayKey)
if len(tokenHeaders) == 0 || len(tokenHeaders[0]) == 0 {
return "", nil, ErrNoToken
}
return s.Parse(ctx, []byte(tokenHeaders[0]))
}
3 changes: 3 additions & 0 deletions core/resource/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ func (s Service) CheckAuthz(ctx context.Context, res Resource, permissionName st
return false, err
}

// TODO(kushsharma): a user can pass object name instead of id in the request
// we should support converting name to id based on object namespace

return s.relationService.CheckPermission(ctx, relation.Subject{
ID: currentUser.ID,
// TODO(kushsharma): refactor this to also support app/serviceuser
Expand Down
11 changes: 10 additions & 1 deletion core/user/context.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package user

import "context"
import (
"context"

"github.com/odpf/shield/pkg/server/consts"
)

type contextEmailKey struct{}

Expand All @@ -12,3 +16,8 @@ func GetEmailFromContext(ctx context.Context) (string, bool) {
email, ok := ctx.Value(contextEmailKey{}).(string)
return email, ok
}

func GetUserFromContext(ctx context.Context) (User, bool) {
u, ok := ctx.Value(consts.AuthenticatedUserContextKey).(*User)
return *u, ok
}
31 changes: 30 additions & 1 deletion core/user/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"strings"
"time"

"github.com/odpf/shield/core/authenticate/token"

"github.com/odpf/shield/pkg/utils"

shieldsession "github.com/odpf/shield/core/authenticate/session"
Expand All @@ -27,18 +29,25 @@ type RelationRepository interface {
LookupResources(ctx context.Context, rel relation.Relation) ([]string, error)
}

type TokenService interface {
ParseFromContext(ctx context.Context) (string, map[string]any, error)
}

type Service struct {
repository Repository
relationService RelationRepository
sessionService SessionService
tokenService TokenService
Now func() time.Time
}

func NewService(repository Repository, sessionService SessionService, relationRepo RelationRepository) *Service {
func NewService(repository Repository, sessionService SessionService,
relationRepo RelationRepository, tokenService TokenService) *Service {
return &Service{
repository: repository,
sessionService: sessionService,
relationService: relationRepo,
tokenService: tokenService,
Now: func() time.Time {
return time.Now().UTC()
},
Expand Down Expand Up @@ -113,6 +122,11 @@ func (s Service) UpdateByEmail(ctx context.Context, toUpdate User) (User, error)

func (s Service) FetchCurrentUser(ctx context.Context) (User, error) {
var currentUser User
// check if already enriched by auth middleware
if val, ok := GetUserFromContext(ctx); ok {
currentUser = val
return currentUser, nil
}

// extract user from session if present
session, err := s.sessionService.ExtractFromContext(ctx)
Expand All @@ -128,7 +142,22 @@ func (s Service) FetchCurrentUser(ctx context.Context) (User, error) {
return User{}, err
}

// extract user from token if present
userID, _, err := s.tokenService.ParseFromContext(ctx)
if err == nil && utils.IsValidUUID(userID) {
// userID is a valid uuid
currentUser, err = s.GetByID(ctx, userID)
if err != nil {
return User{}, err
}
return currentUser, nil
}
if err != nil && !errors.Is(err, token.ErrNoToken) {
return User{}, err
}

// check if header with user email is set
// TODO(kushsharma): this should ideally be deprecated
if val, ok := GetEmailFromContext(ctx); ok && len(val) > 0 {
currentUser, err = s.GetByEmail(ctx, strings.TrimSpace(val))
if err != nil {
Expand Down

0 comments on commit e84f056

Please sign in to comment.