Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 60 additions & 13 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,9 @@ func RunTransaction(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func
if err != nil {
return err
}
defer conn.Close()
defer func() {
_ = conn.Close()
}()

// We don't need to keep track of a running checksum for retries when using
// this method, so we disable internal retries.
Expand All @@ -452,7 +454,9 @@ func RunTransaction(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func
}
// Reset the flag for internal retries after the transaction (if applicable).
if origRetryAborts {
defer func() { _ = spannerConn.SetRetryAbortsInternally(origRetryAborts) }()
defer func() {
_ = spannerConn.SetRetryAbortsInternally(origRetryAborts)
}()
}

tx, err := conn.BeginTx(ctx, opts)
Expand All @@ -461,11 +465,13 @@ func RunTransaction(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func
}
for {
err = f(ctx, tx)
errDuringCommit := false
if err == nil {
err = tx.Commit()
if err == nil {
return nil
}
errDuringCommit = true
}
// Rollback and return the error if:
// 1. The connection is not a Spanner connection.
Expand Down Expand Up @@ -493,12 +499,23 @@ func RunTransaction(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func
}
}

// TODO: Reset the existing transaction for retry instead of creating a new one.
_ = tx.Rollback()
tx, err = conn.BeginTx(ctx, opts)
// Reset the transaction after it was aborted.
err = spannerConn.resetTransactionForRetry(ctx, errDuringCommit)
if err != nil {
_ = tx.Rollback()
return err
}
// This does not actually start a new transaction, instead it
// continues with the previous transaction that was already reset.
// We need to do this, because the sql package registers the
// transaction as 'done' when Commit has been called, also if the
// commit fails.
if errDuringCommit {
tx, err = conn.BeginTx(ctx, opts)
if err != nil {
return err
}
}
}
}

Expand Down Expand Up @@ -596,17 +613,25 @@ type SpannerConn interface {
// this function on different connections to the same database, can
// return the same Spanner client.
UnderlyingClient() (client *spanner.Client, err error)

// resetTransactionForRetry resets the current transaction after it has
// been aborted by Spanner. Calling this function on a transaction that
// has not been aborted is not supported and will cause an error to be
// returned.
resetTransactionForRetry(ctx context.Context, errDuringCommit bool) error
}

type conn struct {
connector *connector
closed bool
client *spanner.Client
adminClient *adminapi.DatabaseAdminClient
tx contextTransaction
commitTs *time.Time
database string
retryAborts bool
connector *connector
closed bool
client *spanner.Client
adminClient *adminapi.DatabaseAdminClient
tx contextTransaction
prevTx contextTransaction
resetForRetry bool
commitTs *time.Time
database string
retryAborts bool

execSingleQuery func(ctx context.Context, c *spanner.Client, statement spanner.Statement, bound spanner.TimestampBound) *spanner.RowIterator
execSingleDMLTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, transactionOptions spanner.TransactionOptions) (int64, time.Time, error)
Expand Down Expand Up @@ -1169,11 +1194,32 @@ func (c *conn) Close() error {
return c.connector.decreaseConnCount()
}

func noTransaction() error {
return status.Errorf(codes.FailedPrecondition, "connection does not have a transaction")
}

func (c *conn) resetTransactionForRetry(ctx context.Context, errDuringCommit bool) error {
if errDuringCommit {
if c.prevTx == nil {
return noTransaction()
}
c.tx = c.prevTx
c.resetForRetry = true
} else if c.tx == nil {
return noTransaction()
}
return c.tx.resetForRetry(ctx)
}

func (c *conn) Begin() (driver.Tx, error) {
return c.BeginTx(context.Background(), driver.TxOptions{})
}

func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
if c.resetForRetry {
c.resetForRetry = false
return c.tx, nil
}
if c.inTransaction() {
return nil, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "already in a transaction"))
}
Expand Down Expand Up @@ -1202,6 +1248,7 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
client: c.client,
rwTx: tx,
close: func(commitTs *time.Time, commitErr error) {
c.prevTx = c.tx
c.tx = nil
if commitErr == nil {
c.commitTs = commitTs
Expand Down
15 changes: 15 additions & 0 deletions transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
type contextTransaction interface {
Commit() error
Rollback() error
resetForRetry(ctx context.Context) error
Query(ctx context.Context, stmt spanner.Statement) rowIterator
ExecContext(ctx context.Context, stmt spanner.Statement) (int64, error)

Expand Down Expand Up @@ -89,6 +90,11 @@ func (tx *readOnlyTransaction) Rollback() error {
return nil
}

func (tx *readOnlyTransaction) resetForRetry(ctx context.Context) error {
// no-op
return nil
}

func (tx *readOnlyTransaction) Query(ctx context.Context, stmt spanner.Statement) rowIterator {
return &readOnlyRowIterator{tx.roTx.Query(ctx, stmt)}
}
Expand Down Expand Up @@ -295,6 +301,15 @@ func (tx *readWriteTransaction) Rollback() error {
return nil
}

func (tx *readWriteTransaction) resetForRetry(ctx context.Context) error {
t, err := tx.rwTx.ResetForRetry(ctx)
if err != nil {
return err
}
tx.rwTx = t
return nil
}

// Query executes a query using the read/write transaction and returns a
// rowIterator that will automatically retry the read/write transaction if the
// transaction is aborted during the query or while iterating the returned rows.
Expand Down
Loading