Skip to content

Commit

Permalink
fix: psgql tx unsupport LastInsertId (#2815)
Browse files Browse the repository at this point in the history
  • Loading branch information
hailaz committed Aug 3, 2023
1 parent 2fbe412 commit a4e7cc4
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 24 deletions.
27 changes: 17 additions & 10 deletions contrib/drivers/pgsql/pgsql.go
Expand Up @@ -35,8 +35,8 @@ type Driver struct {

const (
internalPrimaryKeyInCtx gctx.StrKey = "primary_key"
defaultSchema = "public"
quoteChar = `"`
defaultSchema string = "public"
quoteChar string = `"`
)

func init() {
Expand Down Expand Up @@ -372,14 +372,22 @@ func (d *Driver) DoExec(ctx context.Context, link gdb.Link, sql string, args ...
)

// Transaction checks.
if link != nil && link.IsTransaction() {
isUseCoreDoExec = true
} else {
if link == nil {
if tx := gdb.TXFromCtx(ctx, d.GetGroup()); tx != nil {
isUseCoreDoExec = true
// Firstly, check and retrieve transaction link from context.
link = tx
} else if link, err = d.MasterLink(); err != nil {
// Or else it creates one from master node.
return nil, err
}
} else if !link.IsTransaction() {
// If current link is not transaction link, it checks and retrieves transaction from context.
if tx := gdb.TXFromCtx(ctx, d.GetGroup()); tx != nil {
link = tx
}
}

// Check if it is an insert operation with primary key.
if value := ctx.Value(internalPrimaryKeyInCtx); value != nil {
var ok bool
pkField, ok = value.(gdb.TableField)
Expand Down Expand Up @@ -408,8 +416,7 @@ func (d *Driver) DoExec(ctx context.Context, link gdb.Link, sql string, args ...
}

// Sql filtering.
// TODO: internal function formatSql
// sql, args = formatSql(sql, args)
sql, args = d.FormatSqlBeforeExecuting(sql, args)
sql, args, err = d.DoFilter(ctx, link, sql, args)
if err != nil {
return nil, err
Expand Down Expand Up @@ -442,10 +449,10 @@ func (d *Driver) DoExec(ctx context.Context, link gdb.Link, sql string, args ...
}

if out.Records[affected-1][primaryKey] != nil {
lastInsertId := out.Records[affected-1][primaryKey].Int()
lastInsertId := out.Records[affected-1][primaryKey].Int64()
return Result{
affected: int64(affected),
lastInsertId: int64(lastInsertId),
lastInsertId: lastInsertId,
}, nil
}
}
Expand Down
39 changes: 39 additions & 0 deletions contrib/drivers/pgsql/pgsql_z_test.go
Expand Up @@ -7,8 +7,10 @@
package pgsql_test

import (
"context"
"testing"

"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gctx"
"github.com/gogf/gf/v2/test/gtest"
Expand Down Expand Up @@ -45,6 +47,43 @@ func Test_LastInsertId(t *testing.T) {
})
}

func Test_TxLastInsertId(t *testing.T) {
gtest.C(t, func(t *gtest.T) {
tableName := createTable()
defer dropTable(tableName)
err := db.Transaction(context.TODO(), func(ctx context.Context, tx gdb.TX) error {
// user
res, err := tx.Model(tableName).Insert(g.List{
{"passport": "user1", "password": "pwd", "nickname": "nickname", "create_time": CreateTime},
{"passport": "user2", "password": "pwd", "nickname": "nickname", "create_time": CreateTime},
{"passport": "user3", "password": "pwd", "nickname": "nickname", "create_time": CreateTime},
})
t.Assert(err, nil)
lastInsertId, err := res.LastInsertId()
t.Assert(err, nil)
t.AssertEQ(lastInsertId, int64(3))
rowsAffected, err := res.RowsAffected()
t.Assert(err, nil)
t.AssertEQ(rowsAffected, int64(3))

res1, err := tx.Model(tableName).Insert(g.List{
{"passport": "user4", "password": "pwd", "nickname": "nickname", "create_time": CreateTime},
{"passport": "user5", "password": "pwd", "nickname": "nickname", "create_time": CreateTime},
})
t.Assert(err, nil)
lastInsertId1, err := res1.LastInsertId()
t.Assert(err, nil)
t.AssertEQ(lastInsertId1, int64(5))
rowsAffected1, err := res1.RowsAffected()
t.Assert(err, nil)
t.AssertEQ(rowsAffected1, int64(2))
return nil

})
t.Assert(err, nil)
})
}

func Test_Driver_DoFilter(t *testing.T) {
var (
ctx = gctx.New()
Expand Down
2 changes: 2 additions & 0 deletions database/gdb/gdb.go
Expand Up @@ -179,6 +179,8 @@ type DB interface {

// TX defines the interfaces for ORM transaction operations.
type TX interface {
Link

Ctx(ctx context.Context) TX
Raw(rawSql string, args ...interface{}) *Model
Model(tableNameQueryOrStruct ...interface{}) *Model
Expand Down
11 changes: 11 additions & 0 deletions database/gdb/gdb_core.go
Expand Up @@ -796,3 +796,14 @@ func (c *Core) isSoftCreatedFieldName(fieldName string) bool {
}
return false
}

// FormatSqlBeforeExecuting formats the sql string and its arguments before executing.
// The internal handleArguments function might be called twice during the SQL procedure,
// but do not worry about it, it's safe and efficient.
func (c *Core) FormatSqlBeforeExecuting(sql string, args []interface{}) (newSql string, newArgs []interface{}) {
// DO NOT do this as there may be multiple lines and comments in the sql.
// sql = gstr.Trim(sql)
// sql = gstr.Replace(sql, "\n", " ")
// sql, _ = gregex.ReplaceString(`\s{2,}`, ` `, sql)
return handleArguments(sql, args)
}
25 changes: 25 additions & 0 deletions database/gdb/gdb_core_transaction.go
Expand Up @@ -517,3 +517,28 @@ func (tx *TXCore) Update(table string, data interface{}, condition interface{},
func (tx *TXCore) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) {
return tx.Model(table).Ctx(tx.ctx).Where(condition, args...).Delete()
}

// QueryContext implements interface function Link.QueryContext.
func (tx *TXCore) QueryContext(ctx context.Context, sql string, args ...interface{}) (*sql.Rows, error) {
return tx.tx.QueryContext(ctx, sql, args...)
}

// ExecContext implements interface function Link.ExecContext.
func (tx *TXCore) ExecContext(ctx context.Context, sql string, args ...interface{}) (sql.Result, error) {
return tx.tx.ExecContext(ctx, sql, args...)
}

// PrepareContext implements interface function Link.PrepareContext.
func (tx *TXCore) PrepareContext(ctx context.Context, sql string) (*sql.Stmt, error) {
return tx.tx.PrepareContext(ctx, sql)
}

// IsOnMaster implements interface function Link.IsOnMaster.
func (tx *TXCore) IsOnMaster() bool {
return true
}

// IsTransaction implements interface function Link.IsTransaction.
func (tx *TXCore) IsTransaction() bool {
return true
}
7 changes: 4 additions & 3 deletions database/gdb/gdb_core_underlying.go
Expand Up @@ -10,9 +10,10 @@ package gdb
import (
"context"
"database/sql"
"reflect"

"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/trace"
"reflect"

"github.com/gogf/gf/v2/util/gconv"

Expand Down Expand Up @@ -55,7 +56,7 @@ func (c *Core) DoQuery(ctx context.Context, link Link, sql string, args ...inter
}

// Sql filtering.
sql, args = formatSql(sql, args)
sql, args = c.FormatSqlBeforeExecuting(sql, args)
sql, args, err = c.db.DoFilter(ctx, link, sql, args)
if err != nil {
return nil, err
Expand Down Expand Up @@ -116,7 +117,7 @@ func (c *Core) DoExec(ctx context.Context, link Link, sql string, args ...interf
}

// SQL filtering.
sql, args = formatSql(sql, args)
sql, args = c.FormatSqlBeforeExecuting(sql, args)
sql, args, err = c.db.DoFilter(ctx, link, sql, args)
if err != nil {
return nil, err
Expand Down
11 changes: 0 additions & 11 deletions database/gdb/gdb_func.go
Expand Up @@ -373,17 +373,6 @@ func GetPrimaryKeyCondition(primary string, where ...interface{}) (newWhereCondi
return where
}

// formatSql formats the sql string and its arguments before executing.
// The internal handleArguments function might be called twice during the SQL procedure,
// but do not worry about it, it's safe and efficient.
func formatSql(sql string, args []interface{}) (newSql string, newArgs []interface{}) {
// DO NOT do this as there may be multiple lines and comments in the sql.
// sql = gstr.Trim(sql)
// sql = gstr.Replace(sql, "\n", " ")
// sql, _ = gregex.ReplaceString(`\s{2,}`, ` `, sql)
return handleArguments(sql, args)
}

type formatWhereHolderInput struct {
WhereHolder
OmitNil bool
Expand Down

0 comments on commit a4e7cc4

Please sign in to comment.