Skip to content

Commit

Permalink
Update update sql script to write writer field for Subscriptions table (
Browse files Browse the repository at this point in the history
#410)

* Update update sql script to write writer field for Subscriptions table

* remove unused code
  • Loading branch information
Dos-Ph committed Aug 26, 2020
1 parent ae056da commit c789b2b
Show file tree
Hide file tree
Showing 7 changed files with 325 additions and 33 deletions.
1 change: 1 addition & 0 deletions pkg/rid/models/subscriptions.go
Expand Up @@ -35,6 +35,7 @@ type Subscription struct {
Version *dssmodels.Version
AltitudeHi *float32
AltitudeLo *float32
Writer string
}

// SetCells is a convenience function that accepts an int64 array and converts
Expand Down
8 changes: 5 additions & 3 deletions pkg/rid/server/subscription_handler.go
Expand Up @@ -136,9 +136,10 @@ func (s *Server) CreateSubscription(
}

sub := &ridmodels.Subscription{
ID: id,
Owner: owner,
URL: params.Callbacks.IdentificationServiceAreaUrl,
ID: id,
Owner: owner,
URL: params.Callbacks.IdentificationServiceAreaUrl,
Writer: s.Locality,
}

if err := sub.SetExtents(params.Extents); err != nil {
Expand Down Expand Up @@ -214,6 +215,7 @@ func (s *Server) UpdateSubscription(
Owner: owner,
URL: params.Callbacks.IdentificationServiceAreaUrl,
Version: version,
Writer: s.Locality,
}

if err := sub.SetExtents(params.Extents); err != nil {
Expand Down
10 changes: 1 addition & 9 deletions pkg/rid/store/cockroach/identification_service_area.go
Expand Up @@ -26,10 +26,6 @@ const (
updateISAFields = "id, url, cells, starts_at, ends_at, writer, updated_at"
)

var (
v310 = *semver.New("3.1.0")
)

func NewISARepo(ctx context.Context, db dssql.Queryable, dbVersion semver.Version, logger *zap.Logger) repos.ISA {
if dbVersion.Compare(v310) >= 0 {
return &isaRepo{
Expand Down Expand Up @@ -77,11 +73,7 @@ func (c *isaRepo) process(ctx context.Context, query string, args ...interface{}
if err != nil {
return nil, stacktrace.Propagate(err, "Error scanning ISA row")
}
if writer.Valid {
i.Writer = writer.String
} else {
i.Writer = ""
}
i.Writer = writer.String
i.SetCells(cids)
payload = append(payload, i)
}
Expand Down
20 changes: 7 additions & 13 deletions pkg/rid/store/cockroach/store.go
Expand Up @@ -33,11 +33,13 @@ var (

// DatabaseName is the name of database storing remote ID data.
DatabaseName = "defaultdb"

v310 = *semver.New("3.1.0")
)

type repo struct {
repos.ISA
*subscriptionRepo
repos.Subscription
}

// Store is an implementation of store.Store using Cockroach DB as its backend
Expand Down Expand Up @@ -92,12 +94,8 @@ func (s *Store) Interact(ctx context.Context) (repos.Repository, error) {
}

return &repo{
ISA: NewISARepo(ctx, s.db, *storeVersion, logger),
subscriptionRepo: &subscriptionRepo{
Queryable: s.db,
logger: logger,
clock: s.clock,
},
ISA: NewISARepo(ctx, s.db, *storeVersion, logger),
Subscription: NewISASubscriptionRepo(ctx, s.db, *storeVersion, logger, s.clock),
}, nil
}

Expand All @@ -120,12 +118,8 @@ func (s *Store) Transact(ctx context.Context, f func(repo repos.Repository) erro
// Is this recover still necessary?
defer recoverRollbackRepanic(ctx, tx)
return f(&repo{
ISA: NewISARepo(ctx, tx, *storeVersion, logger),
subscriptionRepo: &subscriptionRepo{
Queryable: tx,
logger: logger,
clock: s.clock,
},
ISA: NewISARepo(ctx, tx, *storeVersion, logger),
Subscription: NewISASubscriptionRepo(ctx, tx, *storeVersion, logger, s.clock),
})
})
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/rid/store/cockroach/store_test.go
Expand Up @@ -209,7 +209,7 @@ func TestBasicTxn(t *testing.T) {
Queryable: tx1,
logger: logging.Logger,
},
subscriptionRepo: &subscriptionRepo{
Subscription: &subscriptionRepo{
Queryable: tx1,
logger: logging.Logger,
clock: DefaultClock,
Expand All @@ -223,7 +223,7 @@ func TestBasicTxn(t *testing.T) {
Queryable: tx2,
logger: logging.Logger,
},
subscriptionRepo: &subscriptionRepo{
Subscription: &subscriptionRepo{
Queryable: tx2,
logger: logging.Logger,
clock: DefaultClock,
Expand Down
280 changes: 280 additions & 0 deletions pkg/rid/store/cockroach/subcriptions_v3.go
@@ -0,0 +1,280 @@
package cockroach

import (
"context"
"fmt"

"github.com/dpjacques/clockwork"
dsserr "github.com/interuss/dss/pkg/errors"
"github.com/interuss/dss/pkg/geo"
dssmodels "github.com/interuss/dss/pkg/models"
ridmodels "github.com/interuss/dss/pkg/rid/models"

"github.com/golang/geo/s2"
dssql "github.com/interuss/dss/pkg/sql"
"github.com/lib/pq"
"github.com/palantir/stacktrace"
"go.uber.org/zap"
)

const (
subscriptionFieldsV3 = "id, owner, url, notification_index, cells, starts_at, ends_at, updated_at"
updateSubscriptionFieldsV3 = "id, url, notification_index, cells, starts_at, ends_at, updated_at"
)

// subscriptions is an implementation of the SubscriptionRepo for CRDB.
type subscriptionRepoV3 struct {
dssql.Queryable

clock clockwork.Clock
logger *zap.Logger
}

// process a query that should return one or many subscriptions.
func (c *subscriptionRepoV3) process(ctx context.Context, query string, args ...interface{}) ([]*ridmodels.Subscription, error) {
rows, err := c.QueryContext(ctx, query, args...)
if err != nil {
return nil, stacktrace.Propagate(err, fmt.Sprintf("Error in query: %s", query))
}
defer rows.Close()

var payload []*ridmodels.Subscription
cids := pq.Int64Array{}

for rows.Next() {
s := new(ridmodels.Subscription)

err := rows.Scan(
&s.ID,
&s.Owner,
&s.URL,
&s.NotificationIndex,
&cids,
&s.StartTime,
&s.EndTime,
&s.Version,
)
if err != nil {
return nil, stacktrace.Propagate(err, "Error scanning Subscription row")
}
s.SetCells(cids)
payload = append(payload, s)
}
if err := rows.Err(); err != nil {
return nil, stacktrace.Propagate(err, "Error in rows query result")
}
return payload, nil
}

// processOne processes a query that should return exactly a single subscription.
func (c *subscriptionRepoV3) processOne(ctx context.Context, query string, args ...interface{}) (*ridmodels.Subscription, error) {
subs, err := c.process(ctx, query, args...)
if err != nil {
return nil, err // No need to Propagate this error as this stack layer does not add useful information
}
if len(subs) > 1 {
return nil, stacktrace.NewError("Query returned %d subscriptions when only 0 or 1 was expected", len(subs))
}
if len(subs) == 0 {
return nil, nil
}
return subs[0], nil
}

// MaxSubscriptionCountInCellsByOwner counts how many subscriptions the
// owner has in each one of these cells, and returns the number of subscriptions
// in the cell with the highest number of subscriptions.
func (c *subscriptionRepoV3) MaxSubscriptionCountInCellsByOwner(ctx context.Context, cells s2.CellUnion, owner dssmodels.Owner) (int, error) {
// TODO:steeling this query is expensive. The standard defines the max sub
// per "area", but area is loosely defined. Since we may not have to be so
// strict we could keep this count in memory, (or in some other storage).
var query = `
SELECT
IFNULL(MAX(subscriptions_per_cell_id), 0)
FROM (
SELECT
COUNT(*) AS subscriptions_per_cell_id
FROM (
SELECT unnest(cells) as cell_id
FROM subscriptions
WHERE owner = $1
AND ends_at >= $2
)
WHERE
cell_id = ANY($3)
GROUP BY cell_id
)`

cids := make([]int64, len(cells))
for i, cell := range cells {
cids[i] = int64(cell)
}

row := c.QueryRowContext(ctx, query, owner, c.clock.Now(), pq.Int64Array(cids))
var ret int
err := row.Scan(&ret)
return ret, stacktrace.Propagate(err, "Error scanning subscription count row")
}

// GetSubscription returns the subscription identified by "id".
// Returns nil, nil if not found
func (c *subscriptionRepoV3) GetSubscription(ctx context.Context, id dssmodels.ID) (*ridmodels.Subscription, error) {
// TODO(steeling) we should enforce startTime and endTime to not be null at the DB level.
var query = fmt.Sprintf(`
SELECT %s FROM subscriptions
WHERE id = $1`, subscriptionFieldsV3)
return c.processOne(ctx, query, id)
}

// UpdateSubscription updates the Subscription.. not yet implemented.
// Returns nil, nil if ID, version not found
func (c *subscriptionRepoV3) UpdateSubscription(ctx context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) {
var (
updateQuery = fmt.Sprintf(`
UPDATE
subscriptions
SET (%s) = ($1, $2, $3, $4, $5, $6, transaction_timestamp())
WHERE id = $1 AND updated_at = $7
RETURNING
%s`, updateSubscriptionFieldsV3, subscriptionFieldsV3)
)

cids := make([]int64, len(s.Cells))

for i, cell := range s.Cells {
if err := geo.ValidateCell(cell); err != nil {
return nil, stacktrace.Propagate(err, "Error validating cell")
}
cids[i] = int64(cell)
}

return c.processOne(ctx, updateQuery,
s.ID,
s.URL,
s.NotificationIndex,
pq.Int64Array(cids),
s.StartTime,
s.EndTime,
s.Version.ToTimestamp())
}

// InsertSubscription inserts subscription into the store and returns
// the resulting subscription including its ID.
func (c *subscriptionRepoV3) InsertSubscription(ctx context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) {
var (
insertQuery = fmt.Sprintf(`
INSERT INTO
subscriptions
(%s)
VALUES
($1, $2, $3, $4, $5, $6, $7, transaction_timestamp())
RETURNING
%s`, subscriptionFieldsV3, subscriptionFieldsV3)
)

cids := make([]int64, len(s.Cells))

for i, cell := range s.Cells {
if err := geo.ValidateCell(cell); err != nil {
return nil, stacktrace.Propagate(err, "Error validating cell")
}
cids[i] = int64(cell)
}

return c.processOne(ctx, insertQuery,
s.ID,
s.Owner,
s.URL,
s.NotificationIndex,
pq.Int64Array(cids),
s.StartTime,
s.EndTime)
}

// DeleteSubscription deletes the subscription identified by ID.
// It must be done in a txn and the version verified.
// Returns nil, nil if ID, version not found
func (c *subscriptionRepoV3) DeleteSubscription(ctx context.Context, s *ridmodels.Subscription) (*ridmodels.Subscription, error) {
var (
query = fmt.Sprintf(`
DELETE FROM
subscriptions
WHERE
id = $1
AND updated_at = $2
RETURNING %s`, subscriptionFieldsV3)
)
return c.processOne(ctx, query, s.ID, s.Version.ToTimestamp())
}

// UpdateNotificationIdxsInCells incremement the notification for each sub in the given cells.
func (c *subscriptionRepoV3) UpdateNotificationIdxsInCells(ctx context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) {
var updateQuery = fmt.Sprintf(`
UPDATE subscriptions
SET notification_index = notification_index + 1
WHERE
cells && $1
AND ends_at >= $2
RETURNING %s`, subscriptionFieldsV3)

cids := make([]int64, len(cells))
for i, cell := range cells {
cids[i] = int64(cell)
}
return c.process(
ctx, updateQuery, pq.Int64Array(cids), c.clock.Now())
}

// SearchSubscriptions returns all subscriptions in "cells".
func (c *subscriptionRepoV3) SearchSubscriptions(ctx context.Context, cells s2.CellUnion) ([]*ridmodels.Subscription, error) {
var (
query = fmt.Sprintf(`
SELECT
%s
FROM
subscriptions
WHERE
cells && $1
AND
ends_at >= $2`, subscriptionFieldsV3)
)

if len(cells) == 0 {
return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "no location provided")
}

cids := make([]int64, len(cells))
for i, cell := range cells {
cids[i] = int64(cell)
}

return c.process(ctx, query, pq.Int64Array(cids), c.clock.Now())
}

// SearchSubscriptionsByOwner returns all subscriptions in "cells".
func (c *subscriptionRepoV3) SearchSubscriptionsByOwner(ctx context.Context, cells s2.CellUnion, owner dssmodels.Owner) ([]*ridmodels.Subscription, error) {
var (
query = fmt.Sprintf(`
SELECT
%s
FROM
subscriptions
WHERE
cells && $1
AND
subscriptions.owner = $2
AND
ends_at >= $3`, subscriptionFieldsV3)
)

if len(cells) == 0 {
return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "no location provided")
}

cids := make([]int64, len(cells))
for i, cell := range cells {
cids[i] = int64(cell)
}

return c.process(ctx, query, pq.Int64Array(cids), owner, c.clock.Now())
}

0 comments on commit c789b2b

Please sign in to comment.