Skip to content

Commit

Permalink
Refactor FirstOrCreate, FirstOrInit
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Apr 26, 2022
1 parent 0211ac9 commit 6a6dfda
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 16 deletions.
24 changes: 12 additions & 12 deletions finisher_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
})

if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 {
if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 {
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok {
tx.assignInterfacesToValue(where.Exprs)
Expand All @@ -312,25 +312,26 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {

// FirstOrCreate gets the first matched record or create a new one with given conditions (only works with struct, map conditions)
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
queryTx := db.Limit(1).Order(clause.OrderByColumn{
tx = db.getInstance()
queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
})
if tx = queryTx.Find(dest, conds...); tx.Error == nil {
if tx.RowsAffected == 0 {
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
if result := queryTx.Find(dest, conds...); result.Error == nil {
if result.RowsAffected == 0 {
if c, ok := result.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok {
tx.assignInterfacesToValue(where.Exprs)
result.assignInterfacesToValue(where.Exprs)
}
}

// initialize with attrs, conds
if len(tx.Statement.attrs) > 0 {
tx.assignInterfacesToValue(tx.Statement.attrs...)
if len(db.Statement.attrs) > 0 {
result.assignInterfacesToValue(db.Statement.attrs...)
}

// initialize with attrs, conds
if len(tx.Statement.assigns) > 0 {
tx.assignInterfacesToValue(tx.Statement.assigns...)
if len(db.Statement.assigns) > 0 {
result.assignInterfacesToValue(db.Statement.assigns...)
}

return tx.Create(dest)
Expand All @@ -351,8 +352,7 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {

return tx.Model(dest).Updates(assigns)
} else {
// can not use Find RowsAffected
tx.RowsAffected = 0
tx.Error = result.Error
}
}
return tx
Expand Down
7 changes: 3 additions & 4 deletions tests/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@ require (
github.com/google/uuid v1.3.0
github.com/jinzhu/now v1.1.5
github.com/lib/pq v1.10.5
github.com/mattn/go-sqlite3 v1.14.12 // indirect
golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect
gorm.io/driver/mysql v1.3.3
gorm.io/driver/postgres v1.3.4
gorm.io/driver/sqlite v1.3.1
gorm.io/driver/postgres v1.3.5
gorm.io/driver/sqlite v1.3.2
gorm.io/driver/sqlserver v1.3.2
gorm.io/gorm v1.23.3
gorm.io/gorm v1.23.4
)

replace gorm.io/gorm => ../

0 comments on commit 6a6dfda

Please sign in to comment.