diff --git a/conn.go b/conn.go new file mode 100644 index 00000000..bc6863c1 --- /dev/null +++ b/conn.go @@ -0,0 +1,73 @@ +package dbresolver + +import ( + "context" + "database/sql" + "strings" +) + +// Conn is a *sql.Conn wrapper. +// Its main purpose is to be able to return the internal Tx and Stmt interfaces. +type Conn interface { + Close() error + BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + PingContext(ctx context.Context) error + PrepareContext(ctx context.Context, query string) (Stmt, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row + Raw(f func(driverConn interface{}) error) (err error) +} + +type conn struct { + sourceDB *sql.DB + conn *sql.Conn +} + +func (c *conn) Close() error { + return c.conn.Close() +} + +func (c *conn) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) { + stx, err := c.conn.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + + return &tx{ + sourceDB: c.sourceDB, + tx: stx, + }, nil +} + +func (c *conn) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + return c.conn.ExecContext(ctx, query, args...) +} + +func (c *conn) PingContext(ctx context.Context) error { + return c.conn.PingContext(ctx) +} + +func (c *conn) PrepareContext(ctx context.Context, query string) (Stmt, error) { + pstmt, err := c.conn.PrepareContext(ctx, query) + if err != nil { + return nil, err + } + + _query := strings.ToUpper(query) + writeFlag := strings.Contains(_query, "RETURNING") + + return newSingleDBStmt(c.sourceDB, pstmt, writeFlag), nil +} + +func (c *conn) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + return c.conn.QueryContext(ctx, query, args...) +} + +func (c *conn) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + return c.conn.QueryRowContext(ctx, query, args...) +} + +func (c *conn) Raw(f func(driverConn interface{}) error) (err error) { + return c.conn.Raw(f) +} diff --git a/db.go b/db.go index 1e9132e9..32d17c5b 100644 --- a/db.go +++ b/db.go @@ -4,6 +4,8 @@ import ( "context" "database/sql" "database/sql/driver" + "strings" + "sync" "time" "go.uber.org/multierr" @@ -15,11 +17,11 @@ import ( // with multi dbs connection, we decided to forward all single connection DB related function to the first primary DB // For example, function like, `Conn()“, or `Stats()` only available for the primary DB, or the first primary DB (if using multi-primary) type DB interface { - Begin() (*sql.Tx, error) - BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) + Begin() (Tx, error) + BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) Close() error // Conn only available for the primary db or the first primary db (if using multi-primary) - Conn(ctx context.Context) (*sql.Conn, error) + Conn(ctx context.Context) (Conn, error) Driver() driver.Driver Exec(query string, args ...interface{}) (sql.Result, error) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) @@ -86,8 +88,8 @@ func (db *sqlDB) Driver() driver.Driver { } // Begin starts a transaction on the RW-db. The isolation level is dependent on the driver. -func (db *sqlDB) Begin() (*sql.Tx, error) { - return db.ReadWrite().Begin() +func (db *sqlDB) Begin() (Tx, error) { + return db.BeginTx(context.Background(), nil) } // BeginTx starts a transaction with the provided context on the RW-db. @@ -95,8 +97,18 @@ func (db *sqlDB) Begin() (*sql.Tx, error) { // The provided TxOptions is optional and may be nil if defaults should be used. // If a non-default isolation level is used that the driver doesn't support, // an error will be returned. -func (db *sqlDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { - return db.ReadWrite().BeginTx(ctx, opts) +func (db *sqlDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) { + sourceDB := db.ReadWrite() + + stx, err := sourceDB.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + + return &tx{ + sourceDB: sourceDB, + tx: stx, + }, nil } // Exec executes a query without returning any rows. @@ -143,15 +155,24 @@ func (db *sqlDB) Prepare(query string) (_stmt Stmt, err error) { // The provided context is used for the preparation of the statement, not for // the execution of the statement. func (db *sqlDB) PrepareContext(ctx context.Context, query string) (_stmt Stmt, err error) { + dbStmt := map[*sql.DB]*sql.Stmt{} + var dbStmtLock sync.Mutex roStmts := make([]*sql.Stmt, len(db.replicas)) primaryStmts := make([]*sql.Stmt, len(db.primaries)) errPrimaries := doParallely(len(db.primaries), func(i int) (err error) { primaryStmts[i], err = db.primaries[i].PrepareContext(ctx, query) + dbStmtLock.Lock() + dbStmt[db.primaries[i]] = primaryStmts[i] + dbStmtLock.Unlock() return }) errReplicas := doParallely(len(db.replicas), func(i int) (err error) { roStmts[i], err = db.replicas[i].PrepareContext(ctx, query) + dbStmtLock.Lock() + dbStmt[db.replicas[i]] = roStmts[i] + dbStmtLock.Unlock() + // if connection error happens on RO connection, // ignore and fallback to RW connection if isDBConnectionError(err) { @@ -166,11 +187,15 @@ func (db *sqlDB) PrepareContext(ctx context.Context, query string) (_stmt Stmt, return } + _query := strings.ToUpper(query) + writeFlag := strings.Contains(_query, "RETURNING") + _stmt = &stmt{ - db: db, loadBalancer: db.stmtLoadBalancer, primaryStmts: primaryStmts, replicaStmts: roStmts, + dbStmt: dbStmt, + writeFlag: writeFlag, } return _stmt, nil } @@ -298,8 +323,16 @@ func (db *sqlDB) ReadWrite() *sql.DB { // Conn returns a single connection by either opening a new connection or returning an existing connection from the // connection pool of the first primary db. -func (db *sqlDB) Conn(ctx context.Context) (*sql.Conn, error) { - return db.primaries[0].Conn(ctx) +func (db *sqlDB) Conn(ctx context.Context) (Conn, error) { + c, err := db.primaries[0].Conn(ctx) + if err != nil { + return nil, err + } + + return &conn{ + sourceDB: db.primaries[0], + conn: c, + }, nil } // Stats returns database statistics for the first primary db diff --git a/db_test.go b/db_test.go index 12c0157b..96c18729 100644 --- a/db_test.go +++ b/db_test.go @@ -201,6 +201,56 @@ func testMW(t *testing.T, config DBConfig) { stmt.Exec() }) + t.Run("prepare tx", func(t *testing.T) { + query := "select 1" + + for _, mock := range mockPimaries { + mock.ExpectPrepare(query) + defer func(mock sqlmock.Sqlmock) { + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("sqlmock:unmet expectations: %s", err) + } + }(mock) + } + for _, mock := range mockReplicas { + mock.ExpectPrepare(query) + defer func(mock sqlmock.Sqlmock) { + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("sqlmock:unmet expectations: %s", err) + } + }(mock) + } + + stmt, err := resolver.Prepare(query) + if err != nil { + t.Error("prepare failed") + return + } + + robin := resolver.loadBalancer.predict(noOfPrimaries) + mock := mockPimaries[robin] + + mock.ExpectBegin() + + tx, err := resolver.Begin() + if err != nil { + t.Error("begin failed", err) + return + } + + txstmt := tx.Stmt(stmt) + + mock.ExpectExec(query).WillReturnResult(sqlmock.NewResult(0, 0)) + _, err = txstmt.Exec() + if err != nil { + t.Error("stmt exec failed", err) + return + } + + mock.ExpectCommit() + tx.Commit() + }) + t.Run("ping", func(t *testing.T) { for _, mock := range mockPimaries { mock.ExpectPing() diff --git a/stmt.go b/stmt.go index 09157bc7..0dca1500 100644 --- a/stmt.go +++ b/stmt.go @@ -20,10 +20,11 @@ type Stmt interface { } type stmt struct { - db *sqlDB loadBalancer StmtLoadBalancer primaryStmts []*sql.Stmt replicaStmts []*sql.Stmt + writeFlag bool + dbStmt map[*sql.DB]*sql.Stmt } // Close closes the statement by concurrently closing all underlying @@ -64,8 +65,15 @@ func (s *stmt) Query(args ...interface{}) (*sql.Rows, error) { // arguments and returns the query results as a *sql.Rows. // Query uses the read only DB as the underlying physical db. func (s *stmt) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) { - rows, err := s.ROStmt().QueryContext(ctx, args...) - if isDBConnectionError(err) { + var curStmt *sql.Stmt + if s.writeFlag { + curStmt = s.RWStmt() + } else { + curStmt = s.ROStmt() + } + + rows, err := curStmt.QueryContext(ctx, args...) + if isDBConnectionError(err) && !s.writeFlag { rows, err = s.RWStmt().QueryContext(ctx, args...) } return rows, err @@ -88,8 +96,15 @@ func (s *stmt) QueryRow(args ...interface{}) *sql.Row { // Otherwise, the *sql.Row's Scan scans the first selected row and discards the rest. // QueryRowContext uses the read only DB as the underlying physical db. func (s *stmt) QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row { - row := s.ROStmt().QueryRowContext(ctx, args...) - if isDBConnectionError(row.Err()) { + var curStmt *sql.Stmt + if s.writeFlag { + curStmt = s.RWStmt() + } else { + curStmt = s.ROStmt() + } + + row := curStmt.QueryRowContext(ctx, args...) + if isDBConnectionError(row.Err()) && !s.writeFlag { row = s.RWStmt().QueryRowContext(ctx, args...) } return row @@ -108,3 +123,29 @@ func (s *stmt) ROStmt() *sql.Stmt { func (s *stmt) RWStmt() *sql.Stmt { return s.loadBalancer.Resolve(s.primaryStmts) } + +// stmtForDB returns the corresponding *sql.Stmt instance for the given *sql.DB. +// Ihis is needed because sql.Tx.Stmt() requires that the passed *sql.Stmt be from the same database +// as the transaction. +func (s *stmt) stmtForDB(db *sql.DB) *sql.Stmt { + xsm, ok := s.dbStmt[db] + if ok { + return xsm + } + + // return any statement so errors can be detected by Tx.Stmt() + return s.RWStmt() +} + +// newSingleDBStmt creates a new stmt for a single DB connection. +// This is used by statements return by transaction and connections. +func newSingleDBStmt(sourceDB *sql.DB, st *sql.Stmt, writeFlag bool) *stmt { + return &stmt{ + loadBalancer: &RoundRobinLoadBalancer[*sql.Stmt]{}, + primaryStmts: []*sql.Stmt{st}, + dbStmt: map[*sql.DB]*sql.Stmt{ + sourceDB: st, + }, + writeFlag: writeFlag, + } +} diff --git a/tx.go b/tx.go new file mode 100644 index 00000000..3a5752b5 --- /dev/null +++ b/tx.go @@ -0,0 +1,84 @@ +package dbresolver + +import ( + "context" + "database/sql" +) + +// Tx is a *sql.Tx wrapper. +// Its main purpose is to be able to return the internal Stmt interface. +type Tx interface { + Commit() error + Rollback() error + Exec(query string, args ...interface{}) (sql.Result, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + Prepare(query string) (Stmt, error) + PrepareContext(ctx context.Context, query string) (Stmt, error) + Query(query string, args ...interface{}) (*sql.Rows, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryRow(query string, args ...interface{}) *sql.Row + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row + Stmt(stmt Stmt) Stmt + StmtContext(ctx context.Context, stmt Stmt) Stmt +} + +type tx struct { + sourceDB *sql.DB + tx *sql.Tx +} + +func (t *tx) Commit() error { + return t.tx.Commit() +} + +func (t *tx) Rollback() error { + return t.tx.Rollback() +} + +func (t *tx) Exec(query string, args ...interface{}) (sql.Result, error) { + return t.ExecContext(context.Background(), query, args...) +} + +func (t *tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + return t.tx.ExecContext(ctx, query, args...) +} + +func (t *tx) Prepare(query string) (Stmt, error) { + return t.PrepareContext(context.Background(), query) +} + +func (t *tx) PrepareContext(ctx context.Context, query string) (Stmt, error) { + txstmt, err := t.tx.PrepareContext(ctx, query) + if err != nil { + return nil, err + } + + return newSingleDBStmt(t.sourceDB, txstmt, true), nil +} + +func (t *tx) Query(query string, args ...interface{}) (*sql.Rows, error) { + return t.QueryContext(context.Background(), query, args...) +} + +func (t *tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + return t.tx.QueryContext(ctx, query, args...) +} + +func (t *tx) QueryRow(query string, args ...interface{}) *sql.Row { + return t.QueryRowContext(context.Background(), query, args...) +} + +func (t *tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { + return t.tx.QueryRowContext(ctx, query, args...) +} + +func (t *tx) Stmt(s Stmt) Stmt { + return t.StmtContext(context.Background(), s) +} + +func (t *tx) StmtContext(ctx context.Context, s Stmt) Stmt { + if rstmt, ok := s.(*stmt); ok { + return newSingleDBStmt(t.sourceDB, t.tx.StmtContext(ctx, rstmt.stmtForDB(t.sourceDB)), true) + } + return s +}