diff --git a/aborted_transactions_test.go b/aborted_transactions_test.go index 6962e96f..835f971d 100644 --- a/aborted_transactions_test.go +++ b/aborted_transactions_test.go @@ -41,6 +41,9 @@ func TestCommitAborted(t *testing.T) { if err != nil { t.Fatalf("begin failed: %v", err) } + if _, err := tx.ExecContext(ctx, testutil.UpdateBarSetFoo); err != nil { + t.Fatal(err) + } server.TestSpanner.PutExecutionTime(testutil.MethodCommitTransaction, testutil.SimulatedExecutionTime{ Errors: []error{status.Error(codes.Aborted, "Aborted")}, }) @@ -51,7 +54,7 @@ func TestCommitAborted(t *testing.T) { reqs := server.TestSpanner.DrainRequestsFromServer() commitReqs := testutil.RequestsOfType(reqs, reflect.TypeOf(&sppb.CommitRequest{})) if g, w := len(commitReqs), 2; g != w { - t.Fatalf("commit request count mismatch\nGot: %v\nWant: %v", g, w) + t.Fatalf("commit request count mismatch\n Got: %v\nWant: %v", g, w) } // Verify that the db is still usable. @@ -117,6 +120,9 @@ func TestCommitAbortedWithInternalRetriesDisabled(t *testing.T) { if err != nil { t.Fatalf("begin failed: %v", err) } + if _, err := tx.ExecContext(ctx, testutil.UpdateBarSetFoo); err != nil { + t.Fatal(err) + } server.TestSpanner.PutExecutionTime(testutil.MethodCommitTransaction, testutil.SimulatedExecutionTime{ Errors: []error{status.Error(codes.Aborted, "Aborted")}, }) diff --git a/client_side_statement_test.go b/client_side_statement_test.go index 8d7130d1..45bc1d7a 100644 --- a/client_side_statement_test.go +++ b/client_side_statement_test.go @@ -65,7 +65,7 @@ func TestStatementExecutor_StartBatchDdl(t *testing.T) { } // Starting a DDL batch while the connection is in a transaction is not allowed. - c.tx = &readWriteTransaction{} + c.tx = &delegatingTransaction{conn: c, ctx: ctx} if _, err := c.ExecContext(ctx, "start batch ddl", []driver.NamedValue{}); spanner.ErrCode(err) != codes.FailedPrecondition { t.Fatalf("error mismatch for starting a DDL batch while in a transaction\nGot: %v\nWant: %v", spanner.ErrCode(err), codes.FailedPrecondition) } @@ -102,13 +102,13 @@ func TestStatementExecutor_StartBatchDml(t *testing.T) { } // Starting a DML batch while the connection is in a read-only transaction is not allowed. - c.tx = &readOnlyTransaction{logger: noopLogger} + c.tx = &delegatingTransaction{conn: c, contextTransaction: &readOnlyTransaction{logger: noopLogger}} if _, err := c.ExecContext(ctx, "start batch dml", []driver.NamedValue{}); spanner.ErrCode(err) != codes.FailedPrecondition { t.Fatalf("error mismatch for starting a DML batch while in a read-only transaction\nGot: %v\nWant: %v", spanner.ErrCode(err), codes.FailedPrecondition) } // Starting a DML batch while the connection is in a read/write transaction is allowed. - c.tx = &readWriteTransaction{logger: noopLogger} + c.tx = &delegatingTransaction{conn: c, contextTransaction: &readWriteTransaction{logger: noopLogger}} if _, err := c.ExecContext(ctx, "start batch dml", []driver.NamedValue{}); err != nil { t.Fatalf("could not start a DML batch while in a read/write transaction: %v", err) } diff --git a/conn.go b/conn.go index d35b238d..278cff0c 100644 --- a/conn.go +++ b/conn.go @@ -231,20 +231,25 @@ type SpannerConn interface { // returned. resetTransactionForRetry(ctx context.Context, errDuringCommit bool) error - // withTempTransactionOptions sets the TransactionOptions that should be used - // for the next read/write transaction. This method should only be called - // directly before starting a new read/write transaction. - withTempTransactionOptions(options *ReadWriteTransactionOptions) + // withTransactionCloseFunc sets the close function that should be registered + // on the next transaction on this connection. This method should only be called + // directly before starting a new transaction. + withTransactionCloseFunc(close func()) - // withTempReadOnlyTransactionOptions sets the options that should be used - // for the next read-only transaction. This method should only be called - // directly before starting a new read-only transaction. - withTempReadOnlyTransactionOptions(options *ReadOnlyTransactionOptions) + // setReadWriteTransactionOptions sets the ReadWriteTransactionOptions that should be + // used for the current read/write transaction. This method should be called right + // after starting a new read/write transaction. + setReadWriteTransactionOptions(options *ReadWriteTransactionOptions) - // withTempBatchReadOnlyTransactionOptions sets the options that should be used - // for the next batch read-only transaction. This method should only be called - // directly before starting a new batch read-only transaction. - withTempBatchReadOnlyTransactionOptions(options *BatchReadOnlyTransactionOptions) + // setReadOnlyTransactionOptions sets the options that should be used + // for the current read-only transaction. This method should be called + // right after starting a new read-only transaction. + setReadOnlyTransactionOptions(options *ReadOnlyTransactionOptions) + + // setBatchReadOnlyTransactionOptions sets the options that should be used + // for the current batch read-only transaction. This method should be called + // right after starting a new batch read-only transaction. + setBatchReadOnlyTransactionOptions(options *BatchReadOnlyTransactionOptions) } var _ SpannerConn = &conn{} @@ -257,8 +262,8 @@ type conn struct { adminClient *adminapi.DatabaseAdminClient connId string logger *slog.Logger - tx contextTransaction - prevTx contextTransaction + tx *delegatingTransaction + prevTx *delegatingTransaction resetForRetry bool database string @@ -536,7 +541,7 @@ func (c *conn) InDDLBatch() bool { } func (c *conn) InDMLBatch() bool { - return (c.batch != nil && c.batch.tp == parser.BatchTypeDml) || (c.inReadWriteTransaction() && c.tx.(*readWriteTransaction).batch != nil) + return (c.batch != nil && c.batch.tp == parser.BatchTypeDml) || (c.inTransaction() && c.tx.IsInBatch()) } func (c *conn) GetBatchedStatements() []spanner.Statement { @@ -572,9 +577,6 @@ func (c *conn) startBatchDML(automatic bool) (driver.Result, error) { if c.batch != nil { return nil, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "This connection already has an active batch.")) } - if c.inReadOnlyTransaction() { - return nil, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "This connection has an active read-only transaction. Read-only transactions cannot execute DML batches.")) - } c.logger.Debug("starting dml batch outside transaction") c.batch = &batch{tp: parser.BatchTypeDml, options: execOptions} return driver.ResultNoRows, nil @@ -660,8 +662,8 @@ func (c *conn) execBatchDML(ctx context.Context, statements []spanner.Statement, var affected []int64 var err error - if c.inTransaction() { - tx, ok := c.tx.(*readWriteTransaction) + if c.inTransaction() && c.tx.contextTransaction != nil { + tx, ok := c.tx.contextTransaction.(*readWriteTransaction) if !ok { return nil, status.Errorf(codes.FailedPrecondition, "connection is in a transaction that is not a read/write transaction") } @@ -949,7 +951,7 @@ func (c *conn) execContext(ctx context.Context, query string, execOptions *ExecO } // Start an automatic DML batch. - if c.AutoBatchDml() && !c.inBatch() && c.inReadWriteTransaction() { + if c.AutoBatchDml() && !c.inBatch() && c.inTransaction() && statementInfo.StatementType == parser.StatementTypeDml { if _, err := c.startBatchDML( /* automatic = */ true); err != nil { return nil, err } @@ -1046,14 +1048,14 @@ func (c *conn) resetTransactionForRetry(ctx context.Context, errDuringCommit boo return c.tx.resetForRetry(ctx) } -func (c *conn) withTempTransactionOptions(options *ReadWriteTransactionOptions) { +func (c *conn) withTransactionCloseFunc(close func()) { + c.tempTransactionCloseFunc = close +} + +func (c *conn) setReadWriteTransactionOptions(options *ReadWriteTransactionOptions) { if options == nil { return } - c.tempTransactionCloseFunc = options.close - // Start a transaction for the connection state, so we can set the transaction options - // as local options in the current transaction. - _ = c.state.Begin() if options.DisableInternalRetries { _ = propertyRetryAbortsInternally.SetLocalValue(c.state, !options.DisableInternalRetries) } @@ -1103,14 +1105,10 @@ func (c *conn) getTransactionOptions(execOptions *ExecOptions) ReadWriteTransact return txOpts } -func (c *conn) withTempReadOnlyTransactionOptions(options *ReadOnlyTransactionOptions) { +func (c *conn) setReadOnlyTransactionOptions(options *ReadOnlyTransactionOptions) { if options == nil { return } - c.tempTransactionCloseFunc = options.close - // Start a transaction for the connection state, so we can set the transaction options - // as local options in the current transaction. - _ = c.state.Begin() if options.BeginTransactionOption != spanner.DefaultBeginTransaction { _ = propertyBeginTransactionOption.SetLocalValue(c.state, options.BeginTransactionOption) } @@ -1123,14 +1121,10 @@ func (c *conn) getReadOnlyTransactionOptions() ReadOnlyTransactionOptions { return ReadOnlyTransactionOptions{TimestampBound: c.ReadOnlyStaleness(), BeginTransactionOption: c.convertDefaultBeginTransactionOption(propertyBeginTransactionOption.GetValueOrDefault(c.state))} } -func (c *conn) withTempBatchReadOnlyTransactionOptions(options *BatchReadOnlyTransactionOptions) { +func (c *conn) setBatchReadOnlyTransactionOptions(options *BatchReadOnlyTransactionOptions) { if options == nil { return } - c.tempTransactionCloseFunc = options.close - // Start a transaction for the connection state, so we can set the transaction options - // as local options in the current transaction. - _ = c.state.Begin() if options.TimestampBound.String() != "(strong)" { _ = propertyReadOnlyStaleness.SetLocalValue(c.state, options.TimestampBound) } @@ -1144,9 +1138,9 @@ func (c *conn) getBatchReadOnlyTransactionOptions() BatchReadOnlyTransactionOpti // It is exported for internal reasons, and may receive breaking changes without prior notice. // // BeginReadOnlyTransaction starts a new read-only transaction on this connection. -func (c *conn) BeginReadOnlyTransaction(ctx context.Context, options *ReadOnlyTransactionOptions) (driver.Tx, error) { - c.withTempReadOnlyTransactionOptions(options) - tx, err := c.BeginTx(ctx, driver.TxOptions{ReadOnly: true}) +func (c *conn) BeginReadOnlyTransaction(ctx context.Context, options *ReadOnlyTransactionOptions, close func()) (driver.Tx, error) { + tx, err := c.beginTx(ctx, driver.TxOptions{ReadOnly: true}, close) + c.setReadOnlyTransactionOptions(options) if err != nil { return nil, err } @@ -1157,9 +1151,9 @@ func (c *conn) BeginReadOnlyTransaction(ctx context.Context, options *ReadOnlyTr // It is exported for internal reasons, and may receive breaking changes without prior notice. // // BeginReadWriteTransaction starts a new read/write transaction on this connection. -func (c *conn) BeginReadWriteTransaction(ctx context.Context, options *ReadWriteTransactionOptions) (driver.Tx, error) { - c.withTempTransactionOptions(options) - tx, err := c.BeginTx(ctx, driver.TxOptions{}) +func (c *conn) BeginReadWriteTransaction(ctx context.Context, options *ReadWriteTransactionOptions, close func()) (driver.Tx, error) { + tx, err := c.beginTx(ctx, driver.TxOptions{}, close) + c.setReadWriteTransactionOptions(options) if err != nil { return nil, err } @@ -1182,21 +1176,6 @@ func (c *conn) beginTx(ctx context.Context, driverOpts driver.TxOptions, closeFu c.resetForRetry = false return c.tx, nil } - // Also start a transaction on the ConnectionState if the BeginTx call was successful. - defer func() { - if c.tx != nil { - _ = c.state.Begin() - } else { - // Rollback in case the connection state transaction was started before this function - // was called, for example if the caller set temporary transaction options. - _ = c.state.Rollback() - } - }() - - // TODO: Delay the actual determination of the transaction type until the first query. - // This is required in order to support SET TRANSACTION READ {ONLY | WRITE} - readOnlyTxOpts := c.getReadOnlyTransactionOptions() - batchReadOnlyTxOpts := c.getBatchReadOnlyTransactionOptions() if c.inTransaction() { return nil, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "already in a transaction")) } @@ -1234,94 +1213,105 @@ func (c *conn) beginTx(ctx context.Context, driverOpts driver.TxOptions, closeFu if closeFunc == nil { closeFunc = func() {} } + if err := c.state.Begin(); err != nil { + return nil, err + } + c.clearCommitResponse() + if isolationLevelFromTxOpts != spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED { + _ = propertyIsolationLevel.SetLocalValue(c.state, sql.IsolationLevel(driverOpts.Isolation)) + } + // TODO: Figure out how to distinguish between 'use the default' and 'use read/write'. if driverOpts.ReadOnly { + _ = propertyTransactionReadOnly.SetLocalValue(c.state, true) + } + if batchReadOnly { + _ = propertyTransactionBatchReadOnly.SetLocalValue(c.state, true) + } + if disableRetryAborts { + _ = propertyRetryAbortsInternally.SetLocalValue(c.state, false) + } + + c.tx = &delegatingTransaction{ + conn: c, + ctx: ctx, + close: func(result txResult) { + closeFunc() + if result == txResultCommit { + _ = c.state.Commit() + } else { + _ = c.state.Rollback() + } + c.tx = nil + }, + } + return c.tx, nil +} + +func (c *conn) activateTransaction() (contextTransaction, error) { + closeFunc := c.tx.close + if propertyTransactionReadOnly.GetValueOrDefault(c.state) { var logger *slog.Logger var ro *spanner.ReadOnlyTransaction var bo *spanner.BatchReadOnlyTransaction - if batchReadOnly { + if propertyTransactionBatchReadOnly.GetValueOrDefault(c.state) { logger = c.logger.With("tx", "batchro") var err error // BatchReadOnly transactions (currently) do not support inline-begin. // This means that the transaction options must be supplied here, and not through a callback. - bo, err = c.client.BatchReadOnlyTransaction(ctx, batchReadOnlyTxOpts.TimestampBound) + bo, err = c.client.BatchReadOnlyTransaction(c.tx.ctx, propertyReadOnlyStaleness.GetValueOrDefault(c.state)) if err != nil { return nil, err } ro = &bo.ReadOnlyTransaction } else { logger = c.logger.With("tx", "ro") - ro = c.client.ReadOnlyTransaction().WithBeginTransactionOption(readOnlyTxOpts.BeginTransactionOption) + beginTxOpt := c.convertDefaultBeginTransactionOption(propertyBeginTransactionOption.GetValueOrDefault(c.state)) + ro = c.client.ReadOnlyTransaction().WithBeginTransactionOption(beginTxOpt) } - c.tx = &readOnlyTransaction{ + return &readOnlyTransaction{ roTx: ro, boTx: bo, logger: logger, - close: func(result txResult) { - closeFunc() - if result == txResultCommit { - _ = c.state.Commit() - } else { - _ = c.state.Rollback() - } - c.tx = nil - }, + close: closeFunc, timestampBoundCallback: func() spanner.TimestampBound { return propertyReadOnlyStaleness.GetValueOrDefault(c.state) }, - } - return c.tx, nil + }, nil } - // These options are only used to determine how to start the transaction. - // All other options are fetched in a callback that is called when the transaction is actually started. - // That callback reads all transaction options from the connection state at that moment. This allows - // applications to execute a series of statement like this: - // BEGIN TRANSACTION; - // SET LOCAL transaction_tag='my_tag'; - // SET LOCAL commit_priority=LOW; - // INSERT INTO my_table ... -- This starts the transaction with the options above included. opts := spanner.TransactionOptions{} opts.BeginTransactionOption = c.convertDefaultBeginTransactionOption(propertyBeginTransactionOption.GetValueOrDefault(c.state)) - tx, err := spanner.NewReadWriteStmtBasedTransactionWithCallbackForOptions(ctx, c.client, opts, func() spanner.TransactionOptions { + tx, err := spanner.NewReadWriteStmtBasedTransactionWithCallbackForOptions(c.tx.ctx, c.client, opts, func() spanner.TransactionOptions { defer func() { // Reset the transaction_tag after starting the transaction. _ = propertyTransactionTag.ResetValue(c.state, connectionstate.ContextUser) }() - return c.effectiveTransactionOptions(isolationLevelFromTxOpts, c.options( /*reset=*/ true)) + return c.effectiveTransactionOptions(spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED, c.options( /*reset=*/ true)) }) if err != nil { return nil, err } logger := c.logger.With("tx", "rw") - c.tx = &readWriteTransaction{ - ctx: ctx, + return &readWriteTransaction{ + ctx: c.tx.ctx, conn: c, logger: logger, rwTx: tx, close: func(result txResult, commitResponse *spanner.CommitResponse, commitErr error) { - closeFunc() c.prevTx = c.tx - c.tx = nil if commitErr == nil { c.setCommitResponse(commitResponse) - if result == txResultCommit { - _ = c.state.Commit() - } else { - _ = c.state.Rollback() - } + closeFunc(result) } else { - _ = c.state.Rollback() + closeFunc(txResultRollback) } }, - // Disable internal retries if any of these options have been set. retryAborts: sync.OnceValue(func() bool { - return c.RetryAbortsInternally() && !disableRetryAborts + return c.RetryAbortsInternally() }), - } - c.clearCommitResponse() - return c.tx, nil + }, nil } func (c *conn) effectiveTransactionOptions(isolationLevelFromTxOpts spannerpb.TransactionOptions_IsolationLevel, execOptions *ExecOptions) spanner.TransactionOptions { @@ -1347,22 +1337,6 @@ func (c *conn) inTransaction() bool { return c.tx != nil } -func (c *conn) inReadOnlyTransaction() bool { - if c.tx != nil { - _, ok := c.tx.(*readOnlyTransaction) - return ok - } - return false -} - -func (c *conn) inReadWriteTransaction() bool { - if c.tx != nil { - _, ok := c.tx.(*readWriteTransaction) - return ok - } - return false -} - // Commit is not part of the public API of the database/sql driver. // It is exported for internal reasons, and may receive breaking changes without prior notice. // diff --git a/conn_with_mockserver_test.go b/conn_with_mockserver_test.go index 5ceba7ed..ffe0df58 100644 --- a/conn_with_mockserver_test.go +++ b/conn_with_mockserver_test.go @@ -20,6 +20,7 @@ import ( "fmt" "reflect" "testing" + "time" "cloud.google.com/go/longrunning/autogen/longrunningpb" "cloud.google.com/go/spanner" @@ -82,6 +83,79 @@ func TestTwoTransactionsOnOneConn(t *testing.T) { } } +func TestEmptyTransaction(t *testing.T) { + t.Parallel() + + db, server, teardown := setupTestDBConnection(t) + defer teardown() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + db.SetMaxOpenConns(1) + + c, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer silentClose(c) + // Run twice to ensure that there is no connection leak. + for range 2 { + tx, err := c.BeginTx(ctx, &sql.TxOptions{}) + if err != nil { + t.Fatal(err) + } + if err := tx.Commit(); err != nil { + t.Fatal(err) + } + + // An empty transaction should be a no-op and not lead to any requests being sent to Spanner. + requests := server.TestSpanner.DrainRequestsFromServer() + beginRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 0; g != w { + t.Fatalf("begin requests count mismatch\n Got: %v\nWant: %v", g, w) + } + commitRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{})) + if g, w := len(commitRequests), 0; g != w { + t.Fatalf("commit requests count mismatch\n Got: %v\nWant: %v", g, w) + } + } +} + +func TestEmptyTransactionUsingSql(t *testing.T) { + t.Parallel() + + db, server, teardown := setupTestDBConnection(t) + defer teardown() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + db.SetMaxOpenConns(1) + + c, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer silentClose(c) + // Run twice to ensure that there is no connection leak. + for range 2 { + if _, err := c.ExecContext(ctx, "begin"); err != nil { + t.Fatal(err) + } + if _, err := c.ExecContext(ctx, "commit"); err != nil { + t.Fatal(err) + } + + // An empty transaction should be a no-op and not lead to any requests being sent to Spanner. + requests := server.TestSpanner.DrainRequestsFromServer() + beginRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 0; g != w { + t.Fatalf("begin requests count mismatch\n Got: %v\nWant: %v", g, w) + } + commitRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{})) + if g, w := len(commitRequests), 0; g != w { + t.Fatalf("commit requests count mismatch\n Got: %v\nWant: %v", g, w) + } + } +} + func TestTwoQueriesOnOneConn(t *testing.T) { t.Parallel() diff --git a/connection_properties.go b/connection_properties.go index 6d65e02e..31218cd7 100644 --- a/connection_properties.go +++ b/connection_properties.go @@ -296,6 +296,15 @@ var propertyReturnCommitStats = createConnectionProperty( connectionstate.ContextUser, connectionstate.ConvertBool, ) +var propertyTransactionBatchReadOnly = createConnectionProperty( + "transaction_batch_read_only", + "transaction_batch_read_only indicates whether read-only transactions on this connection should use a batch read-only transaction.", + false, + false, + nil, + connectionstate.ContextUser, + connectionstate.ConvertBool, +) // ------------------------------------------------------------------------------------------------ // Statement connection properties. diff --git a/driver.go b/driver.go index 7ac7400f..5359f5ab 100644 --- a/driver.go +++ b/driver.go @@ -996,6 +996,7 @@ func runTransactionWithOptions(ctx context.Context, db *sql.DB, opts *sql.TxOpti _ = conn.Close() }() + tx, err := conn.BeginTx(ctx, opts) // We don't need to keep track of a running checksum for retries when using // this method, so we disable internal retries. // Retries will instead be handled by the loop below. @@ -1011,13 +1012,12 @@ func runTransactionWithOptions(ctx context.Context, db *sql.DB, opts *sql.TxOpti // It is not a Spanner connection, so just ignore and continue without any special handling. return nil } - spannerConn.withTempTransactionOptions(transactionOptions) + spannerConn.setReadWriteTransactionOptions(transactionOptions) return nil }); err != nil { return nil, err } - tx, err := conn.BeginTx(ctx, opts) if err != nil { return nil, err } @@ -1130,8 +1130,6 @@ type ReadWriteTransactionOptions struct { // disabled, and any Aborted error from Spanner is propagated to the // application. DisableInternalRetries bool - - close func() } // BeginReadWriteTransaction begins a read/write transaction on a Spanner database. @@ -1146,15 +1144,17 @@ func BeginReadWriteTransaction(ctx context.Context, db *sql.DB, options ReadWrit if err != nil { return nil, err } - options.close = func() { + if err := withTransactionCloseFunc(conn, func() { // Close the connection asynchronously, as the transaction will still // be active when we hit this point. go conn.Close() + }); err != nil { + return nil, err } + tx, err := conn.BeginTx(ctx, &sql.TxOptions{}) if err := withTempReadWriteTransactionOptions(conn, &options); err != nil { return nil, err } - tx, err := conn.BeginTx(ctx, &sql.TxOptions{}) if err != nil { return nil, err } @@ -1168,7 +1168,7 @@ func withTempReadWriteTransactionOptions(conn *sql.Conn, options *ReadWriteTrans // It is not a Spanner connection. return spanner.ToSpannerError(status.Error(codes.FailedPrecondition, "This function can only be used with a Spanner connection")) } - spannerConn.withTempTransactionOptions(options) + spannerConn.setReadWriteTransactionOptions(options) return nil }) } @@ -1178,8 +1178,6 @@ func withTempReadWriteTransactionOptions(conn *sql.Conn, options *ReadWriteTrans type ReadOnlyTransactionOptions struct { TimestampBound spanner.TimestampBound BeginTransactionOption spanner.BeginTransactionOption - - close func() } // BeginReadOnlyTransaction begins a read-only transaction on a Spanner database. @@ -1192,15 +1190,17 @@ func BeginReadOnlyTransaction(ctx context.Context, db *sql.DB, options ReadOnlyT if err != nil { return nil, err } - options.close = func() { + if err := withTransactionCloseFunc(conn, func() { // Close the connection asynchronously, as the transaction will still // be active when we hit this point. go conn.Close() + }); err != nil { + return nil, err } + tx, err := conn.BeginTx(ctx, &sql.TxOptions{ReadOnly: true}) if err := withTempReadOnlyTransactionOptions(conn, &options); err != nil { return nil, err } - tx, err := conn.BeginTx(ctx, &sql.TxOptions{ReadOnly: true}) if err != nil { clearTempReadOnlyTransactionOptions(conn) return nil, err @@ -1208,6 +1208,18 @@ func BeginReadOnlyTransaction(ctx context.Context, db *sql.DB, options ReadOnlyT return tx, nil } +func withTransactionCloseFunc(conn *sql.Conn, close func()) error { + return conn.Raw(func(driverConn any) error { + spannerConn, ok := driverConn.(SpannerConn) + if !ok { + // It is not a Spanner connection. + return spanner.ToSpannerError(status.Error(codes.FailedPrecondition, "This function can only be used with a Spanner connection")) + } + spannerConn.withTransactionCloseFunc(close) + return nil + }) +} + func withTempReadOnlyTransactionOptions(conn *sql.Conn, options *ReadOnlyTransactionOptions) error { return conn.Raw(func(driverConn any) error { spannerConn, ok := driverConn.(SpannerConn) @@ -1215,7 +1227,7 @@ func withTempReadOnlyTransactionOptions(conn *sql.Conn, options *ReadOnlyTransac // It is not a Spanner connection. return spanner.ToSpannerError(status.Error(codes.FailedPrecondition, "This function can only be used with a Spanner connection")) } - spannerConn.withTempReadOnlyTransactionOptions(options) + spannerConn.setReadOnlyTransactionOptions(options) return nil }) } diff --git a/driver_test.go b/driver_test.go index dc09940f..746df7b0 100644 --- a/driver_test.go +++ b/driver_test.go @@ -492,12 +492,12 @@ func TestConnection_Reset(t *testing.T) { propertyCommitResponse.Key(): propertyCommitResponse.CreateTypedInitialValue(nil), }), batch: &batch{tp: parser.BatchTypeDml}, - tx: &readOnlyTransaction{ + tx: &delegatingTransaction{contextTransaction: &readOnlyTransaction{ logger: noopLogger, close: func(_ txResult) { txClosed = true }, - }, + }}, } c.setCommitResponse(&spanner.CommitResponse{}) @@ -525,7 +525,7 @@ func TestConnection_NoNestedTransactions(t *testing.T) { c := conn{ logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeTransactional, map[string]connectionstate.ConnectionPropertyValue{}), - tx: &readOnlyTransaction{}, + tx: &delegatingTransaction{}, } _, err := c.BeginTx(context.Background(), driver.TxOptions{}) if err == nil { @@ -571,9 +571,9 @@ func TestConn_StartBatchDdl(t *testing.T) { {"Default", &conn{logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeTransactional, map[string]connectionstate.ConnectionPropertyValue{})}, false}, {"In DDL batch", &conn{logger: noopLogger, batch: &batch{tp: parser.BatchTypeDdl}, state: createInitialConnectionState(connectionstate.TypeTransactional, map[string]connectionstate.ConnectionPropertyValue{})}, true}, {"In DML batch", &conn{logger: noopLogger, batch: &batch{tp: parser.BatchTypeDml}, state: createInitialConnectionState(connectionstate.TypeTransactional, map[string]connectionstate.ConnectionPropertyValue{})}, true}, - {"In read/write transaction", &conn{logger: noopLogger, tx: &readWriteTransaction{}, state: createInitialConnectionState(connectionstate.TypeTransactional, map[string]connectionstate.ConnectionPropertyValue{})}, true}, - {"In read-only transaction", &conn{logger: noopLogger, tx: &readOnlyTransaction{}, state: createInitialConnectionState(connectionstate.TypeTransactional, map[string]connectionstate.ConnectionPropertyValue{})}, true}, - {"In read/write transaction with a DML batch", &conn{logger: noopLogger, tx: &readWriteTransaction{batch: &batch{tp: parser.BatchTypeDml}}, state: createInitialConnectionState(connectionstate.TypeTransactional, map[string]connectionstate.ConnectionPropertyValue{})}, true}, + {"In read/write transaction", &conn{logger: noopLogger, tx: &delegatingTransaction{contextTransaction: &readWriteTransaction{}}, state: createInitialConnectionState(connectionstate.TypeTransactional, map[string]connectionstate.ConnectionPropertyValue{})}, true}, + {"In read-only transaction", &conn{logger: noopLogger, tx: &delegatingTransaction{contextTransaction: &readOnlyTransaction{}}, state: createInitialConnectionState(connectionstate.TypeTransactional, map[string]connectionstate.ConnectionPropertyValue{})}, true}, + {"In read/write transaction with a DML batch", &conn{logger: noopLogger, tx: &delegatingTransaction{contextTransaction: &readWriteTransaction{batch: &batch{tp: parser.BatchTypeDml}}}, state: createInitialConnectionState(connectionstate.TypeTransactional, map[string]connectionstate.ConnectionPropertyValue{})}, true}, } { err := test.c.StartBatchDDL() if test.wantErr { @@ -600,9 +600,9 @@ func TestConn_StartBatchDml(t *testing.T) { {"Default", &conn{logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeTransactional, map[string]connectionstate.ConnectionPropertyValue{})}, false}, {"In DDL batch", &conn{logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeTransactional, map[string]connectionstate.ConnectionPropertyValue{}), batch: &batch{tp: parser.BatchTypeDdl}}, true}, {"In DML batch", &conn{logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeTransactional, map[string]connectionstate.ConnectionPropertyValue{}), batch: &batch{tp: parser.BatchTypeDml}}, true}, - {"In read/write transaction", &conn{logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeTransactional, map[string]connectionstate.ConnectionPropertyValue{}), tx: &readWriteTransaction{logger: noopLogger}}, false}, - {"In read-only transaction", &conn{logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeTransactional, map[string]connectionstate.ConnectionPropertyValue{}), tx: &readOnlyTransaction{logger: noopLogger}}, true}, - {"In read/write transaction with a DML batch", &conn{logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeTransactional, map[string]connectionstate.ConnectionPropertyValue{}), tx: &readWriteTransaction{logger: noopLogger, batch: &batch{tp: parser.BatchTypeDml}}}, true}, + {"In read/write transaction", &conn{logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeTransactional, map[string]connectionstate.ConnectionPropertyValue{}), tx: &delegatingTransaction{contextTransaction: &readWriteTransaction{logger: noopLogger}}}, false}, + {"In read-only transaction", &conn{logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeTransactional, map[string]connectionstate.ConnectionPropertyValue{}), tx: &delegatingTransaction{contextTransaction: &readOnlyTransaction{logger: noopLogger}}}, true}, + {"In read/write transaction with a DML batch", &conn{logger: noopLogger, state: createInitialConnectionState(connectionstate.TypeTransactional, map[string]connectionstate.ConnectionPropertyValue{}), tx: &delegatingTransaction{contextTransaction: &readWriteTransaction{logger: noopLogger, batch: &batch{tp: parser.BatchTypeDml}}}}, true}, } { err := test.c.StartBatchDML() if test.wantErr { diff --git a/driver_with_mockserver_test.go b/driver_with_mockserver_test.go index 209c1fa1..b059aa63 100644 --- a/driver_with_mockserver_test.go +++ b/driver_with_mockserver_test.go @@ -570,7 +570,7 @@ func TestReadOnlyTransactionWithOptions(t *testing.T) { requests = server.TestSpanner.DrainRequestsFromServer() beginReadOnlyRequests = filterBeginReadOnlyRequests(testutil.RequestsOfType(requests, reflect.TypeOf(&sppb.BeginTransactionRequest{}))) if g, w := len(beginReadOnlyRequests), 0; g != w { - t.Fatalf("begin requests count mismatch\nGot: %v\nWant: %v", g, w) + t.Fatalf("begin requests count mismatch\n Got: %v\nWant: %v", g, w) } executeRequests = testutil.RequestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{})) if g, w := len(executeRequests), 1; g != w { @@ -2723,6 +2723,9 @@ func TestShowAndSetVariableRetryAbortsInternally(t *testing.T) { // Check that the behavior matches the setting. tx, _ := c.BeginTx(ctx, nil) + if _, err := tx.ExecContext(ctx, testutil.UpdateBarSetFoo); err != nil { + t.Fatal(err) + } server.TestSpanner.PutExecutionTime(testutil.MethodCommitTransaction, testutil.SimulatedExecutionTime{ Errors: []error{gstatus.Error(codes.Aborted, "Aborted")}, }) @@ -3242,6 +3245,9 @@ func TestCommitResponse(t *testing.T) { if err != nil { t.Fatalf("failed to start transaction: %v", err) } + if _, err := tx.ExecContext(ctx, testutil.UpdateBarSetFoo); err != nil { + t.Fatal(err) + } if err := tx.Commit(); err != nil { t.Fatalf("commit failed: %v", err) } @@ -3410,6 +3416,9 @@ func TestShowVariableCommitTimestamp(t *testing.T) { t.Fatalf("failed to get a connection: %v", err) } tx, err := conn.BeginTx(ctx, nil) + if _, err := conn.ExecContext(ctx, testutil.UpdateBarSetFoo); err != nil { + t.Fatal(err) + } if err != nil { t.Fatalf("failed to start transaction: %v", err) } @@ -4562,7 +4571,8 @@ func TestRunTransaction(t *testing.T) { defer silentClose(rows) // Verify that internal retries are disabled during RunTransaction txi := reflect.ValueOf(tx).Elem().FieldByName("txi") - rwTx := (*readWriteTransaction)(txi.Elem().UnsafePointer()) + delegatingTx := (*delegatingTransaction)(txi.Elem().UnsafePointer()) + rwTx := delegatingTx.contextTransaction.(*readWriteTransaction) // Verify that getting the transaction through reflection worked. if g, w := rwTx.ctx, ctx; g != w { return fmt.Errorf("getting the transaction through reflection failed") @@ -5023,7 +5033,8 @@ func TestBeginReadWriteTransaction(t *testing.T) { } // Verify that internal retries are disabled during this transaction. txi := reflect.ValueOf(tx).Elem().FieldByName("txi") - rwTx := (*readWriteTransaction)(txi.Elem().UnsafePointer()) + delegatingTx := (*delegatingTransaction)(txi.Elem().UnsafePointer()) + rwTx := delegatingTx.contextTransaction.(*readWriteTransaction) // Verify that getting the transaction through reflection worked. if g, w := rwTx.ctx, ctx; g != w { t.Fatal("getting the transaction through reflection failed") diff --git a/partitioned_query.go b/partitioned_query.go index 018f7b7d..30c2f6fb 100644 --- a/partitioned_query.go +++ b/partitioned_query.go @@ -29,8 +29,6 @@ import ( type BatchReadOnlyTransactionOptions struct { TimestampBound spanner.TimestampBound - - close func() } // PartitionedQueryOptions are used for queries that use the AutoPartitionQuery @@ -183,15 +181,17 @@ func BeginBatchReadOnlyTransaction(ctx context.Context, db *sql.DB, options Batc if err != nil { return nil, err } - options.close = func() { + if err := withTransactionCloseFunc(conn, func() { // Close the connection asynchronously, as the transaction will still // be active when we hit this point. go conn.Close() + }); err != nil { + return nil, err } + tx, err := conn.BeginTx(ctx, &sql.TxOptions{ReadOnly: true, Isolation: WithBatchReadOnly(sql.LevelDefault)}) if err := withTempBatchReadOnlyTransactionOptions(conn, &options); err != nil { return nil, err } - tx, err := conn.BeginTx(ctx, &sql.TxOptions{ReadOnly: true, Isolation: WithBatchReadOnly(sql.LevelDefault)}) if err != nil { clearTempBatchReadOnlyTransactionOptions(conn) return nil, err @@ -206,7 +206,7 @@ func withTempBatchReadOnlyTransactionOptions(conn *sql.Conn, options *BatchReadO // It is not a Spanner connection. return spanner.ToSpannerError(status.Error(codes.FailedPrecondition, "This function can only be used with a Spanner connection")) } - spannerConn.withTempBatchReadOnlyTransactionOptions(options) + spannerConn.setBatchReadOnlyTransactionOptions(options) return nil }) } diff --git a/spannerlib/api/connection.go b/spannerlib/api/connection.go index 907212ec..ab256879 100644 --- a/spannerlib/api/connection.go +++ b/spannerlib/api/connection.go @@ -121,8 +121,8 @@ type Connection struct { // It is implemented by the spannerdriver.conn struct. type spannerConn interface { WriteMutations(ctx context.Context, ms []*spanner.Mutation) (*spanner.CommitResponse, error) - BeginReadOnlyTransaction(ctx context.Context, options *spannerdriver.ReadOnlyTransactionOptions) (driver.Tx, error) - BeginReadWriteTransaction(ctx context.Context, options *spannerdriver.ReadWriteTransactionOptions) (driver.Tx, error) + BeginReadOnlyTransaction(ctx context.Context, options *spannerdriver.ReadOnlyTransactionOptions, close func()) (driver.Tx, error) + BeginReadWriteTransaction(ctx context.Context, options *spannerdriver.ReadWriteTransactionOptions, close func()) (driver.Tx, error) Commit(ctx context.Context) (*spanner.CommitResponse, error) Rollback(ctx context.Context) error } @@ -155,7 +155,10 @@ func (conn *Connection) writeMutations(ctx context.Context, mutation *spannerpb. } var commitResponse *spanner.CommitResponse if err := conn.backend.Raw(func(driverConn any) (err error) { - sc, _ := driverConn.(spannerConn) + sc, ok := driverConn.(spannerConn) + if !ok { + return status.Error(codes.Internal, "spanner driver connection does not implement spannerConn") + } commitResponse, err = sc.WriteMutations(ctx, mutations) return err }); err != nil { @@ -189,16 +192,22 @@ func (conn *Connection) BeginTransaction(ctx context.Context, txOpts *spannerpb. func (conn *Connection) beginReadOnlyTransaction(ctx context.Context, opts *spannerdriver.ReadOnlyTransactionOptions) error { return conn.backend.Raw(func(driverConn any) (err error) { - sc, _ := driverConn.(spannerConn) - _, err = sc.BeginReadOnlyTransaction(ctx, opts) + sc, ok := driverConn.(spannerConn) + if !ok { + return status.Error(codes.Internal, "driver connection does not implement spannerConn") + } + _, err = sc.BeginReadOnlyTransaction(ctx, opts, func() {}) return err }) } func (conn *Connection) beginReadWriteTransaction(ctx context.Context, opts *spannerdriver.ReadWriteTransactionOptions) error { return conn.backend.Raw(func(driverConn any) (err error) { - sc, _ := driverConn.(spannerConn) - _, err = sc.BeginReadWriteTransaction(ctx, opts) + sc, ok := driverConn.(spannerConn) + if !ok { + return status.Error(codes.Internal, "driver connection does not implement spannerConn") + } + _, err = sc.BeginReadWriteTransaction(ctx, opts, func() {}) return err }) } @@ -206,8 +215,11 @@ func (conn *Connection) beginReadWriteTransaction(ctx context.Context, opts *spa func (conn *Connection) commit(ctx context.Context) (*spannerpb.CommitResponse, error) { var response *spanner.CommitResponse if err := conn.backend.Raw(func(driverConn any) (err error) { - spannerConn, _ := driverConn.(spannerConn) - response, err = spannerConn.Commit(ctx) + sc, ok := driverConn.(spannerConn) + if !ok { + return status.Error(codes.Internal, "driver connection does not implement spannerConn") + } + response, err = sc.Commit(ctx) if err != nil { return err } @@ -226,8 +238,11 @@ func (conn *Connection) commit(ctx context.Context) (*spannerpb.CommitResponse, func (conn *Connection) rollback(ctx context.Context) error { return conn.backend.Raw(func(driverConn any) (err error) { - spannerConn, _ := driverConn.(spannerConn) - return spannerConn.Rollback(ctx) + sc, ok := driverConn.(spannerConn) + if !ok { + return status.Error(codes.Internal, "driver connection does not implement spannerConn") + } + return sc.Rollback(ctx) }) } diff --git a/spannerlib/lib/connection_test.go b/spannerlib/lib/connection_test.go index 193e3fd9..c9ccac88 100644 --- a/spannerlib/lib/connection_test.go +++ b/spannerlib/lib/connection_test.go @@ -196,12 +196,8 @@ func TestBeginAndCommit(t *testing.T) { if g, w := commitMsg.Code, int32(0); g != w { t.Fatalf("Commit result mismatch\n Got: %v\nWant: %v", g, w) } - if commitMsg.Length() == 0 { - t.Fatal("Commit return zero length") - } - resp := &spannerpb.CommitResponse{} - if err := proto.Unmarshal(commitMsg.Res, resp); err != nil { - t.Fatalf("Failed to unmarshal commit response: %v", err) + if commitMsg.Length() != 0 { + t.Fatal("Commit returned non-zero length") } closeMsg := CloseConnection(ctx, poolMsg.ObjectId, connMsg.ObjectId) diff --git a/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/TransactionTest.java b/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/TransactionTest.java index f2eef624..2619fd70 100644 --- a/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/TransactionTest.java +++ b/spannerlib/wrappers/spannerlib-java/src/test/java/com/google/cloud/spannerlib/TransactionTest.java @@ -53,9 +53,8 @@ public void testBeginAndCommit() { connection.beginTransaction(TransactionOptions.getDefaultInstance()); connection.commit(); - // TODO: The library should take a shortcut and just skip committing empty transactions. - assertEquals(1, mockSpanner.countRequestsOfType(BeginTransactionRequest.class)); - assertEquals(1, mockSpanner.countRequestsOfType(CommitRequest.class)); + assertEquals(0, mockSpanner.countRequestsOfType(BeginTransactionRequest.class)); + assertEquals(0, mockSpanner.countRequestsOfType(CommitRequest.class)); } } diff --git a/transaction.go b/transaction.go index 297c6b60..8b92cb24 100644 --- a/transaction.go +++ b/transaction.go @@ -53,6 +53,7 @@ type contextTransaction interface { RunBatch(ctx context.Context) (driver.Result, error) RunDmlBatch(ctx context.Context) (SpannerResult, error) AbortBatch() (driver.Result, error) + IsInBatch() bool BufferWrite(ms []*spanner.Mutation) error } @@ -106,6 +107,116 @@ const ( txResultRollback ) +var _ contextTransaction = &delegatingTransaction{} + +// delegatingTransaction wraps a read/write or read-only transaction and delegates +// all calls to the underlying transaction. The underlying transaction is automatically +// created when the first query or DML statement is executed. The type of transaction is +// determined at the moment that the underlying transaction is created. This allows an +// application to execute statements like `set transaction read only` at the start of a +// transaction to set the type of transaction. +type delegatingTransaction struct { + conn *conn + ctx context.Context + close func(result txResult) + contextTransaction contextTransaction +} + +func (d *delegatingTransaction) ensureActivated() error { + if d.contextTransaction != nil { + return nil + } + tx, err := d.conn.activateTransaction() + if err != nil { + return err + } + d.contextTransaction = tx + return nil +} + +func (d *delegatingTransaction) Commit() error { + if d.contextTransaction == nil { + d.close(txResultCommit) + return nil + } + return d.contextTransaction.Commit() +} + +func (d *delegatingTransaction) Rollback() error { + if d.contextTransaction == nil { + d.close(txResultRollback) + return nil + } + return d.contextTransaction.Rollback() +} + +func (d *delegatingTransaction) resetForRetry(ctx context.Context) error { + if d.contextTransaction == nil { + return status.Error(codes.FailedPrecondition, "a transaction can only be reset after it has been activated") + } + return d.contextTransaction.resetForRetry(ctx) +} + +func (d *delegatingTransaction) Query(ctx context.Context, stmt spanner.Statement, stmtType parser.StatementType, execOptions *ExecOptions) (rowIterator, error) { + if err := d.ensureActivated(); err != nil { + return nil, err + } + return d.contextTransaction.Query(ctx, stmt, stmtType, execOptions) +} + +func (d *delegatingTransaction) partitionQuery(ctx context.Context, stmt spanner.Statement, execOptions *ExecOptions) (driver.Rows, error) { + if err := d.ensureActivated(); err != nil { + return nil, err + } + return d.contextTransaction.partitionQuery(ctx, stmt, execOptions) +} + +func (d *delegatingTransaction) ExecContext(ctx context.Context, stmt spanner.Statement, statementInfo *parser.StatementInfo, options spanner.QueryOptions) (*result, error) { + if err := d.ensureActivated(); err != nil { + return nil, err + } + return d.contextTransaction.ExecContext(ctx, stmt, statementInfo, options) +} + +func (d *delegatingTransaction) StartBatchDML(options spanner.QueryOptions, automatic bool) (driver.Result, error) { + if err := d.ensureActivated(); err != nil { + return nil, err + } + return d.contextTransaction.StartBatchDML(options, automatic) +} + +func (d *delegatingTransaction) RunBatch(ctx context.Context) (driver.Result, error) { + if err := d.ensureActivated(); err != nil { + return nil, err + } + return d.contextTransaction.RunBatch(ctx) +} + +func (d *delegatingTransaction) RunDmlBatch(ctx context.Context) (SpannerResult, error) { + if err := d.ensureActivated(); err != nil { + return nil, err + } + return d.contextTransaction.RunDmlBatch(ctx) +} + +func (d *delegatingTransaction) AbortBatch() (driver.Result, error) { + if err := d.ensureActivated(); err != nil { + return nil, err + } + return d.contextTransaction.AbortBatch() +} + +func (d *delegatingTransaction) IsInBatch() bool { + return d.contextTransaction != nil && d.contextTransaction.IsInBatch() +} + +func (d *delegatingTransaction) BufferWrite(ms []*spanner.Mutation) error { + if err := d.ensureActivated(); err != nil { + return err + } + return d.contextTransaction.BufferWrite(ms) +} + var _ contextTransaction = &readOnlyTransaction{} type readOnlyTransaction struct { @@ -220,6 +331,10 @@ func (tx *readOnlyTransaction) AbortBatch() (driver.Result, error) { return driver.ResultNoRows, nil } +func (tx *readOnlyTransaction) IsInBatch() bool { + return false +} + func (tx *readOnlyTransaction) BufferWrite([]*spanner.Mutation) error { return spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "read-only transactions cannot write")) } @@ -428,7 +543,6 @@ func (tx *readWriteTransaction) Commit() (err error) { return err } var commitResponse spanner.CommitResponse - // TODO: Optimize this to skip the Commit also if the transaction has not yet been used. if tx.rwTx != nil { if !tx.retryAborts() { ts, err := tx.rwTx.CommitWithReturnResp(tx.ctx) @@ -586,6 +700,10 @@ func (tx *readWriteTransaction) AbortBatch() (driver.Result, error) { return driver.ResultNoRows, nil } +func (tx *readWriteTransaction) IsInBatch() bool { + return tx.batch != nil +} + func (tx *readWriteTransaction) maybeRunAutoDmlBatch(ctx context.Context) error { if tx.batch == nil || !tx.batch.automatic { return nil