diff --git a/README.md b/README.md index f9082cc..d6285ef 100644 --- a/README.md +++ b/README.md @@ -4,8 +4,6 @@ Golang OAuth 2.0协议实现 [![GoDoc](https://godoc.org/gopkg.in/oauth2.v1?status.svg)](https://godoc.org/gopkg.in/oauth2.v1) [![Go Report Card](https://goreportcard.com/badge/gopkg.in/oauth2.v1)](https://goreportcard.com/report/gopkg.in/oauth2.v1) -> 基于Golang实现的OAuth 2.0协议相关操作,包括:令牌(或授权码)的生成、存储、验证操作以及更新令牌、废除令牌; 具有简单、灵活的特点; 其中所涉及的相关http请求操作在这里不做处理; 支持授权码模式、简化模式、密码模式、客户端模式; 默认使用MongoDB存储相关信息 - 获取 ---- @@ -16,7 +14,7 @@ $ go get -v gopkg.in/oauth2.v1 范例 ---- -> 数据初始化:初始化相关的客户端信息 +> 使用之前,初始化客户端信息 ```go package main @@ -28,12 +26,22 @@ import ( ) func main() { - mongoConfig := oauth2.NewMongoConfig("mongodb://127.0.0.1:27017", "test") + // 初始化配置参数 + ocfg := &oauth2.OAuthConfig{ + ACConfig: &oauth2.ACConfig{ + ATExpiresIn: 60 * 60 * 24, + }, + } + mcfg := oauth2.NewMongoConfig("mongodb://127.0.0.1:27017", "test") + // 创建默认的OAuth2管理实例(基于MongoDB) - manager, err := oauth2.CreateDefaultOAuthManager(mongoConfig, "", "", nil) + manager, err := oauth2.NewDefaultOAuthManager(ocfg, mcfg, "xxx", "xxx") if err != nil { panic(err) } + manager.SetACGenerate(oauth2.NewDefaultACGenerate()) + manager.SetACStore(oauth2.NewACMemoryStore(0)) + // 模拟授权码模式 // 使用默认参数,生成授权码 code, err := manager.GetACManager(). @@ -41,36 +49,41 @@ func main() { if err != nil { panic(err) } + // 生成访问令牌及更新令牌 genToken, err := manager.GetACManager(). GenerateToken(code, "http://www.example.com/cb", "clientID_x", "clientSecret_x", true) if err != nil { panic(err) } + // 检查访问令牌 checkToken, err := manager.CheckAccessToken(genToken.AccessToken) if err != nil { panic(err) } + // TODO: 使用用户标识、申请的授权范围响应数据 fmt.Println(checkToken.UserID, checkToken.Scope) - // 申请一个新的访问令牌 + + // 更新令牌 newToken, err := manager.RefreshAccessToken(checkToken.RefreshToken, "scopes") if err != nil { panic(err) } fmt.Println(newToken.AccessToken, newToken.ATExpiresIn) // TODO: 将新的访问令牌响应给客户端 + } ``` 执行测试 ----- +------- ```bash $ go test -v # 或 -$ goconvey --port=9090 +$ goconvey -port=9090 ``` License diff --git a/authorizationCode.go b/authorizationCode.go index 671a1de..dc648ce 100644 --- a/authorizationCode.go +++ b/authorizationCode.go @@ -3,7 +3,7 @@ package oauth2 import ( "time" - "gopkg.in/LyricTian/lib.v2" + "github.com/LyricTian/go.uuid" ) // NewACManager 创建授权码模式管理实例 @@ -13,9 +13,6 @@ func NewACManager(oaManager *OAuthManager, config *ACConfig) *ACManager { if config == nil { config = new(ACConfig) } - if config.RandomCodeLen == 0 { - config.RandomCodeLen = DefaultRandomCodeLen - } if config.ACExpiresIn == 0 { config.ACExpiresIn = DefaultACExpiresIn } @@ -53,7 +50,7 @@ func (am *ACManager) GenerateCode(clientID, userID, redirectURI, scopes string) UserID: userID, RedirectURI: redirectURI, Scope: scopes, - Code: lib.NewRandom(am.config.RandomCodeLen).NumberAndLetter(), + Code: uuid.NewV4().String(), CreateAt: time.Now().Unix(), ExpiresIn: time.Duration(am.config.ACExpiresIn) * time.Second, } diff --git a/authorizationCodeGenerate.go b/authorizationCodeGenerate.go index 1414998..f4b75b8 100644 --- a/authorizationCodeGenerate.go +++ b/authorizationCodeGenerate.go @@ -7,6 +7,8 @@ import ( "strconv" "strings" + "github.com/LyricTian/go.uuid" + "gopkg.in/LyricTian/lib.v2" ) @@ -30,31 +32,28 @@ func NewDefaultACGenerate() ACGenerate { // ACGenerateDefault 默认的授权码生成方式 type ACGenerateDefault struct{} -func (ag *ACGenerateDefault) genToken(info *ACInfo) (string, error) { - var buf bytes.Buffer - _, _ = buf.WriteString(info.ClientID) - _ = buf.WriteByte('_') +func (ag *ACGenerateDefault) genCode(info *ACInfo) (string, error) { + ns, _ := uuid.FromString(info.Code) + buf := bytes.NewBuffer(uuid.NewV3(ns, info.ClientID).Bytes()) _, _ = buf.WriteString(info.UserID) - _ = buf.WriteByte('\n') _, _ = buf.WriteString(strconv.FormatInt(info.CreateAt, 10)) - _ = buf.WriteByte('\n') - _, _ = buf.WriteString(info.Code) + md5Val, err := lib.NewEncryption(buf.Bytes()).MD5() if err != nil { return "", err } - buf.Reset() md5Val = md5Val[:15] + return md5Val, nil } // Code Authorization code func (ag *ACGenerateDefault) Code(info *ACInfo) (string, error) { - tokenVal, err := ag.genToken(info) + codeVal, err := ag.genCode(info) if err != nil { return "", err } - val := base64.URLEncoding.EncodeToString([]byte(tokenVal + "." + strconv.FormatInt(info.ID, 10))) + val := base64.URLEncoding.EncodeToString([]byte(codeVal + "." + strconv.FormatInt(info.ID, 10))) return strings.TrimRight(val, "="), nil } @@ -64,20 +63,20 @@ func (ag *ACGenerateDefault) parse(code string) (id int64, token string, err err codeLen = 4 - codeLen } code = code + strings.Repeat("=", codeLen) - codeVal, err := base64.URLEncoding.DecodeString(code) + codeBV, err := base64.URLEncoding.DecodeString(code) if err != nil { return } - tokenVal := strings.SplitN(string(codeVal), ".", 2) - if len(tokenVal) != 2 { + codeVal := strings.SplitN(string(codeBV), ".", 2) + if len(codeVal) != 2 { err = errors.New("Token is invalid") return } - id, err = strconv.ParseInt(tokenVal[1], 10, 64) + id, err = strconv.ParseInt(codeVal[1], 10, 64) if err != nil { return } - token = tokenVal[0] + token = codeVal[0] return } @@ -93,9 +92,9 @@ func (ag *ACGenerateDefault) Verify(code string, info *ACInfo) (valid bool, err if err != nil { return } - tokenVal, err := ag.genToken(info) + codeVal, err := ag.genCode(info) if err != nil { return } - return token == tokenVal, nil + return token == codeVal, nil } diff --git a/authorizationCodeGenerate_test.go b/authorizationCodeGenerate_test.go index 9c07c82..9ce5b62 100644 --- a/authorizationCodeGenerate_test.go +++ b/authorizationCodeGenerate_test.go @@ -1,19 +1,18 @@ -package oauth2_test +package oauth2 import ( "testing" "time" "gopkg.in/LyricTian/lib.v2" - "gopkg.in/oauth2.v1" . "github.com/smartystreets/goconvey/convey" ) func TestACGenerate(t *testing.T) { Convey("Authorization code generate test", t, func() { - acGenerate := oauth2.NewDefaultACGenerate() - info := &oauth2.ACInfo{ + acGenerate := NewDefaultACGenerate() + info := &ACInfo{ ID: 1, ClientID: "123456", UserID: "999999", diff --git a/authorizationCodeMemoryStore_test.go b/authorizationCodeMemoryStore_test.go index 6c86d02..e6dd128 100644 --- a/authorizationCodeMemoryStore_test.go +++ b/authorizationCodeMemoryStore_test.go @@ -1,36 +1,40 @@ -package oauth2_test +package oauth2 import ( "testing" "time" - "gopkg.in/oauth2.v1" - . "github.com/smartystreets/goconvey/convey" ) func TestACMemoryStore(t *testing.T) { Convey("AC memory store test", t, func() { - store := oauth2.NewACMemoryStore(1) - item := oauth2.ACInfo{ + store := NewACMemoryStore(1) + item := ACInfo{ ClientID: "123456", UserID: "999999", CreateAt: time.Now().Unix(), ExpiresIn: time.Millisecond * 500, } + Convey("Put Test", func() { id, err := store.Put(item) So(err, ShouldBeNil) - So(id, ShouldEqual, 1) - item.ID = id + So(id, ShouldBeGreaterThan, 0) Convey("Take Test", func() { info, err := store.TakeByID(id) So(err, ShouldBeNil) So(info.ClientID, ShouldEqual, item.ClientID) So(info.UserID, ShouldEqual, item.UserID) }) + }) + + Convey("GC Test", func() { + id, err := store.Put(item) + So(err, ShouldBeNil) + So(id, ShouldBeGreaterThan, 0) Convey("Take GC Test", func() { - time.Sleep(time.Second * 2) + time.Sleep(time.Millisecond * 1500) info, err := store.TakeByID(id) So(err, ShouldNotBeNil) So(info, ShouldBeNil) diff --git a/authorizationCodeRedisStore.go b/authorizationCodeRedisStore.go new file mode 100644 index 0000000..c3c8db8 --- /dev/null +++ b/authorizationCodeRedisStore.go @@ -0,0 +1,89 @@ +package oauth2 + +import ( + "encoding/json" + "fmt" + + "gopkg.in/redis.v3" +) + +const ( + // DefaultACRedisIDKey Redis存储授权码唯一标识的键 + DefaultACRedisIDKey = "ACID" +) + +// NewACRedisStore 创建Redis存储的实例 +// config Redis配置参数 +// key Redis存储授权码唯一标识的键(默认为ACID) +func NewACRedisStore(cfg *RedisConfig, key string) (*ACRedisStore, error) { + opt := &redis.Options{ + Network: cfg.Network, + Addr: cfg.Addr, + Password: cfg.Password, + DB: cfg.DB, + MaxRetries: cfg.MaxRetries, + DialTimeout: cfg.DialTimeout, + ReadTimeout: cfg.ReadTimeout, + WriteTimeout: cfg.WriteTimeout, + PoolSize: cfg.PoolSize, + PoolTimeout: cfg.PoolTimeout, + } + cli := redis.NewClient(opt) + err := cli.Ping().Err() + if err != nil { + return nil, err + } + if key == "" { + key = DefaultACRedisIDKey + } + return &ACRedisStore{ + cli: cli, + key: key, + }, nil +} + +// ACRedisStore 提供授权码的redis存储 +type ACRedisStore struct { + cli *redis.Client + key string +} + +// Put 存储授权码 +func (ar *ACRedisStore) Put(item ACInfo) (id int64, err error) { + n, err := ar.cli.Incr(ar.key).Result() + if err != nil { + return + } + item.ID = n + jv, err := json.Marshal(item) + if err != nil { + return + } + key := fmt.Sprintf("%s_%d", ar.key, n) + err = ar.cli.Set(key, string(jv), item.ExpiresIn).Err() + if err != nil { + return + } + id = item.ID + return +} + +// TakeByID 取出授权码 +func (ar *ACRedisStore) TakeByID(id int64) (info *ACInfo, err error) { + key := fmt.Sprintf("%s_%d", ar.key, id) + data, err := ar.cli.Get(key).Result() + if err != nil { + return + } + var v ACInfo + err = json.Unmarshal([]byte(data), &v) + if err != nil { + return + } + err = ar.cli.Del(key).Err() + if err != nil { + return + } + info = &v + return +} diff --git a/authorizationCodeRedisStore_test.go b/authorizationCodeRedisStore_test.go new file mode 100644 index 0000000..96aad3a --- /dev/null +++ b/authorizationCodeRedisStore_test.go @@ -0,0 +1,49 @@ +package oauth2 + +import ( + "testing" + "time" + + . "github.com/smartystreets/goconvey/convey" +) + +func TestACRedisStore(t *testing.T) { + Convey("Authorization code redis store test", t, func() { + store, err := NewACRedisStore(&RedisConfig{ + Addr: "192.168.33.70:6379", + DB: 1, + }, "") + So(err, ShouldBeNil) + item := ACInfo{ + ClientID: "123456", + UserID: "999999", + Code: "", + CreateAt: time.Now().Unix(), + ExpiresIn: time.Millisecond * 500, + } + + Convey("Put Test", func() { + id, err := store.Put(item) + So(err, ShouldBeNil) + So(id, ShouldBeGreaterThan, 0) + Convey("Take Test", func() { + info, err := store.TakeByID(id) + So(err, ShouldBeNil) + So(info.ClientID, ShouldEqual, item.ClientID) + So(info.UserID, ShouldEqual, item.UserID) + }) + }) + + Convey("GC Test", func() { + id, err := store.Put(item) + So(err, ShouldBeNil) + So(id, ShouldBeGreaterThan, 0) + Convey("Take GC Test", func() { + time.Sleep(time.Millisecond * 1500) + info, err := store.TakeByID(id) + So(err, ShouldNotBeNil) + So(info, ShouldBeNil) + }) + }) + }) +} diff --git a/authorizationCode_test.go b/authorizationCode_test.go index e9bc583..954852b 100644 --- a/authorizationCode_test.go +++ b/authorizationCode_test.go @@ -1,20 +1,22 @@ -package oauth2_test +package oauth2 import ( "testing" - "gopkg.in/oauth2.v1" - . "github.com/smartystreets/goconvey/convey" ) func TestACManager(t *testing.T) { - ClientHandle(func(info oauth2.Client) { - userID := "999999" - oManager, err := oauth2.CreateDefaultOAuthManager(oauth2.NewMongoConfig(MongoURL, DBName), "", "", nil) + ClientHandle(func(info Client) { + oManager, err := NewDefaultOAuthManager(nil, NewMongoConfig(MongoURL, DBName), "", "") if err != nil { - t.Error(err) + t.Fatal(err) } + oManager.SetACGenerate(NewDefaultACGenerate()) + oManager.SetACStore(NewACMemoryStore(0)) + + userID := "999999" + Convey("Authorization Code Manager Test", t, func() { manager := oManager.GetACManager() diff --git a/clientCredentials_test.go b/clientCredentials_test.go index 528e2c4..bea67d9 100644 --- a/clientCredentials_test.go +++ b/clientCredentials_test.go @@ -1,17 +1,16 @@ -package oauth2_test +package oauth2 import ( . "github.com/smartystreets/goconvey/convey" - "gopkg.in/oauth2.v1" "testing" ) func TestCCManager(t *testing.T) { - ClientHandle(func(cli oauth2.Client) { - oManager, err := oauth2.CreateDefaultOAuthManager(oauth2.NewMongoConfig(MongoURL, DBName), "", "", nil) + ClientHandle(func(cli Client) { + oManager, err := NewDefaultOAuthManager(nil, NewMongoConfig(MongoURL, DBName), "", "") if err != nil { - t.Error(err) + t.Fatal(err) } Convey("Client Credentials Manager Test", t, func() { manager := oManager.GetCCManager() diff --git a/clientMongoStore_test.go b/clientMongoStore_test.go index 92c0b78..adb6f44 100644 --- a/clientMongoStore_test.go +++ b/clientMongoStore_test.go @@ -1,17 +1,15 @@ -package oauth2_test +package oauth2 import ( "testing" - "gopkg.in/oauth2.v1" - . "github.com/smartystreets/goconvey/convey" ) func TestClientMongoStore(t *testing.T) { - ClientHandle(func(info oauth2.Client) { + ClientHandle(func(info Client) { Convey("Client mongodb store test", t, func() { - clientStore, err := oauth2.NewClientMongoStore(oauth2.NewMongoConfig(MongoURL, DBName), "") + clientStore, err := NewClientMongoStore(NewMongoConfig(MongoURL, DBName), "") So(err, ShouldBeNil) client, err := clientStore.GetByID(info.ID()) So(err, ShouldBeNil) diff --git a/config.go b/config.go index 9943712..b5f0c68 100644 --- a/config.go +++ b/config.go @@ -16,10 +16,9 @@ func NewMongoConfig(url, dbName string) *MongoConfig { // ACConfig 授权码模式配置参数(Authorization Code Config) type ACConfig struct { - RandomCodeLen int // 随机码的长度(用于生成授权码的随机码) - ACExpiresIn int64 // 授权码有效期(单位秒) - ATExpiresIn int64 // 访问令牌有效期(单位秒) - RTExpiresIn int64 // 更新令牌有效期(单位秒) + ACExpiresIn int64 // 授权码有效期(单位秒) + ATExpiresIn int64 // 访问令牌有效期(单位秒) + RTExpiresIn int64 // 更新令牌有效期(单位秒) } // ImplicitConfig 简化模式配置参数 diff --git a/config_redis.go b/config_redis.go new file mode 100644 index 0000000..9f0546b --- /dev/null +++ b/config_redis.go @@ -0,0 +1,41 @@ +package oauth2 + +import "time" + +// RedisConfig Redis的配置参数 +type RedisConfig struct { + // The network type, either tcp or unix. + // Default is tcp. + Network string + // host:port address. + Addr string + + // An optional password. Must match the password specified in the + // requirepass server configuration option. + Password string + // A database to be selected after connecting to server. + DB int64 + + // The maximum number of retries before giving up. + // Default is to not retry failed commands. + MaxRetries int + + // Sets the deadline for establishing new connections. If reached, + // dial will fail with a timeout. + // Default is 5 seconds. + DialTimeout time.Duration + // Sets the deadline for socket reads. If reached, commands will + // fail with a timeout instead of blocking. + ReadTimeout time.Duration + // Sets the deadline for socket writes. If reached, commands will + // fail with a timeout instead of blocking. + WriteTimeout time.Duration + + // The maximum number of socket connections. + // Default is 10 connections. + PoolSize int + // Specifies amount of time client waits for connection if all + // connections are busy before returning an error. + // Default is 1 second. + PoolTimeout time.Duration +} diff --git a/implicit_test.go b/implicit_test.go index 4c454ea..aaded8f 100644 --- a/implicit_test.go +++ b/implicit_test.go @@ -1,19 +1,19 @@ -package oauth2_test +package oauth2 import ( . "github.com/smartystreets/goconvey/convey" - "gopkg.in/oauth2.v1" "testing" ) func TestImplicitManager(t *testing.T) { - ClientHandle(func(cli oauth2.Client) { + ClientHandle(func(cli Client) { userID := "999999" - oManager, err := oauth2.CreateDefaultOAuthManager(oauth2.NewMongoConfig(MongoURL, DBName), "", "", nil) + oManager, err := NewDefaultOAuthManager(nil, NewMongoConfig(MongoURL, DBName), "", "") if err != nil { - t.Error(err) + t.Fatal(err) } + Convey("Implicit Manager Test", t, func() { manager := oManager.GetImplicitManager() diff --git a/oauth2.go b/oauth2.go index 311d090..37aef1b 100644 --- a/oauth2.go +++ b/oauth2.go @@ -6,32 +6,37 @@ import ( "time" ) -// CreateDefaultOAuthManager 创建默认的OAuth授权管理实例 -// mongoConfig MongoDB配置参数 -// tokenCollectionName 存储令牌的集合名称(默认为AuthToken) -// clientCollectionName 存储客户端的集合名称(默认为ClientInfo) -// oauthConfig 配置参数 -func CreateDefaultOAuthManager(mongoConfig *MongoConfig, tokenCollectionName, clientCollectionName string, oauthConfig *OAuthConfig) (*OAuthManager, error) { - if oauthConfig == nil { - oauthConfig = new(OAuthConfig) - } - oaManager := &OAuthManager{ - Config: oauthConfig, - ACGenerate: NewDefaultACGenerate(), - ACStore: NewACMemoryStore(0), - TokenGenerate: NewDefaultTokenGenerate(), - } - tokenStore, err := NewTokenMongoStore(mongoConfig, tokenCollectionName) +// NewOAuthManager 创建OAuth授权管理实例 +// cfg 配置参数 +func NewOAuthManager(cfg *OAuthConfig) *OAuthManager { + if cfg == nil { + cfg = new(OAuthConfig) + } + return &OAuthManager{ + Config: cfg, + } +} + +// NewDefaultOAuthManager 创建默认的OAuth授权管理实例 +// cfg 配置参数 +// mcfg MongoDB配置参数 +// ccName 存储客户端的集合名称(默认为ClientInfo) +// tcName 存储令牌的集合名称(默认为AuthToken) +func NewDefaultOAuthManager(cfg *OAuthConfig, mcfg *MongoConfig, ccName, tcName string) (*OAuthManager, error) { + oManager := NewOAuthManager(cfg) + clientStore, err := NewClientMongoStore(mcfg, ccName) if err != nil { return nil, err } - oaManager.TokenStore = tokenStore - clientStore, err := NewClientMongoStore(mongoConfig, clientCollectionName) + oManager.SetClientStore(clientStore) + tokenStore, err := NewTokenMongoStore(mcfg, tcName) if err != nil { return nil, err } - oaManager.ClientStore = clientStore - return oaManager, nil + oManager.SetTokenStore(tokenStore) + oManager.SetTokenGenerate(NewDefaultTokenGenerate()) + + return oManager, nil } // OAuthManager OAuth授权管理 @@ -44,6 +49,11 @@ type OAuthManager struct { ClientStore ClientStore // 客户端存储 } +// SetConfig 设置授权码生成接口 +func (om *OAuthManager) SetConfig(cfg *OAuthConfig) { + om.Config = cfg +} + // SetACGenerate 设置授权码生成接口 func (om *OAuthManager) SetACGenerate(generate ACGenerate) { om.ACGenerate = generate @@ -54,6 +64,21 @@ func (om *OAuthManager) SetACStore(store ACStore) { om.ACStore = store } +// SetTokenGenerate 设置令牌生成接口 +func (om *OAuthManager) SetTokenGenerate(generate TokenGenerate) { + om.TokenGenerate = generate +} + +// SetTokenStore 设置令牌存储接口 +func (om *OAuthManager) SetTokenStore(store TokenStore) { + om.TokenStore = store +} + +// SetClientStore 设置客户端存储接口 +func (om *OAuthManager) SetClientStore(store ClientStore) { + om.ClientStore = store +} + // GetACManager 获取授权码模式管理实例 func (om *OAuthManager) GetACManager() *ACManager { return NewACManager(om, om.Config.ACConfig) diff --git a/oauth2_test.go b/oauth2_test.go index 92b40b6..87ffcb3 100644 --- a/oauth2_test.go +++ b/oauth2_test.go @@ -1,22 +1,25 @@ -package oauth2_test +package oauth2 import ( "gopkg.in/LyricTian/lib.v2" "gopkg.in/LyricTian/lib.v2/mongo" "gopkg.in/mgo.v2/bson" - "gopkg.in/oauth2.v1" ) const ( // MongoURL MongoDB连接字符串 - MongoURL = "mongodb://admin:123456@45.78.35.157:37017" + MongoURL = "mongodb://admin:123456@192.168.33.70:27017" // DBName 数据库名称 DBName = "test" ) +var ( + oManager *OAuthManager +) + // ClientHandle 执行客户端处理 -func ClientHandle(handle func(cli oauth2.Client)) { - info := oauth2.DefaultClient{ +func ClientHandle(handle func(cli Client)) { + info := DefaultClient{ ClientID: bson.NewObjectId().Hex(), ClientDomain: "http://www.example.com", } @@ -26,13 +29,13 @@ func ClientHandle(handle func(cli oauth2.Client)) { panic(err) } defer func() { - err = mHandler.C(oauth2.DefaultClientCollectionName).RemoveId(info.ClientID) + err = mHandler.C(DefaultClientCollectionName).RemoveId(info.ClientID) if err != nil { panic(err) } mHandler.Session().Close() }() - err = mHandler.C(oauth2.DefaultClientCollectionName).Insert(info) + err = mHandler.C(DefaultClientCollectionName).Insert(info) if err != nil { panic(err) } diff --git a/password_test.go b/password_test.go index 575aa6e..65b7257 100644 --- a/password_test.go +++ b/password_test.go @@ -1,20 +1,19 @@ -package oauth2_test +package oauth2 import ( "testing" - "gopkg.in/oauth2.v1" - . "github.com/smartystreets/goconvey/convey" ) func TestPasswordManager(t *testing.T) { - ClientHandle(func(info oauth2.Client) { + ClientHandle(func(info Client) { userID := "999999" - oManager, err := oauth2.CreateDefaultOAuthManager(oauth2.NewMongoConfig(MongoURL, DBName), "", "", nil) + oManager, err := NewDefaultOAuthManager(nil, NewMongoConfig(MongoURL, DBName), "", "") if err != nil { - t.Error(err) + t.Fatal(err) } + Convey("Password Manager Test", t, func() { manager := oManager.GetPasswordManager() diff --git a/tokenGenerate_test.go b/tokenGenerate_test.go index d2aa759..e7a9003 100644 --- a/tokenGenerate_test.go +++ b/tokenGenerate_test.go @@ -1,4 +1,4 @@ -package oauth2_test +package oauth2 import ( "testing" @@ -6,25 +6,23 @@ import ( "github.com/LyricTian/go.uuid" - "gopkg.in/oauth2.v1" - . "github.com/smartystreets/goconvey/convey" ) func TestTokenGenerate(t *testing.T) { - cli := oauth2.DefaultClient{ + cli := DefaultClient{ ClientID: "123456", ClientSecret: "654321", ClientDomain: "http://www.lyric.name", } - basicInfo := &oauth2.TokenBasicInfo{ + basicInfo := &TokenBasicInfo{ Client: cli, TokenID: uuid.NewV4().String(), UserID: "999999", CreateAt: time.Now().Unix(), } Convey("Token generate test", t, func() { - tokenGenerate := oauth2.NewDefaultTokenGenerate() + tokenGenerate := NewDefaultTokenGenerate() Convey("Generate access token", func() { token, err := tokenGenerate.AccessToken(basicInfo) So(err, ShouldBeNil) diff --git a/tokenMongoStore_test.go b/tokenMongoStore_test.go index 8d1ac11..2ab1b87 100644 --- a/tokenMongoStore_test.go +++ b/tokenMongoStore_test.go @@ -1,20 +1,18 @@ -package oauth2_test +package oauth2 import ( "testing" "time" - "gopkg.in/oauth2.v1" - . "github.com/smartystreets/goconvey/convey" ) func TestTokenMongoStore(t *testing.T) { Convey("Token mongodb store test", t, func() { - tokenStore, err := oauth2.NewTokenMongoStore(oauth2.NewMongoConfig(MongoURL, DBName), "") + tokenStore, err := NewTokenMongoStore(NewMongoConfig(MongoURL, DBName), "") So(err, ShouldBeNil) createAt := time.Now().Unix() - tokenValue := oauth2.Token{ + tokenValue := Token{ ClientID: "123456", UserID: "999999", AccessToken: "654321", @@ -24,17 +22,17 @@ func TestTokenMongoStore(t *testing.T) { RTCreateAt: createAt, RTExpiresIn: time.Second * 1, CreateAt: createAt, - Status: oauth2.Actived, + Status: Actived, } id, err := tokenStore.Create(&tokenValue) So(err, ShouldBeNil) So(id, ShouldBeGreaterThanOrEqualTo, 1) tokenValue.ID = id - err = tokenStore.Update(id, map[string]interface{}{"Status": oauth2.Expired}) + err = tokenStore.Update(id, map[string]interface{}{"Status": Expired}) So(err, ShouldBeNil) at, err := tokenStore.GetByAccessToken("654321") So(err, ShouldBeNil) - So(at.Status, ShouldEqual, oauth2.Expired) + So(at.Status, ShouldEqual, Expired) rt, err := tokenStore.GetByRefreshToken("000000") So(err, ShouldBeNil) So(rt.ID, ShouldEqual, id) diff --git a/util_test.go b/util_test.go index 61e736e..54689b2 100644 --- a/util_test.go +++ b/util_test.go @@ -1,16 +1,14 @@ -package oauth2_test +package oauth2 import ( "testing" - "gopkg.in/oauth2.v1" - . "github.com/smartystreets/goconvey/convey" ) func TestUtil(t *testing.T) { Convey("ValidateURI Test", t, func() { - err := oauth2.ValidateURI("http://www.example.com", "http://www.example.com/cb?code=xxx") + err := ValidateURI("http://www.example.com", "http://www.example.com/cb?code=xxx") So(err, ShouldBeNil) }) }