diff --git a/Makefile b/Makefile index b0daed8..def111d 100644 --- a/Makefile +++ b/Makefile @@ -1,22 +1,33 @@ .POSIX: .SUFFIXES: -fmt: - @golangci-lint fmt +bench: + @go test -run='^$$' -bench=. -cpuprofile=profile.cpu -memprofile=profile.mem -gen: - @go generate ./... +clean: + @rm -rf tests/coverdata tests/coverage.out tests/test.sqlite deps: @go mod tidy @cd tests && go mod tidy +fmt: + @golangci-lint fmt + +gen: + @go generate ./... + lint: @golangci-lint run -test: - @rm -rf tests/coverdata tests/coverage.out tests/test.sqlite && mkdir tests/coverdata +test: test/unit test/integration + +test/unit: clean + @mkdir -p tests/coverdata @go test -race -shuffle=on -cover . -args -test.gocoverdir=$$PWD/tests/coverdata + +test/integration: clean + @mkdir -p tests/coverdata @$(CONTAINER_RUNNER) compose --file=tests/compose.yaml up --detach @go test -v -race -coverpkg=go-simpler.org/queries ./tests -args -test.gocoverdir=$$PWD/tests/coverdata @$(CONTAINER_RUNNER) compose --file=tests/compose.yaml down @@ -24,6 +35,3 @@ test: test/cover: test @go tool cover -html=tests/coverage.out - -bench: - @go test -run='^$$' -bench=. -cpuprofile=profile.cpu -memprofile=profile.mem diff --git a/interceptor.go b/interceptor.go index a5c1387..65c38ac 100644 --- a/interceptor.go +++ b/interceptor.go @@ -26,7 +26,6 @@ var ( // Otherwise, it prepares a [driver.Stmt] using [driver.ConnPrepareContext], executes it, and closes it. // In such cases, you may want to implement both the PrepareContext and ExecContext/QueryContext callbacks, // even if you don't prepare statements manually via [sql.DB.PrepareContext]. -// TODO: provide an example of such an implementation. // // [go-sql-driver/mysql]: https://github.com/go-sql-driver/mysql type Interceptor struct { @@ -49,6 +48,10 @@ type Interceptor struct { // PrepareContext is a callback for [sql.DB.PrepareContext] and [sql.Tx.PrepareContext]. // The implementation must call preparer.ConnPrepareContext(ctx, query) and return the result. PrepareContext func(ctx context.Context, query string, preparer driver.ConnPrepareContext) (driver.Stmt, error) + + // BeginTx is a callback for [sql.DB.BeginTx]. + // The implementation must call beginner.BeginTx(ctx, opts) and return the result. + BeginTx func(ctx context.Context, opts driver.TxOptions, beginner driver.ConnBeginTx) (driver.Tx, error) } // Open implements [driver.Driver]. @@ -134,6 +137,9 @@ func (c wrappedConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver if !ok { panic("queries: driver does not implement driver.ConnBeginTx") } + if c.interceptor.BeginTx != nil { + return c.interceptor.BeginTx(ctx, opts, beginner) + } return beginner.BeginTx(ctx, opts) } diff --git a/interceptor_test.go b/interceptor_test.go index b456c82..b1f9179 100644 --- a/interceptor_test.go +++ b/interceptor_test.go @@ -18,6 +18,7 @@ func TestInterceptor(t *testing.T) { var execCalled bool var queryCalled bool var prepareCalled bool + var beginTxCalled bool interceptor := queries.Interceptor{ Driver: mockDriver{conn: spyConn{}}, @@ -33,9 +34,13 @@ func TestInterceptor(t *testing.T) { prepareCalled = true return preparer.PrepareContext(ctx, query) }, + BeginTx: func(ctx context.Context, opts driver.TxOptions, beginner driver.ConnBeginTx) (driver.Tx, error) { + beginTxCalled = true + return beginner.BeginTx(ctx, opts) + }, } - driverName := t.Name() + "_interceptor" + driverName := t.Name() sql.Register(driverName, interceptor) db, err := sql.Open(driverName, "") @@ -53,6 +58,10 @@ func TestInterceptor(t *testing.T) { _, err = db.PrepareContext(ctx, "") assert.IsErr[E](t, err, errCalled) assert.Equal[E](t, prepareCalled, true) + + _, err = db.BeginTx(ctx, nil) + assert.IsErr[E](t, err, errCalled) + assert.Equal[E](t, beginTxCalled, true) } func TestInterceptor_passthrough(t *testing.T) { @@ -62,7 +71,7 @@ func TestInterceptor_passthrough(t *testing.T) { Driver: mockDriver{conn: spyConn{}}, } - driverName := t.Name() + "_interceptor" + driverName := t.Name() sql.Register(driverName, interceptor) db, err := sql.Open(driverName, "") @@ -77,6 +86,9 @@ func TestInterceptor_passthrough(t *testing.T) { _, err = db.PrepareContext(ctx, "") assert.IsErr[E](t, err, errCalled) + + _, err = db.BeginTx(ctx, nil) + assert.IsErr[E](t, err, errCalled) } func TestInterceptor_unimplemented(t *testing.T) { @@ -86,7 +98,7 @@ func TestInterceptor_unimplemented(t *testing.T) { Driver: mockDriver{conn: unimplementedConn{}}, } - driverName := t.Name() + "_interceptor" + driverName := t.Name() sql.Register(driverName, interceptor) db, err := sql.Open(driverName, "") @@ -113,7 +125,7 @@ func TestInterceptor_driver(t *testing.T) { mdriver := mockDriver{} interceptor := queries.Interceptor{Driver: mdriver} - driverName := t.Name() + "_interceptor" + driverName := t.Name() sql.Register(driverName, interceptor) db, err := sql.Open(driverName, "") @@ -148,3 +160,7 @@ func (spyConn) QueryContext(context.Context, string, []driver.NamedValue) (drive func (spyConn) PrepareContext(context.Context, string) (driver.Stmt, error) { return nil, errCalled } + +func (spyConn) BeginTx(context.Context, driver.TxOptions) (driver.Tx, error) { + return nil, errCalled +} diff --git a/tests/integration_test.go b/tests/integration_test.go index c01b164..f2cc812 100644 --- a/tests/integration_test.go +++ b/tests/integration_test.go @@ -99,6 +99,7 @@ func TestIntegration(t *testing.T) { var execCalls int var queryCalls int var prepareCalls int + var beginTxCalls int interceptor := queries.Interceptor{ Driver: driverIface, @@ -117,6 +118,11 @@ func TestIntegration(t *testing.T) { t.Logf("PrepareContext: %s", query) return preparer.PrepareContext(ctx, query) }, + BeginTx: func(ctx context.Context, opts driver.TxOptions, beginner driver.ConnBeginTx) (driver.Tx, error) { + beginTxCalls++ + t.Log("BeginTx") + return beginner.BeginTx(ctx, opts) + }, } driverName += "+interceptor" @@ -187,14 +193,17 @@ func TestIntegration(t *testing.T) { assert.Equal[E](t, execCalls, 3) assert.Equal[E](t, queryCalls, 5*2) assert.Equal[E](t, prepareCalls, 1) + assert.Equal[E](t, beginTxCalls, 1) case *mssqldb.Driver: // always uses PrepareContext. assert.Equal[E](t, execCalls, 0) assert.Equal[E](t, queryCalls, 0) assert.Equal[E](t, prepareCalls, 3+5*2) + assert.Equal[E](t, beginTxCalls, 1) default: assert.Equal[E](t, execCalls, 3) assert.Equal[E](t, queryCalls, 5*2) assert.Equal[E](t, prepareCalls, 0) + assert.Equal[E](t, beginTxCalls, 1) } }) }