Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,29 +1,37 @@
.POSIX:
.SUFFIXES:

fmt:
@golangci-lint fmt
bench:
@go test -run='^$$' -bench=. -cpuprofile=profile.cpu -memprofile=profile.mem

gen:
@go generate ./...
clean:
@rm -rf tests/coverdata tests/coverage.out tests/test.sqlite

deps:
@go mod tidy
@cd tests && go mod tidy

fmt:
@golangci-lint fmt

gen:
@go generate ./...

lint:
@golangci-lint run

test:
@rm -rf tests/coverdata tests/coverage.out tests/test.sqlite && mkdir tests/coverdata
test: test/unit test/integration

test/unit: clean
@mkdir -p tests/coverdata
@go test -race -shuffle=on -cover . -args -test.gocoverdir=$$PWD/tests/coverdata

test/integration: clean
@mkdir -p tests/coverdata
@$(CONTAINER_RUNNER) compose --file=tests/compose.yaml up --detach
@go test -v -race -coverpkg=go-simpler.org/queries ./tests -args -test.gocoverdir=$$PWD/tests/coverdata
@$(CONTAINER_RUNNER) compose --file=tests/compose.yaml down
@go tool covdata textfmt -i=tests/coverdata -o=tests/coverage.out

test/cover: test
@go tool cover -html=tests/coverage.out

bench:
@go test -run='^$$' -bench=. -cpuprofile=profile.cpu -memprofile=profile.mem
8 changes: 7 additions & 1 deletion interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ var (
// Otherwise, it prepares a [driver.Stmt] using [driver.ConnPrepareContext], executes it, and closes it.
// In such cases, you may want to implement both the PrepareContext and ExecContext/QueryContext callbacks,
// even if you don't prepare statements manually via [sql.DB.PrepareContext].
// TODO: provide an example of such an implementation.
//
// [go-sql-driver/mysql]: https://github.com/go-sql-driver/mysql
type Interceptor struct {
Expand All @@ -49,6 +48,10 @@ type Interceptor struct {
// PrepareContext is a callback for [sql.DB.PrepareContext] and [sql.Tx.PrepareContext].
// The implementation must call preparer.ConnPrepareContext(ctx, query) and return the result.
PrepareContext func(ctx context.Context, query string, preparer driver.ConnPrepareContext) (driver.Stmt, error)

// BeginTx is a callback for [sql.DB.BeginTx].
// The implementation must call beginner.BeginTx(ctx, opts) and return the result.
BeginTx func(ctx context.Context, opts driver.TxOptions, beginner driver.ConnBeginTx) (driver.Tx, error)
}

// Open implements [driver.Driver].
Expand Down Expand Up @@ -134,6 +137,9 @@ func (c wrappedConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver
if !ok {
panic("queries: driver does not implement driver.ConnBeginTx")
}
if c.interceptor.BeginTx != nil {
return c.interceptor.BeginTx(ctx, opts, beginner)
}
return beginner.BeginTx(ctx, opts)
}

Expand Down
24 changes: 20 additions & 4 deletions interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ func TestInterceptor(t *testing.T) {
var execCalled bool
var queryCalled bool
var prepareCalled bool
var beginTxCalled bool

interceptor := queries.Interceptor{
Driver: mockDriver{conn: spyConn{}},
Expand All @@ -33,9 +34,13 @@ func TestInterceptor(t *testing.T) {
prepareCalled = true
return preparer.PrepareContext(ctx, query)
},
BeginTx: func(ctx context.Context, opts driver.TxOptions, beginner driver.ConnBeginTx) (driver.Tx, error) {
beginTxCalled = true
return beginner.BeginTx(ctx, opts)
},
}

driverName := t.Name() + "_interceptor"
driverName := t.Name()
sql.Register(driverName, interceptor)

db, err := sql.Open(driverName, "")
Expand All @@ -53,6 +58,10 @@ func TestInterceptor(t *testing.T) {
_, err = db.PrepareContext(ctx, "")
assert.IsErr[E](t, err, errCalled)
assert.Equal[E](t, prepareCalled, true)

_, err = db.BeginTx(ctx, nil)
assert.IsErr[E](t, err, errCalled)
assert.Equal[E](t, beginTxCalled, true)
}

func TestInterceptor_passthrough(t *testing.T) {
Expand All @@ -62,7 +71,7 @@ func TestInterceptor_passthrough(t *testing.T) {
Driver: mockDriver{conn: spyConn{}},
}

driverName := t.Name() + "_interceptor"
driverName := t.Name()
sql.Register(driverName, interceptor)

db, err := sql.Open(driverName, "")
Expand All @@ -77,6 +86,9 @@ func TestInterceptor_passthrough(t *testing.T) {

_, err = db.PrepareContext(ctx, "")
assert.IsErr[E](t, err, errCalled)

_, err = db.BeginTx(ctx, nil)
assert.IsErr[E](t, err, errCalled)
}

func TestInterceptor_unimplemented(t *testing.T) {
Expand All @@ -86,7 +98,7 @@ func TestInterceptor_unimplemented(t *testing.T) {
Driver: mockDriver{conn: unimplementedConn{}},
}

driverName := t.Name() + "_interceptor"
driverName := t.Name()
sql.Register(driverName, interceptor)

db, err := sql.Open(driverName, "")
Expand All @@ -113,7 +125,7 @@ func TestInterceptor_driver(t *testing.T) {
mdriver := mockDriver{}
interceptor := queries.Interceptor{Driver: mdriver}

driverName := t.Name() + "_interceptor"
driverName := t.Name()
sql.Register(driverName, interceptor)

db, err := sql.Open(driverName, "")
Expand Down Expand Up @@ -148,3 +160,7 @@ func (spyConn) QueryContext(context.Context, string, []driver.NamedValue) (drive
func (spyConn) PrepareContext(context.Context, string) (driver.Stmt, error) {
return nil, errCalled
}

func (spyConn) BeginTx(context.Context, driver.TxOptions) (driver.Tx, error) {
return nil, errCalled
}
9 changes: 9 additions & 0 deletions tests/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ func TestIntegration(t *testing.T) {
var execCalls int
var queryCalls int
var prepareCalls int
var beginTxCalls int

interceptor := queries.Interceptor{
Driver: driverIface,
Expand All @@ -117,6 +118,11 @@ func TestIntegration(t *testing.T) {
t.Logf("PrepareContext: %s", query)
return preparer.PrepareContext(ctx, query)
},
BeginTx: func(ctx context.Context, opts driver.TxOptions, beginner driver.ConnBeginTx) (driver.Tx, error) {
beginTxCalls++
t.Log("BeginTx")
return beginner.BeginTx(ctx, opts)
},
}

driverName += "+interceptor"
Expand Down Expand Up @@ -187,14 +193,17 @@ func TestIntegration(t *testing.T) {
assert.Equal[E](t, execCalls, 3)
assert.Equal[E](t, queryCalls, 5*2)
assert.Equal[E](t, prepareCalls, 1)
assert.Equal[E](t, beginTxCalls, 1)
case *mssqldb.Driver: // always uses PrepareContext.
assert.Equal[E](t, execCalls, 0)
assert.Equal[E](t, queryCalls, 0)
assert.Equal[E](t, prepareCalls, 3+5*2)
assert.Equal[E](t, beginTxCalls, 1)
default:
assert.Equal[E](t, execCalls, 3)
assert.Equal[E](t, queryCalls, 5*2)
assert.Equal[E](t, prepareCalls, 0)
assert.Equal[E](t, beginTxCalls, 1)
}
})
}
Expand Down
Loading