From 54e0051ac9f0ac1129e4f63bacfc569d12ec963d Mon Sep 17 00:00:00 2001 From: Sebastien Binet Date: Fri, 22 May 2020 10:40:11 +0200 Subject: [PATCH] csvutil/csvdriver: implement new database/sql/driver.XyzContext Updates go-hep/hep#720. --- csvutil/csvdriver/driver.go | 78 +++++++++++++++++++++++++++---------- csvutil/csvdriver/import.go | 21 ++++++---- 2 files changed, 72 insertions(+), 27 deletions(-) diff --git a/csvutil/csvdriver/driver.go b/csvutil/csvdriver/driver.go index cc15f5955..668ef1504 100644 --- a/csvutil/csvdriver/driver.go +++ b/csvutil/csvdriver/driver.go @@ -6,6 +6,7 @@ package csvdriver // import "go-hep.org/x/hep/csvutil/csvdriver" import ( + "context" "database/sql" "database/sql/driver" "encoding/json" @@ -21,13 +22,19 @@ import ( ) var ( - _ driver.Driver = (*csvDriver)(nil) - _ driver.Conn = (*csvConn)(nil) - _ driver.Execer = (*csvConn)(nil) - _ driver.Queryer = (*csvConn)(nil) - _ driver.Tx = (*csvConn)(nil) + _ driver.Driver = (*csvDriver)(nil) + _ drvConn = (*csvConn)(nil) + _ driver.ExecerContext = (*csvConn)(nil) + _ driver.QueryerContext = (*csvConn)(nil) + _ driver.Tx = (*csvConn)(nil) ) +type drvConn interface { + driver.Conn + driver.ConnBeginTx + driver.ConnPrepareContext +} + // Conn describes how a connection to the CSV-driver should be established. type Conn struct { File string `json:"file"` // name of the file to be open @@ -50,7 +57,6 @@ func (c *Conn) setDefaults() { if c.Comment == 0 { c.Comment = '#' } - return } func (c Conn) toJSON() (string, error) { @@ -194,9 +200,9 @@ type csvConn struct { drv *csvDriver refs int - conn driver.Conn - exec driver.Execer - query driver.Queryer + conn drvConn + exec driver.ExecerContext + query driver.QueryerContext tx driver.Tx } @@ -206,15 +212,42 @@ func (conn *csvConn) initDB() error { return err } - conn.conn = c - conn.exec = c.(driver.Execer) - conn.query = c.(driver.Queryer) + conn.conn = connWrap(c) + conn.exec = c.(driver.ExecerContext) + conn.query = c.(driver.QueryerContext) return nil } +func connWrap(c driver.Conn) drvConn { + if c, ok := c.(drvConn); ok { + return c + } + + wrap := &drvConnWrap{ + Conn: c, + ConnPrepareContext: c.(driver.ConnPrepareContext), + } + return wrap +} + +type drvConnWrap struct { + driver.Conn + driver.ConnPrepareContext +} + +func (drv *drvConnWrap) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + // FIXME(sbinet): drop this hack when modernc.org/ql implements driver.ConnBeginTx. + return drv.Conn.Begin() //lint:ignore SA1019 drop this hack when modernc.org/ql supports driver.ConnBeginTx +} + +// PrepareContext returns a prepared statement, bound to this connection. +func (conn *csvConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + return conn.conn.PrepareContext(ctx, query) +} + // Prepare returns a prepared statement, bound to this connection. func (conn *csvConn) Prepare(query string) (driver.Stmt, error) { - return conn.conn.Prepare(query) + return conn.conn.PrepareContext(context.Background(), query) } // Close invalidates and potentially stops any current @@ -256,9 +289,9 @@ func (conn *csvConn) Close() error { return err } -// Begin starts and returns a new transaction. -func (conn *csvConn) Begin() (driver.Tx, error) { - tx, err := conn.conn.Begin() +// BeginTx starts and returns a new transaction. +func (conn *csvConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + tx, err := conn.conn.BeginTx(ctx, opts) if err != nil { return nil, err } @@ -266,12 +299,17 @@ func (conn *csvConn) Begin() (driver.Tx, error) { return tx, err } -func (conn *csvConn) Exec(query string, args []driver.Value) (driver.Result, error) { - return conn.exec.Exec(query, args) +// Begin starts and returns a new transaction. +func (conn *csvConn) Begin() (driver.Tx, error) { + return conn.BeginTx(context.Background(), driver.TxOptions{}) +} + +func (conn *csvConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + return conn.exec.ExecContext(ctx, query, args) } -func (conn *csvConn) Query(query string, args []driver.Value) (driver.Rows, error) { - rows, err := conn.query.Query(query, args) +func (conn *csvConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + rows, err := conn.query.QueryContext(ctx, query, args) if err != nil { return nil, err } diff --git a/csvutil/csvdriver/import.go b/csvutil/csvdriver/import.go index f1452cd79..dfefe4cdb 100644 --- a/csvutil/csvdriver/import.go +++ b/csvutil/csvdriver/import.go @@ -5,6 +5,7 @@ package csvdriver import ( + "context" "database/sql/driver" "fmt" "io" @@ -36,12 +37,14 @@ func (conn *csvConn) importCSV() error { } defer tx.Commit() - _, err = conn.Exec("create table csv ("+schema.Decl()+")", nil) + ctx := context.Background() + + _, err = conn.ExecContext(ctx, "create table csv ("+schema.Decl()+")", nil) if err != nil { return err } - _, err = conn.Exec("create index csv_id on csv (id());", nil) + _, err = conn.ExecContext(ctx, "create index csv_id on csv (id());", nil) if err != nil { return err } @@ -65,9 +68,9 @@ func (conn *csvConn) importCSV() error { return err } for i, arg := range pargs { - vargs[i] = reflect.ValueOf(arg).Elem().Interface() + vargs[i].Value = reflect.ValueOf(arg).Elem().Interface() } - _, err = conn.Exec(insert, vargs) + _, err = conn.ExecContext(ctx, insert, vargs) if err != nil { return err } @@ -177,12 +180,16 @@ func (st *schemaType) Decl() string { return strings.Join(o, ", ") } -func (st *schemaType) Args() ([]driver.Value, []interface{}) { - vargs := make([]driver.Value, len(*st)) +func (st *schemaType) Args() ([]driver.NamedValue, []interface{}) { + vargs := make([]driver.NamedValue, len(*st)) pargs := make([]interface{}, len(*st)) for i, v := range *st { ptr := reflect.New(v.v.Type()) - vargs[i] = ptr.Elem().Interface() + vargs[i] = driver.NamedValue{ + Name: v.n, + Ordinal: i + 1, + Value: ptr.Elem().Interface(), + } pargs[i] = ptr.Interface() } return vargs, pargs