Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ Golang OAuth 2.0协议实现
[![GoDoc](https://godoc.org/gopkg.in/oauth2.v1?status.svg)](https://godoc.org/gopkg.in/oauth2.v1)
[![Go Report Card](https://goreportcard.com/badge/gopkg.in/oauth2.v1)](https://goreportcard.com/report/gopkg.in/oauth2.v1)

> 基于Golang实现的OAuth 2.0协议相关操作,包括:令牌(或授权码)的生成、存储、验证操作以及更新令牌、废除令牌; 具有简单、灵活的特点; 其中所涉及的相关http请求操作在这里不做处理; 支持授权码模式、简化模式、密码模式、客户端模式; 默认使用MongoDB存储相关信息

获取
----

Expand All @@ -16,7 +14,7 @@ $ go get -v gopkg.in/oauth2.v1
范例
----

> 数据初始化:初始化相关的客户端信息
> 使用之前,初始化客户端信息

```go
package main
Expand All @@ -28,49 +26,64 @@ import (
)

func main() {
mongoConfig := oauth2.NewMongoConfig("mongodb://127.0.0.1:27017", "test")
// 初始化配置参数
ocfg := &oauth2.OAuthConfig{
ACConfig: &oauth2.ACConfig{
ATExpiresIn: 60 * 60 * 24,
},
}
mcfg := oauth2.NewMongoConfig("mongodb://127.0.0.1:27017", "test")

// 创建默认的OAuth2管理实例(基于MongoDB)
manager, err := oauth2.CreateDefaultOAuthManager(mongoConfig, "", "", nil)
manager, err := oauth2.NewDefaultOAuthManager(ocfg, mcfg, "xxx", "xxx")
if err != nil {
panic(err)
}
manager.SetACGenerate(oauth2.NewDefaultACGenerate())
manager.SetACStore(oauth2.NewACMemoryStore(0))

// 模拟授权码模式
// 使用默认参数,生成授权码
code, err := manager.GetACManager().
GenerateCode("clientID_x", "userID_x", "http://www.example.com/cb", "scopes")
if err != nil {
panic(err)
}

// 生成访问令牌及更新令牌
genToken, err := manager.GetACManager().
GenerateToken(code, "http://www.example.com/cb", "clientID_x", "clientSecret_x", true)
if err != nil {
panic(err)
}

// 检查访问令牌
checkToken, err := manager.CheckAccessToken(genToken.AccessToken)
if err != nil {
panic(err)
}

// TODO: 使用用户标识、申请的授权范围响应数据
fmt.Println(checkToken.UserID, checkToken.Scope)
// 申请一个新的访问令牌

// 更新令牌
newToken, err := manager.RefreshAccessToken(checkToken.RefreshToken, "scopes")
if err != nil {
panic(err)
}
fmt.Println(newToken.AccessToken, newToken.ATExpiresIn)
// TODO: 将新的访问令牌响应给客户端

}
```

执行测试
----
-------

```bash
$ go test -v
# 或
$ goconvey --port=9090
$ goconvey -port=9090
```

License
Expand Down
7 changes: 2 additions & 5 deletions authorizationCode.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package oauth2
import (
"time"

"gopkg.in/LyricTian/lib.v2"
"github.com/LyricTian/go.uuid"
)

// NewACManager 创建授权码模式管理实例
Expand All @@ -13,9 +13,6 @@ func NewACManager(oaManager *OAuthManager, config *ACConfig) *ACManager {
if config == nil {
config = new(ACConfig)
}
if config.RandomCodeLen == 0 {
config.RandomCodeLen = DefaultRandomCodeLen
}
if config.ACExpiresIn == 0 {
config.ACExpiresIn = DefaultACExpiresIn
}
Expand Down Expand Up @@ -53,7 +50,7 @@ func (am *ACManager) GenerateCode(clientID, userID, redirectURI, scopes string)
UserID: userID,
RedirectURI: redirectURI,
Scope: scopes,
Code: lib.NewRandom(am.config.RandomCodeLen).NumberAndLetter(),
Code: uuid.NewV4().String(),
CreateAt: time.Now().Unix(),
ExpiresIn: time.Duration(am.config.ACExpiresIn) * time.Second,
}
Expand Down
33 changes: 16 additions & 17 deletions authorizationCodeGenerate.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"strconv"
"strings"

"github.com/LyricTian/go.uuid"

"gopkg.in/LyricTian/lib.v2"
)

Expand All @@ -30,31 +32,28 @@ func NewDefaultACGenerate() ACGenerate {
// ACGenerateDefault 默认的授权码生成方式
type ACGenerateDefault struct{}

func (ag *ACGenerateDefault) genToken(info *ACInfo) (string, error) {
var buf bytes.Buffer
_, _ = buf.WriteString(info.ClientID)
_ = buf.WriteByte('_')
func (ag *ACGenerateDefault) genCode(info *ACInfo) (string, error) {
ns, _ := uuid.FromString(info.Code)
buf := bytes.NewBuffer(uuid.NewV3(ns, info.ClientID).Bytes())
_, _ = buf.WriteString(info.UserID)
_ = buf.WriteByte('\n')
_, _ = buf.WriteString(strconv.FormatInt(info.CreateAt, 10))
_ = buf.WriteByte('\n')
_, _ = buf.WriteString(info.Code)

md5Val, err := lib.NewEncryption(buf.Bytes()).MD5()
if err != nil {
return "", err
}
buf.Reset()
md5Val = md5Val[:15]

return md5Val, nil
}

// Code Authorization code
func (ag *ACGenerateDefault) Code(info *ACInfo) (string, error) {
tokenVal, err := ag.genToken(info)
codeVal, err := ag.genCode(info)
if err != nil {
return "", err
}
val := base64.URLEncoding.EncodeToString([]byte(tokenVal + "." + strconv.FormatInt(info.ID, 10)))
val := base64.URLEncoding.EncodeToString([]byte(codeVal + "." + strconv.FormatInt(info.ID, 10)))
return strings.TrimRight(val, "="), nil
}

Expand All @@ -64,20 +63,20 @@ func (ag *ACGenerateDefault) parse(code string) (id int64, token string, err err
codeLen = 4 - codeLen
}
code = code + strings.Repeat("=", codeLen)
codeVal, err := base64.URLEncoding.DecodeString(code)
codeBV, err := base64.URLEncoding.DecodeString(code)
if err != nil {
return
}
tokenVal := strings.SplitN(string(codeVal), ".", 2)
if len(tokenVal) != 2 {
codeVal := strings.SplitN(string(codeBV), ".", 2)
if len(codeVal) != 2 {
err = errors.New("Token is invalid")
return
}
id, err = strconv.ParseInt(tokenVal[1], 10, 64)
id, err = strconv.ParseInt(codeVal[1], 10, 64)
if err != nil {
return
}
token = tokenVal[0]
token = codeVal[0]
return
}

Expand All @@ -93,9 +92,9 @@ func (ag *ACGenerateDefault) Verify(code string, info *ACInfo) (valid bool, err
if err != nil {
return
}
tokenVal, err := ag.genToken(info)
codeVal, err := ag.genCode(info)
if err != nil {
return
}
return token == tokenVal, nil
return token == codeVal, nil
}
7 changes: 3 additions & 4 deletions authorizationCodeGenerate_test.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
package oauth2_test
package oauth2

import (
"testing"
"time"

"gopkg.in/LyricTian/lib.v2"
"gopkg.in/oauth2.v1"

. "github.com/smartystreets/goconvey/convey"
)

func TestACGenerate(t *testing.T) {
Convey("Authorization code generate test", t, func() {
acGenerate := oauth2.NewDefaultACGenerate()
info := &oauth2.ACInfo{
acGenerate := NewDefaultACGenerate()
info := &ACInfo{
ID: 1,
ClientID: "123456",
UserID: "999999",
Expand Down
20 changes: 12 additions & 8 deletions authorizationCodeMemoryStore_test.go
Original file line number Diff line number Diff line change
@@ -1,36 +1,40 @@
package oauth2_test
package oauth2

import (
"testing"
"time"

"gopkg.in/oauth2.v1"

. "github.com/smartystreets/goconvey/convey"
)

func TestACMemoryStore(t *testing.T) {
Convey("AC memory store test", t, func() {
store := oauth2.NewACMemoryStore(1)
item := oauth2.ACInfo{
store := NewACMemoryStore(1)
item := ACInfo{
ClientID: "123456",
UserID: "999999",
CreateAt: time.Now().Unix(),
ExpiresIn: time.Millisecond * 500,
}

Convey("Put Test", func() {
id, err := store.Put(item)
So(err, ShouldBeNil)
So(id, ShouldEqual, 1)
item.ID = id
So(id, ShouldBeGreaterThan, 0)
Convey("Take Test", func() {
info, err := store.TakeByID(id)
So(err, ShouldBeNil)
So(info.ClientID, ShouldEqual, item.ClientID)
So(info.UserID, ShouldEqual, item.UserID)
})
})

Convey("GC Test", func() {
id, err := store.Put(item)
So(err, ShouldBeNil)
So(id, ShouldBeGreaterThan, 0)
Convey("Take GC Test", func() {
time.Sleep(time.Second * 2)
time.Sleep(time.Millisecond * 1500)
info, err := store.TakeByID(id)
So(err, ShouldNotBeNil)
So(info, ShouldBeNil)
Expand Down
89 changes: 89 additions & 0 deletions authorizationCodeRedisStore.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package oauth2

import (
"encoding/json"
"fmt"

"gopkg.in/redis.v3"
)

const (
// DefaultACRedisIDKey Redis存储授权码唯一标识的键
DefaultACRedisIDKey = "ACID"
)

// NewACRedisStore 创建Redis存储的实例
// config Redis配置参数
// key Redis存储授权码唯一标识的键(默认为ACID)
func NewACRedisStore(cfg *RedisConfig, key string) (*ACRedisStore, error) {
opt := &redis.Options{
Network: cfg.Network,
Addr: cfg.Addr,
Password: cfg.Password,
DB: cfg.DB,
MaxRetries: cfg.MaxRetries,
DialTimeout: cfg.DialTimeout,
ReadTimeout: cfg.ReadTimeout,
WriteTimeout: cfg.WriteTimeout,
PoolSize: cfg.PoolSize,
PoolTimeout: cfg.PoolTimeout,
}
cli := redis.NewClient(opt)
err := cli.Ping().Err()
if err != nil {
return nil, err
}
if key == "" {
key = DefaultACRedisIDKey
}
return &ACRedisStore{
cli: cli,
key: key,
}, nil
}

// ACRedisStore 提供授权码的redis存储
type ACRedisStore struct {
cli *redis.Client
key string
}

// Put 存储授权码
func (ar *ACRedisStore) Put(item ACInfo) (id int64, err error) {
n, err := ar.cli.Incr(ar.key).Result()
if err != nil {
return
}
item.ID = n
jv, err := json.Marshal(item)
if err != nil {
return
}
key := fmt.Sprintf("%s_%d", ar.key, n)
err = ar.cli.Set(key, string(jv), item.ExpiresIn).Err()
if err != nil {
return
}
id = item.ID
return
}

// TakeByID 取出授权码
func (ar *ACRedisStore) TakeByID(id int64) (info *ACInfo, err error) {
key := fmt.Sprintf("%s_%d", ar.key, id)
data, err := ar.cli.Get(key).Result()
if err != nil {
return
}
var v ACInfo
err = json.Unmarshal([]byte(data), &v)
if err != nil {
return
}
err = ar.cli.Del(key).Err()
if err != nil {
return
}
info = &v
return
}
Loading