Skip to content

Commit

Permalink
Merge b4e9f90 into 240ea3c
Browse files Browse the repository at this point in the history
  • Loading branch information
keegancsmith committed Nov 15, 2017
2 parents 240ea3c + b4e9f90 commit 12a8f27
Showing 1 changed file with 82 additions and 1 deletion.
83 changes: 82 additions & 1 deletion sqlhooks.go
Expand Up @@ -3,6 +3,7 @@ package sqlhooks
import (
"context"
"database/sql/driver"
"errors"
)

// Hook is the hook callback signature
Expand All @@ -27,7 +28,13 @@ func (drv *Driver) Open(name string) (driver.Conn, error) {
return conn, err
}

return &Conn{conn, drv.hooks}, nil
wrapped := &Conn{conn, drv.hooks}
if isExecer(conn) {
// If conn implements an Execer interface, return a driver.Conn which
// also implements Execer
return &ExecerContext{wrapped}, nil
}
return wrapped, nil
}

// Conn implements a database/sql.driver.Conn
Expand Down Expand Up @@ -59,6 +66,68 @@ func (conn *Conn) Prepare(query string) (driver.Stmt, error) { return conn.Conn.
func (conn *Conn) Close() error { return conn.Conn.Close() }
func (conn *Conn) Begin() (driver.Tx, error) { return conn.Conn.Begin() }

// ExecerContext implements a database/sql.driver.ExecerContext
type ExecerContext struct {
*Conn
}

func isExecer(conn driver.Conn) bool {
switch conn.(type) {
case driver.ExecerContext:
return true
case driver.Execer:
return true
default:
return false
}
}

func (conn *ExecerContext) execContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
switch c := conn.Conn.Conn.(type) {
case driver.ExecerContext:
return c.ExecContext(ctx, query, args)
case driver.Execer:
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
return c.Exec(query, dargs)
default:
// This should not happen
return nil, errors.New("ExecerContext created for a non Execer driver.Conn")
}
}

func (conn *ExecerContext) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
var err error

list := namedToInterface(args)

// Exec `Before` Hooks
if ctx, err = conn.hooks.Before(ctx, query, list...); err != nil {
return nil, err
}

results, err := conn.execContext(ctx, query, args)
if err != nil {
return results, err
}

if ctx, err = conn.hooks.After(ctx, query, list...); err != nil {
return nil, err
}

return results, err
}

func (conn *ExecerContext) Exec(query string, args []driver.Value) (driver.Result, error) {
// We have to implement Exec since it is required in the current version of
// Go for it to run ExecContext. From Go 10 it will be optional. However,
// this code should never run since database/sql always prefers to run
// ExecContext.
return nil, errors.New("Exec was called when ExecContext was implemented")
}

// Stmt implements a database/sql/driver.Stmt
type Stmt struct {
Stmt driver.Stmt
Expand Down Expand Up @@ -154,6 +223,18 @@ func namedToInterface(args []driver.NamedValue) []interface{} {
return list
}

// namedValueToValue copied from database/sql
func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
dargs := make([]driver.Value, len(named))
for n, param := range named {
if len(param.Name) > 0 {
return nil, errors.New("sql: driver does not support the use of Named Parameters")
}
dargs[n] = param.Value
}
return dargs, nil
}

/*
type hooks struct {
}
Expand Down

0 comments on commit 12a8f27

Please sign in to comment.