From d239726d3abb9b802f1a81d27d6ff3122712ca50 Mon Sep 17 00:00:00 2001 From: lyric Date: Thu, 26 May 2016 22:44:16 +0800 Subject: [PATCH 1/2] Initialize files --- README.md | 89 ++++++++++ authorizationCode.go | 153 +++++++++++++++++ authorizationCodeGenerate.go | 101 ++++++++++++ authorizationCodeGenerate_test.go | 39 +++++ authorizationCodeMemoryStore.go | 87 ++++++++++ authorizationCodeMemoryStore_test.go | 40 +++++ authorizationCodeStore.go | 28 ++++ authorizationCode_test.go | 46 ++++++ clientCredentials.go | 62 +++++++ clientCredentials_test.go | 27 +++ clientMongoStore.go | 49 ++++++ clientMongoStore_test.go | 21 +++ clientStore.go | 47 ++++++ config.go | 47 ++++++ const.go | 30 ++++ error.go | 37 +++++ implicit.go | 61 +++++++ implicit_test.go | 29 ++++ oauth2.go | 238 +++++++++++++++++++++++++++ oauth2_test.go | 40 +++++ password.go | 77 +++++++++ password_test.go | 34 ++++ tokenGenerate.go | 74 +++++++++ tokenGenerate_test.go | 36 ++++ tokenMongoStore.go | 87 ++++++++++ tokenMongoStore_test.go | 42 +++++ tokenStore.go | 41 +++++ util.go | 28 ++++ util_test.go | 16 ++ 29 files changed, 1706 insertions(+) create mode 100644 README.md create mode 100644 authorizationCode.go create mode 100644 authorizationCodeGenerate.go create mode 100644 authorizationCodeGenerate_test.go create mode 100644 authorizationCodeMemoryStore.go create mode 100644 authorizationCodeMemoryStore_test.go create mode 100644 authorizationCodeStore.go create mode 100644 authorizationCode_test.go create mode 100644 clientCredentials.go create mode 100644 clientCredentials_test.go create mode 100644 clientMongoStore.go create mode 100644 clientMongoStore_test.go create mode 100644 clientStore.go create mode 100644 config.go create mode 100644 const.go create mode 100644 error.go create mode 100644 implicit.go create mode 100644 implicit_test.go create mode 100644 oauth2.go create mode 100644 oauth2_test.go create mode 100644 password.go create mode 100644 password_test.go create mode 100644 tokenGenerate.go create mode 100644 tokenGenerate_test.go create mode 100644 tokenMongoStore.go create mode 100644 tokenMongoStore_test.go create mode 100644 tokenStore.go create mode 100644 util.go create mode 100644 util_test.go diff --git a/README.md b/README.md new file mode 100644 index 0000000..3808a4f --- /dev/null +++ b/README.md @@ -0,0 +1,89 @@ +Golang OAuth 2.0 +================ + +[![GoDoc](https://godoc.org/gopkg.in/oauth2.v1?status.svg)](https://godoc.org/gopkg.in/oauth2.v1) +[![Go Report Card](https://goreportcard.com/badge/gopkg.in/oauth2.v1)](https://goreportcard.com/report/gopkg.in/oauth2.v1) + +> 基于Golang实现的OAuth 2.0协议相关操作,包括:令牌(或授权码)的生成、存储、验证操作以及更新令牌、废除令牌; 具有简单、灵活的特点; 其中所涉及的相关http请求操作在这里不做处理; 支持授权码模式、简化模式、密码模式、客户端模式; 默认使用MongoDB存储相关信息 + +获取 +---- + +```bash +$ go get -v gopkg.in/oauth2.v1 +``` + +范例 +---- + +> 数据初始化:初始化相关的客户端信息 + +```go +package main + +import ( + "fmt" + + "gopkg.in/oauth2.v1" +) + +func main() { + mongoConfig := oauth2.NewMongoConfig("mongodb://127.0.0.1:27017", "test") + // 创建默认的OAuth2管理实例(基于MongoDB) + manager, err := oauth2.CreateDefaultOAuthManager(mongoConfig, "", "", nil) + if err != nil { + panic(err) + } + // 模拟授权码模式 + // 使用默认参数,生成授权码 + code, err := manager.GetACManager(). + GenerateCode("clientID_x", "userID_x", "http://www.example.com/cb", "scopes") + if err != nil { + panic(err) + } + // 生成访问令牌及更新令牌 + genToken, err := manager.GetACManager(). + GenerateToken(code, "http://www.example.com/cb", "clientID_x", "clientSecret_x", true) + if err != nil { + panic(err) + } + // 检查访问令牌 + checkToken, err := manager.CheckAccessToken(genToken.AccessToken) + if err != nil { + panic(err) + } + // TODO: 使用用户标识、申请的授权范围响应数据 + fmt.Println(checkToken.UserID, checkToken.Scope) + // 申请一个新的访问令牌 + newToken, err := manager.RefreshAccessToken(checkToken.RefreshToken, "scopes") + if err != nil { + panic(err) + } + fmt.Println(newToken.AccessToken, newToken.ATExpiresIn) + // TODO: 将新的访问令牌响应给客户端 +} +``` + +执行测试 +---- + +```bash +$ go test -v +# 或 +$ goconvey --port=9090 +``` + +License +------- + +``` +Copyright 2016.All rights reserved. +``` + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + +``` + http://www.apache.org/licenses/LICENSE-2.0 +``` + +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. diff --git a/authorizationCode.go b/authorizationCode.go new file mode 100644 index 0000000..6f48e33 --- /dev/null +++ b/authorizationCode.go @@ -0,0 +1,153 @@ +package oauth2 + +import ( + "time" + + "gopkg.in/LyricTian/lib.v2" +) + +// NewACManager 创建授权码模式管理实例 +// oaManager OAuth授权管理 +// config 配置参数(nil则使用默认值) +func NewACManager(oaManager *OAuthManager, config *ACConfig) *ACManager { + if config == nil { + config = new(ACConfig) + } + if config.RandomCodeLen == 0 { + config.RandomCodeLen = DefaultRandomCodeLen + } + if config.ACExpiresIn == 0 { + config.ACExpiresIn = DefaultACExpiresIn + } + if config.ATExpiresIn == 0 { + config.ATExpiresIn = DefaultATExpiresIn + } + if config.RTExpiresIn == 0 { + config.RTExpiresIn = DefaultRTExpiresIn + } + acManager := &ACManager{ + oAuthManager: oaManager, + config: config, + } + return acManager +} + +// ACManager 授权码模式管理(Authorization Code Manager) +type ACManager struct { + oAuthManager *OAuthManager // 授权管理 + config *ACConfig // 配置参数 +} + +// GenerateCode 生成授权码 +// clientID 客户端标识 +// userID 用户标识 +// redirectURI 重定向URI +// scopes 应用授权标识 +func (am *ACManager) GenerateCode(clientID, userID, redirectURI, scopes string) (code string, err error) { + cli, err := am.oAuthManager.ValidateClient(clientID, redirectURI) + if err != nil { + return + } + acInfo := ACInfo{ + ClientID: cli.ID(), + UserID: userID, + RedirectURI: redirectURI, + Scope: scopes, + Code: lib.NewRandom(am.config.RandomCodeLen).NumberAndLetter(), + CreateAt: time.Now().Unix(), + ExpiresIn: time.Duration(am.config.ACExpiresIn) * time.Second, + } + id, err := am.oAuthManager.ACStore.Put(acInfo) + if err != nil { + return + } + acInfo.ID = id + code, err = am.oAuthManager.ACGenerate.Code(&acInfo) + return +} + +// GenerateToken 生成令牌 +// code 授权码 +// redirectURI 重定向URI +// clientID 客户端标识 +// clientSecret 客户端秘钥 +// isGenerateRefresh 是否生成更新令牌 +func (am *ACManager) GenerateToken(code, redirectURI, clientID, clientSecret string, isGenerateRefresh bool) (token *Token, err error) { + acInfo, err := am.getACInfo(code) + if err != nil { + return + } else if acInfo.RedirectURI != redirectURI { + err = ErrACInvalid + return + } else if acInfo.ClientID != clientID { + err = ErrACInvalid + return + } + cli, err := am.oAuthManager.ClientStore.GetByID(acInfo.ClientID) + if err != nil { + return + } else if clientSecret != cli.Secret() { + err = ErrCSInvalid + return + } + createAt := time.Now().Unix() + basicInfo := NewTokenBasicInfo(cli, acInfo.UserID, createAt) + atValue, err := am.oAuthManager.TokenGenerate.AccessToken(basicInfo) + if err != nil { + return + } + tokenValue := Token{ + ClientID: acInfo.ClientID, + UserID: acInfo.UserID, + AccessToken: atValue, + ATCreateAt: createAt, + ATExpiresIn: time.Duration(am.config.ATExpiresIn) * time.Second, + Scope: acInfo.Scope, + CreateAt: createAt, + Status: Actived, + } + if isGenerateRefresh { + rtValue, rtErr := am.oAuthManager.TokenGenerate.RefreshToken(basicInfo) + if rtErr != nil { + err = rtErr + return + } + tokenValue.RefreshToken = rtValue + tokenValue.RTCreateAt = createAt + tokenValue.RTExpiresIn = time.Duration(am.config.RTExpiresIn) * time.Second + } + id, err := am.oAuthManager.TokenStore.Create(tokenValue) + if err != nil { + return + } + tokenValue.ID = id + token = &tokenValue + return +} + +// getACInfo 根据授权码获取授权信息 +func (am *ACManager) getACInfo(code string) (info *ACInfo, err error) { + if code == "" { + err = ErrACNotFound + return + } + acID, err := am.oAuthManager.ACGenerate.Parse(code) + if err != nil { + return + } + acInfo, err := am.oAuthManager.ACStore.TakeByID(acID) + if err != nil { + return + } + acValid, err := am.oAuthManager.ACGenerate.Verify(code, acInfo) + if err != nil { + return + } + if !acValid || + (acInfo.CreateAt+int64(acInfo.ExpiresIn/time.Second)) < time.Now().Unix() { + err = ErrACInvalid + return + } + info = acInfo + return +} diff --git a/authorizationCodeGenerate.go b/authorizationCodeGenerate.go new file mode 100644 index 0000000..1414998 --- /dev/null +++ b/authorizationCodeGenerate.go @@ -0,0 +1,101 @@ +package oauth2 + +import ( + "bytes" + "encoding/base64" + "errors" + "strconv" + "strings" + + "gopkg.in/LyricTian/lib.v2" +) + +// ACGenerate 授权码生成接口(Authorization Code Generate) +type ACGenerate interface { + // Code 根据授权码相关信息生成授权码 + Code(info *ACInfo) (string, error) + + // Parse 解析授权码,返回授权信息ID + Parse(code string) (int64, error) + + // Verify 验证授权码的有效性 + Verify(code string, info *ACInfo) (bool, error) +} + +// NewDefaultACGenerate 创建默认的授权码生成方式 +func NewDefaultACGenerate() ACGenerate { + return &ACGenerateDefault{} +} + +// ACGenerateDefault 默认的授权码生成方式 +type ACGenerateDefault struct{} + +func (ag *ACGenerateDefault) genToken(info *ACInfo) (string, error) { + var buf bytes.Buffer + _, _ = buf.WriteString(info.ClientID) + _ = buf.WriteByte('_') + _, _ = buf.WriteString(info.UserID) + _ = buf.WriteByte('\n') + _, _ = buf.WriteString(strconv.FormatInt(info.CreateAt, 10)) + _ = buf.WriteByte('\n') + _, _ = buf.WriteString(info.Code) + md5Val, err := lib.NewEncryption(buf.Bytes()).MD5() + if err != nil { + return "", err + } + buf.Reset() + md5Val = md5Val[:15] + return md5Val, nil +} + +// Code Authorization code +func (ag *ACGenerateDefault) Code(info *ACInfo) (string, error) { + tokenVal, err := ag.genToken(info) + if err != nil { + return "", err + } + val := base64.URLEncoding.EncodeToString([]byte(tokenVal + "." + strconv.FormatInt(info.ID, 10))) + return strings.TrimRight(val, "="), nil +} + +func (ag *ACGenerateDefault) parse(code string) (id int64, token string, err error) { + codeLen := len(code) % 4 + if codeLen > 0 { + codeLen = 4 - codeLen + } + code = code + strings.Repeat("=", codeLen) + codeVal, err := base64.URLEncoding.DecodeString(code) + if err != nil { + return + } + tokenVal := strings.SplitN(string(codeVal), ".", 2) + if len(tokenVal) != 2 { + err = errors.New("Token is invalid") + return + } + id, err = strconv.ParseInt(tokenVal[1], 10, 64) + if err != nil { + return + } + token = tokenVal[0] + return +} + +// Parse Parse authorization code +func (ag *ACGenerateDefault) Parse(code string) (id int64, err error) { + id, _, err = ag.parse(code) + return +} + +// Verify Verify code +func (ag *ACGenerateDefault) Verify(code string, info *ACInfo) (valid bool, err error) { + _, token, err := ag.parse(code) + if err != nil { + return + } + tokenVal, err := ag.genToken(info) + if err != nil { + return + } + return token == tokenVal, nil +} diff --git a/authorizationCodeGenerate_test.go b/authorizationCodeGenerate_test.go new file mode 100644 index 0000000..9c07c82 --- /dev/null +++ b/authorizationCodeGenerate_test.go @@ -0,0 +1,39 @@ +package oauth2_test + +import ( + "testing" + "time" + + "gopkg.in/LyricTian/lib.v2" + "gopkg.in/oauth2.v1" + + . "github.com/smartystreets/goconvey/convey" +) + +func TestACGenerate(t *testing.T) { + Convey("Authorization code generate test", t, func() { + acGenerate := oauth2.NewDefaultACGenerate() + info := &oauth2.ACInfo{ + ID: 1, + ClientID: "123456", + UserID: "999999", + Code: lib.NewRandom(6).NumberAndLetter(), + CreateAt: time.Now().Unix(), + } + Convey("Generate code", func() { + code, err := acGenerate.Code(info) + So(err, ShouldBeNil) + So(code, ShouldNotBeBlank) + Convey("Parse code", func() { + id, err := acGenerate.Parse(code) + So(err, ShouldBeNil) + So(id, ShouldEqual, 1) + }) + Convey("Verify code", func() { + valid, err := acGenerate.Verify(code, info) + So(err, ShouldBeNil) + So(valid, ShouldBeTrue) + }) + }) + }) +} diff --git a/authorizationCodeMemoryStore.go b/authorizationCodeMemoryStore.go new file mode 100644 index 0000000..87f3fb4 --- /dev/null +++ b/authorizationCodeMemoryStore.go @@ -0,0 +1,87 @@ +package oauth2 + +import ( + "container/list" + "errors" + "sync" + "sync/atomic" + "time" +) + +// NewACMemoryStore 创建授权码的内存存储 +// gcInterval GC周期(单位秒,默认60秒执行一次) +func NewACMemoryStore(gcInterval int64) ACStore { + if gcInterval == 0 { + gcInterval = 60 + } + memStore := &ACMemoryStore{ + gcInterval: time.Second * time.Duration(gcInterval), + data: list.New(), + } + go memStore.gc() + return memStore +} + +// ACMemoryStore 提供授权码的内存存储 +type ACMemoryStore struct { + sync.RWMutex + globalID int64 + gcInterval time.Duration + data *list.List +} + +func (am *ACMemoryStore) gc() { + time.AfterFunc(am.gcInterval, func() { + defer am.gc() + for { + am.RLock() + ele := am.data.Front() + if ele == nil { + am.RUnlock() + break + } + item := ele.Value.(ACInfo) + am.RUnlock() + if (item.CreateAt + int64(item.ExpiresIn/time.Second)) < time.Now().Unix() { + am.Lock() + am.data.Remove(ele) + am.Unlock() + continue + } + break + } + }) +} + +// Put Put item +func (am *ACMemoryStore) Put(item ACInfo) (int64, error) { + am.Lock() + defer am.Unlock() + atomic.AddInt64(&am.globalID, 1) + item.ID = am.globalID + am.data.PushBack(item) + return item.ID, nil +} + +// TakeByID Take item by ID +func (am *ACMemoryStore) TakeByID(id int64) (*ACInfo, error) { + am.RLock() + var takeEle *list.Element + for ele := am.data.Back(); ele != nil; ele = ele.Prev() { + item := ele.Value.(ACInfo) + if item.ID == id { + takeEle = ele + break + } + } + if takeEle == nil { + am.RUnlock() + return nil, errors.New("Item not found") + } + item := takeEle.Value.(ACInfo) + am.RUnlock() + am.Lock() + am.data.Remove(takeEle) + am.Unlock() + return &item, nil +} diff --git a/authorizationCodeMemoryStore_test.go b/authorizationCodeMemoryStore_test.go new file mode 100644 index 0000000..6c86d02 --- /dev/null +++ b/authorizationCodeMemoryStore_test.go @@ -0,0 +1,40 @@ +package oauth2_test + +import ( + "testing" + "time" + + "gopkg.in/oauth2.v1" + + . "github.com/smartystreets/goconvey/convey" +) + +func TestACMemoryStore(t *testing.T) { + Convey("AC memory store test", t, func() { + store := oauth2.NewACMemoryStore(1) + item := oauth2.ACInfo{ + ClientID: "123456", + UserID: "999999", + CreateAt: time.Now().Unix(), + ExpiresIn: time.Millisecond * 500, + } + Convey("Put Test", func() { + id, err := store.Put(item) + So(err, ShouldBeNil) + So(id, ShouldEqual, 1) + item.ID = id + Convey("Take Test", func() { + info, err := store.TakeByID(id) + So(err, ShouldBeNil) + So(info.ClientID, ShouldEqual, item.ClientID) + So(info.UserID, ShouldEqual, item.UserID) + }) + Convey("Take GC Test", func() { + time.Sleep(time.Second * 2) + info, err := store.TakeByID(id) + So(err, ShouldNotBeNil) + So(info, ShouldBeNil) + }) + }) + }) +} diff --git a/authorizationCodeStore.go b/authorizationCodeStore.go new file mode 100644 index 0000000..67658e3 --- /dev/null +++ b/authorizationCodeStore.go @@ -0,0 +1,28 @@ +package oauth2 + +import ( + "time" +) + +// ACInfo 授权码信息(Authorization Code Info) +type ACInfo struct { + ID int64 // 唯一标识 + ClientID string // 客户端标识 + UserID string // 用户标识 + RedirectURI string // 重定向URI + Scope string // 申请的权限范围 + Code string // 随机码 + CreateAt int64 // 创建时间(时间戳) + ExpiresIn time.Duration // 有效期(单位秒) +} + +// ACStore 授权码存储接口(临时存储,提供自动GC过期的元素)(Authorization Code Store) +type ACStore interface { + // Put 将元素放入存储,返回存储的ID + // 如果存储发生异常,则返回错误 + Put(item ACInfo) (int64, error) + + // TakeByID 根据ID取出元素 + // 如果元素找不到或发生异常,则返回错误 + TakeByID(id int64) (*ACInfo, error) +} diff --git a/authorizationCode_test.go b/authorizationCode_test.go new file mode 100644 index 0000000..e9bc583 --- /dev/null +++ b/authorizationCode_test.go @@ -0,0 +1,46 @@ +package oauth2_test + +import ( + "testing" + + "gopkg.in/oauth2.v1" + + . "github.com/smartystreets/goconvey/convey" +) + +func TestACManager(t *testing.T) { + ClientHandle(func(info oauth2.Client) { + userID := "999999" + oManager, err := oauth2.CreateDefaultOAuthManager(oauth2.NewMongoConfig(MongoURL, DBName), "", "", nil) + if err != nil { + t.Error(err) + } + Convey("Authorization Code Manager Test", t, func() { + manager := oManager.GetACManager() + + redirectURI := "http://www.example.com/cb" + code, err := manager.GenerateCode(info.ID(), userID, redirectURI, "all") + So(err, ShouldBeNil) + + accessToken, err := manager.GenerateToken(code, redirectURI, info.ID(), info.Secret(), true) + So(err, ShouldBeNil) + So(accessToken.UserID, ShouldEqual, userID) + + checkAT, err := oManager.CheckAccessToken(accessToken.AccessToken) + So(err, ShouldBeNil) + So(checkAT.ClientID, ShouldEqual, info.ID()) + So(checkAT.UserID, ShouldEqual, userID) + + newAT, err := oManager.RefreshAccessToken(checkAT.RefreshToken, "") + So(err, ShouldBeNil) + So(newAT.AccessToken, ShouldNotEqual, checkAT.AccessToken) + + err = oManager.RevokeAccessToken(newAT.AccessToken) + So(err, ShouldBeNil) + + checkAT, err = oManager.CheckAccessToken(newAT.AccessToken) + So(err, ShouldNotBeNil) + So(checkAT, ShouldBeNil) + }) + }) +} diff --git a/clientCredentials.go b/clientCredentials.go new file mode 100644 index 0000000..a4f2f5b --- /dev/null +++ b/clientCredentials.go @@ -0,0 +1,62 @@ +package oauth2 + +import "time" + +// NewCCManager 创建默认的客户端模式管理实例 +// oaManager OAuth授权管理 +// config 配置参数(nil则使用默认值) +func NewCCManager(oaManager *OAuthManager, config *CCConfig) *CCManager { + if config == nil { + config = new(CCConfig) + } + if config.ATExpiresIn == 0 { + config.ATExpiresIn = DefaultCCATExpiresIn + } + ccManager := &CCManager{ + oAuthManager: oaManager, + config: config, + } + return ccManager +} + +// CCManager 客户端模式管理(Client Credentials Manager) +type CCManager struct { + oAuthManager *OAuthManager // 授权管理 + config *CCConfig // 配置参数 +} + +// GenerateToken 生成令牌(只生成访问令牌) +// clientID 客户端标识 +// clientSecret 客户端秘钥 +// scopes 应用授权标识 +func (cm *CCManager) GenerateToken(clientID, clientSecret, scopes string) (token *Token, err error) { + cli, err := cm.oAuthManager.GetClient(clientID) + if err != nil { + return + } else if cli.Secret() != clientSecret { + err = ErrCSInvalid + return + } + createAt := time.Now().Unix() + basicInfo := NewTokenBasicInfo(cli, "", createAt) + atValue, err := cm.oAuthManager.TokenGenerate.AccessToken(basicInfo) + if err != nil { + return + } + tokenValue := Token{ + ClientID: clientID, + AccessToken: atValue, + ATCreateAt: createAt, + ATExpiresIn: time.Duration(cm.config.ATExpiresIn) * time.Second, + Scope: scopes, + CreateAt: createAt, + Status: Actived, + } + id, err := cm.oAuthManager.TokenStore.Create(tokenValue) + if err != nil { + return + } + tokenValue.ID = id + token = &tokenValue + return +} diff --git a/clientCredentials_test.go b/clientCredentials_test.go new file mode 100644 index 0000000..528e2c4 --- /dev/null +++ b/clientCredentials_test.go @@ -0,0 +1,27 @@ +package oauth2_test + +import ( + . "github.com/smartystreets/goconvey/convey" + "gopkg.in/oauth2.v1" + + "testing" +) + +func TestCCManager(t *testing.T) { + ClientHandle(func(cli oauth2.Client) { + oManager, err := oauth2.CreateDefaultOAuthManager(oauth2.NewMongoConfig(MongoURL, DBName), "", "", nil) + if err != nil { + t.Error(err) + } + Convey("Client Credentials Manager Test", t, func() { + manager := oManager.GetCCManager() + + token, err := manager.GenerateToken(cli.ID(), cli.Secret(), "all") + So(err, ShouldBeNil) + + checkToken, err := oManager.CheckAccessToken(token.AccessToken) + So(err, ShouldBeNil) + So(checkToken.ClientID, ShouldEqual, cli.ID()) + }) + }) +} diff --git a/clientMongoStore.go b/clientMongoStore.go new file mode 100644 index 0000000..fe2224a --- /dev/null +++ b/clientMongoStore.go @@ -0,0 +1,49 @@ +package oauth2 + +import ( + "gopkg.in/LyricTian/lib.v2/mongo" + "gopkg.in/mgo.v2" +) + +const ( + // DefaultClientCollectionName 默认的客户端存储集合名称 + DefaultClientCollectionName = "ClientInfo" +) + +// NewClientMongoStore 创建基于MongoDB的客户端存储方式 +// mongoConfig MongoDB配置参数 +// cName 存储客户端的集合名称(默认为ClientInfo) +func NewClientMongoStore(mongoConfig *MongoConfig, cName string) (ClientStore, error) { + mHandler, err := mongo.InitHandlerWithDB(mongoConfig.URL, mongoConfig.DBName) + if err != nil { + return nil, err + } + if cName == "" { + cName = DefaultClientCollectionName + } + return &ClientMongoStore{ + cName: cName, + mHandler: mHandler, + }, nil +} + +// ClientMongoStore 基于MongoDB的默认客户端信息存储 +type ClientMongoStore struct { + cName string + mHandler *mongo.Handler +} + +// GetByID 根据ID获取客户端信息 +func (dcm *ClientMongoStore) GetByID(id string) (client Client, err error) { + dcm.mHandler.CHandle(dcm.cName, func(c *mgo.Collection) { + var result []DefaultClient + err = dcm.mHandler.C(dcm.cName).FindId(id).Limit(1).All(&result) + if err != nil { + return + } + if len(result) > 0 { + client = result[0] + } + }) + return +} diff --git a/clientMongoStore_test.go b/clientMongoStore_test.go new file mode 100644 index 0000000..92c0b78 --- /dev/null +++ b/clientMongoStore_test.go @@ -0,0 +1,21 @@ +package oauth2_test + +import ( + "testing" + + "gopkg.in/oauth2.v1" + + . "github.com/smartystreets/goconvey/convey" +) + +func TestClientMongoStore(t *testing.T) { + ClientHandle(func(info oauth2.Client) { + Convey("Client mongodb store test", t, func() { + clientStore, err := oauth2.NewClientMongoStore(oauth2.NewMongoConfig(MongoURL, DBName), "") + So(err, ShouldBeNil) + client, err := clientStore.GetByID(info.ID()) + So(err, ShouldBeNil) + So(client.Secret(), ShouldEqual, info.Secret()) + }) + }) +} diff --git a/clientStore.go b/clientStore.go new file mode 100644 index 0000000..e4504af --- /dev/null +++ b/clientStore.go @@ -0,0 +1,47 @@ +package oauth2 + +// Client 客户端的验证信息接口 +type Client interface { + // ID 客户端唯一标识 + ID() string + // Secret 客户端秘钥 + Secret() string + // Domain 客户端域名 + Domain() string + // RetainData 保留数据 + RetainData() interface{} +} + +// ClientStore 客户端存储接口(持久化存储) +type ClientStore interface { + // GetByID 根据ID获取客户端信息; + // 如果客户端不存在则返回nil + GetByID(id string) (Client, error) +} + +// DefaultClient 默认的客户端信息 +type DefaultClient struct { + ClientID string `bson:"_id"` // 客户端唯一标识 + ClientSecret string `bson:"Secret"` // 客户端秘钥 + ClientDomain string `bson:"Domain"` // 客户端域名 +} + +// ID Get ClientID +func (dc DefaultClient) ID() string { + return dc.ClientID +} + +// Secret Get ClientSecret +func (dc DefaultClient) Secret() string { + return dc.ClientSecret +} + +// Domain Get ClientDomain +func (dc DefaultClient) Domain() string { + return dc.ClientDomain +} + +// RetainData Get retain data +func (dc DefaultClient) RetainData() interface{} { + return dc +} diff --git a/config.go b/config.go new file mode 100644 index 0000000..9943712 --- /dev/null +++ b/config.go @@ -0,0 +1,47 @@ +package oauth2 + +// MongoConfig MongoDB配置参数 +type MongoConfig struct { + URL string // MongoDB连接字符串 + DBName string // 数据库名称 +} + +// NewMongoConfig 创建MongoDB配置参数的实例 +func NewMongoConfig(url, dbName string) *MongoConfig { + return &MongoConfig{ + URL: url, + DBName: dbName, + } +} + +// ACConfig 授权码模式配置参数(Authorization Code Config) +type ACConfig struct { + RandomCodeLen int // 随机码的长度(用于生成授权码的随机码) + ACExpiresIn int64 // 授权码有效期(单位秒) + ATExpiresIn int64 // 访问令牌有效期(单位秒) + RTExpiresIn int64 // 更新令牌有效期(单位秒) +} + +// ImplicitConfig 简化模式配置参数 +type ImplicitConfig struct { + ATExpiresIn int64 // 访问令牌有效期(单位秒) +} + +// PasswordConfig 密码模式配置参数 +type PasswordConfig struct { + ATExpiresIn int64 // 访问令牌有效期(单位秒) + RTExpiresIn int64 // 更新令牌有效期(单位秒) +} + +// CCConfig 客户端模式配置参数(Client Credentials Config) +type CCConfig struct { + ATExpiresIn int64 // 访问令牌有效期(单位秒) +} + +// OAuthConfig OAuth授权配置参数 +type OAuthConfig struct { + ACConfig *ACConfig // 授权码模式配置参数 + ImplicitConfig *ImplicitConfig // 简化模式配置参数 + PasswordConfig *PasswordConfig // 密码模式配置参数 + CCConfig *CCConfig // 客户端模式配置参数 +} diff --git a/const.go b/const.go new file mode 100644 index 0000000..97b991c --- /dev/null +++ b/const.go @@ -0,0 +1,30 @@ +package oauth2 + +const ( + // DefaultRandomCodeLen 默认随机码的长度 + DefaultRandomCodeLen = 6 + // DefaultACExpiresIn 默认授权码模式的授权码有效期(10分钟) + DefaultACExpiresIn = 60 * 10 + // DefaultATExpiresIn 默认授权码模式的访问令牌有效期(7天) + DefaultATExpiresIn = 60 * 60 * 24 * 7 + // DefaultRTExpiresIn 默认授权码模式的更新令牌有效期(30天) + DefaultRTExpiresIn = 60 * 60 * 24 * 30 + // DefaultIATExpiresIn 默认简化模式的访问令牌有效期(1小时) + DefaultIATExpiresIn = 60 * 60 + // DefaultCCATExpiresIn 默认客户端模式的访问令牌有效期(1天) + DefaultCCATExpiresIn = 60 * 60 * 24 +) + +// STATUS 提供一些状态标识 +type STATUS byte + +const ( + // Deleted 删除状态 + Deleted STATUS = iota + // Actived 激活状态 + Actived + // Blocked 冻结状态 + Blocked + // Expired 过期状态 + Expired +) diff --git a/error.go b/error.go new file mode 100644 index 0000000..18ecf25 --- /dev/null +++ b/error.go @@ -0,0 +1,37 @@ +package oauth2 + +import ( + "errors" +) + +var ( + // ErrClientNotFound Client not found + ErrClientNotFound = errors.New("The client is not found.") + + // ErrACNotFound Authorization code not found + ErrACNotFound = errors.New("The authorization code is not found.") + + // ErrACInvalid Authorization code invalid + ErrACInvalid = errors.New("The authorization code is invalid.") + + // ErrCSInvalid Client secret invalid + ErrCSInvalid = errors.New("The client secret is invalid.") + + // ErrATNotFound Refresh token not found + ErrATNotFound = errors.New("The access token is not found.") + + // ErrATInvalid Access token invalid + ErrATInvalid = errors.New("The access token is invalid.") + + // ErrATExpire Access token expire + ErrATExpire = errors.New("The access token is expire.") + + // ErrRTNotFound Refresh token not found + ErrRTNotFound = errors.New("The refresh token is not found.") + + // ErrRTInvalid Refresh token invalid + ErrRTInvalid = errors.New("The refresh token is invalid.") + + // ErrRTExpire Refresh token expire + ErrRTExpire = errors.New("The refresh token is expire.") +) diff --git a/implicit.go b/implicit.go new file mode 100644 index 0000000..a07e7ca --- /dev/null +++ b/implicit.go @@ -0,0 +1,61 @@ +package oauth2 + +import "time" + +// NewImplicitManager 创建默认的简化模式管理实例 +// oaManager OAuth授权管理 +// config 配置参数(nil则使用默认值) +func NewImplicitManager(oaManager *OAuthManager, config *ImplicitConfig) *ImplicitManager { + if config == nil { + config = new(ImplicitConfig) + } + if config.ATExpiresIn == 0 { + config.ATExpiresIn = DefaultIATExpiresIn + } + iManager := &ImplicitManager{ + oAuthManager: oaManager, + config: config, + } + return iManager +} + +// ImplicitManager 简化模式管理 +type ImplicitManager struct { + oAuthManager *OAuthManager // 授权管理 + config *ImplicitConfig // 配置参数 +} + +// GenerateToken 生成令牌(只生成访问令牌) +// clientID 客户端标识 +// userID 用户标识 +// redirectURI 重定向URI +// scopes 应用授权标识 +func (im *ImplicitManager) GenerateToken(clientID, userID, redirectURI, scopes string) (token *Token, err error) { + cli, err := im.oAuthManager.ValidateClient(clientID, redirectURI) + if err != nil { + return + } + createAt := time.Now().Unix() + basicInfo := NewTokenBasicInfo(cli, userID, createAt) + atValue, err := im.oAuthManager.TokenGenerate.AccessToken(basicInfo) + if err != nil { + return + } + tokenValue := Token{ + ClientID: clientID, + UserID: userID, + AccessToken: atValue, + ATCreateAt: createAt, + ATExpiresIn: time.Duration(im.config.ATExpiresIn) * time.Second, + Scope: scopes, + CreateAt: createAt, + Status: Actived, + } + id, err := im.oAuthManager.TokenStore.Create(tokenValue) + if err != nil { + return + } + tokenValue.ID = id + token = &tokenValue + return +} diff --git a/implicit_test.go b/implicit_test.go new file mode 100644 index 0000000..4c454ea --- /dev/null +++ b/implicit_test.go @@ -0,0 +1,29 @@ +package oauth2_test + +import ( + . "github.com/smartystreets/goconvey/convey" + "gopkg.in/oauth2.v1" + + "testing" +) + +func TestImplicitManager(t *testing.T) { + ClientHandle(func(cli oauth2.Client) { + userID := "999999" + oManager, err := oauth2.CreateDefaultOAuthManager(oauth2.NewMongoConfig(MongoURL, DBName), "", "", nil) + if err != nil { + t.Error(err) + } + Convey("Implicit Manager Test", t, func() { + manager := oManager.GetImplicitManager() + + token, err := manager.GenerateToken(cli.ID(), userID, "http://www.example.com/cb", "all") + So(err, ShouldBeNil) + + checkToken, err := oManager.CheckAccessToken(token.AccessToken) + So(err, ShouldBeNil) + So(checkToken.ClientID, ShouldEqual, cli.ID()) + So(checkToken.UserID, ShouldEqual, userID) + }) + }) +} diff --git a/oauth2.go b/oauth2.go new file mode 100644 index 0000000..9df2aa5 --- /dev/null +++ b/oauth2.go @@ -0,0 +1,238 @@ +package oauth2 + +import "time" + +// CreateDefaultOAuthManager 创建默认的OAuth授权管理实例 +// mongoConfig MongoDB配置参数 +// tokenCollectionName 存储令牌的集合名称(默认为AuthToken) +// clientCollectionName 存储客户端的集合名称(默认为ClientInfo) +// oauthConfig 配置参数 +func CreateDefaultOAuthManager(mongoConfig *MongoConfig, tokenCollectionName, clientCollectionName string, oauthConfig *OAuthConfig) (*OAuthManager, error) { + if oauthConfig == nil { + oauthConfig = new(OAuthConfig) + } + oaManager := &OAuthManager{ + Config: oauthConfig, + ACGenerate: NewDefaultACGenerate(), + ACStore: NewACMemoryStore(0), + TokenGenerate: NewDefaultTokenGenerate(), + } + tokenStore, err := NewTokenMongoStore(mongoConfig, tokenCollectionName) + if err != nil { + return nil, err + } + oaManager.TokenStore = tokenStore + clientStore, err := NewClientMongoStore(mongoConfig, clientCollectionName) + if err != nil { + return nil, err + } + oaManager.ClientStore = clientStore + return oaManager, nil +} + +// OAuthManager OAuth授权管理 +type OAuthManager struct { + Config *OAuthConfig // 配置参数 + ACGenerate ACGenerate // 授权码生成 + ACStore ACStore // 授权码存储 + TokenGenerate TokenGenerate // 令牌生成 + TokenStore TokenStore // 令牌存储 + ClientStore ClientStore // 客户端存储 +} + +// SetACGenerate 设置授权码生成接口 +func (om *OAuthManager) SetACGenerate(generate ACGenerate) { + om.ACGenerate = generate +} + +// SetACStore 设置授权码存储接口 +func (om *OAuthManager) SetACStore(store ACStore) { + om.ACStore = store +} + +// GetACManager 获取授权码模式管理实例 +func (om *OAuthManager) GetACManager() *ACManager { + return NewACManager(om, om.Config.ACConfig) +} + +// GetImplicitManager 获取简化模式管理实例 +func (om *OAuthManager) GetImplicitManager() *ImplicitManager { + return NewImplicitManager(om, om.Config.ImplicitConfig) +} + +// GetPasswordManager 获取密码模式管理实例 +func (om *OAuthManager) GetPasswordManager() *PasswordManager { + return NewPasswordManager(om, om.Config.PasswordConfig) +} + +// GetCCManager 获取客户端模式管理实例 +func (om *OAuthManager) GetCCManager() *CCManager { + return NewCCManager(om, om.Config.CCConfig) +} + +// GetClient 根据客户端标识获取客户端信息 +// clientID 客户端标识 +func (om *OAuthManager) GetClient(clientID string) (cli Client, err error) { + cli, err = om.ClientStore.GetByID(clientID) + if err != nil { + return + } else if cli == nil { + err = ErrClientNotFound + } + return +} + +// ValidateClient 验证客户端的重定向URI +// clientID 客户端标识 +// redirectURI 重定向URI +func (om *OAuthManager) ValidateClient(clientID, redirectURI string) (cli Client, err error) { + cli, err = om.GetClient(clientID) + if err != nil { + return + } else if v := ValidateURI(cli.Domain(), redirectURI); v != nil { + err = v + } + return +} + +// CheckAccessToken 检查访问令牌是否可用,同时返回该令牌的相关信息 +// accessToken 访问令牌 +func (om *OAuthManager) CheckAccessToken(accessToken string) (token *Token, err error) { + if accessToken == "" { + err = ErrATNotFound + return + } + tokenValue, err := om.TokenStore.GetByAccessToken(accessToken) + if err != nil { + return + } else if tokenValue == nil { + err = ErrATNotFound + return + } else if tokenValue.Status != Actived { + err = ErrATInvalid + return + } else if v := om.checkRefreshTokenExpire(tokenValue); v != nil { + err = v + return + } else if v := om.checkAccessTokenExpire(tokenValue); v != nil { + err = v + return + } + token = tokenValue + return +} + +// RevokeAccessToken 废除访问令牌(将该访问令牌的状态更改为删除) +// accessToken 访问令牌 +func (om *OAuthManager) RevokeAccessToken(accessToken string) (err error) { + if accessToken == "" { + err = ErrATNotFound + return + } + token, err := om.TokenStore.GetByAccessToken(accessToken) + if err != nil { + return + } else if token == nil { + err = ErrATNotFound + return + } else if token.Status != Actived { + err = ErrATInvalid + return + } + info := map[string]interface{}{ + "Status": Deleted, + } + err = om.TokenStore.Update(token.ID, info) + return +} + +// RefreshAccessToken 更新访问令牌(在更新令牌有效期内,更新访问令牌的有效期),同时返回更新后的令牌信息 +// refreshToken 更新令牌 +// scopes 申请的权限范围(不可以超出上一次申请的范围,如果省略该参数,则表示与上一次一致) +func (om *OAuthManager) RefreshAccessToken(refreshToken, scopes string) (token *Token, err error) { + if refreshToken == "" { + err = ErrRTNotFound + return + } + tokenValue, err := om.TokenStore.GetByRefreshToken(refreshToken) + if err != nil { + return + } else if tokenValue == nil { + err = ErrRTNotFound + return + } else if tokenValue.Status != Actived { + err = ErrRTInvalid + return + } else if v := om.checkRefreshTokenExpire(tokenValue); v != nil { + err = v + return + } + cli, err := om.GetClient(tokenValue.ClientID) + if err != nil { + return + } + tokenValue.ATCreateAt = time.Now().Unix() + atValue, err := om.TokenGenerate.AccessToken(NewTokenBasicInfo(cli, tokenValue.UserID, tokenValue.ATCreateAt)) + if err != nil { + return + } + tokenValue.AccessToken = atValue + tokenInfo := map[string]interface{}{ + "AccessToken": tokenValue.AccessToken, + "ATCreateAt": tokenValue.ATCreateAt, + } + if scopes != "" { + tokenValue.Scope = scopes + tokenInfo["Scope"] = tokenValue.Scope + } + err = om.TokenStore.Update(tokenValue.ID, tokenInfo) + if err != nil { + return + } + token = tokenValue + return +} + +// checkAccessTokenExpire 检查访问令牌是否过期, +// 如果访问令牌过期同时没有更新令牌的情况下, +// 则将令牌状态更改为过期 +func (om *OAuthManager) checkAccessTokenExpire(token *Token) error { + if token.AccessToken == "" { + return nil + } + nowUnix := time.Now().Unix() + if (token.ATCreateAt + int64(token.ATExpiresIn/time.Second)) > nowUnix { + return nil + } + var err error + if token.RefreshToken == "" { + info := map[string]interface{}{ + "Status": Expired, + } + err = om.TokenStore.Update(token.ID, info) + if err == nil { + err = ErrATExpire + } + } + return err +} + +// checkRefreshTokenExpire 检查更新令牌是否过期, +// 如果更新令牌过期则将令牌状态更改为过期 +func (om *OAuthManager) checkRefreshTokenExpire(token *Token) error { + if token.RefreshToken == "" { + return nil + } + nowUnix := time.Now().Unix() + if (token.RTCreateAt + int64(token.RTExpiresIn/time.Second)) > nowUnix { + return nil + } + info := map[string]interface{}{ + "Status": Expired, + } + err := om.TokenStore.Update(token.ID, info) + if err == nil { + err = ErrRTExpire + } + return err +} diff --git a/oauth2_test.go b/oauth2_test.go new file mode 100644 index 0000000..92b40b6 --- /dev/null +++ b/oauth2_test.go @@ -0,0 +1,40 @@ +package oauth2_test + +import ( + "gopkg.in/LyricTian/lib.v2" + "gopkg.in/LyricTian/lib.v2/mongo" + "gopkg.in/mgo.v2/bson" + "gopkg.in/oauth2.v1" +) + +const ( + // MongoURL MongoDB连接字符串 + MongoURL = "mongodb://admin:123456@45.78.35.157:37017" + // DBName 数据库名称 + DBName = "test" +) + +// ClientHandle 执行客户端处理 +func ClientHandle(handle func(cli oauth2.Client)) { + info := oauth2.DefaultClient{ + ClientID: bson.NewObjectId().Hex(), + ClientDomain: "http://www.example.com", + } + info.ClientSecret, _ = lib.NewEncryption([]byte(info.ClientID)).MD5() + mHandler, err := mongo.InitHandlerWithDB(MongoURL, DBName) + if err != nil { + panic(err) + } + defer func() { + err = mHandler.C(oauth2.DefaultClientCollectionName).RemoveId(info.ClientID) + if err != nil { + panic(err) + } + mHandler.Session().Close() + }() + err = mHandler.C(oauth2.DefaultClientCollectionName).Insert(info) + if err != nil { + panic(err) + } + handle(info) +} diff --git a/password.go b/password.go new file mode 100644 index 0000000..93894d3 --- /dev/null +++ b/password.go @@ -0,0 +1,77 @@ +package oauth2 + +import "time" + +// NewPasswordManager 创建默认的密码模式管理实例 +// oaManager OAuth授权管理 +// config 配置参数(nil则使用默认值) +func NewPasswordManager(oaManager *OAuthManager, config *PasswordConfig) *PasswordManager { + if config == nil { + config = new(PasswordConfig) + } + if config.ATExpiresIn == 0 { + config.ATExpiresIn = DefaultATExpiresIn + } + if config.RTExpiresIn == 0 { + config.RTExpiresIn = DefaultRTExpiresIn + } + pManager := &PasswordManager{ + oAuthManager: oaManager, + config: config, + } + return pManager +} + +// PasswordManager 密码模式管理 +type PasswordManager struct { + oAuthManager *OAuthManager // 授权管理 + config *PasswordConfig // 配置参数 +} + +// GenerateToken 生成令牌(只生成访问令牌) +// clientID 客户端标识 +// userID 用户标识 +// clientSecret 客户端秘钥 +// scopes 应用授权标识 +func (pm *PasswordManager) GenerateToken(clientID, userID, clientSecret, scopes string, isGenerateRefresh bool) (token *Token, err error) { + cli, err := pm.oAuthManager.GetClient(clientID) + if err != nil { + return + } else if cli.Secret() != clientSecret { + err = ErrCSInvalid + return + } + createAt := time.Now().Unix() + basicInfo := NewTokenBasicInfo(cli, userID, createAt) + atValue, err := pm.oAuthManager.TokenGenerate.AccessToken(basicInfo) + if err != nil { + return + } + tokenValue := Token{ + ClientID: clientID, + UserID: userID, + AccessToken: atValue, + ATCreateAt: createAt, + ATExpiresIn: time.Duration(pm.config.ATExpiresIn) * time.Second, + Scope: scopes, + CreateAt: createAt, + Status: Actived, + } + if isGenerateRefresh { + rtValue, rtErr := pm.oAuthManager.TokenGenerate.RefreshToken(basicInfo) + if rtErr != nil { + err = rtErr + return + } + tokenValue.RefreshToken = rtValue + tokenValue.RTCreateAt = createAt + tokenValue.RTExpiresIn = time.Duration(pm.config.RTExpiresIn) * time.Second + } + id, err := pm.oAuthManager.TokenStore.Create(tokenValue) + if err != nil { + return + } + tokenValue.ID = id + token = &tokenValue + return +} diff --git a/password_test.go b/password_test.go new file mode 100644 index 0000000..575aa6e --- /dev/null +++ b/password_test.go @@ -0,0 +1,34 @@ +package oauth2_test + +import ( + "testing" + + "gopkg.in/oauth2.v1" + + . "github.com/smartystreets/goconvey/convey" +) + +func TestPasswordManager(t *testing.T) { + ClientHandle(func(info oauth2.Client) { + userID := "999999" + oManager, err := oauth2.CreateDefaultOAuthManager(oauth2.NewMongoConfig(MongoURL, DBName), "", "", nil) + if err != nil { + t.Error(err) + } + Convey("Password Manager Test", t, func() { + manager := oManager.GetPasswordManager() + + token, err := manager.GenerateToken(info.ID(), userID, info.Secret(), "all", true) + So(err, ShouldBeNil) + + checkAT, err := oManager.CheckAccessToken(token.AccessToken) + So(err, ShouldBeNil) + So(checkAT.ClientID, ShouldEqual, info.ID()) + So(checkAT.UserID, ShouldEqual, userID) + + newAT, err := oManager.RefreshAccessToken(checkAT.RefreshToken, "") + So(err, ShouldBeNil) + So(newAT.AccessToken, ShouldNotEqual, checkAT.AccessToken) + }) + }) +} diff --git a/tokenGenerate.go b/tokenGenerate.go new file mode 100644 index 0000000..5d4c674 --- /dev/null +++ b/tokenGenerate.go @@ -0,0 +1,74 @@ +package oauth2 + +import ( + "strconv" + + "gopkg.in/LyricTian/lib.v2" + + "bytes" +) + +// NewTokenBasicInfo 创建用于生成令牌的基础信息 +// cli 客户端信息 +// userID 用户标识 +// createAt 创建令牌的时间戳 +func NewTokenBasicInfo(cli Client, userID string, createAt int64) TokenBasicInfo { + return TokenBasicInfo{ + Client: cli, + UserID: userID, + CreateAt: createAt, + } +} + +// TokenBasicInfo 用于生成令牌的基础信息 +type TokenBasicInfo struct { + Client Client // 客户端信息 + UserID string // 用户标识 + CreateAt int64 // 创建令牌的时间戳 +} + +// TokenGenerate 令牌生成接口 +type TokenGenerate interface { + // AccessToken 根据客户端信息生成访问令牌 + AccessToken(basicInfo TokenBasicInfo) (string, error) + + // RefreshToken 根据客户端信息生成更新令牌 + RefreshToken(basicInfo TokenBasicInfo) (string, error) +} + +// NewDefaultTokenGenerate 创建默认的访问令牌生成方式 +func NewDefaultTokenGenerate() TokenGenerate { + return &TokenGenerateDefault{} +} + +// TokenGenerateDefault 提供默认的令牌生成 +// 采用MD5(ClientID+ClientSecret+RandomCode+Nanosecond Timestamp)的生成方式 +type TokenGenerateDefault struct{} + +func (tg *TokenGenerateDefault) generate(basicInfo TokenBasicInfo) (string, error) { + var buf bytes.Buffer + _, _ = buf.WriteString(basicInfo.Client.ID()) + if basicInfo.UserID != "" { + _ = buf.WriteByte('_') + _, _ = buf.WriteString(basicInfo.UserID) + } + _ = buf.WriteByte('\n') + _, _ = buf.WriteString(basicInfo.Client.Secret()) + _ = buf.WriteByte('\n') + _, _ = buf.WriteString(lib.NewRandom(6).NumberAndLetter()) + _ = buf.WriteByte('\n') + _, _ = buf.WriteString(strconv.FormatInt(basicInfo.CreateAt, 10)) + val, err := lib.NewEncryption(buf.Bytes()).MD5() + buf.Reset() + return val, err +} + +// AccessToken Generate access token +func (tg *TokenGenerateDefault) AccessToken(basicInfo TokenBasicInfo) (string, error) { + return tg.generate(basicInfo) +} + +// RefreshToken Generate refresh token +func (tg *TokenGenerateDefault) RefreshToken(basicInfo TokenBasicInfo) (string, error) { + return tg.generate(basicInfo) +} diff --git a/tokenGenerate_test.go b/tokenGenerate_test.go new file mode 100644 index 0000000..ec35017 --- /dev/null +++ b/tokenGenerate_test.go @@ -0,0 +1,36 @@ +package oauth2_test + +import ( + "testing" + "time" + + "gopkg.in/oauth2.v1" + + . "github.com/smartystreets/goconvey/convey" +) + +func TestTokenGenerate(t *testing.T) { + cli := oauth2.DefaultClient{ + ClientID: "123456", + ClientSecret: "654321", + ClientDomain: "http://www.lyric.name", + } + basicInfo := oauth2.TokenBasicInfo{ + Client: cli, + UserID: "999999", + CreateAt: time.Now().Unix(), + } + Convey("Token generate test", t, func() { + tokenGenerate := oauth2.NewDefaultTokenGenerate() + Convey("Generate access token", func() { + token, err := tokenGenerate.AccessToken(basicInfo) + So(err, ShouldBeNil) + _, _ = Println("\n [P]Access Token:" + token) + }) + Convey("Generate refresh token", func() { + token, err := tokenGenerate.RefreshToken(basicInfo) + So(err, ShouldBeNil) + _, _ = Println("\n [P]Refresh Token:" + token) + }) + }) +} diff --git a/tokenMongoStore.go b/tokenMongoStore.go new file mode 100644 index 0000000..289e53d --- /dev/null +++ b/tokenMongoStore.go @@ -0,0 +1,87 @@ +package oauth2 + +import ( + "gopkg.in/LyricTian/lib.v2/mongo" + "gopkg.in/mgo.v2" + "gopkg.in/mgo.v2/bson" +) + +const ( + // DefaultTokenCollectionName 默认的令牌存储集合名称 + DefaultTokenCollectionName = "AuthToken" +) + +// NewTokenMongoStore 创建基于MongoDB的令牌存储方式 +// mongoConfig MongoDB配置参数 +// cName 存储令牌的集合名称(默认为AuthToken) +func NewTokenMongoStore(mongoConfig *MongoConfig, cName string) (TokenStore, error) { + mHandler, err := mongo.InitHandlerWithDB(mongoConfig.URL, mongoConfig.DBName) + if err != nil { + return nil, err + } + if cName == "" { + cName = DefaultTokenCollectionName + } + return &TokenMongoStore{ + cName: cName, + mHandler: mHandler, + }, nil +} + +// TokenMongoStore 基于MongoDB的令牌存储方式 +type TokenMongoStore struct { + cName string + mHandler *mongo.Handler +} + +// Create Add item +func (tm *TokenMongoStore) Create(item Token) (id int64, err error) { + tm.mHandler.CHandle(tm.cName, func(c *mgo.Collection) { + tid, err := tm.mHandler.IncrID(tm.cName) + if err != nil { + return + } + item.ID = tid + err = c.Insert(item) + if err != nil { + return + } + id = item.ID + }) + return +} + +// Update Modify item +func (tm *TokenMongoStore) Update(id int64, info map[string]interface{}) (err error) { + tm.mHandler.CHandle(tm.cName, func(c *mgo.Collection) { + err = c.Update(bson.M{"ID": id}, bson.M{"$set": info}) + if err != nil { + return + } + }) + return +} + +func (tm *TokenMongoStore) findOne(query interface{}) (token *Token, err error) { + tm.mHandler.CHandle(tm.cName, func(c *mgo.Collection) { + var result []Token + err = c.Find(query).Sort("-ID").Limit(1).All(&result) + if err != nil { + return + } + if len(result) > 0 { + token = &result[0] + } + }) + return +} + +// GetByAccessToken 根据访问令牌获取令牌信息 +func (tm *TokenMongoStore) GetByAccessToken(accessToken string) (*Token, error) { + return tm.findOne(bson.M{"AccessToken": accessToken}) +} + +// GetByRefreshToken 根据更新令牌获取令牌信息 +func (tm *TokenMongoStore) GetByRefreshToken(refreshToken string) (*Token, error) { + return tm.findOne(bson.M{"RefreshToken": refreshToken}) +} diff --git a/tokenMongoStore_test.go b/tokenMongoStore_test.go new file mode 100644 index 0000000..6123557 --- /dev/null +++ b/tokenMongoStore_test.go @@ -0,0 +1,42 @@ +package oauth2_test + +import ( + "testing" + "time" + + "gopkg.in/oauth2.v1" + + . "github.com/smartystreets/goconvey/convey" +) + +func TestTokenMongoStore(t *testing.T) { + Convey("Token mongodb store test", t, func() { + tokenStore, err := oauth2.NewTokenMongoStore(oauth2.NewMongoConfig(MongoURL, DBName), "") + So(err, ShouldBeNil) + createAt := time.Now().Unix() + tokenValue := oauth2.Token{ + ClientID: "123456", + UserID: "999999", + AccessToken: "654321", + ATCreateAt: createAt, + ATExpiresIn: time.Second * 1, + RefreshToken: "000000", + RTCreateAt: createAt, + RTExpiresIn: time.Second * 1, + CreateAt: createAt, + Status: oauth2.Actived, + } + id, err := tokenStore.Create(tokenValue) + So(err, ShouldBeNil) + So(id, ShouldBeGreaterThanOrEqualTo, 1) + tokenValue.ID = id + err = tokenStore.Update(id, map[string]interface{}{"Status": oauth2.Expired}) + So(err, ShouldBeNil) + at, err := tokenStore.GetByAccessToken("654321") + So(err, ShouldBeNil) + So(at.Status, ShouldEqual, oauth2.Expired) + rt, err := tokenStore.GetByRefreshToken("000000") + So(err, ShouldBeNil) + So(rt.ID, ShouldEqual, id) + }) +} diff --git a/tokenStore.go b/tokenStore.go new file mode 100644 index 0000000..1269d7a --- /dev/null +++ b/tokenStore.go @@ -0,0 +1,41 @@ +package oauth2 + +import ( + "time" +) + +// Token 令牌信息 +type Token struct { + ID int64 `bson:"ID"` // 唯一标识(自增ID) + ClientID string `bson:"ClientID"` // 客户端标识 + UserID string `bson:"UserID"` // 用户标识 + AccessToken string `bson:"AccessToken"` // 访问令牌 + ATCreateAt int64 `bson:"ATCreateAt"` // 访问令牌创建时间(时间戳) + ATExpiresIn time.Duration `bson:"ATExpiresIn"` // 访问令牌有效期(单位秒) + RefreshToken string `bson:"RefreshToken"` // 更新令牌 + RTCreateAt int64 `bson:"RTCreateAt"` // 更新令牌创建时间(时间戳) + RTExpiresIn time.Duration `bson:"RTExpiresIn"` // 更新令牌有效期(单位秒) + Scope string `bson:"Scope"` // 申请的权限范围 + CreateAt int64 `bson:"CreateAt"` // 创建时间(时间戳) + Status STATUS `bson:"Status"` // 令牌状态 +} + +// TokenStore 令牌存储接口(持久化存储) +type TokenStore interface { + // Create 创建新的令牌,返回令牌ID + // 如果创建发生异常,则返回错误 + Create(item Token) (int64, error) + + // Update 根据ID更新令牌信息 + // info 需要更新的字段信息(字段名称与结构体的字段名保持一致) + // 如果更新发生异常,则返回错误 + Update(id int64, info map[string]interface{}) error + + // GetByAccessToken 根据访问令牌,获取令牌信息 + // 如果不存则返回nil + GetByAccessToken(accessToken string) (*Token, error) + + // GetByRefreshToken 根据更新令牌,获取令牌信息 + // 如果不存则返回nil + GetByRefreshToken(refreshToken string) (*Token, error) +} diff --git a/util.go b/util.go new file mode 100644 index 0000000..8a92145 --- /dev/null +++ b/util.go @@ -0,0 +1,28 @@ +package oauth2 + +import ( + "errors" + "net/url" +) + +// ValidateURI 验证基础的Uri与重定向的URI是否一致 +func ValidateURI(baseURI string, redirectURI string) error { + base, err := url.Parse(baseURI) + if err != nil { + return err + } + redirect, err := url.Parse(redirectURI) + if err != nil { + return err + } + if base.Fragment != "" || redirect.Fragment != "" { + return errors.New("Url must not include fragment.") + } + if base.Scheme != redirect.Scheme { + return errors.New("Scheme don't match.") + } + if base.Host != redirect.Host { + return errors.New("Host don't match.") + } + return nil +} diff --git a/util_test.go b/util_test.go new file mode 100644 index 0000000..61e736e --- /dev/null +++ b/util_test.go @@ -0,0 +1,16 @@ +package oauth2_test + +import ( + "testing" + + "gopkg.in/oauth2.v1" + + . "github.com/smartystreets/goconvey/convey" +) + +func TestUtil(t *testing.T) { + Convey("ValidateURI Test", t, func() { + err := oauth2.ValidateURI("http://www.example.com", "http://www.example.com/cb?code=xxx") + So(err, ShouldBeNil) + }) +} From 36978b2e50346f1b9c9eacf565fd4b8a2d93837d Mon Sep 17 00:00:00 2001 From: lyric Date: Thu, 26 May 2016 22:45:54 +0800 Subject: [PATCH 2/2] Add .DS_Store gitignore --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index daf913b..34b0df6 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,6 @@ _testmain.go *.exe *.test *.prof + +# OSX +*.DS_Store