forked from qor/auth
/
handlers.go
97 lines (78 loc) · 2.81 KB
/
handlers.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
package password
import (
"reflect"
"strings"
"github.com/qor/auth"
"github.com/qor/auth/auth_identity"
"github.com/qor/auth/claims"
"github.com/qor/qor/utils"
"github.com/qor/session"
)
// DefaultAuthorizeHandler default authorize handler
var DefaultAuthorizeHandler = func(context *auth.Context) (*claims.Claims, error) {
var (
authInfo auth_identity.Basic
req = context.Request
tx = context.Auth.GetDB(req)
provider, _ = context.Provider.(*Provider)
)
req.ParseForm()
authInfo.Provider = provider.GetName()
authInfo.UID = strings.TrimSpace(req.Form.Get("login"))
if tx.Model(context.Auth.AuthIdentityModel).Where(authInfo).Scan(&authInfo).RecordNotFound() {
return nil, auth.ErrInvalidAccount
}
if provider.Config.Confirmable && authInfo.ConfirmedAt == nil {
currentUser, _ := context.Auth.UserStorer.Get(authInfo.ToClaims(), context)
provider.Config.ConfirmMailer(authInfo.UID, context, authInfo.ToClaims(), currentUser)
return nil, ErrUnconfirmed
}
if err := provider.Encryptor.Compare(authInfo.EncryptedPassword, strings.TrimSpace(req.Form.Get("password"))); err == nil {
return authInfo.ToClaims(), err
}
return nil, auth.ErrInvalidPassword
}
// DefaultRegisterHandler default register handler
var DefaultRegisterHandler = func(context *auth.Context) (*claims.Claims, error) {
var (
err error
currentUser interface{}
schema auth.Schema
authInfo auth_identity.Basic
req = context.Request
tx = context.Auth.GetDB(req)
provider, _ = context.Provider.(*Provider)
)
req.ParseForm()
if req.Form.Get("login") == "" {
return nil, auth.ErrInvalidAccount
}
if req.Form.Get("password") == "" {
return nil, auth.ErrInvalidPassword
}
authInfo.Provider = provider.GetName()
authInfo.UID = strings.TrimSpace(req.Form.Get("login"))
if !tx.Model(context.Auth.AuthIdentityModel).Where(authInfo).Scan(&authInfo).RecordNotFound() {
return nil, auth.ErrInvalidAccount
}
if authInfo.EncryptedPassword, err = provider.Encryptor.Digest(strings.TrimSpace(req.Form.Get("password"))); err == nil {
schema.Provider = authInfo.Provider
schema.UID = authInfo.UID
schema.Email = authInfo.UID
schema.RawInfo = req
currentUser, authInfo.UserID, err = context.Auth.UserStorer.Save(&schema, context)
if err != nil {
return nil, err
}
// create auth identity
authIdentity := reflect.New(utils.ModelType(context.Auth.Config.AuthIdentityModel)).Interface()
if err = tx.Where(authInfo).FirstOrCreate(authIdentity).Error; err == nil {
if provider.Config.Confirmable {
context.SessionStorer.Flash(context.Writer, req, session.Message{Message: ConfirmFlashMessage, Type: "success"})
err = provider.Config.ConfirmMailer(schema.Email, context, authInfo.ToClaims(), currentUser)
}
return authInfo.ToClaims(), err
}
}
return nil, err
}