diff --git a/driver.go b/driver.go index a359701b..147b78b4 100644 --- a/driver.go +++ b/driver.go @@ -431,7 +431,9 @@ func RunTransaction(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func if err != nil { return err } - defer conn.Close() + defer func() { + _ = conn.Close() + }() // We don't need to keep track of a running checksum for retries when using // this method, so we disable internal retries. @@ -452,7 +454,9 @@ func RunTransaction(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func } // Reset the flag for internal retries after the transaction (if applicable). if origRetryAborts { - defer func() { _ = spannerConn.SetRetryAbortsInternally(origRetryAborts) }() + defer func() { + _ = spannerConn.SetRetryAbortsInternally(origRetryAborts) + }() } tx, err := conn.BeginTx(ctx, opts) @@ -461,11 +465,13 @@ func RunTransaction(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func } for { err = f(ctx, tx) + errDuringCommit := false if err == nil { err = tx.Commit() if err == nil { return nil } + errDuringCommit = true } // Rollback and return the error if: // 1. The connection is not a Spanner connection. @@ -493,12 +499,23 @@ func RunTransaction(ctx context.Context, db *sql.DB, opts *sql.TxOptions, f func } } - // TODO: Reset the existing transaction for retry instead of creating a new one. - _ = tx.Rollback() - tx, err = conn.BeginTx(ctx, opts) + // Reset the transaction after it was aborted. + err = spannerConn.resetTransactionForRetry(ctx, errDuringCommit) if err != nil { + _ = tx.Rollback() return err } + // This does not actually start a new transaction, instead it + // continues with the previous transaction that was already reset. + // We need to do this, because the sql package registers the + // transaction as 'done' when Commit has been called, also if the + // commit fails. + if errDuringCommit { + tx, err = conn.BeginTx(ctx, opts) + if err != nil { + return err + } + } } } @@ -596,17 +613,25 @@ type SpannerConn interface { // this function on different connections to the same database, can // return the same Spanner client. UnderlyingClient() (client *spanner.Client, err error) + + // resetTransactionForRetry resets the current transaction after it has + // been aborted by Spanner. Calling this function on a transaction that + // has not been aborted is not supported and will cause an error to be + // returned. + resetTransactionForRetry(ctx context.Context, errDuringCommit bool) error } type conn struct { - connector *connector - closed bool - client *spanner.Client - adminClient *adminapi.DatabaseAdminClient - tx contextTransaction - commitTs *time.Time - database string - retryAborts bool + connector *connector + closed bool + client *spanner.Client + adminClient *adminapi.DatabaseAdminClient + tx contextTransaction + prevTx contextTransaction + resetForRetry bool + commitTs *time.Time + database string + retryAborts bool execSingleQuery func(ctx context.Context, c *spanner.Client, statement spanner.Statement, bound spanner.TimestampBound) *spanner.RowIterator execSingleDMLTransactional func(ctx context.Context, c *spanner.Client, statement spanner.Statement, transactionOptions spanner.TransactionOptions) (int64, time.Time, error) @@ -1169,11 +1194,32 @@ func (c *conn) Close() error { return c.connector.decreaseConnCount() } +func noTransaction() error { + return status.Errorf(codes.FailedPrecondition, "connection does not have a transaction") +} + +func (c *conn) resetTransactionForRetry(ctx context.Context, errDuringCommit bool) error { + if errDuringCommit { + if c.prevTx == nil { + return noTransaction() + } + c.tx = c.prevTx + c.resetForRetry = true + } else if c.tx == nil { + return noTransaction() + } + return c.tx.resetForRetry(ctx) +} + func (c *conn) Begin() (driver.Tx, error) { return c.BeginTx(context.Background(), driver.TxOptions{}) } func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + if c.resetForRetry { + c.resetForRetry = false + return c.tx, nil + } if c.inTransaction() { return nil, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "already in a transaction")) } @@ -1202,6 +1248,7 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e client: c.client, rwTx: tx, close: func(commitTs *time.Time, commitErr error) { + c.prevTx = c.tx c.tx = nil if commitErr == nil { c.commitTs = commitTs diff --git a/transaction.go b/transaction.go index 7f6aaec0..465024a5 100644 --- a/transaction.go +++ b/transaction.go @@ -32,6 +32,7 @@ import ( type contextTransaction interface { Commit() error Rollback() error + resetForRetry(ctx context.Context) error Query(ctx context.Context, stmt spanner.Statement) rowIterator ExecContext(ctx context.Context, stmt spanner.Statement) (int64, error) @@ -89,6 +90,11 @@ func (tx *readOnlyTransaction) Rollback() error { return nil } +func (tx *readOnlyTransaction) resetForRetry(ctx context.Context) error { + // no-op + return nil +} + func (tx *readOnlyTransaction) Query(ctx context.Context, stmt spanner.Statement) rowIterator { return &readOnlyRowIterator{tx.roTx.Query(ctx, stmt)} } @@ -295,6 +301,15 @@ func (tx *readWriteTransaction) Rollback() error { return nil } +func (tx *readWriteTransaction) resetForRetry(ctx context.Context) error { + t, err := tx.rwTx.ResetForRetry(ctx) + if err != nil { + return err + } + tx.rwTx = t + return nil +} + // Query executes a query using the read/write transaction and returns a // rowIterator that will automatically retry the read/write transaction if the // transaction is aborted during the query or while iterating the returned rows.