From a3a8d4aa4be9ab18d92f4941d880f3d8670c5a31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Tue, 1 Jul 2025 17:29:09 +0200 Subject: [PATCH 1/2] perf: inline BeginTransaction with first statement Inline the BeginTransaction option with the first statement in the transaction, instead of executing a separate BeginTransaction RPC. This reduces the number of round-trips to Spanner by one for all transactions that have at least one SQL statement. Using line-begin improves performance for most transaction shapes, as it requires one less round-trip to Spanner. Some transaction shapes do not benefit from this. These are: 1. Transactions that only write mutations still need an explicit BeginTransaction RPC to be executed, as mutations are included in the Commit RPC. The Commit RPC can also start a transaction, but such transactions are not guaranteed to be applied only once to Spanner. 2. Transactions that execute multiple parallel queries at the start of the transaction can see higher end-to-end execution times, as only one query can include the BeginTransaction option. All other queries must wait for the first query to return at least one result, which also includes the transaction identifier, before they can proceed. The default for the database/sql driver is to use inline-begin. A follow-up pull request will add an option to the driver to set a different default for a connection. --- aborted_transactions_test.go | 32 +++++++++--- auto_dml_batch_test.go | 7 ++- conn.go | 7 ++- conn_with_mockserver_test.go | 39 ++++++++++---- driver.go | 3 +- driver_with_mockserver_test.go | 95 ++++++++++++++++++++++------------ 6 files changed, 129 insertions(+), 54 deletions(-) diff --git a/aborted_transactions_test.go b/aborted_transactions_test.go index ed6ae430..3df08d65 100644 --- a/aborted_transactions_test.go +++ b/aborted_transactions_test.go @@ -325,7 +325,7 @@ func TestQueryWithError_CommitAborted(t *testing.T) { server.PutExecutionTime(testutil.MethodCommitTransaction, testutil.SimulatedExecutionTime{ Errors: []error{status.Error(codes.Aborted, "Aborted")}, }) - }, codes.NotFound, 0, 2, 2) + }, codes.NotFound, 0, 3, 2) } func TestQueryWithErrorHalfway_CommitAborted(t *testing.T) { @@ -1080,7 +1080,7 @@ func TestBatchUpdateAbortedWithError_DifferentErrorDuringRetry(t *testing.T) { t.Fatalf("dml statement failed: %v", err) } if _, err := tx.ExecContext(ctx, "RUN BATCH"); spanner.ErrCode(err) != codes.NotFound { - t.Fatalf("error code mismatch\nGot: %v\nWant: %v", spanner.ErrCode(err), codes.NotFound) + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", spanner.ErrCode(err), codes.NotFound) } // Remove the error for the DML statement and cause a retry. The missing @@ -1094,19 +1094,37 @@ func TestBatchUpdateAbortedWithError_DifferentErrorDuringRetry(t *testing.T) { }) err = tx.Commit() if err != ErrAbortedDueToConcurrentModification { - t.Fatalf("commit error mismatch\nGot: %v\nWant: %v", err, ErrAbortedDueToConcurrentModification) + t.Fatalf("commit error mismatch\n Got: %v\nWant: %v", err, ErrAbortedDueToConcurrentModification) } reqs := drainRequestsFromServer(server.TestSpanner) execReqs := requestsOfType(reqs, reflect.TypeOf(&sppb.ExecuteBatchDmlRequest{})) - if g, w := len(execReqs), 2; g != w { - t.Fatalf("batch request count mismatch\nGot: %v\nWant: %v", g, w) + // There are 3 ExecuteBatchDmlRequests sent to Spanner: + // 1. An initial attempt with a BeginTransaction RPC, but this returns a NotFound error. + // This causes the transaction to be retried with an explicit BeginTransaction request. + // 2. Another attempt with a transaction ID. + // 3. A third attempt after the initial transaction is aborted. + if g, w := len(execReqs), 3; g != w { + t.Fatalf("batch request count mismatch\n Got: %v\nWant: %v", g, w) } commitReqs := requestsOfType(reqs, reflect.TypeOf(&sppb.CommitRequest{})) // The commit should be attempted only once. if g, w := len(commitReqs), 1; g != w { - t.Fatalf("commit request count mismatch\nGot: %v\nWant: %v", g, w) + t.Fatalf("commit request count mismatch\n Got: %v\nWant: %v", g, w) + } + // The first ExecuteBatchDml request should try to use an inline-begin. + // After that, we should have two BeginTransaction requests. + req1 := execReqs[0].(*sppb.ExecuteBatchDmlRequest) + if req1.GetTransaction() == nil || req1.GetTransaction().GetBegin() == nil { + t.Fatal("the first ExecuteBatchDmlRequest should have a BeginTransaction") + } + req2 := execReqs[1].(*sppb.ExecuteBatchDmlRequest) + if req2.GetTransaction() == nil || req2.GetTransaction().GetId() == nil { + t.Fatal("the second ExecuteBatchDmlRequest should have a transaction id") + } + beginRequests := requestsOfType(reqs, reflect.TypeOf(&sppb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 2; g != w { + t.Fatalf("begin request count mismatch\n Got: %v\nWant: %v", g, w) } - // Verify that the db is still usable. if _, err := db.ExecContext(ctx, testutil.UpdateSingersSetLastName); err != nil { t.Fatalf("failed to execute statement after transaction: %v", err) diff --git a/auto_dml_batch_test.go b/auto_dml_batch_test.go index 0941dbb5..b6f14937 100644 --- a/auto_dml_batch_test.go +++ b/auto_dml_batch_test.go @@ -303,8 +303,13 @@ func TestAutoBatchDml_FollowedByRollback(t *testing.T) { if g, w := len(commitRequests), 0; g != w { t.Fatalf("num commit requests mismatch\n Got: %v\nWant: %v", g, w) } + beginRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 0; g != w { + t.Fatalf("num BeginTransaction requests mismatch\n Got: %v\nWant: %v", g, w) + } rollbackRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.RollbackRequest{})) - if g, w := len(rollbackRequests), 1; g != w { + // There are no rollback requests sent to Spanner, as the transaction is never started. + if g, w := len(rollbackRequests), 0; g != w { t.Fatalf("num rollback requests mismatch\n Got: %v\nWant: %v", g, w) } } diff --git a/conn.go b/conn.go index 7208fc53..e11cec54 100644 --- a/conn.go +++ b/conn.go @@ -947,6 +947,9 @@ func (c *conn) getTransactionOptions() ReadWriteTransactionOptions { txOpts.TransactionOptions.IsolationLevel = level } } + if txOpts.TransactionOptions.BeginTransactionOption == spanner.DefaultBeginTransaction { + txOpts.TransactionOptions.BeginTransactionOption = spanner.InlinedBeginTransaction + } return txOpts } @@ -959,7 +962,7 @@ func (c *conn) getReadOnlyTransactionOptions() ReadOnlyTransactionOptions { defer func() { c.tempReadOnlyTransactionOptions = nil }() return *c.tempReadOnlyTransactionOptions } - return ReadOnlyTransactionOptions{TimestampBound: c.readOnlyStaleness} + return ReadOnlyTransactionOptions{TimestampBound: c.readOnlyStaleness, BeginTransactionOption: spanner.InlinedBeginTransaction} } func (c *conn) withTempBatchReadOnlyTransactionOptions(options *BatchReadOnlyTransactionOptions) { @@ -1034,7 +1037,7 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e ro = &bo.ReadOnlyTransaction } else { logger = c.logger.With("tx", "ro") - ro = c.client.ReadOnlyTransaction().WithTimestampBound(readOnlyTxOpts.TimestampBound) + ro = c.client.ReadOnlyTransaction().WithBeginTransactionOption(readOnlyTxOpts.BeginTransactionOption).WithTimestampBound(readOnlyTxOpts.TimestampBound) } c.tx = &readOnlyTransaction{ roTx: ro, diff --git a/conn_with_mockserver_test.go b/conn_with_mockserver_test.go index b64af250..70107511 100644 --- a/conn_with_mockserver_test.go +++ b/conn_with_mockserver_test.go @@ -39,11 +39,18 @@ func TestBeginTx(t *testing.T) { requests := drainRequestsFromServer(server.TestSpanner) beginRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) - if g, w := len(beginRequests), 1; g != w { + if g, w := len(beginRequests), 0; g != w { t.Fatalf("begin requests count mismatch\n Got: %v\nWant: %v", g, w) } - request := beginRequests[0].(*spannerpb.BeginTransactionRequest) - if g, w := request.Options.GetIsolationLevel(), spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED; g != w { + executeRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 1; g != w { + t.Fatalf("execute requests count mismatch\n Got: %v\nWant: %v", g, w) + } + request := executeRequests[0].(*spannerpb.ExecuteSqlRequest) + if request.GetTransaction() == nil || request.GetTransaction().GetBegin() == nil { + t.Fatal("missing begin transaction on ExecuteSqlRequest") + } + if g, w := request.GetTransaction().GetBegin().GetIsolationLevel(), spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED; g != w { t.Fatalf("begin isolation level mismatch\n Got: %v\nWant: %v", g, w) } } @@ -76,12 +83,19 @@ func TestBeginTxWithIsolationLevel(t *testing.T) { requests := drainRequestsFromServer(server.TestSpanner) beginRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) - if g, w := len(beginRequests), 1; g != w { + if g, w := len(beginRequests), 0; g != w { t.Fatalf("begin requests count mismatch\n Got: %v\nWant: %v", g, w) } - request := beginRequests[0].(*spannerpb.BeginTransactionRequest) + executeRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 1; g != w { + t.Fatalf("execute requests count mismatch\n Got: %v\nWant: %v", g, w) + } + request := executeRequests[0].(*spannerpb.ExecuteSqlRequest) + if request.GetTransaction() == nil || request.GetTransaction().GetBegin() == nil { + t.Fatalf("execute request does not have a begin transaction") + } wantIsolationLevel, _ := toProtoIsolationLevel(originalLevel) - if g, w := request.Options.GetIsolationLevel(), wantIsolationLevel; g != w { + if g, w := request.GetTransaction().GetBegin().GetIsolationLevel(), wantIsolationLevel; g != w { t.Fatalf("begin isolation level mismatch\n Got: %v\nWant: %v", g, w) } } @@ -162,12 +176,19 @@ func TestDefaultIsolationLevel(t *testing.T) { requests := drainRequestsFromServer(server.TestSpanner) beginRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) - if g, w := len(beginRequests), 1; g != w { + if g, w := len(beginRequests), 0; g != w { t.Fatalf("begin requests count mismatch\n Got: %v\nWant: %v", g, w) } - request := beginRequests[0].(*spannerpb.BeginTransactionRequest) + executeRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 1; g != w { + t.Fatalf("execute requests count mismatch\n Got: %v\nWant: %v", g, w) + } + request := executeRequests[0].(*spannerpb.ExecuteSqlRequest) + if request.GetTransaction() == nil || request.GetTransaction().GetBegin() == nil { + t.Fatalf("ExecuteSqlRequest should have a Begin transaction") + } wantIsolationLevel, _ := toProtoIsolationLevel(originalLevel) - if g, w := request.Options.GetIsolationLevel(), wantIsolationLevel; g != w { + if g, w := request.GetTransaction().GetBegin().GetIsolationLevel(), wantIsolationLevel; g != w { t.Fatalf("begin isolation level mismatch\n Got: %v\nWant: %v", g, w) } } diff --git a/driver.go b/driver.go index f29f7151..55953181 100644 --- a/driver.go +++ b/driver.go @@ -1026,7 +1026,8 @@ func clearTempReadWriteTransactionOptions(conn *sql.Conn) { // ReadOnlyTransactionOptions can be used to create a read-only transaction // on a Spanner connection. type ReadOnlyTransactionOptions struct { - TimestampBound spanner.TimestampBound + TimestampBound spanner.TimestampBound + BeginTransactionOption spanner.BeginTransactionOption close func() } diff --git a/driver_with_mockserver_test.go b/driver_with_mockserver_test.go index 60c539e5..537db19a 100644 --- a/driver_with_mockserver_test.go +++ b/driver_with_mockserver_test.go @@ -441,8 +441,8 @@ func TestSimpleReadOnlyTransaction(t *testing.T) { if req.Transaction == nil { t.Fatalf("missing transaction for ExecuteSqlRequest") } - if req.Transaction.GetId() == nil { - t.Fatalf("missing id selector for ExecuteSqlRequest") + if req.Transaction.GetBegin() == nil { + t.Fatalf("missing begin selector for ExecuteSqlRequest") } // Read-only transactions are not really committed on Cloud Spanner, so // there should be no commit request on the server. @@ -451,7 +451,7 @@ func TestSimpleReadOnlyTransaction(t *testing.T) { t.Fatalf("commit requests count mismatch\nGot: %v\nWant: %v", g, w) } beginReadOnlyRequests := filterBeginReadOnlyRequests(requestsOfType(requests, reflect.TypeOf(&sppb.BeginTransactionRequest{}))) - if g, w := len(beginReadOnlyRequests), 1; g != w { + if g, w := len(beginReadOnlyRequests), 0; g != w { t.Fatalf("begin requests count mismatch\nGot: %v\nWant: %v", g, w) } } @@ -491,12 +491,19 @@ func TestReadOnlyTransactionWithStaleness(t *testing.T) { requests := drainRequestsFromServer(server.TestSpanner) beginReadOnlyRequests := filterBeginReadOnlyRequests(requestsOfType(requests, reflect.TypeOf(&sppb.BeginTransactionRequest{}))) - if g, w := len(beginReadOnlyRequests), 1; g != w { - t.Fatalf("begin requests count mismatch\nGot: %v\nWant: %v", g, w) + if g, w := len(beginReadOnlyRequests), 0; g != w { + t.Fatalf("begin requests count mismatch\n Got: %v\nWant: %v", g, w) + } + executeRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 1; g != w { + t.Fatalf("execute requests count mismatch\n Got: %v\nWant: %v", g, w) } - beginReq := beginReadOnlyRequests[0] - if beginReq.GetOptions().GetReadOnly().GetExactStaleness() == nil { - t.Fatalf("missing exact_staleness option on BeginTransaction request") + executeReq := executeRequests[0].(*sppb.ExecuteSqlRequest) + if executeReq.GetTransaction() == nil || executeReq.GetTransaction().GetBegin() == nil { + t.Fatalf("missing begin selector for ExecuteSqlRequest") + } + if executeReq.GetTransaction().GetBegin().GetReadOnly().GetExactStaleness() == nil { + t.Fatalf("missing exact_staleness option on BeginTransaction option") } } @@ -510,8 +517,10 @@ func TestReadOnlyTransactionWithOptions(t *testing.T) { // Set max open connections to 1 to force a failure if there is a connection leak. db.SetMaxOpenConns(1) - tx, err := BeginReadOnlyTransaction(ctx, db, - ReadOnlyTransactionOptions{TimestampBound: spanner.ExactStaleness(10 * time.Second)}) + tx, err := BeginReadOnlyTransaction(ctx, db, ReadOnlyTransactionOptions{ + TimestampBound: spanner.ExactStaleness(10 * time.Second), + BeginTransactionOption: spanner.InlinedBeginTransaction, + }) if err != nil { t.Fatal(err) } @@ -535,12 +544,19 @@ func TestReadOnlyTransactionWithOptions(t *testing.T) { requests := drainRequestsFromServer(server.TestSpanner) beginReadOnlyRequests := filterBeginReadOnlyRequests(requestsOfType(requests, reflect.TypeOf(&sppb.BeginTransactionRequest{}))) - if g, w := len(beginReadOnlyRequests), 1; g != w { + if g, w := len(beginReadOnlyRequests), 0; g != w { t.Fatalf("begin requests count mismatch\nGot: %v\nWant: %v", g, w) } - beginReq := beginReadOnlyRequests[0] - if beginReq.GetOptions().GetReadOnly().GetExactStaleness() == nil { - t.Fatalf("missing exact_staleness option on BeginTransaction request") + executeRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 1; g != w { + t.Fatalf("execute requests count mismatch\n Got: %v\nWant: %v", g, w) + } + executeReq := executeRequests[0].(*sppb.ExecuteSqlRequest) + if executeReq.GetTransaction() == nil || executeReq.GetTransaction().GetBegin() == nil { + t.Fatalf("missing begin selector for ExecuteSqlRequest") + } + if executeReq.GetTransaction().GetBegin().GetReadOnly().GetExactStaleness() == nil { + t.Fatalf("missing exact_staleness option on BeginTransaction option") } // Verify that the staleness option is not 'sticky' on the database. @@ -554,11 +570,18 @@ func TestReadOnlyTransactionWithOptions(t *testing.T) { requests = drainRequestsFromServer(server.TestSpanner) beginReadOnlyRequests = filterBeginReadOnlyRequests(requestsOfType(requests, reflect.TypeOf(&sppb.BeginTransactionRequest{}))) - if g, w := len(beginReadOnlyRequests), 1; g != w { + if g, w := len(beginReadOnlyRequests), 0; g != w { t.Fatalf("begin requests count mismatch\nGot: %v\nWant: %v", g, w) } - beginReq = beginReadOnlyRequests[0] - if beginReq.GetOptions().GetReadOnly().GetExactStaleness() != nil { + executeRequests = requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 1; g != w { + t.Fatalf("execute requests count mismatch\n Got: %v\nWant: %v", g, w) + } + executeReq = executeRequests[0].(*sppb.ExecuteSqlRequest) + if executeReq.GetTransaction() == nil || executeReq.GetTransaction().GetBegin() == nil { + t.Fatalf("missing begin selector for ExecuteSqlRequest") + } + if executeReq.GetTransaction().GetBegin().GetReadOnly().GetExactStaleness() != nil { t.Fatalf("got unexpected exact_staleness option on BeginTransaction request") } } @@ -615,28 +638,29 @@ func TestSimpleReadWriteTransaction(t *testing.T) { } requests := drainRequestsFromServer(server.TestSpanner) + beginRequests := requestsOfType(requests, reflect.TypeOf(&sppb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 0; g != w { + t.Fatalf("begin requests count mismatch\n Got: %v\nWant: %v", g, w) + } sqlRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{})) if g, w := len(sqlRequests), 1; g != w { - t.Fatalf("ExecuteSqlRequests count mismatch\nGot: %v\nWant: %v", g, w) + t.Fatalf("ExecuteSqlRequests count mismatch\n Got: %v\nWant: %v", g, w) } req := sqlRequests[0].(*sppb.ExecuteSqlRequest) if req.Transaction == nil { t.Fatalf("missing transaction for ExecuteSqlRequest") } - if req.Transaction.GetId() == nil { - t.Fatalf("missing id selector for ExecuteSqlRequest") + if req.Transaction.GetBegin() == nil { + t.Fatalf("missing begin selector for ExecuteSqlRequest") } if req.LastStatement { t.Fatalf("last statement set for ExecuteSqlRequest") } commitRequests := requestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{})) if g, w := len(commitRequests), 1; g != w { - t.Fatalf("commit requests count mismatch\nGot: %v\nWant: %v", g, w) + t.Fatalf("commit requests count mismatch\n Got: %v\nWant: %v", g, w) } commitReq := commitRequests[0].(*sppb.CommitRequest) - if c, e := commitReq.GetTransactionId(), req.Transaction.GetId(); !cmp.Equal(c, e) { - t.Fatalf("transaction id mismatch\nCommit: %c\nExecute: %v", c, e) - } if g, w := commitReq.MaxCommitDelay.Nanos, int32(time.Millisecond*10); g != w { t.Fatalf("max_commit_delay mismatch\n Got: %v\nWant: %v", g, w) } @@ -3659,11 +3683,18 @@ func TestExcludeTxnFromChangeStreams_Transaction(t *testing.T) { requests := drainRequestsFromServer(server.TestSpanner) beginRequests := requestsOfType(requests, reflect.TypeOf(&sppb.BeginTransactionRequest{})) - if g, w := len(beginRequests), 1; g != w { - t.Fatalf("BeginTransactionRequest count mismatch\nGot: %v\nWant: %v", g, w) + if g, w := len(beginRequests), 0; g != w { + t.Fatalf("BeginTransactionRequest count mismatch\n Got: %v\nWant: %v", g, w) } - req := beginRequests[0].(*sppb.BeginTransactionRequest) - if !req.Options.ExcludeTxnFromChangeStreams { + executeRequests := requestsOfType(requests, reflect.TypeOf(&sppb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 1; g != w { + t.Fatalf("ExecuteSqlRequest count mismatch\n Got: %v\nWant: %v", g, w) + } + req := executeRequests[0].(*sppb.ExecuteSqlRequest) + if req.GetTransaction() == nil || req.GetTransaction().GetBegin() == nil { + t.Fatal("missing BeginTransaction option on ExecuteSqlRequest") + } + if !req.GetTransaction().GetBegin().ExcludeTxnFromChangeStreams { t.Fatalf("missing ExcludeTxnFromChangeStreams option on BeginTransaction option") } @@ -4752,17 +4783,13 @@ func TestTransactionWithLevelDisableRetryAborts(t *testing.T) { if req.Transaction == nil { t.Fatalf("missing transaction for ExecuteSqlRequest") } - if req.Transaction.GetId() == nil { - t.Fatalf("missing id selector for ExecuteSqlRequest") + if req.Transaction.GetBegin() == nil { + t.Fatalf("missing begin selector for ExecuteSqlRequest") } commitRequests := requestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{})) if g, w := len(commitRequests), 1; g != w { t.Fatalf("commit requests count mismatch\nGot: %v\nWant: %v", g, w) } - commitReq := commitRequests[0].(*sppb.CommitRequest) - if c, e := commitReq.GetTransactionId(), req.Transaction.GetId(); !cmp.Equal(c, e) { - t.Fatalf("transaction id mismatch\nCommit: %c\nExecute: %v", c, e) - } } func TestBeginReadWriteTransaction(t *testing.T) { From 3d26ee4d1cacf12995a79e0ed29069d5e3acee2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Tue, 1 Jul 2025 21:24:30 +0200 Subject: [PATCH 2/2] feat: add BeginTransactionOption (#467) Adds a BeginTransactionOption configuration field that can be used to determine how the database/sql driver should begin transactions. The default is to inline the BeginTransaction option with the first SQL statement of the transaction. This reduces the number of round-trips needed per transaction by one for most transaction shapes. --- conn.go | 25 +++++++++++--- conn_with_mockserver_test.go | 47 ++++++++++++++++++++++++++ driver.go | 22 ++++++++++++ driver_test.go | 62 ++++++++++++++++++++++++++++++++-- driver_with_mockserver_test.go | 40 ++++++++++++---------- 5 files changed, 173 insertions(+), 23 deletions(-) diff --git a/conn.go b/conn.go index e11cec54..57fae247 100644 --- a/conn.go +++ b/conn.go @@ -248,6 +248,8 @@ type conn struct { // transactions on this connection. This default is ignored if the BeginTx function is // called with an isolation level other than sql.LevelDefault. isolationLevel sql.IsolationLevel + // beginTransactionOption determines the default transactions start mode. + beginTransactionOption spanner.BeginTransactionOption // execOptions are applied to the next statement or transaction that is executed // on this connection. It can also be set by passing it in as an argument to @@ -660,6 +662,7 @@ func (c *conn) ResetSession(_ context.Context) error { c.autoBatchDmlUpdateCountVerification = !c.connector.connectorConfig.DisableAutoBatchDmlUpdateCountVerification c.retryAborts = c.connector.retryAbortsInternally c.isolationLevel = c.connector.connectorConfig.IsolationLevel + c.beginTransactionOption = c.connector.connectorConfig.BeginTransactionOption // TODO: Reset the following fields to the connector default c.autocommitDMLMode = Transactional c.readOnlyStaleness = spanner.TimestampBound{} @@ -927,7 +930,9 @@ func (c *conn) withTempTransactionOptions(options *ReadWriteTransactionOptions) func (c *conn) getTransactionOptions() ReadWriteTransactionOptions { if c.tempTransactionOptions != nil { defer func() { c.tempTransactionOptions = nil }() - return *c.tempTransactionOptions + opts := *c.tempTransactionOptions + opts.TransactionOptions.BeginTransactionOption = c.convertDefaultBeginTransactionOption(opts.TransactionOptions.BeginTransactionOption) + return opts } // Clear the transaction tag that has been set on the connection after returning // from this function. @@ -948,7 +953,7 @@ func (c *conn) getTransactionOptions() ReadWriteTransactionOptions { } } if txOpts.TransactionOptions.BeginTransactionOption == spanner.DefaultBeginTransaction { - txOpts.TransactionOptions.BeginTransactionOption = spanner.InlinedBeginTransaction + txOpts.TransactionOptions.BeginTransactionOption = c.convertDefaultBeginTransactionOption(c.beginTransactionOption) } return txOpts } @@ -960,9 +965,11 @@ func (c *conn) withTempReadOnlyTransactionOptions(options *ReadOnlyTransactionOp func (c *conn) getReadOnlyTransactionOptions() ReadOnlyTransactionOptions { if c.tempReadOnlyTransactionOptions != nil { defer func() { c.tempReadOnlyTransactionOptions = nil }() - return *c.tempReadOnlyTransactionOptions + opts := *c.tempReadOnlyTransactionOptions + opts.BeginTransactionOption = c.convertDefaultBeginTransactionOption(opts.BeginTransactionOption) + return opts } - return ReadOnlyTransactionOptions{TimestampBound: c.readOnlyStaleness, BeginTransactionOption: spanner.InlinedBeginTransaction} + return ReadOnlyTransactionOptions{TimestampBound: c.readOnlyStaleness, BeginTransactionOption: c.convertDefaultBeginTransactionOption(c.beginTransactionOption)} } func (c *conn) withTempBatchReadOnlyTransactionOptions(options *BatchReadOnlyTransactionOptions) { @@ -1083,6 +1090,16 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e return c.tx, nil } +func (c *conn) convertDefaultBeginTransactionOption(opt spanner.BeginTransactionOption) spanner.BeginTransactionOption { + if opt == spanner.DefaultBeginTransaction { + if c.beginTransactionOption == spanner.DefaultBeginTransaction { + return spanner.InlinedBeginTransaction + } + return c.beginTransactionOption + } + return opt +} + func (c *conn) inTransaction() bool { return c.tx != nil } diff --git a/conn_with_mockserver_test.go b/conn_with_mockserver_test.go index 70107511..b6be1708 100644 --- a/conn_with_mockserver_test.go +++ b/conn_with_mockserver_test.go @@ -55,6 +55,53 @@ func TestBeginTx(t *testing.T) { } } +func TestExplicitBeginTx(t *testing.T) { + t.Parallel() + + db, server, teardown := setupTestDBConnectionWithConnectorConfig(t, ConnectorConfig{ + Project: "p", + Instance: "i", + Database: "d", + + BeginTransactionOption: spanner.ExplicitBeginTransaction, + }) + defer teardown() + ctx := context.Background() + + for _, readOnly := range []bool{true, false} { + tx, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: readOnly}) + if err != nil { + t.Fatal(err) + } + res, err := tx.QueryContext(ctx, testutil.SelectFooFromBar) + if err != nil { + t.Fatal(err) + } + for res.Next() { + } + if err := res.Err(); err != nil { + t.Fatal(err) + } + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } + + requests := drainRequestsFromServer(server.TestSpanner) + beginRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 1; g != w { + t.Fatalf("begin requests count mismatch\n Got: %v\nWant: %v", g, w) + } + executeRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 1; g != w { + t.Fatalf("execute requests count mismatch\n Got: %v\nWant: %v", g, w) + } + request := executeRequests[0].(*spannerpb.ExecuteSqlRequest) + if request.GetTransaction() == nil || request.GetTransaction().GetId() == nil { + t.Fatal("missing transaction id on ExecuteSqlRequest") + } + } +} + func TestBeginTxWithIsolationLevel(t *testing.T) { t.Parallel() diff --git a/driver.go b/driver.go index 55953181..b50b1aa1 100644 --- a/driver.go +++ b/driver.go @@ -304,6 +304,11 @@ type ConnectorConfig struct { // IsolationLevel is the default isolation level for read/write transactions. IsolationLevel sql.IsolationLevel + // BeginTransactionOption determines the default for how to begin transactions. + // The Spanner database/sql driver uses spanner.InlinedBeginTransaction by default + // for both read-only and read/write transactions. + BeginTransactionOption spanner.BeginTransactionOption + // DecodeToNativeArrays determines whether arrays that have a Go native // type should be decoded to those types rather than the corresponding // spanner.NullTypeName type. @@ -551,6 +556,11 @@ func createConnector(d *Driver, connectorConfig ConnectorConfig) (*connector, er connectorConfig.IsolationLevel = val } } + if strval, ok := connectorConfig.Params[strings.ToLower("BeginTransactionOption")]; ok { + if val, err := parseBeginTransactionOption(strval); err == nil { + connectorConfig.BeginTransactionOption = val + } + } if strval, ok := connectorConfig.Params[strings.ToLower("StatementCacheSize")]; ok { if val, err := strconv.Atoi(strval); err == nil { connectorConfig.StatementCacheSize = val @@ -1284,6 +1294,18 @@ func checkIsValidType(v driver.Value) bool { return true } +func parseBeginTransactionOption(val string) (spanner.BeginTransactionOption, error) { + switch strings.ToLower(val) { + case strings.ToLower("DefaultBeginTransaction"): + return spanner.DefaultBeginTransaction, nil + case strings.ToLower("InlinedBeginTransaction"): + return spanner.InlinedBeginTransaction, nil + case strings.ToLower("ExplicitBeginTransaction"): + return spanner.ExplicitBeginTransaction, nil + } + return spanner.DefaultBeginTransaction, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "invalid or unsupported BeginTransactionOption: %v", val)) +} + func parseIsolationLevel(val string) (sql.IsolationLevel, error) { switch strings.Replace(strings.ToLower(strings.TrimSpace(val)), " ", "_", 1) { case "default": diff --git a/driver_test.go b/driver_test.go index e9454843..8873ad0e 100644 --- a/driver_test.go +++ b/driver_test.go @@ -171,6 +171,23 @@ func TestExtractDnsParts(t *testing.T) { Params: map[string]string{ "isolationlevel": "repeatable_read", }, + IsolationLevel: sql.LevelRepeatableRead, + }, + wantSpannerConfig: spanner.ClientConfig{ + SessionPoolConfig: spanner.DefaultSessionPoolConfig, + UserAgent: userAgent, + }, + }, + { + input: "projects/p/instances/i/databases/d?beginTransactionOption=ExplicitBeginTransaction", + wantConnectorConfig: ConnectorConfig{ + Project: "p", + Instance: "i", + Database: "d", + Params: map[string]string{ + "begintransactionoption": "ExplicitBeginTransaction", + }, + BeginTransactionOption: spanner.ExplicitBeginTransaction, }, wantSpannerConfig: spanner.ClientConfig{ SessionPoolConfig: spanner.DefaultSessionPoolConfig, @@ -186,6 +203,7 @@ func TestExtractDnsParts(t *testing.T) { Params: map[string]string{ "statementcachesize": "100", }, + StatementCacheSize: 100, }, wantSpannerConfig: spanner.ClientConfig{ SessionPoolConfig: spanner.DefaultSessionPoolConfig, @@ -252,8 +270,7 @@ func TestExtractDnsParts(t *testing.T) { if tc.wantErr { t.Error("did not encounter expected error") } - tc.wantConnectorConfig.name = tc.input - if diff := cmp.Diff(config, tc.wantConnectorConfig, cmp.AllowUnexported(ConnectorConfig{})); diff != "" { + if diff := cmp.Diff(config.Params, tc.wantConnectorConfig.Params); diff != "" { t.Errorf("connector config mismatch for %q\n%v", tc.input, diff) } conn, err := newOrCachedConnector(&Driver{connectors: make(map[string]*connector)}, tc.input) @@ -263,6 +280,47 @@ func TestExtractDnsParts(t *testing.T) { if diff := cmp.Diff(conn.spannerClientConfig, tc.wantSpannerConfig, cmpopts.IgnoreUnexported(spanner.ClientConfig{}, spanner.SessionPoolConfig{}, spanner.InactiveTransactionRemovalOptions{}, spannerpb.ExecuteSqlRequest_QueryOptions{})); diff != "" { t.Errorf("connector Spanner client config mismatch for %q\n%v", tc.input, diff) } + actualConfig := conn.connectorConfig + actualConfig.name = "" + if diff := cmp.Diff(actualConfig, tc.wantConnectorConfig, cmp.AllowUnexported(ConnectorConfig{})); diff != "" { + t.Errorf("actual connector config mismatch for %q\n%v", tc.input, diff) + } + } + }) + } +} + +func TestParseBeginTransactionOption(t *testing.T) { + tests := []struct { + input string + want spanner.BeginTransactionOption + wantErr bool + }{ + { + input: "DefaultBeginTransaction", + want: spanner.DefaultBeginTransaction, + }, + { + input: "InlinedBeginTransaction", + want: spanner.InlinedBeginTransaction, + }, + { + input: "ExplicitBeginTransaction", + want: spanner.ExplicitBeginTransaction, + }, + { + input: "invalid", + wantErr: true, + }, + } + for i, test := range tests { + t.Run(test.input, func(t *testing.T) { + val, err := parseBeginTransactionOption(test.input) + if (err != nil) != test.wantErr { + t.Errorf("%d: parseBeginTransactionOption(%q) error = %v, wantErr %v", i, err, test.wantErr, err) + } + if g, w := val, test.want; g != w { + t.Errorf("%d: parseBeginTransactionOption(%q) = %v, want %v", i, g, w, g) } }) } diff --git a/driver_with_mockserver_test.go b/driver_with_mockserver_test.go index 537db19a..edb382c5 100644 --- a/driver_with_mockserver_test.go +++ b/driver_with_mockserver_test.go @@ -4418,17 +4418,13 @@ func TestRunTransaction(t *testing.T) { if req.Transaction == nil { t.Fatalf("missing transaction for ExecuteSqlRequest") } - if req.Transaction.GetId() == nil { - t.Fatalf("missing id selector for ExecuteSqlRequest") + if req.Transaction.GetBegin() == nil { + t.Fatalf("missing begin selector for ExecuteSqlRequest") } commitRequests := requestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{})) if g, w := len(commitRequests), 1; g != w { t.Fatalf("commit requests count mismatch\nGot: %v\nWant: %v", g, w) } - commitReq := commitRequests[0].(*sppb.CommitRequest) - if c, e := commitReq.GetTransactionId(), req.Transaction.GetId(); !cmp.Equal(c, e) { - t.Fatalf("transaction id mismatch\nCommit: %c\nExecute: %v", c, e) - } } func TestRunTransactionCommitAborted(t *testing.T) { @@ -4494,12 +4490,19 @@ func TestRunTransactionCommitAborted(t *testing.T) { if req.Transaction == nil { t.Fatalf("missing transaction for ExecuteSqlRequest") } - if req.Transaction.GetId() == nil { - t.Fatalf("missing id selector for ExecuteSqlRequest") - } - commitReq := commitRequests[i].(*sppb.CommitRequest) - if c, e := commitReq.GetTransactionId(), req.Transaction.GetId(); !cmp.Equal(c, e) { - t.Fatalf("transaction id mismatch\nCommit: %c\nExecute: %v", c, e) + if i == 0 { + if req.Transaction.GetBegin() == nil { + t.Fatalf("missing begin selector for ExecuteSqlRequest") + } + } else { + // The retried transaction uses an explicit BeginTransaction RPC. + if req.Transaction.GetId() == nil { + t.Fatalf("missing id selector for ExecuteSqlRequest") + } + commitReq := commitRequests[i].(*sppb.CommitRequest) + if c, e := commitReq.GetTransactionId(), req.Transaction.GetId(); !cmp.Equal(c, e) { + t.Fatalf("transaction id mismatch\nCommit: %c\nExecute: %v", c, e) + } } } } @@ -4633,9 +4636,11 @@ func TestRunTransactionQueryError(t *testing.T) { if g, w := len(commitRequests), 0; g != w { t.Fatalf("commit requests count mismatch\nGot: %v\nWant: %v", g, w) } - // There should be a RollbackRequest, as the transaction failed. + // There is no RollbackRequest, as the transaction was never started. + // The ExecuteSqlRequest included a BeginTransaction option, but because that + // request failed, the transaction was not started. rollbackRequests := requestsOfType(requests, reflect.TypeOf(&sppb.RollbackRequest{})) - if g, w := len(rollbackRequests), 1; g != w { + if g, w := len(rollbackRequests), 0; g != w { t.Fatalf("rollback requests count mismatch\nGot: %v\nWant: %v", g, w) } } @@ -4806,8 +4811,9 @@ func TestBeginReadWriteTransaction(t *testing.T) { tx, err := BeginReadWriteTransaction(ctx, db, ReadWriteTransactionOptions{ DisableInternalRetries: true, TransactionOptions: spanner.TransactionOptions{ - TransactionTag: tag, - CommitPriority: sppb.RequestOptions_PRIORITY_LOW, + TransactionTag: tag, + CommitPriority: sppb.RequestOptions_PRIORITY_LOW, + BeginTransactionOption: spanner.ExplicitBeginTransaction, }, }) if err != nil { @@ -4873,7 +4879,7 @@ func TestBeginReadWriteTransaction(t *testing.T) { t.Fatalf("missing transaction for ExecuteSqlRequest") } if req.Transaction.GetId() == nil { - t.Fatalf("missing id selector for ExecuteSqlRequest") + t.Fatalf("missing begin selector for ExecuteSqlRequest") } if g, w := req.RequestOptions.TransactionTag, tag; g != w { t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w)