Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion client_side_statement_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,11 @@ func TestShowCommitTimestamp(t *testing.T) {
{&ts},
{nil},
} {
c.commitTs = test.wantValue
if test.wantValue == nil {
c.commitResponse = nil
} else {
c.commitResponse = &spanner.CommitResponse{CommitTs: *test.wantValue}
}

it, err := s.ShowCommitTimestamp(ctx, c, "", ExecOptions{}, nil)
if err != nil {
Expand Down
81 changes: 46 additions & 35 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ type SpannerConn interface {
// was executed on the connection, or an error if the connection has not executed a read/write transaction
// that committed successfully. The timestamp is in the local timezone.
CommitTimestamp() (commitTimestamp time.Time, err error)
// CommitResponse returns the commit response of the last implicit or explicit read/write transaction that
// was executed on the connection, or an error if the connection has not executed a read/write transaction
// that committed successfully.
CommitResponse() (commitResponse *spanner.CommitResponse, err error)

// UnderlyingClient returns the underlying Spanner client for the database.
// The client cannot be used to access the current transaction or batch on
Expand Down Expand Up @@ -208,23 +212,23 @@ type SpannerConn interface {
var _ SpannerConn = &conn{}

type conn struct {
parser *statementParser
connector *connector
closed bool
client *spanner.Client
adminClient *adminapi.DatabaseAdminClient
connId string
logger *slog.Logger
tx contextTransaction
prevTx contextTransaction
resetForRetry bool
commitTs *time.Time
database string
retryAborts bool
parser *statementParser
connector *connector
closed bool
client *spanner.Client
adminClient *adminapi.DatabaseAdminClient
connId string
logger *slog.Logger
tx contextTransaction
prevTx contextTransaction
resetForRetry bool
commitResponse *spanner.CommitResponse
database string
retryAborts bool

execSingleQuery func(ctx context.Context, c *spanner.Client, statement spanner.Statement, bound spanner.TimestampBound, options ExecOptions) *spanner.RowIterator
execSingleQueryTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options ExecOptions) (rowIterator, time.Time, error)
execSingleDMLTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *statementInfo, options ExecOptions) (*result, time.Time, error)
execSingleQueryTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options ExecOptions) (rowIterator, *spanner.CommitResponse, error)
execSingleDMLTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *statementInfo, options ExecOptions) (*result, *spanner.CommitResponse, error)
execSingleDMLPartitioned func(ctx context.Context, c *spanner.Client, statement spanner.Statement, options ExecOptions) (int64, error)

// batch is the currently active DDL or DML batch on this connection.
Expand Down Expand Up @@ -273,10 +277,17 @@ func (c *conn) UnderlyingClient() (*spanner.Client, error) {
}

func (c *conn) CommitTimestamp() (time.Time, error) {
if c.commitTs == nil {
if c.commitResponse == nil {
return time.Time{}, spanner.ToSpannerError(status.Error(codes.FailedPrecondition, "this connection has not executed a read/write transaction that committed successfully"))
}
return *c.commitTs, nil
return c.commitResponse.CommitTs, nil
}

func (c *conn) CommitResponse() (commitResponse *spanner.CommitResponse, err error) {
if c.commitResponse == nil {
return nil, spanner.ToSpannerError(status.Error(codes.FailedPrecondition, "this connection has not executed a read/write transaction that committed successfully"))
}
return c.commitResponse, nil
}

func (c *conn) RetryAbortsInternally() bool {
Expand Down Expand Up @@ -670,7 +681,7 @@ func (c *conn) ResetSession(_ context.Context) error {
return driver.ErrBadConn
}
}
c.commitTs = nil
c.commitResponse = nil
c.batch = nil
c.autoBatchDml = c.connector.connectorConfig.AutoBatchDml
c.autoBatchDmlUpdateCount = c.connector.connectorConfig.AutoBatchDmlUpdateCount
Expand Down Expand Up @@ -771,7 +782,7 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam

func (c *conn) queryContext(ctx context.Context, query string, execOptions ExecOptions, args []driver.NamedValue) (driver.Rows, error) {
// Clear the commit timestamp of this connection before we execute the query.
c.commitTs = nil
c.commitResponse = nil
// Check if the execution options contains an instruction to execute
// a specific partition of a PartitionedQuery.
if pq := execOptions.PartitionedQueryOptions.ExecutePartition.PartitionedQuery; pq != nil {
Expand All @@ -791,12 +802,12 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions ExecO
if c.tx == nil {
if statementType.statementType == statementTypeDml {
// Use a read/write transaction to execute the statement.
var commitTs time.Time
iter, commitTs, err = c.execSingleQueryTransactional(ctx, c.client, stmt, execOptions)
var commitResponse *spanner.CommitResponse
iter, commitResponse, err = c.execSingleQueryTransactional(ctx, c.client, stmt, execOptions)
if err != nil {
return nil, err
}
c.commitTs = &commitTs
c.commitResponse = commitResponse
} else if execOptions.PartitionedQueryOptions.PartitionQuery {
return nil, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "PartitionQuery is only supported in batch read-only transactions"))
} else if execOptions.PartitionedQueryOptions.AutoPartitionQuery {
Expand Down Expand Up @@ -843,7 +854,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name

func (c *conn) execContext(ctx context.Context, query string, execOptions ExecOptions, args []driver.NamedValue) (driver.Result, error) {
// Clear the commit timestamp of this connection before we execute the statement.
c.commitTs = nil
c.commitResponse = nil

statementInfo := c.parser.detectStatementType(query)
// Use admin API if DDL statement is provided.
Expand All @@ -870,7 +881,7 @@ func (c *conn) execContext(ctx context.Context, query string, execOptions ExecOp
}

var res *result
var commitTs time.Time
var commitResponse *spanner.CommitResponse
if c.tx == nil {
if c.InDMLBatch() {
c.batch.statements = append(c.batch.statements, ss)
Expand All @@ -881,9 +892,9 @@ func (c *conn) execContext(ctx context.Context, query string, execOptions ExecOp
dmlMode = execOptions.AutocommitDMLMode
}
if dmlMode == Transactional {
res, commitTs, err = c.execSingleDMLTransactional(ctx, c.client, ss, statementInfo, execOptions)
res, commitResponse, err = c.execSingleDMLTransactional(ctx, c.client, ss, statementInfo, execOptions)
if err == nil {
c.commitTs = &commitTs
c.commitResponse = commitResponse
}
} else if dmlMode == PartitionedNonAtomic {
var rowsAffected int64
Expand Down Expand Up @@ -1084,20 +1095,20 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
conn: c,
logger: logger,
rwTx: tx,
close: func(commitTs *time.Time, commitErr error) {
close: func(commitResponse *spanner.CommitResponse, commitErr error) {
if readWriteTransactionOptions.close != nil {
readWriteTransactionOptions.close()
}
c.prevTx = c.tx
c.tx = nil
if commitErr == nil {
c.commitTs = commitTs
c.commitResponse = commitResponse
}
},
// Disable internal retries if any of these options have been set.
retryAborts: !readWriteTransactionOptions.DisableInternalRetries && !disableRetryAborts,
}
c.commitTs = nil
c.commitResponse = nil
return c.tx, nil
}

Expand Down Expand Up @@ -1153,7 +1164,7 @@ func (c *conn) executeAutoPartitionedQuery(ctx context.Context, query string, ar
return r, nil
}

func queryInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, options ExecOptions) (rowIterator, time.Time, error) {
func queryInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, options ExecOptions) (rowIterator, *spanner.CommitResponse, error) {
var result *wrappedRowIterator
options.QueryOptions.LastStatement = true
fn := func(ctx context.Context, tx *spanner.ReadWriteTransaction) error {
Expand All @@ -1177,14 +1188,14 @@ func queryInNewRWTransaction(ctx context.Context, c *spanner.Client, statement s
}
resp, err := c.ReadWriteTransactionWithOptions(ctx, fn, options.TransactionOptions)
if err != nil {
return nil, time.Time{}, err
return nil, nil, err
}
return result, resp.CommitTs, nil
return result, &resp, nil
}

var errInvalidDmlForExecContext = spanner.ToSpannerError(status.Error(codes.FailedPrecondition, "Exec and ExecContext can only be used with INSERT statements with a THEN RETURN clause that return exactly one row with one column of type INT64. Use Query or QueryContext for DML statements other than INSERT and/or with THEN RETURN clauses that return other/more data."))

func execInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *statementInfo, options ExecOptions) (*result, time.Time, error) {
func execInNewRWTransaction(ctx context.Context, c *spanner.Client, statement spanner.Statement, statementInfo *statementInfo, options ExecOptions) (*result, *spanner.CommitResponse, error) {
var res *result
options.QueryOptions.LastStatement = true
fn := func(ctx context.Context, tx *spanner.ReadWriteTransaction) error {
Expand All @@ -1197,9 +1208,9 @@ func execInNewRWTransaction(ctx context.Context, c *spanner.Client, statement sp
}
resp, err := c.ReadWriteTransactionWithOptions(ctx, fn, options.TransactionOptions)
if err != nil {
return &result{}, time.Time{}, err
return &result{}, nil, err
}
return res, resp.CommitTs, nil
return res, &resp, nil
}

func execTransactionalDML(ctx context.Context, tx spannerTransaction, statement spanner.Statement, statementInfo *statementInfo, options spanner.QueryOptions) (*result, error) {
Expand Down
66 changes: 56 additions & 10 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,8 @@ func (c *connector) closeClients() (err error) {
//
// This function will never return ErrAbortedDueToConcurrentModification.
func RunTransaction(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func(ctx context.Context, tx *sql.Tx) error) error {
return runTransactionWithOptions(ctx, db, opts, f, spanner.TransactionOptions{})
_, err := runTransactionWithOptions(ctx, db, opts, f, spanner.TransactionOptions{})
return err
}

// RunTransactionWithOptions runs the given function in a transaction on the given database.
Expand All @@ -873,18 +874,44 @@ func RunTransaction(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func
//
// This function will never return ErrAbortedDueToConcurrentModification.
func RunTransactionWithOptions(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func(ctx context.Context, tx *sql.Tx) error, spannerOptions spanner.TransactionOptions) error {
_, err := runTransactionWithOptions(ctx, db, opts, f, spannerOptions)
return err
}

// RunTransactionWithCommitResponse runs the given function in a transaction on
// the given database. If the connection is a connection to a Spanner database,
// the transaction will automatically be retried if the transaction is aborted
// by Spanner. Any other errors will be propagated to the caller and the
// transaction will be rolled back. The transaction will be committed if the
// supplied function did not return an error.
//
// If the connection is to a non-Spanner database, no retries will be attempted,
// and any error that occurs during the transaction will be propagated to the
// caller.
//
// The application should *NOT* call tx.Commit() or tx.Rollback(). This is done
// automatically by this function, depending on whether the transaction function
// returned an error or not.
//
// The given spanner.TransactionOptions will be used for the transaction.
//
// This function returns a spanner.CommitResponse if the transaction committed
// successfully.
//
// This function will never return ErrAbortedDueToConcurrentModification.
func RunTransactionWithCommitResponse(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func(ctx context.Context, tx *sql.Tx) error, spannerOptions spanner.TransactionOptions) (*spanner.CommitResponse, error) {
return runTransactionWithOptions(ctx, db, opts, f, spannerOptions)
}

func runTransactionWithOptions(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func(ctx context.Context, tx *sql.Tx) error, spannerOptions spanner.TransactionOptions) error {
func runTransactionWithOptions(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func(ctx context.Context, tx *sql.Tx) error, spannerOptions spanner.TransactionOptions) (*spanner.CommitResponse, error) {
// Get a connection from the pool that we can use to run a transaction.
// Getting a connection here already makes sure that we can reserve this
// connection exclusively for the duration of this method. That again
// allows us to temporarily change the state of the connection (e.g. set
// the retryAborts flag to false).
conn, err := db.Conn(ctx)
if err != nil {
return err
return nil, err
}
defer func() {
_ = conn.Close()
Expand All @@ -908,20 +935,24 @@ func runTransactionWithOptions(ctx context.Context, db *sql.DB, opts *sql.TxOpti
spannerConn.withTempTransactionOptions(transactionOptions)
return nil
}); err != nil {
return err
return nil, err
}

tx, err := conn.BeginTx(ctx, opts)
if err != nil {
return err
return nil, err
}
for {
err = protected(ctx, tx, f)
errDuringCommit := false
if err == nil {
err = tx.Commit()
if err == nil {
return nil
resp, err := getCommitResponse(conn)
if err != nil {
return nil, err
}
return resp, nil
}
errDuringCommit = true
}
Expand All @@ -934,7 +965,7 @@ func runTransactionWithOptions(ctx context.Context, db *sql.DB, opts *sql.TxOpti
// and just returns an ErrTxDone if we do, so this is simpler than
// keeping track of where the error happened.
_ = tx.Rollback()
return err
return nil, err
}

// The transaction was aborted by Spanner.
Expand All @@ -947,15 +978,15 @@ func runTransactionWithOptions(ctx context.Context, db *sql.DB, opts *sql.TxOpti
// anymore. It does not actually roll back the transaction, as it
// has already been aborted by Spanner.
_ = tx.Rollback()
return err
return nil, err
}
}

// Reset the transaction after it was aborted.
err = resetTransactionForRetry(ctx, conn, errDuringCommit)
if err != nil {
_ = tx.Rollback()
return err
return nil, err
}
// This does not actually start a new transaction, instead it
// continues with the previous transaction that was already reset.
Expand All @@ -965,12 +996,13 @@ func runTransactionWithOptions(ctx context.Context, db *sql.DB, opts *sql.TxOpti
if errDuringCommit {
tx, err = conn.BeginTx(ctx, opts)
if err != nil {
return err
return nil, err
}
}
}

}

func protected(ctx context.Context, tx *sql.Tx, f func(ctx context.Context, tx *sql.Tx) error) (err error) {
defer func() {
if x := recover(); x != nil {
Expand All @@ -990,6 +1022,20 @@ func resetTransactionForRetry(ctx context.Context, conn *sql.Conn, errDuringComm
})
}

func getCommitResponse(conn *sql.Conn) (resp *spanner.CommitResponse, err error) {
if err := conn.Raw(func(driverConn any) error {
spannerConn, ok := driverConn.(SpannerConn)
if !ok {
return spanner.ToSpannerError(status.Error(codes.InvalidArgument, "not a Spanner connection"))
}
resp, err = spannerConn.CommitResponse()
return err
}); err != nil {
return nil, err
}
return resp, nil
}

type ReadWriteTransactionOptions struct {
// TransactionOptions are passed through to the Spanner client to use for
// the read/write transaction.
Expand Down
Loading
Loading