Skip to content

Commit

Permalink
fix: return ID on create (#590)
Browse files Browse the repository at this point in the history
* fix: return ID with CRDB and Postgres on create

* fix: remove IDField in sqlite create as well

* ci: trigger

* fix: string ID type

* fix: use dynamic ID field in Model.whereNamedID
  • Loading branch information
zepatrik committed Sep 19, 2020
1 parent 9d9abfa commit b1085ba
Show file tree
Hide file tree
Showing 10 changed files with 103 additions and 37 deletions.
14 changes: 6 additions & 8 deletions dialect_cockroach.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,30 +68,28 @@ func (p *cockroach) Create(s store, model *Model, cols columns.Columns) error {
keyType := model.PrimaryKeyType()
switch keyType {
case "int", "int64":
cols.Remove("id")
id := struct {
ID int `db:"id"`
}{}
cols.Remove(model.IDField())
w := cols.Writeable()
var query string
if len(w.Cols) > 0 {
query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) returning id", p.Quote(model.TableName()), w.QuotedString(p), w.SymbolizedString())
query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) returning %s", p.Quote(model.TableName()), w.QuotedString(p), w.SymbolizedString(), model.IDField())
} else {
query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES returning id", p.Quote(model.TableName()))
query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES returning %s", p.Quote(model.TableName()), model.IDField())
}
log(logging.SQL, query)
stmt, err := s.PrepareNamed(query)
if err != nil {
return err
}
err = stmt.Get(&id, model.Value)
id := map[string]interface{}{}
err = stmt.QueryRow(model.Value).MapScan(id)
if err != nil {
if err := stmt.Close(); err != nil {
return errors.WithMessage(err, "failed to close statement")
}
return err
}
model.setID(id.ID)
model.setID(id[model.IDField()])
return errors.WithMessage(stmt.Close(), "failed to close statement")
}
return genericCreate(s, model, cols, p)
Expand Down
3 changes: 2 additions & 1 deletion dialect_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ func genericCreate(s store, model *Model, cols columns.Columns, quoter quotable)
switch keyType {
case "int", "int64":
var id int64
cols.Remove(model.IDField())
w := cols.Writeable()
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", quoter.Quote(model.TableName()), w.QuotedString(quoter), w.SymbolizedString())
log(logging.SQL, query)
Expand Down Expand Up @@ -76,7 +77,7 @@ func genericCreate(s store, model *Model, cols columns.Columns, quoter quotable)
return fmt.Errorf("missing ID value")
}
w := cols.Writeable()
w.Add("id")
w.Add(model.IDField())
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", quoter.Quote(model.TableName()), w.QuotedString(quoter), w.SymbolizedString())
log(logging.SQL, query)
stmt, err := s.PrepareNamed(query)
Expand Down
14 changes: 6 additions & 8 deletions dialect_postgresql.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,30 +55,28 @@ func (p *postgresql) Create(s store, model *Model, cols columns.Columns) error {
keyType := model.PrimaryKeyType()
switch keyType {
case "int", "int64":
cols.Remove("id")
id := struct {
ID int `db:"id"`
}{}
cols.Remove(model.IDField())
w := cols.Writeable()
var query string
if len(w.Cols) > 0 {
query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) returning id", p.Quote(model.TableName()), w.QuotedString(p), w.SymbolizedString())
query = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s) returning %s", p.Quote(model.TableName()), w.QuotedString(p), w.SymbolizedString(), model.IDField())
} else {
query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES returning id", p.Quote(model.TableName()))
query = fmt.Sprintf("INSERT INTO %s DEFAULT VALUES returning %s", p.Quote(model.TableName()), model.IDField())
}
log(logging.SQL, query)
stmt, err := s.PrepareNamed(query)
if err != nil {
return err
}
err = stmt.Get(&id, model.Value)
id := map[string]interface{}{}
err = stmt.QueryRow(model.Value).MapScan(id)
if err != nil {
if err := stmt.Close(); err != nil {
return errors.WithMessage(err, "failed to close statement")
}
return err
}
model.setID(id.ID)
model.setID(id[model.IDField()])
return errors.WithMessage(stmt.Close(), "failed to close statement")
}
return genericCreate(s, model, cols, p)
Expand Down
1 change: 1 addition & 0 deletions dialect_sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ func (m *sqlite) Create(s store, model *Model, cols columns.Columns) error {
switch keyType {
case "int", "int64":
var id int64
cols.Remove(model.IDField())
w := cols.Writeable()
var query string
if len(w.Cols) > 0 {
Expand Down
46 changes: 46 additions & 0 deletions executors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,52 @@ func Test_Create_With_Slice(t *testing.T) {
})
}

func Test_Create_With_Non_ID_PK(t *testing.T) {
if PDB == nil {
t.Skip("skipping integration tests")
}
transaction(func(tx *Connection) {
r := require.New(t)

count, _ := tx.Count(&CrookedColour{})
djs := []CrookedColour{
{Name: "Phil Slabber"},
{Name: "Leon Debaughn"},
{Name: "Liam Merrett-Park"},
}
err := tx.Create(&djs)
r.NoError(err)

ctx, _ := tx.Count(&CrookedColour{})
r.Equal(count+3, ctx)
r.NotEqual(djs[0].ID, djs[1].ID)
r.NotEqual(djs[1].ID, djs[2].ID)
})
}

func Test_Create_With_Non_ID_PK_String(t *testing.T) {
if PDB == nil {
t.Skip("skipping integration tests")
}
transaction(func(tx *Connection) {
r := require.New(t)

count, _ := tx.Count(&CrookedSong{})
djs := []CrookedSong{
{ID: "Flow"},
{ID: "Do It Like You"},
{ID: "I C Light"},
}
err := tx.Create(&djs)
r.NoError(err)

ctx, _ := tx.Count(&CrookedSong{})
r.Equal(count+3, ctx)
r.NotEqual(djs[0].ID, djs[1].ID)
r.NotEqual(djs[1].ID, djs[2].ID)
})
}

func Test_Eager_Create_Has_Many(t *testing.T) {
if PDB == nil {
t.Skip("skipping integration tests")
Expand Down
2 changes: 1 addition & 1 deletion model.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ func (m *Model) whereID() string {
}

func (m *Model) whereNamedID() string {
return fmt.Sprintf("%s.%s = :id", m.TableName(), m.IDField())
return fmt.Sprintf("%s.%s = :%s", m.TableName(), m.IDField(), m.IDField())
}

func (m *Model) isSlice() bool {
Expand Down
36 changes: 17 additions & 19 deletions packrd/packed-packr.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions pop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,3 +391,16 @@ type Parent struct {
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
Students []*Student `many_to_many:"parents_students"`
}

type CrookedColour struct {
ID int `db:"pk"`
Name string `db:"name"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt time.Time `db:"updated_at"`
}

type CrookedSong struct {
ID string `db:"name"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt time.Time `db:"updated_at"`
}
3 changes: 3 additions & 0 deletions testdata/migrations/20200914115538_crooked_colours.down.fizz
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
drop_table("crooked_colours")

drop_table("crooked_songs")
8 changes: 8 additions & 0 deletions testdata/migrations/20200914115538_crooked_colours.up.fizz
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
create_table("crooked_colours") {
t.Column("pk", "int", { primary: true })
t.Column("name", "string", {})
}

create_table("crooked_songs") {
t.Column("name", "string", { primary: true })
}

0 comments on commit b1085ba

Please sign in to comment.