From c9f1f320644bfd0d951d8643963e7ab25bb3d366 Mon Sep 17 00:00:00 2001 From: lyric Date: Sat, 25 Jun 2016 15:42:30 +0800 Subject: [PATCH 01/18] All the refactoring --- README.md | 72 +----- TODO.md | 7 + authorizationCode.go | 126 ---------- authorizationCodeGenerate.go | 100 -------- authorizationCodeGenerate_test.go | 38 --- authorizationCodeMemoryStore.go | 87 ------- authorizationCodeMemoryStore_test.go | 44 ---- authorizationCodeRedisStore.go | 89 ------- authorizationCodeRedisStore_test.go | 49 ---- authorizationCodeStore.go | 28 --- authorizationCode_test.go | 48 ---- clientCredentials.go | 40 --- clientCredentials_test.go | 26 -- clientMongoStore.go | 49 ---- clientMongoStore_test.go | 19 -- clientStore.go | 47 ---- config.go | 46 ---- config_redis.go | 41 --- const.go | 49 ++-- error.go | 39 +-- generate.go | 33 +++ implicit.go | 38 --- implicit_test.go | 29 --- manager.go | 127 ++++++++++ model.go | 98 ++++++++ oauth2.go | 359 --------------------------- oauth2_test.go | 43 ---- password.go | 51 ---- password_test.go | 33 --- storage.go | 43 ++++ tokenGenerate.go | 74 ------ tokenGenerate_test.go | 37 --- tokenMongoStore.go | 91 ------- tokenMongoStore_test.go | 40 --- tokenStore.go | 44 ---- util.go | 28 --- util_test.go | 14 -- 37 files changed, 346 insertions(+), 1880 deletions(-) create mode 100644 TODO.md delete mode 100644 authorizationCode.go delete mode 100644 authorizationCodeGenerate.go delete mode 100644 authorizationCodeGenerate_test.go delete mode 100644 authorizationCodeMemoryStore.go delete mode 100644 authorizationCodeMemoryStore_test.go delete mode 100644 authorizationCodeRedisStore.go delete mode 100644 authorizationCodeRedisStore_test.go delete mode 100644 authorizationCodeStore.go delete mode 100644 authorizationCode_test.go delete mode 100644 clientCredentials.go delete mode 100644 clientCredentials_test.go delete mode 100644 clientMongoStore.go delete mode 100644 clientMongoStore_test.go delete mode 100644 clientStore.go delete mode 100644 config.go delete mode 100644 config_redis.go create mode 100644 generate.go delete mode 100644 implicit.go delete mode 100644 implicit_test.go create mode 100644 manager.go create mode 100644 model.go delete mode 100644 oauth2.go delete mode 100644 oauth2_test.go delete mode 100644 password.go delete mode 100644 password_test.go create mode 100644 storage.go delete mode 100644 tokenGenerate.go delete mode 100644 tokenGenerate_test.go delete mode 100644 tokenMongoStore.go delete mode 100644 tokenMongoStore_test.go delete mode 100644 tokenStore.go delete mode 100644 util.go delete mode 100644 util_test.go diff --git a/README.md b/README.md index d6285ef..26ccf08 100644 --- a/README.md +++ b/README.md @@ -1,80 +1,14 @@ Golang OAuth 2.0协议实现 ======================== -[![GoDoc](https://godoc.org/gopkg.in/oauth2.v1?status.svg)](https://godoc.org/gopkg.in/oauth2.v1) -[![Go Report Card](https://goreportcard.com/badge/gopkg.in/oauth2.v1)](https://goreportcard.com/report/gopkg.in/oauth2.v1) +[![GoDoc](https://godoc.org/gopkg.in/oauth2.v2?status.svg)](https://godoc.org/gopkg.in/oauth2.v2) +[![Go Report Card](https://goreportcard.com/badge/gopkg.in/oauth2.v2)](https://goreportcard.com/report/gopkg.in/oauth2.v2) 获取 ---- ```bash -$ go get -v gopkg.in/oauth2.v1 -``` - -范例 ----- - -> 使用之前,初始化客户端信息 - -```go -package main - -import ( - "fmt" - - "gopkg.in/oauth2.v1" -) - -func main() { - // 初始化配置参数 - ocfg := &oauth2.OAuthConfig{ - ACConfig: &oauth2.ACConfig{ - ATExpiresIn: 60 * 60 * 24, - }, - } - mcfg := oauth2.NewMongoConfig("mongodb://127.0.0.1:27017", "test") - - // 创建默认的OAuth2管理实例(基于MongoDB) - manager, err := oauth2.NewDefaultOAuthManager(ocfg, mcfg, "xxx", "xxx") - if err != nil { - panic(err) - } - manager.SetACGenerate(oauth2.NewDefaultACGenerate()) - manager.SetACStore(oauth2.NewACMemoryStore(0)) - - // 模拟授权码模式 - // 使用默认参数,生成授权码 - code, err := manager.GetACManager(). - GenerateCode("clientID_x", "userID_x", "http://www.example.com/cb", "scopes") - if err != nil { - panic(err) - } - - // 生成访问令牌及更新令牌 - genToken, err := manager.GetACManager(). - GenerateToken(code, "http://www.example.com/cb", "clientID_x", "clientSecret_x", true) - if err != nil { - panic(err) - } - - // 检查访问令牌 - checkToken, err := manager.CheckAccessToken(genToken.AccessToken) - if err != nil { - panic(err) - } - - // TODO: 使用用户标识、申请的授权范围响应数据 - fmt.Println(checkToken.UserID, checkToken.Scope) - - // 更新令牌 - newToken, err := manager.RefreshAccessToken(checkToken.RefreshToken, "scopes") - if err != nil { - panic(err) - } - fmt.Println(newToken.AccessToken, newToken.ATExpiresIn) - // TODO: 将新的访问令牌响应给客户端 - -} +$ go get -v gopkg.in/oauth2.v2 ``` 执行测试 diff --git a/TODO.md b/TODO.md new file mode 100644 index 0000000..b6c5a02 --- /dev/null +++ b/TODO.md @@ -0,0 +1,7 @@ +# OAuth2包的重构 + +* 将所有的Storage提取到公共的包 +* 所有文件命名、结构体命名遵循简单、明了 +* 所用的存储使用依赖注入进行管理 +* 移除针对各个授权类型的管理 +* 针对授权码的生成及存储增加独立的函数 \ No newline at end of file diff --git a/authorizationCode.go b/authorizationCode.go deleted file mode 100644 index dc648ce..0000000 --- a/authorizationCode.go +++ /dev/null @@ -1,126 +0,0 @@ -package oauth2 - -import ( - "time" - - "github.com/LyricTian/go.uuid" -) - -// NewACManager 创建授权码模式管理实例 -// oaManager OAuth授权管理 -// config 配置参数(nil则使用默认值) -func NewACManager(oaManager *OAuthManager, config *ACConfig) *ACManager { - if config == nil { - config = new(ACConfig) - } - if config.ACExpiresIn == 0 { - config.ACExpiresIn = DefaultACExpiresIn - } - if config.ATExpiresIn == 0 { - config.ATExpiresIn = DefaultATExpiresIn - } - if config.RTExpiresIn == 0 { - config.RTExpiresIn = DefaultRTExpiresIn - } - acManager := &ACManager{ - oAuthManager: oaManager, - config: config, - } - return acManager -} - -// ACManager 授权码模式管理(Authorization Code Manager) -type ACManager struct { - oAuthManager *OAuthManager // 授权管理 - config *ACConfig // 配置参数 -} - -// GenerateCode 生成授权码 -// clientID 客户端标识 -// userID 用户标识 -// redirectURI 重定向URI -// scopes 应用授权标识 -func (am *ACManager) GenerateCode(clientID, userID, redirectURI, scopes string) (code string, err error) { - cli, err := am.oAuthManager.ValidateClient(clientID, redirectURI) - if err != nil { - return - } - acInfo := ACInfo{ - ClientID: cli.ID(), - UserID: userID, - RedirectURI: redirectURI, - Scope: scopes, - Code: uuid.NewV4().String(), - CreateAt: time.Now().Unix(), - ExpiresIn: time.Duration(am.config.ACExpiresIn) * time.Second, - } - id, err := am.oAuthManager.ACStore.Put(acInfo) - if err != nil { - return - } - acInfo.ID = id - code, err = am.oAuthManager.ACGenerate.Code(&acInfo) - return -} - -// GenerateToken 生成令牌 -// code 授权码 -// redirectURI 重定向URI -// clientID 客户端标识 -// clientSecret 客户端秘钥 -// isGenerateRefresh 是否生成更新令牌 -func (am *ACManager) GenerateToken(code, redirectURI, clientID, clientSecret string, isGenerateRefresh bool) (token *Token, err error) { - acInfo, err := am.getACInfo(code) - if err != nil { - return - } else if acInfo.RedirectURI != redirectURI { - err = ErrACInvalid - return - } else if acInfo.ClientID != clientID { - err = ErrACInvalid - return - } - cli, err := am.oAuthManager.ClientStore.GetByID(acInfo.ClientID) - if err != nil { - return - } else if clientSecret != cli.Secret() { - err = ErrCSInvalid - return - } - - token, err = am.oAuthManager.GenerateToken(cli, - acInfo.UserID, - acInfo.Scope, - am.config.ATExpiresIn, - am.config.RTExpiresIn, - isGenerateRefresh) - - return -} - -// getACInfo 根据授权码获取授权信息 -func (am *ACManager) getACInfo(code string) (info *ACInfo, err error) { - if code == "" { - err = ErrACNotFound - return - } - acID, err := am.oAuthManager.ACGenerate.Parse(code) - if err != nil { - return - } - acInfo, err := am.oAuthManager.ACStore.TakeByID(acID) - if err != nil { - return - } - acValid, err := am.oAuthManager.ACGenerate.Verify(code, acInfo) - if err != nil { - return - } - if !acValid || - (acInfo.CreateAt+int64(acInfo.ExpiresIn/time.Second)) < time.Now().Unix() { - err = ErrACInvalid - return - } - info = acInfo - return -} diff --git a/authorizationCodeGenerate.go b/authorizationCodeGenerate.go deleted file mode 100644 index f4b75b8..0000000 --- a/authorizationCodeGenerate.go +++ /dev/null @@ -1,100 +0,0 @@ -package oauth2 - -import ( - "bytes" - "encoding/base64" - "errors" - "strconv" - "strings" - - "github.com/LyricTian/go.uuid" - - "gopkg.in/LyricTian/lib.v2" -) - -// ACGenerate 授权码生成接口(Authorization Code Generate) -type ACGenerate interface { - // Code 根据授权码相关信息生成授权码 - Code(info *ACInfo) (string, error) - - // Parse 解析授权码,返回授权信息ID - Parse(code string) (int64, error) - - // Verify 验证授权码的有效性 - Verify(code string, info *ACInfo) (bool, error) -} - -// NewDefaultACGenerate 创建默认的授权码生成方式 -func NewDefaultACGenerate() ACGenerate { - return &ACGenerateDefault{} -} - -// ACGenerateDefault 默认的授权码生成方式 -type ACGenerateDefault struct{} - -func (ag *ACGenerateDefault) genCode(info *ACInfo) (string, error) { - ns, _ := uuid.FromString(info.Code) - buf := bytes.NewBuffer(uuid.NewV3(ns, info.ClientID).Bytes()) - _, _ = buf.WriteString(info.UserID) - _, _ = buf.WriteString(strconv.FormatInt(info.CreateAt, 10)) - - md5Val, err := lib.NewEncryption(buf.Bytes()).MD5() - if err != nil { - return "", err - } - md5Val = md5Val[:15] - - return md5Val, nil -} - -// Code Authorization code -func (ag *ACGenerateDefault) Code(info *ACInfo) (string, error) { - codeVal, err := ag.genCode(info) - if err != nil { - return "", err - } - val := base64.URLEncoding.EncodeToString([]byte(codeVal + "." + strconv.FormatInt(info.ID, 10))) - return strings.TrimRight(val, "="), nil -} - -func (ag *ACGenerateDefault) parse(code string) (id int64, token string, err error) { - codeLen := len(code) % 4 - if codeLen > 0 { - codeLen = 4 - codeLen - } - code = code + strings.Repeat("=", codeLen) - codeBV, err := base64.URLEncoding.DecodeString(code) - if err != nil { - return - } - codeVal := strings.SplitN(string(codeBV), ".", 2) - if len(codeVal) != 2 { - err = errors.New("Token is invalid") - return - } - id, err = strconv.ParseInt(codeVal[1], 10, 64) - if err != nil { - return - } - token = codeVal[0] - return -} - -// Parse Parse authorization code -func (ag *ACGenerateDefault) Parse(code string) (id int64, err error) { - id, _, err = ag.parse(code) - return -} - -// Verify Verify code -func (ag *ACGenerateDefault) Verify(code string, info *ACInfo) (valid bool, err error) { - _, token, err := ag.parse(code) - if err != nil { - return - } - codeVal, err := ag.genCode(info) - if err != nil { - return - } - return token == codeVal, nil -} diff --git a/authorizationCodeGenerate_test.go b/authorizationCodeGenerate_test.go deleted file mode 100644 index 9ce5b62..0000000 --- a/authorizationCodeGenerate_test.go +++ /dev/null @@ -1,38 +0,0 @@ -package oauth2 - -import ( - "testing" - "time" - - "gopkg.in/LyricTian/lib.v2" - - . "github.com/smartystreets/goconvey/convey" -) - -func TestACGenerate(t *testing.T) { - Convey("Authorization code generate test", t, func() { - acGenerate := NewDefaultACGenerate() - info := &ACInfo{ - ID: 1, - ClientID: "123456", - UserID: "999999", - Code: lib.NewRandom(6).NumberAndLetter(), - CreateAt: time.Now().Unix(), - } - Convey("Generate code", func() { - code, err := acGenerate.Code(info) - So(err, ShouldBeNil) - So(code, ShouldNotBeBlank) - Convey("Parse code", func() { - id, err := acGenerate.Parse(code) - So(err, ShouldBeNil) - So(id, ShouldEqual, 1) - }) - Convey("Verify code", func() { - valid, err := acGenerate.Verify(code, info) - So(err, ShouldBeNil) - So(valid, ShouldBeTrue) - }) - }) - }) -} diff --git a/authorizationCodeMemoryStore.go b/authorizationCodeMemoryStore.go deleted file mode 100644 index 87f3fb4..0000000 --- a/authorizationCodeMemoryStore.go +++ /dev/null @@ -1,87 +0,0 @@ -package oauth2 - -import ( - "container/list" - "errors" - "sync" - "sync/atomic" - "time" -) - -// NewACMemoryStore 创建授权码的内存存储 -// gcInterval GC周期(单位秒,默认60秒执行一次) -func NewACMemoryStore(gcInterval int64) ACStore { - if gcInterval == 0 { - gcInterval = 60 - } - memStore := &ACMemoryStore{ - gcInterval: time.Second * time.Duration(gcInterval), - data: list.New(), - } - go memStore.gc() - return memStore -} - -// ACMemoryStore 提供授权码的内存存储 -type ACMemoryStore struct { - sync.RWMutex - globalID int64 - gcInterval time.Duration - data *list.List -} - -func (am *ACMemoryStore) gc() { - time.AfterFunc(am.gcInterval, func() { - defer am.gc() - for { - am.RLock() - ele := am.data.Front() - if ele == nil { - am.RUnlock() - break - } - item := ele.Value.(ACInfo) - am.RUnlock() - if (item.CreateAt + int64(item.ExpiresIn/time.Second)) < time.Now().Unix() { - am.Lock() - am.data.Remove(ele) - am.Unlock() - continue - } - break - } - }) -} - -// Put Put item -func (am *ACMemoryStore) Put(item ACInfo) (int64, error) { - am.Lock() - defer am.Unlock() - atomic.AddInt64(&am.globalID, 1) - item.ID = am.globalID - am.data.PushBack(item) - return item.ID, nil -} - -// TakeByID Take item by ID -func (am *ACMemoryStore) TakeByID(id int64) (*ACInfo, error) { - am.RLock() - var takeEle *list.Element - for ele := am.data.Back(); ele != nil; ele = ele.Prev() { - item := ele.Value.(ACInfo) - if item.ID == id { - takeEle = ele - break - } - } - if takeEle == nil { - am.RUnlock() - return nil, errors.New("Item not found") - } - item := takeEle.Value.(ACInfo) - am.RUnlock() - am.Lock() - am.data.Remove(takeEle) - am.Unlock() - return &item, nil -} diff --git a/authorizationCodeMemoryStore_test.go b/authorizationCodeMemoryStore_test.go deleted file mode 100644 index e6dd128..0000000 --- a/authorizationCodeMemoryStore_test.go +++ /dev/null @@ -1,44 +0,0 @@ -package oauth2 - -import ( - "testing" - "time" - - . "github.com/smartystreets/goconvey/convey" -) - -func TestACMemoryStore(t *testing.T) { - Convey("AC memory store test", t, func() { - store := NewACMemoryStore(1) - item := ACInfo{ - ClientID: "123456", - UserID: "999999", - CreateAt: time.Now().Unix(), - ExpiresIn: time.Millisecond * 500, - } - - Convey("Put Test", func() { - id, err := store.Put(item) - So(err, ShouldBeNil) - So(id, ShouldBeGreaterThan, 0) - Convey("Take Test", func() { - info, err := store.TakeByID(id) - So(err, ShouldBeNil) - So(info.ClientID, ShouldEqual, item.ClientID) - So(info.UserID, ShouldEqual, item.UserID) - }) - }) - - Convey("GC Test", func() { - id, err := store.Put(item) - So(err, ShouldBeNil) - So(id, ShouldBeGreaterThan, 0) - Convey("Take GC Test", func() { - time.Sleep(time.Millisecond * 1500) - info, err := store.TakeByID(id) - So(err, ShouldNotBeNil) - So(info, ShouldBeNil) - }) - }) - }) -} diff --git a/authorizationCodeRedisStore.go b/authorizationCodeRedisStore.go deleted file mode 100644 index c3c8db8..0000000 --- a/authorizationCodeRedisStore.go +++ /dev/null @@ -1,89 +0,0 @@ -package oauth2 - -import ( - "encoding/json" - "fmt" - - "gopkg.in/redis.v3" -) - -const ( - // DefaultACRedisIDKey Redis存储授权码唯一标识的键 - DefaultACRedisIDKey = "ACID" -) - -// NewACRedisStore 创建Redis存储的实例 -// config Redis配置参数 -// key Redis存储授权码唯一标识的键(默认为ACID) -func NewACRedisStore(cfg *RedisConfig, key string) (*ACRedisStore, error) { - opt := &redis.Options{ - Network: cfg.Network, - Addr: cfg.Addr, - Password: cfg.Password, - DB: cfg.DB, - MaxRetries: cfg.MaxRetries, - DialTimeout: cfg.DialTimeout, - ReadTimeout: cfg.ReadTimeout, - WriteTimeout: cfg.WriteTimeout, - PoolSize: cfg.PoolSize, - PoolTimeout: cfg.PoolTimeout, - } - cli := redis.NewClient(opt) - err := cli.Ping().Err() - if err != nil { - return nil, err - } - if key == "" { - key = DefaultACRedisIDKey - } - return &ACRedisStore{ - cli: cli, - key: key, - }, nil -} - -// ACRedisStore 提供授权码的redis存储 -type ACRedisStore struct { - cli *redis.Client - key string -} - -// Put 存储授权码 -func (ar *ACRedisStore) Put(item ACInfo) (id int64, err error) { - n, err := ar.cli.Incr(ar.key).Result() - if err != nil { - return - } - item.ID = n - jv, err := json.Marshal(item) - if err != nil { - return - } - key := fmt.Sprintf("%s_%d", ar.key, n) - err = ar.cli.Set(key, string(jv), item.ExpiresIn).Err() - if err != nil { - return - } - id = item.ID - return -} - -// TakeByID 取出授权码 -func (ar *ACRedisStore) TakeByID(id int64) (info *ACInfo, err error) { - key := fmt.Sprintf("%s_%d", ar.key, id) - data, err := ar.cli.Get(key).Result() - if err != nil { - return - } - var v ACInfo - err = json.Unmarshal([]byte(data), &v) - if err != nil { - return - } - err = ar.cli.Del(key).Err() - if err != nil { - return - } - info = &v - return -} diff --git a/authorizationCodeRedisStore_test.go b/authorizationCodeRedisStore_test.go deleted file mode 100644 index 96aad3a..0000000 --- a/authorizationCodeRedisStore_test.go +++ /dev/null @@ -1,49 +0,0 @@ -package oauth2 - -import ( - "testing" - "time" - - . "github.com/smartystreets/goconvey/convey" -) - -func TestACRedisStore(t *testing.T) { - Convey("Authorization code redis store test", t, func() { - store, err := NewACRedisStore(&RedisConfig{ - Addr: "192.168.33.70:6379", - DB: 1, - }, "") - So(err, ShouldBeNil) - item := ACInfo{ - ClientID: "123456", - UserID: "999999", - Code: "", - CreateAt: time.Now().Unix(), - ExpiresIn: time.Millisecond * 500, - } - - Convey("Put Test", func() { - id, err := store.Put(item) - So(err, ShouldBeNil) - So(id, ShouldBeGreaterThan, 0) - Convey("Take Test", func() { - info, err := store.TakeByID(id) - So(err, ShouldBeNil) - So(info.ClientID, ShouldEqual, item.ClientID) - So(info.UserID, ShouldEqual, item.UserID) - }) - }) - - Convey("GC Test", func() { - id, err := store.Put(item) - So(err, ShouldBeNil) - So(id, ShouldBeGreaterThan, 0) - Convey("Take GC Test", func() { - time.Sleep(time.Millisecond * 1500) - info, err := store.TakeByID(id) - So(err, ShouldNotBeNil) - So(info, ShouldBeNil) - }) - }) - }) -} diff --git a/authorizationCodeStore.go b/authorizationCodeStore.go deleted file mode 100644 index 67658e3..0000000 --- a/authorizationCodeStore.go +++ /dev/null @@ -1,28 +0,0 @@ -package oauth2 - -import ( - "time" -) - -// ACInfo 授权码信息(Authorization Code Info) -type ACInfo struct { - ID int64 // 唯一标识 - ClientID string // 客户端标识 - UserID string // 用户标识 - RedirectURI string // 重定向URI - Scope string // 申请的权限范围 - Code string // 随机码 - CreateAt int64 // 创建时间(时间戳) - ExpiresIn time.Duration // 有效期(单位秒) -} - -// ACStore 授权码存储接口(临时存储,提供自动GC过期的元素)(Authorization Code Store) -type ACStore interface { - // Put 将元素放入存储,返回存储的ID - // 如果存储发生异常,则返回错误 - Put(item ACInfo) (int64, error) - - // TakeByID 根据ID取出元素 - // 如果元素找不到或发生异常,则返回错误 - TakeByID(id int64) (*ACInfo, error) -} diff --git a/authorizationCode_test.go b/authorizationCode_test.go deleted file mode 100644 index 954852b..0000000 --- a/authorizationCode_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package oauth2 - -import ( - "testing" - - . "github.com/smartystreets/goconvey/convey" -) - -func TestACManager(t *testing.T) { - ClientHandle(func(info Client) { - oManager, err := NewDefaultOAuthManager(nil, NewMongoConfig(MongoURL, DBName), "", "") - if err != nil { - t.Fatal(err) - } - oManager.SetACGenerate(NewDefaultACGenerate()) - oManager.SetACStore(NewACMemoryStore(0)) - - userID := "999999" - - Convey("Authorization Code Manager Test", t, func() { - manager := oManager.GetACManager() - - redirectURI := "http://www.example.com/cb" - code, err := manager.GenerateCode(info.ID(), userID, redirectURI, "all") - So(err, ShouldBeNil) - - accessToken, err := manager.GenerateToken(code, redirectURI, info.ID(), info.Secret(), true) - So(err, ShouldBeNil) - So(accessToken.UserID, ShouldEqual, userID) - - checkAT, err := oManager.CheckAccessToken(accessToken.AccessToken) - So(err, ShouldBeNil) - So(checkAT.ClientID, ShouldEqual, info.ID()) - So(checkAT.UserID, ShouldEqual, userID) - - newAT, err := oManager.RefreshAccessToken(checkAT.RefreshToken, "") - So(err, ShouldBeNil) - So(newAT.AccessToken, ShouldNotEqual, checkAT.AccessToken) - - err = oManager.RevokeAccessToken(newAT.AccessToken) - So(err, ShouldBeNil) - - checkAT, err = oManager.CheckAccessToken(newAT.AccessToken) - So(err, ShouldNotBeNil) - So(checkAT, ShouldBeNil) - }) - }) -} diff --git a/clientCredentials.go b/clientCredentials.go deleted file mode 100644 index 8302473..0000000 --- a/clientCredentials.go +++ /dev/null @@ -1,40 +0,0 @@ -package oauth2 - -// NewCCManager 创建默认的客户端模式管理实例 -// oaManager OAuth授权管理 -// config 配置参数(nil则使用默认值) -func NewCCManager(oaManager *OAuthManager, config *CCConfig) *CCManager { - if config == nil { - config = new(CCConfig) - } - if config.ATExpiresIn == 0 { - config.ATExpiresIn = DefaultCCATExpiresIn - } - ccManager := &CCManager{ - oAuthManager: oaManager, - config: config, - } - return ccManager -} - -// CCManager 客户端模式管理(Client Credentials Manager) -type CCManager struct { - oAuthManager *OAuthManager // 授权管理 - config *CCConfig // 配置参数 -} - -// GenerateToken 生成令牌(只生成访问令牌) -// clientID 客户端标识 -// clientSecret 客户端秘钥 -// scopes 应用授权标识 -func (cm *CCManager) GenerateToken(clientID, clientSecret, scopes string) (token *Token, err error) { - cli, err := cm.oAuthManager.GetClient(clientID) - if err != nil { - return - } else if cli.Secret() != clientSecret { - err = ErrCSInvalid - return - } - token, err = cm.oAuthManager.GenerateToken(cli, "", scopes, cm.config.ATExpiresIn, 0, false) - return -} diff --git a/clientCredentials_test.go b/clientCredentials_test.go deleted file mode 100644 index bea67d9..0000000 --- a/clientCredentials_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package oauth2 - -import ( - . "github.com/smartystreets/goconvey/convey" - - "testing" -) - -func TestCCManager(t *testing.T) { - ClientHandle(func(cli Client) { - oManager, err := NewDefaultOAuthManager(nil, NewMongoConfig(MongoURL, DBName), "", "") - if err != nil { - t.Fatal(err) - } - Convey("Client Credentials Manager Test", t, func() { - manager := oManager.GetCCManager() - - token, err := manager.GenerateToken(cli.ID(), cli.Secret(), "all") - So(err, ShouldBeNil) - - checkToken, err := oManager.CheckAccessToken(token.AccessToken) - So(err, ShouldBeNil) - So(checkToken.ClientID, ShouldEqual, cli.ID()) - }) - }) -} diff --git a/clientMongoStore.go b/clientMongoStore.go deleted file mode 100644 index fe2224a..0000000 --- a/clientMongoStore.go +++ /dev/null @@ -1,49 +0,0 @@ -package oauth2 - -import ( - "gopkg.in/LyricTian/lib.v2/mongo" - "gopkg.in/mgo.v2" -) - -const ( - // DefaultClientCollectionName 默认的客户端存储集合名称 - DefaultClientCollectionName = "ClientInfo" -) - -// NewClientMongoStore 创建基于MongoDB的客户端存储方式 -// mongoConfig MongoDB配置参数 -// cName 存储客户端的集合名称(默认为ClientInfo) -func NewClientMongoStore(mongoConfig *MongoConfig, cName string) (ClientStore, error) { - mHandler, err := mongo.InitHandlerWithDB(mongoConfig.URL, mongoConfig.DBName) - if err != nil { - return nil, err - } - if cName == "" { - cName = DefaultClientCollectionName - } - return &ClientMongoStore{ - cName: cName, - mHandler: mHandler, - }, nil -} - -// ClientMongoStore 基于MongoDB的默认客户端信息存储 -type ClientMongoStore struct { - cName string - mHandler *mongo.Handler -} - -// GetByID 根据ID获取客户端信息 -func (dcm *ClientMongoStore) GetByID(id string) (client Client, err error) { - dcm.mHandler.CHandle(dcm.cName, func(c *mgo.Collection) { - var result []DefaultClient - err = dcm.mHandler.C(dcm.cName).FindId(id).Limit(1).All(&result) - if err != nil { - return - } - if len(result) > 0 { - client = result[0] - } - }) - return -} diff --git a/clientMongoStore_test.go b/clientMongoStore_test.go deleted file mode 100644 index adb6f44..0000000 --- a/clientMongoStore_test.go +++ /dev/null @@ -1,19 +0,0 @@ -package oauth2 - -import ( - "testing" - - . "github.com/smartystreets/goconvey/convey" -) - -func TestClientMongoStore(t *testing.T) { - ClientHandle(func(info Client) { - Convey("Client mongodb store test", t, func() { - clientStore, err := NewClientMongoStore(NewMongoConfig(MongoURL, DBName), "") - So(err, ShouldBeNil) - client, err := clientStore.GetByID(info.ID()) - So(err, ShouldBeNil) - So(client.Secret(), ShouldEqual, info.Secret()) - }) - }) -} diff --git a/clientStore.go b/clientStore.go deleted file mode 100644 index e4504af..0000000 --- a/clientStore.go +++ /dev/null @@ -1,47 +0,0 @@ -package oauth2 - -// Client 客户端的验证信息接口 -type Client interface { - // ID 客户端唯一标识 - ID() string - // Secret 客户端秘钥 - Secret() string - // Domain 客户端域名 - Domain() string - // RetainData 保留数据 - RetainData() interface{} -} - -// ClientStore 客户端存储接口(持久化存储) -type ClientStore interface { - // GetByID 根据ID获取客户端信息; - // 如果客户端不存在则返回nil - GetByID(id string) (Client, error) -} - -// DefaultClient 默认的客户端信息 -type DefaultClient struct { - ClientID string `bson:"_id"` // 客户端唯一标识 - ClientSecret string `bson:"Secret"` // 客户端秘钥 - ClientDomain string `bson:"Domain"` // 客户端域名 -} - -// ID Get ClientID -func (dc DefaultClient) ID() string { - return dc.ClientID -} - -// Secret Get ClientSecret -func (dc DefaultClient) Secret() string { - return dc.ClientSecret -} - -// Domain Get ClientDomain -func (dc DefaultClient) Domain() string { - return dc.ClientDomain -} - -// RetainData Get retain data -func (dc DefaultClient) RetainData() interface{} { - return dc -} diff --git a/config.go b/config.go deleted file mode 100644 index b5f0c68..0000000 --- a/config.go +++ /dev/null @@ -1,46 +0,0 @@ -package oauth2 - -// MongoConfig MongoDB配置参数 -type MongoConfig struct { - URL string // MongoDB连接字符串 - DBName string // 数据库名称 -} - -// NewMongoConfig 创建MongoDB配置参数的实例 -func NewMongoConfig(url, dbName string) *MongoConfig { - return &MongoConfig{ - URL: url, - DBName: dbName, - } -} - -// ACConfig 授权码模式配置参数(Authorization Code Config) -type ACConfig struct { - ACExpiresIn int64 // 授权码有效期(单位秒) - ATExpiresIn int64 // 访问令牌有效期(单位秒) - RTExpiresIn int64 // 更新令牌有效期(单位秒) -} - -// ImplicitConfig 简化模式配置参数 -type ImplicitConfig struct { - ATExpiresIn int64 // 访问令牌有效期(单位秒) -} - -// PasswordConfig 密码模式配置参数 -type PasswordConfig struct { - ATExpiresIn int64 // 访问令牌有效期(单位秒) - RTExpiresIn int64 // 更新令牌有效期(单位秒) -} - -// CCConfig 客户端模式配置参数(Client Credentials Config) -type CCConfig struct { - ATExpiresIn int64 // 访问令牌有效期(单位秒) -} - -// OAuthConfig OAuth授权配置参数 -type OAuthConfig struct { - ACConfig *ACConfig // 授权码模式配置参数 - ImplicitConfig *ImplicitConfig // 简化模式配置参数 - PasswordConfig *PasswordConfig // 密码模式配置参数 - CCConfig *CCConfig // 客户端模式配置参数 -} diff --git a/config_redis.go b/config_redis.go deleted file mode 100644 index 9f0546b..0000000 --- a/config_redis.go +++ /dev/null @@ -1,41 +0,0 @@ -package oauth2 - -import "time" - -// RedisConfig Redis的配置参数 -type RedisConfig struct { - // The network type, either tcp or unix. - // Default is tcp. - Network string - // host:port address. - Addr string - - // An optional password. Must match the password specified in the - // requirepass server configuration option. - Password string - // A database to be selected after connecting to server. - DB int64 - - // The maximum number of retries before giving up. - // Default is to not retry failed commands. - MaxRetries int - - // Sets the deadline for establishing new connections. If reached, - // dial will fail with a timeout. - // Default is 5 seconds. - DialTimeout time.Duration - // Sets the deadline for socket reads. If reached, commands will - // fail with a timeout instead of blocking. - ReadTimeout time.Duration - // Sets the deadline for socket writes. If reached, commands will - // fail with a timeout instead of blocking. - WriteTimeout time.Duration - - // The maximum number of socket connections. - // Default is 10 connections. - PoolSize int - // Specifies amount of time client waits for connection if all - // connections are busy before returning an error. - // Default is 1 second. - PoolTimeout time.Duration -} diff --git a/const.go b/const.go index 97b991c..2f49ae5 100644 --- a/const.go +++ b/const.go @@ -1,30 +1,29 @@ package oauth2 -const ( - // DefaultRandomCodeLen 默认随机码的长度 - DefaultRandomCodeLen = 6 - // DefaultACExpiresIn 默认授权码模式的授权码有效期(10分钟) - DefaultACExpiresIn = 60 * 10 - // DefaultATExpiresIn 默认授权码模式的访问令牌有效期(7天) - DefaultATExpiresIn = 60 * 60 * 24 * 7 - // DefaultRTExpiresIn 默认授权码模式的更新令牌有效期(30天) - DefaultRTExpiresIn = 60 * 60 * 24 * 30 - // DefaultIATExpiresIn 默认简化模式的访问令牌有效期(1小时) - DefaultIATExpiresIn = 60 * 60 - // DefaultCCATExpiresIn 默认客户端模式的访问令牌有效期(1天) - DefaultCCATExpiresIn = 60 * 60 * 24 -) - -// STATUS 提供一些状态标识 -type STATUS byte +// GrantType 定义授权模式 +type GrantType byte const ( - // Deleted 删除状态 - Deleted STATUS = iota - // Actived 激活状态 - Actived - // Blocked 冻结状态 - Blocked - // Expired 过期状态 - Expired + // AuthorizationCode 授权码模式 + AuthorizationCode GrantType = 1 << (iota + 1) + // Implicit 简化模式 + Implicit + // PasswordCredentials 密码模式 + PasswordCredentials + // ClientCredentials 客户端模式 + ClientCredentials ) + +func (gt GrantType) String() string { + switch gt { + case 1 << 1: + return "authorization_code" + case 1 << 2: + return "implicit" + case 1 << 3: + return "password" + case 1 << 4: + return "clientcredentials" + } + return "unknown" +} diff --git a/error.go b/error.go index 18ecf25..6e2dd3a 100644 --- a/error.go +++ b/error.go @@ -1,37 +1,20 @@ package oauth2 -import ( - "errors" -) +import "errors" var ( - // ErrClientNotFound Client not found - ErrClientNotFound = errors.New("The client is not found.") - - // ErrACNotFound Authorization code not found - ErrACNotFound = errors.New("The authorization code is not found.") - - // ErrACInvalid Authorization code invalid - ErrACInvalid = errors.New("The authorization code is invalid.") - - // ErrCSInvalid Client secret invalid - ErrCSInvalid = errors.New("The client secret is invalid.") - - // ErrATNotFound Refresh token not found - ErrATNotFound = errors.New("The access token is not found.") - - // ErrATInvalid Access token invalid - ErrATInvalid = errors.New("The access token is invalid.") + // ErrNotFound Not Found + ErrNotFound = errors.New("not found") - // ErrATExpire Access token expire - ErrATExpire = errors.New("The access token is expire.") + // ErrInvalid Invalid + ErrInvalid = errors.New("invalid") - // ErrRTNotFound Refresh token not found - ErrRTNotFound = errors.New("The refresh token is not found.") + // ErrExpired Expired + ErrExpired = errors.New("expired") - // ErrRTInvalid Refresh token invalid - ErrRTInvalid = errors.New("The refresh token is invalid.") + // ErrForbidden Forbidden + ErrForbidden = errors.New("forbidden") - // ErrRTExpire Refresh token expire - ErrRTExpire = errors.New("The refresh token is expire.") + // ErrNilValue Nil Value + ErrNilValue = errors.New("nil value") ) diff --git a/generate.go b/generate.go new file mode 100644 index 0000000..84a5dd2 --- /dev/null +++ b/generate.go @@ -0,0 +1,33 @@ +package oauth2 + +import "time" + +type ( + // TokenData 提供生成令牌的基础数据 + TokenData struct { + Client ClientInfo // 客户端信息 + UserID string // 用户标识 + Scope string // 权限范围 + CreateAt time.Time // 创建时间 + ExpiresIn time.Duration // 有效期 + Identifier string // 唯一标识码 + } + + // AuthorizeGenerate 授权令牌生成接口 + AuthorizeGenerate interface { + // 生成授权令牌 + Token(data *TokenData) (string, error) + + // 验证令牌的有效性 + Verify(token string, data *TokenData) (bool, error) + } + + // TokenGenerate 访问令牌生成接口 + TokenGenerate interface { + // 生成访问令牌 + AccessToken(data *TokenData) (string, error) + + // 生成刷新令牌 + RefreshToken(data *TokenData) (string, error) + } +) diff --git a/implicit.go b/implicit.go deleted file mode 100644 index 33d2d2f..0000000 --- a/implicit.go +++ /dev/null @@ -1,38 +0,0 @@ -package oauth2 - -// NewImplicitManager 创建默认的简化模式管理实例 -// oaManager OAuth授权管理 -// config 配置参数(nil则使用默认值) -func NewImplicitManager(oaManager *OAuthManager, config *ImplicitConfig) *ImplicitManager { - if config == nil { - config = new(ImplicitConfig) - } - if config.ATExpiresIn == 0 { - config.ATExpiresIn = DefaultIATExpiresIn - } - iManager := &ImplicitManager{ - oAuthManager: oaManager, - config: config, - } - return iManager -} - -// ImplicitManager 简化模式管理 -type ImplicitManager struct { - oAuthManager *OAuthManager // 授权管理 - config *ImplicitConfig // 配置参数 -} - -// GenerateToken 生成令牌(只生成访问令牌) -// clientID 客户端标识 -// userID 用户标识 -// redirectURI 重定向URI -// scopes 应用授权标识 -func (im *ImplicitManager) GenerateToken(clientID, userID, redirectURI, scopes string) (token *Token, err error) { - cli, err := im.oAuthManager.ValidateClient(clientID, redirectURI) - if err != nil { - return - } - token, err = im.oAuthManager.GenerateToken(cli, userID, scopes, im.config.ATExpiresIn, 0, false) - return -} diff --git a/implicit_test.go b/implicit_test.go deleted file mode 100644 index aaded8f..0000000 --- a/implicit_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package oauth2 - -import ( - . "github.com/smartystreets/goconvey/convey" - - "testing" -) - -func TestImplicitManager(t *testing.T) { - ClientHandle(func(cli Client) { - userID := "999999" - oManager, err := NewDefaultOAuthManager(nil, NewMongoConfig(MongoURL, DBName), "", "") - if err != nil { - t.Fatal(err) - } - - Convey("Implicit Manager Test", t, func() { - manager := oManager.GetImplicitManager() - - token, err := manager.GenerateToken(cli.ID(), userID, "http://www.example.com/cb", "all") - So(err, ShouldBeNil) - - checkToken, err := oManager.CheckAccessToken(token.AccessToken) - So(err, ShouldBeNil) - So(checkToken.ClientID, ShouldEqual, cli.ID()) - So(checkToken.UserID, ShouldEqual, userID) - }) - }) -} diff --git a/manager.go b/manager.go new file mode 100644 index 0000000..bbfaede --- /dev/null +++ b/manager.go @@ -0,0 +1,127 @@ +package oauth2 + +import ( + "time" + + "github.com/LyricTian/inject" +) + +// NewManager 创建Manager的实例 +func NewManager() *Manager { + return nil +} + +// Config 授权配置参数 +type Config struct { + CodeExpiresIn time.Duration // 授权码有效期 + AccessExpiresIn time.Duration // 访问令牌有效期 + RefreshExpiresIn time.Duration // 刷新令牌有效期 +} + +// Manager OAuth2授权管理 +type Manager struct { + injector inject.Injector + configs map[GrantType]*Config +} + +// SetConfig 设定配置参数 +func (m *Manager) SetConfig(gt GrantType, cfg *Config) { + m.configs[gt] = cfg +} + +// MapClientModel 注入客户端信息模型 +func (m *Manager) MapClientModel(cli ClientInfo) { + if cli == nil { + panic(ErrNilValue) + } + m.injector.Map(cli) +} + +// MapAuthorizeModel 注入授权信息模型 +func (m *Manager) MapAuthorizeModel(auth Authorize) { + if auth == nil { + panic(ErrNilValue) + } + m.injector.Map(auth) +} + +// MapTokenModel 注入令牌信息模型 +func (m *Manager) MapTokenModel(token TokenInfo) { + if token == nil { + panic(ErrNilValue) + } + m.injector.Map(token) +} + +// MapAuthorizeGenerate 注入授权令牌生成接口 +func (m *Manager) MapAuthorizeGenerate(gen AuthorizeGenerate) { + if gen == nil { + panic(ErrNilValue) + } + m.injector.Map(gen) +} + +// MapTokenGenerate 注入访问令牌生成接口 +func (m *Manager) MapTokenGenerate(gen TokenGenerate) { + if gen == nil { + panic(ErrNilValue) + } + m.injector.Map(gen) +} + +// MapClientStorage 注入客户端信息存储接口 +func (m *Manager) MapClientStorage(stor ClientStorage) { + if stor == nil { + panic(ErrNilValue) + } + m.injector.Map(stor) +} + +// MustClientStorage 注入客户端信息存储接口 +func (m *Manager) MustClientStorage(stor ClientStorage, err error) { + if err != nil { + panic(err) + } + if stor == nil { + panic(ErrNilValue) + } + m.injector.Map(stor) +} + +// MapAuthorizeStorage 注入授权码信息存储接口 +func (m *Manager) MapAuthorizeStorage(stor AuthorizeStorage) { + if stor == nil { + panic(ErrNilValue) + } + m.injector.Map(stor) +} + +// MustAuthorizeStorage 注入授权码信息存储接口 +func (m *Manager) MustAuthorizeStorage(stor AuthorizeStorage, err error) { + if err != nil { + panic(err) + } + if stor == nil { + panic(ErrNilValue) + } + m.injector.Map(stor) +} + +// MapTokenStorage 注入令牌信息存储接口 +func (m *Manager) MapTokenStorage(stor TokenStorage) { + if stor == nil { + panic(ErrNilValue) + } + m.injector.Map(stor) +} + +// MustTokenStorage 注入令牌信息存储接口 +func (m *Manager) MustTokenStorage(stor TokenStorage, err error) { + if err != nil { + panic(err) + } + if stor == nil { + panic(ErrNilValue) + } + m.injector.Map(stor) +} diff --git a/model.go b/model.go new file mode 100644 index 0000000..5826aec --- /dev/null +++ b/model.go @@ -0,0 +1,98 @@ +package oauth2 + +import "time" + +// 相关模型接口的定义 +type ( + // ClientInfo 客户端信息模型接口 + ClientInfo interface { + // 客户端唯一标识 + GetID() string + // 客户端秘钥 + GetSecret() string + // 客户端域名URL + GetDomain() string + // 自定义数据 + GetRetainData() interface{} + } + + // Authorize 授权信息模型接口 + Authorize interface { + // 客户端标识 + GetClientID() string + // 设置客户端标识 + SetClientID(string) + // 用户标识 + GetUserID() string + // 设置用户标识 + SetUserID(string) + // 重定向URI + GetRedirectURI() string + // 设置重定向URI + SetRedirectURI(string) + // 权限范围 + GetScope() string + // 设置权限范围 + SetScope(string) + // 创建时间 + GetCreateAt() time.Time + // 设置创建时间 + SetCreateAt(time.Time) + // 有效期 + GetExpiresIn() time.Duration + // 设置有效期 + SetExpiresIn(time.Duration) + // 授权令牌 + GetToken() string + // 设置授权令牌 + SetToken(string) + // 用于标识授权令牌的唯一标识码 + GetIdentifier() string + // 设置用于标识授权令牌的唯一标识码 + SetIdentifier(string) + } + + // TokenInfo 令牌信息模型接口 + TokenInfo interface { + // 客户端标识 + GetClientID() string + // 设置客户端标识 + SetClientID(string) + // 用户标识 + GetUserID() string + // 设置用户标识 + SetUserID(string) + // 访问令牌 + GetAccess() TokenBasic + // 设置访问令牌 + SetAccess(TokenBasic) + // 更新令牌 + GetRefresh() TokenBasic + // 设置更新令牌 + SetRefresh(TokenBasic) + // 权限范围 + GetScope() string + // 设置权限范围 + SetScope(string) + } + + // TokenBasic 令牌基础模型接口 + TokenBasic interface { + // 创建时间 + GetCreateAt() time.Time + // 设置创建时间 + SetCreateAt(time.Time) + // 有效期 + GetExpiresIn() time.Duration + // 设置有效期 + SetExpiresIn(time.Duration) + // 令牌 + GetToken() string + // 设置令牌 + SetToken(string) + // 用于标识令牌的唯一标识码 + GetIdentifier() string + // 设置用于标识令牌的唯一标识码 + SetIdentifier(string) + } +) diff --git a/oauth2.go b/oauth2.go deleted file mode 100644 index 37aef1b..0000000 --- a/oauth2.go +++ /dev/null @@ -1,359 +0,0 @@ -package oauth2 - -import ( - "github.com/LyricTian/go.uuid" - - "time" -) - -// NewOAuthManager 创建OAuth授权管理实例 -// cfg 配置参数 -func NewOAuthManager(cfg *OAuthConfig) *OAuthManager { - if cfg == nil { - cfg = new(OAuthConfig) - } - return &OAuthManager{ - Config: cfg, - } -} - -// NewDefaultOAuthManager 创建默认的OAuth授权管理实例 -// cfg 配置参数 -// mcfg MongoDB配置参数 -// ccName 存储客户端的集合名称(默认为ClientInfo) -// tcName 存储令牌的集合名称(默认为AuthToken) -func NewDefaultOAuthManager(cfg *OAuthConfig, mcfg *MongoConfig, ccName, tcName string) (*OAuthManager, error) { - oManager := NewOAuthManager(cfg) - clientStore, err := NewClientMongoStore(mcfg, ccName) - if err != nil { - return nil, err - } - oManager.SetClientStore(clientStore) - tokenStore, err := NewTokenMongoStore(mcfg, tcName) - if err != nil { - return nil, err - } - oManager.SetTokenStore(tokenStore) - oManager.SetTokenGenerate(NewDefaultTokenGenerate()) - - return oManager, nil -} - -// OAuthManager OAuth授权管理 -type OAuthManager struct { - Config *OAuthConfig // 配置参数 - ACGenerate ACGenerate // 授权码生成 - ACStore ACStore // 授权码存储 - TokenGenerate TokenGenerate // 令牌生成 - TokenStore TokenStore // 令牌存储 - ClientStore ClientStore // 客户端存储 -} - -// SetConfig 设置授权码生成接口 -func (om *OAuthManager) SetConfig(cfg *OAuthConfig) { - om.Config = cfg -} - -// SetACGenerate 设置授权码生成接口 -func (om *OAuthManager) SetACGenerate(generate ACGenerate) { - om.ACGenerate = generate -} - -// SetACStore 设置授权码存储接口 -func (om *OAuthManager) SetACStore(store ACStore) { - om.ACStore = store -} - -// SetTokenGenerate 设置令牌生成接口 -func (om *OAuthManager) SetTokenGenerate(generate TokenGenerate) { - om.TokenGenerate = generate -} - -// SetTokenStore 设置令牌存储接口 -func (om *OAuthManager) SetTokenStore(store TokenStore) { - om.TokenStore = store -} - -// SetClientStore 设置客户端存储接口 -func (om *OAuthManager) SetClientStore(store ClientStore) { - om.ClientStore = store -} - -// GetACManager 获取授权码模式管理实例 -func (om *OAuthManager) GetACManager() *ACManager { - return NewACManager(om, om.Config.ACConfig) -} - -// GetImplicitManager 获取简化模式管理实例 -func (om *OAuthManager) GetImplicitManager() *ImplicitManager { - return NewImplicitManager(om, om.Config.ImplicitConfig) -} - -// GetPasswordManager 获取密码模式管理实例 -func (om *OAuthManager) GetPasswordManager() *PasswordManager { - return NewPasswordManager(om, om.Config.PasswordConfig) -} - -// GetCCManager 获取客户端模式管理实例 -func (om *OAuthManager) GetCCManager() *CCManager { - return NewCCManager(om, om.Config.CCConfig) -} - -// GenerateToken 生成令牌 -// cli 客户端信息 -// userID 用户标识 -// scopes 应用授权标识 -// isGenerateRefresh 是否生成更新令牌 -func (om *OAuthManager) GenerateToken(cli Client, userID, scopes string, atExpireIn, rtExpireIn int64, isGenerateRefresh bool) (token *Token, err error) { - createAt := time.Now().Unix() - atID := uuid.NewV4().String() - atBI := NewTokenBasicInfo(cli, atID, userID, createAt) - atValue, err := om.TokenGenerate.AccessToken(atBI) - if err != nil { - return - } - tokenValue := Token{ - ClientID: cli.ID(), - UserID: userID, - AccessToken: atValue, - ATID: atID, - ATCreateAt: createAt, - ATExpiresIn: time.Duration(atExpireIn) * time.Second, - Scope: scopes, - CreateAt: createAt, - Status: Actived, - } - if isGenerateRefresh { - rtID := uuid.NewV4().String() - rtBI := NewTokenBasicInfo(cli, rtID, userID, createAt) - rtValue, rtErr := om.TokenGenerate.RefreshToken(rtBI) - if rtErr != nil { - err = rtErr - return - } - tokenValue.RefreshToken = rtValue - tokenValue.RTID = rtID - tokenValue.RTCreateAt = createAt - tokenValue.RTExpiresIn = time.Duration(rtExpireIn) * time.Second - } - id, err := om.TokenStore.Create(&tokenValue) - if err != nil { - return - } - tokenValue.ID = id - token = &tokenValue - return -} - -// GetClient 根据客户端标识获取客户端信息 -// clientID 客户端标识 -func (om *OAuthManager) GetClient(clientID string) (cli Client, err error) { - cli, err = om.ClientStore.GetByID(clientID) - if err != nil { - return - } else if cli == nil { - err = ErrClientNotFound - } - return -} - -// ValidateClient 验证客户端的重定向URI -// clientID 客户端标识 -// redirectURI 重定向URI -func (om *OAuthManager) ValidateClient(clientID, redirectURI string) (cli Client, err error) { - cli, err = om.GetClient(clientID) - if err != nil { - return - } else if v := ValidateURI(cli.Domain(), redirectURI); v != nil { - err = v - } - return -} - -// CheckAccessToken 检查访问令牌是否可用,同时返回该令牌的相关信息 -// accessToken 访问令牌 -func (om *OAuthManager) CheckAccessToken(accessToken string) (token *Token, err error) { - if accessToken == "" { - err = ErrATNotFound - return - } - tokenValue, err := om.TokenStore.GetByAccessToken(accessToken) - if err != nil { - return - } else if tokenValue == nil { - err = ErrATNotFound - return - } else if tokenValue.Status != Actived { - err = ErrATInvalid - return - } else if v := om.checkRefreshTokenExpire(tokenValue); v != nil { - err = v - return - } else if v := om.checkAccessTokenExpire(tokenValue); v != nil { - err = v - return - } else if v := om.checkAccessTokenValidity(accessToken, tokenValue); v != nil { - err = v - return - } - token = tokenValue - return -} - -// RevokeAccessToken 废除访问令牌(将该访问令牌的状态更改为删除) -// accessToken 访问令牌 -func (om *OAuthManager) RevokeAccessToken(accessToken string) (err error) { - if accessToken == "" { - err = ErrATNotFound - return - } - token, err := om.TokenStore.GetByAccessToken(accessToken) - if err != nil { - return - } else if token == nil { - err = ErrATNotFound - return - } else if token.Status != Actived { - err = ErrATInvalid - return - } else if v := om.checkAccessTokenValidity(accessToken, token); v != nil { - err = v - return - } - info := map[string]interface{}{ - "Status": Deleted, - } - err = om.TokenStore.Update(token.ID, info) - return -} - -// RefreshAccessToken 更新访问令牌(在更新令牌有效期内,更新访问令牌的有效期),同时返回更新后的令牌信息 -// refreshToken 更新令牌 -// scopes 申请的权限范围(不可以超出上一次申请的范围,如果省略该参数,则表示与上一次一致) -func (om *OAuthManager) RefreshAccessToken(refreshToken, scopes string) (token *Token, err error) { - if refreshToken == "" { - err = ErrRTNotFound - return - } - tokenValue, err := om.TokenStore.GetByRefreshToken(refreshToken) - if err != nil { - return - } else if tokenValue == nil { - err = ErrRTNotFound - return - } else if tokenValue.Status != Actived { - err = ErrRTInvalid - return - } else if v := om.checkRefreshTokenExpire(tokenValue); v != nil { - err = v - return - } else if v := om.checkRefreshTokenValidity(refreshToken, tokenValue); v != nil { - err = v - return - } - cli, err := om.GetClient(tokenValue.ClientID) - if err != nil { - return - } - tokenValue.ATCreateAt = time.Now().Unix() - tokenValue.ATID = uuid.NewV4().String() - atBI := NewTokenBasicInfo(cli, tokenValue.ATID, tokenValue.UserID, tokenValue.ATCreateAt) - atValue, err := om.TokenGenerate.AccessToken(atBI) - if err != nil { - return - } - tokenValue.AccessToken = atValue - tokenInfo := map[string]interface{}{ - "AccessToken": tokenValue.AccessToken, - "ATID": tokenValue.ATID, - "ATCreateAt": tokenValue.ATCreateAt, - } - if scopes != "" { - tokenValue.Scope = scopes - tokenInfo["Scope"] = tokenValue.Scope - } - err = om.TokenStore.Update(tokenValue.ID, tokenInfo) - if err != nil { - return - } - token = tokenValue - return -} - -// checkAccessTokenExpire 检查访问令牌是否过期, -// 如果访问令牌过期同时没有更新令牌的情况下, -// 则将令牌状态更改为过期 -func (om *OAuthManager) checkAccessTokenExpire(token *Token) error { - if token.AccessToken == "" { - return nil - } - nowUnix := time.Now().Unix() - if (token.ATCreateAt + int64(token.ATExpiresIn/time.Second)) > nowUnix { - return nil - } - var err error - if token.RefreshToken == "" { - info := map[string]interface{}{ - "Status": Expired, - } - err = om.TokenStore.Update(token.ID, info) - if err == nil { - err = ErrATExpire - } - } - return err -} - -// checkRefreshTokenExpire 检查更新令牌是否过期, -// 如果更新令牌过期则将令牌状态更改为过期 -func (om *OAuthManager) checkRefreshTokenExpire(token *Token) error { - if token.RefreshToken == "" { - return nil - } - nowUnix := time.Now().Unix() - if (token.RTCreateAt + int64(token.RTExpiresIn/time.Second)) > nowUnix { - return nil - } - info := map[string]interface{}{ - "Status": Expired, - } - err := om.TokenStore.Update(token.ID, info) - if err == nil { - err = ErrRTExpire - } - return err -} - -// checkAccessTokenValidity 检查访问令牌的有效性 -func (om *OAuthManager) checkAccessTokenValidity(tv string, token *Token) (err error) { - cli, err := om.GetClient(token.ClientID) - if err != nil { - return - } - bi := NewTokenBasicInfo(cli, token.ATID, token.UserID, token.ATCreateAt) - v, err := om.TokenGenerate.AccessToken(bi) - if err != nil { - return - } - if tv != v { - err = ErrATInvalid - } - return -} - -// checkRefreshTokenValidity 检查刷新令牌的有效性 -func (om *OAuthManager) checkRefreshTokenValidity(rv string, token *Token) (err error) { - cli, err := om.GetClient(token.ClientID) - if err != nil { - return - } - bi := NewTokenBasicInfo(cli, token.RTID, token.UserID, token.RTCreateAt) - v, err := om.TokenGenerate.RefreshToken(bi) - if err != nil { - return - } - if rv != v { - err = ErrRTInvalid - } - return -} diff --git a/oauth2_test.go b/oauth2_test.go deleted file mode 100644 index 87ffcb3..0000000 --- a/oauth2_test.go +++ /dev/null @@ -1,43 +0,0 @@ -package oauth2 - -import ( - "gopkg.in/LyricTian/lib.v2" - "gopkg.in/LyricTian/lib.v2/mongo" - "gopkg.in/mgo.v2/bson" -) - -const ( - // MongoURL MongoDB连接字符串 - MongoURL = "mongodb://admin:123456@192.168.33.70:27017" - // DBName 数据库名称 - DBName = "test" -) - -var ( - oManager *OAuthManager -) - -// ClientHandle 执行客户端处理 -func ClientHandle(handle func(cli Client)) { - info := DefaultClient{ - ClientID: bson.NewObjectId().Hex(), - ClientDomain: "http://www.example.com", - } - info.ClientSecret, _ = lib.NewEncryption([]byte(info.ClientID)).MD5() - mHandler, err := mongo.InitHandlerWithDB(MongoURL, DBName) - if err != nil { - panic(err) - } - defer func() { - err = mHandler.C(DefaultClientCollectionName).RemoveId(info.ClientID) - if err != nil { - panic(err) - } - mHandler.Session().Close() - }() - err = mHandler.C(DefaultClientCollectionName).Insert(info) - if err != nil { - panic(err) - } - handle(info) -} diff --git a/password.go b/password.go deleted file mode 100644 index b1975fd..0000000 --- a/password.go +++ /dev/null @@ -1,51 +0,0 @@ -package oauth2 - -// NewPasswordManager 创建默认的密码模式管理实例 -// oaManager OAuth授权管理 -// config 配置参数(nil则使用默认值) -func NewPasswordManager(oaManager *OAuthManager, config *PasswordConfig) *PasswordManager { - if config == nil { - config = new(PasswordConfig) - } - if config.ATExpiresIn == 0 { - config.ATExpiresIn = DefaultATExpiresIn - } - if config.RTExpiresIn == 0 { - config.RTExpiresIn = DefaultRTExpiresIn - } - pManager := &PasswordManager{ - oAuthManager: oaManager, - config: config, - } - return pManager -} - -// PasswordManager 密码模式管理 -type PasswordManager struct { - oAuthManager *OAuthManager // 授权管理 - config *PasswordConfig // 配置参数 -} - -// GenerateToken 生成令牌(只生成访问令牌) -// clientID 客户端标识 -// userID 用户标识 -// clientSecret 客户端秘钥 -// scopes 应用授权标识 -func (pm *PasswordManager) GenerateToken(clientID, userID, clientSecret, scopes string, isGenerateRefresh bool) (token *Token, err error) { - cli, err := pm.oAuthManager.GetClient(clientID) - if err != nil { - return - } else if cli.Secret() != clientSecret { - err = ErrCSInvalid - return - } - - token, err = pm.oAuthManager.GenerateToken(cli, - userID, - scopes, - pm.config.ATExpiresIn, - pm.config.RTExpiresIn, - isGenerateRefresh) - - return -} diff --git a/password_test.go b/password_test.go deleted file mode 100644 index 65b7257..0000000 --- a/password_test.go +++ /dev/null @@ -1,33 +0,0 @@ -package oauth2 - -import ( - "testing" - - . "github.com/smartystreets/goconvey/convey" -) - -func TestPasswordManager(t *testing.T) { - ClientHandle(func(info Client) { - userID := "999999" - oManager, err := NewDefaultOAuthManager(nil, NewMongoConfig(MongoURL, DBName), "", "") - if err != nil { - t.Fatal(err) - } - - Convey("Password Manager Test", t, func() { - manager := oManager.GetPasswordManager() - - token, err := manager.GenerateToken(info.ID(), userID, info.Secret(), "all", true) - So(err, ShouldBeNil) - - checkAT, err := oManager.CheckAccessToken(token.AccessToken) - So(err, ShouldBeNil) - So(checkAT.ClientID, ShouldEqual, info.ID()) - So(checkAT.UserID, ShouldEqual, userID) - - newAT, err := oManager.RefreshAccessToken(checkAT.RefreshToken, "") - So(err, ShouldBeNil) - So(newAT.AccessToken, ShouldNotEqual, checkAT.AccessToken) - }) - }) -} diff --git a/storage.go b/storage.go new file mode 100644 index 0000000..f66b302 --- /dev/null +++ b/storage.go @@ -0,0 +1,43 @@ +package oauth2 + +// 提供存储接口 +type ( + // ClientStorage 客户端信息存储接口 + ClientStorage interface { + // GetByID 根据ID获取客户端信息 + GetByID(id string) (ClientInfo, error) + } + + // AuthorizeStorage 授权码信息存储接口 + AuthorizeStorage interface { + // 将授权信息放入存储 + Put(info Authorize) error + + // 根据授权令牌取出授权信息 + TakeByToken(token string) (Authorize, error) + } + + // TokenStorage 令牌信息存储接口 + TokenStorage interface { + // Create 创建并存储新的令牌信息 + Create(info TokenInfo) error + + // UpdateByRefresh 根据刷新令牌更新令牌信息 + UpdateByRefresh(refresh string, info TokenInfo) error + + // 根据访问令牌获取令牌信息数据 + GetByAccess(access string) (TokenInfo, error) + + // 根据刷新令牌获取令牌信息数据 + GetByRefresh(refresh string) (TokenInfo, error) + + // 根据访问令牌废除令牌信息 + RevokeByAccess(access string) error + + // 将该访问令牌对应的令牌信息作过期处理 + ExpiredByAccess(access string) error + + // 将该刷新令牌对应的令牌信息作过期处理 + ExpiredByRefresh(refresh string) error + } +) diff --git a/tokenGenerate.go b/tokenGenerate.go deleted file mode 100644 index 65193f5..0000000 --- a/tokenGenerate.go +++ /dev/null @@ -1,74 +0,0 @@ -package oauth2 - -import ( - "bytes" - "strconv" - - "github.com/LyricTian/go.uuid" - - "gopkg.in/LyricTian/lib.v2" -) - -// NewTokenBasicInfo 创建用于生成令牌的基础信息 -// cli 客户端信息 -// userID 用户标识 -// createAt 创建令牌的时间戳 -func NewTokenBasicInfo(cli Client, tokenID, userID string, createAt int64) *TokenBasicInfo { - return &TokenBasicInfo{ - Client: cli, - UserID: userID, - TokenID: tokenID, - CreateAt: createAt, - } -} - -// TokenBasicInfo 用于生成令牌的基础信息 -type TokenBasicInfo struct { - Client Client // 客户端信息 - UserID string // 用户标识 - TokenID string // 令牌标识 - CreateAt int64 // 创建令牌的时间戳 -} - -// TokenGenerate 令牌生成接口 -type TokenGenerate interface { - // AccessToken 生成访问令牌 - AccessToken(basicInfo *TokenBasicInfo) (string, error) - - // RefreshToken 生成刷新令牌 - RefreshToken(basicInfo *TokenBasicInfo) (string, error) -} - -// NewDefaultTokenGenerate 创建默认的访问令牌生成方式 -func NewDefaultTokenGenerate() TokenGenerate { - return &TokenGenerateDefault{} -} - -// TokenGenerateDefault 提供访问令牌、更新令牌的默认生成函数 -type TokenGenerateDefault struct{} - -// AccessToken 生成访问令牌(md5) -// basicInfo 生成访问令牌的基础参数 -func (tg *TokenGenerateDefault) AccessToken(basicInfo *TokenBasicInfo) (token string, err error) { - ns, _ := uuid.FromString(basicInfo.TokenID) - buf := bytes.NewBuffer(uuid.NewV3(ns, basicInfo.Client.ID()).Bytes()) - if basicInfo.UserID != "" { - _, _ = buf.WriteString(basicInfo.UserID) - } - _, _ = buf.WriteString(strconv.FormatInt(basicInfo.CreateAt, 10)) - - return lib.NewEncryption(buf.Bytes()).MD5() -} - -// RefreshToken 生成刷新令牌(sha1) -// basicInfo 生成刷新令牌的基础参数 -func (tg *TokenGenerateDefault) RefreshToken(basicInfo *TokenBasicInfo) (string, error) { - ns, _ := uuid.FromString(basicInfo.TokenID) - buf := bytes.NewBuffer(uuid.NewV5(ns, basicInfo.Client.ID()).Bytes()) - if basicInfo.UserID != "" { - _, _ = buf.WriteString(basicInfo.UserID) - } - _, _ = buf.WriteString(strconv.FormatInt(basicInfo.CreateAt, 10)) - - return lib.NewEncryption(buf.Bytes()).Sha1() -} diff --git a/tokenGenerate_test.go b/tokenGenerate_test.go deleted file mode 100644 index e7a9003..0000000 --- a/tokenGenerate_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package oauth2 - -import ( - "testing" - "time" - - "github.com/LyricTian/go.uuid" - - . "github.com/smartystreets/goconvey/convey" -) - -func TestTokenGenerate(t *testing.T) { - cli := DefaultClient{ - ClientID: "123456", - ClientSecret: "654321", - ClientDomain: "http://www.lyric.name", - } - basicInfo := &TokenBasicInfo{ - Client: cli, - TokenID: uuid.NewV4().String(), - UserID: "999999", - CreateAt: time.Now().Unix(), - } - Convey("Token generate test", t, func() { - tokenGenerate := NewDefaultTokenGenerate() - Convey("Generate access token", func() { - token, err := tokenGenerate.AccessToken(basicInfo) - So(err, ShouldBeNil) - _, _ = Println("\n [P]Access Token:" + token) - }) - Convey("Generate refresh token", func() { - token, err := tokenGenerate.RefreshToken(basicInfo) - So(err, ShouldBeNil) - _, _ = Println("\n [P]Refresh Token:" + token) - }) - }) -} diff --git a/tokenMongoStore.go b/tokenMongoStore.go deleted file mode 100644 index eeeefe7..0000000 --- a/tokenMongoStore.go +++ /dev/null @@ -1,91 +0,0 @@ -package oauth2 - -import ( - "gopkg.in/LyricTian/lib.v2/mongo" - "gopkg.in/mgo.v2" - "gopkg.in/mgo.v2/bson" -) - -const ( - // DefaultTokenCollectionName 默认的令牌存储集合名称 - DefaultTokenCollectionName = "AuthToken" -) - -// NewTokenMongoStore 创建基于MongoDB的令牌存储方式 -// mongoConfig MongoDB配置参数 -// cName 存储令牌的集合名称(默认为AuthToken) -func NewTokenMongoStore(mongoConfig *MongoConfig, cName string) (TokenStore, error) { - mHandler, err := mongo.InitHandlerWithDB(mongoConfig.URL, mongoConfig.DBName) - if err != nil { - return nil, err - } - if cName == "" { - cName = DefaultTokenCollectionName - } - err = mHandler.C(cName).EnsureIndexKey("AccessToken", "RefreshToken") - if err != nil { - return nil, err - } - return &TokenMongoStore{ - cName: cName, - mHandler: mHandler, - }, nil -} - -// TokenMongoStore 基于MongoDB的令牌存储方式 -type TokenMongoStore struct { - cName string - mHandler *mongo.Handler -} - -// Create Add item -func (tm *TokenMongoStore) Create(item *Token) (id int64, err error) { - tm.mHandler.CHandle(tm.cName, func(c *mgo.Collection) { - tid, err := tm.mHandler.IncrID(tm.cName) - if err != nil { - return - } - item.ID = tid - err = c.Insert(item) - if err != nil { - return - } - id = item.ID - }) - return -} - -// Update Modify item -func (tm *TokenMongoStore) Update(id int64, info map[string]interface{}) (err error) { - tm.mHandler.CHandle(tm.cName, func(c *mgo.Collection) { - err = c.UpdateId(id, bson.M{"$set": info}) - if err != nil { - return - } - }) - return -} - -func (tm *TokenMongoStore) findOne(query interface{}) (token *Token, err error) { - tm.mHandler.CHandle(tm.cName, func(c *mgo.Collection) { - var result []Token - err = c.Find(query).Sort("-_id").Limit(1).All(&result) - if err != nil { - return - } - if len(result) > 0 { - token = &result[0] - } - }) - return -} - -// GetByAccessToken 根据访问令牌获取令牌信息 -func (tm *TokenMongoStore) GetByAccessToken(accessToken string) (*Token, error) { - return tm.findOne(bson.M{"AccessToken": accessToken}) -} - -// GetByRefreshToken 根据更新令牌获取令牌信息 -func (tm *TokenMongoStore) GetByRefreshToken(refreshToken string) (*Token, error) { - return tm.findOne(bson.M{"RefreshToken": refreshToken}) -} diff --git a/tokenMongoStore_test.go b/tokenMongoStore_test.go deleted file mode 100644 index 2ab1b87..0000000 --- a/tokenMongoStore_test.go +++ /dev/null @@ -1,40 +0,0 @@ -package oauth2 - -import ( - "testing" - "time" - - . "github.com/smartystreets/goconvey/convey" -) - -func TestTokenMongoStore(t *testing.T) { - Convey("Token mongodb store test", t, func() { - tokenStore, err := NewTokenMongoStore(NewMongoConfig(MongoURL, DBName), "") - So(err, ShouldBeNil) - createAt := time.Now().Unix() - tokenValue := Token{ - ClientID: "123456", - UserID: "999999", - AccessToken: "654321", - ATCreateAt: createAt, - ATExpiresIn: time.Second * 1, - RefreshToken: "000000", - RTCreateAt: createAt, - RTExpiresIn: time.Second * 1, - CreateAt: createAt, - Status: Actived, - } - id, err := tokenStore.Create(&tokenValue) - So(err, ShouldBeNil) - So(id, ShouldBeGreaterThanOrEqualTo, 1) - tokenValue.ID = id - err = tokenStore.Update(id, map[string]interface{}{"Status": Expired}) - So(err, ShouldBeNil) - at, err := tokenStore.GetByAccessToken("654321") - So(err, ShouldBeNil) - So(at.Status, ShouldEqual, Expired) - rt, err := tokenStore.GetByRefreshToken("000000") - So(err, ShouldBeNil) - So(rt.ID, ShouldEqual, id) - }) -} diff --git a/tokenStore.go b/tokenStore.go deleted file mode 100644 index dbbf3f7..0000000 --- a/tokenStore.go +++ /dev/null @@ -1,44 +0,0 @@ -package oauth2 - -import ( - "time" -) - -// Token 令牌信息 -type Token struct { - ID int64 `bson:"_id"` // 唯一标识(自增ID) - ClientID string `bson:"ClientID"` // 客户端标识 - UserID string `bson:"UserID"` // 用户标识 - AccessToken string `bson:"AccessToken"` // 访问令牌 - ATID string `bson:"ATID"` // 访问令牌标识(uuid) - ATCreateAt int64 `bson:"ATCreateAt"` // 访问令牌创建时间(时间戳) - ATExpiresIn time.Duration `bson:"ATExpiresIn"` // 访问令牌有效期(单位秒) - RefreshToken string `bson:"RefreshToken"` // 更新令牌 - RTID string `bson:"RTID"` // 更新令牌标识(uuid) - RTCreateAt int64 `bson:"RTCreateAt"` // 更新令牌创建时间(时间戳) - RTExpiresIn time.Duration `bson:"RTExpiresIn"` // 更新令牌有效期(单位秒) - Scope string `bson:"Scope"` // 申请的权限范围 - CreateAt int64 `bson:"CreateAt"` // 创建时间(时间戳) - Status STATUS `bson:"Status"` // 令牌状态 -} - -// TokenStore 令牌存储接口(持久化存储) -type TokenStore interface { - // Create 创建新的令牌,返回令牌ID - // 如果创建发生异常,则返回错误 - Create(item *Token) (int64, error) - - // Update 根据ID更新令牌信息 - // id 令牌ID - // info 需要更新的字段信息(字段名称与结构体的字段名保持一致) - // 如果更新发生异常,则返回错误 - Update(id int64, info map[string]interface{}) error - - // GetByAccessToken 根据访问令牌,获取令牌信息 - // 如果不存则返回nil - GetByAccessToken(accessToken string) (*Token, error) - - // GetByRefreshToken 根据更新令牌,获取令牌信息 - // 如果不存则返回nil - GetByRefreshToken(refreshToken string) (*Token, error) -} diff --git a/util.go b/util.go deleted file mode 100644 index 8a92145..0000000 --- a/util.go +++ /dev/null @@ -1,28 +0,0 @@ -package oauth2 - -import ( - "errors" - "net/url" -) - -// ValidateURI 验证基础的Uri与重定向的URI是否一致 -func ValidateURI(baseURI string, redirectURI string) error { - base, err := url.Parse(baseURI) - if err != nil { - return err - } - redirect, err := url.Parse(redirectURI) - if err != nil { - return err - } - if base.Fragment != "" || redirect.Fragment != "" { - return errors.New("Url must not include fragment.") - } - if base.Scheme != redirect.Scheme { - return errors.New("Scheme don't match.") - } - if base.Host != redirect.Host { - return errors.New("Host don't match.") - } - return nil -} diff --git a/util_test.go b/util_test.go deleted file mode 100644 index 54689b2..0000000 --- a/util_test.go +++ /dev/null @@ -1,14 +0,0 @@ -package oauth2 - -import ( - "testing" - - . "github.com/smartystreets/goconvey/convey" -) - -func TestUtil(t *testing.T) { - Convey("ValidateURI Test", t, func() { - err := ValidateURI("http://www.example.com", "http://www.example.com/cb?code=xxx") - So(err, ShouldBeNil) - }) -} From 74676577dc9ba8e7004c1895df28919713337d58 Mon Sep 17 00:00:00 2001 From: lyric Date: Sun, 26 Jun 2016 10:39:19 +0800 Subject: [PATCH 02/18] Add manage --- TODO.md | 7 -- error.go | 20 ---- generate.go | 30 ++--- manage/config.go | 20 ++++ const.go => manage/const.go | 28 ++++- manage/error.go | 23 ++++ manage/manager.go | 233 ++++++++++++++++++++++++++++++++++++ manage/util.go | 25 ++++ manage/util_test.go | 14 +++ manager.go | 127 -------------------- model.go | 80 ++++--------- storage.go | 26 ++-- 12 files changed, 381 insertions(+), 252 deletions(-) delete mode 100644 TODO.md delete mode 100644 error.go create mode 100644 manage/config.go rename const.go => manage/const.go (59%) create mode 100644 manage/error.go create mode 100644 manage/manager.go create mode 100644 manage/util.go create mode 100644 manage/util_test.go delete mode 100644 manager.go diff --git a/TODO.md b/TODO.md deleted file mode 100644 index b6c5a02..0000000 --- a/TODO.md +++ /dev/null @@ -1,7 +0,0 @@ -# OAuth2包的重构 - -* 将所有的Storage提取到公共的包 -* 所有文件命名、结构体命名遵循简单、明了 -* 所用的存储使用依赖注入进行管理 -* 移除针对各个授权类型的管理 -* 针对授权码的生成及存储增加独立的函数 \ No newline at end of file diff --git a/error.go b/error.go deleted file mode 100644 index 6e2dd3a..0000000 --- a/error.go +++ /dev/null @@ -1,20 +0,0 @@ -package oauth2 - -import "errors" - -var ( - // ErrNotFound Not Found - ErrNotFound = errors.New("not found") - - // ErrInvalid Invalid - ErrInvalid = errors.New("invalid") - - // ErrExpired Expired - ErrExpired = errors.New("expired") - - // ErrForbidden Forbidden - ErrForbidden = errors.New("forbidden") - - // ErrNilValue Nil Value - ErrNilValue = errors.New("nil value") -) diff --git a/generate.go b/generate.go index 84a5dd2..55dca04 100644 --- a/generate.go +++ b/generate.go @@ -3,31 +3,23 @@ package oauth2 import "time" type ( - // TokenData 提供生成令牌的基础数据 - TokenData struct { - Client ClientInfo // 客户端信息 - UserID string // 用户标识 - Scope string // 权限范围 - CreateAt time.Time // 创建时间 - ExpiresIn time.Duration // 有效期 - Identifier string // 唯一标识码 + // TokenGenerateData 提供生成令牌的基础数据 + TokenGenerateData struct { + Client ClientInfo // 客户端信息 + UserID string // 用户标识 + Scope string // 权限范围 + CreateAt time.Time // 创建时间 } // AuthorizeGenerate 授权令牌生成接口 AuthorizeGenerate interface { - // 生成授权令牌 - Token(data *TokenData) (string, error) - - // 验证令牌的有效性 - Verify(token string, data *TokenData) (bool, error) + // 授权令牌 + Token(data *TokenGenerateData) (string, error) } - // TokenGenerate 访问令牌生成接口 + // TokenGenerate 令牌生成接口 TokenGenerate interface { - // 生成访问令牌 - AccessToken(data *TokenData) (string, error) - - // 生成刷新令牌 - RefreshToken(data *TokenData) (string, error) + // 生成令牌 + Token(data *TokenGenerateData, isGenRefresh bool) (string, string, error) } ) diff --git a/manage/config.go b/manage/config.go new file mode 100644 index 0000000..8f83f10 --- /dev/null +++ b/manage/config.go @@ -0,0 +1,20 @@ +package manage + +import "time" + +// Config 授权配置参数 +type Config struct { + TokenExpiresIn time.Duration // 令牌有效期 + RefreshExpiresIn time.Duration // 刷新令牌有效期 +} + +// TokenGenerateData 提供生成令牌的相应参数 +type TokenGenerateData struct { + ClientID string // 客户端标识 + ClientSecret string // 客户端密钥 + UserID string // 用户标识 + RedirectURI string // 重定向URI + Scope string // 授权范围 + Code string // 授权码(授权码模式使用) + IsGenerateRefresh bool // 是否生成刷新令牌 +} diff --git a/const.go b/manage/const.go similarity index 59% rename from const.go rename to manage/const.go index 2f49ae5..518217e 100644 --- a/const.go +++ b/manage/const.go @@ -1,4 +1,24 @@ -package oauth2 +package manage + +// ResponseType 授权类型 +type ResponseType byte + +const ( + // Code 授权码类型 + Code ResponseType = 1 << (iota + 1) + // Token 令牌类型 + Token +) + +func (rt ResponseType) String() string { + switch rt { + case 1 << 1: + return "code" + case 1 << 2: + return "token" + } + return "unknown" +} // GrantType 定义授权模式 type GrantType byte @@ -6,8 +26,6 @@ type GrantType byte const ( // AuthorizationCode 授权码模式 AuthorizationCode GrantType = 1 << (iota + 1) - // Implicit 简化模式 - Implicit // PasswordCredentials 密码模式 PasswordCredentials // ClientCredentials 客户端模式 @@ -19,10 +37,8 @@ func (gt GrantType) String() string { case 1 << 1: return "authorization_code" case 1 << 2: - return "implicit" - case 1 << 3: return "password" - case 1 << 4: + case 1 << 3: return "clientcredentials" } return "unknown" diff --git a/manage/error.go b/manage/error.go new file mode 100644 index 0000000..8277b8d --- /dev/null +++ b/manage/error.go @@ -0,0 +1,23 @@ +package manage + +import "errors" + +var ( + // ErrNilValue Nil Value + ErrNilValue = errors.New("nil value") + + // ErrClientNotFound Client not Found + ErrClientNotFound = errors.New("client not found") + + // ErrClientInvalid Client invalid + ErrClientInvalid = errors.New("client invalid") + + // ErrAuthTokenInvalid Authorize token invalid + ErrAuthTokenInvalid = errors.New("authorize token invalid") + + // ErrExpired Expired + ErrExpired = errors.New("expired") + + // ErrForbidden Forbidden + ErrForbidden = errors.New("forbidden") +) diff --git a/manage/manager.go b/manage/manager.go new file mode 100644 index 0000000..40cdb26 --- /dev/null +++ b/manage/manager.go @@ -0,0 +1,233 @@ +package manage + +import ( + "time" + + "github.com/LyricTian/inject" + "gopkg.in/oauth2.v2" +) + +// NewManager 创建Manager的实例 +func NewManager() *Manager { + return nil +} + +// Manager OAuth2授权管理 +type Manager struct { + injector inject.Injector // 注入器 + rtcfg map[ResponseType]*Config // 授权类型配置参数 + gtcfg map[GrantType]*Config // 授权模式配置参数 +} + +// SetRTConfig 设定授权类型配置参数 +// rt 授权类型 +// cfg 配置参数 +func (m *Manager) SetRTConfig(rt ResponseType, cfg *Config) { + m.rtcfg[rt] = cfg +} + +// SetGTConfig 设定授权模式配置参数 +// gt 授权模式 +// cfg 配置参数 +func (m *Manager) SetGTConfig(gt GrantType, cfg *Config) { + m.gtcfg[gt] = cfg +} + +// MapClientModel 注入客户端信息模型 +func (m *Manager) MapClientModel(cli oauth2.ClientInfo) { + if cli == nil { + panic(ErrNilValue) + } + m.injector.Map(cli) +} + +// MapTokenModel 注入令牌信息模型 +func (m *Manager) MapTokenModel(token oauth2.TokenInfo) { + if token == nil { + panic(ErrNilValue) + } + m.injector.Map(token) +} + +// MapAuthorizeGenerate 注入授权令牌生成接口 +func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeGenerate) { + if gen == nil { + panic(ErrNilValue) + } + m.injector.Map(gen) +} + +// MapTokenGenerate 注入访问令牌生成接口 +func (m *Manager) MapTokenGenerate(gen oauth2.TokenGenerate) { + if gen == nil { + panic(ErrNilValue) + } + m.injector.Map(gen) +} + +// MapClientStorage 注入客户端信息存储接口 +func (m *Manager) MapClientStorage(stor oauth2.ClientStorage) { + if stor == nil { + panic(ErrNilValue) + } + m.injector.Map(stor) +} + +// MustClientStorage 注入客户端信息存储接口 +func (m *Manager) MustClientStorage(stor oauth2.ClientStorage, err error) { + if err != nil { + panic(err) + } + if stor == nil { + panic(ErrNilValue) + } + m.injector.Map(stor) +} + +// MapTokenStorage 注入令牌信息存储接口 +func (m *Manager) MapTokenStorage(stor oauth2.TokenStorage) { + if stor == nil { + panic(ErrNilValue) + } + m.injector.Map(stor) +} + +// MustTokenStorage 注入令牌信息存储接口 +func (m *Manager) MustTokenStorage(stor oauth2.TokenStorage, err error) { + if err != nil { + panic(err) + } + if stor == nil { + panic(ErrNilValue) + } + m.injector.Map(stor) +} + +// GetClient 获取客户端信息 +func (m *Manager) GetClient(clientID string) (cli oauth2.ClientInfo, err error) { + err = m.injector.Apply(func(stor oauth2.ClientStorage) { + cli, err = stor.GetByID(clientID) + if err != nil { + return + } else if cli == nil { + err = ErrClientNotFound + } + }) + return +} + +// GenerateAuthToken 生成授权令牌 +// rt 授权类型 +// config 生成令牌的配置参数 +func (m *Manager) GenerateAuthToken(rt ResponseType, config *TokenGenerateData) (token string, err error) { + cli, err := m.GetClient(config.ClientID) + if err != nil { + return + } else if verr := ValidateURI(cli.GetDomain(), config.RedirectURI); verr != nil { + err = verr + return + } + _, ierr := m.injector.Invoke(func(ti oauth2.TokenInfo, gen oauth2.AuthorizeGenerate, stor oauth2.TokenStorage) { + td := &oauth2.TokenGenerateData{ + Client: cli, + UserID: config.UserID, + Scope: config.Scope, + CreateAt: time.Now(), + } + tv, terr := gen.Token(td) + if terr != nil { + err = terr + return + } + ti.SetClientID(config.ClientID) + ti.SetUserID(config.UserID) + ti.SetRedirectURI(config.RedirectURI) + ti.SetScope(config.Scope) + ti.SetTokenCreateAt(td.CreateAt) + ti.SetTokenExpiresIn(m.rtcfg[rt].TokenExpiresIn) + ti.SetToken(tv) + err = stor.Create(ti) + if err != nil { + return + } + token = tv + }) + if ierr != nil && err == nil { + err = ierr + } + return +} + +// checkAuthToken 检查授权令牌 +func (m *Manager) checkAuthToken(config *TokenGenerateData) (err error) { + _, ierr := m.injector.Invoke(func(stor oauth2.TokenStorage) { + ti, terr := stor.TakeByToken(config.Code) + if terr != nil { + err = terr + return + } else if ti.GetRedirectURI() != config.RedirectURI || ti.GetClientID() != config.ClientID { + err = ErrAuthTokenInvalid + return + } else if ti.GetTokenCreateAt().Add(ti.GetTokenExpiresIn()).Before(time.Now()) { + err = ErrAuthTokenInvalid + return + } + }) + if ierr != nil && err == nil { + err = ierr + } + return +} + +// GenerateToken 生成令牌 +// gt 授权模式 +// config 生成令牌的参数 +func (m *Manager) GenerateToken(gt GrantType, config *TokenGenerateData) (token, refresh string, err error) { + if gt == AuthorizationCode { + err = m.checkAuthToken(config) + if err != nil { + return + } + } + cli, err := m.GetClient(config.ClientID) + if err != nil { + return + } else if config.ClientSecret != "" && config.ClientSecret != cli.GetSecret() { + err = ErrClientInvalid + return + } + _, ierr := m.injector.Invoke(func(ti oauth2.TokenInfo, gen oauth2.TokenGenerate, stor oauth2.TokenStorage) { + td := &oauth2.TokenGenerateData{ + Client: cli, + UserID: config.UserID, + Scope: config.Scope, + CreateAt: time.Now(), + } + tv, rv, terr := gen.Token(td, config.IsGenerateRefresh) + if terr != nil { + err = terr + return + } + ti.SetClientID(config.ClientID) + ti.SetUserID(config.UserID) + ti.SetRedirectURI(config.RedirectURI) + ti.SetScope(config.Scope) + ti.SetTokenCreateAt(td.CreateAt) + ti.SetTokenExpiresIn(m.gtcfg[gt].TokenExpiresIn) + ti.SetToken(tv) + if rv != "" { + ti.SetRefreshCreateAt(td.CreateAt) + ti.SetRefreshExpiresIn(m.gtcfg[gt].RefreshExpiresIn) + ti.SetRefresh(rv) + } + err = stor.Create(ti) + if err != nil { + return + } + token = tv + }) + if ierr != nil && err == nil { + err = ierr + } + return +} diff --git a/manage/util.go b/manage/util.go new file mode 100644 index 0000000..2677f80 --- /dev/null +++ b/manage/util.go @@ -0,0 +1,25 @@ +package manage + +import ( + "errors" + "net/url" +) + +// ValidateURI 校验重定向的URI与域名的一致性 +func ValidateURI(domain string, redirectURI string) error { + base, err := url.Parse(domain) + if err != nil { + return err + } + redirect, err := url.Parse(redirectURI) + if err != nil { + return err + } else if base.Fragment != "" || redirect.Fragment != "" { + return errors.New("Url must not include fragment.") + } else if base.Scheme != redirect.Scheme { + return errors.New("Scheme don't match.") + } else if base.Host != redirect.Host { + return errors.New("Host don't match.") + } + return nil +} diff --git a/manage/util_test.go b/manage/util_test.go new file mode 100644 index 0000000..6c4dd24 --- /dev/null +++ b/manage/util_test.go @@ -0,0 +1,14 @@ +package manage + +import ( + "testing" + + . "github.com/smartystreets/goconvey/convey" +) + +func TestUtil(t *testing.T) { + Convey("ValidateURI Test", t, func() { + err := ValidateURI("http://www.example.com", "http://www.example.com/cb?code=xxx") + So(err, ShouldBeNil) + }) +} diff --git a/manager.go b/manager.go deleted file mode 100644 index bbfaede..0000000 --- a/manager.go +++ /dev/null @@ -1,127 +0,0 @@ -package oauth2 - -import ( - "time" - - "github.com/LyricTian/inject" -) - -// NewManager 创建Manager的实例 -func NewManager() *Manager { - return nil -} - -// Config 授权配置参数 -type Config struct { - CodeExpiresIn time.Duration // 授权码有效期 - AccessExpiresIn time.Duration // 访问令牌有效期 - RefreshExpiresIn time.Duration // 刷新令牌有效期 -} - -// Manager OAuth2授权管理 -type Manager struct { - injector inject.Injector - configs map[GrantType]*Config -} - -// SetConfig 设定配置参数 -func (m *Manager) SetConfig(gt GrantType, cfg *Config) { - m.configs[gt] = cfg -} - -// MapClientModel 注入客户端信息模型 -func (m *Manager) MapClientModel(cli ClientInfo) { - if cli == nil { - panic(ErrNilValue) - } - m.injector.Map(cli) -} - -// MapAuthorizeModel 注入授权信息模型 -func (m *Manager) MapAuthorizeModel(auth Authorize) { - if auth == nil { - panic(ErrNilValue) - } - m.injector.Map(auth) -} - -// MapTokenModel 注入令牌信息模型 -func (m *Manager) MapTokenModel(token TokenInfo) { - if token == nil { - panic(ErrNilValue) - } - m.injector.Map(token) -} - -// MapAuthorizeGenerate 注入授权令牌生成接口 -func (m *Manager) MapAuthorizeGenerate(gen AuthorizeGenerate) { - if gen == nil { - panic(ErrNilValue) - } - m.injector.Map(gen) -} - -// MapTokenGenerate 注入访问令牌生成接口 -func (m *Manager) MapTokenGenerate(gen TokenGenerate) { - if gen == nil { - panic(ErrNilValue) - } - m.injector.Map(gen) -} - -// MapClientStorage 注入客户端信息存储接口 -func (m *Manager) MapClientStorage(stor ClientStorage) { - if stor == nil { - panic(ErrNilValue) - } - m.injector.Map(stor) -} - -// MustClientStorage 注入客户端信息存储接口 -func (m *Manager) MustClientStorage(stor ClientStorage, err error) { - if err != nil { - panic(err) - } - if stor == nil { - panic(ErrNilValue) - } - m.injector.Map(stor) -} - -// MapAuthorizeStorage 注入授权码信息存储接口 -func (m *Manager) MapAuthorizeStorage(stor AuthorizeStorage) { - if stor == nil { - panic(ErrNilValue) - } - m.injector.Map(stor) -} - -// MustAuthorizeStorage 注入授权码信息存储接口 -func (m *Manager) MustAuthorizeStorage(stor AuthorizeStorage, err error) { - if err != nil { - panic(err) - } - if stor == nil { - panic(ErrNilValue) - } - m.injector.Map(stor) -} - -// MapTokenStorage 注入令牌信息存储接口 -func (m *Manager) MapTokenStorage(stor TokenStorage) { - if stor == nil { - panic(ErrNilValue) - } - m.injector.Map(stor) -} - -// MustTokenStorage 注入令牌信息存储接口 -func (m *Manager) MustTokenStorage(stor TokenStorage, err error) { - if err != nil { - panic(err) - } - if stor == nil { - panic(ErrNilValue) - } - m.injector.Map(stor) -} diff --git a/model.go b/model.go index 5826aec..0c9ae4e 100644 --- a/model.go +++ b/model.go @@ -16,8 +16,8 @@ type ( GetRetainData() interface{} } - // Authorize 授权信息模型接口 - Authorize interface { + // TokenInfo 令牌信息模型接口 + TokenInfo interface { // 客户端标识 GetClientID() string // 设置客户端标识 @@ -34,65 +34,31 @@ type ( GetScope() string // 设置权限范围 SetScope(string) - // 创建时间 - GetCreateAt() time.Time - // 设置创建时间 - SetCreateAt(time.Time) - // 有效期 - GetExpiresIn() time.Duration - // 设置有效期 - SetExpiresIn(time.Duration) - // 授权令牌 - GetToken() string - // 设置授权令牌 - SetToken(string) - // 用于标识授权令牌的唯一标识码 - GetIdentifier() string - // 设置用于标识授权令牌的唯一标识码 - SetIdentifier(string) - } - - // TokenInfo 令牌信息模型接口 - TokenInfo interface { - // 客户端标识 - GetClientID() string - // 设置客户端标识 - SetClientID(string) - // 用户标识 - GetUserID() string - // 设置用户标识 - SetUserID(string) - // 访问令牌 - GetAccess() TokenBasic - // 设置访问令牌 - SetAccess(TokenBasic) - // 更新令牌 - GetRefresh() TokenBasic - // 设置更新令牌 - SetRefresh(TokenBasic) - // 权限范围 - GetScope() string - // 设置权限范围 - SetScope(string) - } - // TokenBasic 令牌基础模型接口 - TokenBasic interface { - // 创建时间 - GetCreateAt() time.Time - // 设置创建时间 - SetCreateAt(time.Time) - // 有效期 - GetExpiresIn() time.Duration - // 设置有效期 - SetExpiresIn(time.Duration) // 令牌 GetToken() string // 设置令牌 SetToken(string) - // 用于标识令牌的唯一标识码 - GetIdentifier() string - // 设置用于标识令牌的唯一标识码 - SetIdentifier(string) + // 令牌创建时间 + GetTokenCreateAt() time.Time + // 设置令牌创建时间 + SetTokenCreateAt(time.Time) + // 令牌有效期 + GetTokenExpiresIn() time.Duration + // 设置令牌有效期 + SetTokenExpiresIn(time.Duration) + + // 刷新令牌 + GetRefresh() string + // 设置刷新令牌 + SetRefresh(string) + // 刷新令牌创建时间 + GetRefreshCreateAt() time.Time + // 设置刷新令牌创建时间 + SetRefreshCreateAt(time.Time) + // 刷新令牌有效期 + GetRefreshExpiresIn() time.Duration + // 设置刷新令牌有效期 + SetRefreshExpiresIn(time.Duration) } ) diff --git a/storage.go b/storage.go index f66b302..389f35d 100644 --- a/storage.go +++ b/storage.go @@ -8,15 +8,6 @@ type ( GetByID(id string) (ClientInfo, error) } - // AuthorizeStorage 授权码信息存储接口 - AuthorizeStorage interface { - // 将授权信息放入存储 - Put(info Authorize) error - - // 根据授权令牌取出授权信息 - TakeByToken(token string) (Authorize, error) - } - // TokenStorage 令牌信息存储接口 TokenStorage interface { // Create 创建并存储新的令牌信息 @@ -25,17 +16,20 @@ type ( // UpdateByRefresh 根据刷新令牌更新令牌信息 UpdateByRefresh(refresh string, info TokenInfo) error - // 根据访问令牌获取令牌信息数据 - GetByAccess(access string) (TokenInfo, error) + // DeleteByToken 根据令牌删除令牌信息 + DeleteByToken(val string) error + + // 根据令牌取出令牌信息数据(获取并删除) + TakeByToken(val string) (TokenInfo, error) + + // 根据令牌获取令牌信息数据 + GetByToken(val string) (TokenInfo, error) // 根据刷新令牌获取令牌信息数据 GetByRefresh(refresh string) (TokenInfo, error) - // 根据访问令牌废除令牌信息 - RevokeByAccess(access string) error - - // 将该访问令牌对应的令牌信息作过期处理 - ExpiredByAccess(access string) error + // 将该令牌对应的令牌信息作过期处理 + ExpiredByToken(val string) error // 将该刷新令牌对应的令牌信息作过期处理 ExpiredByRefresh(refresh string) error From 15212baa4b20d0136ef89adac8b11b648d95e066 Mon Sep 17 00:00:00 2001 From: lyric Date: Tue, 28 Jun 2016 08:59:28 +0800 Subject: [PATCH 03/18] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=B9=B6=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0Manager=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- manage/const.go => const.go | 4 +- generate.go | 13 ++- manage.go | 40 ++++++++ manage/config.go | 20 ---- manage/error.go | 14 ++- manage/manager.go | 196 +++++++++++++++++++++++++++++------- 6 files changed, 219 insertions(+), 68 deletions(-) rename manage/const.go => const.go (93%) create mode 100644 manage.go delete mode 100644 manage/config.go diff --git a/manage/const.go b/const.go similarity index 93% rename from manage/const.go rename to const.go index 518217e..d091275 100644 --- a/manage/const.go +++ b/const.go @@ -1,6 +1,6 @@ -package manage +package oauth2 -// ResponseType 授权类型 +// ResponseType 定义授权类型 type ResponseType byte const ( diff --git a/generate.go b/generate.go index 55dca04..de4ba41 100644 --- a/generate.go +++ b/generate.go @@ -3,23 +3,22 @@ package oauth2 import "time" type ( - // TokenGenerateData 提供生成令牌的基础数据 - TokenGenerateData struct { + // TokenGenerateBasic 提供生成令牌的基础数据 + TokenGenerateBasic struct { Client ClientInfo // 客户端信息 UserID string // 用户标识 - Scope string // 权限范围 CreateAt time.Time // 创建时间 } - // AuthorizeGenerate 授权令牌生成接口 - AuthorizeGenerate interface { + // AuthorizeTokenGenerate 授权令牌生成接口 + AuthorizeTokenGenerate interface { // 授权令牌 - Token(data *TokenGenerateData) (string, error) + Token(data *TokenGenerateBasic) (string, error) } // TokenGenerate 令牌生成接口 TokenGenerate interface { // 生成令牌 - Token(data *TokenGenerateData, isGenRefresh bool) (string, string, error) + Token(data *TokenGenerateBasic, isGenRefresh bool) (string, string, error) } ) diff --git a/manage.go b/manage.go new file mode 100644 index 0000000..3eda7f9 --- /dev/null +++ b/manage.go @@ -0,0 +1,40 @@ +package oauth2 + +// TokenGenerateRequest 提供生成令牌的请求参数 +type TokenGenerateRequest struct { + ClientID string // 客户端标识 + ClientSecret string // 客户端密钥 + UserID string // 用户标识 + RedirectURI string // 重定向URI + Scope string // 授权范围 + Code string // 授权码(授权码模式使用) + IsGenerateRefresh bool // 是否生成刷新令牌 +} + +// Manager OAuth2授权管理接口 +type Manager interface { + // GenerateAuthToken 生成授权令牌 + // rt 授权类型 + // tgr 生成令牌的请求参数 + GenerateAuthToken(rt ResponseType, tgr *TokenGenerateRequest) (token string, err error) + + // GenerateToken 生成访问令牌、刷新令牌 + // rt 授权模式 + // tgr 生成令牌的请求参数 + GenerateToken(rt GrantType, tgr *TokenGenerateRequest) (token, refresh string, err error) + + // RefreshToken 使用刷新令牌更新访问令牌 + // refresh 刷新令牌 + // scope 作用域 + RefreshToken(refresh, scope string) (token string, err error) + + // RevokeToken 使用访问令牌废除令牌信息 + // token 访问令牌 + RevokeToken(token string) (err error) + + // CheckToken 令牌检查,如果存在则返回令牌信息 + CheckToken(token string) (ti TokenInfo, err error) + + // CheckRefreshToken 访问令牌检查,如果存在则返回令牌信息 + CheckRefreshToken(refresh string) (ti TokenInfo, err error) +} diff --git a/manage/config.go b/manage/config.go deleted file mode 100644 index 8f83f10..0000000 --- a/manage/config.go +++ /dev/null @@ -1,20 +0,0 @@ -package manage - -import "time" - -// Config 授权配置参数 -type Config struct { - TokenExpiresIn time.Duration // 令牌有效期 - RefreshExpiresIn time.Duration // 刷新令牌有效期 -} - -// TokenGenerateData 提供生成令牌的相应参数 -type TokenGenerateData struct { - ClientID string // 客户端标识 - ClientSecret string // 客户端密钥 - UserID string // 用户标识 - RedirectURI string // 重定向URI - Scope string // 授权范围 - Code string // 授权码(授权码模式使用) - IsGenerateRefresh bool // 是否生成刷新令牌 -} diff --git a/manage/error.go b/manage/error.go index 8277b8d..21e7297 100644 --- a/manage/error.go +++ b/manage/error.go @@ -15,9 +15,15 @@ var ( // ErrAuthTokenInvalid Authorize token invalid ErrAuthTokenInvalid = errors.New("authorize token invalid") - // ErrExpired Expired - ErrExpired = errors.New("expired") + // ErrRefreshInvalid Refresh token invalid + ErrRefreshInvalid = errors.New("refresh token invalid") - // ErrForbidden Forbidden - ErrForbidden = errors.New("forbidden") + // ErrRefreshExpired Refresh token expired + ErrRefreshExpired = errors.New("refresh token expired") + + // ErrTokenInvalid Token invalid + ErrTokenInvalid = errors.New("token invalid") + + // ErrTokenExpired Token expired + ErrTokenExpired = errors.New("token expired") ) diff --git a/manage/manager.go b/manage/manager.go index 40cdb26..0987bb0 100644 --- a/manage/manager.go +++ b/manage/manager.go @@ -7,6 +7,12 @@ import ( "gopkg.in/oauth2.v2" ) +// Config 授权配置参数 +type Config struct { + TokenExpiresIn time.Duration // 令牌有效期 + RefreshExpiresIn time.Duration // 刷新令牌有效期 +} + // NewManager 创建Manager的实例 func NewManager() *Manager { return nil @@ -14,22 +20,22 @@ func NewManager() *Manager { // Manager OAuth2授权管理 type Manager struct { - injector inject.Injector // 注入器 - rtcfg map[ResponseType]*Config // 授权类型配置参数 - gtcfg map[GrantType]*Config // 授权模式配置参数 + injector inject.Injector // 注入器 + rtcfg map[oauth2.ResponseType]*Config // 授权类型配置参数 + gtcfg map[oauth2.GrantType]*Config // 授权模式配置参数 } // SetRTConfig 设定授权类型配置参数 // rt 授权类型 // cfg 配置参数 -func (m *Manager) SetRTConfig(rt ResponseType, cfg *Config) { +func (m *Manager) SetRTConfig(rt oauth2.ResponseType, cfg *Config) { m.rtcfg[rt] = cfg } // SetGTConfig 设定授权模式配置参数 // gt 授权模式 // cfg 配置参数 -func (m *Manager) SetGTConfig(gt GrantType, cfg *Config) { +func (m *Manager) SetGTConfig(gt oauth2.GrantType, cfg *Config) { m.gtcfg[gt] = cfg } @@ -50,7 +56,7 @@ func (m *Manager) MapTokenModel(token oauth2.TokenInfo) { } // MapAuthorizeGenerate 注入授权令牌生成接口 -func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeGenerate) { +func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeTokenGenerate) { if gen == nil { panic(ErrNilValue) } @@ -118,20 +124,19 @@ func (m *Manager) GetClient(clientID string) (cli oauth2.ClientInfo, err error) // GenerateAuthToken 生成授权令牌 // rt 授权类型 -// config 生成令牌的配置参数 -func (m *Manager) GenerateAuthToken(rt ResponseType, config *TokenGenerateData) (token string, err error) { - cli, err := m.GetClient(config.ClientID) +// tgr 生成令牌的配置参数 +func (m *Manager) GenerateAuthToken(rt oauth2.ResponseType, tgr *oauth2.TokenGenerateRequest) (token string, err error) { + cli, err := m.GetClient(tgr.ClientID) if err != nil { return - } else if verr := ValidateURI(cli.GetDomain(), config.RedirectURI); verr != nil { + } else if verr := ValidateURI(cli.GetDomain(), tgr.RedirectURI); verr != nil { err = verr return } - _, ierr := m.injector.Invoke(func(ti oauth2.TokenInfo, gen oauth2.AuthorizeGenerate, stor oauth2.TokenStorage) { - td := &oauth2.TokenGenerateData{ + _, ierr := m.injector.Invoke(func(ti oauth2.TokenInfo, gen oauth2.AuthorizeTokenGenerate, stor oauth2.TokenStorage) { + td := &oauth2.TokenGenerateBasic{ Client: cli, - UserID: config.UserID, - Scope: config.Scope, + UserID: tgr.UserID, CreateAt: time.Now(), } tv, terr := gen.Token(td) @@ -139,10 +144,10 @@ func (m *Manager) GenerateAuthToken(rt ResponseType, config *TokenGenerateData) err = terr return } - ti.SetClientID(config.ClientID) - ti.SetUserID(config.UserID) - ti.SetRedirectURI(config.RedirectURI) - ti.SetScope(config.Scope) + ti.SetClientID(tgr.ClientID) + ti.SetUserID(tgr.UserID) + ti.SetRedirectURI(tgr.RedirectURI) + ti.SetScope(tgr.Scope) ti.SetTokenCreateAt(td.CreateAt) ti.SetTokenExpiresIn(m.rtcfg[rt].TokenExpiresIn) ti.SetToken(tv) @@ -159,13 +164,13 @@ func (m *Manager) GenerateAuthToken(rt ResponseType, config *TokenGenerateData) } // checkAuthToken 检查授权令牌 -func (m *Manager) checkAuthToken(config *TokenGenerateData) (err error) { +func (m *Manager) checkAuthToken(tgr *oauth2.TokenGenerateRequest) (err error) { _, ierr := m.injector.Invoke(func(stor oauth2.TokenStorage) { - ti, terr := stor.TakeByToken(config.Code) + ti, terr := stor.TakeByToken(tgr.Code) if terr != nil { err = terr return - } else if ti.GetRedirectURI() != config.RedirectURI || ti.GetClientID() != config.ClientID { + } else if ti.GetRedirectURI() != tgr.RedirectURI || ti.GetClientID() != tgr.ClientID { err = ErrAuthTokenInvalid return } else if ti.GetTokenCreateAt().Add(ti.GetTokenExpiresIn()).Before(time.Now()) { @@ -181,37 +186,36 @@ func (m *Manager) checkAuthToken(config *TokenGenerateData) (err error) { // GenerateToken 生成令牌 // gt 授权模式 -// config 生成令牌的参数 -func (m *Manager) GenerateToken(gt GrantType, config *TokenGenerateData) (token, refresh string, err error) { - if gt == AuthorizationCode { - err = m.checkAuthToken(config) +// tgr 生成令牌的参数 +func (m *Manager) GenerateToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (token, refresh string, err error) { + if gt == oauth2.AuthorizationCode { + err = m.checkAuthToken(tgr) if err != nil { return } } - cli, err := m.GetClient(config.ClientID) + cli, err := m.GetClient(tgr.ClientID) if err != nil { return - } else if config.ClientSecret != "" && config.ClientSecret != cli.GetSecret() { + } else if tgr.ClientSecret != "" && tgr.ClientSecret != cli.GetSecret() { err = ErrClientInvalid return } _, ierr := m.injector.Invoke(func(ti oauth2.TokenInfo, gen oauth2.TokenGenerate, stor oauth2.TokenStorage) { - td := &oauth2.TokenGenerateData{ + td := &oauth2.TokenGenerateBasic{ Client: cli, - UserID: config.UserID, - Scope: config.Scope, + UserID: tgr.UserID, CreateAt: time.Now(), } - tv, rv, terr := gen.Token(td, config.IsGenerateRefresh) + tv, rv, terr := gen.Token(td, tgr.IsGenerateRefresh) if terr != nil { err = terr return } - ti.SetClientID(config.ClientID) - ti.SetUserID(config.UserID) - ti.SetRedirectURI(config.RedirectURI) - ti.SetScope(config.Scope) + ti.SetClientID(tgr.ClientID) + ti.SetUserID(tgr.UserID) + ti.SetRedirectURI(tgr.RedirectURI) + ti.SetScope(tgr.Scope) ti.SetTokenCreateAt(td.CreateAt) ti.SetTokenExpiresIn(m.gtcfg[gt].TokenExpiresIn) ti.SetToken(tv) @@ -231,3 +235,125 @@ func (m *Manager) GenerateToken(gt GrantType, config *TokenGenerateData) (token, } return } + +// RefreshToken 刷新令牌 +func (m *Manager) RefreshToken(refresh, scope string) (token string, err error) { + ti, err := m.CheckRefreshToken(refresh) + if err != nil { + return + } + _, ierr := m.injector.Invoke(func(stor oauth2.TokenStorage, gen oauth2.TokenGenerate) { + cli, cerr := m.GetClient(ti.GetClientID()) + if cerr != nil { + err = cerr + return + } + td := &oauth2.TokenGenerateBasic{ + Client: cli, + UserID: ti.GetUserID(), + CreateAt: time.Now(), + } + tv, _, terr := gen.Token(td, false) + if terr != nil { + err = terr + return + } + ti.SetToken(tv) + ti.SetTokenCreateAt(td.CreateAt) + if scope != "" { + ti.SetScope(scope) + } + err = stor.UpdateByRefresh(refresh, ti) + if err != nil { + return + } + token = tv + }) + if ierr != nil && err == nil { + err = ierr + } + return +} + +// RevokeToken 废除令牌 +func (m *Manager) RevokeToken(token string) (err error) { + if token == "" { + err = ErrTokenInvalid + return + } + _, ierr := m.injector.Invoke(func(stor oauth2.TokenStorage) { + err = stor.DeleteByToken(token) + }) + if ierr != nil && err == nil { + err = ierr + } + return +} + +// CheckToken 令牌检查 +func (m *Manager) CheckToken(token string) (info oauth2.TokenInfo, err error) { + if token == "" { + err = ErrTokenInvalid + return + } + _, ierr := m.injector.Invoke(func(stor oauth2.TokenStorage) { + ct := time.Now() + ti, terr := stor.GetByToken(token) + if terr != nil { + err = terr + return + } else if ti == nil { + err = ErrTokenInvalid + return + } else if ti.GetRefresh() != "" && ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) { // 检查刷新令牌是否过期 + if verr := stor.ExpiredByRefresh(ti.GetRefresh()); verr != nil { + err = verr + return + } + err = ErrRefreshExpired + } else if ti.GetTokenCreateAt().Add(ti.GetTokenExpiresIn()).Before(ct) { // 检查令牌是否过期 + if verr := stor.ExpiredByToken(token); verr != nil { + err = verr + return + } + err = ErrTokenExpired + return + } + info = ti + }) + if ierr != nil && err == nil { + err = ierr + } + return +} + +// CheckRefreshToken 访问令牌检查 +func (m *Manager) CheckRefreshToken(refresh string) (info oauth2.TokenInfo, err error) { + if refresh == "" { + err = ErrRefreshInvalid + return + } + _, ierr := m.injector.Invoke(func(stor oauth2.TokenStorage) { + ti, terr := stor.GetByRefresh(refresh) + if terr != nil { + err = terr + return + } else if ti == nil { + err = ErrRefreshInvalid + return + } else if ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(time.Now()) { + // 废除过期的令牌 + if verr := stor.ExpiredByRefresh(refresh); verr != nil { + err = verr + return + } + err = ErrRefreshExpired + return + } + info = ti + }) + if ierr != nil && err == nil { + err = ierr + } + return +} From 031d2cc6d6b42eb1afa1586d13fc72ab43a16d67 Mon Sep 17 00:00:00 2001 From: lyric Date: Tue, 28 Jun 2016 20:33:32 +0800 Subject: [PATCH 04/18] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BB=A4=E7=89=8C?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E3=80=81=E5=AE=A2=E6=88=B7=E7=AB=AF=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- const.go | 4 +- generate.go | 4 +- manage.go | 10 ++-- manage/manager.go | 35 ++++++++++---- model.go | 22 ++++----- models/client.go | 28 +++++++++++ models/token.go | 118 ++++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 191 insertions(+), 30 deletions(-) create mode 100644 models/client.go create mode 100644 models/token.go diff --git a/const.go b/const.go index d091275..b4e0440 100644 --- a/const.go +++ b/const.go @@ -24,8 +24,8 @@ func (rt ResponseType) String() string { type GrantType byte const ( - // AuthorizationCode 授权码模式 - AuthorizationCode GrantType = 1 << (iota + 1) + // AuthorizationCodeCredentials 授权码模式 + AuthorizationCodeCredentials GrantType = 1 << (iota + 1) // PasswordCredentials 密码模式 PasswordCredentials // ClientCredentials 客户端模式 diff --git a/generate.go b/generate.go index de4ba41..c6447ca 100644 --- a/generate.go +++ b/generate.go @@ -13,12 +13,12 @@ type ( // AuthorizeTokenGenerate 授权令牌生成接口 AuthorizeTokenGenerate interface { // 授权令牌 - Token(data *TokenGenerateBasic) (string, error) + Token(data *TokenGenerateBasic) (token string, err error) } // TokenGenerate 令牌生成接口 TokenGenerate interface { // 生成令牌 - Token(data *TokenGenerateBasic, isGenRefresh bool) (string, string, error) + Token(data *TokenGenerateBasic, isGenRefresh bool) (token string, refresh string, err error) } ) diff --git a/manage.go b/manage.go index 3eda7f9..2a662e3 100644 --- a/manage.go +++ b/manage.go @@ -8,7 +8,7 @@ type TokenGenerateRequest struct { RedirectURI string // 重定向URI Scope string // 授权范围 Code string // 授权码(授权码模式使用) - IsGenerateRefresh bool // 是否生成刷新令牌 + IsGenerateRefresh bool // 是否生成更新令牌 } // Manager OAuth2授权管理接口 @@ -18,13 +18,13 @@ type Manager interface { // tgr 生成令牌的请求参数 GenerateAuthToken(rt ResponseType, tgr *TokenGenerateRequest) (token string, err error) - // GenerateToken 生成访问令牌、刷新令牌 + // GenerateToken 生成访问令牌、更新令牌 // rt 授权模式 // tgr 生成令牌的请求参数 GenerateToken(rt GrantType, tgr *TokenGenerateRequest) (token, refresh string, err error) - // RefreshToken 使用刷新令牌更新访问令牌 - // refresh 刷新令牌 + // RefreshToken 使用更新令牌更新访问令牌 + // refresh 更新令牌 // scope 作用域 RefreshToken(refresh, scope string) (token string, err error) @@ -35,6 +35,6 @@ type Manager interface { // CheckToken 令牌检查,如果存在则返回令牌信息 CheckToken(token string) (ti TokenInfo, err error) - // CheckRefreshToken 访问令牌检查,如果存在则返回令牌信息 + // CheckRefreshToken 更新令牌检查,如果存在则返回令牌信息 CheckRefreshToken(refresh string) (ti TokenInfo, err error) } diff --git a/manage/manager.go b/manage/manager.go index 0987bb0..27f9ed5 100644 --- a/manage/manager.go +++ b/manage/manager.go @@ -9,13 +9,28 @@ import ( // Config 授权配置参数 type Config struct { - TokenExpiresIn time.Duration // 令牌有效期 - RefreshExpiresIn time.Duration // 刷新令牌有效期 + TokenExp time.Duration // 令牌有效期 + RefreshExp time.Duration // g令牌有效期 } // NewManager 创建Manager的实例 func NewManager() *Manager { - return nil + m := &Manager{ + injector: inject.New(), + } + // 设定参数默认值 + + // 设定授权码的有效期为10分钟 + m.SetRTConfig(oauth2.Code, &Config{TokenExp: time.Minute * 10}) + // 设定简化模式授权令牌的有效期为1小时 + m.SetRTConfig(oauth2.Token, &Config{TokenExp: time.Hour * 1}) + + // 设定授权码模式令牌的有效期为2小时,g令牌的有效期为3天 + m.SetGTConfig(oauth2.PasswordCredentials, &Config{TokenExp: time.Hour * 2, RefreshExp: time.Hour * 24 * 3}) + + // 设定客户端模式令牌的有效期为1小时 + m.SetGTConfig(oauth2.ClientCredentials, &Config{TokenExp: time.Hour * 2}) + return m } // Manager OAuth2授权管理 @@ -149,7 +164,7 @@ func (m *Manager) GenerateAuthToken(rt oauth2.ResponseType, tgr *oauth2.TokenGen ti.SetRedirectURI(tgr.RedirectURI) ti.SetScope(tgr.Scope) ti.SetTokenCreateAt(td.CreateAt) - ti.SetTokenExpiresIn(m.rtcfg[rt].TokenExpiresIn) + ti.SetTokenExpiresIn(m.rtcfg[rt].TokenExp) ti.SetToken(tv) err = stor.Create(ti) if err != nil { @@ -188,7 +203,7 @@ func (m *Manager) checkAuthToken(tgr *oauth2.TokenGenerateRequest) (err error) { // gt 授权模式 // tgr 生成令牌的参数 func (m *Manager) GenerateToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (token, refresh string, err error) { - if gt == oauth2.AuthorizationCode { + if gt == oauth2.AuthorizationCodeCredentials { err = m.checkAuthToken(tgr) if err != nil { return @@ -217,11 +232,11 @@ func (m *Manager) GenerateToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRe ti.SetRedirectURI(tgr.RedirectURI) ti.SetScope(tgr.Scope) ti.SetTokenCreateAt(td.CreateAt) - ti.SetTokenExpiresIn(m.gtcfg[gt].TokenExpiresIn) + ti.SetTokenExpiresIn(m.gtcfg[gt].TokenExp) ti.SetToken(tv) if rv != "" { ti.SetRefreshCreateAt(td.CreateAt) - ti.SetRefreshExpiresIn(m.gtcfg[gt].RefreshExpiresIn) + ti.SetRefreshExpiresIn(m.gtcfg[gt].RefreshExp) ti.SetRefresh(rv) } err = stor.Create(ti) @@ -236,7 +251,7 @@ func (m *Manager) GenerateToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRe return } -// RefreshToken 刷新令牌 +// RefreshToken 更新访问令牌 func (m *Manager) RefreshToken(refresh, scope string) (token string, err error) { ti, err := m.CheckRefreshToken(refresh) if err != nil { @@ -305,7 +320,7 @@ func (m *Manager) CheckToken(token string) (info oauth2.TokenInfo, err error) { } else if ti == nil { err = ErrTokenInvalid return - } else if ti.GetRefresh() != "" && ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) { // 检查刷新令牌是否过期 + } else if ti.GetRefresh() != "" && ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) { // 检查g令牌是否过期 if verr := stor.ExpiredByRefresh(ti.GetRefresh()); verr != nil { err = verr return @@ -327,7 +342,7 @@ func (m *Manager) CheckToken(token string) (info oauth2.TokenInfo, err error) { return } -// CheckRefreshToken 访问令牌检查 +// CheckRefreshToken 更新令牌检查 func (m *Manager) CheckRefreshToken(refresh string) (info oauth2.TokenInfo, err error) { if refresh == "" { err = ErrRefreshInvalid diff --git a/model.go b/model.go index 0c9ae4e..f46be35 100644 --- a/model.go +++ b/model.go @@ -6,7 +6,7 @@ import "time" type ( // ClientInfo 客户端信息模型接口 ClientInfo interface { - // 客户端唯一标识 + // 客户端ID GetID() string // 客户端秘钥 GetSecret() string @@ -18,13 +18,13 @@ type ( // TokenInfo 令牌信息模型接口 TokenInfo interface { - // 客户端标识 + // 客户端ID GetClientID() string - // 设置客户端标识 + // 设置客户端ID SetClientID(string) - // 用户标识 + // 用户ID GetUserID() string - // 设置用户标识 + // 设置用户ID SetUserID(string) // 重定向URI GetRedirectURI() string @@ -48,17 +48,17 @@ type ( // 设置令牌有效期 SetTokenExpiresIn(time.Duration) - // 刷新令牌 + // 更新令牌 GetRefresh() string - // 设置刷新令牌 + // 设置更新令牌 SetRefresh(string) - // 刷新令牌创建时间 + // 更新令牌创建时间 GetRefreshCreateAt() time.Time - // 设置刷新令牌创建时间 + // 设置更新令牌创建时间 SetRefreshCreateAt(time.Time) - // 刷新令牌有效期 + // 更新令牌有效期 GetRefreshExpiresIn() time.Duration - // 设置刷新令牌有效期 + // 设置更新令牌有效期 SetRefreshExpiresIn(time.Duration) } ) diff --git a/models/client.go b/models/client.go new file mode 100644 index 0000000..dd8e3e1 --- /dev/null +++ b/models/client.go @@ -0,0 +1,28 @@ +package models + +// Client 客户端信息 +type Client struct { + ClientID string `bson:"ClientID"` // 客户端ID + Secret string `bson:"Secret"` // 密钥 + Domain string `bson:"Domain"` // 域名url +} + +// GetID 客户端ID +func (c *Client) GetID() string { + return c.ClientID +} + +// GetSecret 客户端秘钥 +func (c *Client) GetSecret() string { + return c.Secret +} + +// GetDomain 域名URL +func (c *Client) GetDomain() string { + return c.Domain +} + +// GetRetainData 自定义数据 +func (c *Client) GetRetainData() interface{} { + return nil +} diff --git a/models/token.go b/models/token.go new file mode 100644 index 0000000..48b40b1 --- /dev/null +++ b/models/token.go @@ -0,0 +1,118 @@ +package models + +import "time" + +// Token 令牌信息 +type Token struct { + ID int64 `bson:"_id"` // 唯一标识 + ClientID string `bson:"ClientID"` // 客户端标识 + UserID string `bson:"UserID"` // 用户标识 + RedirectURI string `bson:"RedirectURI"` // 重定向URI + Scope string `bson:"Scope"` // 权限范围 + Token string `bson:"Token"` // 令牌 + TokenCreateAt time.Time `bson:"TokenCreateAt"` // 令牌创建时间 + TokenExpiresIn time.Duration `bson:"TokenExpiresIn"` // 令牌有效期 + Refresh string `bson:"Refresh"` // 更新令牌 + RefreshCreateAt time.Time `bson:"RefreshCreateAt"` // 更新令牌创建时间 + RefreshExpiresIn time.Duration `bson:"RefreshExpiresIn"` // 更新令牌有效期 +} + +// GetClientID 客户端ID +func (t *Token) GetClientID() string { + return t.ClientID +} + +// SetClientID 设置客户端ID +func (t *Token) SetClientID(clientID string) { + t.ClientID = clientID +} + +// GetUserID 用户ID +func (t *Token) GetUserID() string { + return t.UserID +} + +// SetUserID 设置用户ID +func (t *Token) SetUserID(userID string) { + t.UserID = userID +} + +// GetRedirectURI 重定向URI +func (t *Token) GetRedirectURI() string { + return t.RedirectURI +} + +// SetRedirectURI 设置重定向URI +func (t *Token) SetRedirectURI(redirectURI string) { + t.RedirectURI = redirectURI +} + +// GetScope 权限范围 +func (t *Token) GetScope() string { + return t.Scope +} + +// SetScope 设置权限范围 +func (t *Token) SetScope(scope string) { + t.Scope = scope +} + +// GetToken 令牌 +func (t *Token) GetToken() string { + return t.Token +} + +// SetToken 设置令牌 +func (t *Token) SetToken(token string) { + t.Token = token +} + +// GetTokenCreateAt 令牌创建时间 +func (t *Token) GetTokenCreateAt() time.Time { + return t.TokenCreateAt +} + +// SetTokenCreateAt 设置令牌创建时间 +func (t *Token) SetTokenCreateAt(createAt time.Time) { + t.TokenCreateAt = createAt +} + +// GetTokenExpiresIn 令牌有效期 +func (t *Token) GetTokenExpiresIn() time.Duration { + return t.TokenExpiresIn +} + +// SetTokenExpiresIn 设置令牌有效期 +func (t *Token) SetTokenExpiresIn(exp time.Duration) { + t.TokenExpiresIn = exp +} + +// GetRefresh 更新令牌 +func (t *Token) GetRefresh() string { + return t.Refresh +} + +// SetRefresh 设置更新令牌 +func (t *Token) SetRefresh(refresh string) { + t.Refresh = refresh +} + +// GetRefreshCreateAt 更新令牌创建时间 +func (t *Token) GetRefreshCreateAt() time.Time { + return t.RefreshCreateAt +} + +// SetRefreshCreateAt 设置更新令牌创建时间 +func (t *Token) SetRefreshCreateAt(createAt time.Time) { + t.RefreshCreateAt = createAt +} + +// GetRefreshExpiresIn 更新令牌有效期 +func (t *Token) GetRefreshExpiresIn() time.Duration { + return t.RefreshExpiresIn +} + +// SetRefreshExpiresIn 设置更新令牌有效期 +func (t *Token) SetRefreshExpiresIn(exp time.Duration) { + t.RefreshExpiresIn = exp +} From 74d7ad64ccb6886d400158d2c807f09c33012a4a Mon Sep 17 00:00:00 2001 From: lyric Date: Wed, 29 Jun 2016 10:11:19 +0800 Subject: [PATCH 05/18] Fixed naming convention --- const.go | 4 ++ generate.go | 18 +++---- manage.go | 28 ++++++---- manage/error.go | 12 ++--- manage/manager.go | 132 ++++++++++++++++++++++++---------------------- model.go | 30 ++++++----- models/client.go | 8 +-- models/token.go | 68 ++++++++++++++---------- storage.go | 25 ++++----- 9 files changed, 175 insertions(+), 150 deletions(-) diff --git a/const.go b/const.go index b4e0440..2aa18a8 100644 --- a/const.go +++ b/const.go @@ -30,6 +30,8 @@ const ( PasswordCredentials // ClientCredentials 客户端模式 ClientCredentials + // RefreshCredentials 更新令牌模式 + RefreshCredentials ) func (gt GrantType) String() string { @@ -40,6 +42,8 @@ func (gt GrantType) String() string { return "password" case 1 << 3: return "clientcredentials" + case 1 << 4: + return "refreshtoken" } return "unknown" } diff --git a/generate.go b/generate.go index c6447ca..7eb0f06 100644 --- a/generate.go +++ b/generate.go @@ -3,22 +3,22 @@ package oauth2 import "time" type ( - // TokenGenerateBasic 提供生成令牌的基础数据 - TokenGenerateBasic struct { + // GenerateBasic 提供生成令牌的基础数据 + GenerateBasic struct { Client ClientInfo // 客户端信息 UserID string // 用户标识 CreateAt time.Time // 创建时间 } - // AuthorizeTokenGenerate 授权令牌生成接口 - AuthorizeTokenGenerate interface { + // AuthorizeGenerate 授权令牌生成接口 + AuthorizeGenerate interface { // 授权令牌 - Token(data *TokenGenerateBasic) (token string, err error) + Token(data *GenerateBasic) (code string, err error) } - // TokenGenerate 令牌生成接口 - TokenGenerate interface { - // 生成令牌 - Token(data *TokenGenerateBasic, isGenRefresh bool) (token string, refresh string, err error) + // AccessGenerate 访问令牌生成接口 + AccessGenerate interface { + // 访问令牌、更新令牌 + Token(data *GenerateBasic, isGenRefresh bool) (access, refresh string, err error) } ) diff --git a/manage.go b/manage.go index 2a662e3..fcf0947 100644 --- a/manage.go +++ b/manage.go @@ -18,23 +18,29 @@ type Manager interface { // tgr 生成令牌的请求参数 GenerateAuthToken(rt ResponseType, tgr *TokenGenerateRequest) (token string, err error) - // GenerateToken 生成访问令牌、更新令牌 + // GenerateAccessToken 生成访问令牌、更新令牌 // rt 授权模式 // tgr 生成令牌的请求参数 - GenerateToken(rt GrantType, tgr *TokenGenerateRequest) (token, refresh string, err error) + GenerateAccessToken(rt GrantType, tgr *TokenGenerateRequest) (access, refresh string, err error) - // RefreshToken 使用更新令牌更新访问令牌 + // RefreshAccessToken 更新访问令牌 // refresh 更新令牌 // scope 作用域 - RefreshToken(refresh, scope string) (token string, err error) + RefreshAccessToken(refresh, scope string) (access string, err error) - // RevokeToken 使用访问令牌废除令牌信息 - // token 访问令牌 - RevokeToken(token string) (err error) + // RemoveAccessToken 删除访问令牌 + // access 访问令牌 + RemoveAccessToken(access string) (err error) - // CheckToken 令牌检查,如果存在则返回令牌信息 - CheckToken(token string) (ti TokenInfo, err error) + // RemoveRefreshToken 删除更新令牌 + // refresh 更新令牌 + RemoveRefreshToken(refresh string) (err error) + + // LoadAccessToken 加载访问令牌信息 + // access 访问令牌 + LoadAccessToken(access string) (ti TokenInfo, err error) - // CheckRefreshToken 更新令牌检查,如果存在则返回令牌信息 - CheckRefreshToken(refresh string) (ti TokenInfo, err error) + // LoadRefreshToken 加载更新令牌信息 + // refresh 更新令牌 + LoadRefreshToken(refresh string) (ti TokenInfo, err error) } diff --git a/manage/error.go b/manage/error.go index 21e7297..b959173 100644 --- a/manage/error.go +++ b/manage/error.go @@ -15,15 +15,15 @@ var ( // ErrAuthTokenInvalid Authorize token invalid ErrAuthTokenInvalid = errors.New("authorize token invalid") + // ErrAccessInvalid Access token expired + ErrAccessInvalid = errors.New("access token invalid") + + // ErrAccessExpired Access token expired + ErrAccessExpired = errors.New("access token expired") + // ErrRefreshInvalid Refresh token invalid ErrRefreshInvalid = errors.New("refresh token invalid") // ErrRefreshExpired Refresh token expired ErrRefreshExpired = errors.New("refresh token expired") - - // ErrTokenInvalid Token invalid - ErrTokenInvalid = errors.New("token invalid") - - // ErrTokenExpired Token expired - ErrTokenExpired = errors.New("token expired") ) diff --git a/manage/manager.go b/manage/manager.go index 27f9ed5..a8da8fa 100644 --- a/manage/manager.go +++ b/manage/manager.go @@ -71,7 +71,7 @@ func (m *Manager) MapTokenModel(token oauth2.TokenInfo) { } // MapAuthorizeGenerate 注入授权令牌生成接口 -func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeTokenGenerate) { +func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeGenerate) { if gen == nil { panic(ErrNilValue) } @@ -79,7 +79,7 @@ func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeTokenGenerate) { } // MapTokenGenerate 注入访问令牌生成接口 -func (m *Manager) MapTokenGenerate(gen oauth2.TokenGenerate) { +func (m *Manager) MapTokenGenerate(gen oauth2.AccessGenerate) { if gen == nil { panic(ErrNilValue) } @@ -148,8 +148,8 @@ func (m *Manager) GenerateAuthToken(rt oauth2.ResponseType, tgr *oauth2.TokenGen err = verr return } - _, ierr := m.injector.Invoke(func(ti oauth2.TokenInfo, gen oauth2.AuthorizeTokenGenerate, stor oauth2.TokenStorage) { - td := &oauth2.TokenGenerateBasic{ + _, ierr := m.injector.Invoke(func(ti oauth2.TokenInfo, gen oauth2.AuthorizeGenerate, stor oauth2.TokenStorage) { + td := &oauth2.GenerateBasic{ Client: cli, UserID: tgr.UserID, CreateAt: time.Now(), @@ -163,9 +163,10 @@ func (m *Manager) GenerateAuthToken(rt oauth2.ResponseType, tgr *oauth2.TokenGen ti.SetUserID(tgr.UserID) ti.SetRedirectURI(tgr.RedirectURI) ti.SetScope(tgr.Scope) - ti.SetTokenCreateAt(td.CreateAt) - ti.SetTokenExpiresIn(m.rtcfg[rt].TokenExp) - ti.SetToken(tv) + ti.SetAuthType(rt.String()) + ti.SetAccess(tv) + ti.SetAccessCreateAt(td.CreateAt) + ti.SetAccessExpiresIn(m.rtcfg[rt].TokenExp) err = stor.Create(ti) if err != nil { return @@ -178,35 +179,18 @@ func (m *Manager) GenerateAuthToken(rt oauth2.ResponseType, tgr *oauth2.TokenGen return } -// checkAuthToken 检查授权令牌 -func (m *Manager) checkAuthToken(tgr *oauth2.TokenGenerateRequest) (err error) { - _, ierr := m.injector.Invoke(func(stor oauth2.TokenStorage) { - ti, terr := stor.TakeByToken(tgr.Code) +// GenerateAccessToken 生成访问令牌、更新令牌 +// gt 授权模式 +// tgr 生成令牌的参数 +func (m *Manager) GenerateAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (token, refresh string, err error) { + if gt == oauth2.AuthorizationCodeCredentials { // 授权码模式 + ti, terr := m.LoadAccessToken(tgr.Code) if terr != nil { err = terr return } else if ti.GetRedirectURI() != tgr.RedirectURI || ti.GetClientID() != tgr.ClientID { err = ErrAuthTokenInvalid return - } else if ti.GetTokenCreateAt().Add(ti.GetTokenExpiresIn()).Before(time.Now()) { - err = ErrAuthTokenInvalid - return - } - }) - if ierr != nil && err == nil { - err = ierr - } - return -} - -// GenerateToken 生成令牌 -// gt 授权模式 -// tgr 生成令牌的参数 -func (m *Manager) GenerateToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (token, refresh string, err error) { - if gt == oauth2.AuthorizationCodeCredentials { - err = m.checkAuthToken(tgr) - if err != nil { - return } } cli, err := m.GetClient(tgr.ClientID) @@ -216,8 +200,8 @@ func (m *Manager) GenerateToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRe err = ErrClientInvalid return } - _, ierr := m.injector.Invoke(func(ti oauth2.TokenInfo, gen oauth2.TokenGenerate, stor oauth2.TokenStorage) { - td := &oauth2.TokenGenerateBasic{ + _, ierr := m.injector.Invoke(func(ti oauth2.TokenInfo, gen oauth2.AccessGenerate, stor oauth2.TokenStorage) { + td := &oauth2.GenerateBasic{ Client: cli, UserID: tgr.UserID, CreateAt: time.Now(), @@ -231,9 +215,10 @@ func (m *Manager) GenerateToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRe ti.SetUserID(tgr.UserID) ti.SetRedirectURI(tgr.RedirectURI) ti.SetScope(tgr.Scope) - ti.SetTokenCreateAt(td.CreateAt) - ti.SetTokenExpiresIn(m.gtcfg[gt].TokenExp) - ti.SetToken(tv) + ti.SetAuthType(gt.String()) + ti.SetAccessCreateAt(td.CreateAt) + ti.SetAccessExpiresIn(m.gtcfg[gt].TokenExp) + ti.SetAccess(tv) if rv != "" { ti.SetRefreshCreateAt(td.CreateAt) ti.SetRefreshExpiresIn(m.gtcfg[gt].RefreshExp) @@ -251,19 +236,19 @@ func (m *Manager) GenerateToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRe return } -// RefreshToken 更新访问令牌 -func (m *Manager) RefreshToken(refresh, scope string) (token string, err error) { - ti, err := m.CheckRefreshToken(refresh) +// RefreshAccessToken 更新访问令牌 +func (m *Manager) RefreshAccessToken(refresh, scope string) (token string, err error) { + ti, err := m.LoadRefreshToken(refresh) if err != nil { return } - _, ierr := m.injector.Invoke(func(stor oauth2.TokenStorage, gen oauth2.TokenGenerate) { + _, ierr := m.injector.Invoke(func(stor oauth2.TokenStorage, gen oauth2.AccessGenerate) { cli, cerr := m.GetClient(ti.GetClientID()) if cerr != nil { err = cerr return } - td := &oauth2.TokenGenerateBasic{ + td := &oauth2.GenerateBasic{ Client: cli, UserID: ti.GetUserID(), CreateAt: time.Now(), @@ -273,8 +258,9 @@ func (m *Manager) RefreshToken(refresh, scope string) (token string, err error) err = terr return } - ti.SetToken(tv) - ti.SetTokenCreateAt(td.CreateAt) + ti.SetAuthType(oauth2.RefreshCredentials.String()) + ti.SetAccess(tv) + ti.SetAccessCreateAt(td.CreateAt) if scope != "" { ti.SetScope(scope) } @@ -290,14 +276,14 @@ func (m *Manager) RefreshToken(refresh, scope string) (token string, err error) return } -// RevokeToken 废除令牌 -func (m *Manager) RevokeToken(token string) (err error) { - if token == "" { - err = ErrTokenInvalid +// RemoveAccessToken 删除访问令牌 +func (m *Manager) RemoveAccessToken(access string) (err error) { + if access == "" { + err = ErrAccessInvalid return } _, ierr := m.injector.Invoke(func(stor oauth2.TokenStorage) { - err = stor.DeleteByToken(token) + err = stor.RemoveByAccess(access) }) if ierr != nil && err == nil { err = ierr @@ -305,33 +291,51 @@ func (m *Manager) RevokeToken(token string) (err error) { return } -// CheckToken 令牌检查 -func (m *Manager) CheckToken(token string) (info oauth2.TokenInfo, err error) { - if token == "" { - err = ErrTokenInvalid +// RemoveRefreshToken 删除更新令牌 +func (m *Manager) RemoveRefreshToken(refresh string) (err error) { + if refresh == "" { + err = ErrAccessInvalid + return + } + _, ierr := m.injector.Invoke(func(stor oauth2.TokenStorage) { + err = stor.RemoveByRefresh(refresh) + }) + if ierr != nil && err == nil { + err = ierr + } + return +} + +// LoadAccessToken 加载访问令牌信息 +func (m *Manager) LoadAccessToken(access string) (info oauth2.TokenInfo, err error) { + if access == "" { + err = ErrAccessInvalid return } _, ierr := m.injector.Invoke(func(stor oauth2.TokenStorage) { ct := time.Now() - ti, terr := stor.GetByToken(token) + ti, terr := stor.GetByAccess(access) if terr != nil { err = terr return } else if ti == nil { - err = ErrTokenInvalid + err = ErrAccessInvalid return - } else if ti.GetRefresh() != "" && ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) { // 检查g令牌是否过期 - if verr := stor.ExpiredByRefresh(ti.GetRefresh()); verr != nil { + } else if ti.GetRefresh() != "" && ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) { // 检查更新令牌是否过期 + // 删除过期的访问令牌 + if verr := stor.RemoveByRefresh(ti.GetRefresh()); verr != nil { err = verr return } err = ErrRefreshExpired - } else if ti.GetTokenCreateAt().Add(ti.GetTokenExpiresIn()).Before(ct) { // 检查令牌是否过期 - if verr := stor.ExpiredByToken(token); verr != nil { - err = verr - return + } else if ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) { // 检查访问令牌是否过期 + if ti.GetRefresh() == "" { // 删除过期的访问令牌 + if verr := stor.RemoveByAccess(access); verr != nil { + err = verr + return + } } - err = ErrTokenExpired + err = ErrAccessExpired return } info = ti @@ -342,8 +346,8 @@ func (m *Manager) CheckToken(token string) (info oauth2.TokenInfo, err error) { return } -// CheckRefreshToken 更新令牌检查 -func (m *Manager) CheckRefreshToken(refresh string) (info oauth2.TokenInfo, err error) { +// LoadRefreshToken 加载更新令牌信息 +func (m *Manager) LoadRefreshToken(refresh string) (info oauth2.TokenInfo, err error) { if refresh == "" { err = ErrRefreshInvalid return @@ -357,8 +361,8 @@ func (m *Manager) CheckRefreshToken(refresh string) (info oauth2.TokenInfo, err err = ErrRefreshInvalid return } else if ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(time.Now()) { - // 废除过期的令牌 - if verr := stor.ExpiredByRefresh(refresh); verr != nil { + // 删除过期的更新令牌 + if verr := stor.RemoveByRefresh(refresh); verr != nil { err = verr return } diff --git a/model.go b/model.go index f46be35..2d9cf62 100644 --- a/model.go +++ b/model.go @@ -12,7 +12,7 @@ type ( GetSecret() string // 客户端域名URL GetDomain() string - // 自定义数据 + // 预留数据 GetRetainData() interface{} } @@ -34,19 +34,23 @@ type ( GetScope() string // 设置权限范围 SetScope(string) + // 令牌授权类型 + GetAuthType() string + // 设置令牌授权类型 + SetAuthType(string) - // 令牌 - GetToken() string - // 设置令牌 - SetToken(string) - // 令牌创建时间 - GetTokenCreateAt() time.Time - // 设置令牌创建时间 - SetTokenCreateAt(time.Time) - // 令牌有效期 - GetTokenExpiresIn() time.Duration - // 设置令牌有效期 - SetTokenExpiresIn(time.Duration) + // 访问令牌 + GetAccess() string + // 设置访问令牌 + SetAccess(string) + // 访问令牌创建时间 + GetAccessCreateAt() time.Time + // 设置访问令牌创建时间 + SetAccessCreateAt(time.Time) + // 访问令牌有效期 + GetAccessExpiresIn() time.Duration + // 设置访问令牌有效期 + SetAccessExpiresIn(time.Duration) // 更新令牌 GetRefresh() string diff --git a/models/client.go b/models/client.go index dd8e3e1..c93813e 100644 --- a/models/client.go +++ b/models/client.go @@ -2,9 +2,9 @@ package models // Client 客户端信息 type Client struct { - ClientID string `bson:"ClientID"` // 客户端ID - Secret string `bson:"Secret"` // 密钥 - Domain string `bson:"Domain"` // 域名url + ClientID string // 客户端ID + Secret string // 密钥 + Domain string // 域名url } // GetID 客户端ID @@ -22,7 +22,7 @@ func (c *Client) GetDomain() string { return c.Domain } -// GetRetainData 自定义数据 +// GetRetainData 预留数据 func (c *Client) GetRetainData() interface{} { return nil } diff --git a/models/token.go b/models/token.go index 48b40b1..3121fd8 100644 --- a/models/token.go +++ b/models/token.go @@ -4,17 +4,17 @@ import "time" // Token 令牌信息 type Token struct { - ID int64 `bson:"_id"` // 唯一标识 - ClientID string `bson:"ClientID"` // 客户端标识 - UserID string `bson:"UserID"` // 用户标识 - RedirectURI string `bson:"RedirectURI"` // 重定向URI - Scope string `bson:"Scope"` // 权限范围 - Token string `bson:"Token"` // 令牌 - TokenCreateAt time.Time `bson:"TokenCreateAt"` // 令牌创建时间 - TokenExpiresIn time.Duration `bson:"TokenExpiresIn"` // 令牌有效期 - Refresh string `bson:"Refresh"` // 更新令牌 - RefreshCreateAt time.Time `bson:"RefreshCreateAt"` // 更新令牌创建时间 - RefreshExpiresIn time.Duration `bson:"RefreshExpiresIn"` // 更新令牌有效期 + ClientID string // 客户端标识 + UserID string // 用户标识 + RedirectURI string // 重定向URI + Scope string // 权限范围 + AuthType string // 令牌授权类型 + Access string // 访问令牌 + AccessCreateAt time.Time // 访问令牌创建时间 + AccessExpiresIn time.Duration // 访问令牌有效期 + Refresh string // 更新令牌 + RefreshCreateAt time.Time // 更新令牌创建时间 + RefreshExpiresIn time.Duration // 更新令牌有效期 } // GetClientID 客户端ID @@ -57,34 +57,44 @@ func (t *Token) SetScope(scope string) { t.Scope = scope } -// GetToken 令牌 -func (t *Token) GetToken() string { - return t.Token +// GetAuthType 授权类型 +func (t *Token) GetAuthType() string { + return t.AuthType } -// SetToken 设置令牌 -func (t *Token) SetToken(token string) { - t.Token = token +// SetAuthType 设置授权类型 +func (t *Token) SetAuthType(authType string) { + t.AuthType = authType } -// GetTokenCreateAt 令牌创建时间 -func (t *Token) GetTokenCreateAt() time.Time { - return t.TokenCreateAt +// GetAccess 访问令牌 +func (t *Token) GetAccess() string { + return t.Access } -// SetTokenCreateAt 设置令牌创建时间 -func (t *Token) SetTokenCreateAt(createAt time.Time) { - t.TokenCreateAt = createAt +// SetAccess 设置访问令牌 +func (t *Token) SetAccess(access string) { + t.Access = access } -// GetTokenExpiresIn 令牌有效期 -func (t *Token) GetTokenExpiresIn() time.Duration { - return t.TokenExpiresIn +// GetAccessCreateAt 访问令牌创建时间 +func (t *Token) GetAccessCreateAt() time.Time { + return t.AccessCreateAt } -// SetTokenExpiresIn 设置令牌有效期 -func (t *Token) SetTokenExpiresIn(exp time.Duration) { - t.TokenExpiresIn = exp +// SetAccessCreateAt 设置访问令牌创建时间 +func (t *Token) SetAccessCreateAt(createAt time.Time) { + t.AccessCreateAt = createAt +} + +// GetAccessExpiresIn 访问令牌有效期 +func (t *Token) GetAccessExpiresIn() time.Duration { + return t.AccessExpiresIn +} + +// SetAccessExpiresIn 设置访问令牌有效期 +func (t *Token) SetAccessExpiresIn(exp time.Duration) { + t.AccessExpiresIn = exp } // GetRefresh 更新令牌 diff --git a/storage.go b/storage.go index 389f35d..d3e3c44 100644 --- a/storage.go +++ b/storage.go @@ -13,25 +13,22 @@ type ( // Create 创建并存储新的令牌信息 Create(info TokenInfo) error - // UpdateByRefresh 根据刷新令牌更新令牌信息 + // UpdateByRefresh 使用更新令牌更新令牌信息 UpdateByRefresh(refresh string, info TokenInfo) error - // DeleteByToken 根据令牌删除令牌信息 - DeleteByToken(val string) error + // RemoveByAccess 使用访问令牌删除令牌信息 + RemoveByAccess(access string) error - // 根据令牌取出令牌信息数据(获取并删除) - TakeByToken(val string) (TokenInfo, error) + // RemoveByRefresh 使用更新令牌删除令牌信息 + RemoveByRefresh(refresh string) error - // 根据令牌获取令牌信息数据 - GetByToken(val string) (TokenInfo, error) + // 使用访问令牌取出令牌信息数据(获取并删除) + TakeByAccess(access string) (TokenInfo, error) - // 根据刷新令牌获取令牌信息数据 - GetByRefresh(refresh string) (TokenInfo, error) - - // 将该令牌对应的令牌信息作过期处理 - ExpiredByToken(val string) error + // 使用访问令牌获取令牌信息数据 + GetByAccess(access string) (TokenInfo, error) - // 将该刷新令牌对应的令牌信息作过期处理 - ExpiredByRefresh(refresh string) error + // 根据更新令牌获取令牌信息数据 + GetByRefresh(refresh string) (TokenInfo, error) } ) From 6b1aa2097983463b8ff886e602e403d96b12fa2d Mon Sep 17 00:00:00 2001 From: lyric Date: Thu, 30 Jun 2016 17:23:47 +0800 Subject: [PATCH 06/18] Initialize folders --- README.md | 13 ++----------- example/.gitkeep | 0 generates/.gitkeep | 0 manage/manager.go | 9 ++++++--- server/.gitkeep | 0 storage.go | 3 --- storages/.gitkeep | 0 7 files changed, 8 insertions(+), 17 deletions(-) create mode 100644 example/.gitkeep create mode 100644 generates/.gitkeep create mode 100644 server/.gitkeep create mode 100644 storages/.gitkeep diff --git a/README.md b/README.md index 26ccf08..c93f8c9 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ -Golang OAuth 2.0协议实现 -======================== +Golang OAuth2 Server +===================== [![GoDoc](https://godoc.org/gopkg.in/oauth2.v2?status.svg)](https://godoc.org/gopkg.in/oauth2.v2) [![Go Report Card](https://goreportcard.com/badge/gopkg.in/oauth2.v2)](https://goreportcard.com/report/gopkg.in/oauth2.v2) @@ -11,15 +11,6 @@ Golang OAuth 2.0协议实现 $ go get -v gopkg.in/oauth2.v2 ``` -执行测试 -------- - -```bash -$ go test -v -# 或 -$ goconvey -port=9090 -``` - License ------- diff --git a/example/.gitkeep b/example/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/generates/.gitkeep b/generates/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/manage/manager.go b/manage/manager.go index a8da8fa..0434e7f 100644 --- a/manage/manager.go +++ b/manage/manager.go @@ -27,7 +27,6 @@ func NewManager() *Manager { // 设定授权码模式令牌的有效期为2小时,g令牌的有效期为3天 m.SetGTConfig(oauth2.PasswordCredentials, &Config{TokenExp: time.Hour * 2, RefreshExp: time.Hour * 24 * 3}) - // 设定客户端模式令牌的有效期为1小时 m.SetGTConfig(oauth2.ClientCredentials, &Config{TokenExp: time.Hour * 2}) return m @@ -264,8 +263,12 @@ func (m *Manager) RefreshAccessToken(refresh, scope string) (token string, err e if scope != "" { ti.SetScope(scope) } - err = stor.UpdateByRefresh(refresh, ti) - if err != nil { + if verr := stor.Create(ti); verr != nil { + err = verr + return + } + if verr := stor.RemoveByRefresh(refresh); verr != nil { + err = verr return } token = tv diff --git a/server/.gitkeep b/server/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/storage.go b/storage.go index d3e3c44..37f6b28 100644 --- a/storage.go +++ b/storage.go @@ -13,9 +13,6 @@ type ( // Create 创建并存储新的令牌信息 Create(info TokenInfo) error - // UpdateByRefresh 使用更新令牌更新令牌信息 - UpdateByRefresh(refresh string, info TokenInfo) error - // RemoveByAccess 使用访问令牌删除令牌信息 RemoveByAccess(access string) error diff --git a/storages/.gitkeep b/storages/.gitkeep new file mode 100644 index 0000000..e69de29 From 7c480eadd36108c9ebf761da6582a82fdc0eb296 Mon Sep 17 00:00:00 2001 From: lyric Date: Sat, 2 Jul 2016 11:26:43 +0800 Subject: [PATCH 07/18] Add generated token package --- generates/.gitkeep | 0 generates/access.go | 41 +++++++++++++++++++++++++++++++++++++ generates/access_test.go | 28 +++++++++++++++++++++++++ generates/authorize.go | 31 ++++++++++++++++++++++++++++ generates/authorize_test.go | 27 ++++++++++++++++++++++++ manage/manager.go | 12 +++++------ model.go | 16 +++++++-------- models/client.go | 4 ++-- 8 files changed, 143 insertions(+), 16 deletions(-) delete mode 100644 generates/.gitkeep create mode 100644 generates/access.go create mode 100644 generates/access_test.go create mode 100644 generates/authorize.go create mode 100644 generates/authorize_test.go diff --git a/generates/.gitkeep b/generates/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/generates/access.go b/generates/access.go new file mode 100644 index 0000000..2d7c666 --- /dev/null +++ b/generates/access.go @@ -0,0 +1,41 @@ +package generates + +import ( + "bytes" + "strconv" + "strings" + + "github.com/LyricTian/go.uuid" + "gopkg.in/LyricTian/lib.v2" + "gopkg.in/oauth2.v2" +) + +// NewAccessGenerate 创建访问令牌生成实例 +func NewAccessGenerate() *AccessGenerate { + return &AccessGenerate{} +} + +// AccessGenerate 访问令牌生成 +type AccessGenerate struct { +} + +// Token 生成令牌 +func (ag *AccessGenerate) Token(data *oauth2.GenerateBasic, isGenRefresh bool) (access, refresh string, err error) { + buf := bytes.NewBufferString(data.Client.GetID()) + buf.WriteString(data.UserID) + buf.WriteString(strconv.FormatInt(data.CreateAt.UnixNano(), 10)) + access, err = lib.NewEncryption(uuid.NewV3(uuid.NewV4(), buf.String()).Bytes()).MD5() + if err != nil { + return + } + access = strings.ToUpper(access) + if isGenRefresh { + refresh, err = lib.NewEncryption(uuid.NewV5(uuid.NewV4(), buf.String()).Bytes()).Sha1() + if err != nil { + return + } + refresh = strings.ToUpper(refresh) + } + + return +} diff --git a/generates/access_test.go b/generates/access_test.go new file mode 100644 index 0000000..169c027 --- /dev/null +++ b/generates/access_test.go @@ -0,0 +1,28 @@ +package generates + +import ( + "testing" + "time" + + . "github.com/smartystreets/goconvey/convey" + "gopkg.in/oauth2.v2" + "gopkg.in/oauth2.v2/models" +) + +func TestAccess(t *testing.T) { + Convey("Test Access Generate", t, func() { + data := &oauth2.GenerateBasic{ + Client: &models.Client{ + ClientID: "123456", + Secret: "123456", + }, + UserID: "000000", + CreateAt: time.Now(), + } + gen := NewAccessGenerate() + access, refresh, err := gen.Token(data, true) + So(err, ShouldBeNil) + Println("\nAccess:", access) + Println("Refresh:", refresh) + }) +} diff --git a/generates/authorize.go b/generates/authorize.go new file mode 100644 index 0000000..32fc50f --- /dev/null +++ b/generates/authorize.go @@ -0,0 +1,31 @@ +package generates + +import ( + "bytes" + "strings" + + "github.com/LyricTian/go.uuid" + "gopkg.in/LyricTian/lib.v2" + "gopkg.in/oauth2.v2" +) + +// NewAuthorizeGenerate 创建授权令牌生成实例 +func NewAuthorizeGenerate() *AuthorizeGenerate { + return &AuthorizeGenerate{} +} + +// AuthorizeGenerate 授权令牌生成 +type AuthorizeGenerate struct{} + +// Token 生成令牌 +func (ag *AuthorizeGenerate) Token(data *oauth2.GenerateBasic) (code string, err error) { + buf := bytes.NewBuffer(uuid.NewV1().Bytes()) + buf.WriteString(data.UserID) + buf.WriteString(data.Client.GetID()) + code, err = lib.NewEncryption(buf.Bytes()).MD5() + if err != nil { + return + } + code = strings.ToUpper(code) + return +} diff --git a/generates/authorize_test.go b/generates/authorize_test.go new file mode 100644 index 0000000..49abc81 --- /dev/null +++ b/generates/authorize_test.go @@ -0,0 +1,27 @@ +package generates + +import ( + "testing" + "time" + + . "github.com/smartystreets/goconvey/convey" + "gopkg.in/oauth2.v2" + "gopkg.in/oauth2.v2/models" +) + +func TestAuthorize(t *testing.T) { + Convey("Test Authorize Generate", t, func() { + data := &oauth2.GenerateBasic{ + Client: &models.Client{ + ClientID: "123456", + Secret: "123456", + }, + UserID: "000000", + CreateAt: time.Now(), + } + gen := NewAuthorizeGenerate() + code, err := gen.Token(data) + So(err, ShouldBeNil) + Println("\nCode:", code) + }) +} diff --git a/manage/manager.go b/manage/manager.go index 0434e7f..e98d449 100644 --- a/manage/manager.go +++ b/manage/manager.go @@ -19,14 +19,15 @@ func NewManager() *Manager { injector: inject.New(), } // 设定参数默认值 - // 设定授权码的有效期为10分钟 m.SetRTConfig(oauth2.Code, &Config{TokenExp: time.Minute * 10}) // 设定简化模式授权令牌的有效期为1小时 m.SetRTConfig(oauth2.Token, &Config{TokenExp: time.Hour * 1}) - // 设定授权码模式令牌的有效期为2小时,g令牌的有效期为3天 - m.SetGTConfig(oauth2.PasswordCredentials, &Config{TokenExp: time.Hour * 2, RefreshExp: time.Hour * 24 * 3}) + // 设定授权码模式令牌的有效期为2小时,更新令牌的有效期为3天 + m.SetGTConfig(oauth2.AuthorizationCodeCredentials, &Config{TokenExp: time.Hour * 2, RefreshExp: time.Hour * 24 * 3}) + // 设定密码模式令牌的有效期为2小时,更新令牌的有效期为7天 + m.SetGTConfig(oauth2.PasswordCredentials, &Config{TokenExp: time.Hour * 2, RefreshExp: time.Hour * 24 * 7}) // 设定客户端模式令牌的有效期为1小时 m.SetGTConfig(oauth2.ClientCredentials, &Config{TokenExp: time.Hour * 2}) return m @@ -93,7 +94,7 @@ func (m *Manager) MapClientStorage(stor oauth2.ClientStorage) { m.injector.Map(stor) } -// MustClientStorage 注入客户端信息存储接口 +// MustClientStorage 强制注入客户端信息存储接口 func (m *Manager) MustClientStorage(stor oauth2.ClientStorage, err error) { if err != nil { panic(err) @@ -112,7 +113,7 @@ func (m *Manager) MapTokenStorage(stor oauth2.TokenStorage) { m.injector.Map(stor) } -// MustTokenStorage 注入令牌信息存储接口 +// MustTokenStorage 强制注入令牌信息存储接口 func (m *Manager) MustTokenStorage(stor oauth2.TokenStorage, err error) { if err != nil { panic(err) @@ -257,7 +258,6 @@ func (m *Manager) RefreshAccessToken(refresh, scope string) (token string, err e err = terr return } - ti.SetAuthType(oauth2.RefreshCredentials.String()) ti.SetAccess(tv) ti.SetAccessCreateAt(td.CreateAt) if scope != "" { diff --git a/model.go b/model.go index 2d9cf62..822fb55 100644 --- a/model.go +++ b/model.go @@ -12,8 +12,8 @@ type ( GetSecret() string // 客户端域名URL GetDomain() string - // 预留数据 - GetRetainData() interface{} + // Other data + GetOtherData() interface{} } // TokenInfo 令牌信息模型接口 @@ -39,17 +39,17 @@ type ( // 设置令牌授权类型 SetAuthType(string) - // 访问令牌 + // 访问令牌(或授权令牌) GetAccess() string - // 设置访问令牌 + // 设置访问令牌(或授权令牌) SetAccess(string) - // 访问令牌创建时间 + // 访问令牌(或授权令牌)创建时间 GetAccessCreateAt() time.Time - // 设置访问令牌创建时间 + // 设置访问令牌(或授权令牌)创建时间 SetAccessCreateAt(time.Time) - // 访问令牌有效期 + // 访问令牌(或授权令牌)有效期 GetAccessExpiresIn() time.Duration - // 设置访问令牌有效期 + // 设置访问令牌(或授权令牌)有效期 SetAccessExpiresIn(time.Duration) // 更新令牌 diff --git a/models/client.go b/models/client.go index c93813e..1dc9742 100644 --- a/models/client.go +++ b/models/client.go @@ -22,7 +22,7 @@ func (c *Client) GetDomain() string { return c.Domain } -// GetRetainData 预留数据 -func (c *Client) GetRetainData() interface{} { +// GetOtherData Other data +func (c *Client) GetOtherData() interface{} { return nil } From f23fea57cf719ec6bf51ee286e5abd1ba7c043a3 Mon Sep 17 00:00:00 2001 From: lyric Date: Sun, 3 Jul 2016 17:23:48 +0800 Subject: [PATCH 08/18] Add redis token store --- generates/access_test.go | 4 +- generates/authorize_test.go | 4 +- manage/manager.go | 29 ++++---- model.go | 4 +- models/client.go | 12 ++-- storages/.gitkeep | 0 storage.go => store.go | 11 ++-- store/client/temp.go | 36 ++++++++++ store/token/redis.go | 120 +++++++++++++++++++++++++++++++++ store/token/redis_config.go | 41 ++++++++++++ store/token/redis_test.go | 128 ++++++++++++++++++++++++++++++++++++ 11 files changed, 358 insertions(+), 31 deletions(-) delete mode 100644 storages/.gitkeep rename storage.go => store.go (70%) create mode 100644 store/client/temp.go create mode 100644 store/token/redis.go create mode 100644 store/token/redis_config.go create mode 100644 store/token/redis_test.go diff --git a/generates/access_test.go b/generates/access_test.go index 169c027..3a78542 100644 --- a/generates/access_test.go +++ b/generates/access_test.go @@ -13,8 +13,8 @@ func TestAccess(t *testing.T) { Convey("Test Access Generate", t, func() { data := &oauth2.GenerateBasic{ Client: &models.Client{ - ClientID: "123456", - Secret: "123456", + ID: "123456", + Secret: "123456", }, UserID: "000000", CreateAt: time.Now(), diff --git a/generates/authorize_test.go b/generates/authorize_test.go index 49abc81..6439b72 100644 --- a/generates/authorize_test.go +++ b/generates/authorize_test.go @@ -13,8 +13,8 @@ func TestAuthorize(t *testing.T) { Convey("Test Authorize Generate", t, func() { data := &oauth2.GenerateBasic{ Client: &models.Client{ - ClientID: "123456", - Secret: "123456", + ID: "123456", + Secret: "123456", }, UserID: "000000", CreateAt: time.Now(), diff --git a/manage/manager.go b/manage/manager.go index e98d449..7de01d6 100644 --- a/manage/manager.go +++ b/manage/manager.go @@ -87,7 +87,7 @@ func (m *Manager) MapTokenGenerate(gen oauth2.AccessGenerate) { } // MapClientStorage 注入客户端信息存储接口 -func (m *Manager) MapClientStorage(stor oauth2.ClientStorage) { +func (m *Manager) MapClientStorage(stor oauth2.ClientStore) { if stor == nil { panic(ErrNilValue) } @@ -95,7 +95,7 @@ func (m *Manager) MapClientStorage(stor oauth2.ClientStorage) { } // MustClientStorage 强制注入客户端信息存储接口 -func (m *Manager) MustClientStorage(stor oauth2.ClientStorage, err error) { +func (m *Manager) MustClientStorage(stor oauth2.ClientStore, err error) { if err != nil { panic(err) } @@ -106,7 +106,7 @@ func (m *Manager) MustClientStorage(stor oauth2.ClientStorage, err error) { } // MapTokenStorage 注入令牌信息存储接口 -func (m *Manager) MapTokenStorage(stor oauth2.TokenStorage) { +func (m *Manager) MapTokenStorage(stor oauth2.TokenStore) { if stor == nil { panic(ErrNilValue) } @@ -114,7 +114,7 @@ func (m *Manager) MapTokenStorage(stor oauth2.TokenStorage) { } // MustTokenStorage 强制注入令牌信息存储接口 -func (m *Manager) MustTokenStorage(stor oauth2.TokenStorage, err error) { +func (m *Manager) MustTokenStorage(stor oauth2.TokenStore, err error) { if err != nil { panic(err) } @@ -126,7 +126,7 @@ func (m *Manager) MustTokenStorage(stor oauth2.TokenStorage, err error) { // GetClient 获取客户端信息 func (m *Manager) GetClient(clientID string) (cli oauth2.ClientInfo, err error) { - err = m.injector.Apply(func(stor oauth2.ClientStorage) { + err = m.injector.Apply(func(stor oauth2.ClientStore) { cli, err = stor.GetByID(clientID) if err != nil { return @@ -148,7 +148,7 @@ func (m *Manager) GenerateAuthToken(rt oauth2.ResponseType, tgr *oauth2.TokenGen err = verr return } - _, ierr := m.injector.Invoke(func(ti oauth2.TokenInfo, gen oauth2.AuthorizeGenerate, stor oauth2.TokenStorage) { + _, ierr := m.injector.Invoke(func(ti oauth2.TokenInfo, gen oauth2.AuthorizeGenerate, stor oauth2.TokenStore) { td := &oauth2.GenerateBasic{ Client: cli, UserID: tgr.UserID, @@ -191,7 +191,12 @@ func (m *Manager) GenerateAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGene } else if ti.GetRedirectURI() != tgr.RedirectURI || ti.GetClientID() != tgr.ClientID { err = ErrAuthTokenInvalid return + } else if verr := m.RemoveAccessToken(tgr.Code); verr != nil { // 删除授权码 + err = verr + return } + tgr.UserID = ti.GetUserID() + tgr.Scope = ti.GetScope() } cli, err := m.GetClient(tgr.ClientID) if err != nil { @@ -200,7 +205,7 @@ func (m *Manager) GenerateAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGene err = ErrClientInvalid return } - _, ierr := m.injector.Invoke(func(ti oauth2.TokenInfo, gen oauth2.AccessGenerate, stor oauth2.TokenStorage) { + _, ierr := m.injector.Invoke(func(ti oauth2.TokenInfo, gen oauth2.AccessGenerate, stor oauth2.TokenStore) { td := &oauth2.GenerateBasic{ Client: cli, UserID: tgr.UserID, @@ -242,7 +247,7 @@ func (m *Manager) RefreshAccessToken(refresh, scope string) (token string, err e if err != nil { return } - _, ierr := m.injector.Invoke(func(stor oauth2.TokenStorage, gen oauth2.AccessGenerate) { + _, ierr := m.injector.Invoke(func(stor oauth2.TokenStore, gen oauth2.AccessGenerate) { cli, cerr := m.GetClient(ti.GetClientID()) if cerr != nil { err = cerr @@ -285,7 +290,7 @@ func (m *Manager) RemoveAccessToken(access string) (err error) { err = ErrAccessInvalid return } - _, ierr := m.injector.Invoke(func(stor oauth2.TokenStorage) { + _, ierr := m.injector.Invoke(func(stor oauth2.TokenStore) { err = stor.RemoveByAccess(access) }) if ierr != nil && err == nil { @@ -300,7 +305,7 @@ func (m *Manager) RemoveRefreshToken(refresh string) (err error) { err = ErrAccessInvalid return } - _, ierr := m.injector.Invoke(func(stor oauth2.TokenStorage) { + _, ierr := m.injector.Invoke(func(stor oauth2.TokenStore) { err = stor.RemoveByRefresh(refresh) }) if ierr != nil && err == nil { @@ -315,7 +320,7 @@ func (m *Manager) LoadAccessToken(access string) (info oauth2.TokenInfo, err err err = ErrAccessInvalid return } - _, ierr := m.injector.Invoke(func(stor oauth2.TokenStorage) { + _, ierr := m.injector.Invoke(func(stor oauth2.TokenStore) { ct := time.Now() ti, terr := stor.GetByAccess(access) if terr != nil { @@ -355,7 +360,7 @@ func (m *Manager) LoadRefreshToken(refresh string) (info oauth2.TokenInfo, err e err = ErrRefreshInvalid return } - _, ierr := m.injector.Invoke(func(stor oauth2.TokenStorage) { + _, ierr := m.injector.Invoke(func(stor oauth2.TokenStore) { ti, terr := stor.GetByRefresh(refresh) if terr != nil { err = terr diff --git a/model.go b/model.go index 822fb55..e7c7fde 100644 --- a/model.go +++ b/model.go @@ -12,8 +12,8 @@ type ( GetSecret() string // 客户端域名URL GetDomain() string - // Other data - GetOtherData() interface{} + // 用户数据 + GetUserData() interface{} } // TokenInfo 令牌信息模型接口 diff --git a/models/client.go b/models/client.go index 1dc9742..16e519b 100644 --- a/models/client.go +++ b/models/client.go @@ -2,14 +2,14 @@ package models // Client 客户端信息 type Client struct { - ClientID string // 客户端ID - Secret string // 密钥 - Domain string // 域名url + ID string // 客户端ID + Secret string // 密钥 + Domain string // 域名url } // GetID 客户端ID func (c *Client) GetID() string { - return c.ClientID + return c.ID } // GetSecret 客户端秘钥 @@ -22,7 +22,7 @@ func (c *Client) GetDomain() string { return c.Domain } -// GetOtherData Other data -func (c *Client) GetOtherData() interface{} { +// GetUserData 用户数据 +func (c *Client) GetUserData() interface{} { return nil } diff --git a/storages/.gitkeep b/storages/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/storage.go b/store.go similarity index 70% rename from storage.go rename to store.go index 37f6b28..c2ac40d 100644 --- a/storage.go +++ b/store.go @@ -2,14 +2,14 @@ package oauth2 // 提供存储接口 type ( - // ClientStorage 客户端信息存储接口 - ClientStorage interface { + // ClientStore 客户端信息存储接口 + ClientStore interface { // GetByID 根据ID获取客户端信息 GetByID(id string) (ClientInfo, error) } - // TokenStorage 令牌信息存储接口 - TokenStorage interface { + // TokenStore 令牌信息存储接口 + TokenStore interface { // Create 创建并存储新的令牌信息 Create(info TokenInfo) error @@ -19,9 +19,6 @@ type ( // RemoveByRefresh 使用更新令牌删除令牌信息 RemoveByRefresh(refresh string) error - // 使用访问令牌取出令牌信息数据(获取并删除) - TakeByAccess(access string) (TokenInfo, error) - // 使用访问令牌获取令牌信息数据 GetByAccess(access string) (TokenInfo, error) diff --git a/store/client/temp.go b/store/client/temp.go new file mode 100644 index 0000000..07a00cf --- /dev/null +++ b/store/client/temp.go @@ -0,0 +1,36 @@ +package client + +import ( + "errors" + + "gopkg.in/oauth2.v2" + "gopkg.in/oauth2.v2/models" +) + +// NewTempStore 创建客户端临时存储实例 +func NewTempStore() oauth2.ClientStore { + return &TempStore{ + data: map[string]*models.Client{ + "1": &models.Client{ + ID: "1", + Secret: "11", + Domain: "http://localhost", + }, + }, + } +} + +// TempStore 客户端信息的临时存储 +type TempStore struct { + data map[string]*models.Client +} + +// GetByID 获取客户端信息 +func (ts *TempStore) GetByID(id string) (cli oauth2.ClientInfo, err error) { + if c, ok := ts.data[id]; ok { + cli = c + return + } + err = errors.New("not found") + return +} diff --git a/store/token/redis.go b/store/token/redis.go new file mode 100644 index 0000000..ece515b --- /dev/null +++ b/store/token/redis.go @@ -0,0 +1,120 @@ +package token + +import ( + "encoding/json" + + "gopkg.in/oauth2.v2" + "gopkg.in/oauth2.v2/models" + "gopkg.in/redis.v4" +) + +// NewRedisStore 创建redis存储的实例 +func NewRedisStore(cfg *RedisConfig) (store oauth2.TokenStore, err error) { + opt := &redis.Options{ + Network: cfg.Network, + Addr: cfg.Addr, + Password: cfg.Password, + DB: cfg.DB, + MaxRetries: cfg.MaxRetries, + DialTimeout: cfg.DialTimeout, + ReadTimeout: cfg.ReadTimeout, + WriteTimeout: cfg.WriteTimeout, + PoolSize: cfg.PoolSize, + PoolTimeout: cfg.PoolTimeout, + } + cli := redis.NewClient(opt) + if verr := cli.Ping().Err(); verr != nil { + err = verr + return + } + store = &RedisStore{cli: cli} + return +} + +// RedisStore 令牌的redis存储 +type RedisStore struct { + cli *redis.Client +} + +// Create 存储令牌信息 +func (rs *RedisStore) Create(info oauth2.TokenInfo) (err error) { + jv, err := json.Marshal(info) + if err != nil { + return + } + pipe := rs.cli.Pipeline() + + aexp := info.GetAccessExpiresIn() + if refresh := info.GetRefresh(); refresh != "" { + exp := info.GetRefreshExpiresIn() + ttl := rs.cli.TTL(refresh) + if verr := ttl.Err(); verr != nil { + err = verr + return + } + if v := ttl.Val(); v.Seconds() > 0 { + exp = v + } + if aexp.Seconds() > exp.Seconds() { + aexp = exp + } + pipe.Set(refresh, jv, exp) + } + pipe.Set(info.GetAccess(), jv, aexp) + + if _, verr := pipe.Exec(); verr != nil { + err = verr + } + return +} + +// remove +func (rs *RedisStore) remove(key string) (err error) { + del := rs.cli.Del(key) + if verr := del.Err(); verr != nil { + err = verr + } + return +} + +// RemoveByAccess 移除令牌 +func (rs *RedisStore) RemoveByAccess(access string) (err error) { + err = rs.remove(access) + return +} + +// RemoveByRefresh 移除令牌 +func (rs *RedisStore) RemoveByRefresh(refresh string) (err error) { + err = rs.remove(refresh) + return +} + +func (rs *RedisStore) get(key string) (ti oauth2.TokenInfo, err error) { + gv, gerr := rs.cli.Get(key).Result() + if gerr != nil { + if gerr == redis.Nil { + return + } + err = gerr + return + } + var tm models.Token + if verr := json.Unmarshal([]byte(gv), &tm); verr != nil { + err = verr + return + } + ti = &tm + return +} + +// GetByAccess 获取令牌数据 +func (rs *RedisStore) GetByAccess(access string) (ti oauth2.TokenInfo, err error) { + ti, err = rs.get(access) + return +} + +// GetByRefresh 获取令牌数据 +func (rs *RedisStore) GetByRefresh(refresh string) (ti oauth2.TokenInfo, err error) { + ti, err = rs.get(refresh) + return +} diff --git a/store/token/redis_config.go b/store/token/redis_config.go new file mode 100644 index 0000000..c793498 --- /dev/null +++ b/store/token/redis_config.go @@ -0,0 +1,41 @@ +package token + +import "time" + +// RedisConfig Redis配置参数 +type RedisConfig struct { + // The network type, either tcp or unix. + // Default is tcp. + Network string + // host:port address. + Addr string + + // An optional password. Must match the password specified in the + // requirepass server configuration option. + Password string + // A database to be selected after connecting to server. + DB int + + // The maximum number of retries before giving up. + // Default is to not retry failed commands. + MaxRetries int + + // Sets the deadline for establishing new connections. If reached, + // dial will fail with a timeout. + // Default is 5 seconds. + DialTimeout time.Duration + // Sets the deadline for socket reads. If reached, commands will + // fail with a timeout instead of blocking. + ReadTimeout time.Duration + // Sets the deadline for socket writes. If reached, commands will + // fail with a timeout instead of blocking. + WriteTimeout time.Duration + + // The maximum number of socket connections. + // Default is 10 connections. + PoolSize int + // Specifies amount of time client waits for connection if all + // connections are busy before returning an error. + // Default is 1 second. + PoolTimeout time.Duration +} diff --git a/store/token/redis_test.go b/store/token/redis_test.go new file mode 100644 index 0000000..151e258 --- /dev/null +++ b/store/token/redis_test.go @@ -0,0 +1,128 @@ +package token + +import ( + "testing" + "time" + + . "github.com/smartystreets/goconvey/convey" + "gopkg.in/oauth2.v2" + "gopkg.in/oauth2.v2/models" +) + +func TestRedisStore(t *testing.T) { + Convey("Test redis store", t, func() { + cfg := &RedisConfig{ + Addr: "192.168.33.70:6379", + } + store, err := NewRedisStore(cfg) + So(err, ShouldBeNil) + + info := &models.Token{ + ClientID: "1", + UserID: "1_1", + RedirectURI: "http://localhost/", + Scope: "all", + AuthType: oauth2.Code.String(), + Access: "1_1_1", + AccessCreateAt: time.Now(), + AccessExpiresIn: time.Second * 10, + Refresh: "1_1_2", + RefreshCreateAt: time.Now(), + RefreshExpiresIn: time.Minute * 1, + } + err = store.Create(info) + So(err, ShouldBeNil) + + ainfo, err := store.GetByAccess(info.GetAccess()) + So(err, ShouldBeNil) + So(ainfo.GetRefresh(), ShouldEqual, info.GetRefresh()) + + err = store.RemoveByAccess(info.GetAccess()) + So(err, ShouldBeNil) + + ainfo, err = store.GetByAccess(info.GetAccess()) + So(err, ShouldBeNil) + So(ainfo, ShouldBeNil) + + rinfo, err := store.GetByRefresh(info.GetRefresh()) + So(err, ShouldBeNil) + So(rinfo.GetAccess(), ShouldEqual, info.GetAccess()) + + err = store.RemoveByRefresh(info.GetRefresh()) + So(err, ShouldBeNil) + + rinfo, err = store.GetByRefresh(info.GetRefresh()) + So(err, ShouldBeNil) + So(rinfo, ShouldBeNil) + }) +} + +func TestRedisStoreAccessExpired(t *testing.T) { + Convey("Test redis store access token expired", t, func() { + cfg := &RedisConfig{ + Addr: "192.168.33.70:6379", + } + store, err := NewRedisStore(cfg) + So(err, ShouldBeNil) + info := &models.Token{ + ClientID: "1", + UserID: "1_2", + RedirectURI: "http://localhost/", + Scope: "all", + AuthType: oauth2.Code.String(), + Access: "1_2_1", + AccessCreateAt: time.Now(), + AccessExpiresIn: time.Second * 1, + Refresh: "1_2_2", + RefreshCreateAt: time.Now(), + RefreshExpiresIn: time.Second * 5, + } + err = store.Create(info) + So(err, ShouldBeNil) + + time.Sleep(time.Second * 1) + + ainfo, err := store.GetByAccess(info.GetAccess()) + So(err, ShouldBeNil) + So(ainfo, ShouldBeNil) + + rinfo, err := store.GetByRefresh(info.GetRefresh()) + So(err, ShouldBeNil) + So(rinfo, ShouldNotBeNil) + }) +} + +func TestRedisStoreRefreshExpired(t *testing.T) { + Convey("Test redis store refresh token expired", t, func() { + cfg := &RedisConfig{ + Addr: "192.168.33.70:6379", + } + store, err := NewRedisStore(cfg) + So(err, ShouldBeNil) + info := &models.Token{ + ClientID: "1", + UserID: "1_3", + RedirectURI: "http://localhost/", + Scope: "all", + AuthType: oauth2.Code.String(), + Access: "1_3_1", + AccessCreateAt: time.Now(), + AccessExpiresIn: time.Second * 2, + Refresh: "1_3_2", + RefreshCreateAt: time.Now(), + RefreshExpiresIn: time.Second * 1, + } + err = store.Create(info) + So(err, ShouldBeNil) + + time.Sleep(time.Second * 1) + + ainfo, err := store.GetByAccess(info.GetAccess()) + So(err, ShouldBeNil) + So(ainfo, ShouldBeNil) + + rinfo, err := store.GetByRefresh(info.GetRefresh()) + So(err, ShouldBeNil) + So(rinfo, ShouldBeNil) + }) +} From b5f24be4666c5ec066b67bdd222ed9e1839767fc Mon Sep 17 00:00:00 2001 From: lyric Date: Sun, 3 Jul 2016 23:06:33 +0800 Subject: [PATCH 09/18] Add mongodb token store --- manage/manager.go | 19 +---- models/token.go | 22 +++--- store/token/mongo.go | 148 ++++++++++++++++++++++++++++++++++++ store/token/mongo_test.go | 51 +++++++++++++ store/token/redis_config.go | 2 +- store/token/redis_test.go | 97 ++--------------------- store/token/token_test.go | 113 +++++++++++++++++++++++++++ 7 files changed, 334 insertions(+), 118 deletions(-) create mode 100644 store/token/mongo.go create mode 100644 store/token/mongo_test.go create mode 100644 store/token/token_test.go diff --git a/manage/manager.go b/manage/manager.go index 7de01d6..71e7b51 100644 --- a/manage/manager.go +++ b/manage/manager.go @@ -247,6 +247,7 @@ func (m *Manager) RefreshAccessToken(refresh, scope string) (token string, err e if err != nil { return } + access := ti.GetAccess() _, ierr := m.injector.Invoke(func(stor oauth2.TokenStore, gen oauth2.AccessGenerate) { cli, cerr := m.GetClient(ti.GetClientID()) if cerr != nil { @@ -272,7 +273,7 @@ func (m *Manager) RefreshAccessToken(refresh, scope string) (token string, err e err = verr return } - if verr := stor.RemoveByRefresh(refresh); verr != nil { + if verr := stor.RemoveByAccess(access); verr != nil { err = verr return } @@ -330,19 +331,8 @@ func (m *Manager) LoadAccessToken(access string) (info oauth2.TokenInfo, err err err = ErrAccessInvalid return } else if ti.GetRefresh() != "" && ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) { // 检查更新令牌是否过期 - // 删除过期的访问令牌 - if verr := stor.RemoveByRefresh(ti.GetRefresh()); verr != nil { - err = verr - return - } err = ErrRefreshExpired } else if ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) { // 检查访问令牌是否过期 - if ti.GetRefresh() == "" { // 删除过期的访问令牌 - if verr := stor.RemoveByAccess(access); verr != nil { - err = verr - return - } - } err = ErrAccessExpired return } @@ -369,11 +359,6 @@ func (m *Manager) LoadRefreshToken(refresh string) (info oauth2.TokenInfo, err e err = ErrRefreshInvalid return } else if ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(time.Now()) { - // 删除过期的更新令牌 - if verr := stor.RemoveByRefresh(refresh); verr != nil { - err = verr - return - } err = ErrRefreshExpired return } diff --git a/models/token.go b/models/token.go index 3121fd8..ede71af 100644 --- a/models/token.go +++ b/models/token.go @@ -4,17 +4,17 @@ import "time" // Token 令牌信息 type Token struct { - ClientID string // 客户端标识 - UserID string // 用户标识 - RedirectURI string // 重定向URI - Scope string // 权限范围 - AuthType string // 令牌授权类型 - Access string // 访问令牌 - AccessCreateAt time.Time // 访问令牌创建时间 - AccessExpiresIn time.Duration // 访问令牌有效期 - Refresh string // 更新令牌 - RefreshCreateAt time.Time // 更新令牌创建时间 - RefreshExpiresIn time.Duration // 更新令牌有效期 + ClientID string `bson:"ClientID"` // 客户端标识 + UserID string `bson:"UserID"` // 用户标识 + RedirectURI string `bson:"RedirectURI"` // 重定向URI + Scope string `bson:"Scope"` // 权限范围 + AuthType string `bson:"AuthType"` // 令牌授权类型 + Access string `bson:"Access"` // 访问令牌 + AccessCreateAt time.Time `bson:"AccessCreateAt"` // 访问令牌创建时间 + AccessExpiresIn time.Duration `bson:"AccessExpiresIn"` // 访问令牌有效期 + Refresh string `bson:"Refresh"` // 更新令牌 + RefreshCreateAt time.Time `bson:"RefreshCreateAt"` // 更新令牌创建时间 + RefreshExpiresIn time.Duration `bson:"RefreshExpiresIn"` // 更新令牌有效期 } // GetClientID 客户端ID diff --git a/store/token/mongo.go b/store/token/mongo.go new file mode 100644 index 0000000..84520cd --- /dev/null +++ b/store/token/mongo.go @@ -0,0 +1,148 @@ +package token + +import ( + "time" + + "gopkg.in/LyricTian/lib.v2/mongo" + "gopkg.in/mgo.v2" + "gopkg.in/mgo.v2/bson" + "gopkg.in/oauth2.v2" + "gopkg.in/oauth2.v2/models" +) + +// MongoConfig MongoDB Configuration +type MongoConfig struct { + // Connection String + URL string + // DB Name(default oauth2) + DB string + // Collection Name(default tokens) + C string +} + +// NewMongoStore 创建MongoDB的令牌存储 +func NewMongoStore(cfg *MongoConfig) (store oauth2.TokenStore, err error) { + if cfg.DB == "" { + cfg.DB = "oauth2" + } + if cfg.C == "" { + cfg.C = "tokens" + } + handler, err := mongo.InitHandlerWithDB(cfg.URL, cfg.DB) + if err != nil { + return + } + // 创建自动过期索引 + err = handler.C(cfg.C).EnsureIndex(mgo.Index{ + Key: []string{"ExpiredAt"}, + ExpireAfter: time.Second, + }) + if err != nil { + return + } + err = handler.C(cfg.C).EnsureIndexKey("Access") + if err != nil { + return + } + err = handler.C(cfg.C).EnsureIndexKey("Refresh") + if err != nil { + return + } + store = &MongoStore{ + handler: handler, + cfg: cfg, + } + return +} + +// MongoStore MongoDB Store +type MongoStore struct { + cfg *MongoConfig + handler *mongo.Handler +} + +// Create 存储令牌信息 +func (ms *MongoStore) Create(info oauth2.TokenInfo) (err error) { + tm := info.(*models.Token) + var expiredAt time.Time + if refresh := tm.Refresh; refresh != "" { + expiredAt = tm.RefreshCreateAt.Add(tm.RefreshExpiresIn) + rinfo, rerr := ms.GetByRefresh(refresh) + if rerr != nil { + err = rerr + return + } + if rinfo != nil { + expiredAt = rinfo.GetRefreshCreateAt().Add(rinfo.GetRefreshExpiresIn()) + } + } + if expiredAt.IsZero() { + expiredAt = tm.AccessCreateAt.Add(tm.AccessExpiresIn) + } + doc := map[string]interface{}{ + "ExpiredAt": expiredAt, + "ClientID": tm.ClientID, + "UserID": tm.UserID, + "RedirectURI": tm.RedirectURI, + "Scope": tm.Scope, + "AuthType": tm.AuthType, + "Access": tm.Access, + "AccessCreateAt": tm.AccessCreateAt, + "AccessExpiresIn": tm.AccessExpiresIn, + "Refresh": tm.Refresh, + "RefreshCreateAt": tm.RefreshCreateAt, + "RefreshExpiresIn": tm.RefreshExpiresIn, + } + + ms.handler.CHandle(ms.cfg.C, func(c *mgo.Collection) { + err = c.Insert(doc) + }) + return +} + +func (ms *MongoStore) remove(selector interface{}) (err error) { + ms.handler.CHandle(ms.cfg.C, func(c *mgo.Collection) { + err = c.Remove(selector) + }) + return +} + +// RemoveByAccess 移除令牌 +func (ms *MongoStore) RemoveByAccess(access string) (err error) { + err = ms.remove(bson.M{"Access": access}) + return +} + +// RemoveByRefresh 移除令牌 +func (ms *MongoStore) RemoveByRefresh(refresh string) (err error) { + err = ms.remove(bson.M{"Refresh": refresh}) + return +} + +func (ms *MongoStore) get(find interface{}) (info oauth2.TokenInfo, err error) { + ms.handler.CHandle(ms.cfg.C, func(c *mgo.Collection) { + var tm models.Token + aerr := c.Find(find).Select(bson.M{"_id": 0}).One(&tm) + if aerr != nil { + if aerr == mgo.ErrNotFound { + return + } + err = aerr + return + } + info = &tm + }) + return +} + +// GetByAccess 获取令牌数据 +func (ms *MongoStore) GetByAccess(access string) (info oauth2.TokenInfo, err error) { + info, err = ms.get(bson.M{"Access": access}) + return +} + +// GetByRefresh 获取令牌数据 +func (ms *MongoStore) GetByRefresh(refresh string) (info oauth2.TokenInfo, err error) { + info, err = ms.get(bson.M{"Refresh": refresh}) + return +} diff --git a/store/token/mongo_test.go b/store/token/mongo_test.go new file mode 100644 index 0000000..3a6db26 --- /dev/null +++ b/store/token/mongo_test.go @@ -0,0 +1,51 @@ +package token + +import ( + "testing" + + . "github.com/smartystreets/goconvey/convey" +) + +const ( + mongoURL = "mongodb://admin:123456@192.168.33.70:27017" +) + +func TestMongoStore(t *testing.T) { + Convey("Test mongo store", t, func() { + cfg := &MongoConfig{ + URL: mongoURL, + } + store, err := NewMongoStore(cfg) + So(err, ShouldBeNil) + + Convey("Test mongo store access", func() { + testAccessStore(store) + }) + + Convey("Test mongo store refresh", func() { + testRefreshStore(store) + }) + }) +} + +func TestMongoStoreAccessExpired(t *testing.T) { + Convey("Test mongo store access token expired", t, func() { + cfg := &MongoConfig{ + URL: mongoURL, + } + store, err := NewMongoStore(cfg) + So(err, ShouldBeNil) + testAccessExpired(store) + }) +} + +func TestMongoStoreRefreshExpired(t *testing.T) { + Convey("Test mongo store refresh token expired", t, func() { + cfg := &MongoConfig{ + URL: mongoURL, + } + store, err := NewMongoStore(cfg) + So(err, ShouldBeNil) + testRefreshExpired(store) + }) +} diff --git a/store/token/redis_config.go b/store/token/redis_config.go index c793498..625e99d 100644 --- a/store/token/redis_config.go +++ b/store/token/redis_config.go @@ -2,7 +2,7 @@ package token import "time" -// RedisConfig Redis配置参数 +// RedisConfig Redis Configuration type RedisConfig struct { // The network type, either tcp or unix. // Default is tcp. diff --git a/store/token/redis_test.go b/store/token/redis_test.go index 151e258..57dc0c2 100644 --- a/store/token/redis_test.go +++ b/store/token/redis_test.go @@ -2,11 +2,8 @@ package token import ( "testing" - "time" . "github.com/smartystreets/goconvey/convey" - "gopkg.in/oauth2.v2" - "gopkg.in/oauth2.v2/models" ) func TestRedisStore(t *testing.T) { @@ -17,43 +14,13 @@ func TestRedisStore(t *testing.T) { store, err := NewRedisStore(cfg) So(err, ShouldBeNil) - info := &models.Token{ - ClientID: "1", - UserID: "1_1", - RedirectURI: "http://localhost/", - Scope: "all", - AuthType: oauth2.Code.String(), - Access: "1_1_1", - AccessCreateAt: time.Now(), - AccessExpiresIn: time.Second * 10, - Refresh: "1_1_2", - RefreshCreateAt: time.Now(), - RefreshExpiresIn: time.Minute * 1, - } - err = store.Create(info) - So(err, ShouldBeNil) - - ainfo, err := store.GetByAccess(info.GetAccess()) - So(err, ShouldBeNil) - So(ainfo.GetRefresh(), ShouldEqual, info.GetRefresh()) - - err = store.RemoveByAccess(info.GetAccess()) - So(err, ShouldBeNil) - - ainfo, err = store.GetByAccess(info.GetAccess()) - So(err, ShouldBeNil) - So(ainfo, ShouldBeNil) - - rinfo, err := store.GetByRefresh(info.GetRefresh()) - So(err, ShouldBeNil) - So(rinfo.GetAccess(), ShouldEqual, info.GetAccess()) + Convey("Test redis store access", func() { + testAccessStore(store) + }) - err = store.RemoveByRefresh(info.GetRefresh()) - So(err, ShouldBeNil) - - rinfo, err = store.GetByRefresh(info.GetRefresh()) - So(err, ShouldBeNil) - So(rinfo, ShouldBeNil) + Convey("Test redis store refresh", func() { + testRefreshStore(store) + }) }) } @@ -64,31 +31,7 @@ func TestRedisStoreAccessExpired(t *testing.T) { } store, err := NewRedisStore(cfg) So(err, ShouldBeNil) - info := &models.Token{ - ClientID: "1", - UserID: "1_2", - RedirectURI: "http://localhost/", - Scope: "all", - AuthType: oauth2.Code.String(), - Access: "1_2_1", - AccessCreateAt: time.Now(), - AccessExpiresIn: time.Second * 1, - Refresh: "1_2_2", - RefreshCreateAt: time.Now(), - RefreshExpiresIn: time.Second * 5, - } - err = store.Create(info) - So(err, ShouldBeNil) - - time.Sleep(time.Second * 1) - - ainfo, err := store.GetByAccess(info.GetAccess()) - So(err, ShouldBeNil) - So(ainfo, ShouldBeNil) - - rinfo, err := store.GetByRefresh(info.GetRefresh()) - So(err, ShouldBeNil) - So(rinfo, ShouldNotBeNil) + testAccessExpired(store) }) } @@ -99,30 +42,6 @@ func TestRedisStoreRefreshExpired(t *testing.T) { } store, err := NewRedisStore(cfg) So(err, ShouldBeNil) - info := &models.Token{ - ClientID: "1", - UserID: "1_3", - RedirectURI: "http://localhost/", - Scope: "all", - AuthType: oauth2.Code.String(), - Access: "1_3_1", - AccessCreateAt: time.Now(), - AccessExpiresIn: time.Second * 2, - Refresh: "1_3_2", - RefreshCreateAt: time.Now(), - RefreshExpiresIn: time.Second * 1, - } - err = store.Create(info) - So(err, ShouldBeNil) - - time.Sleep(time.Second * 1) - - ainfo, err := store.GetByAccess(info.GetAccess()) - So(err, ShouldBeNil) - So(ainfo, ShouldBeNil) - - rinfo, err := store.GetByRefresh(info.GetRefresh()) - So(err, ShouldBeNil) - So(rinfo, ShouldBeNil) + testRefreshExpired(store) }) } diff --git a/store/token/token_test.go b/store/token/token_test.go new file mode 100644 index 0000000..13badaa --- /dev/null +++ b/store/token/token_test.go @@ -0,0 +1,113 @@ +package token + +import ( + "time" + + . "github.com/smartystreets/goconvey/convey" + "gopkg.in/oauth2.v2" + "gopkg.in/oauth2.v2/models" +) + +func testAccessStore(store oauth2.TokenStore) { + info := &models.Token{ + ClientID: "1", + UserID: "1_1", + RedirectURI: "http://localhost/", + Scope: "all", + AuthType: oauth2.Code.String(), + Access: "1_1_1", + AccessCreateAt: time.Now(), + AccessExpiresIn: time.Second * 5, + } + err := store.Create(info) + So(err, ShouldBeNil) + + ainfo, err := store.GetByAccess(info.GetAccess()) + So(err, ShouldBeNil) + So(ainfo.GetUserID(), ShouldEqual, info.GetUserID()) + + err = store.RemoveByAccess(info.GetAccess()) + So(err, ShouldBeNil) + + ainfo, err = store.GetByAccess(info.GetAccess()) + So(err, ShouldBeNil) + So(ainfo, ShouldBeNil) +} + +func testRefreshStore(store oauth2.TokenStore) { + info := &models.Token{ + ClientID: "1", + UserID: "1_1", + RedirectURI: "http://localhost/", + Scope: "all", + AuthType: oauth2.Code.String(), + Access: "1_1_2", + AccessCreateAt: time.Now(), + AccessExpiresIn: time.Second * 5, + Refresh: "1_1_2_1", + RefreshCreateAt: time.Now(), + RefreshExpiresIn: time.Minute * 1, + } + err := store.Create(info) + So(err, ShouldBeNil) + + rinfo, err := store.GetByRefresh(info.GetRefresh()) + So(err, ShouldBeNil) + So(rinfo.GetUserID(), ShouldEqual, info.GetUserID()) + + err = store.RemoveByRefresh(info.GetRefresh()) + So(err, ShouldBeNil) + + rinfo, err = store.GetByRefresh(info.GetRefresh()) + So(err, ShouldBeNil) + So(rinfo, ShouldBeNil) +} + +func testAccessExpired(store oauth2.TokenStore) { + info := &models.Token{ + ClientID: "1", + UserID: "1_2", + RedirectURI: "http://localhost/", + Scope: "all", + AuthType: oauth2.Code.String(), + Access: "1_2_1", + AccessCreateAt: time.Now(), + AccessExpiresIn: time.Second * 1, + } + err := store.Create(info) + So(err, ShouldBeNil) + + time.Sleep(time.Millisecond * 3000) + + ainfo, err := store.GetByAccess(info.GetAccess()) + So(err, ShouldBeNil) + So(ainfo, ShouldBeNil) +} + +func testRefreshExpired(store oauth2.TokenStore) { + info := &models.Token{ + ClientID: "1", + UserID: "1_3", + RedirectURI: "http://localhost/", + Scope: "all", + AuthType: oauth2.Code.String(), + Access: "1_3_1", + AccessCreateAt: time.Now(), + AccessExpiresIn: time.Second * 2, + Refresh: "1_3_2", + RefreshCreateAt: time.Now(), + RefreshExpiresIn: time.Second * 1, + } + err := store.Create(info) + So(err, ShouldBeNil) + + time.Sleep(time.Millisecond * 3000) + + ainfo, err := store.GetByAccess(info.GetAccess()) + So(err, ShouldBeNil) + So(ainfo, ShouldBeNil) + + rinfo, err := store.GetByRefresh(info.GetRefresh()) + So(err, ShouldBeNil) + So(rinfo, ShouldBeNil) +} From b4447382bc179211a7e2f1d676153303cb041d95 Mon Sep 17 00:00:00 2001 From: lyric Date: Mon, 4 Jul 2016 09:38:13 +0800 Subject: [PATCH 10/18] Fixed token mongo store --- store/token/mongo.go | 4 +-- store/token/mongo_test.go | 26 ++---------------- store/token/redis_test.go | 26 ++---------------- store/token/token_test.go | 55 +++------------------------------------ 4 files changed, 9 insertions(+), 102 deletions(-) diff --git a/store/token/mongo.go b/store/token/mongo.go index 84520cd..ee699f3 100644 --- a/store/token/mongo.go +++ b/store/token/mongo.go @@ -35,7 +35,7 @@ func NewMongoStore(cfg *MongoConfig) (store oauth2.TokenStore, err error) { // 创建自动过期索引 err = handler.C(cfg.C).EnsureIndex(mgo.Index{ Key: []string{"ExpiredAt"}, - ExpireAfter: time.Second, + ExpireAfter: time.Second * 1, }) if err != nil { return @@ -122,7 +122,7 @@ func (ms *MongoStore) RemoveByRefresh(refresh string) (err error) { func (ms *MongoStore) get(find interface{}) (info oauth2.TokenInfo, err error) { ms.handler.CHandle(ms.cfg.C, func(c *mgo.Collection) { var tm models.Token - aerr := c.Find(find).Select(bson.M{"_id": 0}).One(&tm) + aerr := c.Find(find).Select(bson.M{"_id": 0}).Sort("-_id").One(&tm) if aerr != nil { if aerr == mgo.ErrNotFound { return diff --git a/store/token/mongo_test.go b/store/token/mongo_test.go index 3a6db26..6e5a2fd 100644 --- a/store/token/mongo_test.go +++ b/store/token/mongo_test.go @@ -18,34 +18,12 @@ func TestMongoStore(t *testing.T) { store, err := NewMongoStore(cfg) So(err, ShouldBeNil) - Convey("Test mongo store access", func() { + Convey("Test access token store", func() { testAccessStore(store) }) - Convey("Test mongo store refresh", func() { + Convey("Test refresh token store", func() { testRefreshStore(store) }) }) } - -func TestMongoStoreAccessExpired(t *testing.T) { - Convey("Test mongo store access token expired", t, func() { - cfg := &MongoConfig{ - URL: mongoURL, - } - store, err := NewMongoStore(cfg) - So(err, ShouldBeNil) - testAccessExpired(store) - }) -} - -func TestMongoStoreRefreshExpired(t *testing.T) { - Convey("Test mongo store refresh token expired", t, func() { - cfg := &MongoConfig{ - URL: mongoURL, - } - store, err := NewMongoStore(cfg) - So(err, ShouldBeNil) - testRefreshExpired(store) - }) -} diff --git a/store/token/redis_test.go b/store/token/redis_test.go index 57dc0c2..d658656 100644 --- a/store/token/redis_test.go +++ b/store/token/redis_test.go @@ -14,34 +14,12 @@ func TestRedisStore(t *testing.T) { store, err := NewRedisStore(cfg) So(err, ShouldBeNil) - Convey("Test redis store access", func() { + Convey("Test access token store", func() { testAccessStore(store) }) - Convey("Test redis store refresh", func() { + Convey("Test refresh token store", func() { testRefreshStore(store) }) }) } - -func TestRedisStoreAccessExpired(t *testing.T) { - Convey("Test redis store access token expired", t, func() { - cfg := &RedisConfig{ - Addr: "192.168.33.70:6379", - } - store, err := NewRedisStore(cfg) - So(err, ShouldBeNil) - testAccessExpired(store) - }) -} - -func TestRedisStoreRefreshExpired(t *testing.T) { - Convey("Test redis store refresh token expired", t, func() { - cfg := &RedisConfig{ - Addr: "192.168.33.70:6379", - } - store, err := NewRedisStore(cfg) - So(err, ShouldBeNil) - testRefreshExpired(store) - }) -} diff --git a/store/token/token_test.go b/store/token/token_test.go index 13badaa..87742a5 100644 --- a/store/token/token_test.go +++ b/store/token/token_test.go @@ -37,14 +37,14 @@ func testAccessStore(store oauth2.TokenStore) { func testRefreshStore(store oauth2.TokenStore) { info := &models.Token{ ClientID: "1", - UserID: "1_1", + UserID: "1_2", RedirectURI: "http://localhost/", Scope: "all", AuthType: oauth2.Code.String(), - Access: "1_1_2", + Access: "1_2_1", AccessCreateAt: time.Now(), AccessExpiresIn: time.Second * 5, - Refresh: "1_1_2_1", + Refresh: "1_2_2", RefreshCreateAt: time.Now(), RefreshExpiresIn: time.Minute * 1, } @@ -62,52 +62,3 @@ func testRefreshStore(store oauth2.TokenStore) { So(err, ShouldBeNil) So(rinfo, ShouldBeNil) } - -func testAccessExpired(store oauth2.TokenStore) { - info := &models.Token{ - ClientID: "1", - UserID: "1_2", - RedirectURI: "http://localhost/", - Scope: "all", - AuthType: oauth2.Code.String(), - Access: "1_2_1", - AccessCreateAt: time.Now(), - AccessExpiresIn: time.Second * 1, - } - err := store.Create(info) - So(err, ShouldBeNil) - - time.Sleep(time.Millisecond * 3000) - - ainfo, err := store.GetByAccess(info.GetAccess()) - So(err, ShouldBeNil) - So(ainfo, ShouldBeNil) -} - -func testRefreshExpired(store oauth2.TokenStore) { - info := &models.Token{ - ClientID: "1", - UserID: "1_3", - RedirectURI: "http://localhost/", - Scope: "all", - AuthType: oauth2.Code.String(), - Access: "1_3_1", - AccessCreateAt: time.Now(), - AccessExpiresIn: time.Second * 2, - Refresh: "1_3_2", - RefreshCreateAt: time.Now(), - RefreshExpiresIn: time.Second * 1, - } - err := store.Create(info) - So(err, ShouldBeNil) - - time.Sleep(time.Millisecond * 3000) - - ainfo, err := store.GetByAccess(info.GetAccess()) - So(err, ShouldBeNil) - So(ainfo, ShouldBeNil) - - rinfo, err := store.GetByRefresh(info.GetRefresh()) - So(err, ShouldBeNil) - So(rinfo, ShouldBeNil) -} From 4c41505763c71b1048037aa0083289193aed68a6 Mon Sep 17 00:00:00 2001 From: lyric Date: Mon, 4 Jul 2016 18:53:22 +0800 Subject: [PATCH 11/18] Add manager test --- manage/manage_test.go | 27 +++++++++++++++++++++++++++ manage/manager.go | 4 ++-- manage/util_test.go | 8 +++++--- models/client.go | 5 +++++ models/token.go | 5 +++++ 5 files changed, 44 insertions(+), 5 deletions(-) create mode 100644 manage/manage_test.go diff --git a/manage/manage_test.go b/manage/manage_test.go new file mode 100644 index 0000000..ad26844 --- /dev/null +++ b/manage/manage_test.go @@ -0,0 +1,27 @@ +package manage + +import ( + "testing" + + . "github.com/smartystreets/goconvey/convey" + "gopkg.in/oauth2.v2/generates" + "gopkg.in/oauth2.v2/models" + "gopkg.in/oauth2.v2/store/client" + "gopkg.in/oauth2.v2/store/token" +) + +func TestManager(t *testing.T) { + Convey("Manager Test", t, func() { + manager := NewManager() + + manager.MapClientModel(models.NewClient()) + manager.MapTokenModel(models.NewToken()) + manager.MapAuthorizeGenerate(generates.NewAuthorizeGenerate()) + manager.MapAccessGenerate(generates.NewAccessGenerate()) + manager.MapClientStorage(client.NewTempStore()) + manager.MustTokenStorage(token.NewRedisStore( + &token.RedisConfig{Addr: "192.168.33.70:6379"}, + )) + + }) +} diff --git a/manage/manager.go b/manage/manager.go index 71e7b51..b1590e1 100644 --- a/manage/manager.go +++ b/manage/manager.go @@ -78,8 +78,8 @@ func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeGenerate) { m.injector.Map(gen) } -// MapTokenGenerate 注入访问令牌生成接口 -func (m *Manager) MapTokenGenerate(gen oauth2.AccessGenerate) { +// MapAccessGenerate 注入访问令牌生成接口 +func (m *Manager) MapAccessGenerate(gen oauth2.AccessGenerate) { if gen == nil { panic(ErrNilValue) } diff --git a/manage/util_test.go b/manage/util_test.go index 6c4dd24..f936cbd 100644 --- a/manage/util_test.go +++ b/manage/util_test.go @@ -7,8 +7,10 @@ import ( ) func TestUtil(t *testing.T) { - Convey("ValidateURI Test", t, func() { - err := ValidateURI("http://www.example.com", "http://www.example.com/cb?code=xxx") - So(err, ShouldBeNil) + Convey("Util Test", t, func() { + Convey("ValidateURI Test", func() { + err := ValidateURI("http://www.example.com", "http://www.example.com/cb?code=xxx") + So(err, ShouldBeNil) + }) }) } diff --git a/models/client.go b/models/client.go index 16e519b..4978243 100644 --- a/models/client.go +++ b/models/client.go @@ -1,5 +1,10 @@ package models +// NewClient 创建客户端模型实例 +func NewClient() *Client { + return &Client{} +} + // Client 客户端信息 type Client struct { ID string // 客户端ID diff --git a/models/token.go b/models/token.go index ede71af..60b13c5 100644 --- a/models/token.go +++ b/models/token.go @@ -2,6 +2,11 @@ package models import "time" +// NewToken 创建令牌模型实例 +func NewToken() *Token { + return &Token{} +} + // Token 令牌信息 type Token struct { ClientID string `bson:"ClientID"` // 客户端标识 From 6469c6797f569a0b1c97188d852123d755ba98cb Mon Sep 17 00:00:00 2001 From: lyric Date: Tue, 5 Jul 2016 18:15:21 +0800 Subject: [PATCH 12/18] Add manager test file --- manage/manage_test.go | 80 ++++++++++++++++++++++++++++++++++++++++--- manage/manager.go | 20 +++++++---- store/token/redis.go | 12 +++++-- 3 files changed, 99 insertions(+), 13 deletions(-) diff --git a/manage/manage_test.go b/manage/manage_test.go index ad26844..0b60fed 100644 --- a/manage/manage_test.go +++ b/manage/manage_test.go @@ -4,6 +4,7 @@ import ( "testing" . "github.com/smartystreets/goconvey/convey" + "gopkg.in/oauth2.v2" "gopkg.in/oauth2.v2/generates" "gopkg.in/oauth2.v2/models" "gopkg.in/oauth2.v2/store/client" @@ -11,7 +12,7 @@ import ( ) func TestManager(t *testing.T) { - Convey("Manager Test", t, func() { + Convey("Manager test", t, func() { manager := NewManager() manager.MapClientModel(models.NewClient()) @@ -19,9 +20,80 @@ func TestManager(t *testing.T) { manager.MapAuthorizeGenerate(generates.NewAuthorizeGenerate()) manager.MapAccessGenerate(generates.NewAccessGenerate()) manager.MapClientStorage(client.NewTempStore()) - manager.MustTokenStorage(token.NewRedisStore( - &token.RedisConfig{Addr: "192.168.33.70:6379"}, - )) + + Convey("GetClient test", func() { + cli, err := manager.GetClient("1") + So(err, ShouldBeNil) + So(cli.GetSecret(), ShouldEqual, "11") + }) + + Convey("Redis store test", func() { + manager.MustTokenStorage(token.NewRedisStore( + &token.RedisConfig{Addr: "192.168.33.70:6379"}, + )) + testManager(manager) + }) + + Convey("MongoDB store test", func() { + manager.MustTokenStorage(token.NewMongoStore( + &token.MongoConfig{URL: "mongodb://admin:123456@192.168.33.70:27017"}, + )) + testManager(manager) + }) }) } + +func testManager(manager oauth2.Manager) { + reqParams := &oauth2.TokenGenerateRequest{ + ClientID: "1", + UserID: "123456", + RedirectURI: "http://localhost/oauth2", + Scope: "all", + } + code, err := manager.GenerateAuthToken(oauth2.Code, reqParams) + So(err, ShouldBeNil) + So(code, ShouldNotBeEmpty) + + atParams := &oauth2.TokenGenerateRequest{ + ClientID: "1", + RedirectURI: "http://localhost/oauth2", + Code: code, + IsGenerateRefresh: true, + } + accessToken, refreshToken, err := manager.GenerateAccessToken(oauth2.AuthorizationCodeCredentials, atParams) + So(err, ShouldBeNil) + So(accessToken, ShouldNotBeEmpty) + So(refreshToken, ShouldNotBeEmpty) + + _, err = manager.LoadAccessToken(code) + So(err, ShouldNotBeNil) + + ainfo, err := manager.LoadAccessToken(accessToken) + So(err, ShouldBeNil) + So(ainfo.GetClientID(), ShouldEqual, atParams.ClientID) + + rinfo, err := manager.LoadRefreshToken(refreshToken) + So(err, ShouldBeNil) + So(rinfo.GetClientID(), ShouldEqual, atParams.ClientID) + + refreshAT, err := manager.RefreshAccessToken(refreshToken, "owner") + So(err, ShouldBeNil) + So(refreshAT, ShouldNotBeEmpty) + + _, err = manager.LoadAccessToken(accessToken) + So(err, ShouldNotBeNil) + + refreshAInfo, err := manager.LoadAccessToken(refreshAT) + So(err, ShouldBeNil) + So(refreshAInfo.GetScope(), ShouldEqual, "owner") + + err = manager.RemoveRefreshToken(refreshToken) + So(err, ShouldBeNil) + + _, err = manager.LoadAccessToken(refreshAT) + So(err, ShouldNotBeNil) + + _, err = manager.LoadRefreshToken(refreshToken) + So(err, ShouldNotBeNil) +} diff --git a/manage/manager.go b/manage/manager.go index b1590e1..3bcc651 100644 --- a/manage/manager.go +++ b/manage/manager.go @@ -17,6 +17,8 @@ type Config struct { func NewManager() *Manager { m := &Manager{ injector: inject.New(), + rtcfg: make(map[oauth2.ResponseType]*Config), + gtcfg: make(map[oauth2.GrantType]*Config), } // 设定参数默认值 // 设定授权码的有效期为10分钟 @@ -126,7 +128,7 @@ func (m *Manager) MustTokenStorage(stor oauth2.TokenStore, err error) { // GetClient 获取客户端信息 func (m *Manager) GetClient(clientID string) (cli oauth2.ClientInfo, err error) { - err = m.injector.Apply(func(stor oauth2.ClientStore) { + _, ierr := m.injector.Invoke(func(stor oauth2.ClientStore) { cli, err = stor.GetByID(clientID) if err != nil { return @@ -134,6 +136,9 @@ func (m *Manager) GetClient(clientID string) (cli oauth2.ClientInfo, err error) err = ErrClientNotFound } }) + if err == nil && ierr != nil { + err = ierr + } return } @@ -182,7 +187,7 @@ func (m *Manager) GenerateAuthToken(rt oauth2.ResponseType, tgr *oauth2.TokenGen // GenerateAccessToken 生成访问令牌、更新令牌 // gt 授权模式 // tgr 生成令牌的参数 -func (m *Manager) GenerateAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (token, refresh string, err error) { +func (m *Manager) GenerateAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (access, refresh string, err error) { if gt == oauth2.AuthorizationCodeCredentials { // 授权码模式 ti, terr := m.LoadAccessToken(tgr.Code) if terr != nil { @@ -211,7 +216,7 @@ func (m *Manager) GenerateAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGene UserID: tgr.UserID, CreateAt: time.Now(), } - tv, rv, terr := gen.Token(td, tgr.IsGenerateRefresh) + av, rv, terr := gen.Token(td, tgr.IsGenerateRefresh) if terr != nil { err = terr return @@ -223,7 +228,7 @@ func (m *Manager) GenerateAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGene ti.SetAuthType(gt.String()) ti.SetAccessCreateAt(td.CreateAt) ti.SetAccessExpiresIn(m.gtcfg[gt].TokenExp) - ti.SetAccess(tv) + ti.SetAccess(av) if rv != "" { ti.SetRefreshCreateAt(td.CreateAt) ti.SetRefreshExpiresIn(m.gtcfg[gt].RefreshExp) @@ -233,7 +238,8 @@ func (m *Manager) GenerateAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGene if err != nil { return } - token = tv + access = av + refresh = rv }) if ierr != nil && err == nil { err = ierr @@ -269,11 +275,11 @@ func (m *Manager) RefreshAccessToken(refresh, scope string) (token string, err e if scope != "" { ti.SetScope(scope) } - if verr := stor.Create(ti); verr != nil { + if verr := stor.RemoveByAccess(access); verr != nil { err = verr return } - if verr := stor.RemoveByAccess(access); verr != nil { + if verr := stor.Create(ti); verr != nil { err = verr return } diff --git a/store/token/redis.go b/store/token/redis.go index ece515b..54f84f7 100644 --- a/store/token/redis.go +++ b/store/token/redis.go @@ -70,8 +70,16 @@ func (rs *RedisStore) Create(info oauth2.TokenInfo) (err error) { // remove func (rs *RedisStore) remove(key string) (err error) { - del := rs.cli.Del(key) - if verr := del.Err(); verr != nil { + info, err := rs.get(key) + if err != nil || info == nil { + return + } + pipe := rs.cli.Pipeline() + pipe.Del(info.GetAccess()) + if v := info.GetRefresh(); v != "" { + pipe.Del(v) + } + if _, verr := pipe.Exec(); verr != nil { err = verr } return From 6491a6a08b521000fd4ce2ad3162a38f6bbd23e3 Mon Sep 17 00:00:00 2001 From: lyric Date: Wed, 6 Jul 2016 16:33:29 +0800 Subject: [PATCH 13/18] Fixed some implement --- const.go | 36 ++++++++++-------------------------- manage.go | 4 ++++ manage/manager.go | 1 + model.go | 4 ++-- models/client.go | 4 ++-- 5 files changed, 19 insertions(+), 30 deletions(-) diff --git a/const.go b/const.go index 2aa18a8..29d9efd 100644 --- a/const.go +++ b/const.go @@ -1,49 +1,33 @@ package oauth2 // ResponseType 定义授权类型 -type ResponseType byte +type ResponseType string const ( // Code 授权码类型 - Code ResponseType = 1 << (iota + 1) + Code ResponseType = "code" // Token 令牌类型 - Token + Token ResponseType = "token" ) func (rt ResponseType) String() string { - switch rt { - case 1 << 1: - return "code" - case 1 << 2: - return "token" - } - return "unknown" + return string(rt) } // GrantType 定义授权模式 -type GrantType byte +type GrantType string const ( // AuthorizationCodeCredentials 授权码模式 - AuthorizationCodeCredentials GrantType = 1 << (iota + 1) + AuthorizationCodeCredentials GrantType = "authorization_code" // PasswordCredentials 密码模式 - PasswordCredentials + PasswordCredentials GrantType = "password" // ClientCredentials 客户端模式 - ClientCredentials + ClientCredentials GrantType = "clientcredentials" // RefreshCredentials 更新令牌模式 - RefreshCredentials + RefreshCredentials GrantType = "refreshtoken" ) func (gt GrantType) String() string { - switch gt { - case 1 << 1: - return "authorization_code" - case 1 << 2: - return "password" - case 1 << 3: - return "clientcredentials" - case 1 << 4: - return "refreshtoken" - } - return "unknown" + return string(gt) } diff --git a/manage.go b/manage.go index fcf0947..16f3e60 100644 --- a/manage.go +++ b/manage.go @@ -13,6 +13,10 @@ type TokenGenerateRequest struct { // Manager OAuth2授权管理接口 type Manager interface { + // GetClient 获取客户端信息 + // clientID 客户端标识 + GetClient(clientID string) (cli ClientInfo, err error) + // GenerateAuthToken 生成授权令牌 // rt 授权类型 // tgr 生成令牌的请求参数 diff --git a/manage/manager.go b/manage/manager.go index 3bcc651..a2fe5a9 100644 --- a/manage/manager.go +++ b/manage/manager.go @@ -127,6 +127,7 @@ func (m *Manager) MustTokenStorage(stor oauth2.TokenStore, err error) { } // GetClient 获取客户端信息 +// clientID 客户端标识 func (m *Manager) GetClient(clientID string) (cli oauth2.ClientInfo, err error) { _, ierr := m.injector.Invoke(func(stor oauth2.ClientStore) { cli, err = stor.GetByID(clientID) diff --git a/model.go b/model.go index e7c7fde..163a4b4 100644 --- a/model.go +++ b/model.go @@ -12,8 +12,8 @@ type ( GetSecret() string // 客户端域名URL GetDomain() string - // 用户数据 - GetUserData() interface{} + // 扩展数据 + GetExtraData() interface{} } // TokenInfo 令牌信息模型接口 diff --git a/models/client.go b/models/client.go index 4978243..9f6871b 100644 --- a/models/client.go +++ b/models/client.go @@ -27,7 +27,7 @@ func (c *Client) GetDomain() string { return c.Domain } -// GetUserData 用户数据 -func (c *Client) GetUserData() interface{} { +// GetExtraData 扩展数据 +func (c *Client) GetExtraData() interface{} { return nil } From 222cdc9665e019d4289cc6eca1e3f00d298c883a Mon Sep 17 00:00:00 2001 From: lyric Date: Fri, 8 Jul 2016 14:07:39 +0800 Subject: [PATCH 14/18] Add server package --- generates/access_test.go | 4 +- generates/authorize_test.go | 2 +- manage.go | 7 +- manage/error.go | 4 +- manage/manage_test.go | 17 ++- manage/manager.go | 40 +++---- server/.gitkeep | 0 server/authorize.go | 53 +++++++++ server/config.go | 22 ++++ server/error.go | 23 ++++ server/server.go | 208 ++++++++++++++++++++++++++++++++++++ 11 files changed, 349 insertions(+), 31 deletions(-) delete mode 100644 server/.gitkeep create mode 100644 server/authorize.go create mode 100644 server/config.go create mode 100644 server/error.go create mode 100644 server/server.go diff --git a/generates/access_test.go b/generates/access_test.go index 3a78542..ff533ec 100644 --- a/generates/access_test.go +++ b/generates/access_test.go @@ -22,7 +22,7 @@ func TestAccess(t *testing.T) { gen := NewAccessGenerate() access, refresh, err := gen.Token(data, true) So(err, ShouldBeNil) - Println("\nAccess:", access) - Println("Refresh:", refresh) + So(access, ShouldNotBeEmpty) + So(refresh, ShouldNotBeEmpty) }) } diff --git a/generates/authorize_test.go b/generates/authorize_test.go index 6439b72..d94d67b 100644 --- a/generates/authorize_test.go +++ b/generates/authorize_test.go @@ -22,6 +22,6 @@ func TestAuthorize(t *testing.T) { gen := NewAuthorizeGenerate() code, err := gen.Token(data) So(err, ShouldBeNil) - Println("\nCode:", code) + So(code, ShouldNotBeEmpty) }) } diff --git a/manage.go b/manage.go index 16f3e60..f8c139a 100644 --- a/manage.go +++ b/manage.go @@ -8,6 +8,7 @@ type TokenGenerateRequest struct { RedirectURI string // 重定向URI Scope string // 授权范围 Code string // 授权码(授权码模式使用) + Refresh string // 刷新令牌 IsGenerateRefresh bool // 是否生成更新令牌 } @@ -20,17 +21,17 @@ type Manager interface { // GenerateAuthToken 生成授权令牌 // rt 授权类型 // tgr 生成令牌的请求参数 - GenerateAuthToken(rt ResponseType, tgr *TokenGenerateRequest) (token string, err error) + GenerateAuthToken(rt ResponseType, tgr *TokenGenerateRequest) (authToken TokenInfo, err error) // GenerateAccessToken 生成访问令牌、更新令牌 // rt 授权模式 // tgr 生成令牌的请求参数 - GenerateAccessToken(rt GrantType, tgr *TokenGenerateRequest) (access, refresh string, err error) + GenerateAccessToken(rt GrantType, tgr *TokenGenerateRequest) (accessToken TokenInfo, err error) // RefreshAccessToken 更新访问令牌 // refresh 更新令牌 // scope 作用域 - RefreshAccessToken(refresh, scope string) (access string, err error) + RefreshAccessToken(refresh, scope string) (accessToken TokenInfo, err error) // RemoveAccessToken 删除访问令牌 // access 访问令牌 diff --git a/manage/error.go b/manage/error.go index b959173..c5ec58d 100644 --- a/manage/error.go +++ b/manage/error.go @@ -12,8 +12,8 @@ var ( // ErrClientInvalid Client invalid ErrClientInvalid = errors.New("client invalid") - // ErrAuthTokenInvalid Authorize token invalid - ErrAuthTokenInvalid = errors.New("authorize token invalid") + // ErrAuthCodeInvalid Authorize token invalid + ErrAuthCodeInvalid = errors.New("authorize code invalid") // ErrAccessInvalid Access token expired ErrAccessInvalid = errors.New("access token invalid") diff --git a/manage/manage_test.go b/manage/manage_test.go index 0b60fed..3afb4e1 100644 --- a/manage/manage_test.go +++ b/manage/manage_test.go @@ -51,18 +51,23 @@ func testManager(manager oauth2.Manager) { RedirectURI: "http://localhost/oauth2", Scope: "all", } - code, err := manager.GenerateAuthToken(oauth2.Code, reqParams) + cti, err := manager.GenerateAuthToken(oauth2.Code, reqParams) So(err, ShouldBeNil) + + code := cti.GetAccess() So(code, ShouldNotBeEmpty) atParams := &oauth2.TokenGenerateRequest{ - ClientID: "1", - RedirectURI: "http://localhost/oauth2", + ClientID: reqParams.ClientID, + ClientSecret: "11", + RedirectURI: reqParams.RedirectURI, Code: code, IsGenerateRefresh: true, } - accessToken, refreshToken, err := manager.GenerateAccessToken(oauth2.AuthorizationCodeCredentials, atParams) + ati, err := manager.GenerateAccessToken(oauth2.AuthorizationCodeCredentials, atParams) So(err, ShouldBeNil) + + accessToken, refreshToken := ati.GetAccess(), ati.GetRefresh() So(accessToken, ShouldNotBeEmpty) So(refreshToken, ShouldNotBeEmpty) @@ -77,8 +82,10 @@ func testManager(manager oauth2.Manager) { So(err, ShouldBeNil) So(rinfo.GetClientID(), ShouldEqual, atParams.ClientID) - refreshAT, err := manager.RefreshAccessToken(refreshToken, "owner") + rti, err := manager.RefreshAccessToken(refreshToken, "owner") So(err, ShouldBeNil) + + refreshAT := rti.GetAccess() So(refreshAT, ShouldNotBeEmpty) _, err = manager.LoadAccessToken(accessToken) diff --git a/manage/manager.go b/manage/manager.go index a2fe5a9..092b3ad 100644 --- a/manage/manager.go +++ b/manage/manager.go @@ -146,7 +146,7 @@ func (m *Manager) GetClient(clientID string) (cli oauth2.ClientInfo, err error) // GenerateAuthToken 生成授权令牌 // rt 授权类型 // tgr 生成令牌的配置参数 -func (m *Manager) GenerateAuthToken(rt oauth2.ResponseType, tgr *oauth2.TokenGenerateRequest) (token string, err error) { +func (m *Manager) GenerateAuthToken(rt oauth2.ResponseType, tgr *oauth2.TokenGenerateRequest) (authToken oauth2.TokenInfo, err error) { cli, err := m.GetClient(tgr.ClientID) if err != nil { return @@ -177,7 +177,7 @@ func (m *Manager) GenerateAuthToken(rt oauth2.ResponseType, tgr *oauth2.TokenGen if err != nil { return } - token = tv + authToken = ti }) if ierr != nil && err == nil { err = ierr @@ -188,14 +188,14 @@ func (m *Manager) GenerateAuthToken(rt oauth2.ResponseType, tgr *oauth2.TokenGen // GenerateAccessToken 生成访问令牌、更新令牌 // gt 授权模式 // tgr 生成令牌的参数 -func (m *Manager) GenerateAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (access, refresh string, err error) { +func (m *Manager) GenerateAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (accessToken oauth2.TokenInfo, err error) { if gt == oauth2.AuthorizationCodeCredentials { // 授权码模式 ti, terr := m.LoadAccessToken(tgr.Code) if terr != nil { err = terr return } else if ti.GetRedirectURI() != tgr.RedirectURI || ti.GetClientID() != tgr.ClientID { - err = ErrAuthTokenInvalid + err = ErrAuthCodeInvalid return } else if verr := m.RemoveAccessToken(tgr.Code); verr != nil { // 删除授权码 err = verr @@ -239,8 +239,7 @@ func (m *Manager) GenerateAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGene if err != nil { return } - access = av - refresh = rv + accessToken = ti }) if ierr != nil && err == nil { err = ierr @@ -249,16 +248,25 @@ func (m *Manager) GenerateAccessToken(gt oauth2.GrantType, tgr *oauth2.TokenGene } // RefreshAccessToken 更新访问令牌 -func (m *Manager) RefreshAccessToken(refresh, scope string) (token string, err error) { - ti, err := m.LoadRefreshToken(refresh) +func (m *Manager) RefreshAccessToken(tgr *oauth2.TokenGenerateRequest) (accessToken oauth2.TokenInfo, err error) { + cli, err := m.GetClient(tgr.ClientID) if err != nil { return + } else if tgr.ClientSecret != "" && tgr.ClientSecret != cli.GetSecret() { + err = ErrClientInvalid + return + } + ti, err := m.LoadRefreshToken(tgr.Refresh) + if err != nil { + return + } else if ti.GetClientID() != tgr.ClientID { + err = ErrRefreshInvalid + return } - access := ti.GetAccess() _, ierr := m.injector.Invoke(func(stor oauth2.TokenStore, gen oauth2.AccessGenerate) { - cli, cerr := m.GetClient(ti.GetClientID()) - if cerr != nil { - err = cerr + // 移除旧的访问令牌 + if verr := stor.RemoveByAccess(ti.GetAccess()); verr != nil { + err = verr return } td := &oauth2.GenerateBasic{ @@ -273,18 +281,14 @@ func (m *Manager) RefreshAccessToken(refresh, scope string) (token string, err e } ti.SetAccess(tv) ti.SetAccessCreateAt(td.CreateAt) - if scope != "" { + if scope := tgr.Scope; scope != "" { ti.SetScope(scope) } - if verr := stor.RemoveByAccess(access); verr != nil { - err = verr - return - } if verr := stor.Create(ti); verr != nil { err = verr return } - token = tv + accessToken = ti }) if ierr != nil && err == nil { err = ierr diff --git a/server/.gitkeep b/server/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/server/authorize.go b/server/authorize.go new file mode 100644 index 0000000..9fe5247 --- /dev/null +++ b/server/authorize.go @@ -0,0 +1,53 @@ +package server + +import ( + "encoding/base64" + "net/http" + "strings" + + "gopkg.in/oauth2.v2" +) + +// AuthorizeRequest 授权请求 +type AuthorizeRequest struct { + Type oauth2.ResponseType + ClientID string + Scope string + RedirectURI string + State string + UserID string +} + +// ClientHandler 获取客户端信息 +type ClientHandler func(r *http.Request) (clientID, clientSecret string, err error) + +// UserHandler 获取用户信息 +type UserHandler func(username, password string) (userID string, err error) + +// ClientFormHandler 客户端表单信息 +func ClientFormHandler(r *http.Request) (clientID, clientSecret string, err error) { + clientID = r.Form.Get("client_id") + clientSecret = r.Form.Get("client_secret") + return +} + +// ClientBasicHandler 客户端基础认证信息 +func ClientBasicHandler(r *http.Request) (clientID, clientSecret string, err error) { + s := strings.SplitN(r.Header.Get("Authorization"), " ", 2) + if len(s) != 2 || s[0] != "Basic" { + err = ErrAuthorizationHeaderInvalid + return + } + b, err := base64.StdEncoding.DecodeString(s[1]) + if err != nil { + return + } + pair := strings.SplitN(string(b), ":", 2) + if len(pair) != 2 { + err = ErrAuthorizationHeaderInvalid + return + } + clientID = pair[0] + clientSecret = pair[1] + return +} diff --git a/server/config.go b/server/config.go new file mode 100644 index 0000000..d5be60b --- /dev/null +++ b/server/config.go @@ -0,0 +1,22 @@ +package server + +import "gopkg.in/oauth2.v2" + +// Config 配置参数 +type Config struct { + // TokenType 令牌类型(默认为Bearer) + TokenType string + // AllowedResponseType 允许的授权类型(默认code) + AllowedResponseType []oauth2.ResponseType + // AllowedGrantType 允许的授权模式(默认authorization_code) + AllowedGrantType []oauth2.GrantType +} + +// NewConfig 创建默认的配置参数 +func NewConfig() *Config { + return &Config{ + TokenType: "Bearer", + AllowedResponseType: []oauth2.ResponseType{oauth2.Code}, + AllowedGrantType: []oauth2.GrantType{oauth2.AuthorizationCodeCredentials}, + } +} diff --git a/server/error.go b/server/error.go new file mode 100644 index 0000000..bd39b4e --- /dev/null +++ b/server/error.go @@ -0,0 +1,23 @@ +package server + +import "errors" + +var ( + // ErrRequestMethodInvalid Request method invalid + ErrRequestMethodInvalid = errors.New("request method invalid") + + // ErrResponseTypeInvalid Response type invalid + ErrResponseTypeInvalid = errors.New("response type invalid") + + // ErrGrantTypeInvalid Grant type invalid + ErrGrantTypeInvalid = errors.New("grant type invalid") + + // ErrClientInvalid Client invalid + ErrClientInvalid = errors.New("client invalid") + + // ErrUserInvalid User invalid + ErrUserInvalid = errors.New("user invalid") + + // ErrAuthorizationHeaderInvalid Authorization header invalid + ErrAuthorizationHeaderInvalid = errors.New("authorization header invalid") +) diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..d116c4c --- /dev/null +++ b/server/server.go @@ -0,0 +1,208 @@ +package server + +import ( + "encoding/json" + "net/http" + "net/url" + "strconv" + "time" + + "gopkg.in/oauth2.v2" +) + +// NewServer 创建OAuth2服务实例 +func NewServer(cfg *Config, manager oauth2.Manager) *Server { + return &Server{ + cfg: cfg, + manager: manager, + } +} + +// Server OAuth2服务处理 +type Server struct { + cfg *Config + manager oauth2.Manager +} + +// checkResponseType 检查允许的授权类型 +func (s *Server) checkResponseType(rt oauth2.ResponseType) bool { + for _, art := range s.cfg.AllowedResponseType { + if art == rt { + return true + } + } + return false +} + +// checkGrantType 检查允许的授权模式 +func (s *Server) checkGrantType(gt oauth2.GrantType) bool { + for _, agt := range s.cfg.AllowedGrantType { + if agt == gt { + return true + } + } + return false +} + +// GetAuthorizeRequest 获取授权请求参数 +func (s *Server) GetAuthorizeRequest(r *http.Request) (authReq *AuthorizeRequest, err error) { + if r.Method != "GET" { + err = ErrRequestMethodInvalid + return + } + r.ParseForm() + redirectURI, err := url.QueryUnescape(r.Form.Get("redirect_uri")) + if err != nil { + return + } + authReq = &AuthorizeRequest{ + Type: oauth2.ResponseType(r.Form.Get("response_type")), + RedirectURI: redirectURI, + State: r.Form.Get("state"), + Scope: r.Form.Get("scope"), + ClientID: r.Form.Get("client_id"), + } + if authReq.Type == "" || !s.checkResponseType(authReq.Type) { + err = ErrResponseTypeInvalid + return + } else if authReq.ClientID == "" { + err = ErrClientInvalid + } + return +} + +// HandleAuthorizeRequest 处理授权请求 +func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, authReq *AuthorizeRequest) (err error) { + if authReq.UserID == "" { + err = ErrUserInvalid + return + } + tgr := &oauth2.TokenGenerateRequest{ + ClientID: authReq.ClientID, + UserID: authReq.UserID, + RedirectURI: authReq.RedirectURI, + Scope: authReq.Scope, + } + ti, terr := s.manager.GenerateAuthToken(oauth2.Code, tgr) + if terr != nil { + err = terr + return + } + s.ResRedirectURI(w, authReq, ti) + return +} + +// HandleTokenRequest 处理令牌请求 +// cli 获取客户端信息 +// user 获取用户信息 +func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request, ch ClientHandler, uh UserHandler) (err error) { + if r.Method != "POST" { + err = ErrRequestMethodInvalid + return + } + if verr := r.ParseForm(); verr != nil { + err = verr + return + } + gt := oauth2.GrantType(r.Form.Get("grant_type")) + if gt == "" || !s.checkGrantType(gt) { + err = ErrGrantTypeInvalid + return + } + + var ti oauth2.TokenInfo + clientID, clientSecret, err := ch(r) + if err != nil { + return + } + if clientID == "" || clientSecret == "" { + err = ErrClientInvalid + return + } + tgr := &oauth2.TokenGenerateRequest{ + ClientID: clientID, + ClientSecret: clientSecret, + } + + switch oauth2.GrantType(r.Form.Get("grant_type")) { + case oauth2.AuthorizationCodeCredentials: + tgr.RedirectURI = r.Form.Get("redirect_uri") + tgr.Code = r.Form.Get("code") + tgr.IsGenerateRefresh = true + ti, err = s.manager.GenerateAccessToken(oauth2.AuthorizationCodeCredentials, tgr) + case oauth2.PasswordCredentials: + userID, uerr := uh(r.Form.Get("username"), r.Form.Get("password")) + if uerr != nil { + err = uerr + return + } + tgr.UserID = userID + tgr.Scope = r.Form.Get("scope") + tgr.IsGenerateRefresh = true + case oauth2.ClientCredentials: + tgr.Scope = r.Form.Get("scope") + case oauth2.RefreshCredentials: + tgr.Refresh = r.Form.Get("refresh_token") + tgr.Scope = r.Form.Get("scope") + } + + if err != nil { + return + } + err = s.ResJSON(w, ti) + return +} + +func (s *Server) handleReponse(w http.ResponseWriter) { + w.Header().Add("Cache-Control", "no-cache, no-store, max-age=0, must-revalidate") + w.Header().Add("Pragma", "no-cache") + w.Header().Add("Expires", "Fri, 01 Jan 1990 00:00:00 GMT") +} + +// ResRedirectURI 响应数据到重定向URI +func (s *Server) ResRedirectURI(w http.ResponseWriter, authReq *AuthorizeRequest, ti oauth2.TokenInfo) (err error) { + u, err := url.Parse(authReq.RedirectURI) + if err != nil { + return + } + q := u.Query() + q.Set("state", authReq.State) + switch authReq.Type { + case oauth2.Code: + q.Set("code", ti.GetAccess()) + u.RawQuery = q.Encode() + case oauth2.Token: + q.Set("access_token", ti.GetAccess()) + q.Set("token_type", s.cfg.TokenType) + q.Set("expires_in", strconv.FormatInt(int64(ti.GetAccessExpiresIn()/time.Second), 10)) + q.Set("scope", ti.GetScope()) + u.RawQuery = "" + u.Fragment, err = url.QueryUnescape(q.Encode()) + if err != nil { + return + } + } + s.handleReponse(w) + w.Header().Add("Location", u.String()) + w.WriteHeader(302) + return +} + +// ResJSON 响应Json数据 +func (s *Server) ResJSON(w http.ResponseWriter, ti oauth2.TokenInfo) (err error) { + data := map[string]interface{}{ + "access_token": ti.GetAccess(), + "token_type": s.cfg.TokenType, + "expires_in": ti.GetAccessExpiresIn() / time.Second, + } + if scope := ti.GetScope(); scope != "" { + data["scope"] = scope + } + if refresh := ti.GetRefresh(); refresh != "" { + data["refresh_token"] = refresh + } + s.handleReponse(w) + w.Header().Set("Content-Type", "application/json;charset=UTF-8") + w.WriteHeader(http.StatusOK) + return json.NewEncoder(w).Encode(data) +} From 69732ada54046155cb7e18773bfcbda8d5f2af80 Mon Sep 17 00:00:00 2001 From: lyric Date: Fri, 8 Jul 2016 15:18:12 +0800 Subject: [PATCH 15/18] Fixed server package --- manage.go | 5 ++--- server/authorize.go | 36 +++++++++++++++++++----------------- server/config.go | 3 +++ server/error.go | 6 ++++++ server/server.go | 44 +++++++++++++++++++++++++++++++++++++------- 5 files changed, 67 insertions(+), 27 deletions(-) diff --git a/manage.go b/manage.go index f8c139a..d4635ce 100644 --- a/manage.go +++ b/manage.go @@ -29,9 +29,8 @@ type Manager interface { GenerateAccessToken(rt GrantType, tgr *TokenGenerateRequest) (accessToken TokenInfo, err error) // RefreshAccessToken 更新访问令牌 - // refresh 更新令牌 - // scope 作用域 - RefreshAccessToken(refresh, scope string) (accessToken TokenInfo, err error) + // tgr 生成令牌的请求参数 + RefreshAccessToken(tgr *TokenGenerateRequest) (accessToken TokenInfo, err error) // RemoveAccessToken 删除访问令牌 // access 访问令牌 diff --git a/server/authorize.go b/server/authorize.go index 9fe5247..0f59523 100644 --- a/server/authorize.go +++ b/server/authorize.go @@ -1,9 +1,7 @@ package server import ( - "encoding/base64" "net/http" - "strings" "gopkg.in/oauth2.v2" ) @@ -18,36 +16,40 @@ type AuthorizeRequest struct { UserID string } -// ClientHandler 获取客户端信息 +// ClientHandler 客户端处理(获取请求的客户端认证信息) type ClientHandler func(r *http.Request) (clientID, clientSecret string, err error) -// UserHandler 获取用户信息 +// UserHandler 用户处理(密码模式,根据用户名、密码获取用户标识) type UserHandler func(username, password string) (userID string, err error) +// ScopeHandler 授权范围处理(更新令牌时的授权范围检查) +type ScopeHandler func(new, old string) (err error) + +// TokenRequestHandler 令牌请求处理 +type TokenRequestHandler struct { + ClientHandler ClientHandler + UserHandler UserHandler + ScopeHandler ScopeHandler +} + // ClientFormHandler 客户端表单信息 func ClientFormHandler(r *http.Request) (clientID, clientSecret string, err error) { clientID = r.Form.Get("client_id") clientSecret = r.Form.Get("client_secret") + if clientID == "" || clientSecret == "" { + err = ErrAuthorizationFormInvalid + } return } // ClientBasicHandler 客户端基础认证信息 func ClientBasicHandler(r *http.Request) (clientID, clientSecret string, err error) { - s := strings.SplitN(r.Header.Get("Authorization"), " ", 2) - if len(s) != 2 || s[0] != "Basic" { - err = ErrAuthorizationHeaderInvalid - return - } - b, err := base64.StdEncoding.DecodeString(s[1]) - if err != nil { - return - } - pair := strings.SplitN(string(b), ":", 2) - if len(pair) != 2 { + username, password, ok := r.BasicAuth() + if !ok { err = ErrAuthorizationHeaderInvalid return } - clientID = pair[0] - clientSecret = pair[1] + clientID = username + clientSecret = password return } diff --git a/server/config.go b/server/config.go index d5be60b..eb2d923 100644 --- a/server/config.go +++ b/server/config.go @@ -10,6 +10,8 @@ type Config struct { AllowedResponseType []oauth2.ResponseType // AllowedGrantType 允许的授权模式(默认authorization_code) AllowedGrantType []oauth2.GrantType + // Handler 令牌请求处理 + Handler *TokenRequestHandler } // NewConfig 创建默认的配置参数 @@ -18,5 +20,6 @@ func NewConfig() *Config { TokenType: "Bearer", AllowedResponseType: []oauth2.ResponseType{oauth2.Code}, AllowedGrantType: []oauth2.GrantType{oauth2.AuthorizationCodeCredentials}, + Handler: &TokenRequestHandler{}, } } diff --git a/server/error.go b/server/error.go index bd39b4e..f138392 100644 --- a/server/error.go +++ b/server/error.go @@ -18,6 +18,12 @@ var ( // ErrUserInvalid User invalid ErrUserInvalid = errors.New("user invalid") + // ErrAuthorizationFormInvalid Authorization form invalid + ErrAuthorizationFormInvalid = errors.New("authorization form invalid") + // ErrAuthorizationHeaderInvalid Authorization header invalid ErrAuthorizationHeaderInvalid = errors.New("authorization header invalid") + + // ErrRefreshInvalid Refresh token invalid + ErrRefreshInvalid = errors.New("refresh token invalid") ) diff --git a/server/server.go b/server/server.go index d116c4c..050f923 100644 --- a/server/server.go +++ b/server/server.go @@ -24,6 +24,21 @@ type Server struct { manager oauth2.Manager } +// SetClientHandler 设置客户端处理 +func (s *Server) SetClientHandler(handler ClientHandler) { + s.cfg.Handler.ClientHandler = handler +} + +// SetUserHandler 设置用户处理 +func (s *Server) SetUserHandler(handler UserHandler) { + s.cfg.Handler.UserHandler = handler +} + +// SetScopeHandler 设置授权范围处理 +func (s *Server) SetScopeHandler(handler ScopeHandler) { + s.cfg.Handler.ScopeHandler = handler +} + // checkResponseType 检查允许的授权类型 func (s *Server) checkResponseType(rt oauth2.ResponseType) bool { for _, art := range s.cfg.AllowedResponseType { @@ -95,7 +110,7 @@ func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, authReq *Authoriz // HandleTokenRequest 处理令牌请求 // cli 获取客户端信息 // user 获取用户信息 -func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request, ch ClientHandler, uh UserHandler) (err error) { +func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) (err error) { if r.Method != "POST" { err = ErrRequestMethodInvalid return @@ -111,14 +126,10 @@ func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request, ch C } var ti oauth2.TokenInfo - clientID, clientSecret, err := ch(r) + clientID, clientSecret, err := s.cfg.Handler.ClientHandler(r) if err != nil { return } - if clientID == "" || clientSecret == "" { - err = ErrClientInvalid - return - } tgr := &oauth2.TokenGenerateRequest{ ClientID: clientID, ClientSecret: clientSecret, @@ -131,7 +142,7 @@ func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request, ch C tgr.IsGenerateRefresh = true ti, err = s.manager.GenerateAccessToken(oauth2.AuthorizationCodeCredentials, tgr) case oauth2.PasswordCredentials: - userID, uerr := uh(r.Form.Get("username"), r.Form.Get("password")) + userID, uerr := s.cfg.Handler.UserHandler(r.Form.Get("username"), r.Form.Get("password")) if uerr != nil { err = uerr return @@ -139,11 +150,30 @@ func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request, ch C tgr.UserID = userID tgr.Scope = r.Form.Get("scope") tgr.IsGenerateRefresh = true + ti, err = s.manager.GenerateAccessToken(oauth2.PasswordCredentials, tgr) case oauth2.ClientCredentials: tgr.Scope = r.Form.Get("scope") + ti, err = s.manager.GenerateAccessToken(oauth2.ClientCredentials, tgr) case oauth2.RefreshCredentials: tgr.Refresh = r.Form.Get("refresh_token") tgr.Scope = r.Form.Get("scope") + if tgr.Scope != "" { // 检查授权范围 + rti, rerr := s.manager.LoadRefreshToken(tgr.Refresh) + if rerr != nil { + err = rerr + return + } else if rti.GetClientID() != tgr.ClientID { + err = ErrRefreshInvalid + return + } else if verr := s.cfg.Handler.ScopeHandler(tgr.Scope, rti.GetScope()); verr != nil { + err = verr + return + } + } + ti, err = s.manager.RefreshAccessToken(tgr) + if err == nil { + ti.SetRefresh("") + } } if err != nil { From 149680ad80d092737fe0b72e464fd4ca55ac63de Mon Sep 17 00:00:00 2001 From: lyric Date: Fri, 8 Jul 2016 15:23:56 +0800 Subject: [PATCH 16/18] Fixed manage test file --- manage/manage_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/manage/manage_test.go b/manage/manage_test.go index 3afb4e1..8f495e5 100644 --- a/manage/manage_test.go +++ b/manage/manage_test.go @@ -82,7 +82,9 @@ func testManager(manager oauth2.Manager) { So(err, ShouldBeNil) So(rinfo.GetClientID(), ShouldEqual, atParams.ClientID) - rti, err := manager.RefreshAccessToken(refreshToken, "owner") + atParams.Refresh = refreshToken + atParams.Scope = "owner" + rti, err := manager.RefreshAccessToken(atParams) So(err, ShouldBeNil) refreshAT := rti.GetAccess() From 432c315663ceef361c5e41db52df4c9b4577c990 Mon Sep 17 00:00:00 2001 From: lyric Date: Fri, 8 Jul 2016 15:26:46 +0800 Subject: [PATCH 17/18] Remove example --- example/.gitkeep | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 example/.gitkeep diff --git a/example/.gitkeep b/example/.gitkeep deleted file mode 100644 index e69de29..0000000 From 82e5495a88767ae6cf8ba185628559bc8653b233 Mon Sep 17 00:00:00 2001 From: lyric Date: Sat, 9 Jul 2016 11:58:09 +0800 Subject: [PATCH 18/18] Add example --- README.md | 74 ++++++++++++++++++++++++++++++++++++++++-- example/README.md | 33 +++++++++++++++++++ example/client/main.go | 59 +++++++++++++++++++++++++++++++++ example/server/main.go | 51 +++++++++++++++++++++++++++++ manage/manager.go | 30 ++++++++++++++++- server/server.go | 19 ++++++++++- store/client/temp.go | 20 +++++++----- 7 files changed, 273 insertions(+), 13 deletions(-) create mode 100644 example/README.md create mode 100644 example/client/main.go create mode 100644 example/server/main.go diff --git a/README.md b/README.md index c93f8c9..c807e07 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ -Golang OAuth2 Server -===================== +OAuth2服务端 +=========== + +> 基于Golang实现的OAuth2协议,具有简单化、模块化的特点 [![GoDoc](https://godoc.org/gopkg.in/oauth2.v2?status.svg)](https://godoc.org/gopkg.in/oauth2.v2) [![Go Report Card](https://goreportcard.com/badge/gopkg.in/oauth2.v2)](https://goreportcard.com/report/gopkg.in/oauth2.v2) @@ -8,9 +10,75 @@ Golang OAuth2 Server ---- ```bash -$ go get -v gopkg.in/oauth2.v2 +$ go get -u gopkg.in/oauth2.v2/... +``` + +使用 +---- + +``` go +package main + +import ( + "log" + "net/http" + + "gopkg.in/oauth2.v2/manage" + "gopkg.in/oauth2.v2/models" + "gopkg.in/oauth2.v2/server" + "gopkg.in/oauth2.v2/store/client" + "gopkg.in/oauth2.v2/store/token" +) + +func main() { + manager := manage.NewRedisManager( + &token.RedisConfig{Addr: "192.168.33.70:6379"}, + ) + manager.MapClientStorage(client.NewTempStore()) + srv := server.NewServer(server.NewConfig(), manager) + + http.HandleFunc("/authorize", func(w http.ResponseWriter, r *http.Request) { + authReq, err := srv.GetAuthorizeRequest(r) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + // TODO: 登录验证、授权处理 + authReq.UserID = "000000" + + err = srv.HandleAuthorizeRequest(w, authReq) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + } + }) + + http.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { + err := srv.HandleTokenRequest(w, r) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + } + }) + + log.Fatal(http.ListenAndServe(":9096", nil)) +} + +``` + +测试 +---- + +``` bash +$ goconvey -port=9092 ``` +> goconvey使用明细[https://github.com/smartystreets/goconvey](https://github.com/smartystreets/goconvey) + +范例 +---- + +模拟授权码模式的测试范例,请查看[example](/example) + + License ------- diff --git a/example/README.md b/example/README.md new file mode 100644 index 0000000..2e40d62 --- /dev/null +++ b/example/README.md @@ -0,0 +1,33 @@ +OAuth2 服务端/客户端模拟 +===================== + +运行服务端 +-------- + +``` +$ cd example/server +$ go run main.go +``` + +运行客户端 +-------- + +``` +$ cd example/client +$ go run main.go +``` + +打开浏览器 +-------- + +[http://localhost:9094](http://localhost:9094) + +``` json +{ + "access_token": "143C1A45CFF9E0922F9DC68F7EBC81DC", + "expires_in": 7200, + "refresh_token": "5BD7453B8E7C5A3A308166F1675AD57216811391", + "scope": "all", + "token_type": "Bearer" +} +``` \ No newline at end of file diff --git a/example/client/main.go b/example/client/main.go new file mode 100644 index 0000000..031a194 --- /dev/null +++ b/example/client/main.go @@ -0,0 +1,59 @@ +package main + +import ( + "io" + "log" + "net/http" + "net/url" +) + +const ( + redirectURI = "http://localhost:9094/oauth2" + serverURI = "http://localhost:9096" +) + +func main() { + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + u, err := url.Parse(serverURI + "/authorize") + if err != nil { + panic(err) + } + q := u.Query() + q.Add("response_type", "code") + q.Add("client_id", "222222") + q.Add("scope", "all") + q.Add("state", "xyz") + q.Add("redirect_uri", url.QueryEscape(redirectURI)) + u.RawQuery = q.Encode() + http.Redirect(w, r, u.String(), http.StatusFound) + }) + + http.HandleFunc("/oauth2", func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + state := r.Form.Get("state") + if state != "xyz" { + http.Error(w, "State invalid", http.StatusBadRequest) + return + } + code := r.Form.Get("code") + if code == "" { + http.Error(w, "Code not found", http.StatusBadRequest) + return + } + uv := url.Values{} + uv.Add("code", code) + uv.Add("redirect_uri", redirectURI) + uv.Add("grant_type", "authorization_code") + uv.Add("client_id", "222222") + uv.Add("client_secret", "22222222") + resp, err := http.PostForm(serverURI+"/token", uv) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + io.Copy(w, resp.Body) + }) + + log.Println("OAuth2 client is running at 9094 port.") + log.Fatal(http.ListenAndServe(":9094", nil)) +} diff --git a/example/server/main.go b/example/server/main.go new file mode 100644 index 0000000..c82f4ff --- /dev/null +++ b/example/server/main.go @@ -0,0 +1,51 @@ +package main + +import ( + "log" + "net/http" + + "gopkg.in/oauth2.v2/manage" + "gopkg.in/oauth2.v2/models" + "gopkg.in/oauth2.v2/server" + "gopkg.in/oauth2.v2/store/client" + "gopkg.in/oauth2.v2/store/token" +) + +func main() { + // 创建基于redis的oauth2管理实例 + manager := manage.NewRedisManager( + &token.RedisConfig{Addr: "192.168.33.70:6379"}, + ) + // 使用临时客户端存储 + manager.MapClientStorage(client.NewTempStore(&models.Client{ + ID: "222222", + Secret: "22222222", + Domain: "http://localhost:9094", + })) + + srv := server.NewServer(server.NewConfig(), manager) + + http.HandleFunc("/authorize", func(w http.ResponseWriter, r *http.Request) { + authReq, err := srv.GetAuthorizeRequest(r) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + authReq.UserID = "000000" + // TODO: 登录验证、授权处理 + err = srv.HandleAuthorizeRequest(w, authReq) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + } + }) + + http.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { + err := srv.HandleTokenRequest(w, r) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + } + }) + + log.Println("OAuth2 server is running at 9096 port.") + log.Fatal(http.ListenAndServe(":9096", nil)) +} diff --git a/manage/manager.go b/manage/manager.go index 092b3ad..202aa38 100644 --- a/manage/manager.go +++ b/manage/manager.go @@ -5,12 +5,39 @@ import ( "github.com/LyricTian/inject" "gopkg.in/oauth2.v2" + "gopkg.in/oauth2.v2/generates" + "gopkg.in/oauth2.v2/models" + "gopkg.in/oauth2.v2/store/token" ) // Config 授权配置参数 type Config struct { TokenExp time.Duration // 令牌有效期 - RefreshExp time.Duration // g令牌有效期 + RefreshExp time.Duration // 更新令牌有效期 +} + +// NewRedisManager 创建基于redis存储的管理实例 +func NewRedisManager(redisCfg *token.RedisConfig) *Manager { + m := NewManager() + m.MapClientModel(models.NewClient()) + m.MapTokenModel(models.NewToken()) + m.MapAuthorizeGenerate(generates.NewAuthorizeGenerate()) + m.MapAccessGenerate(generates.NewAccessGenerate()) + m.MustTokenStorage(token.NewRedisStore(redisCfg)) + + return m +} + +// NewMongoManager 创建基于mongodb存储的管理实例 +func NewMongoManager(mongoCfg *token.MongoConfig) *Manager { + m := NewManager() + m.MapClientModel(models.NewClient()) + m.MapTokenModel(models.NewToken()) + m.MapAuthorizeGenerate(generates.NewAuthorizeGenerate()) + m.MapAccessGenerate(generates.NewAccessGenerate()) + m.MustTokenStorage(token.NewMongoStore(mongoCfg)) + + return m } // NewManager 创建Manager的实例 @@ -32,6 +59,7 @@ func NewManager() *Manager { m.SetGTConfig(oauth2.PasswordCredentials, &Config{TokenExp: time.Hour * 2, RefreshExp: time.Hour * 24 * 7}) // 设定客户端模式令牌的有效期为1小时 m.SetGTConfig(oauth2.ClientCredentials, &Config{TokenExp: time.Hour * 2}) + return m } diff --git a/server/server.go b/server/server.go index 050f923..590841b 100644 --- a/server/server.go +++ b/server/server.go @@ -12,10 +12,12 @@ import ( // NewServer 创建OAuth2服务实例 func NewServer(cfg *Config, manager oauth2.Manager) *Server { - return &Server{ + srv := &Server{ cfg: cfg, manager: manager, } + srv.SetClientHandler(ClientFormHandler) + return srv } // Server OAuth2服务处理 @@ -24,6 +26,21 @@ type Server struct { manager oauth2.Manager } +// SetTokenType 设置令牌类型 +func (s *Server) SetTokenType(tokenType string) { + s.cfg.TokenType = tokenType +} + +// SetAllowedResponseType 设置允许的授权类型 +func (s *Server) SetAllowedResponseType(allowedTypes ...oauth2.ResponseType) { + s.cfg.AllowedResponseType = allowedTypes +} + +// SetAllowedGrantType 允许的授权模式 +func (s *Server) SetAllowedGrantType(allowedTypes ...oauth2.GrantType) { + s.cfg.AllowedGrantType = allowedTypes +} + // SetClientHandler 设置客户端处理 func (s *Server) SetClientHandler(handler ClientHandler) { s.cfg.Handler.ClientHandler = handler diff --git a/store/client/temp.go b/store/client/temp.go index 07a00cf..c0e126c 100644 --- a/store/client/temp.go +++ b/store/client/temp.go @@ -8,16 +8,20 @@ import ( ) // NewTempStore 创建客户端临时存储实例 -func NewTempStore() oauth2.ClientStore { - return &TempStore{ - data: map[string]*models.Client{ - "1": &models.Client{ - ID: "1", - Secret: "11", - Domain: "http://localhost", - }, +func NewTempStore(clients ...*models.Client) oauth2.ClientStore { + data := map[string]*models.Client{ + "1": &models.Client{ + ID: "1", + Secret: "11", + Domain: "http://localhost", }, } + for _, cli := range clients { + data[cli.ID] = cli + } + return &TempStore{ + data: data, + } } // TempStore 客户端信息的临时存储