diff --git a/README.md b/README.md index d24ed10..ce43714 100644 --- a/README.md +++ b/README.md @@ -139,8 +139,8 @@ Copyright (c) 2016 Lyric [License-Image]: https://img.shields.io/npm/l/express.svg [Build-Status-Url]: https://travis-ci.org/go-oauth2/oauth2 [Build-Status-Image]: https://travis-ci.org/go-oauth2/oauth2.svg?branch=master -[Release-Url]: https://github.com/go-oauth2/oauth2/releases/tag/v3.6.2 -[Release-image]: http://img.shields.io/badge/release-v3.6.2-1eb0fc.svg +[Release-Url]: https://github.com/go-oauth2/oauth2/releases/tag/v3.6.3 +[Release-image]: http://img.shields.io/badge/release-v3.6.3-1eb0fc.svg [ReportCard-Url]: https://goreportcard.com/report/gopkg.in/oauth2.v3 [ReportCard-Image]: https://goreportcard.com/badge/gopkg.in/oauth2.v3 [GoDoc-Url]: https://godoc.org/gopkg.in/oauth2.v3 diff --git a/server/server.go b/server/server.go index 1ebebad..9d30834 100644 --- a/server/server.go +++ b/server/server.go @@ -22,16 +22,20 @@ func NewServer(cfg *Config, manager oauth2.Manager) *Server { if err := manager.CheckInterface(); err != nil { panic(err) } + srv := &Server{ Config: cfg, Manager: manager, } + // default handler srv.ClientInfoHandler = ClientBasicHandler + srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (userID string, err error) { err = errors.ErrAccessDenied return } + srv.PasswordAuthorizationHandler = func(username, password string) (userID string, err error) { err = errors.ErrAccessDenied return @@ -86,10 +90,12 @@ func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, statu w.Header().Set("Content-Type", "application/json;charset=UTF-8") w.Header().Set("Cache-Control", "no-store") w.Header().Set("Pragma", "no-cache") + status := http.StatusOK if len(statusCode) > 0 && statusCode[0] > 0 { status = statusCode[0] } + w.WriteHeader(status) err = json.NewEncoder(w).Encode(data) return @@ -101,13 +107,16 @@ func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface if err != nil { return } + q := u.Query() if req.State != "" { q.Set("state", req.State) } + for k, v := range data { q.Set(k, fmt.Sprint(v)) } + switch req.ResponseType { case oauth2.Code: u.RawQuery = q.Encode() @@ -118,6 +127,7 @@ func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface return } } + uri = u.String() return } @@ -138,6 +148,7 @@ func (s *Server) ValidationAuthorizeRequest(r *http.Request) (req *AuthorizeRequ if err != nil { return } + clientID := r.FormValue("client_id") if r.Method != "GET" || clientID == "" || @@ -147,6 +158,7 @@ func (s *Server) ValidationAuthorizeRequest(r *http.Request) (req *AuthorizeRequ } resType := oauth2.ResponseType(r.FormValue("response_type")) + if resType.String() == "" { err = errors.ErrUnsupportedResponseType return @@ -170,9 +182,11 @@ func (s *Server) GetAuthorizeToken(req *AuthorizeRequest) (ti oauth2.TokenInfo, // check the client allows the grant type if fn := s.ClientAuthorizedHandler; fn != nil { gt := oauth2.AuthorizationCode + if req.ResponseType == oauth2.Token { gt = oauth2.Implicit } + allowed, verr := fn(req.ClientID, gt) if verr != nil { err = verr @@ -185,6 +199,7 @@ func (s *Server) GetAuthorizeToken(req *AuthorizeRequest) (ti oauth2.TokenInfo, // check the client allows the authorized scope if fn := s.ClientScopeHandler; fn != nil { + allowed, verr := fn(req.ClientID, req.Scope) if verr != nil { err = verr @@ -202,6 +217,7 @@ func (s *Server) GetAuthorizeToken(req *AuthorizeRequest) (ti oauth2.TokenInfo, Scope: req.Scope, AccessTokenExp: req.AccessTokenExp, } + ti, err = s.Manager.GenerateAuthToken(req.ResponseType, tgr) return } @@ -228,16 +244,19 @@ func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) // user authorization userID, verr := s.UserAuthorizationHandler(w, r) + if verr != nil { err = s.redirectError(w, req, verr) return } else if userID == "" { return } + req.UserID = userID // specify the scope of authorization if fn := s.AuthorizeScopeHandler; fn != nil { + scope, verr := fn(w, r) if verr != nil { err = verr @@ -249,6 +268,7 @@ func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) // specify the expiration time of access token if fn := s.AccessTokenExpHandler; fn != nil { + exp, verr := fn(w, r) if verr != nil { err = verr @@ -262,6 +282,7 @@ func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) err = s.redirectError(w, req, verr) return } + err = s.redirect(w, req, s.GetAuthorizeData(req.ResponseType, ti)) return } @@ -273,23 +294,29 @@ func (s *Server) ValidationTokenRequest(r *http.Request) (gt oauth2.GrantType, t err = errors.ErrInvalidRequest return } + gt = oauth2.GrantType(r.FormValue("grant_type")) + if gt.String() == "" { err = errors.ErrUnsupportedGrantType return } + clientID, clientSecret, err := s.ClientInfoHandler(r) if err != nil { return } + tgr = &oauth2.TokenGenerateRequest{ ClientID: clientID, ClientSecret: clientSecret, } + switch gt { case oauth2.AuthorizationCode: tgr.RedirectURI = r.FormValue("redirect_uri") tgr.Code = r.FormValue("code") + if tgr.RedirectURI == "" || tgr.Code == "" { err = errors.ErrInvalidRequest @@ -297,20 +324,29 @@ func (s *Server) ValidationTokenRequest(r *http.Request) (gt oauth2.GrantType, t } case oauth2.PasswordCredentials: tgr.Scope = r.FormValue("scope") - userID, verr := s.PasswordAuthorizationHandler(r.FormValue("username"), r.FormValue("password")) + username, password := r.FormValue("username"), r.FormValue("password") + + if username == "" || password == "" { + err = errors.ErrInvalidRequest + return + } + + userID, verr := s.PasswordAuthorizationHandler(username, password) if verr != nil { err = verr return } else if userID == "" { - err = errors.ErrInvalidRequest + err = errors.ErrInvalidGrant return } + tgr.UserID = userID case oauth2.ClientCredentials: tgr.Scope = r.FormValue("scope") case oauth2.Refreshing: tgr.Refresh = r.FormValue("refresh_token") tgr.Scope = r.FormValue("scope") + if tgr.Refresh == "" { err = errors.ErrInvalidRequest } @@ -350,6 +386,7 @@ func (s *Server) GetAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRe case oauth2.AuthorizationCode: ati, verr := s.Manager.GenerateAccessToken(gt, tgr) if verr != nil { + if verr == errors.ErrInvalidAuthorizeCode { err = errors.ErrInvalidGrant } else if verr == errors.ErrInvalidClient { @@ -362,6 +399,7 @@ func (s *Server) GetAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRe ti = ati case oauth2.PasswordCredentials, oauth2.ClientCredentials: if fn := s.ClientScopeHandler; fn != nil { + allowed, verr := fn(tgr.ClientID, tgr.Scope) if verr != nil { err = verr @@ -375,6 +413,7 @@ func (s *Server) GetAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRe case oauth2.Refreshing: // check scope if scope, scopeFn := tgr.Scope, s.RefreshingScopeHandler; scope != "" && scopeFn != nil { + rti, verr := s.Manager.LoadRefreshToken(tgr.Refresh) if verr != nil { if verr == errors.ErrInvalidRefreshToken || verr == errors.ErrExpiredRefreshToken { @@ -394,6 +433,7 @@ func (s *Server) GetAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRe return } } + rti, verr := s.Manager.RefreshAccessToken(tgr) if verr != nil { if verr == errors.ErrInvalidRefreshToken || verr == errors.ErrExpiredRefreshToken { @@ -416,12 +456,15 @@ func (s *Server) GetTokenData(ti oauth2.TokenInfo) (data map[string]interface{}) "token_type": s.Config.TokenType, "expires_in": int64(ti.GetAccessExpiresIn() / time.Second), } + if scope := ti.GetScope(); scope != "" { data["scope"] = scope } + if refresh := ti.GetRefresh(); refresh != "" { data["refresh_token"] = refresh } + if fn := s.ExtensionFieldsHandler; fn != nil { ext := fn(ti) for k, v := range ext { @@ -441,6 +484,7 @@ func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) (err err = s.tokenError(w, verr) return } + ti, verr := s.GetAccessToken(gt, tgr) if verr != nil { err = s.tokenError(w, verr) @@ -479,16 +523,21 @@ func (s *Server) GetErrorData(err error) (data map[string]interface{}, statusCod data = map[string]interface{}{ "error": re.Error.Error(), } + if v := re.ErrorCode; v != 0 { data["error_code"] = v } + if v := re.Description; v != "" { data["error_description"] = v } + if v := re.URI; v != "" { data["error_uri"] = v } + statusCode = 400 + if v := re.StatusCode; v > 0 { statusCode = v } @@ -521,6 +570,7 @@ func (s *Server) ValidationBearerToken(r *http.Request) (ti oauth2.TokenInfo, er err = errors.ErrInvalidAccessToken return } + ti, err = s.Manager.LoadAccessToken(accessToken) return