Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 179 additions & 14 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,69 @@ func (c *conn) PrepareContext(_ context.Context, query string) (driver.Stmt, err
return &stmt{conn: c, query: parsedSQL, numArgs: len(args), execOptions: execOptions}, nil
}

// Adds any statement or transaction timeout to the given context. The deadline of the returned
// context will be the earliest of:
// 1. Any existing deadline on the input context.
// 2. Any existing transaction deadline.
// 3. A deadline calculated from the current time + the value of statement_timeout.
func (c *conn) addStatementAndTransactionTimeout(ctx context.Context) (context.Context, context.CancelFunc, error) {
var statementDeadline time.Time
var transactionDeadline time.Time
var deadline time.Time
var hasStatementDeadline bool
var hasTransactionDeadline bool

// Check if the connection has a value for statement_timeout.
statementTimeout := propertyStatementTimeout.GetValueOrDefault(c.state)
if statementTimeout != time.Duration(0) {
hasStatementDeadline = true
statementDeadline = time.Now().Add(statementTimeout)
}
// Check if the current transaction has a deadline.
transactionDeadline, hasTransactionDeadline, err := c.transactionDeadline()
if err != nil {
return nil, nil, err
}

// If there is no statement_timeout and no current transaction deadline,
// then can just use the input context as-is.
if !hasStatementDeadline && !hasTransactionDeadline {
return ctx, func() {}, nil
}

// If there is both a transaction and a statement deadline, then we use the earliest
// of those two.
if hasTransactionDeadline && hasStatementDeadline {
if statementDeadline.Before(transactionDeadline) {
deadline = statementDeadline
} else {
deadline = transactionDeadline
}
} else if hasStatementDeadline {
deadline = statementDeadline
} else {
deadline = transactionDeadline
}
// context.WithDeadline automatically selects the earliest deadline of
// the existing deadline on the context and the given deadline.
newCtx, cancel := context.WithDeadline(ctx, deadline)
return newCtx, cancel, nil
}

// transactionDeadline returns the deadline of the current transaction
// on the connection. This also activates the transaction if it is not
// yet activated.
func (c *conn) transactionDeadline() (time.Time, bool, error) {
if c.tx == nil {
return time.Time{}, false, nil
}
if err := c.tx.ensureActivated(); err != nil {
return time.Time{}, false, err
}
deadline, hasDeadline := c.tx.deadline()
return deadline, hasDeadline, nil
}

func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
// Execute client side statement if it is one.
clientStmt, err := c.parser.ParseClientSideStatement(query)
Expand All @@ -849,13 +912,22 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
return c.queryContext(ctx, query, execOptions, args)
}

func (c *conn) queryContext(ctx context.Context, query string, execOptions *ExecOptions, args []driver.NamedValue) (driver.Rows, error) {
func (c *conn) queryContext(ctx context.Context, query string, execOptions *ExecOptions, args []driver.NamedValue) (returnedRows driver.Rows, returnedErr error) {
ctx, cancelCause := context.WithCancelCause(ctx)
cancel := func() {
cancelCause(nil)
}
defer func() {
if returnedErr != nil {
cancel()
}
}()
// Clear the commit timestamp of this connection before we execute the query.
c.clearCommitResponse()
// Check if the execution options contains an instruction to execute
// a specific partition of a PartitionedQuery.
if pq := execOptions.PartitionedQueryOptions.ExecutePartition.PartitionedQuery; pq != nil {
return pq.execute(ctx, execOptions.PartitionedQueryOptions.ExecutePartition.Index)
return pq.execute(ctx, cancel, execOptions.PartitionedQueryOptions.ExecutePartition.Index)
}

stmt, err := prepareSpannerStmt(c.parser, query, args)
Expand All @@ -869,7 +941,7 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
if err != nil {
return nil, err
}
return createDriverResultRows(res, execOptions), nil
return createDriverResultRows(res, cancel, execOptions), nil
}
var iter rowIterator
if c.tx == nil {
Expand All @@ -884,7 +956,7 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
} else if execOptions.PartitionedQueryOptions.PartitionQuery {
return nil, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "PartitionQuery is only supported in batch read-only transactions"))
} else if execOptions.PartitionedQueryOptions.AutoPartitionQuery {
return c.executeAutoPartitionedQuery(ctx, query, execOptions, args)
return c.executeAutoPartitionedQuery(ctx, cancel, query, execOptions, args)
} else {
// The statement was either detected as being a query, or potentially not recognized at all.
// In that case, just default to using a single-use read-only transaction and let Spanner
Expand All @@ -893,25 +965,75 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
}
} else {
if execOptions.PartitionedQueryOptions.PartitionQuery {
// The driver.Rows instance that is returned for partitionQuery does not
// contain a context, and therefore also does not cancel the context when it is closed.
defer cancel()
return c.tx.partitionQuery(ctx, stmt, execOptions)
}
iter, err = c.tx.Query(ctx, stmt, statementInfo.StatementType, execOptions)
if err != nil {
return nil, err
}
}
res := createRows(iter, execOptions)
res := createRows(iter, cancel, execOptions)
if execOptions.DirectExecuteQuery {
// This call to res.getColumns() triggers the execution of the statement, as it needs to fetch the metadata.
res.getColumns()
if res.dirtyErr != nil && !errors.Is(res.dirtyErr, iterator.Done) {
_ = res.Close()
return nil, res.dirtyErr
if err := c.directExecuteQuery(ctx, cancelCause, res, execOptions); err != nil {
return nil, err
}
}
return res, nil
}

// directExecuteQuery blocks until the first PartialResultSet has been returned by Spanner. Any statement_timeout and/or
// transaction_timeout is used while waiting for the first result to be returned.
func (c *conn) directExecuteQuery(ctx context.Context, cancelQuery context.CancelCauseFunc, res *rows, execOptions *ExecOptions) error {
statementCtx := ctx
if execOptions.DirectExecuteContext != nil {
statementCtx = execOptions.DirectExecuteContext
}
// Add the statement or transaction deadline to the context.
statementCtx, cancelStatement, err := c.addStatementAndTransactionTimeout(statementCtx)
if err != nil {
return err
}
defer cancelStatement()

// Asynchronously fetch the first partial result set from Spanner.
done := make(chan struct{})
go func() {
// Calling res.getColumns() ensures that the first PartialResultSet has been returned, as it contains the
// metadata of the query.
defer close(done)
res.getColumns()
}()
// Wait until either the done channel is closed or the context is done.
var statementErr error
select {
case <-statementCtx.Done():
statementErr = statementCtx.Err()
// Cancel the query execution.
cancelQuery(statementCtx.Err())
case <-done:
}

// Now wait until done channel is closed. This could be because the execution finished
// successfully, or because the context was cancelled, which again causes the execution
// to (eventually) fail.
<-done
if res.dirtyErr != nil && !errors.Is(res.dirtyErr, iterator.Done) {
_ = res.Close()
if statementErr != nil {
// Create a status error from the statement error and wrap both the Spanner error and the status error into
// one error. This will preserve the DeadlineExceeded error code from statementErr, and include the request
// ID from the Spanner error.
s := status.FromContextError(statementErr)
return errors.Join(s.Err(), res.dirtyErr)
}
return res.dirtyErr
}
return nil
}

func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
// Execute client side statement if it is one.
stmt, err := c.parser.ParseClientSideStatement(query)
Expand All @@ -929,7 +1051,13 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
return c.execContext(ctx, query, execOptions, args)
}

func (c *conn) execContext(ctx context.Context, query string, execOptions *ExecOptions, args []driver.NamedValue) (driver.Result, error) {
func (c *conn) execContext(ctx context.Context, query string, execOptions *ExecOptions, args []driver.NamedValue) (returnedResult driver.Result, returnedErr error) {
// Add the statement/transaction deadline to the context.
ctx, cancel, err := c.addStatementAndTransactionTimeout(ctx)
if err != nil {
return nil, err
}
defer cancel()
// Clear the commit timestamp of this connection before we execute the statement.
c.clearCommitResponse()

Expand Down Expand Up @@ -1041,6 +1169,18 @@ func (c *conn) resetTransactionForRetry(ctx context.Context, errDuringCommit boo
return noTransaction()
}
c.tx = c.prevTx
// If the aborted error happened during the Commit, then the transaction
// context has been cancelled, and we need to create a new one.
if rwTx, ok := c.tx.contextTransaction.(*readWriteTransaction); ok {
newCtx, cancel := c.addTransactionTimeout(c.tx.ctx)
rwTx.ctx = newCtx
// Make sure that we cancel the new context when the transaction is closed.
origClose := rwTx.close
rwTx.close = func(result txResult, commitResponse *spanner.CommitResponse, commitErr error) {
origClose(result, commitResponse, commitErr)
cancel()
}
}
c.resetForRetry = true
} else if c.tx == nil {
return noTransaction()
Expand Down Expand Up @@ -1248,6 +1388,17 @@ func (c *conn) beginTx(ctx context.Context, driverOpts driver.TxOptions, closeFu
return c.tx, nil
}

// addTransactionTimeout creates a new derived context with the current transaction_timeout.
func (c *conn) addTransactionTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
timeout := propertyTransactionTimeout.GetValueOrDefault(c.state)
if timeout == time.Duration(0) {
return ctx, func() {}
}
// Note that this will set the actual deadline to the earliest of the existing deadline on ctx and the calculated
// deadline based on the timeout.
return context.WithTimeout(ctx, timeout)
}

func (c *conn) activateTransaction() (contextTransaction, error) {
closeFunc := c.tx.close
if propertyTransactionReadOnly.GetValueOrDefault(c.state) {
Expand Down Expand Up @@ -1283,19 +1434,23 @@ func (c *conn) activateTransaction() (contextTransaction, error) {
opts := spanner.TransactionOptions{}
opts.BeginTransactionOption = c.convertDefaultBeginTransactionOption(propertyBeginTransactionOption.GetValueOrDefault(c.state))

tx, err := spanner.NewReadWriteStmtBasedTransactionWithCallbackForOptions(c.tx.ctx, c.client, opts, func() spanner.TransactionOptions {
// Add the current value of transaction_timeout to the context that is registered
// on the transaction.
ctx, cancel := c.addTransactionTimeout(c.tx.ctx)
tx, err := spanner.NewReadWriteStmtBasedTransactionWithCallbackForOptions(ctx, c.client, opts, func() spanner.TransactionOptions {
defer func() {
// Reset the transaction_tag after starting the transaction.
_ = propertyTransactionTag.ResetValue(c.state, connectionstate.ContextUser)
}()
return c.effectiveTransactionOptions(spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED, c.options( /*reset=*/ true))
})
if err != nil {
cancel()
return nil, err
}
logger := c.logger.With("tx", "rw")
return &readWriteTransaction{
ctx: c.tx.ctx,
ctx: ctx,
conn: c,
logger: logger,
rwTx: tx,
Expand All @@ -1307,6 +1462,7 @@ func (c *conn) activateTransaction() (contextTransaction, error) {
} else {
closeFunc(txResultRollback)
}
cancel()
},
retryAborts: sync.OnceValue(func() bool {
return c.RetryAbortsInternally()
Expand Down Expand Up @@ -1371,7 +1527,15 @@ func queryInSingleUse(ctx context.Context, c *spanner.Client, statement spanner.
return c.Single().WithTimestampBound(tb).QueryWithOptions(ctx, statement, options.QueryOptions)
}

func (c *conn) executeAutoPartitionedQuery(ctx context.Context, query string, execOptions *ExecOptions, args []driver.NamedValue) (driver.Rows, error) {
func (c *conn) executeAutoPartitionedQuery(ctx context.Context, cancel context.CancelFunc, query string, execOptions *ExecOptions, args []driver.NamedValue) (returnedRows driver.Rows, returnedErr error) {
// The cancel() function is called by the returned Rows object when it is closed.
// However, if an error is returned instead of a Rows instance, we need to cancel
// the context when we return from this function.
defer func() {
if returnedErr != nil {
cancel()
}
}()
tx, err := c.BeginTx(ctx, driver.TxOptions{ReadOnly: true, Isolation: withBatchReadOnly(driver.IsolationLevel(sql.LevelDefault))})
if err != nil {
return nil, err
Expand All @@ -1383,6 +1547,7 @@ func (c *conn) executeAutoPartitionedQuery(ctx context.Context, query string, ex
}
if rows, ok := r.(*rows); ok {
rows.close = func() error {
defer cancel()
return tx.Commit()
}
}
Expand Down
Loading
Loading