Skip to content
This repository has been archived by the owner on Jun 14, 2019. It is now read-only.

Commit

Permalink
Add insert select support (#39)
Browse files Browse the repository at this point in the history
* add insert select support

* refactor insert select

* improve sort

* update test

* fix

* hide fiddle sql tests

* update README

* update README
  • Loading branch information
lunny committed Sep 28, 2018
1 parent 377feed commit 395bcf3
Show file tree
Hide file tree
Showing 13 changed files with 148 additions and 62 deletions.
6 changes: 6 additions & 0 deletions README.md
Expand Up @@ -13,6 +13,12 @@ Make sure you have installed Go 1.8+ and then:

```Go
sql, args, err := builder.Insert(Eq{"c": 1, "d": 2}).Into("table1").ToSQL()

// INSERT INTO table1 SELECT * FROM table2
sql, err := builder.Insert().Into("table1").Select().From("table2").ToBoundSQL()

// INSERT INTO table1 (a, b) SELECT b, c FROM table2
sql, err = builder.Insert("a, b").Into("table1").Select("b, c").From("table2").ToBoundSQL()
```

# Select
Expand Down
60 changes: 50 additions & 10 deletions builder.go
Expand Up @@ -7,6 +7,7 @@ package builder
import (
sql2 "database/sql"
"fmt"
"sort"
)

type optype byte
Expand Down Expand Up @@ -49,14 +50,16 @@ type Builder struct {
optype
dialect string
isNested bool
tableName string
into string
from string
subQuery *Builder
cond Cond
selects []string
joins []join
unions []union
limitation *limit
inserts Eq
insertCols []string
insertVals []interface{}
updates []Eq
orderBy string
groupBy string
Expand Down Expand Up @@ -111,15 +114,15 @@ func (b *Builder) From(subject interface{}, alias ...string) *Builder {
b.subQuery = subject.(*Builder)

if len(alias) > 0 {
b.tableName = alias[0]
b.from = alias[0]
} else {
b.isNested = true
}
case string:
b.tableName = subject.(string)
b.from = subject.(string)

if len(alias) > 0 {
b.tableName = b.tableName + " " + alias[0]
b.from = b.from + " " + alias[0]
}
}

Expand All @@ -128,12 +131,15 @@ func (b *Builder) From(subject interface{}, alias ...string) *Builder {

// TableName returns the table name
func (b *Builder) TableName() string {
return b.tableName
if b.optype == insertType {
return b.into
}
return b.from
}

// Into sets insert table name
func (b *Builder) Into(tableName string) *Builder {
b.tableName = tableName
b.into = tableName
return b
}

Expand Down Expand Up @@ -221,7 +227,9 @@ func (b *Builder) FullJoin(joinTable string, joinCond interface{}) *Builder {
// Select sets select SQL
func (b *Builder) Select(cols ...string) *Builder {
b.selects = cols
b.optype = selectType
if b.optype == condType {
b.optype = selectType
}
return b
}

Expand All @@ -238,8 +246,40 @@ func (b *Builder) Or(cond Cond) *Builder {
}

// Insert sets insert SQL
func (b *Builder) Insert(eq Eq) *Builder {
b.inserts = eq
func (b *Builder) Insert(eq ...interface{}) *Builder {
if len(eq) > 0 {
var paramType = -1
for _, e := range eq {
switch t := e.(type) {
case Eq:
if paramType == -1 {
paramType = 0
}
if paramType != 0 {
break
}
for k, v := range t {
b.insertCols = append(b.insertCols, k)
b.insertVals = append(b.insertVals, v)
}
case string:
if paramType == -1 {
paramType = 1
}
if paramType != 1 {
break
}
b.insertCols = append(b.insertCols, t)
}
}
}

if len(b.insertCols) == len(b.insertVals) {
sort.Slice(b.insertVals, func(i, j int) bool {
return b.insertCols[i] < b.insertCols[j]
})
sort.Strings(b.insertCols)
}
b.optype = insertType
return b
}
Expand Down
4 changes: 2 additions & 2 deletions builder_delete.go
Expand Up @@ -15,11 +15,11 @@ func Delete(conds ...Cond) *Builder {
}

func (b *Builder) deleteWriteTo(w Writer) error {
if len(b.tableName) <= 0 {
if len(b.from) <= 0 {
return ErrNoTableName
}

if _, err := fmt.Fprintf(w, "DELETE FROM %s WHERE ", b.tableName); err != nil {
if _, err := fmt.Fprintf(w, "DELETE FROM %s WHERE ", b.from); err != nil {
return err
}

Expand Down
38 changes: 28 additions & 10 deletions builder_insert.go
Expand Up @@ -10,30 +10,49 @@ import (
)

// Insert creates an insert Builder
func Insert(eq Eq) *Builder {
func Insert(eq ...interface{}) *Builder {
builder := &Builder{cond: NewCond()}
return builder.Insert(eq)
return builder.Insert(eq...)
}

func (b *Builder) insertSelectWriteTo(w Writer) error {
if _, err := fmt.Fprintf(w, "INSERT INTO %s ", b.into); err != nil {
return err
}

if len(b.insertCols) > 0 {
fmt.Fprintf(w, "(")
for _, col := range b.insertCols {
fmt.Fprintf(w, col)
}
fmt.Fprintf(w, ") ")
}

return b.selectWriteTo(w)
}

func (b *Builder) insertWriteTo(w Writer) error {
if len(b.tableName) <= 0 {
if len(b.into) <= 0 {
return ErrNoTableName
}
if len(b.inserts) <= 0 {
if len(b.insertCols) <= 0 && b.from == "" {
return ErrNoColumnToInsert
}

if _, err := fmt.Fprintf(w, "INSERT INTO %s (", b.tableName); err != nil {
if b.into != "" && b.from != "" {
return b.insertSelectWriteTo(w)
}

if _, err := fmt.Fprintf(w, "INSERT INTO %s (", b.into); err != nil {
return err
}

var args = make([]interface{}, 0)
var bs []byte
var valBuffer = bytes.NewBuffer(bs)
var i = 0

for _, col := range b.inserts.sortedKeys() {
value := b.inserts[col]
for i, col := range b.insertCols {
value := b.insertVals[i]
fmt.Fprint(w, col)
if e, ok := value.(expr); ok {
fmt.Fprintf(valBuffer, "(%s)", e.sql)
Expand All @@ -43,15 +62,14 @@ func (b *Builder) insertWriteTo(w Writer) error {
args = append(args, value)
}

if i != len(b.inserts)-1 {
if i != len(b.insertCols)-1 {
if _, err := fmt.Fprint(w, ","); err != nil {
return err
}
if _, err := fmt.Fprint(valBuffer, ","); err != nil {
return err
}
}
i = i + 1
}

if _, err := fmt.Fprint(w, ") Values ("); err != nil {
Expand Down
41 changes: 41 additions & 0 deletions builder_insert_test.go
@@ -0,0 +1,41 @@
// Copyright 2018 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package builder

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestBuilderInsert(t *testing.T) {
sql, err := Insert(Eq{"c": 1, "d": 2}).Into("table1").ToBoundSQL()
assert.NoError(t, err)
assert.EqualValues(t, "INSERT INTO table1 (c,d) Values (1,2)", sql)

sql, err = Insert(Eq{"c": 1, "d": Expr("SELECT b FROM t WHERE d=? LIMIT 1", 2)}).Into("table1").ToBoundSQL()
assert.NoError(t, err)
assert.EqualValues(t, "INSERT INTO table1 (c,d) Values (1,(SELECT b FROM t WHERE d=2 LIMIT 1))", sql)

sql, err = Insert(Eq{"c": 1, "d": 2}).ToBoundSQL()
assert.Error(t, err)
assert.EqualValues(t, ErrNoTableName, err)
assert.EqualValues(t, "", sql)

sql, err = Insert(Eq{}).Into("table1").ToBoundSQL()
assert.Error(t, err)
assert.EqualValues(t, ErrNoColumnToInsert, err)
assert.EqualValues(t, "", sql)
}

func TestBuidlerInsert_Select(t *testing.T) {
sql, err := Insert().Into("table1").Select().From("table2").ToBoundSQL()
assert.NoError(t, err)
assert.EqualValues(t, "INSERT INTO table1 SELECT * FROM table2", sql)

sql, err = Insert("a, b").Into("table1").Select("b, c").From("table2").ToBoundSQL()
assert.NoError(t, err)
assert.EqualValues(t, "INSERT INTO table1 (a, b) SELECT b, c FROM table2", sql)
}
4 changes: 3 additions & 1 deletion builder_limit.go
Expand Up @@ -56,7 +56,9 @@ func (b *Builder) limitWriteTo(w Writer) error {
case SQLITE, MYSQL, POSTGRES:
// if type UNION, we need to write previous content back to current writer
if b.optype == unionType {
b.WriteTo(ow)
if err := b.WriteTo(ow); err != nil {
return err
}
}

if limit.offset == 0 {
Expand Down
9 changes: 2 additions & 7 deletions builder_limit_test.go
Expand Up @@ -4,12 +4,7 @@

package builder

import (
"testing"

"github.com/stretchr/testify/assert"
)

/*
func TestBuilder_Limit4Mssql(t *testing.T) {
sqlFromFile, err := readPreparationSQLFromFile("testdata/mssql_fiddle_data.sql")
assert.NoError(t, err)
Expand Down Expand Up @@ -126,4 +121,4 @@ func TestBuilder_Limit4Oracle(t *testing.T) {
assert.NoError(t, err)
assert.EqualValues(t, "SELECT a,b,c FROM (SELECT a,b,c,ROWNUM RN FROM ((SELECT a,b,c FROM (SELECT * FROM (SELECT a,b,c,ROWNUM RN FROM table1 WHERE a<>'0' ORDER BY a ASC) at WHERE at.RN<=15) att WHERE att.RN>10) UNION ALL (SELECT a,b,c FROM (SELECT a,b,c,ROWNUM RN FROM table1 WHERE b<>'48' ORDER BY a DESC) at WHERE at.RN<=10)) at) at WHERE at.RN<=3", sql)
assert.NoError(t, f.executableCheck(sql))
}
}*/
10 changes: 5 additions & 5 deletions builder_select.go
Expand Up @@ -15,7 +15,7 @@ func Select(cols ...string) *Builder {
}

func (b *Builder) selectWriteTo(w Writer) error {
if len(b.tableName) <= 0 && !b.isNested {
if len(b.from) <= 0 && !b.isNested {
return ErrNoTableName
}

Expand Down Expand Up @@ -46,11 +46,11 @@ func (b *Builder) selectWriteTo(w Writer) error {
}

if b.subQuery == nil {
if _, err := fmt.Fprint(w, " FROM ", b.tableName); err != nil {
if _, err := fmt.Fprint(w, " FROM ", b.from); err != nil {
return err
}
} else {
if b.cond.IsValid() && len(b.tableName) <= 0 {
if b.cond.IsValid() && len(b.from) <= 0 {
return ErrUnnamedDerivedTable
}
if b.subQuery.dialect != "" && b.dialect != b.subQuery.dialect {
Expand All @@ -69,10 +69,10 @@ func (b *Builder) selectWriteTo(w Writer) error {
return err
}

if len(b.tableName) == 0 {
if len(b.from) == 0 {
fmt.Fprintf(w, ")")
} else {
fmt.Fprintf(w, ") %v", b.tableName)
fmt.Fprintf(w, ") %v", b.from)
}
default:
return ErrUnexpectedSubQuery
Expand Down
9 changes: 5 additions & 4 deletions builder_select_test.go
Expand Up @@ -15,6 +15,7 @@ func TestBuilder_Select(t *testing.T) {
sql, args, err := Select("c, d").From("table1").ToSQL()
assert.NoError(t, err)
assert.EqualValues(t, "SELECT c, d FROM table1", sql)
assert.EqualValues(t, []interface{}(nil), args)

sql, args, err = Select("c, d").From("table1").Where(Eq{"a": 1}).ToSQL()
assert.NoError(t, err)
Expand Down Expand Up @@ -104,24 +105,24 @@ func TestBuilder_From(t *testing.T) {
assert.EqualValues(t, []interface{}{1, 2, 1}, args)

// from union without alias
sql, args, err = Select("sub.id").From(
_, _, err = Select("sub.id").From(
Select("id").From("table1").Where(Eq{"a": 1}).Union(
"all", Select("id").From("table1").Where(Eq{"a": 2}))).Where(Eq{"b": 1}).ToSQL()
assert.Error(t, err)
assert.EqualValues(t, ErrUnnamedDerivedTable, err)

// will raise error
sql, args, err = Select("c").From(Insert(Eq{"a": 1}).From("table1"), "table1").ToSQL()
_, _, err = Select("c").From(Insert(Eq{"a": 1}).From("table1"), "table1").ToSQL()
assert.Error(t, err)
assert.EqualValues(t, ErrUnexpectedSubQuery, err)

// will raise error
sql, args, err = Select("c").From(Delete(Eq{"a": 1}).From("table1"), "table1").ToSQL()
_, _, err = Select("c").From(Delete(Eq{"a": 1}).From("table1"), "table1").ToSQL()
assert.Error(t, err)
assert.EqualValues(t, ErrUnexpectedSubQuery, err)

// from a sub-query in different dialect
sql, args, err = MySQL().Select("sub.id").From(
_, _, err = MySQL().Select("sub.id").From(
Oracle().Select("id").From("table1").Where(Eq{"a": 1}), "sub").Where(Eq{"b": 1}).ToSQL()
assert.Error(t, err)
assert.EqualValues(t, ErrInconsistentDialect, err)
Expand Down
18 changes: 0 additions & 18 deletions builder_test.go
Expand Up @@ -596,24 +596,6 @@ func TestBuilderCond(t *testing.T) {
}
}

func TestBuilderInsert(t *testing.T) {
sql, err := Insert(Eq{"c": 1, "d": 2}).Into("table1").ToBoundSQL()
assert.NoError(t, err)
assert.EqualValues(t, "INSERT INTO table1 (c,d) Values (1,2)", sql)

sql, err = Insert(Eq{"c": 1, "d": Expr("SELECT b FROM t WHERE d=? LIMIT 1", 2)}).Into("table1").ToBoundSQL()
assert.NoError(t, err)
assert.EqualValues(t, "INSERT INTO table1 (c,d) Values (1,(SELECT b FROM t WHERE d=2 LIMIT 1))", sql)

sql, err = Insert(Eq{"c": 1, "d": 2}).ToBoundSQL()
assert.Error(t, err)
assert.EqualValues(t, ErrNoTableName, err)

sql, err = Insert(Eq{}).Into("table1").ToBoundSQL()
assert.Error(t, err)
assert.EqualValues(t, ErrNoColumnToInsert, err)
}

func TestSubquery(t *testing.T) {
subb := Select("id").From("table_b").Where(Eq{"b": "a"})
b := Select("a, b").From("table_a").Where(
Expand Down

0 comments on commit 395bcf3

Please sign in to comment.