diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..74aa1f6 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,2 @@ +[*.go] +end_of_line = crlf \ No newline at end of file diff --git a/manage/manager.go b/manage/manager.go index 74adf0e..26ba2b0 100755 --- a/manage/manager.go +++ b/manage/manager.go @@ -218,7 +218,7 @@ func (m *Manager) GenerateAuthToken(ctx context.Context, rt oauth2.ResponseType, } return ti, nil } - + // get authorization code data func (m *Manager) getAuthorizationCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { ti, err := m.tokenStore.GetByCode(ctx, code) diff --git a/server/handler.go b/server/handler.go index bb2c6d8..e27939b 100755 --- a/server/handler.go +++ b/server/handler.go @@ -1,66 +1,66 @@ -package server - -import ( - "net/http" - "time" - - "github.com/go-oauth2/oauth2/v4" - "github.com/go-oauth2/oauth2/v4/errors" -) - -type ( - // ClientInfoHandler get client info from request - ClientInfoHandler func(r *http.Request) (clientID, clientSecret string, err error) - - // ClientAuthorizedHandler check the client allows to use this authorization grant type - ClientAuthorizedHandler func(clientID string, grant oauth2.GrantType) (allowed bool, err error) - - // ClientScopeHandler check the client allows to use scope - ClientScopeHandler func(tgr *oauth2.TokenGenerateRequest) (allowed bool, err error) - - // UserAuthorizationHandler get user id from request authorization - UserAuthorizationHandler func(w http.ResponseWriter, r *http.Request) (userID string, err error) - - // PasswordAuthorizationHandler get user id from username and password - PasswordAuthorizationHandler func(username, password string) (userID string, err error) - - // RefreshingScopeHandler check the scope of the refreshing token - RefreshingScopeHandler func(tgr *oauth2.TokenGenerateRequest, oldScope string) (allowed bool, err error) - - // RefreshingValidationHandler check if refresh_token is still valid. eg no revocation or other - RefreshingValidationHandler func(ti oauth2.TokenInfo) (allowed bool, err error) - - // ResponseErrorHandler response error handing - ResponseErrorHandler func(re *errors.Response) - - // InternalErrorHandler internal error handing - InternalErrorHandler func(err error) (re *errors.Response) - - // AuthorizeScopeHandler set the authorized scope - AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string, err error) - - // AccessTokenExpHandler set expiration date for the access token - AccessTokenExpHandler func(w http.ResponseWriter, r *http.Request) (exp time.Duration, err error) - - // ExtensionFieldsHandler in response to the access token with the extension of the field - ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) -) - -// ClientFormHandler get client data from form -func ClientFormHandler(r *http.Request) (string, string, error) { - clientID := r.Form.Get("client_id") - if clientID == "" { - return "", "", errors.ErrInvalidClient - } - clientSecret := r.Form.Get("client_secret") - return clientID, clientSecret, nil -} - -// ClientBasicHandler get client data from basic authorization -func ClientBasicHandler(r *http.Request) (string, string, error) { - username, password, ok := r.BasicAuth() - if !ok { - return "", "", errors.ErrInvalidClient - } - return username, password, nil -} +package server + +import ( + "net/http" + "time" + + "github.com/go-oauth2/oauth2/v4" + "github.com/go-oauth2/oauth2/v4/errors" +) + +type ( + // ClientInfoHandler get client info from request + ClientInfoHandler func(r *http.Request) (clientID, clientSecret string, err error) + + // ClientAuthorizedHandler check the client allows to use this authorization grant type + ClientAuthorizedHandler func(clientID string, grant oauth2.GrantType) (allowed bool, err error) + + // ClientScopeHandler check the client allows to use scope + ClientScopeHandler func(tgr *oauth2.TokenGenerateRequest) (allowed bool, err error) + + // UserAuthorizationHandler get user id from request authorization + UserAuthorizationHandler func(w http.ResponseWriter, r *http.Request) (userID string, err error) + + // PasswordAuthorizationHandler get user id from username and password + PasswordAuthorizationHandler func(username, password string) (userID string, err error) + + // RefreshingScopeHandler check the scope of the refreshing token + RefreshingScopeHandler func(tgr *oauth2.TokenGenerateRequest, oldScope string) (allowed bool, err error) + + // RefreshingValidationHandler check if refresh_token is still valid. eg no revocation or other + RefreshingValidationHandler func(ti oauth2.TokenInfo) (allowed bool, err error) + + // ResponseErrorHandler response error handing + ResponseErrorHandler func(re *errors.Response) + + // InternalErrorHandler internal error handing + InternalErrorHandler func(err error) (re *errors.Response) + + // AuthorizeScopeHandler set the authorized scope + AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string, err error) + + // AccessTokenExpHandler set expiration date for the access token + AccessTokenExpHandler func(w http.ResponseWriter, r *http.Request) (exp time.Duration, err error) + + // ExtensionFieldsHandler in response to the access token with the extension of the field + ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) +) + +// ClientFormHandler get client data from form +func ClientFormHandler(r *http.Request) (string, string, error) { + clientID := r.Form.Get("client_id") + if clientID == "" { + return "", "", errors.ErrInvalidClient + } + clientSecret := r.Form.Get("client_secret") + return clientID, clientSecret, nil +} + +// ClientBasicHandler get client data from basic authorization +func ClientBasicHandler(r *http.Request) (string, string, error) { + username, password, ok := r.BasicAuth() + if !ok { + return "", "", errors.ErrInvalidClient + } + return username, password, nil +} diff --git a/server/server.go b/server/server.go index 641fd4b..ca5c861 100755 --- a/server/server.go +++ b/server/server.go @@ -1,589 +1,600 @@ -package server - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "net/url" - "strings" - "time" - - "github.com/go-oauth2/oauth2/v4" - "github.com/go-oauth2/oauth2/v4/errors" -) - -// NewDefaultServer create a default authorization server -func NewDefaultServer(manager oauth2.Manager) *Server { - return NewServer(NewConfig(), manager) -} - -// NewServer create authorization server -func NewServer(cfg *Config, manager oauth2.Manager) *Server { - srv := &Server{ - Config: cfg, - Manager: manager, - } - - // default handler - srv.ClientInfoHandler = ClientBasicHandler - - srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) { - return "", errors.ErrAccessDenied - } - - srv.PasswordAuthorizationHandler = func(username, password string) (string, error) { - return "", errors.ErrAccessDenied - } - return srv -} - -// Server Provide authorization server -type Server struct { - Config *Config - Manager oauth2.Manager - ClientInfoHandler ClientInfoHandler - ClientAuthorizedHandler ClientAuthorizedHandler - ClientScopeHandler ClientScopeHandler - UserAuthorizationHandler UserAuthorizationHandler - PasswordAuthorizationHandler PasswordAuthorizationHandler - RefreshingValidationHandler RefreshingValidationHandler - RefreshingScopeHandler RefreshingScopeHandler - ResponseErrorHandler ResponseErrorHandler - InternalErrorHandler InternalErrorHandler - ExtensionFieldsHandler ExtensionFieldsHandler - AccessTokenExpHandler AccessTokenExpHandler - AuthorizeScopeHandler AuthorizeScopeHandler -} - -func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { - if req == nil { - return err - } - data, _, _ := s.GetErrorData(err) - return s.redirect(w, req, data) -} - -func (s *Server) redirect(w http.ResponseWriter, req *AuthorizeRequest, data map[string]interface{}) error { - uri, err := s.GetRedirectURI(req, data) - if err != nil { - return err - } - - w.Header().Set("Location", uri) - w.WriteHeader(302) - return nil -} - -func (s *Server) tokenError(w http.ResponseWriter, err error) error { - data, statusCode, header := s.GetErrorData(err) - return s.token(w, data, header, statusCode) -} - -func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error { - w.Header().Set("Content-Type", "application/json;charset=UTF-8") - w.Header().Set("Cache-Control", "no-store") - w.Header().Set("Pragma", "no-cache") - - for key := range header { - w.Header().Set(key, header.Get(key)) - } - - status := http.StatusOK - if len(statusCode) > 0 && statusCode[0] > 0 { - status = statusCode[0] - } - - w.WriteHeader(status) - return json.NewEncoder(w).Encode(data) -} - -// GetRedirectURI get redirect uri -func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface{}) (string, error) { - u, err := url.Parse(req.RedirectURI) - if err != nil { - return "", err - } - - q := u.Query() - if req.State != "" { - q.Set("state", req.State) - } - - for k, v := range data { - q.Set(k, fmt.Sprint(v)) - } - - switch req.ResponseType { - case oauth2.Code: - u.RawQuery = q.Encode() - case oauth2.Token: - u.RawQuery = "" - fragment, err := url.QueryUnescape(q.Encode()) - if err != nil { - return "", err - } - u.Fragment = fragment - } - - return u.String(), nil -} - -// CheckResponseType check allows response type -func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool { - for _, art := range s.Config.AllowedResponseTypes { - if art == rt { - return true - } - } - return false -} - -// CheckCodeChallengeMethod checks for allowed code challenge method -func (s *Server) CheckCodeChallengeMethod(ccm oauth2.CodeChallengeMethod) bool { - for _, c := range s.Config.AllowedCodeChallengeMethods { - if c == ccm { - return true - } - } - return false -} - -// ValidationAuthorizeRequest the authorization request validation -func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) { - redirectURI := r.FormValue("redirect_uri") - clientID := r.FormValue("client_id") - if !(r.Method == "GET" || r.Method == "POST") || - clientID == "" { - return nil, errors.ErrInvalidRequest - } - - resType := oauth2.ResponseType(r.FormValue("response_type")) - if resType.String() == "" { - return nil, errors.ErrUnsupportedResponseType - } else if allowed := s.CheckResponseType(resType); !allowed { - return nil, errors.ErrUnauthorizedClient - } - - cc := r.FormValue("code_challenge") - if cc == "" && s.Config.ForcePKCE { - return nil, errors.ErrCodeChallengeRquired - } - if cc != "" && (len(cc) < 43 || len(cc) > 128) { - return nil, errors.ErrInvalidCodeChallengeLen - } - - ccm := oauth2.CodeChallengeMethod(r.FormValue("code_challenge_method")) - // set default - if ccm == "" { - ccm = oauth2.CodeChallengePlain - } - if ccm.String() != "" && !s.CheckCodeChallengeMethod(ccm) { - return nil, errors.ErrUnsupportedCodeChallengeMethod - } - - req := &AuthorizeRequest{ - RedirectURI: redirectURI, - ResponseType: resType, - ClientID: clientID, - State: r.FormValue("state"), - Scope: r.FormValue("scope"), - Request: r, - CodeChallenge: cc, - CodeChallengeMethod: ccm, - } - return req, nil -} - -// GetAuthorizeToken get authorization token(code) -func (s *Server) GetAuthorizeToken(ctx context.Context, req *AuthorizeRequest) (oauth2.TokenInfo, error) { - // check the client allows the grant type - if fn := s.ClientAuthorizedHandler; fn != nil { - gt := oauth2.AuthorizationCode - if req.ResponseType == oauth2.Token { - gt = oauth2.Implicit - } - - allowed, err := fn(req.ClientID, gt) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrUnauthorizedClient - } - } - - // check the client allows the authorized scope - if fn := s.ClientScopeHandler; fn != nil { - tgr := &oauth2.TokenGenerateRequest{ - ClientID: req.ClientID, - UserID: req.UserID, - RedirectURI: req.RedirectURI, - Scope: req.Scope, - AccessTokenExp: req.AccessTokenExp, - Request: req.Request, - } - allowed, err := fn(tgr) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrInvalidScope - } - } - - tgr := &oauth2.TokenGenerateRequest{ - ClientID: req.ClientID, - UserID: req.UserID, - RedirectURI: req.RedirectURI, - Scope: req.Scope, - AccessTokenExp: req.AccessTokenExp, - Request: req.Request, - CodeChallenge: req.CodeChallenge, - CodeChallengeMethod: req.CodeChallengeMethod, - } - return s.Manager.GenerateAuthToken(ctx, req.ResponseType, tgr) -} - -// GetAuthorizeData get authorization response data -func (s *Server) GetAuthorizeData(rt oauth2.ResponseType, ti oauth2.TokenInfo) map[string]interface{} { - if rt == oauth2.Code { - return map[string]interface{}{ - "code": ti.GetCode(), - } - } - return s.GetTokenData(ti) -} - -// HandleAuthorizeRequest the authorization request handling -func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - - req, err := s.ValidationAuthorizeRequest(r) - if err != nil { - return s.redirectError(w, req, err) - } - - // user authorization - userID, err := s.UserAuthorizationHandler(w, r) - if err != nil { - return s.redirectError(w, req, err) - } else if userID == "" { - return nil - } - req.UserID = userID - - // specify the scope of authorization - if fn := s.AuthorizeScopeHandler; fn != nil { - scope, err := fn(w, r) - if err != nil { - return err - } else if scope != "" { - req.Scope = scope - } - } - - // specify the expiration time of access token - if fn := s.AccessTokenExpHandler; fn != nil { - exp, err := fn(w, r) - if err != nil { - return err - } - req.AccessTokenExp = exp - } - - ti, err := s.GetAuthorizeToken(ctx, req) - if err != nil { - return s.redirectError(w, req, err) - } - - // If the redirect URI is empty, the default domain provided by the client is used. - if req.RedirectURI == "" { - client, err := s.Manager.GetClient(ctx, req.ClientID) - if err != nil { - return err - } - req.RedirectURI = client.GetDomain() - } - - return s.redirect(w, req, s.GetAuthorizeData(req.ResponseType, ti)) -} - -// ValidationTokenRequest the token request validation -func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oauth2.TokenGenerateRequest, error) { - if v := r.Method; !(v == "POST" || - (s.Config.AllowGetAccessRequest && v == "GET")) { - return "", nil, errors.ErrInvalidRequest - } - - gt := oauth2.GrantType(r.FormValue("grant_type")) - if gt.String() == "" { - return "", nil, errors.ErrUnsupportedGrantType - } - - clientID, clientSecret, err := s.ClientInfoHandler(r) - if err != nil { - return "", nil, err - } - - tgr := &oauth2.TokenGenerateRequest{ - ClientID: clientID, - ClientSecret: clientSecret, - Request: r, - } - - switch gt { - case oauth2.AuthorizationCode: - tgr.RedirectURI = r.FormValue("redirect_uri") - tgr.Code = r.FormValue("code") - if tgr.RedirectURI == "" || - tgr.Code == "" { - return "", nil, errors.ErrInvalidRequest - } - tgr.CodeVerifier = r.FormValue("code_verifier") - if s.Config.ForcePKCE && tgr.CodeVerifier == "" { - return "", nil, errors.ErrInvalidRequest - } - case oauth2.PasswordCredentials: - tgr.Scope = r.FormValue("scope") - username, password := r.FormValue("username"), r.FormValue("password") - if username == "" || password == "" { - return "", nil, errors.ErrInvalidRequest - } - - userID, err := s.PasswordAuthorizationHandler(username, password) - if err != nil { - return "", nil, err - } else if userID == "" { - return "", nil, errors.ErrInvalidGrant - } - tgr.UserID = userID - case oauth2.ClientCredentials: - tgr.Scope = r.FormValue("scope") - case oauth2.Refreshing: - tgr.Refresh = r.FormValue("refresh_token") - tgr.Scope = r.FormValue("scope") - if tgr.Refresh == "" { - return "", nil, errors.ErrInvalidRequest - } - } - return gt, tgr, nil -} - -// CheckGrantType check allows grant type -func (s *Server) CheckGrantType(gt oauth2.GrantType) bool { - for _, agt := range s.Config.AllowedGrantTypes { - if agt == gt { - return true - } - } - return false -} - -// GetAccessToken access token -func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { - if allowed := s.CheckGrantType(gt); !allowed { - return nil, errors.ErrUnauthorizedClient - } - - if fn := s.ClientAuthorizedHandler; fn != nil { - allowed, err := fn(tgr.ClientID, gt) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrUnauthorizedClient - } - } - - switch gt { - case oauth2.AuthorizationCode: - ti, err := s.Manager.GenerateAccessToken(ctx, gt, tgr) - if err != nil { - switch err { - case errors.ErrInvalidAuthorizeCode, errors.ErrInvalidCodeChallenge, errors.ErrMissingCodeChallenge: - return nil, errors.ErrInvalidGrant - case errors.ErrInvalidClient: - return nil, errors.ErrInvalidClient - default: - return nil, err - } - } - return ti, nil - case oauth2.PasswordCredentials, oauth2.ClientCredentials: - if fn := s.ClientScopeHandler; fn != nil { - allowed, err := fn(tgr) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrInvalidScope - } - } - return s.Manager.GenerateAccessToken(ctx, gt, tgr) - case oauth2.Refreshing: - // check scope - if scope, scopeFn := tgr.Scope, s.RefreshingScopeHandler; scope != "" && scopeFn != nil { - rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh) - if err != nil { - if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { - return nil, errors.ErrInvalidGrant - } - return nil, err - } - - allowed, err := scopeFn(tgr, rti.GetScope()) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrInvalidScope - } - } - - if validationFn := s.RefreshingValidationHandler; validationFn != nil { - rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh) - if err != nil { - if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { - return nil, errors.ErrInvalidGrant - } - return nil, err - } - allowed, err := validationFn(rti) - if err != nil { - return nil, err - } else if !allowed { - return nil, errors.ErrInvalidScope - } - } - - ti, err := s.Manager.RefreshAccessToken(ctx, tgr) - if err != nil { - if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { - return nil, errors.ErrInvalidGrant - } - return nil, err - } - return ti, nil - } - - return nil, errors.ErrUnsupportedGrantType -} - -// GetTokenData token data -func (s *Server) GetTokenData(ti oauth2.TokenInfo) map[string]interface{} { - data := map[string]interface{}{ - "access_token": ti.GetAccess(), - "token_type": s.Config.TokenType, - "expires_in": int64(ti.GetAccessExpiresIn() / time.Second), - } - - if scope := ti.GetScope(); scope != "" { - data["scope"] = scope - } - - if refresh := ti.GetRefresh(); refresh != "" { - data["refresh_token"] = refresh - } - - if fn := s.ExtensionFieldsHandler; fn != nil { - ext := fn(ti) - for k, v := range ext { - if _, ok := data[k]; ok { - continue - } - data[k] = v - } - } - return data -} - -// HandleTokenRequest token request handling -func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error { - ctx := r.Context() - - gt, tgr, err := s.ValidationTokenRequest(r) - if err != nil { - return s.tokenError(w, err) - } - - ti, err := s.GetAccessToken(ctx, gt, tgr) - if err != nil { - return s.tokenError(w, err) - } - - return s.token(w, s.GetTokenData(ti), nil) -} - -// GetErrorData get error response data -func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Header) { - var re errors.Response - if v, ok := errors.Descriptions[err]; ok { - re.Error = err - re.Description = v - re.StatusCode = errors.StatusCodes[err] - } else { - if fn := s.InternalErrorHandler; fn != nil { - if v := fn(err); v != nil { - re = *v - } - } - - if re.Error == nil { - re.Error = errors.ErrServerError - re.Description = errors.Descriptions[errors.ErrServerError] - re.StatusCode = errors.StatusCodes[errors.ErrServerError] - } - } - - if fn := s.ResponseErrorHandler; fn != nil { - fn(&re) - } - - data := make(map[string]interface{}) - if err := re.Error; err != nil { - data["error"] = err.Error() - } - - if v := re.ErrorCode; v != 0 { - data["error_code"] = v - } - - if v := re.Description; v != "" { - data["error_description"] = v - } - - if v := re.URI; v != "" { - data["error_uri"] = v - } - - statusCode := http.StatusInternalServerError - if v := re.StatusCode; v > 0 { - statusCode = v - } - - return data, statusCode, re.Header -} - -// BearerAuth parse bearer token -func (s *Server) BearerAuth(r *http.Request) (string, bool) { - auth := r.Header.Get("Authorization") - prefix := "Bearer " - token := "" - - if auth != "" && strings.HasPrefix(auth, prefix) { - token = auth[len(prefix):] - } else { - token = r.FormValue("access_token") - } - - return token, token != "" -} - -// ValidationBearerToken validation the bearer tokens -// https://tools.ietf.org/html/rfc6750 -func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) { - ctx := r.Context() - - accessToken, ok := s.BearerAuth(r) - if !ok { - return nil, errors.ErrInvalidAccessToken - } - - return s.Manager.LoadAccessToken(ctx, accessToken) -} +package server + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/go-oauth2/oauth2/v4" + "github.com/go-oauth2/oauth2/v4/errors" +) + +// NewDefaultServer create a default authorization server +func NewDefaultServer(manager oauth2.Manager) *Server { + return NewServer(NewConfig(), manager) +} + +// NewServer create authorization server +func NewServer(cfg *Config, manager oauth2.Manager) *Server { + srv := &Server{ + Config: cfg, + Manager: manager, + } + + // default handler + srv.ClientInfoHandler = ClientBasicHandler + + srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) { + return "", errors.ErrAccessDenied + } + + srv.PasswordAuthorizationHandler = func(username, password string) (string, error) { + return "", errors.ErrAccessDenied + } + return srv +} + +// Server Provide authorization server +type Server struct { + Config *Config + Manager oauth2.Manager + ClientInfoHandler ClientInfoHandler + ClientAuthorizedHandler ClientAuthorizedHandler + ClientScopeHandler ClientScopeHandler + UserAuthorizationHandler UserAuthorizationHandler + PasswordAuthorizationHandler PasswordAuthorizationHandler + RefreshingValidationHandler RefreshingValidationHandler + RefreshingScopeHandler RefreshingScopeHandler + ResponseErrorHandler ResponseErrorHandler + InternalErrorHandler InternalErrorHandler + ExtensionFieldsHandler ExtensionFieldsHandler + AccessTokenExpHandler AccessTokenExpHandler + AuthorizeScopeHandler AuthorizeScopeHandler +} + +func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { + if req == nil { + return err + } + data, _, _ := s.GetErrorData(err) + return s.redirect(w, req, data) +} + +func (s *Server) redirect(w http.ResponseWriter, req *AuthorizeRequest, data map[string]interface{}) error { + uri, err := s.GetRedirectURI(req, data) + if err != nil { + return err + } + + w.Header().Set("Location", uri) + w.WriteHeader(302) + return nil +} + +func (s *Server) tokenError(w http.ResponseWriter, err error) error { + data, statusCode, header := s.GetErrorData(err) + return s.token(w, data, header, statusCode) +} + +func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error { + w.Header().Set("Content-Type", "application/json;charset=UTF-8") + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Pragma", "no-cache") + + for key := range header { + w.Header().Set(key, header.Get(key)) + } + + status := http.StatusOK + if len(statusCode) > 0 && statusCode[0] > 0 { + status = statusCode[0] + } + + w.WriteHeader(status) + return json.NewEncoder(w).Encode(data) +} + +// GetRedirectURI get redirect uri +func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface{}) (string, error) { + u, err := url.Parse(req.RedirectURI) + if err != nil { + return "", err + } + + q := u.Query() + if req.State != "" { + q.Set("state", req.State) + } + + for k, v := range data { + q.Set(k, fmt.Sprint(v)) + } + + switch req.ResponseType { + case oauth2.Code: + u.RawQuery = q.Encode() + case oauth2.Token: + u.RawQuery = "" + fragment, err := url.QueryUnescape(q.Encode()) + if err != nil { + return "", err + } + u.Fragment = fragment + } + + return u.String(), nil +} + +// CheckResponseType check allows response type +func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool { + for _, art := range s.Config.AllowedResponseTypes { + if art == rt { + return true + } + } + return false +} + +// CheckCodeChallengeMethod checks for allowed code challenge method +func (s *Server) CheckCodeChallengeMethod(ccm oauth2.CodeChallengeMethod) bool { + for _, c := range s.Config.AllowedCodeChallengeMethods { + if c == ccm { + return true + } + } + return false +} + +// ValidationAuthorizeRequest the authorization request validation +func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) { + redirectURI := r.FormValue("redirect_uri") + clientID := r.FormValue("client_id") + if !(r.Method == "GET" || r.Method == "POST") || + clientID == "" { + return nil, errors.ErrInvalidRequest + } + + resType := oauth2.ResponseType(r.FormValue("response_type")) + if resType.String() == "" { + return nil, errors.ErrUnsupportedResponseType + } else if allowed := s.CheckResponseType(resType); !allowed { + return nil, errors.ErrUnauthorizedClient + } + + cc := r.FormValue("code_challenge") + if cc == "" && s.Config.ForcePKCE { + return nil, errors.ErrCodeChallengeRquired + } + if cc != "" && (len(cc) < 43 || len(cc) > 128) { + return nil, errors.ErrInvalidCodeChallengeLen + } + + ccm := oauth2.CodeChallengeMethod(r.FormValue("code_challenge_method")) + // set default + if ccm == "" { + ccm = oauth2.CodeChallengePlain + } + if ccm.String() != "" && !s.CheckCodeChallengeMethod(ccm) { + return nil, errors.ErrUnsupportedCodeChallengeMethod + } + + req := &AuthorizeRequest{ + RedirectURI: redirectURI, + ResponseType: resType, + ClientID: clientID, + State: r.FormValue("state"), + Scope: r.FormValue("scope"), + Request: r, + CodeChallenge: cc, + CodeChallengeMethod: ccm, + } + return req, nil +} + +// GetAuthorizeToken get authorization token(code) +func (s *Server) GetAuthorizeToken(ctx context.Context, req *AuthorizeRequest) (oauth2.TokenInfo, error) { + // check the client allows the grant type + if fn := s.ClientAuthorizedHandler; fn != nil { + gt := oauth2.AuthorizationCode + if req.ResponseType == oauth2.Token { + gt = oauth2.Implicit + } + + allowed, err := fn(req.ClientID, gt) + if err != nil { + return nil, err + } else if !allowed { + return nil, errors.ErrUnauthorizedClient + } + } + + tgr := &oauth2.TokenGenerateRequest{ + ClientID: req.ClientID, + UserID: req.UserID, + RedirectURI: req.RedirectURI, + Scope: req.Scope, + AccessTokenExp: req.AccessTokenExp, + Request: req.Request, + } + + // check the client allows the authorized scope + if fn := s.ClientScopeHandler; fn != nil { + tgr := &oauth2.TokenGenerateRequest{ + ClientID: req.ClientID, + UserID: req.UserID, + RedirectURI: req.RedirectURI, + Scope: req.Scope, + AccessTokenExp: req.AccessTokenExp, + Request: req.Request, + } + + allowed, err := fn(tgr) + if err != nil { + return nil, err + } else if !allowed { + return nil, errors.ErrInvalidScope + } + } + + tgr := &oauth2.TokenGenerateRequest{ + ClientID: req.ClientID, + UserID: req.UserID, + RedirectURI: req.RedirectURI, + Scope: req.Scope, + AccessTokenExp: req.AccessTokenExp, + Request: req.Request, + CodeChallenge: req.CodeChallenge, + CodeChallengeMethod: req.CodeChallengeMethod, + } + + return s.Manager.GenerateAuthToken(ctx, req.ResponseType, tgr) +} + +// GetAuthorizeData get authorization response data +func (s *Server) GetAuthorizeData(rt oauth2.ResponseType, ti oauth2.TokenInfo) map[string]interface{} { + if rt == oauth2.Code { + return map[string]interface{}{ + "code": ti.GetCode(), + } + } + return s.GetTokenData(ti) +} + +// HandleAuthorizeRequest the authorization request handling +func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + + req, err := s.ValidationAuthorizeRequest(r) + if err != nil { + return s.redirectError(w, req, err) + } + + // user authorization + userID, err := s.UserAuthorizationHandler(w, r) + if err != nil { + return s.redirectError(w, req, err) + } else if userID == "" { + return nil + } + req.UserID = userID + + // specify the scope of authorization + if fn := s.AuthorizeScopeHandler; fn != nil { + scope, err := fn(w, r) + if err != nil { + return err + } else if scope != "" { + req.Scope = scope + } + } + + // specify the expiration time of access token + if fn := s.AccessTokenExpHandler; fn != nil { + exp, err := fn(w, r) + if err != nil { + return err + } + req.AccessTokenExp = exp + } + + ti, err := s.GetAuthorizeToken(ctx, req) + if err != nil { + return s.redirectError(w, req, err) + } + + // If the redirect URI is empty, the default domain provided by the client is used. + if req.RedirectURI == "" { + client, err := s.Manager.GetClient(ctx, req.ClientID) + if err != nil { + return err + } + req.RedirectURI = client.GetDomain() + } + + return s.redirect(w, req, s.GetAuthorizeData(req.ResponseType, ti)) +} + +// ValidationTokenRequest the token request validation +func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oauth2.TokenGenerateRequest, error) { + if v := r.Method; !(v == "POST" || + (s.Config.AllowGetAccessRequest && v == "GET")) { + return "", nil, errors.ErrInvalidRequest + } + + gt := oauth2.GrantType(r.FormValue("grant_type")) + if gt.String() == "" { + return "", nil, errors.ErrUnsupportedGrantType + } + + clientID, clientSecret, err := s.ClientInfoHandler(r) + if err != nil { + return "", nil, err + } + + tgr := &oauth2.TokenGenerateRequest{ + ClientID: clientID, + ClientSecret: clientSecret, + Request: r, + } + + switch gt { + case oauth2.AuthorizationCode: + tgr.RedirectURI = r.FormValue("redirect_uri") + tgr.Code = r.FormValue("code") + if tgr.RedirectURI == "" || + tgr.Code == "" { + return "", nil, errors.ErrInvalidRequest + } + tgr.CodeVerifier = r.FormValue("code_verifier") + if s.Config.ForcePKCE && tgr.CodeVerifier == "" { + return "", nil, errors.ErrInvalidRequest + } + case oauth2.PasswordCredentials: + tgr.Scope = r.FormValue("scope") + username, password := r.FormValue("username"), r.FormValue("password") + if username == "" || password == "" { + return "", nil, errors.ErrInvalidRequest + } + + userID, err := s.PasswordAuthorizationHandler(username, password) + if err != nil { + return "", nil, err + } else if userID == "" { + return "", nil, errors.ErrInvalidGrant + } + tgr.UserID = userID + case oauth2.ClientCredentials: + tgr.Scope = r.FormValue("scope") + case oauth2.Refreshing: + tgr.Refresh = r.FormValue("refresh_token") + tgr.Scope = r.FormValue("scope") + if tgr.Refresh == "" { + return "", nil, errors.ErrInvalidRequest + } + } + return gt, tgr, nil +} + +// CheckGrantType check allows grant type +func (s *Server) CheckGrantType(gt oauth2.GrantType) bool { + for _, agt := range s.Config.AllowedGrantTypes { + if agt == gt { + return true + } + } + return false +} + +// GetAccessToken access token +func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { + if allowed := s.CheckGrantType(gt); !allowed { + return nil, errors.ErrUnauthorizedClient + } + + if fn := s.ClientAuthorizedHandler; fn != nil { + allowed, err := fn(tgr.ClientID, gt) + if err != nil { + return nil, err + } else if !allowed { + return nil, errors.ErrUnauthorizedClient + } + } + + switch gt { + case oauth2.AuthorizationCode: + ti, err := s.Manager.GenerateAccessToken(ctx, gt, tgr) + if err != nil { + switch err { + case errors.ErrInvalidAuthorizeCode, errors.ErrInvalidCodeChallenge, errors.ErrMissingCodeChallenge: + return nil, errors.ErrInvalidGrant + case errors.ErrInvalidClient: + return nil, errors.ErrInvalidClient + default: + return nil, err + } + } + return ti, nil + case oauth2.PasswordCredentials, oauth2.ClientCredentials: + if fn := s.ClientScopeHandler; fn != nil { + allowed, err := fn(tgr) + if err != nil { + return nil, err + } else if !allowed { + return nil, errors.ErrInvalidScope + } + } + return s.Manager.GenerateAccessToken(ctx, gt, tgr) + case oauth2.Refreshing: + // check scope + if scopeFn := s.RefreshingScopeHandler; tgr.Scope != "" && scopeFn != nil { + rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh) + if err != nil { + if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { + return nil, errors.ErrInvalidGrant + } + return nil, err + } + + allowed, err := scopeFn(tgr, rti.GetScope()) + if err != nil { + return nil, err + } else if !allowed { + return nil, errors.ErrInvalidScope + } + } + + if validationFn := s.RefreshingValidationHandler; validationFn != nil { + rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh) + if err != nil { + if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { + return nil, errors.ErrInvalidGrant + } + return nil, err + } + allowed, err := validationFn(rti) + if err != nil { + return nil, err + } else if !allowed { + return nil, errors.ErrInvalidScope + } + } + + ti, err := s.Manager.RefreshAccessToken(ctx, tgr) + if err != nil { + if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { + return nil, errors.ErrInvalidGrant + } + return nil, err + } + return ti, nil + } + + return nil, errors.ErrUnsupportedGrantType +} + +// GetTokenData token data +func (s *Server) GetTokenData(ti oauth2.TokenInfo) map[string]interface{} { + data := map[string]interface{}{ + "access_token": ti.GetAccess(), + "token_type": s.Config.TokenType, + "expires_in": int64(ti.GetAccessExpiresIn() / time.Second), + } + + if scope := ti.GetScope(); scope != "" { + data["scope"] = scope + } + + if refresh := ti.GetRefresh(); refresh != "" { + data["refresh_token"] = refresh + } + + if fn := s.ExtensionFieldsHandler; fn != nil { + ext := fn(ti) + for k, v := range ext { + if _, ok := data[k]; ok { + continue + } + data[k] = v + } + } + return data +} + +// HandleTokenRequest token request handling +func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + + gt, tgr, err := s.ValidationTokenRequest(r) + if err != nil { + return s.tokenError(w, err) + } + + ti, err := s.GetAccessToken(ctx, gt, tgr) + if err != nil { + return s.tokenError(w, err) + } + + return s.token(w, s.GetTokenData(ti), nil) +} + +// GetErrorData get error response data +func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Header) { + var re errors.Response + if v, ok := errors.Descriptions[err]; ok { + re.Error = err + re.Description = v + re.StatusCode = errors.StatusCodes[err] + } else { + if fn := s.InternalErrorHandler; fn != nil { + if v := fn(err); v != nil { + re = *v + } + } + + if re.Error == nil { + re.Error = errors.ErrServerError + re.Description = errors.Descriptions[errors.ErrServerError] + re.StatusCode = errors.StatusCodes[errors.ErrServerError] + } + } + + if fn := s.ResponseErrorHandler; fn != nil { + fn(&re) + } + + data := make(map[string]interface{}) + if err := re.Error; err != nil { + data["error"] = err.Error() + } + + if v := re.ErrorCode; v != 0 { + data["error_code"] = v + } + + if v := re.Description; v != "" { + data["error_description"] = v + } + + if v := re.URI; v != "" { + data["error_uri"] = v + } + + statusCode := http.StatusInternalServerError + if v := re.StatusCode; v > 0 { + statusCode = v + } + + return data, statusCode, re.Header +} + +// BearerAuth parse bearer token +func (s *Server) BearerAuth(r *http.Request) (string, bool) { + auth := r.Header.Get("Authorization") + prefix := "Bearer " + token := "" + + if auth != "" && strings.HasPrefix(auth, prefix) { + token = auth[len(prefix):] + } else { + token = r.FormValue("access_token") + } + + return token, token != "" +} + +// ValidationBearerToken validation the bearer tokens +// https://tools.ietf.org/html/rfc6750 +func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) { + ctx := r.Context() + + accessToken, ok := s.BearerAuth(r) + if !ok { + return nil, errors.ErrInvalidAccessToken + } + + return s.Manager.LoadAccessToken(ctx, accessToken) +}