Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ Copyright (c) 2016 Lyric
[License-Image]: https://img.shields.io/npm/l/express.svg
[Build-Status-Url]: https://travis-ci.org/go-oauth2/oauth2
[Build-Status-Image]: https://travis-ci.org/go-oauth2/oauth2.svg?branch=master
[Release-Url]: https://github.com/go-oauth2/oauth2/releases/tag/v3.6.2
[Release-image]: http://img.shields.io/badge/release-v3.6.2-1eb0fc.svg
[Release-Url]: https://github.com/go-oauth2/oauth2/releases/tag/v3.6.3
[Release-image]: http://img.shields.io/badge/release-v3.6.3-1eb0fc.svg
[ReportCard-Url]: https://goreportcard.com/report/gopkg.in/oauth2.v3
[ReportCard-Image]: https://goreportcard.com/badge/gopkg.in/oauth2.v3
[GoDoc-Url]: https://godoc.org/gopkg.in/oauth2.v3
Expand Down
54 changes: 52 additions & 2 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,20 @@ func NewServer(cfg *Config, manager oauth2.Manager) *Server {
if err := manager.CheckInterface(); err != nil {
panic(err)
}

srv := &Server{
Config: cfg,
Manager: manager,
}

// default handler
srv.ClientInfoHandler = ClientBasicHandler

srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (userID string, err error) {
err = errors.ErrAccessDenied
return
}

srv.PasswordAuthorizationHandler = func(username, password string) (userID string, err error) {
err = errors.ErrAccessDenied
return
Expand Down Expand Up @@ -86,10 +90,12 @@ func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, statu
w.Header().Set("Content-Type", "application/json;charset=UTF-8")
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Pragma", "no-cache")

status := http.StatusOK
if len(statusCode) > 0 && statusCode[0] > 0 {
status = statusCode[0]
}

w.WriteHeader(status)
err = json.NewEncoder(w).Encode(data)
return
Expand All @@ -101,13 +107,16 @@ func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface
if err != nil {
return
}

q := u.Query()
if req.State != "" {
q.Set("state", req.State)
}

for k, v := range data {
q.Set(k, fmt.Sprint(v))
}

switch req.ResponseType {
case oauth2.Code:
u.RawQuery = q.Encode()
Expand All @@ -118,6 +127,7 @@ func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface
return
}
}

uri = u.String()
return
}
Expand All @@ -138,6 +148,7 @@ func (s *Server) ValidationAuthorizeRequest(r *http.Request) (req *AuthorizeRequ
if err != nil {
return
}

clientID := r.FormValue("client_id")
if r.Method != "GET" ||
clientID == "" ||
Expand All @@ -147,6 +158,7 @@ func (s *Server) ValidationAuthorizeRequest(r *http.Request) (req *AuthorizeRequ
}

resType := oauth2.ResponseType(r.FormValue("response_type"))

if resType.String() == "" {
err = errors.ErrUnsupportedResponseType
return
Expand All @@ -170,9 +182,11 @@ func (s *Server) GetAuthorizeToken(req *AuthorizeRequest) (ti oauth2.TokenInfo,
// check the client allows the grant type
if fn := s.ClientAuthorizedHandler; fn != nil {
gt := oauth2.AuthorizationCode

if req.ResponseType == oauth2.Token {
gt = oauth2.Implicit
}

allowed, verr := fn(req.ClientID, gt)
if verr != nil {
err = verr
Expand All @@ -185,6 +199,7 @@ func (s *Server) GetAuthorizeToken(req *AuthorizeRequest) (ti oauth2.TokenInfo,

// check the client allows the authorized scope
if fn := s.ClientScopeHandler; fn != nil {

allowed, verr := fn(req.ClientID, req.Scope)
if verr != nil {
err = verr
Expand All @@ -202,6 +217,7 @@ func (s *Server) GetAuthorizeToken(req *AuthorizeRequest) (ti oauth2.TokenInfo,
Scope: req.Scope,
AccessTokenExp: req.AccessTokenExp,
}

ti, err = s.Manager.GenerateAuthToken(req.ResponseType, tgr)
return
}
Expand All @@ -228,16 +244,19 @@ func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request)

// user authorization
userID, verr := s.UserAuthorizationHandler(w, r)

if verr != nil {
err = s.redirectError(w, req, verr)
return
} else if userID == "" {
return
}

req.UserID = userID

// specify the scope of authorization
if fn := s.AuthorizeScopeHandler; fn != nil {

scope, verr := fn(w, r)
if verr != nil {
err = verr
Expand All @@ -249,6 +268,7 @@ func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request)

// specify the expiration time of access token
if fn := s.AccessTokenExpHandler; fn != nil {

exp, verr := fn(w, r)
if verr != nil {
err = verr
Expand All @@ -262,6 +282,7 @@ func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request)
err = s.redirectError(w, req, verr)
return
}

err = s.redirect(w, req, s.GetAuthorizeData(req.ResponseType, ti))
return
}
Expand All @@ -273,44 +294,59 @@ func (s *Server) ValidationTokenRequest(r *http.Request) (gt oauth2.GrantType, t
err = errors.ErrInvalidRequest
return
}

gt = oauth2.GrantType(r.FormValue("grant_type"))

if gt.String() == "" {
err = errors.ErrUnsupportedGrantType
return
}

clientID, clientSecret, err := s.ClientInfoHandler(r)
if err != nil {
return
}

tgr = &oauth2.TokenGenerateRequest{
ClientID: clientID,
ClientSecret: clientSecret,
}

switch gt {
case oauth2.AuthorizationCode:
tgr.RedirectURI = r.FormValue("redirect_uri")
tgr.Code = r.FormValue("code")

if tgr.RedirectURI == "" ||
tgr.Code == "" {
err = errors.ErrInvalidRequest
return
}
case oauth2.PasswordCredentials:
tgr.Scope = r.FormValue("scope")
userID, verr := s.PasswordAuthorizationHandler(r.FormValue("username"), r.FormValue("password"))
username, password := r.FormValue("username"), r.FormValue("password")

if username == "" || password == "" {
err = errors.ErrInvalidRequest
return
}

userID, verr := s.PasswordAuthorizationHandler(username, password)
if verr != nil {
err = verr
return
} else if userID == "" {
err = errors.ErrInvalidRequest
err = errors.ErrInvalidGrant
return
}

tgr.UserID = userID
case oauth2.ClientCredentials:
tgr.Scope = r.FormValue("scope")
case oauth2.Refreshing:
tgr.Refresh = r.FormValue("refresh_token")
tgr.Scope = r.FormValue("scope")

if tgr.Refresh == "" {
err = errors.ErrInvalidRequest
}
Expand Down Expand Up @@ -350,6 +386,7 @@ func (s *Server) GetAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRe
case oauth2.AuthorizationCode:
ati, verr := s.Manager.GenerateAccessToken(gt, tgr)
if verr != nil {

if verr == errors.ErrInvalidAuthorizeCode {
err = errors.ErrInvalidGrant
} else if verr == errors.ErrInvalidClient {
Expand All @@ -362,6 +399,7 @@ func (s *Server) GetAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRe
ti = ati
case oauth2.PasswordCredentials, oauth2.ClientCredentials:
if fn := s.ClientScopeHandler; fn != nil {

allowed, verr := fn(tgr.ClientID, tgr.Scope)
if verr != nil {
err = verr
Expand All @@ -375,6 +413,7 @@ func (s *Server) GetAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRe
case oauth2.Refreshing:
// check scope
if scope, scopeFn := tgr.Scope, s.RefreshingScopeHandler; scope != "" && scopeFn != nil {

rti, verr := s.Manager.LoadRefreshToken(tgr.Refresh)
if verr != nil {
if verr == errors.ErrInvalidRefreshToken || verr == errors.ErrExpiredRefreshToken {
Expand All @@ -394,6 +433,7 @@ func (s *Server) GetAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRe
return
}
}

rti, verr := s.Manager.RefreshAccessToken(tgr)
if verr != nil {
if verr == errors.ErrInvalidRefreshToken || verr == errors.ErrExpiredRefreshToken {
Expand All @@ -416,12 +456,15 @@ func (s *Server) GetTokenData(ti oauth2.TokenInfo) (data map[string]interface{})
"token_type": s.Config.TokenType,
"expires_in": int64(ti.GetAccessExpiresIn() / time.Second),
}

if scope := ti.GetScope(); scope != "" {
data["scope"] = scope
}

if refresh := ti.GetRefresh(); refresh != "" {
data["refresh_token"] = refresh
}

if fn := s.ExtensionFieldsHandler; fn != nil {
ext := fn(ti)
for k, v := range ext {
Expand All @@ -441,6 +484,7 @@ func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) (err
err = s.tokenError(w, verr)
return
}

ti, verr := s.GetAccessToken(gt, tgr)
if verr != nil {
err = s.tokenError(w, verr)
Expand Down Expand Up @@ -479,16 +523,21 @@ func (s *Server) GetErrorData(err error) (data map[string]interface{}, statusCod
data = map[string]interface{}{
"error": re.Error.Error(),
}

if v := re.ErrorCode; v != 0 {
data["error_code"] = v
}

if v := re.Description; v != "" {
data["error_description"] = v
}

if v := re.URI; v != "" {
data["error_uri"] = v
}

statusCode = 400

if v := re.StatusCode; v > 0 {
statusCode = v
}
Expand Down Expand Up @@ -521,6 +570,7 @@ func (s *Server) ValidationBearerToken(r *http.Request) (ti oauth2.TokenInfo, er
err = errors.ErrInvalidAccessToken
return
}

ti, err = s.Manager.LoadAccessToken(accessToken)

return
Expand Down