Skip to content

Commit

Permalink
csvutil/csvdriver: implement new database/sql/driver.XyzContext
Browse files Browse the repository at this point in the history
Updates #720.
  • Loading branch information
sbinet committed May 22, 2020
1 parent 4723470 commit 54e0051
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 27 deletions.
78 changes: 58 additions & 20 deletions csvutil/csvdriver/driver.go
Expand Up @@ -6,6 +6,7 @@
package csvdriver // import "go-hep.org/x/hep/csvutil/csvdriver"

import (
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
Expand All @@ -21,13 +22,19 @@ import (
)

var (
_ driver.Driver = (*csvDriver)(nil)
_ driver.Conn = (*csvConn)(nil)
_ driver.Execer = (*csvConn)(nil)
_ driver.Queryer = (*csvConn)(nil)
_ driver.Tx = (*csvConn)(nil)
_ driver.Driver = (*csvDriver)(nil)
_ drvConn = (*csvConn)(nil)
_ driver.ExecerContext = (*csvConn)(nil)
_ driver.QueryerContext = (*csvConn)(nil)
_ driver.Tx = (*csvConn)(nil)
)

type drvConn interface {
driver.Conn
driver.ConnBeginTx
driver.ConnPrepareContext
}

// Conn describes how a connection to the CSV-driver should be established.
type Conn struct {
File string `json:"file"` // name of the file to be open
Expand All @@ -50,7 +57,6 @@ func (c *Conn) setDefaults() {
if c.Comment == 0 {
c.Comment = '#'
}
return
}

func (c Conn) toJSON() (string, error) {
Expand Down Expand Up @@ -194,9 +200,9 @@ type csvConn struct {
drv *csvDriver
refs int

conn driver.Conn
exec driver.Execer
query driver.Queryer
conn drvConn
exec driver.ExecerContext
query driver.QueryerContext
tx driver.Tx
}

Expand All @@ -206,15 +212,42 @@ func (conn *csvConn) initDB() error {
return err
}

conn.conn = c
conn.exec = c.(driver.Execer)
conn.query = c.(driver.Queryer)
conn.conn = connWrap(c)
conn.exec = c.(driver.ExecerContext)
conn.query = c.(driver.QueryerContext)
return nil
}

func connWrap(c driver.Conn) drvConn {
if c, ok := c.(drvConn); ok {
return c
}

wrap := &drvConnWrap{
Conn: c,
ConnPrepareContext: c.(driver.ConnPrepareContext),
}
return wrap
}

type drvConnWrap struct {
driver.Conn
driver.ConnPrepareContext
}

func (drv *drvConnWrap) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
// FIXME(sbinet): drop this hack when modernc.org/ql implements driver.ConnBeginTx.
return drv.Conn.Begin() //lint:ignore SA1019 drop this hack when modernc.org/ql supports driver.ConnBeginTx
}

// PrepareContext returns a prepared statement, bound to this connection.
func (conn *csvConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
return conn.conn.PrepareContext(ctx, query)
}

// Prepare returns a prepared statement, bound to this connection.
func (conn *csvConn) Prepare(query string) (driver.Stmt, error) {
return conn.conn.Prepare(query)
return conn.conn.PrepareContext(context.Background(), query)
}

// Close invalidates and potentially stops any current
Expand Down Expand Up @@ -256,22 +289,27 @@ func (conn *csvConn) Close() error {
return err
}

// Begin starts and returns a new transaction.
func (conn *csvConn) Begin() (driver.Tx, error) {
tx, err := conn.conn.Begin()
// BeginTx starts and returns a new transaction.
func (conn *csvConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
tx, err := conn.conn.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
conn.tx = tx
return tx, err
}

func (conn *csvConn) Exec(query string, args []driver.Value) (driver.Result, error) {
return conn.exec.Exec(query, args)
// Begin starts and returns a new transaction.
func (conn *csvConn) Begin() (driver.Tx, error) {
return conn.BeginTx(context.Background(), driver.TxOptions{})
}

func (conn *csvConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
return conn.exec.ExecContext(ctx, query, args)
}

func (conn *csvConn) Query(query string, args []driver.Value) (driver.Rows, error) {
rows, err := conn.query.Query(query, args)
func (conn *csvConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
rows, err := conn.query.QueryContext(ctx, query, args)
if err != nil {
return nil, err
}
Expand Down
21 changes: 14 additions & 7 deletions csvutil/csvdriver/import.go
Expand Up @@ -5,6 +5,7 @@
package csvdriver

import (
"context"
"database/sql/driver"
"fmt"
"io"
Expand Down Expand Up @@ -36,12 +37,14 @@ func (conn *csvConn) importCSV() error {
}
defer tx.Commit()

_, err = conn.Exec("create table csv ("+schema.Decl()+")", nil)
ctx := context.Background()

_, err = conn.ExecContext(ctx, "create table csv ("+schema.Decl()+")", nil)
if err != nil {
return err
}

_, err = conn.Exec("create index csv_id on csv (id());", nil)
_, err = conn.ExecContext(ctx, "create index csv_id on csv (id());", nil)
if err != nil {
return err
}
Expand All @@ -65,9 +68,9 @@ func (conn *csvConn) importCSV() error {
return err
}
for i, arg := range pargs {
vargs[i] = reflect.ValueOf(arg).Elem().Interface()
vargs[i].Value = reflect.ValueOf(arg).Elem().Interface()
}
_, err = conn.Exec(insert, vargs)
_, err = conn.ExecContext(ctx, insert, vargs)
if err != nil {
return err
}
Expand Down Expand Up @@ -177,12 +180,16 @@ func (st *schemaType) Decl() string {
return strings.Join(o, ", ")
}

func (st *schemaType) Args() ([]driver.Value, []interface{}) {
vargs := make([]driver.Value, len(*st))
func (st *schemaType) Args() ([]driver.NamedValue, []interface{}) {
vargs := make([]driver.NamedValue, len(*st))
pargs := make([]interface{}, len(*st))
for i, v := range *st {
ptr := reflect.New(v.v.Type())
vargs[i] = ptr.Elem().Interface()
vargs[i] = driver.NamedValue{
Name: v.n,
Ordinal: i + 1,
Value: ptr.Elem().Interface(),
}
pargs[i] = ptr.Interface()
}
return vargs, pargs
Expand Down

0 comments on commit 54e0051

Please sign in to comment.