Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: psgql tx unsupport LastInsertId #2815

Merged
merged 3 commits into from Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 17 additions & 10 deletions contrib/drivers/pgsql/pgsql.go
Expand Up @@ -35,8 +35,8 @@ type Driver struct {

const (
internalPrimaryKeyInCtx gctx.StrKey = "primary_key"
defaultSchema = "public"
quoteChar = `"`
defaultSchema string = "public"
quoteChar string = `"`
)

func init() {
Expand Down Expand Up @@ -372,14 +372,22 @@ func (d *Driver) DoExec(ctx context.Context, link gdb.Link, sql string, args ...
)

// Transaction checks.
if link != nil && link.IsTransaction() {
isUseCoreDoExec = true
} else {
if link == nil {
if tx := gdb.TXFromCtx(ctx, d.GetGroup()); tx != nil {
isUseCoreDoExec = true
// Firstly, check and retrieve transaction link from context.
link = tx
} else if link, err = d.MasterLink(); err != nil {
// Or else it creates one from master node.
return nil, err
}
} else if !link.IsTransaction() {
// If current link is not transaction link, it checks and retrieves transaction from context.
if tx := gdb.TXFromCtx(ctx, d.GetGroup()); tx != nil {
link = tx
}
}

// Check if it is an insert operation with primary key.
if value := ctx.Value(internalPrimaryKeyInCtx); value != nil {
var ok bool
pkField, ok = value.(gdb.TableField)
Expand Down Expand Up @@ -408,8 +416,7 @@ func (d *Driver) DoExec(ctx context.Context, link gdb.Link, sql string, args ...
}

// Sql filtering.
// TODO: internal function formatSql
// sql, args = formatSql(sql, args)
sql, args = d.FormatSqlBeforeExecuting(sql, args)
sql, args, err = d.DoFilter(ctx, link, sql, args)
if err != nil {
return nil, err
Expand Down Expand Up @@ -442,10 +449,10 @@ func (d *Driver) DoExec(ctx context.Context, link gdb.Link, sql string, args ...
}

if out.Records[affected-1][primaryKey] != nil {
lastInsertId := out.Records[affected-1][primaryKey].Int()
lastInsertId := out.Records[affected-1][primaryKey].Int64()
return Result{
affected: int64(affected),
lastInsertId: int64(lastInsertId),
lastInsertId: lastInsertId,
}, nil
}
}
Expand Down
39 changes: 39 additions & 0 deletions contrib/drivers/pgsql/pgsql_z_test.go
Expand Up @@ -7,8 +7,10 @@
package pgsql_test

import (
"context"
"testing"

"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gctx"
"github.com/gogf/gf/v2/test/gtest"
Expand Down Expand Up @@ -45,6 +47,43 @@ func Test_LastInsertId(t *testing.T) {
})
}

func Test_TxLastInsertId(t *testing.T) {
gtest.C(t, func(t *gtest.T) {
tableName := createTable()
defer dropTable(tableName)
err := db.Transaction(context.TODO(), func(ctx context.Context, tx gdb.TX) error {
// user
res, err := tx.Model(tableName).Insert(g.List{
{"passport": "user1", "password": "pwd", "nickname": "nickname", "create_time": CreateTime},
{"passport": "user2", "password": "pwd", "nickname": "nickname", "create_time": CreateTime},
{"passport": "user3", "password": "pwd", "nickname": "nickname", "create_time": CreateTime},
})
t.Assert(err, nil)
lastInsertId, err := res.LastInsertId()
t.Assert(err, nil)
t.AssertEQ(lastInsertId, int64(3))
rowsAffected, err := res.RowsAffected()
t.Assert(err, nil)
t.AssertEQ(rowsAffected, int64(3))

res1, err := tx.Model(tableName).Insert(g.List{
{"passport": "user4", "password": "pwd", "nickname": "nickname", "create_time": CreateTime},
{"passport": "user5", "password": "pwd", "nickname": "nickname", "create_time": CreateTime},
})
t.Assert(err, nil)
lastInsertId1, err := res1.LastInsertId()
t.Assert(err, nil)
t.AssertEQ(lastInsertId1, int64(5))
rowsAffected1, err := res1.RowsAffected()
t.Assert(err, nil)
t.AssertEQ(rowsAffected1, int64(2))
return nil

})
t.Assert(err, nil)
})
}

func Test_Driver_DoFilter(t *testing.T) {
var (
ctx = gctx.New()
Expand Down
2 changes: 2 additions & 0 deletions database/gdb/gdb.go
Expand Up @@ -179,6 +179,8 @@ type DB interface {

// TX defines the interfaces for ORM transaction operations.
type TX interface {
Link

Ctx(ctx context.Context) TX
Raw(rawSql string, args ...interface{}) *Model
Model(tableNameQueryOrStruct ...interface{}) *Model
Expand Down
11 changes: 11 additions & 0 deletions database/gdb/gdb_core.go
Expand Up @@ -796,3 +796,14 @@ func (c *Core) isSoftCreatedFieldName(fieldName string) bool {
}
return false
}

// FormatSqlBeforeExecuting formats the sql string and its arguments before executing.
// The internal handleArguments function might be called twice during the SQL procedure,
// but do not worry about it, it's safe and efficient.
func (c *Core) FormatSqlBeforeExecuting(sql string, args []interface{}) (newSql string, newArgs []interface{}) {
// DO NOT do this as there may be multiple lines and comments in the sql.
// sql = gstr.Trim(sql)
// sql = gstr.Replace(sql, "\n", " ")
// sql, _ = gregex.ReplaceString(`\s{2,}`, ` `, sql)
return handleArguments(sql, args)
}
25 changes: 25 additions & 0 deletions database/gdb/gdb_core_transaction.go
Expand Up @@ -517,3 +517,28 @@ func (tx *TXCore) Update(table string, data interface{}, condition interface{},
func (tx *TXCore) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) {
return tx.Model(table).Ctx(tx.ctx).Where(condition, args...).Delete()
}

// QueryContext implements interface function Link.QueryContext.
func (tx *TXCore) QueryContext(ctx context.Context, sql string, args ...interface{}) (*sql.Rows, error) {
return tx.tx.QueryContext(ctx, sql, args...)
}

// ExecContext implements interface function Link.ExecContext.
func (tx *TXCore) ExecContext(ctx context.Context, sql string, args ...interface{}) (sql.Result, error) {
return tx.tx.ExecContext(ctx, sql, args...)
}

// PrepareContext implements interface function Link.PrepareContext.
func (tx *TXCore) PrepareContext(ctx context.Context, sql string) (*sql.Stmt, error) {
return tx.tx.PrepareContext(ctx, sql)
}

// IsOnMaster implements interface function Link.IsOnMaster.
func (tx *TXCore) IsOnMaster() bool {
return true
}

// IsTransaction implements interface function Link.IsTransaction.
func (tx *TXCore) IsTransaction() bool {
return true
}
7 changes: 4 additions & 3 deletions database/gdb/gdb_core_underlying.go
Expand Up @@ -10,9 +10,10 @@ package gdb
import (
"context"
"database/sql"
"reflect"

"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/trace"
"reflect"

"github.com/gogf/gf/v2/util/gconv"

Expand Down Expand Up @@ -55,7 +56,7 @@ func (c *Core) DoQuery(ctx context.Context, link Link, sql string, args ...inter
}

// Sql filtering.
sql, args = formatSql(sql, args)
sql, args = c.FormatSqlBeforeExecuting(sql, args)
sql, args, err = c.db.DoFilter(ctx, link, sql, args)
if err != nil {
return nil, err
Expand Down Expand Up @@ -116,7 +117,7 @@ func (c *Core) DoExec(ctx context.Context, link Link, sql string, args ...interf
}

// SQL filtering.
sql, args = formatSql(sql, args)
sql, args = c.FormatSqlBeforeExecuting(sql, args)
sql, args, err = c.db.DoFilter(ctx, link, sql, args)
if err != nil {
return nil, err
Expand Down
11 changes: 0 additions & 11 deletions database/gdb/gdb_func.go
Expand Up @@ -373,17 +373,6 @@ func GetPrimaryKeyCondition(primary string, where ...interface{}) (newWhereCondi
return where
}

// formatSql formats the sql string and its arguments before executing.
// The internal handleArguments function might be called twice during the SQL procedure,
// but do not worry about it, it's safe and efficient.
func formatSql(sql string, args []interface{}) (newSql string, newArgs []interface{}) {
// DO NOT do this as there may be multiple lines and comments in the sql.
// sql = gstr.Trim(sql)
// sql = gstr.Replace(sql, "\n", " ")
// sql, _ = gregex.ReplaceString(`\s{2,}`, ` `, sql)
return handleArguments(sql, args)
}

type formatWhereHolderInput struct {
WhereHolder
OmitNil bool
Expand Down