Skip to content

Commit

Permalink
groot/rsql/rsqldrv: migrate to new database/sql/driver.XyzContext int…
Browse files Browse the repository at this point in the history
…erfaces

Updates #720.
  • Loading branch information
sbinet committed May 22, 2020
1 parent a5deed1 commit 889375c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 20 deletions.
35 changes: 18 additions & 17 deletions groot/rsql/rsqldrv/driver.go
Expand Up @@ -6,6 +6,7 @@
package rsqldrv // import "go-hep.org/x/hep/groot/rsql/rsqldrv"

import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
Expand Down Expand Up @@ -192,31 +193,31 @@ func (conn *driverConn) Rollback() error {
panic("conn-rollback: not implemented")
}

func (conn *driverConn) Exec(query string, args []driver.Value) (driver.Result, error) {
func (conn *driverConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
stmt, err := sqlparser.Parse(query)
if err != nil {
return nil, err
}

return conn.exec(stmt, args)
return conn.exec(ctx, stmt, args)
}

func (conn *driverConn) exec(stmt sqlparser.Statement, args []driver.Value) (driver.Result, error) {
func (conn *driverConn) exec(ctx context.Context, stmt sqlparser.Statement, args []driver.NamedValue) (driver.Result, error) {
panic("not implemented")
}

func (conn *driverConn) Query(query string, args []driver.Value) (driver.Rows, error) {
func (conn *driverConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
stmt, err := sqlparser.Parse(query)
if err != nil {
return nil, err
}
return conn.query(stmt, args)
return conn.query(ctx, stmt, args)
}

func (conn *driverConn) query(stmt sqlparser.Statement, args []driver.Value) (driver.Rows, error) {
func (conn *driverConn) query(ctx context.Context, stmt sqlparser.Statement, args []driver.NamedValue) (driver.Rows, error) {
switch stmt := stmt.(type) {
case *sqlparser.Select:
rows, err := newDriverRows(conn, stmt, args)
rows, err := newDriverRows(ctx, conn, stmt, args)
return rows, err
}
panic("not implemented")
Expand All @@ -233,7 +234,7 @@ func (res *driverResult) RowsAffected() (int64, error) { return res.rows, nil }
// driverRows is an iterator over an executed query's results.
type driverRows struct {
conn *driverConn
args []driver.Value
args []driver.NamedValue
cols []string
types []colDescr // types of the columns
deps []string // names of the columns to be read
Expand All @@ -251,7 +252,7 @@ type colDescr struct {
Type reflect.Type
}

func newDriverRows(conn *driverConn, stmt *sqlparser.Select, args []driver.Value) (*driverRows, error) {
func newDriverRows(ctx context.Context, conn *driverConn, stmt *sqlparser.Select, args []driver.NamedValue) (*driverRows, error) {
var (
name = ""
f = conn.f
Expand Down Expand Up @@ -363,7 +364,7 @@ func varsFrom(vars []rtree.ReadVar) []interface{} {

// extractDepsFromSelect analyses the query and extracts the branches that need to be read
// for the query to be properly executed.
func (rows *driverRows) extractDepsFromSelect(tree rtree.Tree, stmt *sqlparser.Select, args []driver.Value) ([]rtree.ReadVar, error) {
func (rows *driverRows) extractDepsFromSelect(tree rtree.Tree, stmt *sqlparser.Select, args []driver.NamedValue) ([]rtree.ReadVar, error) {
var (
vars []rtree.ReadVar

Expand Down Expand Up @@ -459,7 +460,7 @@ func (rows *driverRows) extractDepsFromSelect(tree rtree.Tree, stmt *sqlparser.S
return vars, nil
}

func (rows *driverRows) extractColsFromSelect(tree rtree.Tree, stmt *sqlparser.Select, args []driver.Value) ([]string, error) {
func (rows *driverRows) extractColsFromSelect(tree rtree.Tree, stmt *sqlparser.Select, args []driver.NamedValue) ([]string, error) {
var cols []string

collect := func(node sqlparser.SQLNode) (bool, error) {
Expand Down Expand Up @@ -632,7 +633,7 @@ func (stmt *driverStmt) Query(args []driver.Value) (driver.Rows, error) {
panic("not implemented")
}

func newExprFrom(expr sqlparser.Expr, args []driver.Value) (expression, error) {
func newExprFrom(expr sqlparser.Expr, args []driver.NamedValue) (expression, error) {
switch expr := expr.(type) {
case *sqlparser.ComparisonExpr:
op := operatorFrom(expr.Operator)
Expand Down Expand Up @@ -717,11 +718,11 @@ func newExprFrom(expr sqlparser.Expr, args []driver.Value) (expression, error) {
}

var (
_ driver.Driver = (*rootDriver)(nil)
_ driver.Conn = (*driverConn)(nil)
_ driver.Execer = (*driverConn)(nil)
_ driver.Queryer = (*driverConn)(nil)
_ driver.Tx = (*driverConn)(nil)
_ driver.Driver = (*rootDriver)(nil)
_ driver.Conn = (*driverConn)(nil)
_ driver.ExecerContext = (*driverConn)(nil)
_ driver.QueryerContext = (*driverConn)(nil)
_ driver.Tx = (*driverConn)(nil)

_ driver.Result = (*driverResult)(nil)
_ driver.Rows = (*driverRows)(nil)
Expand Down
6 changes: 3 additions & 3 deletions groot/rsql/rsqldrv/expr.go
Expand Up @@ -29,7 +29,7 @@ type execCtx struct {
mu sync.RWMutex
}

func newExecCtx(db *driverConn, args []driver.Value) *execCtx {
func newExecCtx(db *driverConn, args []driver.NamedValue) *execCtx {
ectx := execCtx{db: db}
return &ectx
}
Expand Down Expand Up @@ -1314,7 +1314,7 @@ type valueExpr struct {
v interface{}
}

func newValueExpr(expr *sqlparser.SQLVal, args []driver.Value) (expression, error) {
func newValueExpr(expr *sqlparser.SQLVal, args []driver.NamedValue) (expression, error) {
s := string(expr.Val)
switch expr.Type {
// case sqlparser.HexVal: // FIXME(sbinet): difference with HexNum?
Expand Down Expand Up @@ -1359,7 +1359,7 @@ func newValueExpr(expr *sqlparser.SQLVal, args []driver.Value) (expression, erro
i-- // :v1 --> index-0
return &valueExpr{
expr: expr,
v: idealValArgFrom(args[i]), // FIXME(sbinet): unwrap driver.Value?
v: idealValArgFrom(args[i].Value), // FIXME(sbinet): unwrap driver.Value?
}, nil

default:
Expand Down

0 comments on commit 889375c

Please sign in to comment.