From 5ae41870250b9dba3fb637c42a4bd357bf98daca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Thu, 23 Oct 2025 17:57:30 +0200 Subject: [PATCH 1/2] feat: support statement_timeout and transaction_timeout property Add a statement_timeout connection property that is used as the default timeout for the execution of all statements that are executed on a connection. The timeout is only used for the actual execution, and not attached to the iterator that is returned for a query. This also means that a query that is executed without the DirectExecuteQuery option, will ignore the statement_timeout value. Also adds a transaction_timeout property that is additionally used for all statements in a read/write transaction. The deadline of the transaction is calculated at the start of the transaction, and all statements in the transaction get this deadline, unless the statement already has an earlier deadline from for example a statement_timeout or a context deadline. This change also fixes some issues with deadlines when using the gRPC API of SpannerLib. The context that is used for an RPC invocation is cancelled after the RPC has finished. This context should therefore not be used as the context for any query execution, as the context is attached to the row iterator, and would cancel the query execution halfway. Fixes #574 Fixes #575 --- conn.go | 194 ++++++++++++++++++++++++-- connection_leak_test.go | 71 +++++++++- connection_properties.go | 20 +++ driver.go | 9 +- driver_with_mockserver_test.go | 8 +- go.mod | 1 + merged_row_iterator.go | 2 +- partitioned_query.go | 4 +- rows.go | 19 ++- rows_test.go | 4 +- spannerlib/api/connection.go | 23 ++- spannerlib/grpc-server/server.go | 31 +--- spannerlib/grpc-server/server_test.go | 175 +++++++++++++++++++++++ statements.go | 6 +- stmt_with_mockserver_test.go | 90 ++++++++++++ transaction.go | 26 +++- transaction_test.go | 83 +++++++++++ 17 files changed, 699 insertions(+), 67 deletions(-) diff --git a/conn.go b/conn.go index 278cff0c..da89689d 100644 --- a/conn.go +++ b/conn.go @@ -19,6 +19,7 @@ import ( "database/sql" "database/sql/driver" "errors" + "fmt" "log/slog" "slices" "sync" @@ -831,6 +832,69 @@ func (c *conn) PrepareContext(_ context.Context, query string) (driver.Stmt, err return &stmt{conn: c, query: parsedSQL, numArgs: len(args), execOptions: execOptions}, nil } +// Adds any statement or transaction timeout to the given context. The deadline of the returned +// context will be the earliest of: +// 1. Any existing deadline on the input context. +// 2. Any existing transaction deadline. +// 3. A deadline calculated from the current time + the value of statement_timeout. +func (c *conn) addStatementAndTransactionTimeout(ctx context.Context) (context.Context, context.CancelFunc, error) { + var statementDeadline time.Time + var transactionDeadline time.Time + var deadline time.Time + var hasStatementDeadline bool + var hasTransactionDeadline bool + + // Check if the connection has a value for statement_timeout. + statementTimeout := propertyStatementTimeout.GetValueOrDefault(c.state) + if statementTimeout != time.Duration(0) { + hasStatementDeadline = true + statementDeadline = time.Now().Add(statementTimeout) + } + // Check if the current transaction has a deadline. + transactionDeadline, hasTransactionDeadline, err := c.transactionDeadline() + if err != nil { + return nil, nil, err + } + + // If there is no statement_timeout and no current transaction deadline, + // then can just use the input context as-is. + if !hasStatementDeadline && !hasTransactionDeadline { + return ctx, func() {}, nil + } + + // If there is both a transaction and a statement deadline, then we use the earliest + // of those two. + if hasTransactionDeadline && hasStatementDeadline { + if statementDeadline.Before(transactionDeadline) { + deadline = statementDeadline + } else { + deadline = transactionDeadline + } + } else if hasStatementDeadline { + deadline = statementDeadline + } else { + deadline = transactionDeadline + } + // context.WithDeadline automatically selects the earliest deadline of + // the existing deadline on the context and the given deadline. + newCtx, cancel := context.WithDeadline(ctx, deadline) + return newCtx, cancel, nil +} + +// transactionDeadline returns the deadline of the current transaction +// on the connection. This also activates the transaction if it is not +// yet activated. +func (c *conn) transactionDeadline() (time.Time, bool, error) { + if c.tx == nil { + return time.Time{}, false, nil + } + if err := c.tx.ensureActivated(); err != nil { + return time.Time{}, false, err + } + deadline, hasDeadline := c.tx.deadline() + return deadline, hasDeadline, nil +} + func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { // Execute client side statement if it is one. clientStmt, err := c.parser.ParseClientSideStatement(query) @@ -849,13 +913,22 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam return c.queryContext(ctx, query, execOptions, args) } -func (c *conn) queryContext(ctx context.Context, query string, execOptions *ExecOptions, args []driver.NamedValue) (driver.Rows, error) { +func (c *conn) queryContext(ctx context.Context, query string, execOptions *ExecOptions, args []driver.NamedValue) (returnedRows driver.Rows, returnedErr error) { + ctx, cancelCause := context.WithCancelCause(ctx) + cancel := func() { + cancelCause(nil) + } + defer func() { + if returnedErr != nil { + cancel() + } + }() // Clear the commit timestamp of this connection before we execute the query. c.clearCommitResponse() // Check if the execution options contains an instruction to execute // a specific partition of a PartitionedQuery. if pq := execOptions.PartitionedQueryOptions.ExecutePartition.PartitionedQuery; pq != nil { - return pq.execute(ctx, execOptions.PartitionedQueryOptions.ExecutePartition.Index) + return pq.execute(ctx, cancel, execOptions.PartitionedQueryOptions.ExecutePartition.Index) } stmt, err := prepareSpannerStmt(c.parser, query, args) @@ -869,7 +942,7 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec if err != nil { return nil, err } - return createDriverResultRows(res, execOptions), nil + return createDriverResultRows(res, cancel, execOptions), nil } var iter rowIterator if c.tx == nil { @@ -884,7 +957,7 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec } else if execOptions.PartitionedQueryOptions.PartitionQuery { return nil, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "PartitionQuery is only supported in batch read-only transactions")) } else if execOptions.PartitionedQueryOptions.AutoPartitionQuery { - return c.executeAutoPartitionedQuery(ctx, query, execOptions, args) + return c.executeAutoPartitionedQuery(ctx, cancel, query, execOptions, args) } else { // The statement was either detected as being a query, or potentially not recognized at all. // In that case, just default to using a single-use read-only transaction and let Spanner @@ -893,6 +966,9 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec } } else { if execOptions.PartitionedQueryOptions.PartitionQuery { + // The driver.Rows instance that is returned for partitionQuery does not + // contain a context, and therefore also does not cancel the context when it is closed. + defer cancel() return c.tx.partitionQuery(ctx, stmt, execOptions) } iter, err = c.tx.Query(ctx, stmt, statementInfo.StatementType, execOptions) @@ -900,18 +976,65 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec return nil, err } } - res := createRows(iter, execOptions) + res := createRows(iter, cancel, execOptions) if execOptions.DirectExecuteQuery { - // This call to res.getColumns() triggers the execution of the statement, as it needs to fetch the metadata. - res.getColumns() - if res.dirtyErr != nil && !errors.Is(res.dirtyErr, iterator.Done) { - _ = res.Close() - return nil, res.dirtyErr + if err := c.directExecuteQuery(ctx, cancelCause, res, execOptions); err != nil { + return nil, err } } return res, nil } +// directExecuteQuery blocks until the first PartialResultSet has been returned by Spanner. Any statement_timeout and/or +// transaction_timeout is used while waiting for the first result to be returned. +func (c *conn) directExecuteQuery(ctx context.Context, cancelQuery context.CancelCauseFunc, res *rows, execOptions *ExecOptions) error { + statementCtx := ctx + if execOptions.DirectExecuteContext != nil { + statementCtx = execOptions.DirectExecuteContext + } + // Add the statement or transaction deadline to the context. + statementCtx, cancelStatement, err := c.addStatementAndTransactionTimeout(statementCtx) + if err != nil { + return err + } + defer cancelStatement() + + // Asynchronously fetch the first partial result set from Spanner. + done := make(chan struct{}) + go func() { + // Calling res.getColumns() ensures that the first PartialResultSet has been returned, as it contains the + // metadata of the query. + defer close(done) + res.getColumns() + }() + // Wait until either the done channel is closed or the context is done. + var statementErr error + select { + case <-statementCtx.Done(): + statementErr = statementCtx.Err() + // Cancel the query execution. + cancelQuery(statementCtx.Err()) + case <-done: + } + + // Now wait until done channel is closed. This could be because the execution finished + // successfully, or because the context was cancelled, which again causes the execution + // to (eventually) fail. + <-done + if res.dirtyErr != nil && !errors.Is(res.dirtyErr, iterator.Done) { + _ = res.Close() + if statementErr != nil { + // Create a status error from the statement error and wrap both the Spanner error and the status error into + // one error. This will preserve the DeadlineExceeded error code from statementErr, and include the request + // ID from the Spanner error. + s := status.FromContextError(statementErr) + return fmt.Errorf("%w: %w", s.Err(), res.dirtyErr) + } + return res.dirtyErr + } + return nil +} + func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { // Execute client side statement if it is one. stmt, err := c.parser.ParseClientSideStatement(query) @@ -929,7 +1052,13 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name return c.execContext(ctx, query, execOptions, args) } -func (c *conn) execContext(ctx context.Context, query string, execOptions *ExecOptions, args []driver.NamedValue) (driver.Result, error) { +func (c *conn) execContext(ctx context.Context, query string, execOptions *ExecOptions, args []driver.NamedValue) (returnedResult driver.Result, returnedErr error) { + // Add the statement/transaction deadline to the context. + ctx, cancel, err := c.addStatementAndTransactionTimeout(ctx) + if err != nil { + return nil, err + } + defer cancel() // Clear the commit timestamp of this connection before we execute the statement. c.clearCommitResponse() @@ -1041,6 +1170,18 @@ func (c *conn) resetTransactionForRetry(ctx context.Context, errDuringCommit boo return noTransaction() } c.tx = c.prevTx + // If the aborted error happened during the Commit, then the transaction + // context has been cancelled, and we need to create a new one. + if rwTx, ok := c.tx.contextTransaction.(*readWriteTransaction); ok { + newCtx, cancel := c.addTransactionTimeout(c.tx.ctx) + rwTx.ctx = newCtx + // Make sure that we cancel the new context when the transaction is closed. + origClose := rwTx.close + rwTx.close = func(result txResult, commitResponse *spanner.CommitResponse, commitErr error) { + origClose(result, commitResponse, commitErr) + cancel() + } + } c.resetForRetry = true } else if c.tx == nil { return noTransaction() @@ -1248,6 +1389,17 @@ func (c *conn) beginTx(ctx context.Context, driverOpts driver.TxOptions, closeFu return c.tx, nil } +// addTransactionTimeout creates a new derived context with the current transaction_timeout. +func (c *conn) addTransactionTimeout(ctx context.Context) (context.Context, context.CancelFunc) { + timeout := propertyTransactionTimeout.GetValueOrDefault(c.state) + if timeout == time.Duration(0) { + return ctx, func() {} + } + // Note that this will set the actual deadline to the earliest of the existing deadline on ctx and the calculated + // deadline based on the timeout. + return context.WithTimeout(ctx, timeout) +} + func (c *conn) activateTransaction() (contextTransaction, error) { closeFunc := c.tx.close if propertyTransactionReadOnly.GetValueOrDefault(c.state) { @@ -1283,7 +1435,10 @@ func (c *conn) activateTransaction() (contextTransaction, error) { opts := spanner.TransactionOptions{} opts.BeginTransactionOption = c.convertDefaultBeginTransactionOption(propertyBeginTransactionOption.GetValueOrDefault(c.state)) - tx, err := spanner.NewReadWriteStmtBasedTransactionWithCallbackForOptions(c.tx.ctx, c.client, opts, func() spanner.TransactionOptions { + // Add the current value of transaction_timeout to the context that is registered + // on the transaction. + ctx, cancel := c.addTransactionTimeout(c.tx.ctx) + tx, err := spanner.NewReadWriteStmtBasedTransactionWithCallbackForOptions(ctx, c.client, opts, func() spanner.TransactionOptions { defer func() { // Reset the transaction_tag after starting the transaction. _ = propertyTransactionTag.ResetValue(c.state, connectionstate.ContextUser) @@ -1291,11 +1446,12 @@ func (c *conn) activateTransaction() (contextTransaction, error) { return c.effectiveTransactionOptions(spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED, c.options( /*reset=*/ true)) }) if err != nil { + cancel() return nil, err } logger := c.logger.With("tx", "rw") return &readWriteTransaction{ - ctx: c.tx.ctx, + ctx: ctx, conn: c, logger: logger, rwTx: tx, @@ -1307,6 +1463,7 @@ func (c *conn) activateTransaction() (contextTransaction, error) { } else { closeFunc(txResultRollback) } + cancel() }, retryAborts: sync.OnceValue(func() bool { return c.RetryAbortsInternally() @@ -1371,7 +1528,15 @@ func queryInSingleUse(ctx context.Context, c *spanner.Client, statement spanner. return c.Single().WithTimestampBound(tb).QueryWithOptions(ctx, statement, options.QueryOptions) } -func (c *conn) executeAutoPartitionedQuery(ctx context.Context, query string, execOptions *ExecOptions, args []driver.NamedValue) (driver.Rows, error) { +func (c *conn) executeAutoPartitionedQuery(ctx context.Context, cancel context.CancelFunc, query string, execOptions *ExecOptions, args []driver.NamedValue) (returnedRows driver.Rows, returnedErr error) { + // The cancel() function is called by the returned Rows object when it is closed. + // However, if an error is returned instead of a Rows instance, we need to cancel + // the context when we return from this function. + defer func() { + if returnedErr != nil { + cancel() + } + }() tx, err := c.BeginTx(ctx, driver.TxOptions{ReadOnly: true, Isolation: withBatchReadOnly(driver.IsolationLevel(sql.LevelDefault))}) if err != nil { return nil, err @@ -1383,6 +1548,7 @@ func (c *conn) executeAutoPartitionedQuery(ctx context.Context, query string, ex } if rows, ok := r.(*rows); ok { rows.close = func() error { + defer cancel() return tx.Commit() } } diff --git a/connection_leak_test.go b/connection_leak_test.go index b4a07416..544fff58 100644 --- a/connection_leak_test.go +++ b/connection_leak_test.go @@ -24,14 +24,15 @@ import ( "cloud.google.com/go/spanner" "github.com/googleapis/go-sql-spanner/testutil" + "go.uber.org/goleak" "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" ) func TestNoLeak(t *testing.T) { - t.Parallel() + // Not parallel, as it checks for leaked goroutines. - db, server, teardown := setupTestDBConnection(t) + db, server, teardown := setupTestDBConnectionWithParams(t, "statement_timeout=10s;transaction_timeout=20s") defer teardown() // Set MaxOpenConns to 1 to force an error if anything leaks a connection. db.SetMaxOpenConns(1) @@ -39,7 +40,7 @@ func TestNoLeak(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) defer cancel() - for i := 0; i < 2; i++ { + runTests := func() { pingContext(ctx, t, db) pingFailed(ctx, t, server, db) simpleQuery(ctx, t, db) @@ -50,8 +51,28 @@ func TestNoLeak(t *testing.T) { readOnlyTxWithStaleness(ctx, t, db) simpleReadWriteTx(ctx, t, db) runTransactionRetry(ctx, t, server, db) + runTransactionRetryAbortedHalfway(ctx, t, server, db) readOnlyTxWithOptions(ctx, t, db) } + + for i := 0; i < 2; i++ { + runTests() + } + ignoreCurrent := goleak.IgnoreCurrent() + + for i := 0; i < 10; i++ { + runTests() + } + goleak.VerifyNone(t, ignoreCurrent, + goleak.IgnoreTopFunction("cloud.google.com/go/spanner.(*healthChecker).worker"), + goleak.IgnoreTopFunction("cloud.google.com/go/spanner.(*healthChecker).multiplexSessionWorker"), + goleak.IgnoreTopFunction("cloud.google.com/go/spanner.(*healthChecker).maintainer"), + goleak.IgnoreTopFunction("google.golang.org/grpc/internal/transport.(*controlBuffer).get"), + goleak.IgnoreTopFunction("google.golang.org/grpc/internal/transport.(*http2Server).keepalive"), + goleak.IgnoreTopFunction("internal/poll.runtime_pollWait"), + goleak.IgnoreTopFunction("google.golang.org/grpc/internal/grpcsync.(*CallbackSerializer).run"), + goleak.IgnoreTopFunction("cloud.google.com/go/spanner.(*sessionPool).createMultiplexedSession"), + ) } func pingContext(ctx context.Context, t *testing.T, db *sql.DB) { @@ -308,6 +329,50 @@ func runTransactionRetry(ctx context.Context, t *testing.T, server *testutil.Moc } } +func runTransactionRetryAbortedHalfway(ctx context.Context, t *testing.T, server *testutil.MockedSpannerInMemTestServer, db *sql.DB) { + var attempts int + err := RunTransactionWithOptions(ctx, db, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error { + attempts++ + rows, err := tx.QueryContext(ctx, testutil.SelectFooFromBar, ExecOptions{QueryOptions: spanner.QueryOptions{RequestTag: "tag_1"}}) + if err != nil { + t.Fatal(err) + } + for rows.Next() { + } + if err := rows.Close(); err != nil { + t.Fatal(err) + } + + if attempts == 1 { + server.TestSpanner.PutExecutionTime(testutil.MethodExecuteStreamingSql, testutil.SimulatedExecutionTime{ + Errors: []error{gstatus.Error(codes.Aborted, "Aborted")}, + }) + } + if _, err := tx.ExecContext(ctx, testutil.UpdateBarSetFoo, ExecOptions{QueryOptions: spanner.QueryOptions{RequestTag: "tag_2"}}); err != nil { + return err + } + + if attempts == 2 { + server.TestSpanner.PutExecutionTime(testutil.MethodExecuteBatchDml, testutil.SimulatedExecutionTime{ + Errors: []error{gstatus.Error(codes.Aborted, "Aborted")}, + }) + } + if _, err := tx.ExecContext(ctx, "start batch dml", ExecOptions{QueryOptions: spanner.QueryOptions{RequestTag: "tag_3"}}); err != nil { + return err + } + if _, err := tx.ExecContext(ctx, testutil.UpdateBarSetFoo); err != nil { + return err + } + if _, err := tx.ExecContext(ctx, "run batch"); err != nil { + return err + } + return nil + }, spanner.TransactionOptions{TransactionTag: "my_transaction_tag"}) + if err != nil { + t.Fatalf("failed to run transaction: %v", err) + } +} + func readOnlyTxWithOptions(ctx context.Context, t *testing.T, db *sql.DB) { tx, err := BeginReadOnlyTransaction(ctx, db, ReadOnlyTransactionOptions{TimestampBound: spanner.ExactStaleness(10 * time.Second)}) diff --git a/connection_properties.go b/connection_properties.go index 31218cd7..218552c1 100644 --- a/connection_properties.go +++ b/connection_properties.go @@ -305,6 +305,16 @@ var propertyTransactionBatchReadOnly = createConnectionProperty( connectionstate.ContextUser, connectionstate.ConvertBool, ) +var propertyTransactionTimeout = createConnectionProperty( + "transaction_timeout", + "The timeout to apply to all read/write transactions on this connection. "+ + "Setting the timeout to zero means no timeout.", + time.Duration(0), + false, + nil, + connectionstate.ContextUser, + connectionstate.ConvertDuration, +) // ------------------------------------------------------------------------------------------------ // Statement connection properties. @@ -318,6 +328,16 @@ var propertyStatementTag = createConnectionProperty( connectionstate.ContextUser, connectionstate.ConvertString, ) +var propertyStatementTimeout = createConnectionProperty( + "statement_timeout", + "The timeout to apply to all statements on this connection. "+ + "Setting the timeout to zero means no timeout.", + time.Duration(0), + false, + nil, + connectionstate.ContextUser, + connectionstate.ConvertDuration, +) // ------------------------------------------------------------------------------------------------ // Startup connection properties. diff --git a/driver.go b/driver.go index 5359f5ab..6db6ba9b 100644 --- a/driver.go +++ b/driver.go @@ -204,12 +204,16 @@ type ExecOptions struct { // order to move to the result set that contains the spannerpb.ResultSetStats. ReturnResultSetStats bool - // DirectExecute determines whether a query is executed directly when the + // DirectExecuteQuery determines whether a query is executed directly when the // [sql.DB.QueryContext] method is called, or whether the actual query execution // is delayed until the first call to [sql.Rows.Next]. The default is to delay // the execution. Set this flag to true to execute the query directly when // [sql.DB.QueryContext] is called. DirectExecuteQuery bool + + // DirectExecuteContext is the context that is used for the execution of a query + // when DirectExecuteQuery is enabled. + DirectExecuteContext context.Context } func (dest *ExecOptions) merge(src *ExecOptions) { @@ -231,6 +235,9 @@ func (dest *ExecOptions) merge(src *ExecOptions) { if src.DirectExecuteQuery { dest.DirectExecuteQuery = src.DirectExecuteQuery } + if src.DirectExecuteContext != nil { + dest.DirectExecuteContext = src.DirectExecuteContext + } if src.AutocommitDMLMode != Unspecified { dest.AutocommitDMLMode = src.AutocommitDMLMode } diff --git a/driver_with_mockserver_test.go b/driver_with_mockserver_test.go index b059aa63..75eb6df1 100644 --- a/driver_with_mockserver_test.go +++ b/driver_with_mockserver_test.go @@ -4572,9 +4572,9 @@ func TestRunTransaction(t *testing.T) { // Verify that internal retries are disabled during RunTransaction txi := reflect.ValueOf(tx).Elem().FieldByName("txi") delegatingTx := (*delegatingTransaction)(txi.Elem().UnsafePointer()) - rwTx := delegatingTx.contextTransaction.(*readWriteTransaction) + rwTx, ok := delegatingTx.contextTransaction.(*readWriteTransaction) // Verify that getting the transaction through reflection worked. - if g, w := rwTx.ctx, ctx; g != w { + if !ok { return fmt.Errorf("getting the transaction through reflection failed") } if rwTx.retryAborts() { @@ -5034,9 +5034,9 @@ func TestBeginReadWriteTransaction(t *testing.T) { // Verify that internal retries are disabled during this transaction. txi := reflect.ValueOf(tx).Elem().FieldByName("txi") delegatingTx := (*delegatingTransaction)(txi.Elem().UnsafePointer()) - rwTx := delegatingTx.contextTransaction.(*readWriteTransaction) + rwTx, ok := delegatingTx.contextTransaction.(*readWriteTransaction) // Verify that getting the transaction through reflection worked. - if g, w := rwTx.ctx, ctx; g != w { + if !ok { t.Fatal("getting the transaction through reflection failed") } if rwTx.retryAborts() { diff --git a/go.mod b/go.mod index 55e0714d..21907208 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/google/uuid v1.6.0 github.com/googleapis/gax-go/v2 v2.15.0 github.com/hashicorp/golang-lru/v2 v2.0.7 + go.uber.org/goleak v1.3.0 google.golang.org/api v0.252.0 google.golang.org/genproto/googleapis/rpc v0.0.0-20251007200510-49b9836ed3ff google.golang.org/grpc v1.76.0 diff --git a/merged_row_iterator.go b/merged_row_iterator.go index 79e72f0a..a2d840f6 100644 --- a/merged_row_iterator.go +++ b/merged_row_iterator.go @@ -125,7 +125,7 @@ func (m *mergedRowIterator) nextIndex() int { func (m *mergedRowIterator) produceRowsFromPartition(ctx context.Context, index int) { m.logger.DebugContext(ctx, "merged row iterator producing rows from partition", "index", index) - r, err := m.partitionedQuery.execute(ctx, index) + r, err := m.partitionedQuery.execute(ctx /*cancel=*/, nil, index) if err != nil { m.registerErr(err) return diff --git a/partitioned_query.go b/partitioned_query.go index 30c2f6fb..e159a914 100644 --- a/partitioned_query.go +++ b/partitioned_query.go @@ -227,13 +227,13 @@ func (pq *PartitionedQuery) Execute(ctx context.Context, index int, db *sql.DB) }) } -func (pq *PartitionedQuery) execute(ctx context.Context, index int) (*rows, error) { +func (pq *PartitionedQuery) execute(ctx context.Context, cancel context.CancelFunc, index int) (*rows, error) { if index < 0 || index >= len(pq.Partitions) { return nil, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "invalid partition index: %d", index)) } spannerIter := pq.tx.Execute(ctx, pq.Partitions[index]) iter := &readOnlyRowIterator{spannerIter, parser.StatementTypeQuery} - return &rows{it: iter, decodeOption: pq.execOptions.DecodeOption}, nil + return &rows{it: iter, cancel: cancel, decodeOption: pq.execOptions.DecodeOption}, nil } func (pq *PartitionedQuery) Close() { diff --git a/rows.go b/rows.go index b878d96d..72f80318 100644 --- a/rows.go +++ b/rows.go @@ -15,6 +15,7 @@ package spannerdriver import ( + "context" "database/sql/driver" "errors" "fmt" @@ -41,9 +42,10 @@ const ( var _ driver.RowsNextResultSet = &rows{} -func createRows(it rowIterator, opts *ExecOptions) *rows { +func createRows(it rowIterator, cancel context.CancelFunc, opts *ExecOptions) *rows { return &rows{ it: it, + cancel: cancel, decodeOption: opts.DecodeOption, decodeToNativeArrays: opts.DecodeToNativeArrays, returnResultSetMetadata: opts.ReturnResultSetMetadata, @@ -52,8 +54,9 @@ func createRows(it rowIterator, opts *ExecOptions) *rows { } type rows struct { - it rowIterator - close func() error + it rowIterator + close func() error + cancel context.CancelFunc colsOnce sync.Once dirtyErr error @@ -119,6 +122,9 @@ func (r *rows) Close() error { return err } } + if r.cancel != nil { + r.cancel() + } return nil } @@ -487,6 +493,7 @@ var emptyRowsMetadata = &sppb.ResultSetMetadata{ var emptyRowsStats = &sppb.ResultSetStats{} type emptyRows struct { + cancel context.CancelFunc currentResultSetType resultSetType returnResultSetMetadata bool returnResultSetStats bool @@ -495,8 +502,9 @@ type emptyRows struct { hasReturnedResultSetStats bool } -func createDriverResultRows(_ driver.Result, opts *ExecOptions) *emptyRows { +func createDriverResultRows(_ driver.Result, cancel context.CancelFunc, opts *ExecOptions) *emptyRows { res := &emptyRows{ + cancel: cancel, returnResultSetMetadata: opts.ReturnResultSetMetadata, returnResultSetStats: opts.ReturnResultSetStats, } @@ -539,6 +547,9 @@ func (e *emptyRows) Columns() []string { } func (e *emptyRows) Close() error { + if e.cancel != nil { + e.cancel() + } return nil } diff --git a/rows_test.go b/rows_test.go index 824fab2a..6e595ff7 100644 --- a/rows_test.go +++ b/rows_test.go @@ -158,7 +158,7 @@ func TestRows_Next_Unsupported(t *testing.T) { } func TestEmptyRows(t *testing.T) { - r := createDriverResultRows(&result{}, &ExecOptions{}) + r := createDriverResultRows(&result{}, func() {}, &ExecOptions{}) if g, w := r.Columns(), []string{"affected_rows"}; !cmp.Equal(g, w) { t.Fatalf("columns mismatch\n Got: %v\nWant: %v", g, w) @@ -169,7 +169,7 @@ func TestEmptyRows(t *testing.T) { } func TestEmptyRowsWithMetadataAndStats(t *testing.T) { - r := createDriverResultRows(&result{}, &ExecOptions{ReturnResultSetMetadata: true, ReturnResultSetStats: true}) + r := createDriverResultRows(&result{}, func() {}, &ExecOptions{ReturnResultSetMetadata: true, ReturnResultSetStats: true}) // The first result set should contain ResultSetMetadata. if g, w := r.Columns(), []string{"metadata"}; !cmp.Equal(g, w) { diff --git a/spannerlib/api/connection.go b/spannerlib/api/connection.go index ab256879..22f911cd 100644 --- a/spannerlib/api/connection.go +++ b/spannerlib/api/connection.go @@ -66,6 +66,10 @@ func WriteMutations(ctx context.Context, poolId, connId int64, mutations *spanne // BeginTransaction starts a new transaction on the given connection. // A connection can have at most one transaction at any time. This function therefore returns an error if the // connection has an active transaction. +// +// NOTE: The context that is passed in to this function is registered as the transaction context. The transaction is +// invalidated if the context is cancelled. The context that is passed in to this function should therefore not be a +// context that is cancelled right after calling this function. func BeginTransaction(ctx context.Context, poolId, connId int64, txOpts *spannerpb.TransactionOptions) error { conn, err := findConnection(poolId, connId) if err != nil { @@ -93,11 +97,15 @@ func Rollback(ctx context.Context, poolId, connId int64) error { } func Execute(ctx context.Context, poolId, connId int64, executeSqlRequest *spannerpb.ExecuteSqlRequest) (int64, error) { + return ExecuteWithDirectExecuteContext(ctx, nil, poolId, connId, executeSqlRequest) +} + +func ExecuteWithDirectExecuteContext(ctx, directExecuteContext context.Context, poolId, connId int64, executeSqlRequest *spannerpb.ExecuteSqlRequest) (int64, error) { conn, err := findConnection(poolId, connId) if err != nil { return 0, err } - return conn.Execute(ctx, executeSqlRequest) + return conn.Execute(ctx, directExecuteContext, executeSqlRequest) } func ExecuteBatch(ctx context.Context, poolId, connId int64, statements *spannerpb.ExecuteBatchDmlRequest) (*spannerpb.ExecuteBatchDmlResponse, error) { @@ -300,16 +308,16 @@ func (conn *Connection) closeResults(ctx context.Context) { }) } -func (conn *Connection) Execute(ctx context.Context, statement *spannerpb.ExecuteSqlRequest) (int64, error) { - return execute(ctx, conn, conn.backend, statement) +func (conn *Connection) Execute(ctx, directExecuteContext context.Context, statement *spannerpb.ExecuteSqlRequest) (int64, error) { + return execute(ctx, directExecuteContext, conn, conn.backend, statement) } func (conn *Connection) ExecuteBatch(ctx context.Context, statements []*spannerpb.ExecuteBatchDmlRequest_Statement) (*spannerpb.ExecuteBatchDmlResponse, error) { return executeBatch(ctx, conn, conn.backend, statements) } -func execute(ctx context.Context, conn *Connection, executor queryExecutor, statement *spannerpb.ExecuteSqlRequest) (int64, error) { - params := extractParams(statement) +func execute(ctx, directExecuteContext context.Context, conn *Connection, executor queryExecutor, statement *spannerpb.ExecuteSqlRequest) (int64, error) { + params := extractParams(directExecuteContext, statement) it, err := executor.QueryContext(ctx, statement.Sql, params...) if err != nil { return 0, err @@ -397,7 +405,7 @@ func executeBatchDml(ctx context.Context, conn *Connection, executor queryExecut Params: statement.Params, ParamTypes: statement.ParamTypes, } - params := extractParams(request) + params := extractParams(nil, request) _, err := executor.ExecContext(ctx, statement.Sql, params...) if err != nil { return nil, err @@ -423,7 +431,7 @@ func executeBatchDml(ctx context.Context, conn *Connection, executor queryExecut return &response, nil } -func extractParams(statement *spannerpb.ExecuteSqlRequest) []any { +func extractParams(directExecuteContext context.Context, statement *spannerpb.ExecuteSqlRequest) []any { paramsLen := 1 if statement.Params != nil { paramsLen = 1 + len(statement.Params.Fields) @@ -436,6 +444,7 @@ func extractParams(statement *spannerpb.ExecuteSqlRequest) []any { ReturnResultSetMetadata: true, ReturnResultSetStats: true, DirectExecuteQuery: true, + DirectExecuteContext: directExecuteContext, }) if statement.Params != nil { if statement.ParamTypes == nil { diff --git a/spannerlib/grpc-server/server.go b/spannerlib/grpc-server/server.go index 24c82786..206237c4 100644 --- a/spannerlib/grpc-server/server.go +++ b/spannerlib/grpc-server/server.go @@ -102,22 +102,10 @@ func (s *spannerLibServer) CloseConnection(ctx context.Context, connection *pb.C return &emptypb.Empty{}, nil } -func contextWithSameDeadline(ctx context.Context) context.Context { - newContext := context.Background() - if deadline, ok := ctx.Deadline(); ok { - // Ignore the returned cancel function here, as the context will be closed when the Rows object is closed. - //goland:noinspection GoVetLostCancel - newContext, _ = context.WithDeadline(newContext, deadline) - } - return newContext -} - -func (s *spannerLibServer) Execute(ctx context.Context, request *pb.ExecuteRequest) (*pb.Rows, error) { - // Create a new context that is used for the query. We need to do this, because the context that is passed in to - // this function will be cancelled once the RPC call finishes. That again would cause further calls to Next on the - // underlying rows object to fail with a 'Context cancelled' error. - queryContext := contextWithSameDeadline(ctx) - id, err := api.Execute(queryContext, request.Connection.Pool.Id, request.Connection.Id, request.ExecuteSqlRequest) +func (s *spannerLibServer) Execute(ctx context.Context, request *pb.ExecuteRequest) (returnedRows *pb.Rows, returnedErr error) { + // Only use the context of the gRPC invocation for the DirectExecute option. That is: It is only used + // for fetching the first results, and can be cancelled after that. + id, err := api.ExecuteWithDirectExecuteContext(context.Background(), ctx, request.Connection.Pool.Id, request.Connection.Id, request.ExecuteSqlRequest) if err != nil { return nil, err } @@ -125,12 +113,12 @@ func (s *spannerLibServer) Execute(ctx context.Context, request *pb.ExecuteReque } func (s *spannerLibServer) ExecuteStreaming(request *pb.ExecuteRequest, stream grpc.ServerStreamingServer[pb.RowData]) error { - queryContext := contextWithSameDeadline(stream.Context()) + queryContext := stream.Context() id, err := api.Execute(queryContext, request.Connection.Pool.Id, request.Connection.Id, request.ExecuteSqlRequest) if err != nil { return err } - defer func() { _ = api.CloseRows(queryContext, request.Connection.Pool.Id, request.Connection.Id, id) }() + defer func() { _ = api.CloseRows(context.Background(), request.Connection.Pool.Id, request.Connection.Id, id) }() rows := &pb.Rows{Connection: request.Connection, Id: id} metadata, err := api.Metadata(queryContext, request.Connection.Pool.Id, request.Connection.Id, id) if err != nil { @@ -214,12 +202,7 @@ func (s *spannerLibServer) BeginTransaction(ctx context.Context, request *pb.Beg // Create a new context that is used for the transaction. We need to do this, because the context that is passed in // to this function will be cancelled once the RPC call finishes. That again would cause further calls on // the underlying transaction to fail with a 'Context cancelled' error. - txContext := context.Background() - if deadline, ok := ctx.Deadline(); ok { - // Ignore the returned cancel function here, as the context will be closed when the transaction is closed. - //goland:noinspection GoVetLostCancel - txContext, _ = context.WithDeadline(txContext, deadline) - } + txContext := context.WithoutCancel(ctx) err := api.BeginTransaction(txContext, request.Connection.Pool.Id, request.Connection.Id, request.TransactionOptions) if err != nil { return nil, err diff --git a/spannerlib/grpc-server/server_test.go b/spannerlib/grpc-server/server_test.go index 35cc3396..8c6879bb 100644 --- a/spannerlib/grpc-server/server_test.go +++ b/spannerlib/grpc-server/server_test.go @@ -2,20 +2,26 @@ package main import ( "context" + "errors" "fmt" + "io" "net" "os" "path/filepath" "reflect" "runtime" "testing" + "time" + "cloud.google.com/go/spanner" "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" sppb "cloud.google.com/go/spanner/apiv1/spannerpb" "github.com/google/uuid" "github.com/googleapis/go-sql-spanner/testutil" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/structpb" pb "spannerlib/grpc-server/google/spannerlib/v1" ) @@ -142,6 +148,42 @@ func TestExecute(t *testing.T) { } } +func TestExecuteWithTimeout(t *testing.T) { + t.Parallel() + ctx := context.Background() + + server, teardown := setupMockSpannerServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + client, cleanup := startTestSpannerLibServer(t) + defer cleanup() + + pool, err := client.CreatePool(ctx, &pb.CreatePoolRequest{ConnectionString: dsn}) + if err != nil { + t.Fatalf("failed to create pool: %v", err) + } + connection, err := client.CreateConnection(ctx, &pb.CreateConnectionRequest{Pool: pool}) + if err != nil { + t.Fatalf("failed to create connection: %v", err) + } + + server.TestSpanner.PutExecutionTime(testutil.MethodExecuteStreamingSql, testutil.SimulatedExecutionTime{MinimumExecutionTime: 2 * time.Millisecond}) + withTimeout, cancel := context.WithTimeout(ctx, time.Millisecond) + defer cancel() + _, err = client.Execute(withTimeout, &pb.ExecuteRequest{ + Connection: connection, + ExecuteSqlRequest: &sppb.ExecuteSqlRequest{Sql: testutil.SelectFooFromBar}, + }) + if g, w := status.Code(err), codes.DeadlineExceeded; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + + if _, err := client.ClosePool(ctx, pool); err != nil { + t.Fatalf("failed to close pool: %v", err) + } +} + func TestExecuteStreaming(t *testing.T) { t.Parallel() ctx := context.Background() @@ -194,6 +236,50 @@ func TestExecuteStreaming(t *testing.T) { } } +func TestExecuteStreamingWithTimeout(t *testing.T) { + t.Parallel() + ctx := context.Background() + + server, teardown := setupMockSpannerServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + client, cleanup := startTestSpannerLibServer(t) + defer cleanup() + + pool, err := client.CreatePool(ctx, &pb.CreatePoolRequest{ConnectionString: dsn}) + if err != nil { + t.Fatalf("failed to create pool: %v", err) + } + connection, err := client.CreateConnection(ctx, &pb.CreateConnectionRequest{Pool: pool}) + if err != nil { + t.Fatalf("failed to create connection: %v", err) + } + + server.TestSpanner.PutExecutionTime(testutil.MethodExecuteStreamingSql, testutil.SimulatedExecutionTime{MinimumExecutionTime: 2 * time.Millisecond}) + withTimeout, cancel := context.WithTimeout(ctx, time.Millisecond) + defer cancel() + stream, err := client.ExecuteStreaming(withTimeout, &pb.ExecuteRequest{ + Connection: connection, + ExecuteSqlRequest: &sppb.ExecuteSqlRequest{Sql: testutil.SelectFooFromBar}, + }) + // The timeout can happen here or while waiting for the first response. + if err != nil { + if g, w := spanner.ErrCode(err), codes.DeadlineExceeded; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + } else { + _, err = stream.Recv() + if g, w := spanner.ErrCode(err), codes.DeadlineExceeded; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + } + + if _, err := client.ClosePool(ctx, pool); err != nil { + t.Fatalf("failed to close pool: %v", err) + } +} + func TestExecuteStreamingClientSideStatement(t *testing.T) { t.Parallel() ctx := context.Background() @@ -251,6 +337,95 @@ func TestExecuteStreamingClientSideStatement(t *testing.T) { } } +func TestExecuteStreamingCustomSql(t *testing.T) { + t.Parallel() + ctx := context.Background() + + server, teardown := setupMockSpannerServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + client, cleanup := startTestSpannerLibServer(t) + defer cleanup() + + pool, err := client.CreatePool(ctx, &pb.CreatePoolRequest{ConnectionString: dsn}) + if err != nil { + t.Fatalf("failed to create pool: %v", err) + } + connection, err := client.CreateConnection(ctx, &pb.CreateConnectionRequest{Pool: pool}) + if err != nil { + t.Fatalf("failed to create connection: %v", err) + } + + stream, err := client.ExecuteStreaming(ctx, &pb.ExecuteRequest{ + Connection: connection, + ExecuteSqlRequest: &sppb.ExecuteSqlRequest{Sql: "begin"}, + }) + if err != nil { + t.Fatalf("failed to execute: %v", err) + } + row, err := stream.Recv() + if err != nil { + t.Fatalf("failed to receive row: %v", err) + } + if g, w := len(row.Data), 0; g != w { + t.Fatalf("row data length mismatch\n Got: %v\nWant: %v", g, w) + } + if _, err := stream.Recv(); !errors.Is(err, io.EOF) { + t.Fatalf("expected io.EOF, got: %v", err) + } + + stream, err = client.ExecuteStreaming(ctx, &pb.ExecuteRequest{ + Connection: connection, + ExecuteSqlRequest: &sppb.ExecuteSqlRequest{Sql: testutil.SelectFooFromBar}, + }) + if err != nil { + t.Fatalf("failed to execute: %v", err) + } + numRows := 0 + for { + row, err := stream.Recv() + if err != nil { + t.Fatalf("failed to receive row: %v", err) + } + if len(row.Data) == 0 { + break + } + if g, w := len(row.Data), 1; g != w { + t.Fatalf("num rows mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := len(row.Data[0].Values), 1; g != w { + t.Fatalf("num values mismatch\n Got: %v\nWant: %v", g, w) + } + numRows++ + } + if g, w := numRows, 2; g != w { + t.Fatalf("num rows mismatch\n Got: %v\nWant: %v", g, w) + } + + stream, err = client.ExecuteStreaming(ctx, &pb.ExecuteRequest{ + Connection: connection, + ExecuteSqlRequest: &sppb.ExecuteSqlRequest{Sql: "commit"}, + }) + if err != nil { + t.Fatalf("failed to execute: %v", err) + } + row, err = stream.Recv() + if err != nil { + t.Fatalf("failed to receive row: %v", err) + } + if g, w := len(row.Data), 0; g != w { + t.Fatalf("row data length mismatch\n Got: %v\nWant: %v", g, w) + } + if _, err := stream.Recv(); !errors.Is(err, io.EOF) { + t.Fatalf("expected io.EOF, got: %v", err) + } + + if _, err := client.ClosePool(ctx, pool); err != nil { + t.Fatalf("failed to close pool: %v", err) + } +} + func TestExecuteBatch(t *testing.T) { t.Parallel() ctx := context.Background() diff --git a/statements.go b/statements.go index 147146b5..eb9277ac 100644 --- a/statements.go +++ b/statements.go @@ -103,7 +103,7 @@ func (s *executableShowStatement) queryContext(ctx context.Context, c *conn, opt if err != nil { return nil, err } - return createRows(it, opts), nil + return createRows(it /*cancel=*/, nil, opts), nil } // SET [SESSION | LOCAL] [my_extension.]my_property {=|to} @@ -282,6 +282,10 @@ func (s *executableBeginStatement) execContext(ctx context.Context, c *conn, opt if len(s.stmt.Identifiers) != len(s.stmt.Literals) { return nil, status.Errorf(codes.InvalidArgument, "statement contains %d identifiers, but %d values given", len(s.stmt.Identifiers), len(s.stmt.Literals)) } + // The context that is passed in to c.BeginTx(..) becomes the transaction context. The transaction is automatically + // rolled back when that context is cancelled. We therefore create a derived context here that is not cancelled when + // the parent is cancelled. + ctx = context.WithoutCancel(ctx) _, err := c.BeginTx(ctx, driver.TxOptions{}) if err != nil { return nil, err diff --git a/stmt_with_mockserver_test.go b/stmt_with_mockserver_test.go index 27bd7800..301a0759 100644 --- a/stmt_with_mockserver_test.go +++ b/stmt_with_mockserver_test.go @@ -20,12 +20,15 @@ import ( "reflect" "strings" "testing" + "time" + "cloud.google.com/go/spanner" "cloud.google.com/go/spanner/apiv1/spannerpb" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/googleapis/go-sql-spanner/testdata/protos/concertspb" "github.com/googleapis/go-sql-spanner/testutil" + "google.golang.org/grpc/codes" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/structpb" ) @@ -179,3 +182,90 @@ func TestPrepareWithValuerScanner(t *testing.T) { t.Fatalf("param value mismatch\n Got: %v\nWant: %v", g, w) } } + +func TestStatementTimeout(t *testing.T) { + t.Parallel() + + db, server, teardown := setupTestDBConnectionWithParams(t, "statement_timeout=1ms") + defer teardown() + ctx := context.Background() + + // The database/sql driver uses ExecuteStreamingSql for all statements. + server.TestSpanner.PutExecutionTime(testutil.MethodExecuteStreamingSql, testutil.SimulatedExecutionTime{MinimumExecutionTime: 10 * time.Millisecond}) + + _, err := db.ExecContext(ctx, testutil.UpdateBarSetFoo) + if g, w := spanner.ErrCode(err), codes.DeadlineExceeded; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + if !strings.Contains(err.Error(), "requestID =") { + t.Fatalf("missing requestID in error: %v", err) + } + _, err = db.QueryContext(ctx, testutil.SelectFooFromBar, ExecOptions{DirectExecuteQuery: true}) + if g, w := spanner.ErrCode(err), codes.DeadlineExceeded; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + if !strings.Contains(err.Error(), "requestID =") { + t.Fatalf("missing requestID in error: %v", err) + } + + // Get a connection and remove the timeout and verify that the statements work. + c, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + if _, err := c.ExecContext(ctx, "set statement_timeout = null"); err != nil { + t.Fatalf("failed to remove statement_timeout: %v", err) + } + _, err = c.ExecContext(ctx, testutil.UpdateBarSetFoo) + if g, w := spanner.ErrCode(err), codes.OK; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + r, err := c.QueryContext(ctx, testutil.SelectFooFromBar, ExecOptions{DirectExecuteQuery: true}) + if r != nil { + _ = r.Close() + } + if g, w := spanner.ErrCode(err), codes.OK; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + + // Add the timeout again on the connection and verify that the statements time out again. + if _, err := c.ExecContext(ctx, "set statement_timeout = 1ms"); err != nil { + t.Fatalf("failed to set statement_timeout: %v", err) + } + _, err = c.ExecContext(ctx, testutil.UpdateBarSetFoo) + if g, w := spanner.ErrCode(err), codes.DeadlineExceeded; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + if !strings.Contains(err.Error(), "requestID =") { + t.Fatalf("missing requestID in error: %v", err) + } + _, err = c.QueryContext(ctx, testutil.SelectFooFromBar, ExecOptions{DirectExecuteQuery: true}) + if g, w := spanner.ErrCode(err), codes.DeadlineExceeded; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + if !strings.Contains(err.Error(), "requestID =") { + t.Fatalf("missing requestID in error: %v", err) + } + + // Set a longer timeout and verify that executing a query and iterating its results works. + if _, err := c.ExecContext(ctx, "set statement_timeout = 1s"); err != nil { + t.Fatalf("failed to set statement_timeout: %v", err) + } + r, err = c.QueryContext(ctx, testutil.SelectFooFromBar) + if err != nil { + t.Fatal(err) + } + for r.Next() { + var val int64 + if err := r.Scan(&val); err != nil { + t.Fatal(err) + } + } + if err := r.Close(); err != nil { + t.Fatal(err) + } + + if err := c.Close(); err != nil { + t.Fatal(err) + } +} diff --git a/transaction.go b/transaction.go index 31a63083..36b1fb49 100644 --- a/transaction.go +++ b/transaction.go @@ -42,6 +42,7 @@ type spannerTransaction interface { // contextTransaction is the combination of both read/write and read-only // transactions. type contextTransaction interface { + deadline() (time.Time, bool) Commit() error Rollback() error resetForRetry(ctx context.Context) error @@ -122,6 +123,13 @@ type delegatingTransaction struct { contextTransaction contextTransaction } +func (d *delegatingTransaction) deadline() (time.Time, bool) { + if d.contextTransaction != nil { + return d.contextTransaction.deadline() + } + return d.ctx.Deadline() +} + func (d *delegatingTransaction) ensureActivated() error { if d.contextTransaction != nil { return nil @@ -230,6 +238,10 @@ type readOnlyTransaction struct { timestampBoundCallback func() spanner.TimestampBound } +func (tx *readOnlyTransaction) deadline() (time.Time, bool) { + return time.Time{}, false +} + func (tx *readOnlyTransaction) Commit() error { tx.logger.Debug("committing transaction") // Read-only transactions don't really commit, but closing the transaction @@ -472,6 +484,10 @@ func (ru *retriableBatchUpdate) retry(ctx context.Context, tx *spanner.ReadWrite return nil } +func (tx *readWriteTransaction) deadline() (time.Time, bool) { + return tx.ctx.Deadline() +} + // runWithRetry executes a statement on a go/sql read/write transaction and // automatically retries the entire transaction if the statement returns an // Aborted error. The method will return ErrAbortedDueToConcurrentModification @@ -539,7 +555,7 @@ func (tx *readWriteTransaction) Commit() (err error) { tx.logger.Debug("committing transaction") tx.active = true if err := tx.maybeRunAutoDmlBatch(tx.ctx); err != nil { - _ = tx.rollback(tx.ctx) + _ = tx.rollback() return err } var commitResponse spanner.CommitResponse @@ -569,12 +585,14 @@ func (tx *readWriteTransaction) Rollback() error { if tx.batch != nil && tx.batch.automatic { _, _ = tx.AbortBatch() } - return tx.rollback(tx.ctx) + return tx.rollback() } -func (tx *readWriteTransaction) rollback(ctx context.Context) error { +func (tx *readWriteTransaction) rollback() error { if tx.rwTx != nil { - tx.rwTx.Rollback(ctx) + // Always use context.Background() for rollback invocations to allow them + // to be executed, even if the transaction has timed out or been cancelled. + tx.rwTx.Rollback(context.Background()) } tx.close(txResultRollback, nil, nil) return nil diff --git a/transaction_test.go b/transaction_test.go index 6bf35785..35e738da 100644 --- a/transaction_test.go +++ b/transaction_test.go @@ -6,10 +6,13 @@ import ( "fmt" "reflect" "testing" + "time" "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" "cloud.google.com/go/spanner/apiv1/spannerpb" "github.com/googleapis/go-sql-spanner/testutil" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func TestSetTransactionIsolationLevel(t *testing.T) { @@ -271,3 +274,83 @@ func TestDmlBatchReturnsBatchUpdateCounts(t *testing.T) { } } } + +func TestTransactionTimeout(t *testing.T) { + t.Parallel() + + db, server, teardown := setupTestDBConnection(t) + defer teardown() + ctx := context.Background() + + server.TestSpanner.PutExecutionTime(testutil.MethodExecuteStreamingSql, testutil.SimulatedExecutionTime{MinimumExecutionTime: 20 * time.Millisecond}) + tx, _ := db.BeginTx(ctx, &sql.TxOptions{}) + if _, err := tx.ExecContext(ctx, "set local transaction_timeout=10ms"); err != nil { + t.Fatal(err) + } + _, err := tx.ExecContext(ctx, testutil.UpdateBarSetFoo) + if g, w := status.Code(err), codes.DeadlineExceeded; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } + + requests := server.TestSpanner.DrainRequestsFromServer() + executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 1; g != w { + t.Fatalf("execute requests count mismatch\n Got: %v\nWant: %v", g, w) + } + commitRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{})) + if g, w := len(commitRequests), 0; g != w { + t.Fatalf("commit requests count mismatch\n Got: %v\nWant: %v", g, w) + } + rollbackRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.RollbackRequest{})) + // There should be no rollback requests on the server, because the initial ExecuteSqlRequest timed out. + // That means that no transaction ID was returned to the client, so there is nothing to rollback. + if g, w := len(rollbackRequests), 0; g != w { + t.Fatalf("rollback requests count mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestTransactionTimeoutSecondStatement(t *testing.T) { + t.Parallel() + + db, server, teardown := setupTestDBConnection(t) + defer teardown() + ctx := context.Background() + + tx, _ := db.BeginTx(ctx, &sql.TxOptions{}) + if _, err := tx.ExecContext(ctx, "set local transaction_timeout=20ms"); err != nil { + t.Fatal(err) + } + if _, err := tx.ExecContext(ctx, testutil.UpdateBarSetFoo); err != nil { + t.Fatal(err) + } + + server.TestSpanner.PutExecutionTime(testutil.MethodExecuteStreamingSql, testutil.SimulatedExecutionTime{MinimumExecutionTime: 30 * time.Millisecond}) + rows, err := tx.QueryContext(ctx, testutil.SelectFooFromBar, ExecOptions{DirectExecuteQuery: true}) + if rows != nil { + _ = rows.Close() + } + if g, w := status.Code(err), codes.DeadlineExceeded; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } + + requests := server.TestSpanner.DrainRequestsFromServer() + executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + // The server should receive 1 or 2 requests, depending on exactly when the deadline exceeded error happens. + if g, w1, w2 := len(executeRequests), 1, 2; g != w1 && g != w2 { + t.Fatalf("execute requests count mismatch\n Got: %v\nWant: %v\n Or: %v", g, w1, w2) + } + commitRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{})) + if g, w := len(commitRequests), 0; g != w { + t.Fatalf("commit requests count mismatch\n Got: %v\nWant: %v", g, w) + } + rollbackRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.RollbackRequest{})) + if g, w := len(rollbackRequests), 1; g != w { + t.Fatalf("rollback requests count mismatch\n Got: %v\nWant: %v", g, w) + } +} From b3f7cfe46dd8599ac0429bc9d4612cd484a1d775 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Fri, 31 Oct 2025 10:59:08 +0100 Subject: [PATCH 2/2] chore: use errors.Join instead of two %w verbs --- conn.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/conn.go b/conn.go index da89689d..a0e886b2 100644 --- a/conn.go +++ b/conn.go @@ -19,7 +19,6 @@ import ( "database/sql" "database/sql/driver" "errors" - "fmt" "log/slog" "slices" "sync" @@ -1028,7 +1027,7 @@ func (c *conn) directExecuteQuery(ctx context.Context, cancelQuery context.Cance // one error. This will preserve the DeadlineExceeded error code from statementErr, and include the request // ID from the Spanner error. s := status.FromContextError(statementErr) - return fmt.Errorf("%w: %w", s.Err(), res.dirtyErr) + return errors.Join(s.Err(), res.dirtyErr) } return res.dirtyErr }