diff --git a/builder_mysql.go b/builder_mysql.go index 174d5f9..e4f79d9 100644 --- a/builder_mysql.go +++ b/builder_mysql.go @@ -90,6 +90,10 @@ func (b *MysqlBuilder) Upsert(table string, cols Params, constraints ...string) } q.sql += " ON DUPLICATE KEY UPDATE " + strings.Join(lines, ", ") + s, p := b.db.processSQL(q.sql) + q.placeholders = p + q.rawSQL = s + q.Bind(q.params) return q } diff --git a/model_query.go b/model_query.go index 97bb2d9..60173dd 100644 --- a/model_query.go +++ b/model_query.go @@ -172,3 +172,27 @@ func (q *ModelQuery) Delete() error { _, err := q.builder.Delete(q.model.tableName, HashExp(pk)).WithContext(q.ctx).Execute() return err } + +// Upsert creates a Query that represents an UPSERT SQL statement. +// Upsert inserts a row into the table if the primary key or unique index is not found. +// Otherwise it will update the row with the new values. +// The keys of cols are the column names, while the values of cols are the corresponding column +// values to be inserted. +func (q *ModelQuery) Upsert(attrs ...string) error { + if q.lastError != nil { + return q.lastError + } + pk := q.model.pk() + if len(pk) == 0 { + return MissingPKError + } + var pks []string + + cols := q.model.columns(attrs, q.exclude) + for name := range pk { + cols[name] = pk[name] + pks = append(pks, name) + } + _, err := q.builder.Upsert(q.model.tableName, Params(cols), pks...).WithContext(q.ctx).Execute() + return err +} diff --git a/model_query_test.go b/model_query_test.go index 039c592..d5d933b 100644 --- a/model_query_test.go +++ b/model_query_test.go @@ -236,3 +236,105 @@ func TestModelQuery_Delete(t *testing.T) { var a int assert.NotNil(t, db.Model(&a).Delete()) } + +func TestModelQuery_Upsert(t *testing.T) { + db := getPreparedDB() + defer db.Close() + + id := 2 + name := "test" + email := "test@example.com" + { + // updating normally + customer := Customer{ + ID: id, + Name: name, + Email: email, + } + err := db.Model(&customer).Upsert() + if assert.Nil(t, err) { + var c Customer + db.Select().From("customer").Where(HashExp{"ID": id}).One(&c) + assert.Equal(t, name, c.Name) + assert.Equal(t, email, c.Email) + assert.Equal(t, 0, c.Status) + } + } + + { + // updating without primary keys + item2 := Item{ + Name: name, + } + err := db.Model(&item2).Upsert() + assert.Equal(t, MissingPKError, err) + } + + { + // updating all fields + customer := CustomerPtr{ + ID: &id, + Name: name, + Email: &email, + } + err := db.Model(&customer).Upsert() + if assert.Nil(t, err) { + assert.Equal(t, id, *customer.ID) + var c CustomerPtr + db.Select().From("customer").Where(HashExp{"ID": id}).One(&c) + assert.Equal(t, name, c.Name) + if assert.NotNil(t, c.Email) { + assert.Equal(t, email, *c.Email) + } + assert.Nil(t, c.Status) + } + } + + { + // updating selected fields only + id = 3 + customer := CustomerPtr{ + ID: &id, + Name: name, + Email: &email, + } + err := db.Model(&customer).Upsert("Name", "Email") + if assert.Nil(t, err) { + assert.Equal(t, id, *customer.ID) + var c CustomerPtr + db.Select().From("customer").Where(HashExp{"ID": id}).One(&c) + assert.Equal(t, name, c.Name) + if assert.NotNil(t, c.Email) { + assert.Equal(t, email, *c.Email) + } + if assert.NotNil(t, c.Status) { + assert.Equal(t, 2, *c.Status) + } + } + } + + { + // inserting normally + customer := Customer{ + ID: 5, + Name: name, + Email: email, + } + err := db.Model(&customer).Upsert() + if assert.Nil(t, err) { + assert.Equal(t, 5, customer.ID) + var c Customer + db.Select().From("customer").Where(HashExp{"ID": 5}).One(&c) + assert.Equal(t, name, c.Name) + assert.Equal(t, email, c.Email) + assert.Equal(t, 0, c.Status) + assert.False(t, c.Address.Valid) + } + } + + { + // updating non-struct + var a int + assert.NotNil(t, db.Model(&a).Upsert()) + } +}