/
postgres.go
140 lines (115 loc) · 2.3 KB
/
postgres.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
package synlock
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/pgxpool"
)
var (
ErrPostgresInvalidAddr = errors.New("invalid postgres address")
)
var DefPostgresOpts = PostgresOpts{
Host: "127.0.0.1",
Port: "5432",
DB: "postgres",
User: "postgres",
}
type PostgresOpts struct {
Host string
Port string
DB string
User string
Pass string
MaxConnections int
}
type Postgres struct {
client *pgxpool.Pool
}
func NewPostgres(conf PostgresOpts) (_ *Postgres, err error) {
if conf.Host == "" || conf.Port == "" {
return nil, ErrPostgresInvalidAddr
}
var auth string
if conf.User != "" {
auth += conf.User
auth += ":" + conf.Pass
auth += "@"
}
var (
connString = fmt.Sprintf("postgres://%s%s:%s/%s", auth, conf.Host, conf.Port, conf.DB)
)
cfg, err := pgxpool.ParseConfig(connString)
if err != nil {
return nil, err
}
if conf.MaxConnections > 0 {
cfg.MaxConns = int32(conf.MaxConnections)
}
conn, err := pgxpool.ConnectConfig(context.Background(), cfg)
if err != nil {
return nil, fmt.Errorf("unable to connect to database: %v", err)
}
return &Postgres{
client: conn,
}, nil
}
func (r *Postgres) NewMutex(key int64) (Mutex, error) {
return &PostgresMutex{
client: r.client,
key: key,
}, nil
}
type PostgresMutex struct {
client *pgxpool.Pool
key int64
mu sync.Mutex
tx pgx.Tx
}
func (s *PostgresMutex) Lock() error {
s.mu.Lock()
return s.lock()
}
func (s *PostgresMutex) Unlock() error {
defer s.mu.Unlock()
return s.unlock()
}
func (s *PostgresMutex) lock() error {
var (
err error
ok bool
jitter time.Duration
)
for {
if jitter > 0 {
time.Sleep(jitter)
}
s.tx, err = s.client.Begin(context.Background())
if err != nil {
return err
}
err = s.tx.QueryRow(context.Background(), "SELECT pg_try_advisory_xact_lock($1)", s.key).Scan(&ok)
if err != nil {
return err
}
if ok {
return nil
}
if err = s.tx.Rollback(context.Background()); err != nil {
return err
}
switch {
case jitter == 0:
jitter = 10 * time.Millisecond
case jitter > time.Second:
jitter = time.Second
default:
jitter *= 2
}
}
}
func (s *PostgresMutex) unlock() error {
return s.tx.Rollback(context.Background())
}