forked from canonical/candid
/
backend.go
225 lines (200 loc) · 6.41 KB
/
backend.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
package sqlstore
import (
"bytes"
"database/sql"
"strings"
"text/template"
"time"
"github.com/go-macaroon-bakery/macaroon-bakery/v3/bakery"
"github.com/go-macaroon-bakery/macaroon-bakery/v3/bakery/postgresrootkeystore"
"github.com/juju/aclstore/v2"
"github.com/juju/simplekv/sqlsimplekv"
"github.com/juju/utils/v2/debugstatus"
errgo "gopkg.in/errgo.v1"
"github.com/kian99/candid/meeting"
"github.com/kian99/candid/store"
)
// backend provides a wrapper around an SQL database that can be used
// as the persistent storage for the various types of store required by
// the identity service.
type backend struct {
db *sql.DB
driver *driver
rootKeys *postgresrootkeystore.RootKeys
aclStore aclstore.ACLStore
}
// NewBackend creates a new store.Backend implementation using the
// given driverName and *sql.DB. The driverName must match the value
// used to open the database.
//
// Closing the returned Backend will also close db.
func NewBackend(driverName string, db *sql.DB) (store.Backend, error) {
if driverName != "postgres" {
return nil, errgo.Newf("unsupported database driver %q", driverName)
}
driver, err := newPostgresDriver(db)
if err != nil {
return nil, errgo.Notef(err, "cannot initialise database")
}
rootkeys := postgresrootkeystore.NewRootKeys(db, "rootkeys", 1000)
defer rootkeys.Close()
aclStore, err := sqlsimplekv.NewStore(driverName, db, "acls")
if err != nil {
return nil, errgo.Mask(err)
}
return &backend{
db: db,
driver: driver,
rootKeys: postgresrootkeystore.NewRootKeys(db, "rootkeys", 1000),
aclStore: aclstore.NewACLStore(aclStore),
}, nil
}
func (b *backend) Close() {
b.rootKeys.Close()
b.db.Close()
}
// Store returns a new store.Store implementation using this database for
// persistent storage.
func (b *backend) Store() store.Store {
return &identityStore{b}
}
func (b *backend) BakeryRootKeyStore() bakery.RootKeyStore {
return b.rootKeys.NewStore(postgresrootkeystore.Policy{
ExpiryDuration: 365 * 24 * time.Hour,
})
}
// ProviderDataStore returns a new store.ProviderDataStore implementation
// using this database for persistent storage.
func (b *backend) ProviderDataStore() store.ProviderDataStore {
return &providerDataStore{b}
}
// MeetingStore returns a new meeting.Stor implementation using this
// database for persistent storage.
func (b *backend) MeetingStore() meeting.Store {
return &meetingStore{b}
}
func (b *backend) ACLStore() aclstore.ACLStore {
return b.aclStore
}
// DebugStatusCheckerFuncs implements store.Backend.DebugStatusCheckerFuncs.
func (b *backend) DebugStatusCheckerFuncs() []debugstatus.CheckerFunc {
return nil
}
// withTx runs f in a new transaction. any error returned by f will not
// have it's cause masked.
func (b *backend) withTx(f func(*sql.Tx) error) error {
tx, err := b.db.Begin()
if err != nil {
return errgo.Mask(err)
}
if err := f(tx); err != nil {
if err := tx.Rollback(); err != nil {
logger.Errorf("failed to rollback transaction: %s", err)
}
return errgo.Mask(err, errgo.Any)
}
return errgo.Mask(tx.Commit())
}
type tmplID int
const (
tmplClearIdentitySet tmplID = iota
tmplClearMFACredentials
tmplFindIdentities
tmplFindMeetings
tmplGetMeeting
tmplGetMFACredentials
tmplGetProviderData
tmplGetProviderDataForUpdate
tmplIdentityCounts
tmplIdentityFrom
tmplIdentityID
tmplIncrementMFACredentialSignCount
tmplInsertMFACredential
tmplInsertProviderData
tmplPullIdentitySet
tmplPushIdentitySet
tmplPutMeeting
tmplRemoveMeetings
tmplRemoveMFACredential
tmplSelectIdentitySet
tmplUpdateIdentity
tmplUpsertIdentity
numTmpl
)
type queryer interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
// argBuilder is an interface that can be embedded in template parameters
// to record the arguments needed to be supplied with SQL queries.
type argBuilder interface {
// Arg is a method that is called in templates with the value of
// the next argument to be used in the query. Arg should remember
// the value and return a valid placeholder to access that
// argument when executing the query.
Arg(interface{}) string
// args returns the slice of arguments that should be used when
// executing the query.
args() []interface{}
}
type driver struct {
name string
tmpls [numTmpl]*template.Template
argBuilderFunc func() argBuilder
isDuplicateFunc func(error) bool
}
// exec performs the Exec method on the given queryer by processing the
// given template with the given params to determine the query to
// execute.
func (d *driver) exec(q queryer, tmplID tmplID, params argBuilder) (sql.Result, error) {
query, err := d.executeTemplate(tmplID, params)
if err != nil {
return nil, errgo.Notef(err, "cannot build query")
}
res, err := q.Exec(query, params.args()...)
return res, errgo.Mask(err, errgo.Any)
}
// query performs the Query method on the given queryer by processing the
// given template with the given params to determine the query to
// execute.
func (d *driver) query(q queryer, tmplID tmplID, params argBuilder) (*sql.Rows, error) {
query, err := d.executeTemplate(tmplID, params)
if err != nil {
return nil, errgo.Notef(err, "cannot build query")
}
rows, err := q.Query(query, params.args()...)
return rows, errgo.Mask(err, errgo.Any)
}
// queryRow performs the QueryRow method on the given queryer by
// processing the given template with the given params to determine the
// query to execute.
func (d *driver) queryRow(q queryer, tmplID tmplID, params argBuilder) (*sql.Row, error) {
query, err := d.executeTemplate(tmplID, params)
if err != nil {
return nil, errgo.Notef(err, "cannot build query")
}
return q.QueryRow(query, params.args()...), nil
}
func (d *driver) parseTemplate(tmplID tmplID, tmpl string) error {
var err error
d.tmpls[tmplID], err = template.New("").Funcs(template.FuncMap{
"join": strings.Join,
}).Parse(tmpl)
return errgo.Mask(err)
}
func (d *driver) executeTemplate(tmplID tmplID, params interface{}) (string, error) {
buf := new(bytes.Buffer)
if err := d.tmpls[tmplID].Execute(buf, params); err != nil {
return "", errgo.Mask(err)
}
return buf.String(), nil
}
var comparisons = map[store.Comparison]string{
store.Equal: "=",
store.NotEqual: "<>",
store.GreaterThan: ">",
store.LessThan: "<",
store.GreaterThanOrEqual: ">=",
store.LessThanOrEqual: "<=",
}