-
Notifications
You must be signed in to change notification settings - Fork 109
/
Copy pathsqlite.go
322 lines (273 loc) · 8.39 KB
/
sqlite.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
package sqlite
import (
"context"
"database/sql"
"database/sql/driver"
"embed"
"fmt"
"io/fs"
"log"
"os"
"path/filepath"
"sort"
"time"
"github.com/benbjohnson/wtf"
_ "github.com/mattn/go-sqlite3"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
// Database metrics.
var (
userCountGauge = promauto.NewGauge(prometheus.GaugeOpts{
Name: "wtf_db_users",
Help: "The total number of users",
})
dialCountGauge = promauto.NewGauge(prometheus.GaugeOpts{
Name: "wtf_db_dials",
Help: "The total number of dials",
})
dialMembershipCountGauge = promauto.NewGauge(prometheus.GaugeOpts{
Name: "wtf_db_dial_memberships",
Help: "The total number of dial memberships",
})
)
//go:embed migration/*.sql
var migrationFS embed.FS
// DB represents the database connection.
type DB struct {
db *sql.DB
ctx context.Context // background context
cancel func() // cancel background context
// Datasource name.
DSN string
// Destination for events to be published.
EventService wtf.EventService
// Returns the current time. Defaults to time.Now().
// Can be mocked for tests.
Now func() time.Time
}
// NewDB returns a new instance of DB associated with the given datasource name.
func NewDB(dsn string) *DB {
db := &DB{
DSN: dsn,
Now: time.Now,
EventService: wtf.NopEventService(),
}
db.ctx, db.cancel = context.WithCancel(context.Background())
return db
}
// Open opens the database connection.
func (db *DB) Open() (err error) {
// Ensure a DSN is set before attempting to open the database.
if db.DSN == "" {
return fmt.Errorf("dsn required")
}
// Make the parent directory unless using an in-memory db.
if db.DSN != ":memory:" {
if err := os.MkdirAll(filepath.Dir(db.DSN), 0700); err != nil {
return err
}
}
// Connect to the database.
if db.db, err = sql.Open("sqlite3", db.DSN); err != nil {
return err
}
// Enable WAL. SQLite performs better with the WAL because it allows
// multiple readers to operate while data is being written.
if _, err := db.db.Exec(`PRAGMA journal_mode = wal;`); err != nil {
return fmt.Errorf("enable wal: %w", err)
}
// Enable foreign key checks. For historical reasons, SQLite does not check
// foreign key constraints by default... which is kinda insane. There's some
// overhead on inserts to verify foreign key integrity but it's definitely
// worth it.
if _, err := db.db.Exec(`PRAGMA foreign_keys = ON;`); err != nil {
return fmt.Errorf("foreign keys pragma: %w", err)
}
if err := db.migrate(); err != nil {
return fmt.Errorf("migrate: %w", err)
}
// Monitor stats in background goroutine.
go db.monitor()
return nil
}
// migrate sets up migration tracking and executes pending migration files.
//
// Migration files are embedded in the sqlite/migration folder and are executed
// in lexigraphical order.
//
// Once a migration is run, its name is stored in the 'migrations' table so it
// is not re-executed. Migrations run in a transaction to prevent partial
// migrations.
func (db *DB) migrate() error {
// Ensure the 'migrations' table exists so we don't duplicate migrations.
if _, err := db.db.Exec(`CREATE TABLE IF NOT EXISTS migrations (name TEXT PRIMARY KEY);`); err != nil {
return fmt.Errorf("cannot create migrations table: %w", err)
}
// Read migration files from our embedded file system.
// This uses Go 1.16's 'embed' package.
names, err := fs.Glob(migrationFS, "migration/*.sql")
if err != nil {
return err
}
sort.Strings(names)
// Loop over all migration files and execute them in order.
for _, name := range names {
if err := db.migrateFile(name); err != nil {
return fmt.Errorf("migration error: name=%q err=%w", name, err)
}
}
return nil
}
// migrate runs a single migration file within a transaction. On success, the
// migration file name is saved to the "migrations" table to prevent re-running.
func (db *DB) migrateFile(name string) error {
tx, err := db.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
// Ensure migration has not already been run.
var n int
if err := tx.QueryRow(`SELECT COUNT(*) FROM migrations WHERE name = ?`, name).Scan(&n); err != nil {
return err
} else if n != 0 {
return nil // already run migration, skip
}
// Read and execute migration file.
if buf, err := fs.ReadFile(migrationFS, name); err != nil {
return err
} else if _, err := tx.Exec(string(buf)); err != nil {
return err
}
// Insert record into migrations to prevent re-running migration.
if _, err := tx.Exec(`INSERT INTO migrations (name) VALUES (?)`, name); err != nil {
return err
}
return tx.Commit()
}
// Close closes the database connection.
func (db *DB) Close() error {
// Cancel background context.
db.cancel()
// Close database.
if db.db != nil {
return db.db.Close()
}
return nil
}
// BeginTx starts a transaction and returns a wrapper Tx type. This type
// provides a reference to the database and a fixed timestamp at the start of
// the transaction. The timestamp allows us to mock time during tests as well.
func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
tx, err := db.db.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
// Return wrapper Tx that includes the transaction start time.
return &Tx{
Tx: tx,
db: db,
now: db.Now().UTC().Truncate(time.Second),
}, nil
}
// monitor runs in a goroutine and periodically calculates internal stats.
func (db *DB) monitor() {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for {
select {
case <-db.ctx.Done():
return
case <-ticker.C:
}
if err := db.updateStats(db.ctx); err != nil {
log.Printf("stats error: %s", err)
}
}
}
// updateStats updates the metrics for the database.
func (db *DB) updateStats(ctx context.Context) error {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
var n int
if err := tx.QueryRowContext(ctx, `SELECT COUNT(*) FROM users;`).Scan(&n); err != nil {
return fmt.Errorf("user count: %w", err)
}
userCountGauge.Set(float64(n))
if err := tx.QueryRowContext(ctx, `SELECT COUNT(*) FROM dials;`).Scan(&n); err != nil {
return fmt.Errorf("dial count: %w", err)
}
dialCountGauge.Set(float64(n))
if err := tx.QueryRowContext(ctx, `SELECT COUNT(*) FROM dial_memberships;`).Scan(&n); err != nil {
return fmt.Errorf("dial membership count: %w", err)
}
dialMembershipCountGauge.Set(float64(n))
return nil
}
// Tx wraps the SQL Tx object to provide a timestamp at the start of the transaction.
type Tx struct {
*sql.Tx
db *DB
now time.Time
}
// lastInsertID is a helper function for reading the last inserted ID as an int.
func lastInsertID(result sql.Result) (int, error) {
id, err := result.LastInsertId()
return int(id), err
}
// NullTime represents a helper wrapper for time.Time. It automatically converts
// time fields to/from RFC 3339 format. Also supports NULL for zero time.
type NullTime time.Time
// Scan reads a time value from the database.
func (n *NullTime) Scan(value interface{}) error {
if value == nil {
*(*time.Time)(n) = time.Time{}
return nil
} else if value, ok := value.(string); ok {
*(*time.Time)(n), _ = time.Parse(time.RFC3339, value)
return nil
}
return fmt.Errorf("NullTime: cannot scan to time.Time: %T", value)
}
// Value formats a time value for the database.
func (n *NullTime) Value() (driver.Value, error) {
if n == nil || (*time.Time)(n).IsZero() {
return nil, nil
}
return (*time.Time)(n).UTC().Format(time.RFC3339), nil
}
// FormatLimitOffset returns a SQL string for a given limit & offset.
// Clauses are only added if limit and/or offset are greater than zero.
func FormatLimitOffset(limit, offset int) string {
if limit > 0 && offset > 0 {
return fmt.Sprintf(`LIMIT %d OFFSET %d`, limit, offset)
} else if limit > 0 {
return fmt.Sprintf(`LIMIT %d`, limit)
} else if offset > 0 {
return fmt.Sprintf(`OFFSET %d`, offset)
}
return ""
}
// FormatError returns err as a WTF error, if possible.
// Otherwise returns the original error.
func FormatError(err error) error {
if err == nil {
return nil
}
switch err.Error() {
case "UNIQUE constraint failed: dial_memberships.dial_id, dial_memberships.user_id":
return wtf.Errorf(wtf.ECONFLICT, "Dial membership already exists.")
default:
return err
}
}
// logstr is a helper function for printing and returning a string.
// It can be useful for printing out query text.
func logstr(s string) string {
println(s)
return s
}