diff --git a/.gitignore b/.gitignore index 379a88d..715c0f4 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,5 @@ _testmain.go # OSX *.DS_Store *.db +/example/client/client +/example/server/server diff --git a/errors/error.go b/errors/error.go index fb6f2e1..d903519 100644 --- a/errors/error.go +++ b/errors/error.go @@ -2,25 +2,12 @@ package errors import "errors" +// known errors var ( - // ErrNilValue Nil Value - ErrNilValue = errors.New("nil value") - - // ErrInvalidRedirectURI invalid redirect uri - ErrInvalidRedirectURI = errors.New("invalid redirect uri") - - // ErrInvalidAuthorizeCode invalid authorize code + ErrInvalidRedirectURI = errors.New("invalid redirect uri") ErrInvalidAuthorizeCode = errors.New("invalid authorize code") - - // ErrInvalidAccessToken invalid access token - ErrInvalidAccessToken = errors.New("invalid access token") - - // ErrInvalidRefreshToken invalid refresh token - ErrInvalidRefreshToken = errors.New("invalid refresh token") - - // ErrExpiredAccessToken expired access token - ErrExpiredAccessToken = errors.New("expired access token") - - // ErrExpiredRefreshToken expired refresh token - ErrExpiredRefreshToken = errors.New("expired refresh token") + ErrInvalidAccessToken = errors.New("invalid access token") + ErrInvalidRefreshToken = errors.New("invalid refresh token") + ErrExpiredAccessToken = errors.New("expired access token") + ErrExpiredRefreshToken = errors.New("expired refresh token") ) diff --git a/errors/response.go b/errors/response.go index 7e39bb6..6d2a3d8 100644 --- a/errors/response.go +++ b/errors/response.go @@ -10,36 +10,18 @@ type Response struct { StatusCode int `json:"-"` } +// https://tools.ietf.org/html/rfc6749#section-5.2 var ( - // ErrInvalidRequest invalid request - ErrInvalidRequest = errors.New("invalid_request") - - // ErrUnauthorizedClient unauthorized client - ErrUnauthorizedClient = errors.New("unauthorized_client") - - // ErrAccessDenied access denied - ErrAccessDenied = errors.New("access_denied") - - // ErrUnsupportedResponseType unsupported response type + ErrInvalidRequest = errors.New("invalid_request") + ErrUnauthorizedClient = errors.New("unauthorized_client") + ErrAccessDenied = errors.New("access_denied") ErrUnsupportedResponseType = errors.New("unsupported_response_type") - - // ErrInvalidScope invalid scope - ErrInvalidScope = errors.New("invalid_scope") - - // ErrServerError server error - ErrServerError = errors.New("server_error") - - // ErrTemporarilyUnavailable temporarily unavailable - ErrTemporarilyUnavailable = errors.New("temporarily_unavailable") - - // ErrInvalidClient invalid client - ErrInvalidClient = errors.New("invalid_client") - - // ErrInvalidGrant invalid grant - ErrInvalidGrant = errors.New("invalid_grant") - - // ErrUnsupportedGrantType unsupported grant type - ErrUnsupportedGrantType = errors.New("unsupported_grant_type") + ErrInvalidScope = errors.New("invalid_scope") + ErrServerError = errors.New("server_error") + ErrTemporarilyUnavailable = errors.New("temporarily_unavailable") + ErrInvalidClient = errors.New("invalid_client") + ErrInvalidGrant = errors.New("invalid_grant") + ErrUnsupportedGrantType = errors.New("unsupported_grant_type") ) // Descriptions error description diff --git a/example/README.md b/example/README.md index a5b4643..2e5d272 100644 --- a/example/README.md +++ b/example/README.md @@ -6,7 +6,8 @@ Run Server ``` bash $ cd example/server -$ go run main.go +$ go build server.go +$ ./server ``` Run Client @@ -14,7 +15,8 @@ Run Client ``` $ cd example/client -$ go run main.go +$ go build client.go +$ ./client ``` Open the browser diff --git a/example/client/main.go b/example/client/client.go similarity index 100% rename from example/client/main.go rename to example/client/client.go diff --git a/example/server/main.go b/example/server/server.go similarity index 100% rename from example/server/main.go rename to example/server/server.go diff --git a/manage.go b/manage.go index c8887a9..8de4556 100644 --- a/manage.go +++ b/manage.go @@ -18,27 +18,30 @@ type TokenGenerateRequest struct { // Manager Authorization management interface type Manager interface { - // GetClient Get the client information + // Check the interface implementation + CheckInterface() (err error) + + // Get the client information GetClient(clientID string) (cli ClientInfo, err error) - // GenerateAuthToken Generate the authorization token(code) + // Generate the authorization token(code) GenerateAuthToken(rt ResponseType, tgr *TokenGenerateRequest) (authToken TokenInfo, err error) - // GenerateAccessToken Generate the access token + // Generate the access token GenerateAccessToken(rt GrantType, tgr *TokenGenerateRequest) (accessToken TokenInfo, err error) - // RefreshAccessToken Refreshing an access token + // Refreshing an access token RefreshAccessToken(tgr *TokenGenerateRequest) (accessToken TokenInfo, err error) - // RemoveAccessToken Use the access token to delete the token information + // Use the access token to delete the token information RemoveAccessToken(access string) (err error) - // RemoveRefreshToken Use the refresh token to delete the token information + // Use the refresh token to delete the token information RemoveRefreshToken(refresh string) (err error) - // LoadAccessToken According to the access token for corresponding token information + // According to the access token for corresponding token information LoadAccessToken(access string) (ti TokenInfo, err error) - // LoadRefreshToken According to the refresh token for corresponding token information + // According to the refresh token for corresponding token information LoadRefreshToken(refresh string) (ti TokenInfo, err error) } diff --git a/manage/config.go b/manage/config.go new file mode 100644 index 0000000..f9840ba --- /dev/null +++ b/manage/config.go @@ -0,0 +1,23 @@ +package manage + +import "time" + +// Config authorization configuration parameters +type Config struct { + // access token expiration time (in seconds) + AccessTokenExp time.Duration + // refresh token expiration time(in seconds) + RefreshTokenExp time.Duration + // whether to generate the refreshing token + IsGenerateRefresh bool +} + +// default configs +var ( + DefaultCodeExp = time.Minute * 10 + DefaultAuthorizeCodeTokenCfg = &Config{AccessTokenExp: time.Hour * 2, RefreshTokenExp: time.Hour * 24 * 3, IsGenerateRefresh: true} + DefaultImplicitTokenCfg = &Config{AccessTokenExp: time.Hour * 1} + DefaultPasswordTokenCfg = &Config{AccessTokenExp: time.Hour * 2, RefreshTokenExp: time.Hour * 24 * 7, IsGenerateRefresh: true} + DefaultClientTokenCfg = &Config{AccessTokenExp: time.Hour * 2} + DefaultRefreshTokenCfg = &Config{} +) diff --git a/manage/manage_test.go b/manage/manage_test.go index 2d11480..85a14f4 100644 --- a/manage/manage_test.go +++ b/manage/manage_test.go @@ -14,6 +14,12 @@ func TestManager(t *testing.T) { Convey("Manager test", t, func() { manager := manage.NewDefaultManager() manager.MapClientStorage(store.NewTestClientStore()) + manager.MustTokenStorage(store.NewMemoryTokenStore()) + + Convey("CheckInterface test", func() { + err := manager.CheckInterface() + So(err, ShouldBeNil) + }) Convey("GetClient test", func() { cli, err := manager.GetClient("1") @@ -22,7 +28,6 @@ func TestManager(t *testing.T) { }) Convey("Token test", func() { - manager.MustTokenStorage(store.NewMemoryTokenStore()) testManager(manager) }) }) diff --git a/manage/manager.go b/manage/manager.go index bd94d97..6649f5f 100644 --- a/manage/manager.go +++ b/manage/manager.go @@ -11,24 +11,10 @@ import ( "gopkg.in/oauth2.v3/models" ) -// Config Configuration parameters -type Config struct { - AccessTokenExp time.Duration // Access token expiration time (in seconds) - RefreshTokenExp time.Duration // Refresh token expiration time - IsGenerateRefresh bool // Whether to generate the refreshing token -} - // NewDefaultManager Create to default authorization management instance func NewDefaultManager() *Manager { m := NewManager() - - // default config - m.SetAuthorizeCodeExp(time.Minute * 10) - m.SetImplicitTokenCfg(&Config{AccessTokenExp: time.Hour * 1}) - m.SetClientTokenCfg(&Config{AccessTokenExp: time.Hour * 2}) - m.SetAuthorizeCodeTokenCfg(&Config{IsGenerateRefresh: true, AccessTokenExp: time.Hour * 2, RefreshTokenExp: time.Hour * 24 * 3}) - m.SetPasswordTokenCfg(&Config{IsGenerateRefresh: true, AccessTokenExp: time.Hour * 2, RefreshTokenExp: time.Hour * 24 * 7}) - + // default implementation m.MapTokenModel(models.NewToken()) m.MapAuthorizeGenerate(generates.NewAuthorizeGenerate()) m.MapAccessGenerate(generates.NewAccessGenerate()) @@ -46,9 +32,29 @@ func NewManager() *Manager { // Manager Provide authorization management type Manager struct { - injector inject.Injector // Dependency injection - codeExp time.Duration // Authorize code expiration time - gtcfg map[oauth2.GrantType]*Config // Authorization grant configuration + injector inject.Injector + codeExp time.Duration + gtcfg map[oauth2.GrantType]*Config +} + +// get grant type config +func (m *Manager) grantConfig(gt oauth2.GrantType) *Config { + if c, ok := m.gtcfg[gt]; ok && c != nil { + return c + } + switch gt { + case oauth2.AuthorizationCode: + return DefaultAuthorizeCodeTokenCfg + case oauth2.Implicit: + return DefaultImplicitTokenCfg + case oauth2.PasswordCredentials: + return DefaultPasswordTokenCfg + case oauth2.ClientCredentials: + return DefaultClientTokenCfg + case oauth2.Refreshing: + return DefaultRefreshTokenCfg + } + return &Config{} } func (m *Manager) newTokenInfo(ti oauth2.TokenInfo) oauth2.TokenInfo { @@ -91,39 +97,27 @@ func (m *Manager) SetRefreshTokenCfg(cfg *Config) { } // MapTokenModel Mapping the token information model -func (m *Manager) MapTokenModel(token oauth2.TokenInfo) error { - if token == nil { - return errors.ErrNilValue - } +func (m *Manager) MapTokenModel(token oauth2.TokenInfo) { m.injector.Map(token) - return nil + return } // MapAuthorizeGenerate Mapping the authorize code generate interface -func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeGenerate) error { - if gen == nil { - return errors.ErrNilValue - } +func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeGenerate) { m.injector.Map(gen) - return nil + return } // MapAccessGenerate Mapping the access token generate interface -func (m *Manager) MapAccessGenerate(gen oauth2.AccessGenerate) error { - if gen == nil { - return errors.ErrNilValue - } +func (m *Manager) MapAccessGenerate(gen oauth2.AccessGenerate) { m.injector.Map(gen) - return nil + return } // MapClientStorage Mapping the client store interface -func (m *Manager) MapClientStorage(stor oauth2.ClientStore) error { - if stor == nil { - return errors.ErrNilValue - } +func (m *Manager) MapClientStorage(stor oauth2.ClientStore) { m.injector.Map(stor) - return nil + return } // MustClientStorage Mandatory mapping the client store interface @@ -131,19 +125,13 @@ func (m *Manager) MustClientStorage(stor oauth2.ClientStore, err error) { if err != nil { panic(err.Error()) } - if stor == nil { - panic("client store can't be nil value") - } m.injector.Map(stor) } // MapTokenStorage Mapping the token store interface -func (m *Manager) MapTokenStorage(stor oauth2.TokenStore) error { - if stor == nil { - return errors.ErrNilValue - } +func (m *Manager) MapTokenStorage(stor oauth2.TokenStore) { m.injector.Map(stor) - return nil + return } // MustTokenStorage Mandatory mapping the token store interface @@ -151,12 +139,19 @@ func (m *Manager) MustTokenStorage(stor oauth2.TokenStore, err error) { if err != nil { panic(err) } - if stor == nil { - panic("token store can't be nil value") - } m.injector.Map(stor) } +// CheckInterface Check the interface implementation +func (m *Manager) CheckInterface() error { + _, err := m.injector.Invoke(func( + oauth2.TokenInfo, oauth2.AccessGenerate, oauth2.TokenStore, + oauth2.ClientStore, oauth2.AuthorizeGenerate, + ) { + }) + return err +} + // GetClient Get the client information func (m *Manager) GetClient(clientID string) (cli oauth2.ClientInfo, err error) { _, ierr := m.injector.Invoke(func(stor oauth2.ClientStore) { @@ -198,28 +193,35 @@ func (m *Manager) GenerateAuthToken(rt oauth2.ResponseType, tgr *oauth2.TokenGen return } ti.SetCode(tv) - ti.SetCodeExpiresIn(m.codeExp) + codeExp := m.codeExp + if codeExp == 0 { + codeExp = DefaultCodeExp + } + ti.SetCodeExpiresIn(codeExp) ti.SetCodeCreateAt(td.CreateAt) if exp := tgr.AccessTokenExp; exp > 0 { ti.SetAccessExpiresIn(exp) } case oauth2.Token: - tv, rv, terr := tgen.Token(td, m.gtcfg[oauth2.Implicit].IsGenerateRefresh) + icfg := m.grantConfig(oauth2.Implicit) + tv, rv, terr := tgen.Token(td, icfg.IsGenerateRefresh) if terr != nil { err = terr return } ti.SetAccess(tv) ti.SetAccessCreateAt(td.CreateAt) - aexp := m.gtcfg[oauth2.Implicit].AccessTokenExp + // set access token expires + aexp := icfg.AccessTokenExp if exp := tgr.AccessTokenExp; exp > 0 { aexp = exp } ti.SetAccessExpiresIn(aexp) - if rv != "" && m.gtcfg[oauth2.Implicit].IsGenerateRefresh { + + if rv != "" { ti.SetRefresh(rv) ti.SetRefreshCreateAt(td.CreateAt) - ti.SetRefreshExpiresIn(m.gtcfg[oauth2.Implicit].RefreshTokenExp) + ti.SetRefreshExpiresIn(icfg.RefreshTokenExp) } } ti.SetClientID(tgr.ClientID) @@ -245,10 +247,7 @@ func (m *Manager) getAuthorizationCode(code string) (info oauth2.TokenInfo, err if terr != nil { err = terr return - } else if ti == nil { - err = errors.ErrInvalidAuthorizeCode - return - } else if ti.GetCodeCreateAt().Add(ti.GetCodeExpiresIn()).Before(time.Now()) { + } else if ti == nil || ti.GetCodeCreateAt().Add(ti.GetCodeExpiresIn()).Before(time.Now()) { err = errors.ErrInvalidAuthorizeCode return } @@ -305,7 +304,8 @@ func (m *Manager) GenerateAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGene UserID: tgr.UserID, CreateAt: time.Now(), } - av, rv, terr := gen.Token(td, m.gtcfg[gt].IsGenerateRefresh) + gcfg := m.grantConfig(gt) + av, rv, terr := gen.Token(td, gcfg.IsGenerateRefresh) if terr != nil { err = terr return @@ -316,16 +316,16 @@ func (m *Manager) GenerateAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGene ti.SetScope(tgr.Scope) ti.SetAccessCreateAt(td.CreateAt) ti.SetAccess(av) - - aexp := m.gtcfg[gt].AccessTokenExp + // set access token expires + aexp := gcfg.AccessTokenExp if exp := tgr.AccessTokenExp; exp > 0 { aexp = exp } ti.SetAccessExpiresIn(aexp) - if rv != "" && m.gtcfg[gt].IsGenerateRefresh { - ti.SetRefreshCreateAt(td.CreateAt) - ti.SetRefreshExpiresIn(m.gtcfg[gt].RefreshTokenExp) + if rv != "" { ti.SetRefresh(rv) + ti.SetRefreshCreateAt(td.CreateAt) + ti.SetRefreshExpiresIn(gcfg.RefreshTokenExp) } err = stor.Create(ti) @@ -356,18 +356,15 @@ func (m *Manager) RefreshAccessToken(tgr *oauth2.TokenGenerateRequest) (accessTo err = errors.ErrInvalidRefreshToken return } - oldAccess := ti.GetAccess() + oldAccess, oldRefresh := ti.GetAccess(), ti.GetRefresh() _, ierr := m.injector.Invoke(func(stor oauth2.TokenStore, gen oauth2.AccessGenerate) { td := &oauth2.GenerateBasic{ Client: cli, UserID: ti.GetUserID(), CreateAt: time.Now(), } - isGenRefresh := false - if rcfg, ok := m.gtcfg[oauth2.Refreshing]; ok { - isGenRefresh = rcfg.IsGenerateRefresh - } - tv, rv, terr := gen.Token(td, isGenRefresh) + rcfg := m.grantConfig(oauth2.Refreshing) + tv, rv, terr := gen.Token(td, rcfg.IsGenerateRefresh) if terr != nil { err = terr return @@ -389,6 +386,13 @@ func (m *Manager) RefreshAccessToken(tgr *oauth2.TokenGenerateRequest) (accessTo err = verr return } + if rv != "" { + // remove the old refresh token + if verr := stor.RemoveByRefresh(oldRefresh); verr != nil { + err = verr + return + } + } accessToken = ti }) if ierr != nil && err == nil { diff --git a/server/handler.go b/server/handler.go index f1879a6..058348e 100644 --- a/server/handler.go +++ b/server/handler.go @@ -8,40 +8,32 @@ import ( "gopkg.in/oauth2.v3/errors" ) -// ClientInfoHandler Get client info from request -type ClientInfoHandler func(r *http.Request) (clientID, clientSecret string, err error) - -// ClientAuthorizedHandler Check the client allows to use this authorization grant type -type ClientAuthorizedHandler func(clientID string, grantType oauth2.GrantType) (allowed bool, err error) - -// ClientScopeHandler Check the client allows to use scope -type ClientScopeHandler func(clientID, scope string) (allowed bool, err error) - -// UserAuthorizationHandler Get user id from request authorization -type UserAuthorizationHandler func(w http.ResponseWriter, r *http.Request) (userID string, err error) - -// PasswordAuthorizationHandler Get user id from username and password -type PasswordAuthorizationHandler func(username, password string) (userID string, err error) - -// RefreshingScopeHandler Check the scope of the refreshing token -type RefreshingScopeHandler func(newScope, oldScope string) (allowed bool, err error) - -// ResponseErrorHandler Response error handing -type ResponseErrorHandler func(re *errors.Response) - -// InternalErrorHandler Internal error handing -type InternalErrorHandler func(r *http.Request, err error) - -// AuthorizeScopeHandler Set the authorized scope -type AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string, err error) - -// AccessTokenExpHandler Set expiration date for the access token -type AccessTokenExpHandler func(w http.ResponseWriter, r *http.Request) (exp time.Duration, err error) - -// ExtensionFieldsHandler In response to the access token with the extension of the field -type ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) +type ( + // ClientInfoHandler get client info from request + ClientInfoHandler func(r *http.Request) (clientID, clientSecret string, err error) + // ClientAuthorizedHandler check the client allows to use this authorization grant type + ClientAuthorizedHandler func(clientID string, grant oauth2.GrantType) (allowed bool, err error) + // ClientScopeHandler check the client allows to use scope + ClientScopeHandler func(clientID, scope string) (allowed bool, err error) + // UserAuthorizationHandler get user id from request authorization + UserAuthorizationHandler func(w http.ResponseWriter, r *http.Request) (userID string, err error) + // PasswordAuthorizationHandler get user id from username and password + PasswordAuthorizationHandler func(username, password string) (userID string, err error) + // RefreshingScopeHandler check the scope of the refreshing token + RefreshingScopeHandler func(newScope, oldScope string) (allowed bool, err error) + // ResponseErrorHandler response error handing + ResponseErrorHandler func(re *errors.Response) + // InternalErrorHandler internal error handing + InternalErrorHandler func(r *http.Request, err error) + // AuthorizeScopeHandler set the authorized scope + AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string, err error) + // AccessTokenExpHandler set expiration date for the access token + AccessTokenExpHandler func(w http.ResponseWriter, r *http.Request) (exp time.Duration, err error) + // ExtensionFieldsHandler in response to the access token with the extension of the field + ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) +) -// ClientFormHandler Get client data from form +// ClientFormHandler get client data from form func ClientFormHandler(r *http.Request) (clientID, clientSecret string, err error) { clientID = r.Form.Get("client_id") clientSecret = r.Form.Get("client_secret") @@ -51,7 +43,7 @@ func ClientFormHandler(r *http.Request) (clientID, clientSecret string, err erro return } -// ClientBasicHandler Get client data from basic authorization +// ClientBasicHandler get client data from basic authorization func ClientBasicHandler(r *http.Request) (clientID, clientSecret string, err error) { username, password, ok := r.BasicAuth() if !ok { diff --git a/server/request.go b/server/request.go index 5dca880..6454d97 100644 --- a/server/request.go +++ b/server/request.go @@ -6,7 +6,7 @@ import ( "gopkg.in/oauth2.v3" ) -// AuthorizeRequest The authorization request +// AuthorizeRequest authorization request type AuthorizeRequest struct { ResponseType oauth2.ResponseType ClientID string diff --git a/server/server.go b/server/server.go index 523d1cd..3e37e36 100644 --- a/server/server.go +++ b/server/server.go @@ -13,13 +13,25 @@ import ( // NewServer Create to authorization server instance func NewServer(cfg *Config, manager oauth2.Manager) *Server { + if err := manager.CheckInterface(); err != nil { + panic(err) + } srv := &Server{ - Config: cfg, - Manager: manager, - ClientInfoHandler: ClientFormHandler, - ResponseErrorHandler: func(re *errors.Response) { - re.Description = "" - }, + Config: cfg, + Manager: manager, + } + // default handler + srv.ClientInfoHandler = ClientFormHandler + srv.ResponseErrorHandler = func(re *errors.Response) { + re.Description = "" + } + srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (userID string, err error) { + err = errors.ErrAccessDenied + return + } + srv.PasswordAuthorizationHandler = func(username, password string) (userID string, err error) { + err = errors.ErrAccessDenied + return } return srv }