From 6fc2437c117c221740e20f3d8ef1549f764b1a58 Mon Sep 17 00:00:00 2001 From: arran ubels Date: Mon, 6 Jul 2020 17:25:16 +1000 Subject: [PATCH] Not actually required as per section 4.1.3 of the RFC and if done breaks SPA / Jam Stack apps. https://tools.ietf.org/html/rfc6749#section-4.1.3 --- manage/manager.go | 942 +++++++++++++++++++++++----------------------- server/handler.go | 126 +++---- 2 files changed, 534 insertions(+), 534 deletions(-) diff --git a/manage/manager.go b/manage/manager.go index 514bc23..633cea4 100755 --- a/manage/manager.go +++ b/manage/manager.go @@ -1,471 +1,471 @@ -package manage - -import ( - "context" - "time" - - "github.com/go-oauth2/oauth2/v4" - "github.com/go-oauth2/oauth2/v4/errors" - "github.com/go-oauth2/oauth2/v4/generates" - "github.com/go-oauth2/oauth2/v4/models" -) - -// NewDefaultManager create to default authorization management instance -func NewDefaultManager() *Manager { - m := NewManager() - // default implementation - m.MapAuthorizeGenerate(generates.NewAuthorizeGenerate()) - m.MapAccessGenerate(generates.NewAccessGenerate()) - - return m -} - -// NewManager create to authorization management instance -func NewManager() *Manager { - return &Manager{ - gtcfg: make(map[oauth2.GrantType]*Config), - validateURI: DefaultValidateURI, - } -} - -// Manager provide authorization management -type Manager struct { - codeExp time.Duration - gtcfg map[oauth2.GrantType]*Config - rcfg *RefreshingConfig - validateURI ValidateURIHandler - authorizeGenerate oauth2.AuthorizeGenerate - accessGenerate oauth2.AccessGenerate - tokenStore oauth2.TokenStore - clientStore oauth2.ClientStore -} - -// get grant type config -func (m *Manager) grantConfig(gt oauth2.GrantType) *Config { - if c, ok := m.gtcfg[gt]; ok && c != nil { - return c - } - switch gt { - case oauth2.AuthorizationCode: - return DefaultAuthorizeCodeTokenCfg - case oauth2.Implicit: - return DefaultImplicitTokenCfg - case oauth2.PasswordCredentials: - return DefaultPasswordTokenCfg - case oauth2.ClientCredentials: - return DefaultClientTokenCfg - } - return &Config{} -} - -// SetAuthorizeCodeExp set the authorization code expiration time -func (m *Manager) SetAuthorizeCodeExp(exp time.Duration) { - m.codeExp = exp -} - -// SetAuthorizeCodeTokenCfg set the authorization code grant token config -func (m *Manager) SetAuthorizeCodeTokenCfg(cfg *Config) { - m.gtcfg[oauth2.AuthorizationCode] = cfg -} - -// SetImplicitTokenCfg set the implicit grant token config -func (m *Manager) SetImplicitTokenCfg(cfg *Config) { - m.gtcfg[oauth2.Implicit] = cfg -} - -// SetPasswordTokenCfg set the password grant token config -func (m *Manager) SetPasswordTokenCfg(cfg *Config) { - m.gtcfg[oauth2.PasswordCredentials] = cfg -} - -// SetClientTokenCfg set the client grant token config -func (m *Manager) SetClientTokenCfg(cfg *Config) { - m.gtcfg[oauth2.ClientCredentials] = cfg -} - -// SetRefreshTokenCfg set the refreshing token config -func (m *Manager) SetRefreshTokenCfg(cfg *RefreshingConfig) { - m.rcfg = cfg -} - -// SetValidateURIHandler set the validates that RedirectURI is contained in baseURI -func (m *Manager) SetValidateURIHandler(handler ValidateURIHandler) { - m.validateURI = handler -} - -// MapAuthorizeGenerate mapping the authorize code generate interface -func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeGenerate) { - m.authorizeGenerate = gen -} - -// MapAccessGenerate mapping the access token generate interface -func (m *Manager) MapAccessGenerate(gen oauth2.AccessGenerate) { - m.accessGenerate = gen -} - -// MapClientStorage mapping the client store interface -func (m *Manager) MapClientStorage(stor oauth2.ClientStore) { - m.clientStore = stor -} - -// MustClientStorage mandatory mapping the client store interface -func (m *Manager) MustClientStorage(stor oauth2.ClientStore, err error) { - if err != nil { - panic(err.Error()) - } - m.clientStore = stor -} - -// MapTokenStorage mapping the token store interface -func (m *Manager) MapTokenStorage(stor oauth2.TokenStore) { - m.tokenStore = stor -} - -// MustTokenStorage mandatory mapping the token store interface -func (m *Manager) MustTokenStorage(stor oauth2.TokenStore, err error) { - if err != nil { - panic(err) - } - m.tokenStore = stor -} - -// GetClient get the client information -func (m *Manager) GetClient(ctx context.Context, clientID string) (cli oauth2.ClientInfo, err error) { - cli, err = m.clientStore.GetByID(ctx, clientID) - if err != nil { - return - } else if cli == nil { - err = errors.ErrInvalidClient - } - return -} - -// GenerateAuthToken generate the authorization token(code) -func (m *Manager) GenerateAuthToken(ctx context.Context, rt oauth2.ResponseType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { - cli, err := m.GetClient(ctx, tgr.ClientID) - if err != nil { - return nil, err - } else if tgr.RedirectURI != "" { - if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil { - return nil, err - } - } - - ti := models.NewToken() - ti.SetClientID(tgr.ClientID) - ti.SetUserID(tgr.UserID) - ti.SetRedirectURI(tgr.RedirectURI) - ti.SetScope(tgr.Scope) - - createAt := time.Now() - td := &oauth2.GenerateBasic{ - Client: cli, - UserID: tgr.UserID, - CreateAt: createAt, - TokenInfo: ti, - Request: tgr.Request, - } - switch rt { - case oauth2.Code: - codeExp := m.codeExp - if codeExp == 0 { - codeExp = DefaultCodeExp - } - ti.SetCodeCreateAt(createAt) - ti.SetCodeExpiresIn(codeExp) - if exp := tgr.AccessTokenExp; exp > 0 { - ti.SetAccessExpiresIn(exp) - } - - tv, err := m.authorizeGenerate.Token(ctx, td) - if err != nil { - return nil, err - } - ti.SetCode(tv) - case oauth2.Token: - // set access token expires - icfg := m.grantConfig(oauth2.Implicit) - aexp := icfg.AccessTokenExp - if exp := tgr.AccessTokenExp; exp > 0 { - aexp = exp - } - ti.SetAccessCreateAt(createAt) - ti.SetAccessExpiresIn(aexp) - - if icfg.IsGenerateRefresh { - ti.SetRefreshCreateAt(createAt) - ti.SetRefreshExpiresIn(icfg.RefreshTokenExp) - } - - tv, rv, err := m.accessGenerate.Token(ctx, td, icfg.IsGenerateRefresh) - if err != nil { - return nil, err - } - ti.SetAccess(tv) - - if rv != "" { - ti.SetRefresh(rv) - } - } - - err = m.tokenStore.Create(ctx, ti) - if err != nil { - return nil, err - } - return ti, nil -} - -// get authorization code data -func (m *Manager) getAuthorizationCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { - ti, err := m.tokenStore.GetByCode(ctx, code) - if err != nil { - return nil, err - } else if ti == nil || ti.GetCode() != code || ti.GetCodeCreateAt().Add(ti.GetCodeExpiresIn()).Before(time.Now()) { - err = errors.ErrInvalidAuthorizeCode - return nil, errors.ErrInvalidAuthorizeCode - } - return ti, nil -} - -// delete authorization code data -func (m *Manager) delAuthorizationCode(ctx context.Context, code string) error { - return m.tokenStore.RemoveByCode(ctx, code) -} - -// get and delete authorization code data -func (m *Manager) getAndDelAuthorizationCode(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { - code := tgr.Code - ti, err := m.getAuthorizationCode(ctx, code) - if err != nil { - return nil, err - } else if ti.GetClientID() != tgr.ClientID { - return nil, errors.ErrInvalidAuthorizeCode - } else if codeURI := ti.GetRedirectURI(); codeURI != "" && codeURI != tgr.RedirectURI { - return nil, errors.ErrInvalidAuthorizeCode - } - - err = m.delAuthorizationCode(ctx, code) - if err != nil { - return nil, err - } - return ti, nil -} - -// GenerateAccessToken generate the access token -func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { - cli, err := m.GetClient(ctx, tgr.ClientID) - if err != nil { - return nil, err - } - if cliPass, ok := cli.(oauth2.ClientPasswordVerifier); ok { - if !cliPass.VerifyPassword(tgr.ClientSecret) { - return nil, errors.ErrInvalidClient - } - } else if tgr.ClientSecret != cli.GetSecret() { - return nil, errors.ErrInvalidClient - } - if tgr.RedirectURI != "" { - if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil { - return nil, err - } - } - - if gt == oauth2.AuthorizationCode { - ti, err := m.getAndDelAuthorizationCode(ctx, tgr) - if err != nil { - return nil, err - } - tgr.UserID = ti.GetUserID() - tgr.Scope = ti.GetScope() - if exp := ti.GetAccessExpiresIn(); exp > 0 { - tgr.AccessTokenExp = exp - } - } - - ti := models.NewToken() - ti.SetClientID(tgr.ClientID) - ti.SetUserID(tgr.UserID) - ti.SetRedirectURI(tgr.RedirectURI) - ti.SetScope(tgr.Scope) - - createAt := time.Now() - ti.SetAccessCreateAt(createAt) - - // set access token expires - gcfg := m.grantConfig(gt) - aexp := gcfg.AccessTokenExp - if exp := tgr.AccessTokenExp; exp > 0 { - aexp = exp - } - ti.SetAccessExpiresIn(aexp) - if gcfg.IsGenerateRefresh { - ti.SetRefreshCreateAt(createAt) - ti.SetRefreshExpiresIn(gcfg.RefreshTokenExp) - } - - td := &oauth2.GenerateBasic{ - Client: cli, - UserID: tgr.UserID, - CreateAt: createAt, - TokenInfo: ti, - Request: tgr.Request, - } - - av, rv, err := m.accessGenerate.Token(ctx, td, gcfg.IsGenerateRefresh) - if err != nil { - return nil, err - } - ti.SetAccess(av) - - if rv != "" { - ti.SetRefresh(rv) - } - - err = m.tokenStore.Create(ctx, ti) - if err != nil { - return nil, err - } - - return ti, nil -} - -// RefreshAccessToken refreshing an access token -func (m *Manager) RefreshAccessToken(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { - cli, err := m.GetClient(ctx, tgr.ClientID) - if err != nil { - return nil, err - } else if tgr.ClientSecret != cli.GetSecret() { - return nil, errors.ErrInvalidClient - } - - ti, err := m.LoadRefreshToken(ctx, tgr.Refresh) - if err != nil { - return nil, err - } else if ti.GetClientID() != tgr.ClientID { - return nil, errors.ErrInvalidRefreshToken - } - - oldAccess, oldRefresh := ti.GetAccess(), ti.GetRefresh() - - td := &oauth2.GenerateBasic{ - Client: cli, - UserID: ti.GetUserID(), - CreateAt: time.Now(), - TokenInfo: ti, - Request: tgr.Request, - } - - rcfg := DefaultRefreshTokenCfg - if v := m.rcfg; v != nil { - rcfg = v - } - - ti.SetAccessCreateAt(td.CreateAt) - if v := rcfg.AccessTokenExp; v > 0 { - ti.SetAccessExpiresIn(v) - } - - if v := rcfg.RefreshTokenExp; v > 0 { - ti.SetRefreshExpiresIn(v) - } - - if rcfg.IsResetRefreshTime { - ti.SetRefreshCreateAt(td.CreateAt) - } - - if scope := tgr.Scope; scope != "" { - ti.SetScope(scope) - } - - tv, rv, err := m.accessGenerate.Token(ctx, td, rcfg.IsGenerateRefresh) - if err != nil { - return nil, err - } - - ti.SetAccess(tv) - if rv != "" { - ti.SetRefresh(rv) - } - - if err := m.tokenStore.Create(ctx, ti); err != nil { - return nil, err - } - - if rcfg.IsRemoveAccess { - // remove the old access token - if err := m.tokenStore.RemoveByAccess(ctx, oldAccess); err != nil { - return nil, err - } - } - - if rcfg.IsRemoveRefreshing && rv != "" { - // remove the old refresh token - if err := m.tokenStore.RemoveByRefresh(ctx, oldRefresh); err != nil { - return nil, err - } - } - - if rv == "" { - ti.SetRefresh("") - ti.SetRefreshCreateAt(time.Now()) - ti.SetRefreshExpiresIn(0) - } - - return ti, nil -} - -// RemoveAccessToken use the access token to delete the token information -func (m *Manager) RemoveAccessToken(ctx context.Context, access string) error { - if access == "" { - return errors.ErrInvalidAccessToken - } - return m.tokenStore.RemoveByAccess(ctx, access) -} - -// RemoveRefreshToken use the refresh token to delete the token information -func (m *Manager) RemoveRefreshToken(ctx context.Context, refresh string) error { - if refresh == "" { - return errors.ErrInvalidAccessToken - } - return m.tokenStore.RemoveByRefresh(ctx, refresh) -} - -// LoadAccessToken according to the access token for corresponding token information -func (m *Manager) LoadAccessToken(ctx context.Context, access string) (oauth2.TokenInfo, error) { - if access == "" { - return nil, errors.ErrInvalidAccessToken - } - - ct := time.Now() - ti, err := m.tokenStore.GetByAccess(ctx, access) - if err != nil { - return nil, err - } else if ti == nil || ti.GetAccess() != access { - return nil, errors.ErrInvalidAccessToken - } else if ti.GetRefresh() != "" && ti.GetRefreshExpiresIn() != 0 && - ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) { - return nil, errors.ErrExpiredRefreshToken - } else if ti.GetAccessExpiresIn() != 0 && - ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) { - return nil, errors.ErrExpiredAccessToken - } - return ti, nil -} - -// LoadRefreshToken according to the refresh token for corresponding token information -func (m *Manager) LoadRefreshToken(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { - if refresh == "" { - return nil, errors.ErrInvalidRefreshToken - } - - ti, err := m.tokenStore.GetByRefresh(ctx, refresh) - if err != nil { - return nil, err - } else if ti == nil || ti.GetRefresh() != refresh { - return nil, errors.ErrInvalidRefreshToken - } else if ti.GetRefreshExpiresIn() != 0 && // refresh token set to not expire - ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(time.Now()) { - return nil, errors.ErrExpiredRefreshToken - } - return ti, nil -} +package manage + +import ( + "context" + "time" + + "github.com/go-oauth2/oauth2/v4" + "github.com/go-oauth2/oauth2/v4/errors" + "github.com/go-oauth2/oauth2/v4/generates" + "github.com/go-oauth2/oauth2/v4/models" +) + +// NewDefaultManager create to default authorization management instance +func NewDefaultManager() *Manager { + m := NewManager() + // default implementation + m.MapAuthorizeGenerate(generates.NewAuthorizeGenerate()) + m.MapAccessGenerate(generates.NewAccessGenerate()) + + return m +} + +// NewManager create to authorization management instance +func NewManager() *Manager { + return &Manager{ + gtcfg: make(map[oauth2.GrantType]*Config), + validateURI: DefaultValidateURI, + } +} + +// Manager provide authorization management +type Manager struct { + codeExp time.Duration + gtcfg map[oauth2.GrantType]*Config + rcfg *RefreshingConfig + validateURI ValidateURIHandler + authorizeGenerate oauth2.AuthorizeGenerate + accessGenerate oauth2.AccessGenerate + tokenStore oauth2.TokenStore + clientStore oauth2.ClientStore +} + +// get grant type config +func (m *Manager) grantConfig(gt oauth2.GrantType) *Config { + if c, ok := m.gtcfg[gt]; ok && c != nil { + return c + } + switch gt { + case oauth2.AuthorizationCode: + return DefaultAuthorizeCodeTokenCfg + case oauth2.Implicit: + return DefaultImplicitTokenCfg + case oauth2.PasswordCredentials: + return DefaultPasswordTokenCfg + case oauth2.ClientCredentials: + return DefaultClientTokenCfg + } + return &Config{} +} + +// SetAuthorizeCodeExp set the authorization code expiration time +func (m *Manager) SetAuthorizeCodeExp(exp time.Duration) { + m.codeExp = exp +} + +// SetAuthorizeCodeTokenCfg set the authorization code grant token config +func (m *Manager) SetAuthorizeCodeTokenCfg(cfg *Config) { + m.gtcfg[oauth2.AuthorizationCode] = cfg +} + +// SetImplicitTokenCfg set the implicit grant token config +func (m *Manager) SetImplicitTokenCfg(cfg *Config) { + m.gtcfg[oauth2.Implicit] = cfg +} + +// SetPasswordTokenCfg set the password grant token config +func (m *Manager) SetPasswordTokenCfg(cfg *Config) { + m.gtcfg[oauth2.PasswordCredentials] = cfg +} + +// SetClientTokenCfg set the client grant token config +func (m *Manager) SetClientTokenCfg(cfg *Config) { + m.gtcfg[oauth2.ClientCredentials] = cfg +} + +// SetRefreshTokenCfg set the refreshing token config +func (m *Manager) SetRefreshTokenCfg(cfg *RefreshingConfig) { + m.rcfg = cfg +} + +// SetValidateURIHandler set the validates that RedirectURI is contained in baseURI +func (m *Manager) SetValidateURIHandler(handler ValidateURIHandler) { + m.validateURI = handler +} + +// MapAuthorizeGenerate mapping the authorize code generate interface +func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeGenerate) { + m.authorizeGenerate = gen +} + +// MapAccessGenerate mapping the access token generate interface +func (m *Manager) MapAccessGenerate(gen oauth2.AccessGenerate) { + m.accessGenerate = gen +} + +// MapClientStorage mapping the client store interface +func (m *Manager) MapClientStorage(stor oauth2.ClientStore) { + m.clientStore = stor +} + +// MustClientStorage mandatory mapping the client store interface +func (m *Manager) MustClientStorage(stor oauth2.ClientStore, err error) { + if err != nil { + panic(err.Error()) + } + m.clientStore = stor +} + +// MapTokenStorage mapping the token store interface +func (m *Manager) MapTokenStorage(stor oauth2.TokenStore) { + m.tokenStore = stor +} + +// MustTokenStorage mandatory mapping the token store interface +func (m *Manager) MustTokenStorage(stor oauth2.TokenStore, err error) { + if err != nil { + panic(err) + } + m.tokenStore = stor +} + +// GetClient get the client information +func (m *Manager) GetClient(ctx context.Context, clientID string) (cli oauth2.ClientInfo, err error) { + cli, err = m.clientStore.GetByID(ctx, clientID) + if err != nil { + return + } else if cli == nil { + err = errors.ErrInvalidClient + } + return +} + +// GenerateAuthToken generate the authorization token(code) +func (m *Manager) GenerateAuthToken(ctx context.Context, rt oauth2.ResponseType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { + cli, err := m.GetClient(ctx, tgr.ClientID) + if err != nil { + return nil, err + } else if tgr.RedirectURI != "" { + if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil { + return nil, err + } + } + + ti := models.NewToken() + ti.SetClientID(tgr.ClientID) + ti.SetUserID(tgr.UserID) + ti.SetRedirectURI(tgr.RedirectURI) + ti.SetScope(tgr.Scope) + + createAt := time.Now() + td := &oauth2.GenerateBasic{ + Client: cli, + UserID: tgr.UserID, + CreateAt: createAt, + TokenInfo: ti, + Request: tgr.Request, + } + switch rt { + case oauth2.Code: + codeExp := m.codeExp + if codeExp == 0 { + codeExp = DefaultCodeExp + } + ti.SetCodeCreateAt(createAt) + ti.SetCodeExpiresIn(codeExp) + if exp := tgr.AccessTokenExp; exp > 0 { + ti.SetAccessExpiresIn(exp) + } + + tv, err := m.authorizeGenerate.Token(ctx, td) + if err != nil { + return nil, err + } + ti.SetCode(tv) + case oauth2.Token: + // set access token expires + icfg := m.grantConfig(oauth2.Implicit) + aexp := icfg.AccessTokenExp + if exp := tgr.AccessTokenExp; exp > 0 { + aexp = exp + } + ti.SetAccessCreateAt(createAt) + ti.SetAccessExpiresIn(aexp) + + if icfg.IsGenerateRefresh { + ti.SetRefreshCreateAt(createAt) + ti.SetRefreshExpiresIn(icfg.RefreshTokenExp) + } + + tv, rv, err := m.accessGenerate.Token(ctx, td, icfg.IsGenerateRefresh) + if err != nil { + return nil, err + } + ti.SetAccess(tv) + + if rv != "" { + ti.SetRefresh(rv) + } + } + + err = m.tokenStore.Create(ctx, ti) + if err != nil { + return nil, err + } + return ti, nil +} + +// get authorization code data +func (m *Manager) getAuthorizationCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { + ti, err := m.tokenStore.GetByCode(ctx, code) + if err != nil { + return nil, err + } else if ti == nil || ti.GetCode() != code || ti.GetCodeCreateAt().Add(ti.GetCodeExpiresIn()).Before(time.Now()) { + err = errors.ErrInvalidAuthorizeCode + return nil, errors.ErrInvalidAuthorizeCode + } + return ti, nil +} + +// delete authorization code data +func (m *Manager) delAuthorizationCode(ctx context.Context, code string) error { + return m.tokenStore.RemoveByCode(ctx, code) +} + +// get and delete authorization code data +func (m *Manager) getAndDelAuthorizationCode(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { + code := tgr.Code + ti, err := m.getAuthorizationCode(ctx, code) + if err != nil { + return nil, err + } else if ti.GetClientID() != tgr.ClientID { + return nil, errors.ErrInvalidAuthorizeCode + } else if codeURI := ti.GetRedirectURI(); codeURI != "" && codeURI != tgr.RedirectURI { + return nil, errors.ErrInvalidAuthorizeCode + } + + err = m.delAuthorizationCode(ctx, code) + if err != nil { + return nil, err + } + return ti, nil +} + +// GenerateAccessToken generate the access token +func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { + cli, err := m.GetClient(ctx, tgr.ClientID) + if err != nil { + return nil, err + } + if cliPass, ok := cli.(oauth2.ClientPasswordVerifier); ok { + if !cliPass.VerifyPassword(tgr.ClientSecret) { + return nil, errors.ErrInvalidClient + } + } else if len(tgr.ClientSecret) > 0 && tgr.ClientSecret != cli.GetSecret() { + return nil, errors.ErrInvalidClient + } + if tgr.RedirectURI != "" { + if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil { + return nil, err + } + } + + if gt == oauth2.AuthorizationCode { + ti, err := m.getAndDelAuthorizationCode(ctx, tgr) + if err != nil { + return nil, err + } + tgr.UserID = ti.GetUserID() + tgr.Scope = ti.GetScope() + if exp := ti.GetAccessExpiresIn(); exp > 0 { + tgr.AccessTokenExp = exp + } + } + + ti := models.NewToken() + ti.SetClientID(tgr.ClientID) + ti.SetUserID(tgr.UserID) + ti.SetRedirectURI(tgr.RedirectURI) + ti.SetScope(tgr.Scope) + + createAt := time.Now() + ti.SetAccessCreateAt(createAt) + + // set access token expires + gcfg := m.grantConfig(gt) + aexp := gcfg.AccessTokenExp + if exp := tgr.AccessTokenExp; exp > 0 { + aexp = exp + } + ti.SetAccessExpiresIn(aexp) + if gcfg.IsGenerateRefresh { + ti.SetRefreshCreateAt(createAt) + ti.SetRefreshExpiresIn(gcfg.RefreshTokenExp) + } + + td := &oauth2.GenerateBasic{ + Client: cli, + UserID: tgr.UserID, + CreateAt: createAt, + TokenInfo: ti, + Request: tgr.Request, + } + + av, rv, err := m.accessGenerate.Token(ctx, td, gcfg.IsGenerateRefresh) + if err != nil { + return nil, err + } + ti.SetAccess(av) + + if rv != "" { + ti.SetRefresh(rv) + } + + err = m.tokenStore.Create(ctx, ti) + if err != nil { + return nil, err + } + + return ti, nil +} + +// RefreshAccessToken refreshing an access token +func (m *Manager) RefreshAccessToken(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { + cli, err := m.GetClient(ctx, tgr.ClientID) + if err != nil { + return nil, err + } else if tgr.ClientSecret != cli.GetSecret() { + return nil, errors.ErrInvalidClient + } + + ti, err := m.LoadRefreshToken(ctx, tgr.Refresh) + if err != nil { + return nil, err + } else if ti.GetClientID() != tgr.ClientID { + return nil, errors.ErrInvalidRefreshToken + } + + oldAccess, oldRefresh := ti.GetAccess(), ti.GetRefresh() + + td := &oauth2.GenerateBasic{ + Client: cli, + UserID: ti.GetUserID(), + CreateAt: time.Now(), + TokenInfo: ti, + Request: tgr.Request, + } + + rcfg := DefaultRefreshTokenCfg + if v := m.rcfg; v != nil { + rcfg = v + } + + ti.SetAccessCreateAt(td.CreateAt) + if v := rcfg.AccessTokenExp; v > 0 { + ti.SetAccessExpiresIn(v) + } + + if v := rcfg.RefreshTokenExp; v > 0 { + ti.SetRefreshExpiresIn(v) + } + + if rcfg.IsResetRefreshTime { + ti.SetRefreshCreateAt(td.CreateAt) + } + + if scope := tgr.Scope; scope != "" { + ti.SetScope(scope) + } + + tv, rv, err := m.accessGenerate.Token(ctx, td, rcfg.IsGenerateRefresh) + if err != nil { + return nil, err + } + + ti.SetAccess(tv) + if rv != "" { + ti.SetRefresh(rv) + } + + if err := m.tokenStore.Create(ctx, ti); err != nil { + return nil, err + } + + if rcfg.IsRemoveAccess { + // remove the old access token + if err := m.tokenStore.RemoveByAccess(ctx, oldAccess); err != nil { + return nil, err + } + } + + if rcfg.IsRemoveRefreshing && rv != "" { + // remove the old refresh token + if err := m.tokenStore.RemoveByRefresh(ctx, oldRefresh); err != nil { + return nil, err + } + } + + if rv == "" { + ti.SetRefresh("") + ti.SetRefreshCreateAt(time.Now()) + ti.SetRefreshExpiresIn(0) + } + + return ti, nil +} + +// RemoveAccessToken use the access token to delete the token information +func (m *Manager) RemoveAccessToken(ctx context.Context, access string) error { + if access == "" { + return errors.ErrInvalidAccessToken + } + return m.tokenStore.RemoveByAccess(ctx, access) +} + +// RemoveRefreshToken use the refresh token to delete the token information +func (m *Manager) RemoveRefreshToken(ctx context.Context, refresh string) error { + if refresh == "" { + return errors.ErrInvalidAccessToken + } + return m.tokenStore.RemoveByRefresh(ctx, refresh) +} + +// LoadAccessToken according to the access token for corresponding token information +func (m *Manager) LoadAccessToken(ctx context.Context, access string) (oauth2.TokenInfo, error) { + if access == "" { + return nil, errors.ErrInvalidAccessToken + } + + ct := time.Now() + ti, err := m.tokenStore.GetByAccess(ctx, access) + if err != nil { + return nil, err + } else if ti == nil || ti.GetAccess() != access { + return nil, errors.ErrInvalidAccessToken + } else if ti.GetRefresh() != "" && ti.GetRefreshExpiresIn() != 0 && + ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) { + return nil, errors.ErrExpiredRefreshToken + } else if ti.GetAccessExpiresIn() != 0 && + ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) { + return nil, errors.ErrExpiredAccessToken + } + return ti, nil +} + +// LoadRefreshToken according to the refresh token for corresponding token information +func (m *Manager) LoadRefreshToken(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { + if refresh == "" { + return nil, errors.ErrInvalidRefreshToken + } + + ti, err := m.tokenStore.GetByRefresh(ctx, refresh) + if err != nil { + return nil, err + } else if ti == nil || ti.GetRefresh() != refresh { + return nil, errors.ErrInvalidRefreshToken + } else if ti.GetRefreshExpiresIn() != 0 && // refresh token set to not expire + ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(time.Now()) { + return nil, errors.ErrExpiredRefreshToken + } + return ti, nil +} diff --git a/server/handler.go b/server/handler.go index 67d9c9f..3b1f713 100755 --- a/server/handler.go +++ b/server/handler.go @@ -1,63 +1,63 @@ -package server - -import ( - "net/http" - "time" - - "github.com/go-oauth2/oauth2/v4" - "github.com/go-oauth2/oauth2/v4/errors" -) - -type ( - // ClientInfoHandler get client info from request - ClientInfoHandler func(r *http.Request) (clientID, clientSecret string, err error) - - // ClientAuthorizedHandler check the client allows to use this authorization grant type - ClientAuthorizedHandler func(clientID string, grant oauth2.GrantType) (allowed bool, err error) - - // ClientScopeHandler check the client allows to use scope - ClientScopeHandler func(clientID, scope string) (allowed bool, err error) - - // UserAuthorizationHandler get user id from request authorization - UserAuthorizationHandler func(w http.ResponseWriter, r *http.Request) (userID string, err error) - - // PasswordAuthorizationHandler get user id from username and password - PasswordAuthorizationHandler func(username, password string) (userID string, err error) - - // RefreshingScopeHandler check the scope of the refreshing token - RefreshingScopeHandler func(newScope, oldScope string) (allowed bool, err error) - - // ResponseErrorHandler response error handing - ResponseErrorHandler func(re *errors.Response) - - // InternalErrorHandler internal error handing - InternalErrorHandler func(err error) (re *errors.Response) - - // AuthorizeScopeHandler set the authorized scope - AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string, err error) - - // AccessTokenExpHandler set expiration date for the access token - AccessTokenExpHandler func(w http.ResponseWriter, r *http.Request) (exp time.Duration, err error) - - // ExtensionFieldsHandler in response to the access token with the extension of the field - ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) -) - -// ClientFormHandler get client data from form -func ClientFormHandler(r *http.Request) (string, string, error) { - clientID := r.Form.Get("client_id") - clientSecret := r.Form.Get("client_secret") - if clientID == "" || clientSecret == "" { - return "", "", errors.ErrInvalidClient - } - return clientID, clientSecret, nil -} - -// ClientBasicHandler get client data from basic authorization -func ClientBasicHandler(r *http.Request) (string, string, error) { - username, password, ok := r.BasicAuth() - if !ok { - return "", "", errors.ErrInvalidClient - } - return username, password, nil -} +package server + +import ( + "net/http" + "time" + + "github.com/go-oauth2/oauth2/v4" + "github.com/go-oauth2/oauth2/v4/errors" +) + +type ( + // ClientInfoHandler get client info from request + ClientInfoHandler func(r *http.Request) (clientID, clientSecret string, err error) + + // ClientAuthorizedHandler check the client allows to use this authorization grant type + ClientAuthorizedHandler func(clientID string, grant oauth2.GrantType) (allowed bool, err error) + + // ClientScopeHandler check the client allows to use scope + ClientScopeHandler func(clientID, scope string) (allowed bool, err error) + + // UserAuthorizationHandler get user id from request authorization + UserAuthorizationHandler func(w http.ResponseWriter, r *http.Request) (userID string, err error) + + // PasswordAuthorizationHandler get user id from username and password + PasswordAuthorizationHandler func(username, password string) (userID string, err error) + + // RefreshingScopeHandler check the scope of the refreshing token + RefreshingScopeHandler func(newScope, oldScope string) (allowed bool, err error) + + // ResponseErrorHandler response error handing + ResponseErrorHandler func(re *errors.Response) + + // InternalErrorHandler internal error handing + InternalErrorHandler func(err error) (re *errors.Response) + + // AuthorizeScopeHandler set the authorized scope + AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string, err error) + + // AccessTokenExpHandler set expiration date for the access token + AccessTokenExpHandler func(w http.ResponseWriter, r *http.Request) (exp time.Duration, err error) + + // ExtensionFieldsHandler in response to the access token with the extension of the field + ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) +) + +// ClientFormHandler get client data from form +func ClientFormHandler(r *http.Request) (string, string, error) { + clientID := r.Form.Get("client_id") + if clientID == "" { + return "", "", errors.ErrInvalidClient + } + clientSecret := r.Form.Get("client_secret") + return clientID, clientSecret, nil +} + +// ClientBasicHandler get client data from basic authorization +func ClientBasicHandler(r *http.Request) (string, string, error) { + username, password, ok := r.BasicAuth() + if !ok { + return "", "", errors.ErrInvalidClient + } + return username, password, nil +}