/
auth.go
153 lines (133 loc) · 3.49 KB
/
auth.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
package auth
import (
"github.com/go-kratos/kratos/pkg/ecode"
bm "github.com/go-kratos/kratos/pkg/net/http/blademaster"
"github.com/go-kratos/kratos/pkg/net/metadata"
)
// Config is the identify config model.
type Config struct {
// csrf switch.
DisableCSRF bool
}
// Auth is the authorization middleware
type Auth struct {
conf *Config
}
// authFunc will return mid and error by given context
type authFunc func(*bm.Context) (int64, error)
var _defaultConf = &Config{
DisableCSRF: false,
}
// New is used to create an authorization middleware
func New(conf *Config) *Auth {
if conf == nil {
conf = _defaultConf
}
auth := &Auth{
conf: conf,
}
return auth
}
// User is used to mark path as access required.
// If `access_token` is exist in request form, it will using mobile access policy.
// Otherwise to web access policy.
func (a *Auth) User(ctx *bm.Context) {
req := ctx.Request
if req.Form.Get("access_token") == "" {
a.UserWeb(ctx)
return
}
a.UserMobile(ctx)
}
// UserWeb is used to mark path as web access required.
func (a *Auth) UserWeb(ctx *bm.Context) {
a.midAuth(ctx, a.authCookie)
}
// UserMobile is used to mark path as mobile access required.
func (a *Auth) UserMobile(ctx *bm.Context) {
a.midAuth(ctx, a.authToken)
}
// Guest is used to mark path as guest policy.
// If `access_token` is exist in request form, it will using mobile access policy.
// Otherwise to web access policy.
func (a *Auth) Guest(ctx *bm.Context) {
req := ctx.Request
if req.Form.Get("access_token") == "" {
a.GuestWeb(ctx)
return
}
a.GuestMobile(ctx)
}
// GuestWeb is used to mark path as web guest policy.
func (a *Auth) GuestWeb(ctx *bm.Context) {
a.guestAuth(ctx, a.authCookie)
}
// GuestMobile is used to mark path as mobile guest policy.
func (a *Auth) GuestMobile(ctx *bm.Context) {
a.guestAuth(ctx, a.authToken)
}
// authToken is used to authorize request by token
func (a *Auth) authToken(ctx *bm.Context) (int64, error) {
req := ctx.Request
key := req.Form.Get("access_token")
if key == "" {
return 0, ecode.Unauthorized
}
// NOTE: 请求登录鉴权服务接口,拿到对应的用户id
var mid int64
// TODO: get mid from some code
return mid, nil
}
// authCookie is used to authorize request by cookie
func (a *Auth) authCookie(ctx *bm.Context) (int64, error) {
req := ctx.Request
session, _ := req.Cookie("SESSION")
if session == nil {
return 0, ecode.Unauthorized
}
// NOTE: 请求登录鉴权服务接口,拿到对应的用户id
var mid int64
// TODO: get mid from some code
// check csrf
clientCsrf := req.FormValue("csrf")
if a.conf != nil && !a.conf.DisableCSRF && req.Method == "POST" {
// NOTE: 如果开启了CSRF认证,请从CSRF服务获取该用户关联的csrf
var csrf string // TODO: get csrf from some code
if clientCsrf != csrf {
return 0, ecode.Unauthorized
}
}
return mid, nil
}
func (a *Auth) midAuth(ctx *bm.Context, auth authFunc) {
mid, err := auth(ctx)
if err != nil {
ctx.JSON(nil, err)
ctx.Abort()
return
}
setMid(ctx, mid)
}
func (a *Auth) guestAuth(ctx *bm.Context, auth authFunc) {
mid, err := auth(ctx)
// no error happened and mid is valid
if err == nil && mid > 0 {
setMid(ctx, mid)
return
}
ec := ecode.Cause(err)
if ecode.Equal(ec, ecode.Unauthorized) {
ctx.JSON(nil, ec)
ctx.Abort()
return
}
}
// set mid into context
// NOTE: This method is not thread safe.
func setMid(ctx *bm.Context, mid int64) {
ctx.Set(metadata.Mid, mid)
if md, ok := metadata.FromContext(ctx); ok {
md[metadata.Mid] = mid
return
}
}