From 837c19c7b0e726c76d0168df9c1753eac2e2e060 Mon Sep 17 00:00:00 2001 From: lyric Date: Sat, 23 Jul 2016 22:23:07 +0800 Subject: [PATCH 1/2] Modify token store --- store/token.go | 78 +++++++++++++++++++++++++++++++------------------- 1 file changed, 48 insertions(+), 30 deletions(-) diff --git a/store/token.go b/store/token.go index b774410..e6794b8 100644 --- a/store/token.go +++ b/store/token.go @@ -31,50 +31,65 @@ type MemoryTokenStore struct { gcInterval time.Duration globalID int64 lock sync.RWMutex - basicList *list.List data map[string]oauth2.TokenInfo access map[string]string refresh map[string]string + basicList *list.List + listLock sync.RWMutex } func (mts *MemoryTokenStore) gc() { time.AfterFunc(mts.gcInterval, func() { defer mts.gc() - mts.lock.RLock() + rmeles := make([]*list.Element, 0, 32) + mts.listLock.RLock() ele := mts.basicList.Front() - mts.lock.RUnlock() - if ele == nil { - return + mts.listLock.RUnlock() + for ele != nil { + if rm := mts.gcElement(ele); rm { + rmeles = append(rmeles, ele) + } + mts.listLock.RLock() + ele = ele.Next() + mts.listLock.RUnlock() + } + + for _, e := range rmeles { + mts.listLock.Lock() + mts.basicList.Remove(e) + mts.listLock.Unlock() } - basicID := ele.Value.(string) + }) +} + +func (mts *MemoryTokenStore) gcElement(ele *list.Element) (rm bool) { + basicID := ele.Value.(string) + mts.lock.RLock() + ti, ok := mts.data[basicID] + mts.lock.RUnlock() + if !ok { + rm = true + return + } + ct := time.Now() + if refresh := ti.GetRefresh(); refresh != "" && + ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) { mts.lock.RLock() - ti, ok := mts.data[basicID] + delete(mts.access, ti.GetAccess()) + delete(mts.refresh, refresh) + delete(mts.data, basicID) mts.lock.RUnlock() - if !ok { - mts.lock.Lock() - mts.basicList.Remove(ele) - mts.lock.Unlock() - return - } - ct := time.Now() - if refresh := ti.GetRefresh(); refresh != "" && - ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) { - mts.lock.RLock() - delete(mts.access, ti.GetAccess()) - delete(mts.refresh, refresh) + rm = true + } else if ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) { + mts.lock.RLock() + delete(mts.access, ti.GetAccess()) + if refresh := ti.GetRefresh(); refresh == "" { delete(mts.data, basicID) - mts.basicList.Remove(ele) - mts.lock.RUnlock() - } else if ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) { - mts.lock.RLock() - delete(mts.access, ti.GetAccess()) - if refresh := ti.GetRefresh(); refresh == "" { - delete(mts.data, basicID) - mts.basicList.Remove(ele) - } - mts.lock.RUnlock() + rm = true } - }) + mts.lock.RUnlock() + } + return } func (mts *MemoryTokenStore) getBasicID(id int64, info oauth2.TokenInfo) string { @@ -92,7 +107,10 @@ func (mts *MemoryTokenStore) Create(info oauth2.TokenInfo) (err error) { if refresh := info.GetRefresh(); refresh != "" { mts.refresh[refresh] = basicID } + + mts.listLock.Lock() mts.basicList.PushBack(basicID) + mts.listLock.Unlock() return } From 683529c8920244ca39760de1410fb750b334e241 Mon Sep 17 00:00:00 2001 From: lyric Date: Sat, 23 Jul 2016 22:40:04 +0800 Subject: [PATCH 2/2] Modify internal error handler --- example/server/main.go | 4 ++-- server/handler.go | 2 +- server/server.go | 31 +++++++++++++++++++++---------- 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/example/server/main.go b/example/server/main.go index 0c98dda..2ad0b6c 100644 --- a/example/server/main.go +++ b/example/server/main.go @@ -37,8 +37,8 @@ func main() { srv := server.NewServer(server.NewConfig(), manager) srv.SetUserAuthorizationHandler(userAuthorizeHandler) - srv.SetInternalErrorHandler(func(err error) { - fmt.Println("OAuth2 Error:", err.Error()) + srv.SetInternalErrorHandler(func(r *http.Request, err error) { + fmt.Println("OAuth2 Error:", r.RequestURI, err.Error()) }) http.HandleFunc("/login", loginHandler) diff --git a/server/handler.go b/server/handler.go index 56be634..dfbed41 100644 --- a/server/handler.go +++ b/server/handler.go @@ -29,7 +29,7 @@ type RefreshingScopeHandler func(newScope, oldScope string) (allowed bool) type ResponseErrorHandler func(re *errors.Response) // InternalErrorHandler Internal error handing -type InternalErrorHandler func(err error) +type InternalErrorHandler func(req *http.Request, err error) // ClientFormHandler Get client data from form func ClientFormHandler(r *http.Request) (clientID, clientSecret string, err error) { diff --git a/server/server.go b/server/server.go index 36540f3..7480bb1 100644 --- a/server/server.go +++ b/server/server.go @@ -228,9 +228,6 @@ func (s *Server) GetAuthorizeData(rt oauth2.ResponseType, ti oauth2.TokenInfo) ( func (s *Server) GetErrorData(rerr, ierr error) (data map[string]interface{}, statusCode int) { if ierr != nil { rerr = errors.ErrServerError - if fn := s.InternalErrorHandler; fn != nil { - fn(ierr) - } } re := &errors.Response{ Error: rerr, @@ -253,11 +250,18 @@ func (s *Server) GetErrorData(rerr, ierr error) (data map[string]interface{}, st } // response redirect error -func (s *Server) resRedirectError(w http.ResponseWriter, req *AuthorizeRequest, rerr, ierr error) (err error) { +func (s *Server) resRedirectError(w http.ResponseWriter, r *http.Request, req *AuthorizeRequest, rerr, ierr error) (err error) { if req == nil { err = ierr return } + if fn := s.InternalErrorHandler; fn != nil { + verr := ierr + if verr == nil { + verr = rerr + } + fn(r, verr) + } data, _ := s.GetErrorData(rerr, ierr) err = s.resRedirect(w, req, data) return @@ -283,12 +287,12 @@ func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) }() req, rerr, ierr := s.ValidationAuthorizeRequest(r) if rerr != nil || ierr != nil { - err = s.resRedirectError(w, req, rerr, ierr) + err = s.resRedirectError(w, r, req, rerr, ierr) return } userID, err := s.UserAuthorizationHandler(w, r) if err != nil { - err = s.resRedirectError(w, req, nil, err) + err = s.resRedirectError(w, r, req, nil, err) return } else if userID == "" { return @@ -296,7 +300,7 @@ func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) req.UserID = userID ti, rerr, ierr := s.GetAuthorizeToken(req) if rerr != nil || ierr != nil { - err = s.resRedirectError(w, req, rerr, ierr) + err = s.resRedirectError(w, r, req, rerr, ierr) return } err = s.resRedirect(w, req, s.GetAuthorizeData(req.ResponseType, ti)) @@ -460,19 +464,26 @@ func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) (err }() gt, tgr, rerr, ierr := s.ValidationTokenRequest(r) if rerr != nil || ierr != nil { - err = s.resTokenError(w, rerr, ierr) + err = s.resTokenError(w, r, rerr, ierr) return } ti, rerr, ierr := s.GetAccessToken(gt, tgr) if rerr != nil || ierr != nil { - err = s.resTokenError(w, rerr, ierr) + err = s.resTokenError(w, r, rerr, ierr) return } err = s.resToken(w, s.GetTokenData(ti)) return } -func (s *Server) resTokenError(w http.ResponseWriter, rerr, ierr error) (err error) { +func (s *Server) resTokenError(w http.ResponseWriter, r *http.Request, rerr, ierr error) (err error) { + if fn := s.InternalErrorHandler; fn != nil { + verr := ierr + if verr == nil { + verr = rerr + } + fn(r, verr) + } data, statusCode := s.GetErrorData(rerr, ierr) s.resToken(w, data, statusCode) return