Skip to content

Commit

Permalink
fix: return error when ID field is missing (#566)
Browse files Browse the repository at this point in the history
Instead of optimistically guessing that the ID of a model is an int when the ID field is not defined, this patch now explicitly returns an error. This resolves several edgecase scenarios and closes #565
  • Loading branch information
aeneasr committed Jun 29, 2020
1 parent 4900410 commit faf98ab
Show file tree
Hide file tree
Showing 11 changed files with 85 additions and 29 deletions.
5 changes: 4 additions & 1 deletion dialect_cockroach.go
Expand Up @@ -68,7 +68,10 @@ func (p *cockroach) Details() *ConnectionDetails {
}

func (p *cockroach) Create(s store, model *Model, cols columns.Columns) error {
keyType := model.PrimaryKeyType()
keyType, err := model.PrimaryKeyType()
if err != nil {
return err
}
switch keyType {
case "int", "int64":
cols.Remove("id")
Expand Down
5 changes: 4 additions & 1 deletion dialect_common.go
Expand Up @@ -44,7 +44,10 @@ func (commonDialect) Quote(key string) string {
}

func genericCreate(s store, model *Model, cols columns.Columns, quoter quotable) error {
keyType := model.PrimaryKeyType()
keyType, err := model.PrimaryKeyType()
if err != nil {
return err
}
switch keyType {
case "int", "int64":
var id int64
Expand Down
5 changes: 4 additions & 1 deletion dialect_postgresql.go
Expand Up @@ -55,7 +55,10 @@ func (p *postgresql) Details() *ConnectionDetails {
}

func (p *postgresql) Create(s store, model *Model, cols columns.Columns) error {
keyType := model.PrimaryKeyType()
keyType, err := model.PrimaryKeyType()
if err != nil {
return err
}
switch keyType {
case "int", "int64":
cols.Remove("id")
Expand Down
5 changes: 4 additions & 1 deletion dialect_sqlite.go
Expand Up @@ -64,7 +64,10 @@ func (m *sqlite) MigrationURL() string {

func (m *sqlite) Create(s store, model *Model, cols columns.Columns) error {
return m.locker(m.smGil, func() error {
keyType := model.PrimaryKeyType()
keyType, err := model.PrimaryKeyType()
if err != nil {
return err
}
switch keyType {
case "int", "int64":
var id int64
Expand Down
23 changes: 13 additions & 10 deletions finders.go
Expand Up @@ -38,13 +38,16 @@ func (q *Query) Find(model interface{}, id interface{}) error {
// Pick argument type based on column type. This is required for keeping backwards compatibility with:
//
// https://github.com/gobuffalo/buffalo/blob/master/genny/resource/templates/use_model/actions/resource-name.go.tmpl#L76
switch m.ID().(type) {
case int32, int64, uint32, uint64, int8, uint8, int16, uint16, int:
pkt, err := m.PrimaryKeyType()
if err != nil {
return err
}

switch pkt {
case "int32", "int64", "uint32", "uint64", "int8", "uint8", "int16", "uint16", "int":
if tid, ok := id.(string); ok {
var err error
id, err = strconv.Atoi(tid)
if err == nil {
return q.Where(idq, id).First(model)
if intID, err := strconv.Atoi(tid); err == nil {
return q.Where(idq, intID).First(model)
}
}
}
Expand Down Expand Up @@ -144,7 +147,7 @@ func (q *Query) All(models interface{}) error {
})

if err != nil {
return errors.Wrap(err, "unable to fetch records")
return err //errors.Wrap(err, "unable to fetch records")
}

if q.eager {
Expand Down Expand Up @@ -291,7 +294,7 @@ func (q *Query) eagerDefaultAssociations(model interface{}) error {
// q.Where("name = ?", "mark").Exists(&User{})
func (q *Query) Exists(model interface{}) (bool, error) {
tmpQuery := Q(q.Connection)
q.Clone(tmpQuery) //avoid meddling with original query
q.Clone(tmpQuery) // avoid meddling with original query

var res bool

Expand Down Expand Up @@ -337,7 +340,7 @@ func (q Query) Count(model interface{}) (int, error) {
// q.Where("sex = ?", "f").Count(&User{}, "name")
func (q Query) CountByField(model interface{}, field string) (int, error) {
tmpQuery := Q(q.Connection)
q.Clone(tmpQuery) //avoid meddling with original query
q.Clone(tmpQuery) // avoid meddling with original query

res := &rowCount{}

Expand All @@ -346,7 +349,7 @@ func (q Query) CountByField(model interface{}, field string) (int, error) {
tmpQuery.orderClauses = clauses{}
tmpQuery.limitResults = 0
query, args := tmpQuery.ToSQL(&Model{Value: model})
//when query contains custom selected fields / executed using RawQuery,
// when query contains custom selected fields / executed using RawQuery,
// sql may already contains limit and offset

if rLimitOffset.MatchString(query) {
Expand Down
42 changes: 34 additions & 8 deletions finders_test.go
Expand Up @@ -34,6 +34,32 @@ func Test_Find(t *testing.T) {
})
}

func Test_Create_MissingID(t *testing.T) {
if PDB == nil {
t.Skip("skipping integration tests")
}
transaction(func(tx *Connection) {
r := require.New(t)
client := Client{ClientID: "client-0001"}
err := tx.Create(&client)
r.Error(err)
r.Contains(err.Error(), "model *pop.Client is missing required field ID")
})
}

func Test_Find_MissingID(t *testing.T) {
if PDB == nil {
t.Skip("skipping integration tests")
}
transaction(func(tx *Connection) {
r := require.New(t)
r.NoError(tx.RawQuery("INSERT INTO clients (id) VALUES (?)", "client-0001").Exec())

u := Client{}
r.EqualError(tx.Find(&u, "client-0001"), "model *pop.Client is missing required field ID")
})
}

func Test_Find_LeadingZeros(t *testing.T) {
if PDB == nil {
t.Skip("skipping integration tests")
Expand Down Expand Up @@ -275,7 +301,7 @@ func Test_Find_Eager_Has_One(t *testing.T) {
r.Equal(u.Name.String, "Mark")
r.Equal(u.FavoriteSong.ID, coolSong.ID)

//eager should work with rawquery
// eager should work with rawquery
uid := u.ID
u = User{}
err = tx.RawQuery("select * from users where id=?", uid).First(&u)
Expand Down Expand Up @@ -419,14 +445,14 @@ func Test_Find_Eager_Many_To_Many(t *testing.T) {
err = tx.Create(&ownerProperty2)
r.NoError(err)

//eager should work with rawquery
// eager should work with rawquery
uid := u.ID
u = User{}
err = tx.RawQuery("select * from users where id=?", uid).Eager("Houses").First(&u)
r.NoError(err)
r.Equal(1, len(u.Houses))

//eager ALL
// eager ALL
var users []User
err = tx.RawQuery("select * from users order by created_at asc").Eager("Houses").All(&users)
r.NoError(err)
Expand Down Expand Up @@ -731,15 +757,15 @@ func Test_Count_Disregards_Pagination(t *testing.T) {

q := tx.Paginate(1, 3)
r.NoError(q.All(&firstUsers))
r.Equal(len(names), q.Paginator.TotalEntriesSize) //ensure paginator populates count
r.Equal(len(names), q.Paginator.TotalEntriesSize) // ensure paginator populates count
r.Equal(3, len(firstUsers))

firstUsers = Users{}
q = tx.RawQuery("select * from users").Paginate(1, 3)
r.NoError(q.All(&firstUsers))
r.Equal(1, q.Paginator.Page)
r.Equal(3, q.Paginator.PerPage)
r.Equal(len(names), q.Paginator.TotalEntriesSize) //ensure paginator populates count
r.Equal(len(names), q.Paginator.TotalEntriesSize) // ensure paginator populates count

r.Equal(3, len(firstUsers))
totalFirstPage := q.Paginator.TotalPages
Expand All @@ -758,7 +784,7 @@ func Test_Count_Disregards_Pagination(t *testing.T) {
q = tx.RawQuery("select * from users limit 2").Paginate(1, 5)
err := q.All(&firstUsers)
r.NoError(err)
r.Equal(2, len(firstUsers)) //raw query limit applies
r.Equal(2, len(firstUsers)) // raw query limit applies

firstUsers = Users{}
q = tx.RawQuery("select * from users limit 2 offset 1").Paginate(1, 5)
Expand All @@ -782,7 +808,7 @@ func Test_Count_Disregards_Pagination(t *testing.T) {
firstUsers = Users{}
q = tx.RawQuery(`select * from users limit 2 offset
1
`).Paginate(1, 5) //ending space and tab
`).Paginate(1, 5) // ending space and tab
err = q.All(&firstUsers)
r.NoError(err)
r.Equal(2, len(firstUsers))
Expand All @@ -806,7 +832,7 @@ func Test_Count_Disregards_Pagination(t *testing.T) {
q = tx.RawQuery("select * from users FETCH FIRST 3 rows only").Paginate(1, 5)
err = q.All(&firstUsers)
r.NoError(err)
r.Equal(3, len(firstUsers)) //should fetch only 3
r.Equal(3, len(firstUsers)) // should fetch only 3
}
})
}
Expand Down
11 changes: 6 additions & 5 deletions model.go
Expand Up @@ -2,6 +2,7 @@ package pop

import (
"fmt"
"github.com/pkg/errors"
"reflect"
"sync"
"time"
Expand Down Expand Up @@ -34,9 +35,9 @@ type Model struct {
func (m *Model) ID() interface{} {
fbn, err := m.fieldByName("ID")
if err != nil {
return 0
return nil
}
if m.PrimaryKeyType() == "UUID" {
if pkt, _ := m.PrimaryKeyType(); pkt == "UUID" {
return fbn.Interface().(uuid.UUID).String()
}
return fbn.Interface()
Expand All @@ -57,12 +58,12 @@ func (m *Model) IDField() string {
}

// PrimaryKeyType gives the primary key type of the `Model`.
func (m *Model) PrimaryKeyType() string {
func (m *Model) PrimaryKeyType() (string, error) {
fbn, err := m.fieldByName("ID")
if err != nil {
return "int"
return "", errors.Errorf("model %T is missing required field ID", m.Value)
}
return fbn.Type().Name()
return fbn.Type().Name(), nil
}

// TableNameAble interface allows for the customize table mapping
Expand Down
11 changes: 10 additions & 1 deletion pop_test.go
Expand Up @@ -7,11 +7,12 @@ import (
"time"

"github.com/gobuffalo/nulls"
"github.com/gobuffalo/pop/v5/logging"
"github.com/gobuffalo/validate/v3"
"github.com/gobuffalo/validate/v3/validators"
"github.com/gofrs/uuid"
"github.com/stretchr/testify/suite"

"github.com/gobuffalo/pop/v5/logging"
)

var PDB *Connection
Expand Down Expand Up @@ -73,6 +74,14 @@ func ts(s string) string {
return PDB.Dialect.TranslateSQL(s)
}

type Client struct {
ClientID string `db:"id"`
}

func (c Client) TableName() string {
return "clients"
}

type User struct {
ID int `db:"id"`
UserName string `db:"user_name"`
Expand Down
2 changes: 1 addition & 1 deletion test.sh
Expand Up @@ -38,7 +38,7 @@ function cleanup {
trap cleanup EXIT

docker-compose up -d
sleep 4 # Ensure mysql is online
sleep 5 # Ensure mysql is online

go build -v -tags sqlite -o tsoda ./soda

Expand Down
1 change: 1 addition & 0 deletions testdata/migrations/20200621140800_clients.down.fizz
@@ -0,0 +1 @@
drop_table("clients")
4 changes: 4 additions & 0 deletions testdata/migrations/20200621140800_clients.up.fizz
@@ -0,0 +1,4 @@
create_table("clients") {
t.Column("id", "string", {"length": 32, "primary": true})
t.DisableTimestamps()
}

0 comments on commit faf98ab

Please sign in to comment.