-
Notifications
You must be signed in to change notification settings - Fork 0
/
attack.go
179 lines (145 loc) · 3.6 KB
/
attack.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
package flare
import (
"context"
"fmt"
"log"
"net/url"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/pgxpool"
)
const flareDatabaseSchema = `
CREATE TABLE IF NOT EXISTS items (
id TEXT PRIMARY KEY
, name TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY
, last_keepalive_at TIMESTAMP WITH TIME ZONE NOT NULL
);
`
type TrafficGenerator struct {
pool *pgxpool.Pool
name string
}
func NewTrafficGenerator(pool *pgxpool.Pool, name string) *TrafficGenerator {
return &TrafficGenerator{pool: pool, name: name}
}
func (g *TrafficGenerator) KeepAlive(ctx context.Context) error {
for {
select {
case <-ctx.Done():
log.Printf("Stop sending heartbeat...")
return nil
default:
}
if err := g.SendHeartBeat(ctx); err != nil {
log.Printf("Failed to write a new item: %s", err)
}
time.Sleep(100 * time.Millisecond)
}
}
func (g *TrafficGenerator) Attack(ctx context.Context) error {
for {
select {
case <-ctx.Done():
log.Printf("Stop writing new items...")
return nil
default:
}
if err := g.WriteNewItem(ctx); err != nil {
log.Printf("Failed to write a new item: %s", err)
}
}
}
func (g *TrafficGenerator) SendHeartBeat(ctx context.Context) error {
txctx, cancel := context.WithCancel(ctx)
defer cancel()
tx, err := g.pool.Begin(txctx)
if err != nil {
return fmt.Errorf("beginning a new transaction: %w", err)
}
if _, err := tx.Exec(
txctx, `
INSERT into sessions values ($1, $2)
ON CONFLICT (id)
DO
UPDATE SET last_keepalive_at = $2
;`,
g.name,
time.Now(),
); err != nil {
return fmt.Errorf("updating the session: %w", err)
}
if err := tx.Commit(txctx); err != nil {
return fmt.Errorf("commiting the item: %w", err)
}
return nil
}
func (g *TrafficGenerator) WriteNewItem(ctx context.Context) error {
txctx, cancel := context.WithCancel(ctx)
defer cancel()
tx, err := g.pool.Begin(txctx)
if err != nil {
return fmt.Errorf("beginning a new transaction: %w", err)
}
if _, err := tx.Exec(
txctx,
`INSERT into items values($1, $2);`,
uuid.NewString(),
uuid.NewString(),
); err != nil {
return fmt.Errorf("inserting a new item: %w", err)
}
if err := tx.Commit(txctx); err != nil {
return fmt.Errorf("commiting the item: %w", err)
}
return nil
}
func CreateTestTable(ctx context.Context, baseDSN, appUser string, dropDBBefore bool) error {
const dbName = "flare_test"
dsn, err := switchDatabase(baseDSN, "postgres")
if err != nil {
return err
}
conn, err := pgx.Connect(ctx, dsn)
if err != nil {
return err
}
defer conn.Close(ctx)
if dropDBBefore {
if _, err = conn.Exec(ctx, `DROP DATABASE flare_test;`); err != nil {
return fmt.Errorf("dropping a database: %w", err)
}
}
if _, err = conn.Exec(ctx, `CREATE DATABASE flare_test;`); err != nil {
return fmt.Errorf("creating a database: %w", err)
}
dsn, err = switchDatabase(baseDSN, dbName)
if err != nil {
return err
}
newConn, err := pgx.Connect(ctx, dsn)
if err != nil {
return fmt.Errorf("chaging to the new database: %w", err)
}
if _, err := newConn.Exec(ctx, flareDatabaseSchema); err != nil {
return fmt.Errorf("creating tables: %w", err)
}
if _, err := newConn.Exec(
ctx,
fmt.Sprintf(`GRANT ALL ON ALL TABLES In SCHEMA public TO %s;`, quoteIdentifier(appUser)),
); err != nil {
return fmt.Errorf("granting access to the app user: %w", err)
}
return nil
}
func switchDatabase(baseDSN, dbName string) (string, error) {
dsn, err := url.Parse(baseDSN)
if err != nil {
return "", fmt.Errorf("parsing the base DSN: %s", err)
}
dsn.Path = dbName
return dsn.String(), nil
}