Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create Tx and Conn interfaces to allow prepared statements to be used in transactions #40

Merged
merged 10 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package dbresolver

import (
"context"
"database/sql"
"strings"
)

// Conn is a *sql.Conn wrapper.
// Its main purpose is to be able to return the internal Tx and Stmt interfaces.
type Conn interface {
Close() error
BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
PingContext(ctx context.Context) error
PrepareContext(ctx context.Context, query string) (Stmt, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
Raw(f func(driverConn interface{}) error) (err error)
}

type conn struct {
sourceDB *sql.DB
conn *sql.Conn
}

func (c *conn) Close() error {
return c.conn.Close()
}

func (c *conn) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) {
stx, err := c.conn.BeginTx(ctx, opts)
if err != nil {
return nil, err
}

return &tx{
sourceDB: c.sourceDB,
tx: stx,
}, nil
}

func (c *conn) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
return c.conn.ExecContext(ctx, query, args...)
}

func (c *conn) PingContext(ctx context.Context) error {
return c.conn.PingContext(ctx)
}

func (c *conn) PrepareContext(ctx context.Context, query string) (Stmt, error) {
pstmt, err := c.conn.PrepareContext(ctx, query)
if err != nil {
return nil, err
}

_query := strings.ToUpper(query)
writeFlag := strings.Contains(_query, "RETURNING")

return newSingleDBStmt(c.sourceDB, pstmt, writeFlag), nil
}

func (c *conn) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
return c.conn.QueryContext(ctx, query, args...)
}

func (c *conn) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
return c.conn.QueryRowContext(ctx, query, args...)
}

func (c *conn) Raw(f func(driverConn interface{}) error) (err error) {
return c.conn.Raw(f)
}
53 changes: 43 additions & 10 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"database/sql"
"database/sql/driver"
"strings"
"sync"
"time"

"go.uber.org/multierr"
Expand All @@ -15,11 +17,11 @@ import (
// with multi dbs connection, we decided to forward all single connection DB related function to the first primary DB
// For example, function like, `Conn()“, or `Stats()` only available for the primary DB, or the first primary DB (if using multi-primary)
type DB interface {
Begin() (*sql.Tx, error)
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
Begin() (Tx, error)
BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error)
Close() error
// Conn only available for the primary db or the first primary db (if using multi-primary)
Conn(ctx context.Context) (*sql.Conn, error)
Conn(ctx context.Context) (Conn, error)
Driver() driver.Driver
Exec(query string, args ...interface{}) (sql.Result, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
Expand Down Expand Up @@ -86,17 +88,27 @@ func (db *sqlDB) Driver() driver.Driver {
}

// Begin starts a transaction on the RW-db. The isolation level is dependent on the driver.
func (db *sqlDB) Begin() (*sql.Tx, error) {
return db.ReadWrite().Begin()
func (db *sqlDB) Begin() (Tx, error) {
return db.BeginTx(context.Background(), nil)
}

// BeginTx starts a transaction with the provided context on the RW-db.
//
// The provided TxOptions is optional and may be nil if defaults should be used.
// If a non-default isolation level is used that the driver doesn't support,
// an error will be returned.
func (db *sqlDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
return db.ReadWrite().BeginTx(ctx, opts)
func (db *sqlDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) {
sourceDB := db.ReadWrite()

stx, err := sourceDB.BeginTx(ctx, opts)
if err != nil {
return nil, err
}

return &tx{
sourceDB: sourceDB,
tx: stx,
}, nil
}

// Exec executes a query without returning any rows.
Expand Down Expand Up @@ -143,15 +155,24 @@ func (db *sqlDB) Prepare(query string) (_stmt Stmt, err error) {
// The provided context is used for the preparation of the statement, not for
// the execution of the statement.
func (db *sqlDB) PrepareContext(ctx context.Context, query string) (_stmt Stmt, err error) {
dbStmt := map[*sql.DB]*sql.Stmt{}
var dbStmtLock sync.Mutex
roStmts := make([]*sql.Stmt, len(db.replicas))
primaryStmts := make([]*sql.Stmt, len(db.primaries))
errPrimaries := doParallely(len(db.primaries), func(i int) (err error) {
primaryStmts[i], err = db.primaries[i].PrepareContext(ctx, query)
dbStmtLock.Lock()
dbStmt[db.primaries[i]] = primaryStmts[i]
dbStmtLock.Unlock()
return
})

errReplicas := doParallely(len(db.replicas), func(i int) (err error) {
roStmts[i], err = db.replicas[i].PrepareContext(ctx, query)
dbStmtLock.Lock()
dbStmt[db.replicas[i]] = roStmts[i]
dbStmtLock.Unlock()

// if connection error happens on RO connection,
// ignore and fallback to RW connection
if isDBConnectionError(err) {
Expand All @@ -166,11 +187,15 @@ func (db *sqlDB) PrepareContext(ctx context.Context, query string) (_stmt Stmt,
return
}

_query := strings.ToUpper(query)
writeFlag := strings.Contains(_query, "RETURNING")

_stmt = &stmt{
db: db,
loadBalancer: db.stmtLoadBalancer,
primaryStmts: primaryStmts,
replicaStmts: roStmts,
dbStmt: dbStmt,
writeFlag: writeFlag,
}
return _stmt, nil
}
Expand Down Expand Up @@ -298,8 +323,16 @@ func (db *sqlDB) ReadWrite() *sql.DB {

// Conn returns a single connection by either opening a new connection or returning an existing connection from the
// connection pool of the first primary db.
func (db *sqlDB) Conn(ctx context.Context) (*sql.Conn, error) {
return db.primaries[0].Conn(ctx)
func (db *sqlDB) Conn(ctx context.Context) (Conn, error) {
c, err := db.primaries[0].Conn(ctx)
if err != nil {
return nil, err
}

return &conn{
sourceDB: db.primaries[0],
conn: c,
}, nil
}

// Stats returns database statistics for the first primary db
Expand Down
50 changes: 50 additions & 0 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
primaries := make([]*sql.DB, noOfPrimaries)
replicas := make([]*sql.DB, noOfReplicas)

mockPimaries := make([]sqlmock.Sqlmock, noOfPrimaries)

Check failure on line 39 in db_test.go

View workflow job for this annotation

GitHub Actions / build (1.20.x, x64, ubuntu-latest)

undeclared name: `sqlmock` (typecheck)

Check failure on line 39 in db_test.go

View workflow job for this annotation

GitHub Actions / build (1.20.x, x64, macos-latest)

undeclared name: `sqlmock` (typecheck)

Check failure on line 39 in db_test.go

View workflow job for this annotation

GitHub Actions / build (1.20.x, arm, ubuntu-latest)

undeclared name: `sqlmock` (typecheck)

Check failure on line 39 in db_test.go

View workflow job for this annotation

GitHub Actions / build (1.20.x, arm64, ubuntu-latest)

undeclared name: `sqlmock` (typecheck)

Check failure on line 39 in db_test.go

View workflow job for this annotation

GitHub Actions / build (1.20.x, arm, macos-latest)

undeclared name: `sqlmock` (typecheck)

Check failure on line 39 in db_test.go

View workflow job for this annotation

GitHub Actions / build (1.20.x, arm64, macos-latest)

undeclared name: `sqlmock` (typecheck)
mockReplicas := make([]sqlmock.Sqlmock, noOfReplicas)

Check failure on line 40 in db_test.go

View workflow job for this annotation

GitHub Actions / build (1.20.x, x64, ubuntu-latest)

undeclared name: `sqlmock` (typecheck)

Check failure on line 40 in db_test.go

View workflow job for this annotation

GitHub Actions / build (1.20.x, x64, macos-latest)

undeclared name: `sqlmock` (typecheck)

Check failure on line 40 in db_test.go

View workflow job for this annotation

GitHub Actions / build (1.20.x, arm, ubuntu-latest)

undeclared name: `sqlmock` (typecheck)

Check failure on line 40 in db_test.go

View workflow job for this annotation

GitHub Actions / build (1.20.x, arm64, ubuntu-latest)

undeclared name: `sqlmock` (typecheck)

Check failure on line 40 in db_test.go

View workflow job for this annotation

GitHub Actions / build (1.20.x, arm, macos-latest)

undeclared name: `sqlmock` (typecheck)

Check failure on line 40 in db_test.go

View workflow job for this annotation

GitHub Actions / build (1.20.x, arm64, macos-latest)

undeclared name: `sqlmock` (typecheck)

for i := 0; i < noOfPrimaries; i++ {
db, mock, err := createMock()
Expand Down Expand Up @@ -201,6 +201,56 @@
stmt.Exec()
})

t.Run("prepare tx", func(t *testing.T) {
query := "select 1"

for _, mock := range mockPimaries {
mock.ExpectPrepare(query)
defer func(mock sqlmock.Sqlmock) {
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("sqlmock:unmet expectations: %s", err)
}
}(mock)
}
for _, mock := range mockReplicas {
mock.ExpectPrepare(query)
defer func(mock sqlmock.Sqlmock) {
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("sqlmock:unmet expectations: %s", err)
}
}(mock)
}

stmt, err := resolver.Prepare(query)
if err != nil {
t.Error("prepare failed")
return
}

robin := resolver.loadBalancer.predict(noOfPrimaries)
mock := mockPimaries[robin]

mock.ExpectBegin()

tx, err := resolver.Begin()
if err != nil {
t.Error("begin failed", err)
return
}

txstmt := tx.Stmt(stmt)

mock.ExpectExec(query).WillReturnResult(sqlmock.NewResult(0, 0))
_, err = txstmt.Exec()
if err != nil {
t.Error("stmt exec failed", err)
return
}

mock.ExpectCommit()
tx.Commit()
})

t.Run("ping", func(t *testing.T) {
for _, mock := range mockPimaries {
mock.ExpectPing()
Expand Down Expand Up @@ -312,7 +362,7 @@
goto BEGIN_TEST_CASE
}

func createMock() (db *sql.DB, mock sqlmock.Sqlmock, err error) {

Check failure on line 365 in db_test.go

View workflow job for this annotation

GitHub Actions / build (1.20.x, x64, ubuntu-latest)

undeclared name: `sqlmock` (typecheck)

Check failure on line 365 in db_test.go

View workflow job for this annotation

GitHub Actions / build (1.20.x, x64, macos-latest)

undeclared name: `sqlmock` (typecheck)

Check failure on line 365 in db_test.go

View workflow job for this annotation

GitHub Actions / build (1.20.x, arm, ubuntu-latest)

undeclared name: `sqlmock` (typecheck)

Check failure on line 365 in db_test.go

View workflow job for this annotation

GitHub Actions / build (1.20.x, arm64, ubuntu-latest)

undeclared name: `sqlmock` (typecheck)

Check failure on line 365 in db_test.go

View workflow job for this annotation

GitHub Actions / build (1.20.x, arm, macos-latest)

undeclared name: `sqlmock` (typecheck)

Check failure on line 365 in db_test.go

View workflow job for this annotation

GitHub Actions / build (1.20.x, arm64, macos-latest)

undeclared name: `sqlmock` (typecheck)
db, mock, err = sqlmock.New(sqlmock.MonitorPingsOption(true), sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
return
}
Expand Down
51 changes: 46 additions & 5 deletions stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ type Stmt interface {
}

type stmt struct {
db *sqlDB
loadBalancer StmtLoadBalancer
primaryStmts []*sql.Stmt
replicaStmts []*sql.Stmt
writeFlag bool
dbStmt map[*sql.DB]*sql.Stmt
}

// Close closes the statement by concurrently closing all underlying
Expand Down Expand Up @@ -64,8 +65,15 @@ func (s *stmt) Query(args ...interface{}) (*sql.Rows, error) {
// arguments and returns the query results as a *sql.Rows.
// Query uses the read only DB as the underlying physical db.
func (s *stmt) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) {
rows, err := s.ROStmt().QueryContext(ctx, args...)
if isDBConnectionError(err) {
var curStmt *sql.Stmt
if s.writeFlag {
curStmt = s.RWStmt()
} else {
curStmt = s.ROStmt()
}

rows, err := curStmt.QueryContext(ctx, args...)
if isDBConnectionError(err) && !s.writeFlag {
rows, err = s.RWStmt().QueryContext(ctx, args...)
}
return rows, err
Expand All @@ -88,8 +96,15 @@ func (s *stmt) QueryRow(args ...interface{}) *sql.Row {
// Otherwise, the *sql.Row's Scan scans the first selected row and discards the rest.
// QueryRowContext uses the read only DB as the underlying physical db.
func (s *stmt) QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row {
row := s.ROStmt().QueryRowContext(ctx, args...)
if isDBConnectionError(row.Err()) {
var curStmt *sql.Stmt
if s.writeFlag {
curStmt = s.RWStmt()
} else {
curStmt = s.ROStmt()
}

row := curStmt.QueryRowContext(ctx, args...)
if isDBConnectionError(row.Err()) && !s.writeFlag {
row = s.RWStmt().QueryRowContext(ctx, args...)
}
return row
Expand All @@ -108,3 +123,29 @@ func (s *stmt) ROStmt() *sql.Stmt {
func (s *stmt) RWStmt() *sql.Stmt {
return s.loadBalancer.Resolve(s.primaryStmts)
}

// stmtForDB returns the corresponding *sql.Stmt instance for the given *sql.DB.
// Ihis is needed because sql.Tx.Stmt() requires that the passed *sql.Stmt be from the same database
// as the transaction.
func (s *stmt) stmtForDB(db *sql.DB) *sql.Stmt {
xsm, ok := s.dbStmt[db]
if ok {
return xsm
}

// return any statement so errors can be detected by Tx.Stmt()
return s.RWStmt()
}

// newSingleDBStmt creates a new stmt for a single DB connection.
// This is used by statements return by transaction and connections.
func newSingleDBStmt(sourceDB *sql.DB, st *sql.Stmt, writeFlag bool) *stmt {
return &stmt{
loadBalancer: &RoundRobinLoadBalancer[*sql.Stmt]{},
primaryStmts: []*sql.Stmt{st},
dbStmt: map[*sql.DB]*sql.Stmt{
sourceDB: st,
},
writeFlag: writeFlag,
}
}
Loading
Loading