Skip to content

Commit

Permalink
feat: Create Tx and Conn interfaces to allow prepared statements to b…
Browse files Browse the repository at this point in the history
…e used in transactions (#40)

* tx

* fix test

* comments

* wrap connection

* comments

* statement doesn't need sqlDB instance

* don't use any

* return any statement

* check RETURNING in prepared queries

---------

Co-authored-by: Iman Tumorang <iman.tumorang@gmail.com>
  • Loading branch information
RangelReale and bxcodec committed Sep 20, 2023
1 parent 404cf52 commit 1f99231
Show file tree
Hide file tree
Showing 5 changed files with 296 additions and 15 deletions.
73 changes: 73 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package dbresolver

import (
"context"
"database/sql"
"strings"
)

// Conn is a *sql.Conn wrapper.
// Its main purpose is to be able to return the internal Tx and Stmt interfaces.
type Conn interface {
Close() error
BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
PingContext(ctx context.Context) error
PrepareContext(ctx context.Context, query string) (Stmt, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
Raw(f func(driverConn interface{}) error) (err error)
}

type conn struct {
sourceDB *sql.DB
conn *sql.Conn
}

func (c *conn) Close() error {
return c.conn.Close()
}

func (c *conn) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) {
stx, err := c.conn.BeginTx(ctx, opts)
if err != nil {
return nil, err
}

return &tx{
sourceDB: c.sourceDB,
tx: stx,
}, nil
}

func (c *conn) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
return c.conn.ExecContext(ctx, query, args...)
}

func (c *conn) PingContext(ctx context.Context) error {
return c.conn.PingContext(ctx)
}

func (c *conn) PrepareContext(ctx context.Context, query string) (Stmt, error) {
pstmt, err := c.conn.PrepareContext(ctx, query)
if err != nil {
return nil, err
}

_query := strings.ToUpper(query)
writeFlag := strings.Contains(_query, "RETURNING")

return newSingleDBStmt(c.sourceDB, pstmt, writeFlag), nil
}

func (c *conn) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
return c.conn.QueryContext(ctx, query, args...)
}

func (c *conn) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
return c.conn.QueryRowContext(ctx, query, args...)
}

func (c *conn) Raw(f func(driverConn interface{}) error) (err error) {
return c.conn.Raw(f)
}
53 changes: 43 additions & 10 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"database/sql"
"database/sql/driver"
"strings"
"sync"
"time"

"go.uber.org/multierr"
Expand All @@ -15,11 +17,11 @@ import (
// with multi dbs connection, we decided to forward all single connection DB related function to the first primary DB
// For example, function like, `Conn()“, or `Stats()` only available for the primary DB, or the first primary DB (if using multi-primary)
type DB interface {
Begin() (*sql.Tx, error)
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
Begin() (Tx, error)
BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error)
Close() error
// Conn only available for the primary db or the first primary db (if using multi-primary)
Conn(ctx context.Context) (*sql.Conn, error)
Conn(ctx context.Context) (Conn, error)
Driver() driver.Driver
Exec(query string, args ...interface{}) (sql.Result, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
Expand Down Expand Up @@ -86,17 +88,27 @@ func (db *sqlDB) Driver() driver.Driver {
}

// Begin starts a transaction on the RW-db. The isolation level is dependent on the driver.
func (db *sqlDB) Begin() (*sql.Tx, error) {
return db.ReadWrite().Begin()
func (db *sqlDB) Begin() (Tx, error) {
return db.BeginTx(context.Background(), nil)
}

// BeginTx starts a transaction with the provided context on the RW-db.
//
// The provided TxOptions is optional and may be nil if defaults should be used.
// If a non-default isolation level is used that the driver doesn't support,
// an error will be returned.
func (db *sqlDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
return db.ReadWrite().BeginTx(ctx, opts)
func (db *sqlDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) {
sourceDB := db.ReadWrite()

stx, err := sourceDB.BeginTx(ctx, opts)
if err != nil {
return nil, err
}

return &tx{
sourceDB: sourceDB,
tx: stx,
}, nil
}

// Exec executes a query without returning any rows.
Expand Down Expand Up @@ -143,15 +155,24 @@ func (db *sqlDB) Prepare(query string) (_stmt Stmt, err error) {
// The provided context is used for the preparation of the statement, not for
// the execution of the statement.
func (db *sqlDB) PrepareContext(ctx context.Context, query string) (_stmt Stmt, err error) {
dbStmt := map[*sql.DB]*sql.Stmt{}
var dbStmtLock sync.Mutex
roStmts := make([]*sql.Stmt, len(db.replicas))
primaryStmts := make([]*sql.Stmt, len(db.primaries))
errPrimaries := doParallely(len(db.primaries), func(i int) (err error) {
primaryStmts[i], err = db.primaries[i].PrepareContext(ctx, query)
dbStmtLock.Lock()
dbStmt[db.primaries[i]] = primaryStmts[i]
dbStmtLock.Unlock()
return
})

errReplicas := doParallely(len(db.replicas), func(i int) (err error) {
roStmts[i], err = db.replicas[i].PrepareContext(ctx, query)
dbStmtLock.Lock()
dbStmt[db.replicas[i]] = roStmts[i]
dbStmtLock.Unlock()

// if connection error happens on RO connection,
// ignore and fallback to RW connection
if isDBConnectionError(err) {
Expand All @@ -166,11 +187,15 @@ func (db *sqlDB) PrepareContext(ctx context.Context, query string) (_stmt Stmt,
return
}

_query := strings.ToUpper(query)
writeFlag := strings.Contains(_query, "RETURNING")

_stmt = &stmt{
db: db,
loadBalancer: db.stmtLoadBalancer,
primaryStmts: primaryStmts,
replicaStmts: roStmts,
dbStmt: dbStmt,
writeFlag: writeFlag,
}
return _stmt, nil
}
Expand Down Expand Up @@ -298,8 +323,16 @@ func (db *sqlDB) ReadWrite() *sql.DB {

// Conn returns a single connection by either opening a new connection or returning an existing connection from the
// connection pool of the first primary db.
func (db *sqlDB) Conn(ctx context.Context) (*sql.Conn, error) {
return db.primaries[0].Conn(ctx)
func (db *sqlDB) Conn(ctx context.Context) (Conn, error) {
c, err := db.primaries[0].Conn(ctx)
if err != nil {
return nil, err
}

return &conn{
sourceDB: db.primaries[0],
conn: c,
}, nil
}

// Stats returns database statistics for the first primary db
Expand Down
50 changes: 50 additions & 0 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,56 @@ func testMW(t *testing.T, config DBConfig) {
stmt.Exec()
})

t.Run("prepare tx", func(t *testing.T) {
query := "select 1"

for _, mock := range mockPimaries {
mock.ExpectPrepare(query)
defer func(mock sqlmock.Sqlmock) {
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("sqlmock:unmet expectations: %s", err)
}
}(mock)
}
for _, mock := range mockReplicas {
mock.ExpectPrepare(query)
defer func(mock sqlmock.Sqlmock) {
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("sqlmock:unmet expectations: %s", err)
}
}(mock)
}

stmt, err := resolver.Prepare(query)
if err != nil {
t.Error("prepare failed")
return
}

robin := resolver.loadBalancer.predict(noOfPrimaries)
mock := mockPimaries[robin]

mock.ExpectBegin()

tx, err := resolver.Begin()
if err != nil {
t.Error("begin failed", err)
return
}

txstmt := tx.Stmt(stmt)

mock.ExpectExec(query).WillReturnResult(sqlmock.NewResult(0, 0))
_, err = txstmt.Exec()
if err != nil {
t.Error("stmt exec failed", err)
return
}

mock.ExpectCommit()
tx.Commit()
})

t.Run("ping", func(t *testing.T) {
for _, mock := range mockPimaries {
mock.ExpectPing()
Expand Down
51 changes: 46 additions & 5 deletions stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ type Stmt interface {
}

type stmt struct {
db *sqlDB
loadBalancer StmtLoadBalancer
primaryStmts []*sql.Stmt
replicaStmts []*sql.Stmt
writeFlag bool
dbStmt map[*sql.DB]*sql.Stmt
}

// Close closes the statement by concurrently closing all underlying
Expand Down Expand Up @@ -64,8 +65,15 @@ func (s *stmt) Query(args ...interface{}) (*sql.Rows, error) {
// arguments and returns the query results as a *sql.Rows.
// Query uses the read only DB as the underlying physical db.
func (s *stmt) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) {
rows, err := s.ROStmt().QueryContext(ctx, args...)
if isDBConnectionError(err) {
var curStmt *sql.Stmt
if s.writeFlag {
curStmt = s.RWStmt()
} else {
curStmt = s.ROStmt()
}

rows, err := curStmt.QueryContext(ctx, args...)
if isDBConnectionError(err) && !s.writeFlag {
rows, err = s.RWStmt().QueryContext(ctx, args...)
}
return rows, err
Expand All @@ -88,8 +96,15 @@ func (s *stmt) QueryRow(args ...interface{}) *sql.Row {
// Otherwise, the *sql.Row's Scan scans the first selected row and discards the rest.
// QueryRowContext uses the read only DB as the underlying physical db.
func (s *stmt) QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row {
row := s.ROStmt().QueryRowContext(ctx, args...)
if isDBConnectionError(row.Err()) {
var curStmt *sql.Stmt
if s.writeFlag {
curStmt = s.RWStmt()
} else {
curStmt = s.ROStmt()
}

row := curStmt.QueryRowContext(ctx, args...)
if isDBConnectionError(row.Err()) && !s.writeFlag {
row = s.RWStmt().QueryRowContext(ctx, args...)
}
return row
Expand All @@ -108,3 +123,29 @@ func (s *stmt) ROStmt() *sql.Stmt {
func (s *stmt) RWStmt() *sql.Stmt {
return s.loadBalancer.Resolve(s.primaryStmts)
}

// stmtForDB returns the corresponding *sql.Stmt instance for the given *sql.DB.
// Ihis is needed because sql.Tx.Stmt() requires that the passed *sql.Stmt be from the same database
// as the transaction.
func (s *stmt) stmtForDB(db *sql.DB) *sql.Stmt {
xsm, ok := s.dbStmt[db]
if ok {
return xsm
}

// return any statement so errors can be detected by Tx.Stmt()
return s.RWStmt()
}

// newSingleDBStmt creates a new stmt for a single DB connection.
// This is used by statements return by transaction and connections.
func newSingleDBStmt(sourceDB *sql.DB, st *sql.Stmt, writeFlag bool) *stmt {
return &stmt{
loadBalancer: &RoundRobinLoadBalancer[*sql.Stmt]{},
primaryStmts: []*sql.Stmt{st},
dbStmt: map[*sql.DB]*sql.Stmt{
sourceDB: st,
},
writeFlag: writeFlag,
}
}

0 comments on commit 1f99231

Please sign in to comment.