forked from go-micro/go-micro
/
cockroach.go
224 lines (201 loc) · 4.83 KB
/
cockroach.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
// Package cockroach implements the cockroach store
package cockroach
import (
"database/sql"
"fmt"
"strings"
"time"
"unicode"
"github.com/lib/pq"
"github.com/micro/go-micro/store"
"github.com/micro/go-micro/util/log"
"github.com/pkg/errors"
)
// DefaultNamespace is the namespace that the sql store
// will use if no namespace is provided.
var (
DefaultNamespace = "micro"
DefaultPrefix = "micro"
)
type sqlStore struct {
db *sql.DB
database string
table string
options store.Options
}
// List all the known records
func (s *sqlStore) List() ([]*store.Record, error) {
rows, err := s.db.Query(fmt.Sprintf("SELECT key, value, expiry FROM %s.%s;", s.database, s.table))
var records []*store.Record
var timehelper pq.NullTime
if err != nil {
if err == sql.ErrNoRows {
return records, nil
}
return nil, err
}
defer rows.Close()
for rows.Next() {
record := &store.Record{}
if err := rows.Scan(&record.Key, &record.Value, &timehelper); err != nil {
return records, err
}
if timehelper.Valid {
if timehelper.Time.Before(time.Now()) {
// record has expired
go s.Delete(record.Key)
} else {
record.Expiry = time.Until(timehelper.Time)
records = append(records, record)
}
} else {
records = append(records, record)
}
}
rowErr := rows.Close()
if rowErr != nil {
// transaction rollback or something
return records, rowErr
}
if err := rows.Err(); err != nil {
return records, err
}
return records, nil
}
// Read all records with keys
func (s *sqlStore) Read(keys ...string) ([]*store.Record, error) {
q, err := s.db.Prepare(fmt.Sprintf("SELECT key, value, expiry FROM %s.%s WHERE key = $1;", s.database, s.table))
if err != nil {
return nil, err
}
var records []*store.Record
var timehelper pq.NullTime
for _, key := range keys {
row := q.QueryRow(key)
record := &store.Record{}
if err := row.Scan(&record.Key, &record.Value, &timehelper); err != nil {
if err == sql.ErrNoRows {
return records, store.ErrNotFound
}
return records, err
}
if timehelper.Valid {
if timehelper.Time.Before(time.Now()) {
// record has expired
go s.Delete(key)
return records, store.ErrNotFound
}
record.Expiry = time.Until(timehelper.Time)
records = append(records, record)
} else {
records = append(records, record)
}
}
return records, nil
}
// Write records
func (s *sqlStore) Write(rec ...*store.Record) error {
q, err := s.db.Prepare(fmt.Sprintf(`INSERT INTO %s.%s(key, value, expiry)
VALUES ($1, $2::bytea, $3)
ON CONFLICT (key)
DO UPDATE
SET value = EXCLUDED.value, expiry = EXCLUDED.expiry;`, s.database, s.table))
if err != nil {
return err
}
for _, r := range rec {
var err error
if r.Expiry != 0 {
_, err = q.Exec(r.Key, r.Value, time.Now().Add(r.Expiry))
} else {
_, err = q.Exec(r.Key, r.Value, nil)
}
if err != nil {
return errors.Wrap(err, "Couldn't insert record "+r.Key)
}
}
return nil
}
// Delete records with keys
func (s *sqlStore) Delete(keys ...string) error {
q, err := s.db.Prepare(fmt.Sprintf("DELETE FROM %s.%s WHERE key = $1;", s.database, s.table))
if err != nil {
return err
}
for _, key := range keys {
result, err := q.Exec(key)
if err != nil {
return err
}
_, err = result.RowsAffected()
if err != nil {
return err
}
}
return nil
}
func (s *sqlStore) initDB() {
// Create the namespace's database
_, err := s.db.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s ;", s.database))
if err != nil {
log.Fatal(err)
}
_, err = s.db.Exec(fmt.Sprintf("SET DATABASE = %s ;", s.database))
if err != nil {
log.Fatal(errors.Wrap(err, "Couldn't set database"))
}
// Create a table for the namespace's prefix
_, err = s.db.Exec(fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s
(
key text NOT NULL,
value bytea,
expiry timestamp with time zone,
CONSTRAINT %s_pkey PRIMARY KEY (key)
);`, s.table, s.table))
if err != nil {
log.Fatal(errors.Wrap(err, "Couldn't create table"))
}
}
// New returns a new micro Store backed by sql
func New(opts ...store.Option) store.Store {
var options store.Options
for _, o := range opts {
o(&options)
}
nodes := options.Nodes
if len(nodes) == 0 {
nodes = []string{"localhost:26257"}
}
namespace := options.Namespace
if len(namespace) == 0 {
namespace = DefaultNamespace
}
prefix := options.Prefix
if len(prefix) == 0 {
prefix = DefaultPrefix
}
for _, r := range namespace {
if !unicode.IsLetter(r) {
log.Fatal("store.namespace must only contain letters")
}
}
source := nodes[0]
if !strings.Contains(source, " ") {
source = fmt.Sprintf("host=%s", source)
}
// create source from first node
db, err := sql.Open("postgres", source)
if err != nil {
log.Fatal(err)
}
if err := db.Ping(); err != nil {
log.Fatal(err)
}
s := &sqlStore{
db: db,
database: namespace,
table: prefix,
}
s.initDB()
return s
}