Skip to content

Commit

Permalink
fix(dp): return 409 in CRUD when cbsd serial number exists
Browse files Browse the repository at this point in the history
  • Loading branch information
Wojciech Sadowy committed Apr 6, 2022
1 parent bf16d6a commit e7bf1f1
Show file tree
Hide file tree
Showing 11 changed files with 147 additions and 23 deletions.
2 changes: 1 addition & 1 deletion dp/cloud/go/services/dp/dp/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func main() {
if err != nil {
glog.Fatalf("Error opening db connection: %s", err)
}
cbsdStore := dp_storage.NewCbsdManager(db, sqorc.GetSqlBuilder())
cbsdStore := dp_storage.NewCbsdManager(db, sqorc.GetSqlBuilder(), sqorc.GetErrorChecker())

interval := time.Second * time.Duration(serviceConfig.CbsdInactivityIntervalSec)
protos.RegisterCbsdManagementServer(srv.GrpcServer, servicers.NewCbsdManager(cbsdStore, interval))
Expand Down
4 changes: 3 additions & 1 deletion dp/cloud/go/services/dp/obsidian/cbsd/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func createCbsd(c echo.Context) error {
ctx := c.Request().Context()
_, err = client.CreateCbsd(ctx, &req)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err)
return getHttpError(err)
}
return c.NoContent(http.StatusCreated)
}
Expand Down Expand Up @@ -217,6 +217,8 @@ func getHttpError(err error) error {
switch s, _ := status.FromError(err); s.Code() {
case codes.NotFound:
return echo.NewHTTPError(http.StatusNotFound, err)
case codes.AlreadyExists:
return echo.NewHTTPError(http.StatusConflict, err)
default:
return echo.NewHTTPError(http.StatusInternalServerError, err)
}
Expand Down
2 changes: 2 additions & 0 deletions dp/cloud/go/services/dp/servicers/cbsd_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ func makeErr(err error, wrap string) error {
code := codes.Internal
if err == merrors.ErrNotFound {
code = codes.NotFound
} else if err == merrors.ErrAlreadyExists {
code = codes.AlreadyExists
}
return status.Error(code, e.Error())
}
30 changes: 16 additions & 14 deletions dp/cloud/go/services/dp/storage/cbsd_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,20 @@ type DetailedCbsd struct {
GrantState *DBGrantState
}

func NewCbsdManager(db *sql.DB, builder sqorc.StatementBuilder) *cbsdManager {
func NewCbsdManager(db *sql.DB, builder sqorc.StatementBuilder, errorChecker sqorc.ErrorChecker) *cbsdManager {
return &cbsdManager{
db: db,
builder: builder,
cache: &enumCache{cache: map[string]map[string]int64{}},
db: db,
builder: builder,
cache: &enumCache{cache: map[string]map[string]int64{}},
errorChecker: errorChecker,
}
}

type cbsdManager struct {
db *sql.DB
builder sqorc.StatementBuilder
cache *enumCache
db *sql.DB
builder sqorc.StatementBuilder
cache *enumCache
errorChecker sqorc.ErrorChecker
}

type enumCache struct {
Expand All @@ -67,7 +69,7 @@ func (c *cbsdManager) CreateCbsd(networkId string, data *DBCbsd) error {
err := runner.createCbsdWithActiveModeConfig(networkId, data)
return nil, err
})
return makeError(err)
return makeError(err, c.errorChecker)
}

func (c *cbsdManager) UpdateCbsd(networkId string, id int64, data *DBCbsd) error {
Expand All @@ -76,7 +78,7 @@ func (c *cbsdManager) UpdateCbsd(networkId string, id int64, data *DBCbsd) error
err := runner.updateCbsd(networkId, id, data)
return nil, err
})
return makeError(err)
return makeError(err, c.errorChecker)
}

func (c *cbsdManager) DeleteCbsd(networkId string, id int64) error {
Expand All @@ -85,7 +87,7 @@ func (c *cbsdManager) DeleteCbsd(networkId string, id int64) error {
err := runner.markCbsdAsDeleted(networkId, id)
return nil, err
})
return makeError(err)
return makeError(err, c.errorChecker)
}

func (c *cbsdManager) FetchCbsd(networkId string, id int64) (*DetailedCbsd, error) {
Expand All @@ -94,7 +96,7 @@ func (c *cbsdManager) FetchCbsd(networkId string, id int64) (*DetailedCbsd, erro
return runner.fetchDetailedCbsd(networkId, id)
})
if err != nil {
return nil, makeError(err)
return nil, makeError(err, c.errorChecker)
}
return cbsd.(*DetailedCbsd), nil
}
Expand All @@ -105,7 +107,7 @@ func (c *cbsdManager) ListCbsd(networkId string, pagination *Pagination) (*Detai
return runner.listDetailedCbsd(networkId, pagination)
})
if err != nil {
return nil, makeError(err)
return nil, makeError(err, c.errorChecker)
}
return cbsds.(*DetailedCbsdList), nil
}
Expand Down Expand Up @@ -290,11 +292,11 @@ func countCbsds(networkId string, builder sq.StatementBuilderType) (int64, error
Count()
}

func makeError(err error) error {
func makeError(err error, checker sqorc.ErrorChecker) error {
if err == sql.ErrNoRows {
return merrors.ErrNotFound
}
return err
return checker.GetError(err)
}

func getCbsdFiltersWithId(networkId string, id int64) sq.Eq {
Expand Down
34 changes: 32 additions & 2 deletions dp/cloud/go/services/dp/storage/cbsd_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License.
package storage_test

import (
"fmt"
"testing"
"time"

Expand All @@ -40,9 +41,10 @@ type CbsdManagerTestSuite struct {

func (s *CbsdManagerTestSuite) SetupSuite() {
builder := sqorc.GetSqlBuilder()
errorChecker := sqorc.SQLiteErrorChecker{}
database, err := sqorc.Open("sqlite3", ":memory:")
s.Require().NoError(err)
s.cbsdManager = storage.NewCbsdManager(database, builder)
s.cbsdManager = storage.NewCbsdManager(database, builder, errorChecker)
s.resourceManager = dbtest.NewResourceManager(s.T(), database, builder)
err = s.resourceManager.CreateTables(
&storage.DBCbsdState{},
Expand Down Expand Up @@ -151,6 +153,30 @@ func (s *CbsdManagerTestSuite) TestCreateCbsd() {
s.Require().NoError(err)
}

func (s *CbsdManagerTestSuite) TestCreateCbsdWithExistingSerialNumber() {
err := s.cbsdManager.CreateCbsd(someNetwork, getBaseCbsd())
s.Require().NoError(err)
err = s.cbsdManager.CreateCbsd(someNetwork, getBaseCbsd())
s.Assert().ErrorIs(err, merrors.ErrAlreadyExists)
}

func (s *CbsdManagerTestSuite) TestUpdateCbsdWithSerialNumberOfExistingCbsd() {
cbsd1 := getBaseCbsd()
cbsd1.Id = db.MakeInt(1)
cbsd2 := getBaseCbsd()
cbsd2.Id = db.MakeInt(2)
cbsd2.CbsdSerialNumber = db.MakeString("cbsd_serial_number2")
err := s.cbsdManager.CreateCbsd(someNetwork, cbsd1)
s.Require().NoError(err)
err = s.cbsdManager.CreateCbsd(someNetwork, cbsd2)
s.Require().NoError(err)

cbsd2.CbsdSerialNumber = cbsd1.CbsdSerialNumber

err = s.cbsdManager.UpdateCbsd(someNetwork, cbsd2.Id.Int64, cbsd2)
s.Assert().ErrorIs(err, merrors.ErrAlreadyExists)
}

func (s *CbsdManagerTestSuite) TestUpdateCbsd() {
var cbsdId int64
err := s.resourceManager.InTransaction(func() {
Expand Down Expand Up @@ -344,6 +370,8 @@ func (s *CbsdManagerTestSuite) TestListWithPagination() {
for i := range models {
cbsd := getCbsd(someNetwork, stateId)
cbsd.Id = db.MakeInt(int64(i + 1))

cbsd.CbsdSerialNumber = db.MakeString(fmt.Sprintf("some_serial_number%d", i+1))
models[i] = cbsd
}
err := s.resourceManager.InsertResources(db.NewExcludeMask(), models...)
Expand All @@ -363,8 +391,10 @@ func (s *CbsdManagerTestSuite) TestListWithPagination() {
Cbsds: make([]*storage.DetailedCbsd, limit),
}
for i := range expected.Cbsds {
cbsd := getDetailedCbsd(int64(i + 1 + offset))
cbsd.CbsdSerialNumber = db.MakeString(fmt.Sprintf("some_serial_number%d", i+1+offset))
expected.Cbsds[i] = &storage.DetailedCbsd{
Cbsd: getDetailedCbsd(int64(i + 1 + offset)),
Cbsd: cbsd,
CbsdState: &storage.DBCbsdState{Name: db.MakeString("unregistered")},
Grant: &storage.DBGrant{},
GrantState: &storage.DBGrantState{},
Expand Down
1 change: 1 addition & 0 deletions dp/cloud/go/services/dp/storage/db/fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ type Field struct {
Nullable bool
HasDefault bool
DefaultValue interface{}
Unique bool
}

func (f *Field) GetValue() interface{} {
Expand Down
4 changes: 4 additions & 0 deletions dp/cloud/go/services/dp/storage/db/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ func addColumns(builder sqorc.CreateTableBuilder, fields FieldMap) sqorc.CreateT
colBuilder = colBuilder.Default(field.DefaultValue)
}
builder = colBuilder.EndColumn()

if field.Unique {
builder = builder.Unique(column)
}
}
return builder
}
Expand Down
1 change: 1 addition & 0 deletions dp/cloud/go/services/dp/storage/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ func (c *DBCbsd) Fields() db.FieldMap {
"cbsd_serial_number": &db.Field{
BaseType: db.StringType{X: &c.CbsdSerialNumber},
Nullable: true,
Unique: true,
},
"last_seen": &db.Field{
BaseType: db.TimeType{X: &c.LastSeen},
Expand Down
29 changes: 25 additions & 4 deletions dp/cloud/python/magma/test_runner/tests/test_dp_with_orc8r.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,27 @@ def test_frequency_preferences(self):

self.delete_cbsd(cbsd_id)

def test_creating_cbsd_with_the_same_unique_fields_returns_409(self):
builder = CbsdAPIDataBuilder()

self.when_cbsd_is_created(builder.build_post_data())
self.when_cbsd_is_created(builder.build_post_data(), expected_status=HTTPStatus.CONFLICT)

def test_updating_cbsd_returns_409_when_setting_existing_serial_num(self):
builder = CbsdAPIDataBuilder()

cbsd1_payload = builder.build_post_data()
cbsd2_payload = builder.build_post_data()
cbsd1_payload["serial_number"] = "foo"
cbsd2_payload["serial_number"] = "bar"
self.when_cbsd_is_created(cbsd1_payload) # cbsd_id = 1
self.when_cbsd_is_created(cbsd2_payload) # cbsd_id = 2
self.when_cbsd_is_updated(
cbsd_id=2,
data=cbsd1_payload,
expected_status=HTTPStatus.CONFLICT,
)

def test_fetching_logs_with_custom_filters(self):
self.given_cbsd_provisioned(CbsdAPIDataBuilder())
self.when_elastic_indexes_data(keep_alive=False)
Expand Down Expand Up @@ -194,9 +215,9 @@ def when_elastic_indexes_data(self, *, keep_alive: bool):
if keep_alive:
self.when_cbsd_asks_for_state()

def when_cbsd_is_created(self, data: Dict[str, Any]):
def when_cbsd_is_created(self, data: Dict[str, Any], expected_status: int = HTTPStatus.CREATED):
r = send_request_to_backend('post', 'cbsds', json=data)
self.assertEqual(r.status_code, HTTPStatus.CREATED)
self.assertEqual(r.status_code, expected_status)

def when_cbsd_is_fetched(self) -> Dict[str, Any]:
r = send_request_to_backend('get', 'cbsds')
Expand All @@ -217,9 +238,9 @@ def when_cbsd_is_deleted(self, cbsd_id: int):
r = send_request_to_backend('delete', f'cbsds/{cbsd_id}')
self.assertEqual(r.status_code, HTTPStatus.NO_CONTENT)

def when_cbsd_is_updated(self, cbsd_id: int, data: Dict[str, Any]):
def when_cbsd_is_updated(self, cbsd_id: int, data: Dict[str, Any], expected_status: int = HTTPStatus.NO_CONTENT):
r = send_request_to_backend('put', f'cbsds/{cbsd_id}', json=data)
self.assertEqual(r.status_code, HTTPStatus.NO_CONTENT)
self.assertEqual(r.status_code, expected_status)

def when_cbsd_asks_for_state(self) -> CBSDStateResult:
return self.dp_client.GetCBSDState(get_cbsd_request())
Expand Down
23 changes: 22 additions & 1 deletion orc8r/cloud/go/sqorc/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,16 @@ import (
)

const (
SQLDialectEnv = "SQL_DIALECT"
PostgresDialect = "psql"
MariaDialect = "maria"
SQLiteDialect = "sqlite"
)

// GetSqlBuilder returns a squirrel Builder for the configured SQL dialect as
// found in the SQL_DIALECT env var.
func GetSqlBuilder() StatementBuilder {
dialect, envFound := os.LookupEnv("SQL_DIALECT")
dialect, envFound := os.LookupEnv(SQLDialectEnv)
// Default to postgresql
if !envFound {
return NewPostgresStatementBuilder()
Expand All @@ -50,6 +52,25 @@ func GetSqlBuilder() StatementBuilder {
}
}

// GetErrorChecker returns a squirrel Builder for the configured SQL dialect as
// found in the SQL_DIALECT env var.
func GetErrorChecker() ErrorChecker {
dialect, envFound := os.LookupEnv(SQLDialectEnv)
// Default to postgresql
if !envFound {
return PostgresErrorChecker{}
}

switch strings.ToLower(dialect) {
case PostgresDialect:
return PostgresErrorChecker{}
case SQLiteDialect:
return SQLiteErrorChecker{}
default:
panic(fmt.Sprintf("unsupported sql dialect %s", dialect))
}
}

// StatementBuilder is an interface which tracks squirrel's
// StatementBuilderType with the difference that Insert returns this package's
// InsertBuilder interface type.
Expand Down
40 changes: 40 additions & 0 deletions orc8r/cloud/go/sqorc/error_checker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package sqorc

import (
"magma/orc8r/lib/go/merrors"

"github.com/lib/pq"
"github.com/mattn/go-sqlite3"
)

const (
uniqueViolation = "unique_violation"
)

type ErrorChecker interface {
GetError(error) error
}

type SQLiteErrorChecker struct{}

type PostgresErrorChecker struct{}

func (c SQLiteErrorChecker) GetError(err error) error {
if e, ok := err.(sqlite3.Error); ok {
switch e.Code {
case sqlite3.ErrConstraint:
return merrors.ErrAlreadyExists
}
}
return err
}

func (c PostgresErrorChecker) GetError(err error) error {
if e, ok := err.(*pq.Error); ok {
switch e.Code.Name() {
case uniqueViolation:
return merrors.ErrAlreadyExists
}
}
return err
}

0 comments on commit e7bf1f1

Please sign in to comment.