Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance: support save for Oracle #3364

Merged
merged 13 commits into from Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
118 changes: 114 additions & 4 deletions contrib/drivers/oracle/oracle_do_insert.go
Expand Up @@ -10,6 +10,8 @@ import (
"context"
"database/sql"
"fmt"
"github.com/gogf/gf/v2/container/gset"
"github.com/gogf/gf/v2/text/gstr"
"strings"

"github.com/gogf/gf/v2/database/gdb"
Expand All @@ -24,10 +26,7 @@ func (d *Driver) DoInsert(
) (result sql.Result, err error) {
switch option.InsertOption {
case gdb.InsertOptionSave:
return nil, gerror.NewCode(
gcode.CodeNotSupported,
`Save operation is not supported by oracle driver`,
)
return d.doSave(ctx, link, table, list, option)

case gdb.InsertOptionReplace:
return nil, gerror.NewCode(
Expand Down Expand Up @@ -93,3 +92,114 @@ func (d *Driver) DoInsert(
}
return batchResult, nil
}

// doSave support upsert for Oracle
func (d *Driver) doSave(ctx context.Context,
link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption,
) (result sql.Result, err error) {
if len(option.OnConflict) == 0 {
return nil, gerror.NewCode(
gcode.CodeMissingParameter, `Please specify conflict columns`,
)
}

if len(list) == 0 {
return nil, gerror.NewCode(
gcode.CodeInvalidRequest, `Save operation list is empty by oracle driver`,
)
}

var (
one = list[0]
charL, charR = d.GetChars()
valueCharL, valueCharR = "'", "'"
oldme-git marked this conversation as resolved.
Show resolved Hide resolved

conflictKeys = option.OnConflict
conflictKeySet = gset.New(false)

// insertKeys: Handle valid keys that need to be inserted
// insertValues: Handle values that need to be inserted
// updateValues: Handle values that need to be updated
// queryValues: Handle data that need to be upsert
queryValues, insertKeys, insertValues, updateValues []string
)

// conflictKeys slice type conv to set type
for _, conflictKey := range conflictKeys {
conflictKeySet.Add(gstr.ToUpper(conflictKey))
}

for key, value := range one {
saveValue := gconv.String(value)
queryValues = append(
queryValues,
fmt.Sprintf(
valueCharL+"%s"+valueCharR+" AS "+charL+"%s"+charR,
saveValue, key,
),
)

insertKeys = append(insertKeys, charL+key+charR)
insertValues = append(insertValues, "T2."+charL+key+charR)

// filter conflict keys in updateValues
if !conflictKeySet.Contains(key) {
updateValues = append(
updateValues,
fmt.Sprintf(`T1.%s = T2.%s`, charL+key+charR, charL+key+charR),
)
}
}

batchResult := new(gdb.SqlResult)
sqlStr := parseSqlForUpsert(table, queryValues, insertKeys, insertValues, updateValues, conflictKeys)
r, err := d.DoExec(ctx, link, sqlStr)
if err != nil {
return r, err
}
if n, err := r.RowsAffected(); err != nil {
return r, err
} else {
batchResult.Result = r
batchResult.Affected += n
}
return batchResult, nil
}

// parseSqlForUpsert
// MERGE INTO {{table}} T1
// USING ( SELECT {{queryValues}} FROM DUAL T2
// ON (T1.{{duplicateKey}} = T2.{{duplicateKey}} AND ...)
// WHEN NOT MATCHED THEN
// INSERT {{insertKeys}} VALUES {{insertValues}}
// WHEN MATCHED THEN
// UPDATE SET {{updateValues}}
func parseSqlForUpsert(table string,
queryValues, insertKeys, insertValues, updateValues, duplicateKey []string,
) (sqlStr string) {
var (
queryValueStr = strings.Join(queryValues, ",")
insertKeyStr = strings.Join(insertKeys, ",")
insertValueStr = strings.Join(insertValues, ",")
updateValueStr = strings.Join(updateValues, ",")
duplicateKeyStr string
pattern = gstr.Trim(`MERGE INTO %s T1 USING (SELECT %s FROM DUAL) T2 ON (%s) WHEN NOT MATCHED THEN INSERT(%s) VALUES (%s) WHEN MATCHED THEN UPDATE SET %s`)
)

for index, keys := range duplicateKey {
if index != 0 {
duplicateKeyStr += " AND "
}
duplicateTmp := fmt.Sprintf("T1.%s = T2.%s", keys, keys)
duplicateKeyStr += duplicateTmp
}

return fmt.Sprintf(pattern,
table,
queryValueStr,
duplicateKeyStr,
insertKeyStr,
insertValueStr,
updateValueStr,
)
}
6 changes: 3 additions & 3 deletions contrib/drivers/oracle/oracle_z_unit_basic_test.go
Expand Up @@ -19,7 +19,7 @@ import (
"github.com/gogf/gf/v2/test/gtest"
)

func TestTables(t *testing.T) {
func Test_Tables(t *testing.T) {
gtest.C(t, func(t *gtest.T) {
tables := []string{"t_user1", "pop", "haha"}

Expand Down Expand Up @@ -60,7 +60,7 @@ func TestTables(t *testing.T) {
})
}

func TestTableFields(t *testing.T) {
func Test_Table_Fields(t *testing.T) {
gtest.C(t, func(t *gtest.T) {
createTable("t_user")
defer dropTable("t_user")
Expand Down Expand Up @@ -107,7 +107,7 @@ func TestTableFields(t *testing.T) {
})
}

func TestDoInsert(t *testing.T) {
func Test_Do_Insert(t *testing.T) {
gtest.C(t, func(t *gtest.T) {
createTable("t_user")
defer dropTable("t_user")
Expand Down
80 changes: 78 additions & 2 deletions contrib/drivers/oracle/oracle_z_unit_model_test.go
Expand Up @@ -128,7 +128,7 @@ func Test_Model_RightJoin(t *testing.T) {
})
}

func TestPage(t *testing.T) {
func Test_Page(t *testing.T) {
table := createInitTable()
defer dropTable(table)
result, err := db.Model(table).Page(1, 2).Order("ID").All()
Expand Down Expand Up @@ -162,7 +162,6 @@ func TestPage(t *testing.T) {
func Test_Model_Insert(t *testing.T) {
table := createTable()
defer dropTable(table)
// db.SetDebug(true)
gtest.C(t, func(t *gtest.T) {
user := db.Model(table)
result, err := user.Data(g.Map{
Expand Down Expand Up @@ -1101,6 +1100,83 @@ func Test_Model_WhereOrNotLike(t *testing.T) {
})
}

func Test_Model_Save(t *testing.T) {
table := createTable("test")
defer dropTable(table)
gtest.C(t, func(t *gtest.T) {
type User struct {
Id int
Passport string
Password string
NickName string
CreateTime *gtime.Time
}
var (
user User
count int
result sql.Result
createTime = gtime.Now().Format("Y-m-d")
err error
)

result, err = db.Model(table).Data(g.Map{
"id": 1,
"passport": "p1",
"password": "15d55ad283aa400af464c76d713c07ad",
"nickname": "n1",
"create_time": createTime,
}).OnConflict("id").Save()

t.AssertNil(err)
n, _ := result.RowsAffected()
t.Assert(n, 1)

err = db.Model(table).Scan(&user)
t.AssertNil(err)
t.Assert(user.Id, 1)
t.Assert(user.Passport, "p1")
t.Assert(user.Password, "15d55ad283aa400af464c76d713c07ad")
t.Assert(user.NickName, "n1")
t.Assert(user.CreateTime.Format("Y-m-d"), createTime)

_, err = db.Model(table).Data(g.Map{
"id": 1,
"passport": "p1",
"password": "25d55ad283aa400af464c76d713c07ad",
"nickname": "n2",
"create_time": createTime,
}).OnConflict("id").Save()
t.AssertNil(err)

err = db.Model(table).Scan(&user)
t.AssertNil(err)
t.Assert(user.Passport, "p1")
t.Assert(user.Password, "25d55ad283aa400af464c76d713c07ad")
t.Assert(user.NickName, "n2")
t.Assert(user.CreateTime.Format("Y-m-d"), createTime)

count, err = db.Model(table).Count()
t.AssertNil(err)
t.Assert(count, 1)
})
}

func Test_Model_Replace(t *testing.T) {
table := createTable()
defer dropTable(table)

gtest.C(t, func(t *gtest.T) {
_, err := db.Model(table).Data(g.Map{
"id": 1,
"passport": "t11",
"password": "25d55ad283aa400af464c76d713c07ad",
"nickname": "T11",
"create_time": "2018-10-24 10:00:00",
}).Replace()
t.Assert(err, "Replace operation is not supported by oracle driver")
})
}

/* not support the "AS"
func Test_Model_Raw(t *testing.T) {
table := createInitTable()
Expand Down
5 changes: 4 additions & 1 deletion contrib/drivers/pgsql/pgsql_format_upsert.go
Expand Up @@ -10,6 +10,7 @@ import (
"fmt"

"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
Expand All @@ -19,7 +20,9 @@ import (
// For example: ON CONFLICT (id) DO UPDATE SET ...
func (d *Driver) FormatUpsert(columns []string, list gdb.List, option gdb.DoInsertOption) (string, error) {
if len(option.OnConflict) == 0 {
return "", gerror.New("Please specify conflict columns")
return "", gerror.NewCode(
gcode.CodeMissingParameter, `Please specify conflict columns`,
)
}

var onDuplicateStr string
Expand Down
5 changes: 4 additions & 1 deletion contrib/drivers/sqlite/sqlite_format_upsert.go
Expand Up @@ -10,6 +10,7 @@ import (
"fmt"

"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
Expand All @@ -19,7 +20,9 @@ import (
// For example: ON CONFLICT (id) DO UPDATE SET ...
func (d *Driver) FormatUpsert(columns []string, list gdb.List, option gdb.DoInsertOption) (string, error) {
if len(option.OnConflict) == 0 {
return "", gerror.New("Please specify conflict columns")
return "", gerror.NewCode(
gcode.CodeMissingParameter, `Please specify conflict columns`,
)
}

var onDuplicateStr string
Expand Down
1 change: 1 addition & 0 deletions database/gdb/gdb_core_underlying.go
Expand Up @@ -396,6 +396,7 @@ func (c *Core) FormatUpsert(columns []string, list List, option DoInsertOption)
)
}
}

return InsertOnDuplicateKeyUpdate + " " + onDuplicateStr, nil
}

Expand Down