Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions aborted_transactions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ func TestQueryWithError_CommitAborted(t *testing.T) {
server.PutExecutionTime(testutil.MethodCommitTransaction, testutil.SimulatedExecutionTime{
Errors: []error{status.Error(codes.Aborted, "Aborted")},
})
}, codes.NotFound, 0, 2, 2)
}, codes.NotFound, 0, 3, 2)
}

func TestQueryWithErrorHalfway_CommitAborted(t *testing.T) {
Expand Down Expand Up @@ -1080,7 +1080,7 @@ func TestBatchUpdateAbortedWithError_DifferentErrorDuringRetry(t *testing.T) {
t.Fatalf("dml statement failed: %v", err)
}
if _, err := tx.ExecContext(ctx, "RUN BATCH"); spanner.ErrCode(err) != codes.NotFound {
t.Fatalf("error code mismatch\nGot: %v\nWant: %v", spanner.ErrCode(err), codes.NotFound)
t.Fatalf("error code mismatch\n Got: %v\nWant: %v", spanner.ErrCode(err), codes.NotFound)
}

// Remove the error for the DML statement and cause a retry. The missing
Expand All @@ -1094,19 +1094,37 @@ func TestBatchUpdateAbortedWithError_DifferentErrorDuringRetry(t *testing.T) {
})
err = tx.Commit()
if err != ErrAbortedDueToConcurrentModification {
t.Fatalf("commit error mismatch\nGot: %v\nWant: %v", err, ErrAbortedDueToConcurrentModification)
t.Fatalf("commit error mismatch\n Got: %v\nWant: %v", err, ErrAbortedDueToConcurrentModification)
}
reqs := drainRequestsFromServer(server.TestSpanner)
execReqs := requestsOfType(reqs, reflect.TypeOf(&sppb.ExecuteBatchDmlRequest{}))
if g, w := len(execReqs), 2; g != w {
t.Fatalf("batch request count mismatch\nGot: %v\nWant: %v", g, w)
// There are 3 ExecuteBatchDmlRequests sent to Spanner:
// 1. An initial attempt with a BeginTransaction RPC, but this returns a NotFound error.
// This causes the transaction to be retried with an explicit BeginTransaction request.
// 2. Another attempt with a transaction ID.
// 3. A third attempt after the initial transaction is aborted.
if g, w := len(execReqs), 3; g != w {
t.Fatalf("batch request count mismatch\n Got: %v\nWant: %v", g, w)
}
commitReqs := requestsOfType(reqs, reflect.TypeOf(&sppb.CommitRequest{}))
// The commit should be attempted only once.
if g, w := len(commitReqs), 1; g != w {
t.Fatalf("commit request count mismatch\nGot: %v\nWant: %v", g, w)
t.Fatalf("commit request count mismatch\n Got: %v\nWant: %v", g, w)
}
// The first ExecuteBatchDml request should try to use an inline-begin.
// After that, we should have two BeginTransaction requests.
req1 := execReqs[0].(*sppb.ExecuteBatchDmlRequest)
if req1.GetTransaction() == nil || req1.GetTransaction().GetBegin() == nil {
t.Fatal("the first ExecuteBatchDmlRequest should have a BeginTransaction")
}
req2 := execReqs[1].(*sppb.ExecuteBatchDmlRequest)
if req2.GetTransaction() == nil || req2.GetTransaction().GetId() == nil {
t.Fatal("the second ExecuteBatchDmlRequest should have a transaction id")
}
beginRequests := requestsOfType(reqs, reflect.TypeOf(&sppb.BeginTransactionRequest{}))
if g, w := len(beginRequests), 2; g != w {
t.Fatalf("begin request count mismatch\n Got: %v\nWant: %v", g, w)
}

// Verify that the db is still usable.
if _, err := db.ExecContext(ctx, testutil.UpdateSingersSetLastName); err != nil {
t.Fatalf("failed to execute statement after transaction: %v", err)
Expand Down
7 changes: 6 additions & 1 deletion auto_dml_batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,13 @@ func TestAutoBatchDml_FollowedByRollback(t *testing.T) {
if g, w := len(commitRequests), 0; g != w {
t.Fatalf("num commit requests mismatch\n Got: %v\nWant: %v", g, w)
}
beginRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{}))
if g, w := len(beginRequests), 0; g != w {
t.Fatalf("num BeginTransaction requests mismatch\n Got: %v\nWant: %v", g, w)
}
rollbackRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.RollbackRequest{}))
if g, w := len(rollbackRequests), 1; g != w {
// There are no rollback requests sent to Spanner, as the transaction is never started.
if g, w := len(rollbackRequests), 0; g != w {
t.Fatalf("num rollback requests mismatch\n Got: %v\nWant: %v", g, w)
}
}
Expand Down
28 changes: 24 additions & 4 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ type conn struct {
// transactions on this connection. This default is ignored if the BeginTx function is
// called with an isolation level other than sql.LevelDefault.
isolationLevel sql.IsolationLevel
// beginTransactionOption determines the default transactions start mode.
beginTransactionOption spanner.BeginTransactionOption

// execOptions are applied to the next statement or transaction that is executed
// on this connection. It can also be set by passing it in as an argument to
Expand Down Expand Up @@ -660,6 +662,7 @@ func (c *conn) ResetSession(_ context.Context) error {
c.autoBatchDmlUpdateCountVerification = !c.connector.connectorConfig.DisableAutoBatchDmlUpdateCountVerification
c.retryAborts = c.connector.retryAbortsInternally
c.isolationLevel = c.connector.connectorConfig.IsolationLevel
c.beginTransactionOption = c.connector.connectorConfig.BeginTransactionOption
// TODO: Reset the following fields to the connector default
c.autocommitDMLMode = Transactional
c.readOnlyStaleness = spanner.TimestampBound{}
Expand Down Expand Up @@ -927,7 +930,9 @@ func (c *conn) withTempTransactionOptions(options *ReadWriteTransactionOptions)
func (c *conn) getTransactionOptions() ReadWriteTransactionOptions {
if c.tempTransactionOptions != nil {
defer func() { c.tempTransactionOptions = nil }()
return *c.tempTransactionOptions
opts := *c.tempTransactionOptions
opts.TransactionOptions.BeginTransactionOption = c.convertDefaultBeginTransactionOption(opts.TransactionOptions.BeginTransactionOption)
return opts
}
// Clear the transaction tag that has been set on the connection after returning
// from this function.
Expand All @@ -947,6 +952,9 @@ func (c *conn) getTransactionOptions() ReadWriteTransactionOptions {
txOpts.TransactionOptions.IsolationLevel = level
}
}
if txOpts.TransactionOptions.BeginTransactionOption == spanner.DefaultBeginTransaction {
txOpts.TransactionOptions.BeginTransactionOption = c.convertDefaultBeginTransactionOption(c.beginTransactionOption)
}
return txOpts
}

Expand All @@ -957,9 +965,11 @@ func (c *conn) withTempReadOnlyTransactionOptions(options *ReadOnlyTransactionOp
func (c *conn) getReadOnlyTransactionOptions() ReadOnlyTransactionOptions {
if c.tempReadOnlyTransactionOptions != nil {
defer func() { c.tempReadOnlyTransactionOptions = nil }()
return *c.tempReadOnlyTransactionOptions
opts := *c.tempReadOnlyTransactionOptions
opts.BeginTransactionOption = c.convertDefaultBeginTransactionOption(opts.BeginTransactionOption)
return opts
}
return ReadOnlyTransactionOptions{TimestampBound: c.readOnlyStaleness}
return ReadOnlyTransactionOptions{TimestampBound: c.readOnlyStaleness, BeginTransactionOption: c.convertDefaultBeginTransactionOption(c.beginTransactionOption)}
}

func (c *conn) withTempBatchReadOnlyTransactionOptions(options *BatchReadOnlyTransactionOptions) {
Expand Down Expand Up @@ -1034,7 +1044,7 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
ro = &bo.ReadOnlyTransaction
} else {
logger = c.logger.With("tx", "ro")
ro = c.client.ReadOnlyTransaction().WithTimestampBound(readOnlyTxOpts.TimestampBound)
ro = c.client.ReadOnlyTransaction().WithBeginTransactionOption(readOnlyTxOpts.BeginTransactionOption).WithTimestampBound(readOnlyTxOpts.TimestampBound)
}
c.tx = &readOnlyTransaction{
roTx: ro,
Expand Down Expand Up @@ -1080,6 +1090,16 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
return c.tx, nil
}

func (c *conn) convertDefaultBeginTransactionOption(opt spanner.BeginTransactionOption) spanner.BeginTransactionOption {
if opt == spanner.DefaultBeginTransaction {
if c.beginTransactionOption == spanner.DefaultBeginTransaction {
return spanner.InlinedBeginTransaction
}
return c.beginTransactionOption
}
return opt
}

func (c *conn) inTransaction() bool {
return c.tx != nil
}
Expand Down
86 changes: 77 additions & 9 deletions conn_with_mockserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,69 @@ func TestBeginTx(t *testing.T) {

requests := drainRequestsFromServer(server.TestSpanner)
beginRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{}))
if g, w := len(beginRequests), 1; g != w {
if g, w := len(beginRequests), 0; g != w {
t.Fatalf("begin requests count mismatch\n Got: %v\nWant: %v", g, w)
}
request := beginRequests[0].(*spannerpb.BeginTransactionRequest)
if g, w := request.Options.GetIsolationLevel(), spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED; g != w {
executeRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{}))
if g, w := len(executeRequests), 1; g != w {
t.Fatalf("execute requests count mismatch\n Got: %v\nWant: %v", g, w)
}
request := executeRequests[0].(*spannerpb.ExecuteSqlRequest)
if request.GetTransaction() == nil || request.GetTransaction().GetBegin() == nil {
t.Fatal("missing begin transaction on ExecuteSqlRequest")
}
if g, w := request.GetTransaction().GetBegin().GetIsolationLevel(), spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED; g != w {
t.Fatalf("begin isolation level mismatch\n Got: %v\nWant: %v", g, w)
}
}

func TestExplicitBeginTx(t *testing.T) {
t.Parallel()

db, server, teardown := setupTestDBConnectionWithConnectorConfig(t, ConnectorConfig{
Project: "p",
Instance: "i",
Database: "d",

BeginTransactionOption: spanner.ExplicitBeginTransaction,
})
defer teardown()
ctx := context.Background()

for _, readOnly := range []bool{true, false} {
tx, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: readOnly})
if err != nil {
t.Fatal(err)
}
res, err := tx.QueryContext(ctx, testutil.SelectFooFromBar)
if err != nil {
t.Fatal(err)
}
for res.Next() {
}
if err := res.Err(); err != nil {
t.Fatal(err)
}
if err := tx.Rollback(); err != nil {
t.Fatal(err)
}

requests := drainRequestsFromServer(server.TestSpanner)
beginRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{}))
if g, w := len(beginRequests), 1; g != w {
t.Fatalf("begin requests count mismatch\n Got: %v\nWant: %v", g, w)
}
executeRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{}))
if g, w := len(executeRequests), 1; g != w {
t.Fatalf("execute requests count mismatch\n Got: %v\nWant: %v", g, w)
}
request := executeRequests[0].(*spannerpb.ExecuteSqlRequest)
if request.GetTransaction() == nil || request.GetTransaction().GetId() == nil {
t.Fatal("missing transaction id on ExecuteSqlRequest")
}
}
}

func TestBeginTxWithIsolationLevel(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -76,12 +130,19 @@ func TestBeginTxWithIsolationLevel(t *testing.T) {

requests := drainRequestsFromServer(server.TestSpanner)
beginRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{}))
if g, w := len(beginRequests), 1; g != w {
if g, w := len(beginRequests), 0; g != w {
t.Fatalf("begin requests count mismatch\n Got: %v\nWant: %v", g, w)
}
request := beginRequests[0].(*spannerpb.BeginTransactionRequest)
executeRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{}))
if g, w := len(executeRequests), 1; g != w {
t.Fatalf("execute requests count mismatch\n Got: %v\nWant: %v", g, w)
}
request := executeRequests[0].(*spannerpb.ExecuteSqlRequest)
if request.GetTransaction() == nil || request.GetTransaction().GetBegin() == nil {
t.Fatalf("execute request does not have a begin transaction")
}
wantIsolationLevel, _ := toProtoIsolationLevel(originalLevel)
if g, w := request.Options.GetIsolationLevel(), wantIsolationLevel; g != w {
if g, w := request.GetTransaction().GetBegin().GetIsolationLevel(), wantIsolationLevel; g != w {
t.Fatalf("begin isolation level mismatch\n Got: %v\nWant: %v", g, w)
}
}
Expand Down Expand Up @@ -162,12 +223,19 @@ func TestDefaultIsolationLevel(t *testing.T) {

requests := drainRequestsFromServer(server.TestSpanner)
beginRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{}))
if g, w := len(beginRequests), 1; g != w {
if g, w := len(beginRequests), 0; g != w {
t.Fatalf("begin requests count mismatch\n Got: %v\nWant: %v", g, w)
}
request := beginRequests[0].(*spannerpb.BeginTransactionRequest)
executeRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{}))
if g, w := len(executeRequests), 1; g != w {
t.Fatalf("execute requests count mismatch\n Got: %v\nWant: %v", g, w)
}
request := executeRequests[0].(*spannerpb.ExecuteSqlRequest)
if request.GetTransaction() == nil || request.GetTransaction().GetBegin() == nil {
t.Fatalf("ExecuteSqlRequest should have a Begin transaction")
}
wantIsolationLevel, _ := toProtoIsolationLevel(originalLevel)
if g, w := request.Options.GetIsolationLevel(), wantIsolationLevel; g != w {
if g, w := request.GetTransaction().GetBegin().GetIsolationLevel(), wantIsolationLevel; g != w {
t.Fatalf("begin isolation level mismatch\n Got: %v\nWant: %v", g, w)
}
}
Expand Down
25 changes: 24 additions & 1 deletion driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,11 @@ type ConnectorConfig struct {
// IsolationLevel is the default isolation level for read/write transactions.
IsolationLevel sql.IsolationLevel

// BeginTransactionOption determines the default for how to begin transactions.
// The Spanner database/sql driver uses spanner.InlinedBeginTransaction by default
// for both read-only and read/write transactions.
BeginTransactionOption spanner.BeginTransactionOption

// DecodeToNativeArrays determines whether arrays that have a Go native
// type should be decoded to those types rather than the corresponding
// spanner.NullTypeName type.
Expand Down Expand Up @@ -551,6 +556,11 @@ func createConnector(d *Driver, connectorConfig ConnectorConfig) (*connector, er
connectorConfig.IsolationLevel = val
}
}
if strval, ok := connectorConfig.Params[strings.ToLower("BeginTransactionOption")]; ok {
if val, err := parseBeginTransactionOption(strval); err == nil {
connectorConfig.BeginTransactionOption = val
}
}
if strval, ok := connectorConfig.Params[strings.ToLower("StatementCacheSize")]; ok {
if val, err := strconv.Atoi(strval); err == nil {
connectorConfig.StatementCacheSize = val
Expand Down Expand Up @@ -1026,7 +1036,8 @@ func clearTempReadWriteTransactionOptions(conn *sql.Conn) {
// ReadOnlyTransactionOptions can be used to create a read-only transaction
// on a Spanner connection.
type ReadOnlyTransactionOptions struct {
TimestampBound spanner.TimestampBound
TimestampBound spanner.TimestampBound
BeginTransactionOption spanner.BeginTransactionOption

close func()
}
Expand Down Expand Up @@ -1283,6 +1294,18 @@ func checkIsValidType(v driver.Value) bool {
return true
}

func parseBeginTransactionOption(val string) (spanner.BeginTransactionOption, error) {
switch strings.ToLower(val) {
case strings.ToLower("DefaultBeginTransaction"):
return spanner.DefaultBeginTransaction, nil
case strings.ToLower("InlinedBeginTransaction"):
return spanner.InlinedBeginTransaction, nil
case strings.ToLower("ExplicitBeginTransaction"):
return spanner.ExplicitBeginTransaction, nil
}
return spanner.DefaultBeginTransaction, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "invalid or unsupported BeginTransactionOption: %v", val))
}

func parseIsolationLevel(val string) (sql.IsolationLevel, error) {
switch strings.Replace(strings.ToLower(strings.TrimSpace(val)), " ", "_", 1) {
case "default":
Expand Down
Loading
Loading