-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
userfuncs.go
280 lines (243 loc) · 7.37 KB
/
userfuncs.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
package harmonydb
import (
"context"
"errors"
"fmt"
"runtime"
"time"
"github.com/georgysavva/scany/v2/dbscan"
"github.com/jackc/pgerrcode"
"github.com/samber/lo"
"github.com/yugabyte/pgx/v5"
"github.com/yugabyte/pgx/v5/pgconn"
)
var errTx = errors.New("cannot use a non-transaction func in a transaction")
// rawStringOnly is _intentionally_private_ to force only basic strings in SQL queries.
// In any package, raw strings will satisfy compilation. Ex:
//
// harmonydb.Exec("INSERT INTO version (number) VALUES (1)")
//
// This prevents SQL injection attacks where the input contains query fragments.
type rawStringOnly string
// Exec executes changes (INSERT, DELETE, or UPDATE).
// Note, for CREATE & DROP please keep these permanent and express
// them in the ./sql/ files (next number).
func (db *DB) Exec(ctx context.Context, sql rawStringOnly, arguments ...any) (count int, err error) {
if db.usedInTransaction() {
return 0, errTx
}
res, err := db.pgx.Exec(ctx, string(sql), arguments...)
return int(res.RowsAffected()), err
}
type Qry interface {
Next() bool
Err() error
Close()
Scan(...any) error
Values() ([]any, error)
}
// Query offers Next/Err/Close/Scan/Values
type Query struct {
Qry
}
// Query allows iterating returned values to save memory consumption
// with the downside of needing to `defer q.Close()`. For a simpler interface,
// try Select()
// Next() must be called to advance the row cursor, including the first time:
// Ex:
// q, err := db.Query(ctx, "SELECT id, name FROM users")
// handleError(err)
// defer q.Close()
//
// for q.Next() {
// var id int
// var name string
// handleError(q.Scan(&id, &name))
// fmt.Println(id, name)
// }
func (db *DB) Query(ctx context.Context, sql rawStringOnly, arguments ...any) (*Query, error) {
if db.usedInTransaction() {
return &Query{}, errTx
}
q, err := db.pgx.Query(ctx, string(sql), arguments...)
return &Query{q}, err
}
// StructScan allows scanning a single row into a struct.
// This improves efficiency of processing large result sets
// by avoiding the need to allocate a slice of structs.
func (q *Query) StructScan(s any) error {
return dbscan.ScanRow(s, dbscanRows{q.Qry.(pgx.Rows)})
}
type Row interface {
Scan(...any) error
}
type rowErr struct{}
func (rowErr) Scan(_ ...any) error { return errTx }
// QueryRow gets 1 row using column order matching.
// This is a timesaver for the special case of wanting the first row returned only.
// EX:
//
// var name, pet string
// var ID = 123
// err := db.QueryRow(ctx, "SELECT name, pet FROM users WHERE ID=?", ID).Scan(&name, &pet)
func (db *DB) QueryRow(ctx context.Context, sql rawStringOnly, arguments ...any) Row {
if db.usedInTransaction() {
return rowErr{}
}
return db.pgx.QueryRow(ctx, string(sql), arguments...)
}
type dbscanRows struct {
pgx.Rows
}
func (d dbscanRows) Close() error {
d.Rows.Close()
return nil
}
func (d dbscanRows) Columns() ([]string, error) {
return lo.Map(d.Rows.FieldDescriptions(), func(fd pgconn.FieldDescription, _ int) string {
return fd.Name
}), nil
}
func (d dbscanRows) NextResultSet() bool {
return false
}
/*
Select multiple rows into a slice using name matching
Ex:
type user struct {
Name string
ID int
Number string `db:"tel_no"`
}
var users []user
pet := "cat"
err := db.Select(ctx, &users, "SELECT name, id, tel_no FROM customers WHERE pet=?", pet)
*/
func (db *DB) Select(ctx context.Context, sliceOfStructPtr any, sql rawStringOnly, arguments ...any) error {
if db.usedInTransaction() {
return errTx
}
rows, err := db.pgx.Query(ctx, string(sql), arguments...)
if err != nil {
return err
}
defer rows.Close()
return dbscan.ScanAll(sliceOfStructPtr, dbscanRows{rows})
}
type Tx struct {
pgx.Tx
ctx context.Context
}
// usedInTransaction is a helper to prevent nesting transactions
// & non-transaction calls in transactions. It only checks 20 frames.
// Fast: This memory should all be in CPU Caches.
func (db *DB) usedInTransaction() bool {
var framePtrs = (&[20]uintptr{})[:] // 20 can be stack-local (no alloc)
framePtrs = framePtrs[:runtime.Callers(3, framePtrs)] // skip past our caller.
return lo.Contains(framePtrs, db.BTFP.Load()) // Unsafe read @ beginTx overlap, but 'return false' is correct there.
}
type TransactionOptions struct {
RetrySerializationError bool
InitialSerializationErrorRetryWait time.Duration
}
type TransactionOption func(*TransactionOptions)
func OptionRetry() TransactionOption {
return func(o *TransactionOptions) {
o.RetrySerializationError = true
}
}
func OptionSerialRetryTime(d time.Duration) TransactionOption {
return func(o *TransactionOptions) {
o.InitialSerializationErrorRetryWait = d
}
}
// BeginTransaction is how you can access transactions using this library.
// The entire transaction happens in the function passed in.
// The return must be true or a rollback will occur.
// Be sure to test the error for IsErrSerialization() if you want to retry
//
// when there is a DB serialization error.
//
//go:noinline
func (db *DB) BeginTransaction(ctx context.Context, f func(*Tx) (commit bool, err error), opt ...TransactionOption) (didCommit bool, retErr error) {
db.BTFPOnce.Do(func() {
fp := make([]uintptr, 20)
runtime.Callers(1, fp)
db.BTFP.Store(fp[0])
})
if db.usedInTransaction() {
return false, errTx
}
opts := TransactionOptions{
RetrySerializationError: false,
InitialSerializationErrorRetryWait: 10 * time.Millisecond,
}
for _, o := range opt {
o(&opts)
}
retry:
comm, err := db.transactionInner(ctx, f)
if err != nil && opts.RetrySerializationError && IsErrSerialization(err) {
time.Sleep(opts.InitialSerializationErrorRetryWait)
opts.InitialSerializationErrorRetryWait *= 2
goto retry
}
return comm, err
}
func (db *DB) transactionInner(ctx context.Context, f func(*Tx) (commit bool, err error)) (didCommit bool, retErr error) {
tx, err := db.pgx.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
return false, err
}
var commit bool
defer func() { // Panic clean-up.
if !commit {
if tmp := tx.Rollback(ctx); tmp != nil {
retErr = tmp
}
}
}()
commit, err = f(&Tx{tx, ctx})
if err != nil {
return false, err
}
if commit {
err = tx.Commit(ctx)
if err != nil {
return false, err
}
return true, nil
}
return false, nil
}
// Exec in a transaction.
func (t *Tx) Exec(sql rawStringOnly, arguments ...any) (count int, err error) {
res, err := t.Tx.Exec(t.ctx, string(sql), arguments...)
return int(res.RowsAffected()), err
}
// Query in a transaction.
func (t *Tx) Query(sql rawStringOnly, arguments ...any) (*Query, error) {
q, err := t.Tx.Query(t.ctx, string(sql), arguments...)
return &Query{q}, err
}
// QueryRow in a transaction.
func (t *Tx) QueryRow(sql rawStringOnly, arguments ...any) Row {
return t.Tx.QueryRow(t.ctx, string(sql), arguments...)
}
// Select in a transaction.
func (t *Tx) Select(sliceOfStructPtr any, sql rawStringOnly, arguments ...any) error {
rows, err := t.Query(sql, arguments...)
if err != nil {
return fmt.Errorf("scany: query multiple result rows: %w", err)
}
defer rows.Close()
return dbscan.ScanAll(sliceOfStructPtr, dbscanRows{rows.Qry.(pgx.Rows)})
}
func IsErrUniqueContraint(err error) bool {
var e2 *pgconn.PgError
return errors.As(err, &e2) && e2.Code == pgerrcode.UniqueViolation
}
func IsErrSerialization(err error) bool {
var e2 *pgconn.PgError
return errors.As(err, &e2) && e2.Code == pgerrcode.SerializationFailure
}