-
Notifications
You must be signed in to change notification settings - Fork 0
/
pgx.go
199 lines (165 loc) · 6.67 KB
/
pgx.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
// Copyright 2021 George S. Kong. All rights reserved.
// Use of this source code is governed by a license that can be found in the LICENSE.txt file.
// XXX - test that the indices are make prune() and DeleteByUserID() perform well when there are lots ofsessions
// XXX - test that prune() keeps old expired sessions from wasting space in the database
// Package qspgx is a back-end for qsess which uses PostgreSQL, accessed via the pgx package.
package qspgx
import (
"context"
"encoding/binary"
"io"
"strconv"
"time"
"github.com/gkong/go-qweb/qsess"
"github.com/jackc/pgx/v4/pgxpool"
)
const DefaultPruneIntervalSecs = 2 * 60 // prune every 2 minutes
var noctx = context.Background()
// type pgxStore holds per-store information and conforms to the SessBackEnd interface.
type pgxStore struct {
db *pgxpool.Pool
table string
// SQL strings which can be precomputed, saving string concatenation
pGetQuerySQL string
pGetDeleteSQL string
pDeleteSQL string
pDeleteByUserIDSQL string
}
// NewPgxStore creates a new session store, using a PostgreSQL database accessed via pgxpool.
//
// table is the name of a database table to hold session data (it will be created if it doesn't exist).
//
// cipherkeys are one or more 32-byte encryption keys, to be used with AES-GCM.
// For encryption, only the first key is used; for decryption all keys are tried (allowing key rotation).
//
// Additional configuration options can be set by manipulating fields in the returned qsess.Store.
func NewPgxStore(pdb *pgxpool.Pool, tableName string, errLog io.Writer, cipherkeys ...[]byte) (*qsess.Store, error) {
ps := &pgxStore{
db: pdb,
table: tableName,
pGetQuerySQL: `SELECT data, userid, FLOOR(EXTRACT(EPOCH FROM (expires-NOW()))), maxage, minrefresh FROM ` + tableName + ` WHERE id = $1`,
pGetDeleteSQL: `DELETE FROM ` + tableName + ` WHERE id = $1`,
pDeleteSQL: `DELETE FROM ` + tableName + ` WHERE id = $1`,
pDeleteByUserIDSQL: `DELETE FROM ` + tableName + ` WHERE userid = $1`,
}
st, err := qsess.NewStore(ps, false, cipherkeys...)
if err != nil {
return nil, pgxErr{"NewPgxStore - NewStore - ", err}
}
st.PruneInterval = make(chan int)
st.PruneKill = make(chan int)
go ps.prune(DefaultPruneIntervalSecs, st.PruneInterval, st.PruneKill, errLog)
_, err = pdb.Exec(noctx,
`CREATE TABLE IF NOT EXISTS `+tableName+` (
id SERIAL PRIMARY KEY,
data BYTEA,
userid BYTEA,
expires TIMESTAMP NOT NULL,
maxage INTEGER,
minrefresh INTEGER
)`)
if err != nil {
return st, pgxErr{"NewPgxStore - CREATE TABLE failed - ", err}
}
_, err = pdb.Exec(noctx, `CREATE INDEX IF NOT EXISTS `+tableName+`_userid ON `+tableName+` (userid)`)
if err != nil {
return st, pgxErr{"NewPgxStore - CREATE userid index failed - ", err}
}
_, err = pdb.Exec(noctx, `CREATE INDEX IF NOT EXISTS `+tableName+`_expires ON `+tableName+` (expires)`)
if err != nil {
return st, pgxErr{"NewPgxStore - CREATE userid index failed - ", err}
}
return st, nil
}
func (ps *pgxStore) Get(sessIDbytes []byte, uidNOTUSED []byte) ([]byte, []byte, int, int, int, error) {
sessID := bytesToSessID(sessIDbytes)
var data, userID []byte
var ttl, maxage, minrefresh int
row := ps.db.QueryRow(noctx, ps.pGetQuerySQL, sessID)
if err := row.Scan(&data, &userID, &ttl, &maxage, &minrefresh); err != nil {
return []byte{}, []byte{}, 0, 0, 0, pgxErr{"pgxStore.Get - row.Scan failed - ", err}
}
if ttl <= 0 {
if _, err := ps.db.Exec(noctx, ps.pGetDeleteSQL, sessID); err != nil {
return []byte{}, []byte{}, 0, 0, 0, pgxErr{"pgxStore.Get - DELETE failed - ", err}
}
return []byte{}, []byte{}, 0, 0, 0, pgxErr{"pgxStore.Get - record has expired", nil}
}
return data, userID, ttl, maxage, minrefresh, nil
}
func (ps *pgxStore) Save(sessID *[]byte, data []byte, userID []byte, maxAgeSecs int, minRefreshSecs int) error {
if *sessID == nil {
// id is nil: insert a new record and save its id
var newID uint32
// XXX - find a way to make maxAgeSecs a parameter, so we can move the SQL strings into pgxStore and not have to recompute every time
row := ps.db.QueryRow(noctx, `INSERT INTO `+ps.table+
` (data, userid, expires, maxage, minrefresh) VALUES($1, $2, NOW() + INTERVAL '`+strconv.Itoa(maxAgeSecs)+` seconds', $3, $4) RETURNING id`,
data, userID, maxAgeSecs, minRefreshSecs)
if err := row.Scan(&newID); err != nil {
return pgxErr{"pgxStore.Save - row.Scan failed - ", err}
}
*sessID = sessIDToBytes(newID)
} else {
// id is NOT nil: it refers to an existing record; update it.
cmdtag, err := ps.db.Exec(noctx, `UPDATE `+ps.table+
` SET data = $1, userid = $2, expires = NOW() + INTERVAL '`+strconv.Itoa(maxAgeSecs)+` seconds', maxage = $3, minrefresh = $4 WHERE id = $5`,
data, userID, maxAgeSecs, minRefreshSecs, bytesToSessID(*sessID))
if err != nil {
return pgxErr{"pgxStore.Save - UPDATE failed - ", err}
}
// if record does not exist, UPDATE doesn't return an error! you have to check for no RowsAffected.
if cmdtag.RowsAffected() < 1 {
return pgxErr{"pgxStore.Save - UPDATE affected no rows - ", err}
}
}
return nil
}
func (ps *pgxStore) Delete(sessID []byte, uidNOTUSED []byte) error {
if _, err := ps.db.Exec(noctx, ps.pDeleteSQL, bytesToSessID(sessID)); err != nil {
return pgxErr{"pgxStore.Delete - DELETE failed - ", err}
}
return nil
}
func (ps *pgxStore) DeleteByUserID(userID []byte) error {
if _, err := ps.db.Exec(noctx, ps.pDeleteByUserIDSQL, userID); err != nil {
return pgxErr{"pgxStore.DeleteByUserID - DELETE failed - ", err}
}
return nil
}
// prune() periodically deletes expired sessions from the session store.
// the "expires" field must be indexed for this to run efficiently.
//
// prune runs in a goroutine, started by NewPgxStore.
// It runs until it receives something on its pruneKill channel.
// You can change its wait interval by sending a number of seconds to its pruneInterval channel.
func (ps *pgxStore) prune(waitSecs int, pruneInterval <-chan int, pruneKill <-chan int, log io.Writer) {
for {
select {
case waitSecs = <-pruneInterval:
case <-pruneKill:
return
case <-time.After(time.Duration(waitSecs) * time.Second):
}
ps.db.Exec(noctx, `DELETE FROM `+ps.table+` WHERE expires < NOW()`)
}
}
// serialize uint32, which we use to store a session id (database key).
func sessIDToBytes(id uint32) []byte {
b := make([]byte, 4)
binary.LittleEndian.PutUint32(b, id)
return b
}
func bytesToSessID(b []byte) uint32 {
// could check len(b), but depend on qsess promise not to touch
return binary.LittleEndian.Uint32(b)
}
type pgxErr struct {
msg string
err error
}
func (e pgxErr) Error() string {
if e.err != nil {
return "qspgx." + e.msg + " - " + e.err.Error()
}
return "qspgx." + e.msg
}