-
Notifications
You must be signed in to change notification settings - Fork 84
/
subcriptions_v3.go
280 lines (246 loc) · 7.86 KB
/
subcriptions_v3.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
package cockroach
import (
"context"
"fmt"
"github.com/dpjacques/clockwork"
dsserr "github.com/interuss/dss/pkg/errors"
"github.com/interuss/dss/pkg/geo"
dssmodels "github.com/interuss/dss/pkg/models"
ridmodels "github.com/interuss/dss/pkg/rid/models"
"github.com/golang/geo/s2"
dssql "github.com/interuss/dss/pkg/sql"
"github.com/lib/pq"
"github.com/palantir/stacktrace"
"go.uber.org/zap"
)
const (
subscriptionFieldsV3 = "id, owner, url, notification_index, cells, starts_at, ends_at, updated_at"
updateSubscriptionFieldsV3 = "id, url, notification_index, cells, starts_at, ends_at, updated_at"
)
// subscriptions is an implementation of the SubscriptionRepo for CRDB.
type subscriptionRepoV3 struct {
dssql.Queryable
clock clockwork.Clock
logger *zap.Logger
}
// process a query that should return one or many subscriptions.
func (c *subscriptionRepoV3) process(ctx context.Context, query string, args ...interface{}) ([]*ridmodels.Subscription, error) {
rows, err := c.QueryContext(ctx, query, args...)
if err != nil {
return nil, stacktrace.Propagate(err, fmt.Sprintf("Error in query: %s", query))
}
defer rows.Close()
var payload []*ridmodels.Subscription
cids := pq.Int64Array{}
for rows.Next() {
s := new(ridmodels.Subscription)
err := rows.Scan(
&s.ID,
&s.Owner,
&s.URL,
&s.NotificationIndex,
&cids,
&s.StartTime,
&s.EndTime,
&s.Version,
)
if err != nil {
return nil, stacktrace.Propagate(err, "Error scanning Subscription row")
}
s.SetCells(cids)
payload = append(payload, s)
}
if err := rows.Err(); err != nil {
return nil, stacktrace.Propagate(err, "Error in rows query result")
}
return payload, nil
}
// processOne processes a query that should return exactly a single subscription.
func (c *subscriptionRepoV3) processOne(ctx context.Context, query string, args ...interface{}) (*ridmodels.Subscription, error) {
subs, err := c.process(ctx, query, args...)
if err != nil {
return nil, err // No need to Propagate this error as this stack layer does not add useful information
}
if len(subs) > 1 {
return nil, stacktrace.NewError("Query returned %d subscriptions when only 0 or 1 was expected", len(subs))
}
if len(subs) == 0 {
return nil, nil
}
return subs[0], nil
}
// MaxSubscriptionCountInCellsByOwner counts how many subscriptions the
// owner has in each one of these cells, and returns the number of subscriptions
// in the cell with the highest number of subscriptions.
func (c *subscriptionRepoV3) MaxSubscriptionCountInCellsByOwner(ctx context.Context, cells s2.CellUnion, owner dssmodels.Owner) (int, error) {
// TODO:steeling this query is expensive. The standard defines the max sub
// per "area", but area is loosely defined. Since we may not have to be so
// strict we could keep this count in memory, (or in some other storage).
var query = `
SELECT
IFNULL(MAX(subscriptions_per_cell_id), 0)
FROM (
SELECT
COUNT(*) AS subscriptions_per_cell_id
FROM (
SELECT unnest(cells) as cell_id
FROM subscriptions
WHERE owner = $1
AND ends_at >= $2
)
WHERE
cell_id = ANY($3)
GROUP BY cell_id
)`
cids := make([]int64, len(cells))
for i, cell := range cells {
cids[i] = int64(cell)
}
row := c.QueryRowContext(ctx, query, owner, c.clock.Now(), pq.Int64Array(cids))
var ret int
err := row.Scan(&ret)
return ret, stacktrace.Propagate(err, "Error scanning subscription count row")
}
// GetSubscription returns the subscription identified by "id".
// Returns nil, nil if not found
func (c *subscriptionRepoV3) GetSubscription(ctx context.Context, id dssmodels.ID) (*ridmodels.Subscription, error) {
// TODO(steeling) we should enforce startTime and endTime to not be null at the DB level.
var query = fmt.Sprintf(`
SELECT %s FROM subscriptions
WHERE id = $1`, subscriptionFieldsV3)
return c.processOne(ctx, query, id)
}
// UpdateSubscription updates the Subscription.. not yet implemented.
// Returns nil, nil if ID, version not found
func (c *subscriptionRepoV3) UpdateSubscription(ctx context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) {
var (
updateQuery = fmt.Sprintf(`
UPDATE
subscriptions
SET (%s) = ($1, $2, $3, $4, $5, $6, transaction_timestamp())
WHERE id = $1 AND updated_at = $7
RETURNING
%s`, updateSubscriptionFieldsV3, subscriptionFieldsV3)
)
cids := make([]int64, len(s.Cells))
for i, cell := range s.Cells {
if err := geo.ValidateCell(cell); err != nil {
return nil, stacktrace.Propagate(err, "Error validating cell")
}
cids[i] = int64(cell)
}
return c.processOne(ctx, updateQuery,
s.ID,
s.URL,
s.NotificationIndex,
pq.Int64Array(cids),
s.StartTime,
s.EndTime,
s.Version.ToTimestamp())
}
// InsertSubscription inserts subscription into the store and returns
// the resulting subscription including its ID.
func (c *subscriptionRepoV3) InsertSubscription(ctx context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) {
var (
insertQuery = fmt.Sprintf(`
INSERT INTO
subscriptions
(%s)
VALUES
($1, $2, $3, $4, $5, $6, $7, transaction_timestamp())
RETURNING
%s`, subscriptionFieldsV3, subscriptionFieldsV3)
)
cids := make([]int64, len(s.Cells))
for i, cell := range s.Cells {
if err := geo.ValidateCell(cell); err != nil {
return nil, stacktrace.Propagate(err, "Error validating cell")
}
cids[i] = int64(cell)
}
return c.processOne(ctx, insertQuery,
s.ID,
s.Owner,
s.URL,
s.NotificationIndex,
pq.Int64Array(cids),
s.StartTime,
s.EndTime)
}
// DeleteSubscription deletes the subscription identified by ID.
// It must be done in a txn and the version verified.
// Returns nil, nil if ID, version not found
func (c *subscriptionRepoV3) DeleteSubscription(ctx context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) {
var (
query = fmt.Sprintf(`
DELETE FROM
subscriptions
WHERE
id = $1
AND updated_at = $2
RETURNING %s`, subscriptionFieldsV3)
)
return c.processOne(ctx, query, s.ID, s.Version.ToTimestamp())
}
// UpdateNotificationIdxsInCells incremement the notification for each sub in the given cells.
func (c *subscriptionRepoV3) UpdateNotificationIdxsInCells(ctx context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) {
var updateQuery = fmt.Sprintf(`
UPDATE subscriptions
SET notification_index = notification_index + 1
WHERE
cells && $1
AND ends_at >= $2
RETURNING %s`, subscriptionFieldsV3)
cids := make([]int64, len(cells))
for i, cell := range cells {
cids[i] = int64(cell)
}
return c.process(
ctx, updateQuery, pq.Int64Array(cids), c.clock.Now())
}
// SearchSubscriptions returns all subscriptions in "cells".
func (c *subscriptionRepoV3) SearchSubscriptions(ctx context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) {
var (
query = fmt.Sprintf(`
SELECT
%s
FROM
subscriptions
WHERE
cells && $1
AND
ends_at >= $2`, subscriptionFieldsV3)
)
if len(cells) == 0 {
return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "no location provided")
}
cids := make([]int64, len(cells))
for i, cell := range cells {
cids[i] = int64(cell)
}
return c.process(ctx, query, pq.Int64Array(cids), c.clock.Now())
}
// SearchSubscriptionsByOwner returns all subscriptions in "cells".
func (c *subscriptionRepoV3) SearchSubscriptionsByOwner(ctx context.Context, cells s2.CellUnion, owner dssmodels.Owner) ([]*ridmodels.Subscription, error) {
var (
query = fmt.Sprintf(`
SELECT
%s
FROM
subscriptions
WHERE
cells && $1
AND
subscriptions.owner = $2
AND
ends_at >= $3`, subscriptionFieldsV3)
)
if len(cells) == 0 {
return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "no location provided")
}
cids := make([]int64, len(cells))
for i, cell := range cells {
cids[i] = int64(cell)
}
return c.process(ctx, query, pq.Int64Array(cids), owner, c.clock.Now())
}