Skip to content

Commit

Permalink
Merge pull request #8 from embrace-io/upsert
Browse files Browse the repository at this point in the history
Upserts
  • Loading branch information
juansc committed Dec 6, 2018
2 parents ba58e68 + de70906 commit 4fe3c4a
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 1 deletion.
22 changes: 21 additions & 1 deletion dbr_test.go
Expand Up @@ -7,8 +7,8 @@ import (
"testing"
"time"

_ "github.com/go-sql-driver/mysql"
"github.com/embrace-io/dbr/dialect"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -84,6 +84,8 @@ func reset(t *testing.T, sess *Session) {
time_val timestamp NULL ,
bool_val bool NULL
)`, autoIncrementType),
`DROP TABLE IF EXISTS dbr_keys`,
`CREATE TABLE dbr_keys (key_value varchar(255) PRIMARY KEY, val_value varchar(255))`,
} {
_, err := sess.Exec(v)
require.NoError(t, err)
Expand Down Expand Up @@ -204,3 +206,21 @@ func TestTimeout(t *testing.T) {
require.Equal(t, context.DeadlineExceeded, err)
}
}

func TestOnConflict(t *testing.T) {
for _, sess := range testSession {
if sess.Dialect == dialect.SQLite3 || sess.Dialect == dialect.Clickhouse {
continue
}
for i := 0; i < 2; i++ {
b := sess.InsertInto("dbr_keys").Columns("key_value", "val_value").Values("key", "value")
b.OnConflict("dbr_keys_pkey").Action("val_value", Expr("CONCAT(?, 2)", Proposed("val_value")))
_, err := b.Exec()
require.NoError(t, err)
}
var value string
_, err := sess.SelectBySql("SELECT val_value FROM dbr_keys WHERE key_value=?", "key").Load(&value)
require.NoError(t, err)
require.Equal(t, "value2", value)
}
}
3 changes: 3 additions & 0 deletions dialect.go
Expand Up @@ -14,6 +14,9 @@ type Dialect interface {

Placeholder(n int) string

OnConflict(constraint string) string
Proposed(column string) string

CombinedOffset() bool
SupportsOn() bool
}
8 changes: 8 additions & 0 deletions dialect/clickhouse.go
Expand Up @@ -23,3 +23,11 @@ func (d clickhouse) SupportsOn() bool {
func (d clickhouse) CombinedOffset() bool {
return true
}

func (d clickhouse) OnConflict(_ string) string {
return ""
}

func (d clickhouse) Proposed(_ string) string {
return ""
}
8 changes: 8 additions & 0 deletions dialect/mysql.go
Expand Up @@ -72,3 +72,11 @@ func (d mysql) SupportsOn() bool {
func (d mysql) CombinedOffset() bool {
return false
}

func (d mysql) OnConflict(_ string) string {
return "ON DUPLICATE KEY UPDATE"
}

func (d mysql) Proposed(column string) string {
return fmt.Sprintf("VALUES(%s)", d.QuoteIdent(column))
}
8 changes: 8 additions & 0 deletions dialect/postgresql.go
Expand Up @@ -43,3 +43,11 @@ func (d postgreSQL) SupportsOn() bool {
func (d postgreSQL) CombinedOffset() bool {
return false
}

func (d postgreSQL) OnConflict(constraint string) string {
return fmt.Sprintf("ON CONFLICT ON CONSTRAINT %s DO UPDATE SET", d.QuoteIdent(constraint))
}

func (d postgreSQL) Proposed(column string) string {
return fmt.Sprintf("EXCLUDED.%s", d.QuoteIdent(column))
}
8 changes: 8 additions & 0 deletions dialect/sqlite3.go
Expand Up @@ -46,3 +46,11 @@ func (d sqlite3) SupportsOn() bool {
func (d sqlite3) CombinedOffset() bool {
return false
}

func (d sqlite3) OnConflict(_ string) string {
return ""
}

func (d sqlite3) Proposed(_ string) string {
return ""
}
57 changes: 57 additions & 0 deletions insert.go
Expand Up @@ -3,10 +3,17 @@ package dbr
import (
"context"
"database/sql"
"fmt"
"reflect"
"strings"
)

// ConflictStmt is ` ON CONFLICT ...` part of InsertStmt
type ConflictStmt struct {
constraint string
actions map[string]interface{}
}

// InsertStmt builds `INSERT INTO ...`.
type InsertStmt struct {
runner
Expand All @@ -20,6 +27,16 @@ type InsertStmt struct {
Value [][]interface{}
ReturnColumn []string
RecordID *int64

Conflict *ConflictStmt
}

// Proposed is reference to proposed value in on conflict clause
func Proposed(column string) Builder {
return BuildFunc(func(d Dialect, b Buffer) error {
_, err := b.WriteString(d.Proposed(column))
return err
})
}

type InsertBuilder = InsertStmt
Expand Down Expand Up @@ -64,6 +81,29 @@ func (b *InsertStmt) Build(d Dialect, buf Buffer) error {
buf.WriteValue(tuple...)
}

if b.Conflict != nil && len(b.Conflict.actions) > 0 {
keyword := d.OnConflict(b.Conflict.constraint)
if len(keyword) == 0 {
return fmt.Errorf("Dialect %s does not support OnConflict", d)
}
buf.WriteString(" ")
buf.WriteString(keyword)
buf.WriteString(" ")
needComma := false
for _, column := range b.Column {
if v, ok := b.Conflict.actions[column]; ok {
if needComma {
buf.WriteString(",")
}
buf.WriteString(d.QuoteIdent(column))
buf.WriteString("=")
buf.WriteString(placeholder)
buf.WriteValue(v)
needComma = true
}
}
}

if len(b.ReturnColumn) > 0 {
buf.WriteString(" RETURNING ")
for i, col := range b.ReturnColumn {
Expand Down Expand Up @@ -224,3 +264,20 @@ func (b *InsertStmt) LoadContext(ctx context.Context, value interface{}) error {
func (b *InsertStmt) Load(value interface{}) error {
return b.LoadContext(context.Background(), value)
}

// OnConflictMap allows to add actions for constraint violation, e.g UPSERT
func (b *InsertStmt) OnConflictMap(constraint string, actions map[string]interface{}) *InsertStmt {
b.Conflict = &ConflictStmt{constraint: constraint, actions: actions}
return b
}

// OnConflict creates an empty OnConflict section fo insert statement , e.g UPSERT
func (b *InsertStmt) OnConflict(constraint string) *ConflictStmt {
return b.OnConflictMap(constraint, make(map[string]interface{})).Conflict
}

// Action adds action for column which will do if conflict happens
func (b *ConflictStmt) Action(column string, action interface{}) *ConflictStmt {
b.actions[column] = action
return b
}

0 comments on commit 4fe3c4a

Please sign in to comment.