forked from fasthttp/session
/
provider.go
173 lines (138 loc) · 3.42 KB
/
provider.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
package postgres
import (
"sync"
"time"
"github.com/codingbeard/session"
"github.com/savsgio/gotils"
)
var (
provider = NewProvider()
encrypt = session.NewEncrypt()
)
// NewProvider new postgres provider
func NewProvider() *Provider {
return &Provider{
config: new(Config),
db: new(Dao),
storePool: sync.Pool{
New: func() interface{} {
return new(Store)
},
},
}
}
func (pp *Provider) acquireStore(sessionID []byte, expiration time.Duration) *Store {
store := pp.storePool.Get().(*Store)
store.Init(sessionID, expiration)
return store
}
func (pp *Provider) releaseStore(store *Store) {
store.Reset()
pp.storePool.Put(store)
}
// Init init provider config
func (pp *Provider) Init(expiration time.Duration, cfg session.ProviderConfig) error {
if cfg.Name() != ProviderName {
return errInvalidProviderConfig
}
pp.config = cfg.(*Config)
pp.expiration = expiration
if pp.config.Host == "" {
return errConfigHostEmpty
}
if pp.config.Port == 0 {
return errConfigPortZero
}
if pp.config.SerializeFunc == nil {
pp.config.SerializeFunc = encrypt.Base64Encode
}
if pp.config.UnSerializeFunc == nil {
pp.config.UnSerializeFunc = encrypt.Base64Decode
}
var err error
pp.db, err = NewDao("postgres", pp.config.getPostgresDSN(), pp.config.TableName)
if err != nil {
return err
}
pp.db.Connection.SetMaxOpenConns(pp.config.SetMaxIdleConn)
pp.db.Connection.SetMaxIdleConns(pp.config.SetMaxIdleConn)
return pp.db.Connection.Ping()
}
// Get read session store by session id
func (pp *Provider) Get(sessionID []byte) (session.Storer, error) {
store := pp.acquireStore(sessionID, pp.expiration)
row, err := pp.db.getSessionBySessionID(sessionID)
if err != nil {
return nil, err
}
if row.sessionID != "" { // Exist
err = pp.config.UnSerializeFunc(store.DataPointer(), gotils.S2B(row.contents))
if err != nil {
return nil, err
}
} else { // Not exist
_, err = pp.db.insert(sessionID, nil, time.Now().Unix(), pp.expiration)
if err != nil {
return nil, err
}
}
releaseDBRow(row)
return store, nil
}
// Put put store into the pool.
func (pp *Provider) Put(store session.Storer) {
pp.releaseStore(store.(*Store))
}
// Regenerate regenerate session
func (pp *Provider) Regenerate(oldID, newID []byte) (session.Storer, error) {
store := pp.acquireStore(newID, pp.expiration)
row, err := pp.db.getSessionBySessionID(oldID)
if err != nil {
return nil, err
}
now := time.Now().Unix()
if row.sessionID != "" { // Exists
_, err = pp.db.regenerate(oldID, newID, now, pp.expiration)
if err != nil {
return nil, err
}
err = pp.config.UnSerializeFunc(store.DataPointer(), gotils.S2B(row.contents))
if err != nil {
return nil, err
}
} else { // Not exist
_, err = pp.db.insert(newID, nil, now, pp.expiration)
if err != nil {
return nil, err
}
}
releaseDBRow(row)
return store, nil
}
// Destroy destroy session by sessionID
func (pp *Provider) Destroy(sessionID []byte) error {
_, err := pp.db.deleteBySessionID(sessionID)
return err
}
// Count session values count
func (pp *Provider) Count() int {
return pp.db.countSessions()
}
// NeedGC need gc
func (pp *Provider) NeedGC() bool {
return true
}
// GC session garbage collection
func (pp *Provider) GC() {
_, err := pp.db.deleteExpiredSessions()
if err != nil {
panic(err)
}
}
// register session provider
func init() {
err := session.Register(ProviderName, provider)
if err != nil {
panic(err)
}
}