diff --git a/conn.go b/conn.go index e11cec54..57fae247 100644 --- a/conn.go +++ b/conn.go @@ -248,6 +248,8 @@ type conn struct { // transactions on this connection. This default is ignored if the BeginTx function is // called with an isolation level other than sql.LevelDefault. isolationLevel sql.IsolationLevel + // beginTransactionOption determines the default transactions start mode. + beginTransactionOption spanner.BeginTransactionOption // execOptions are applied to the next statement or transaction that is executed // on this connection. It can also be set by passing it in as an argument to @@ -660,6 +662,7 @@ func (c *conn) ResetSession(_ context.Context) error { c.autoBatchDmlUpdateCountVerification = !c.connector.connectorConfig.DisableAutoBatchDmlUpdateCountVerification c.retryAborts = c.connector.retryAbortsInternally c.isolationLevel = c.connector.connectorConfig.IsolationLevel + c.beginTransactionOption = c.connector.connectorConfig.BeginTransactionOption // TODO: Reset the following fields to the connector default c.autocommitDMLMode = Transactional c.readOnlyStaleness = spanner.TimestampBound{} @@ -927,7 +930,9 @@ func (c *conn) withTempTransactionOptions(options *ReadWriteTransactionOptions) func (c *conn) getTransactionOptions() ReadWriteTransactionOptions { if c.tempTransactionOptions != nil { defer func() { c.tempTransactionOptions = nil }() - return *c.tempTransactionOptions + opts := *c.tempTransactionOptions + opts.TransactionOptions.BeginTransactionOption = c.convertDefaultBeginTransactionOption(opts.TransactionOptions.BeginTransactionOption) + return opts } // Clear the transaction tag that has been set on the connection after returning // from this function. @@ -948,7 +953,7 @@ func (c *conn) getTransactionOptions() ReadWriteTransactionOptions { } } if txOpts.TransactionOptions.BeginTransactionOption == spanner.DefaultBeginTransaction { - txOpts.TransactionOptions.BeginTransactionOption = spanner.InlinedBeginTransaction + txOpts.TransactionOptions.BeginTransactionOption = c.convertDefaultBeginTransactionOption(c.beginTransactionOption) } return txOpts } @@ -960,9 +965,11 @@ func (c *conn) withTempReadOnlyTransactionOptions(options *ReadOnlyTransactionOp func (c *conn) getReadOnlyTransactionOptions() ReadOnlyTransactionOptions { if c.tempReadOnlyTransactionOptions != nil { defer func() { c.tempReadOnlyTransactionOptions = nil }() - return *c.tempReadOnlyTransactionOptions + opts := *c.tempReadOnlyTransactionOptions + opts.BeginTransactionOption = c.convertDefaultBeginTransactionOption(opts.BeginTransactionOption) + return opts } - return ReadOnlyTransactionOptions{TimestampBound: c.readOnlyStaleness, BeginTransactionOption: spanner.InlinedBeginTransaction} + return ReadOnlyTransactionOptions{TimestampBound: c.readOnlyStaleness, BeginTransactionOption: c.convertDefaultBeginTransactionOption(c.beginTransactionOption)} } func (c *conn) withTempBatchReadOnlyTransactionOptions(options *BatchReadOnlyTransactionOptions) { @@ -1083,6 +1090,16 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e return c.tx, nil } +func (c *conn) convertDefaultBeginTransactionOption(opt spanner.BeginTransactionOption) spanner.BeginTransactionOption { + if opt == spanner.DefaultBeginTransaction { + if c.beginTransactionOption == spanner.DefaultBeginTransaction { + return spanner.InlinedBeginTransaction + } + return c.beginTransactionOption + } + return opt +} + func (c *conn) inTransaction() bool { return c.tx != nil } diff --git a/conn_with_mockserver_test.go b/conn_with_mockserver_test.go index 70107511..b6be1708 100644 --- a/conn_with_mockserver_test.go +++ b/conn_with_mockserver_test.go @@ -55,6 +55,53 @@ func TestBeginTx(t *testing.T) { } } +func TestExplicitBeginTx(t *testing.T) { + t.Parallel() + + db, server, teardown := setupTestDBConnectionWithConnectorConfig(t, ConnectorConfig{ + Project: "p", + Instance: "i", + Database: "d", + + BeginTransactionOption: spanner.ExplicitBeginTransaction, + }) + defer teardown() + ctx := context.Background() + + for _, readOnly := range []bool{true, false} { + tx, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: readOnly}) + if err != nil { + t.Fatal(err) + } + res, err := tx.QueryContext(ctx, testutil.SelectFooFromBar) + if err != nil { + t.Fatal(err) + } + for res.Next() { + } + if err := res.Err(); err != nil { + t.Fatal(err) + } + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } + + requests := drainRequestsFromServer(server.TestSpanner) + beginRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 1; g != w { + t.Fatalf("begin requests count mismatch\n Got: %v\nWant: %v", g, w) + } + executeRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 1; g != w { + t.Fatalf("execute requests count mismatch\n Got: %v\nWant: %v", g, w) + } + request := executeRequests[0].(*spannerpb.ExecuteSqlRequest) + if request.GetTransaction() == nil || request.GetTransaction().GetId() == nil { + t.Fatal("missing transaction id on ExecuteSqlRequest") + } + } +} + func TestBeginTxWithIsolationLevel(t *testing.T) { t.Parallel() diff --git a/driver.go b/driver.go index 55953181..b50b1aa1 100644 --- a/driver.go +++ b/driver.go @@ -304,6 +304,11 @@ type ConnectorConfig struct { // IsolationLevel is the default isolation level for read/write transactions. IsolationLevel sql.IsolationLevel + // BeginTransactionOption determines the default for how to begin transactions. + // The Spanner database/sql driver uses spanner.InlinedBeginTransaction by default + // for both read-only and read/write transactions. + BeginTransactionOption spanner.BeginTransactionOption + // DecodeToNativeArrays determines whether arrays that have a Go native // type should be decoded to those types rather than the corresponding // spanner.NullTypeName type. @@ -551,6 +556,11 @@ func createConnector(d *Driver, connectorConfig ConnectorConfig) (*connector, er connectorConfig.IsolationLevel = val } } + if strval, ok := connectorConfig.Params[strings.ToLower("BeginTransactionOption")]; ok { + if val, err := parseBeginTransactionOption(strval); err == nil { + connectorConfig.BeginTransactionOption = val + } + } if strval, ok := connectorConfig.Params[strings.ToLower("StatementCacheSize")]; ok { if val, err := strconv.Atoi(strval); err == nil { connectorConfig.StatementCacheSize = val @@ -1284,6 +1294,18 @@ func checkIsValidType(v driver.Value) bool { return true } +func parseBeginTransactionOption(val string) (spanner.BeginTransactionOption, error) { + switch strings.ToLower(val) { + case strings.ToLower("DefaultBeginTransaction"): + return spanner.DefaultBeginTransaction, nil + case strings.ToLower("InlinedBeginTransaction"): + return spanner.InlinedBeginTransaction, nil + case strings.ToLower("ExplicitBeginTransaction"): + return spanner.ExplicitBeginTransaction, nil + } + return spanner.DefaultBeginTransaction, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "invalid or unsupported BeginTransactionOption: %v", val)) +} + func parseIsolationLevel(val string) (sql.IsolationLevel, error) { switch strings.Replace(strings.ToLower(strings.TrimSpace(val)), " ", "_", 1) { case "default": diff --git a/driver_test.go b/driver_test.go index e9454843..8873ad0e 100644 --- a/driver_test.go +++ b/driver_test.go @@ -171,6 +171,23 @@ func TestExtractDnsParts(t *testing.T) { Params: map[string]string{ "isolationlevel": "repeatable_read", }, + IsolationLevel: sql.LevelRepeatableRead, + }, + wantSpannerConfig: spanner.ClientConfig{ + SessionPoolConfig: spanner.DefaultSessionPoolConfig, + UserAgent: userAgent, + }, + }, + { + input: "projects/p/instances/i/databases/d?beginTransactionOption=ExplicitBeginTransaction", + wantConnectorConfig: ConnectorConfig{ + Project: "p", + Instance: "i", + Database: "d", + Params: map[string]string{ + "begintransactionoption": "ExplicitBeginTransaction", + }, + BeginTransactionOption: spanner.ExplicitBeginTransaction, }, wantSpannerConfig: spanner.ClientConfig{ SessionPoolConfig: spanner.DefaultSessionPoolConfig, @@ -186,6 +203,7 @@ func TestExtractDnsParts(t *testing.T) { Params: map[string]string{ "statementcachesize": "100", }, + StatementCacheSize: 100, }, wantSpannerConfig: spanner.ClientConfig{ SessionPoolConfig: spanner.DefaultSessionPoolConfig, @@ -252,8 +270,7 @@ func TestExtractDnsParts(t *testing.T) { if tc.wantErr { t.Error("did not encounter expected error") } - tc.wantConnectorConfig.name = tc.input - if diff := cmp.Diff(config, tc.wantConnectorConfig, cmp.AllowUnexported(ConnectorConfig{})); diff != "" { + if diff := cmp.Diff(config.Params, tc.wantConnectorConfig.Params); diff != "" { t.Errorf("connector config mismatch for %q\n%v", tc.input, diff) } conn, err := newOrCachedConnector(&Driver{connectors: make(map[string]*connector)}, tc.input) @@ -263,6 +280,47 @@ func TestExtractDnsParts(t *testing.T) { if diff := cmp.Diff(conn.spannerClientConfig, tc.wantSpannerConfig, cmpopts.IgnoreUnexported(spanner.ClientConfig{}, spanner.SessionPoolConfig{}, spanner.InactiveTransactionRemovalOptions{}, spannerpb.ExecuteSqlRequest_QueryOptions{})); diff != "" { t.Errorf("connector Spanner client config mismatch for %q\n%v", tc.input, diff) } + actualConfig := conn.connectorConfig + actualConfig.name = "" + if diff := cmp.Diff(actualConfig, tc.wantConnectorConfig, cmp.AllowUnexported(ConnectorConfig{})); diff != "" { + t.Errorf("actual connector config mismatch for %q\n%v", tc.input, diff) + } + } + }) + } +} + +func TestParseBeginTransactionOption(t *testing.T) { + tests := []struct { + input string + want spanner.BeginTransactionOption + wantErr bool + }{ + { + input: "DefaultBeginTransaction", + want: spanner.DefaultBeginTransaction, + }, + { + input: "InlinedBeginTransaction", + want: spanner.InlinedBeginTransaction, + }, + { + input: "ExplicitBeginTransaction", + want: spanner.ExplicitBeginTransaction, + }, + { + input: "invalid", + wantErr: true, + }, + } + for i, test := range tests { + t.Run(test.input, func(t *testing.T) { + val, err := parseBeginTransactionOption(test.input) + if (err != nil) != test.wantErr { + t.Errorf("%d: parseBeginTransactionOption(%q) error = %v, wantErr %v", i, err, test.wantErr, err) + } + if g, w := val, test.want; g != w { + t.Errorf("%d: parseBeginTransactionOption(%q) = %v, want %v", i, g, w, g) } }) } diff --git a/driver_with_mockserver_test.go b/driver_with_mockserver_test.go index 537db19a..edb382c5 100644 --- a/driver_with_mockserver_test.go +++ b/driver_with_mockserver_test.go @@ -4418,17 +4418,13 @@ func TestRunTransaction(t *testing.T) { if req.Transaction == nil { t.Fatalf("missing transaction for ExecuteSqlRequest") } - if req.Transaction.GetId() == nil { - t.Fatalf("missing id selector for ExecuteSqlRequest") + if req.Transaction.GetBegin() == nil { + t.Fatalf("missing begin selector for ExecuteSqlRequest") } commitRequests := requestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{})) if g, w := len(commitRequests), 1; g != w { t.Fatalf("commit requests count mismatch\nGot: %v\nWant: %v", g, w) } - commitReq := commitRequests[0].(*sppb.CommitRequest) - if c, e := commitReq.GetTransactionId(), req.Transaction.GetId(); !cmp.Equal(c, e) { - t.Fatalf("transaction id mismatch\nCommit: %c\nExecute: %v", c, e) - } } func TestRunTransactionCommitAborted(t *testing.T) { @@ -4494,12 +4490,19 @@ func TestRunTransactionCommitAborted(t *testing.T) { if req.Transaction == nil { t.Fatalf("missing transaction for ExecuteSqlRequest") } - if req.Transaction.GetId() == nil { - t.Fatalf("missing id selector for ExecuteSqlRequest") - } - commitReq := commitRequests[i].(*sppb.CommitRequest) - if c, e := commitReq.GetTransactionId(), req.Transaction.GetId(); !cmp.Equal(c, e) { - t.Fatalf("transaction id mismatch\nCommit: %c\nExecute: %v", c, e) + if i == 0 { + if req.Transaction.GetBegin() == nil { + t.Fatalf("missing begin selector for ExecuteSqlRequest") + } + } else { + // The retried transaction uses an explicit BeginTransaction RPC. + if req.Transaction.GetId() == nil { + t.Fatalf("missing id selector for ExecuteSqlRequest") + } + commitReq := commitRequests[i].(*sppb.CommitRequest) + if c, e := commitReq.GetTransactionId(), req.Transaction.GetId(); !cmp.Equal(c, e) { + t.Fatalf("transaction id mismatch\nCommit: %c\nExecute: %v", c, e) + } } } } @@ -4633,9 +4636,11 @@ func TestRunTransactionQueryError(t *testing.T) { if g, w := len(commitRequests), 0; g != w { t.Fatalf("commit requests count mismatch\nGot: %v\nWant: %v", g, w) } - // There should be a RollbackRequest, as the transaction failed. + // There is no RollbackRequest, as the transaction was never started. + // The ExecuteSqlRequest included a BeginTransaction option, but because that + // request failed, the transaction was not started. rollbackRequests := requestsOfType(requests, reflect.TypeOf(&sppb.RollbackRequest{})) - if g, w := len(rollbackRequests), 1; g != w { + if g, w := len(rollbackRequests), 0; g != w { t.Fatalf("rollback requests count mismatch\nGot: %v\nWant: %v", g, w) } } @@ -4806,8 +4811,9 @@ func TestBeginReadWriteTransaction(t *testing.T) { tx, err := BeginReadWriteTransaction(ctx, db, ReadWriteTransactionOptions{ DisableInternalRetries: true, TransactionOptions: spanner.TransactionOptions{ - TransactionTag: tag, - CommitPriority: sppb.RequestOptions_PRIORITY_LOW, + TransactionTag: tag, + CommitPriority: sppb.RequestOptions_PRIORITY_LOW, + BeginTransactionOption: spanner.ExplicitBeginTransaction, }, }) if err != nil { @@ -4873,7 +4879,7 @@ func TestBeginReadWriteTransaction(t *testing.T) { t.Fatalf("missing transaction for ExecuteSqlRequest") } if req.Transaction.GetId() == nil { - t.Fatalf("missing id selector for ExecuteSqlRequest") + t.Fatalf("missing begin selector for ExecuteSqlRequest") } if g, w := req.RequestOptions.TransactionTag, tag; g != w { t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w)