Skip to content

Commit

Permalink
Merge branch 'main' into db-tx
Browse files Browse the repository at this point in the history
  • Loading branch information
bxcodec committed Jul 24, 2023
2 parents d5d7058 + 404cf52 commit 61e6300
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 10 deletions.
7 changes: 3 additions & 4 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ type sqlDB struct {
replicas []*sql.DB
loadBalancer DBLoadBalancer
stmtLoadBalancer StmtLoadBalancer
queryTypeChecker QueryTypeChecker
}

// PrimaryDBs return all the active primary DB
Expand Down Expand Up @@ -209,8 +210,7 @@ func (db *sqlDB) Query(query string, args ...interface{}) (*sql.Rows, error) {
// The args are for any placeholder parameters in the query.
func (db *sqlDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
var curDB *sql.DB
_query := strings.ToUpper(query)
writeFlag := strings.Contains(_query, "RETURNING")
writeFlag := db.queryTypeChecker.Check(query) == QueryTypeWrite

if writeFlag {
curDB = db.ReadWrite()
Expand All @@ -237,8 +237,7 @@ func (db *sqlDB) QueryRow(query string, args ...interface{}) *sql.Row {
// Errors are deferred until Row's Scan method is called.
func (db *sqlDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
var curDB *sql.DB
_query := strings.ToUpper(query)
writeFlag := strings.Contains(_query, "RETURNING")
writeFlag := db.queryTypeChecker.Check(query) == QueryTypeWrite

if writeFlag {
curDB = db.ReadWrite()
Expand Down
22 changes: 16 additions & 6 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ const (

// Option define the option property
type Option struct {
PrimaryDBs []*sql.DB
ReplicaDBs []*sql.DB
StmtLB StmtLoadBalancer
DBLB DBLoadBalancer
PrimaryDBs []*sql.DB
ReplicaDBs []*sql.DB
StmtLB StmtLoadBalancer
DBLB DBLoadBalancer
QueryTypeChecker QueryTypeChecker
}

// OptionFunc used for option chaining
Expand All @@ -39,6 +40,14 @@ func WithReplicaDBs(replicaDBs ...*sql.DB) OptionFunc {
}
}

// WithQueryTypeChecker sets the query type checker instance.
// The default one just checks for the presence of the string "RETURNING" in the uppercase query.
func WithQueryTypeChecker(checker QueryTypeChecker) OptionFunc {
return func(opt *Option) {
opt.QueryTypeChecker = checker
}
}

// WithLoadBalancer configure the loadbalancer for the resolver
func WithLoadBalancer(lb LoadBalancerPolicy) OptionFunc {
return func(opt *Option) {
Expand All @@ -61,7 +70,8 @@ func WithLoadBalancer(lb LoadBalancerPolicy) OptionFunc {

func defaultOption() *Option {
return &Option{
DBLB: &RoundRobinLoadBalancer[*sql.DB]{},
StmtLB: &RoundRobinLoadBalancer[*sql.Stmt]{},
DBLB: &RoundRobinLoadBalancer[*sql.DB]{},
StmtLB: &RoundRobinLoadBalancer[*sql.Stmt]{},
QueryTypeChecker: &DefaultQueryTypeChecker{},
}
}
28 changes: 28 additions & 0 deletions query.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package dbresolver

import "strings"

type QueryType int

const (
QueryTypeUnknown QueryType = iota
QueryTypeRead
QueryTypeWrite
)

// QueryTypeChecker is used to try to detect the query type, like for detecting RETURNING clauses in
// INSERT/UPDATE clauses.
type QueryTypeChecker interface {
Check(query string) QueryType
}

// DefaultQueryTypeChecker searches for a "RETURNING" string inside the query to detect a write query.
type DefaultQueryTypeChecker struct {
}

func (c DefaultQueryTypeChecker) Check(query string) QueryType {
if strings.Contains(strings.ToUpper(query), "RETURNING") {
return QueryTypeWrite
}
return QueryTypeUnknown
}
1 change: 1 addition & 0 deletions resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ func New(opts ...OptionFunc) DB {
replicas: opt.ReplicaDBs,
loadBalancer: opt.DBLB,
stmtLoadBalancer: opt.StmtLB,
queryTypeChecker: opt.QueryTypeChecker,
}
}

0 comments on commit 61e6300

Please sign in to comment.