forked from alexbakker/log4shell-tools
-
Notifications
You must be signed in to change notification settings - Fork 0
/
db.go
146 lines (121 loc) · 3.15 KB
/
db.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
package storage
import (
"context"
"fmt"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/pgxpool"
)
const (
schema = `
CREATE TABLE IF NOT EXISTS test (
id UUID NOT NULL,
created timestamp NOT NULL DEFAULT timezone('utc'::text, CURRENT_TIMESTAMP),
finished timestamp,
PRIMARY KEY (id)
);
CREATE TABLE IF NOT EXISTS test_result (
id BIGSERIAL NOT NULL,
test_id UUID NOT NULL,
created timestamp NOT NULL DEFAULT timezone('utc'::text, CURRENT_TIMESTAMP),
type TEXT NOT NULL,
addr TEXT,
ptr TEXT,
PRIMARY KEY (id),
FOREIGN KEY (test_id) REFERENCES test (id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS test_result_test_id_idx ON test_result (test_id);
`
)
type DB struct {
p *pgxpool.Pool
}
func NewDB(connStr string) (*DB, error) {
p, err := pgxpool.Connect(context.Background(), connStr)
if err != nil {
return nil, fmt.Errorf("db connect: %s", err)
}
if _, err = p.Exec(context.Background(), schema); err != nil {
return nil, fmt.Errorf("schema init: %s", err)
}
return &DB{p: p}, nil
}
func (db *DB) Close() {
db.p.Close()
}
func (db *DB) Test(ctx context.Context, id uuid.UUID) (*Test, error) {
row := db.p.QueryRow(ctx, `SELECT id, created, finished FROM test WHERE id = $1`, id.String())
var test Test
if err := row.Scan(
&test.ID,
&test.Created,
&test.Finished); err != nil {
if err == pgx.ErrNoRows {
return nil, nil
}
return nil, err
}
return &test, nil
}
func (db *DB) InsertTest(ctx context.Context, id uuid.UUID) error {
_, err := db.p.Exec(ctx, "INSERT INTO test (id) VALUES($1)", id)
return err
}
func (db *DB) InsertTestResult(ctx context.Context, t *Test, resultType string, addr string, ptr *string) error {
_, err := db.p.Exec(ctx, `
INSERT INTO test_result (test_id, type, addr, ptr)
VALUES($1, $2, $3, $4)
`, t.ID, resultType, addr, ptr)
return err
}
func (db *DB) TestResults(ctx context.Context, t *Test) ([]*TestResult, error) {
rows, err := db.p.Query(ctx, `
SELECT created, type, addr, ptr
FROM test_result
WHERE test_id = $1
ORDER BY created ASC
`, t.ID)
if err != nil {
return nil, err
}
defer rows.Close()
var results []*TestResult
for rows.Next() {
var res TestResult
if err = rows.Scan(&res.Created, &res.Type, &res.Addr, &res.Ptr); err != nil {
return nil, err
}
results = append(results, &res)
}
return results, nil
}
func (db *DB) PruneTestResults(ctx context.Context) (int64, error) {
res, err := db.p.Exec(ctx, `
DELETE FROM test
WHERE created < timezone('utc'::text, CURRENT_TIMESTAMP) - '1 day'::interval
`)
if err != nil {
return 0, err
}
return res.RowsAffected(), nil
}
func (db *DB) FinishTest(ctx context.Context, t *Test) error {
_, err := db.p.Exec(ctx, `
UPDATE test
SET finished = timezone('utc'::text, CURRENT_TIMESTAMP)
WHERE id = $1
`, t.ID)
return err
}
func (db *DB) ActiveTests(ctx context.Context, timeout time.Duration) (int64, error) {
var count int64
row := db.p.QueryRow(ctx, `
SELECT count(*)
FROM test
WHERE finished IS NULL
AND created > timezone('utc'::text, CURRENT_TIMESTAMP) - ('1 minute'::interval * $1);
`, int64(timeout.Minutes()))
err := row.Scan(&count)
return count, err
}