Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ gen:

deps:
@go mod tidy
@cd tests && go mod tidy

lint:
@golangci-lint run
Expand Down
52 changes: 31 additions & 21 deletions interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,53 +10,63 @@ var (
_ driver.DriverContext = Interceptor{}
)

// TODO: document that database/sql falls back to Prepare if the driver returns ErrSkip for Exec/Query.

// Interceptor is a [driver.Driver] wrapper that allows to register callbacks for database queries.
// It must first be registered with [sql.Register] with the same name that is then passed to [sql.Open]:
// Interceptor is a [driver.Driver] wrapper that allows to register callbacks for SQL queries.
// The main use case is to instrument code with logs, metrics, and traces without introducing an [sql.DB] wrapper.
// An interceptor must first be registered with [sql.Register] using the same name that is then passed to [sql.Open]:
//
// interceptor := queries.Interceptor{...}
// sql.Register("interceptor", interceptor)
// db, err := sql.Open("interceptor", "dsn")
//
// Only the Driver field must be set; all callbacks are optional.
//
// Note that some drivers only partially implement [driver.ExecerContext] and [driver.QueryerContext].
// A driver may return [driver.ErrSkip], which [sql.DB] interprets as a signal to fall back to a prepared statement.
// For example, the [go-sql-driver/mysql] driver only executes a query within [sql.DB.ExecContext] or [sql.DB.QueryContext] if the query has no arguments.
// Otherwise, it prepares a [driver.Stmt] using [driver.ConnPrepareContext], executes it, and closes it.
// In such cases, you may want to implement both the PrepareContext and ExecContext/QueryContext callbacks,
// even if you don't prepare statements manually via [sql.DB.PrepareContext].
// TODO: provide an example of such an implementation.
//
// [go-sql-driver/mysql]: https://github.com/go-sql-driver/mysql
type Interceptor struct {
// Driver is a database driver.
// It must implement [driver.Pinger], [driver.ExecerContext], [driver.QueryerContext],
// [driver.ConnPrepareContext], and [driver.ConnBeginTx] (most drivers do).
// Required.
// Driver is an implementation of [driver.Driver].
// It must also implement [driver.Pinger], [driver.ConnPrepareContext], and [driver.ConnBeginTx].
Driver driver.Driver

// ExecContext is a callback for both [sql.DB.ExecContext] and [sql.Tx.ExecContext].
// ExecContext is a callback for [sql.DB.ExecContext] and [sql.Tx.ExecContext].
// The implementation must call execer.ExecerContext(ctx, query, args) and return the result.
// Optional.
// Note that if the driver does not implement [driver.ExecerContext], the callback will never be called.
// In this case, consider implementing the PrepareContext callback instead.
ExecContext func(ctx context.Context, query string, args []driver.NamedValue, execer driver.ExecerContext) (driver.Result, error)

// QueryContext is a callback for both [sql.DB.QueryContext] and [sql.Tx.QueryContext].
// QueryContext is a callback for [sql.DB.QueryContext] and [sql.Tx.QueryContext].
// The implementation must call queryer.QueryContext(ctx, query, args) and return the result.
// Optional.
// Note that if the driver does not implement [driver.QueryerContext], the callback will never be called.
// In this case, consider implementing the PrepareContext callback instead.
QueryContext func(ctx context.Context, query string, args []driver.NamedValue, queryer driver.QueryerContext) (driver.Rows, error)

// PrepareContext is a callback for [sql.DB.PrepareContext].
// PrepareContext is a callback for [sql.DB.PrepareContext] and [sql.Tx.PrepareContext].
// The implementation must call preparer.ConnPrepareContext(ctx, query) and return the result.
// Optional.
PrepareContext func(ctx context.Context, query string, preparer driver.ConnPrepareContext) (driver.Stmt, error)
}

// Open implements [driver.Driver].
func (i Interceptor) Open(name string) (driver.Conn, error) {
func (Interceptor) Open(string) (driver.Conn, error) {
panic("unreachable") // driver.DriverContext always takes precedence over driver.Driver.
}

// OpenConnector implements [driver.DriverContext].
func (i Interceptor) OpenConnector(name string) (driver.Connector, error) {
if d, ok := i.Driver.(driver.DriverContext); ok {
connector, err := d.OpenConnector(name)
c, err := d.OpenConnector(name)
if err != nil {
return nil, err
}
return wrappedConnector{connector, i}, nil
return wrappedConnector{c, i}, nil
}
connector := dsnConnector{name, i.Driver}
return wrappedConnector{connector, i}, nil
c := dsnConnector{name, i.Driver}
return wrappedConnector{c, i}, nil
}

var (
Expand Down Expand Up @@ -86,7 +96,7 @@ func (c wrappedConn) Ping(ctx context.Context) error {
func (c wrappedConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
execer, ok := c.Conn.(driver.ExecerContext)
if !ok {
panic("queries: driver does not implement driver.ExecerContext")
return nil, driver.ErrSkip
}
if c.interceptor.ExecContext != nil {
return c.interceptor.ExecContext(ctx, query, args, execer)
Expand All @@ -98,7 +108,7 @@ func (c wrappedConn) ExecContext(ctx context.Context, query string, args []drive
func (c wrappedConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
queryer, ok := c.Conn.(driver.QueryerContext)
if !ok {
panic("queries: driver does not implement driver.QueryerContext")
return nil, driver.ErrSkip
}
if c.interceptor.QueryContext != nil {
return c.interceptor.QueryContext(ctx, query, args, queryer)
Expand Down
4 changes: 2 additions & 2 deletions interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ func TestInterceptor_unimplemented(t *testing.T) {
assert.Panics[E](t, pingFn, "queries: driver does not implement driver.Pinger")

execFn := func() { _, _ = db.ExecContext(ctx, "") }
assert.Panics[E](t, execFn, "queries: driver does not implement driver.ExecerContext")
assert.Panics[E](t, execFn, "queries: driver does not implement driver.ConnPrepareContext")

queryFn := func() { _, _ = db.QueryContext(ctx, "") } //nolint:gocritic // sqlQuery: unused result is fine here.
assert.Panics[E](t, queryFn, "queries: driver does not implement driver.QueryerContext")
assert.Panics[E](t, queryFn, "queries: driver does not implement driver.ConnPrepareContext")

prepareFn := func() { _, _ = db.PrepareContext(ctx, "") }
assert.Panics[E](t, prepareFn, "queries: driver does not implement driver.ConnPrepareContext")
Expand Down
29 changes: 26 additions & 3 deletions tests/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,36 @@ import (
"modernc.org/sqlite"
)

// --------------------------------------------------------------------------------------
// | Interface / Driver | jackc/pgx | go-sql-driver/mysql | modernc.org/sqlite |
// |-----------------------------|-----------|---------------------|--------------------|
// | [driver.DriverContext] | + | + | - |
// | [driver.Pinger] | + | + | + |
// | [driver.ExecerContext] | + | + | + |
// | [driver.QueryerContext] | + | + | + |
// | [driver.ConnPrepareContext] | + | + | + |
// | [driver.ConnBeginTx] | + | + | + |
// | [driver.SessionResetter] | + | + | + |
// | [driver.Validator] | - | + | + |
// | [driver.NamedValueChecker] | + | + | - |
// --------------------------------------------------------------------------------------

var DBs = map[string]struct {
driver driver.Driver
dsn string
}{
"postgres": {pgx.GetDefaultDriver(), "postgres://postgres:postgres@localhost:5432/postgres"}, // https://github.com/jackc/pgx
"mysql": {new(mysql.MySQLDriver), "root:root@tcp(localhost:3306)/mysql?parseTime=true"}, // https://github.com/go-sql-driver/mysql
"sqlite": {new(sqlite.Driver), "test.sqlite"}, // https://gitlab.com/cznic/sqlite
"postgres": { // https://github.com/jackc/pgx
pgx.GetDefaultDriver(),
"postgres://postgres:postgres@localhost:5432/postgres",
},
"mysql": { // https://github.com/go-sql-driver/mysql
new(mysql.MySQLDriver),
"root:root@tcp(localhost:3306)/mysql?parseTime=true",
},
"sqlite": { // https://gitlab.com/cznic/sqlite
new(sqlite.Driver),
"test.sqlite",
},
}

type User struct {
Expand Down
Loading