Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ type conn struct {
// transactions on this connection. This default is ignored if the BeginTx function is
// called with an isolation level other than sql.LevelDefault.
isolationLevel sql.IsolationLevel
// beginTransactionOption determines the default transactions start mode.
beginTransactionOption spanner.BeginTransactionOption

// execOptions are applied to the next statement or transaction that is executed
// on this connection. It can also be set by passing it in as an argument to
Expand Down Expand Up @@ -660,6 +662,7 @@ func (c *conn) ResetSession(_ context.Context) error {
c.autoBatchDmlUpdateCountVerification = !c.connector.connectorConfig.DisableAutoBatchDmlUpdateCountVerification
c.retryAborts = c.connector.retryAbortsInternally
c.isolationLevel = c.connector.connectorConfig.IsolationLevel
c.beginTransactionOption = c.connector.connectorConfig.BeginTransactionOption
// TODO: Reset the following fields to the connector default
c.autocommitDMLMode = Transactional
c.readOnlyStaleness = spanner.TimestampBound{}
Expand Down Expand Up @@ -927,7 +930,9 @@ func (c *conn) withTempTransactionOptions(options *ReadWriteTransactionOptions)
func (c *conn) getTransactionOptions() ReadWriteTransactionOptions {
if c.tempTransactionOptions != nil {
defer func() { c.tempTransactionOptions = nil }()
return *c.tempTransactionOptions
opts := *c.tempTransactionOptions
opts.TransactionOptions.BeginTransactionOption = c.convertDefaultBeginTransactionOption(opts.TransactionOptions.BeginTransactionOption)
return opts
}
// Clear the transaction tag that has been set on the connection after returning
// from this function.
Expand All @@ -948,7 +953,7 @@ func (c *conn) getTransactionOptions() ReadWriteTransactionOptions {
}
}
if txOpts.TransactionOptions.BeginTransactionOption == spanner.DefaultBeginTransaction {
txOpts.TransactionOptions.BeginTransactionOption = spanner.InlinedBeginTransaction
txOpts.TransactionOptions.BeginTransactionOption = c.convertDefaultBeginTransactionOption(c.beginTransactionOption)
}
return txOpts
}
Expand All @@ -960,9 +965,11 @@ func (c *conn) withTempReadOnlyTransactionOptions(options *ReadOnlyTransactionOp
func (c *conn) getReadOnlyTransactionOptions() ReadOnlyTransactionOptions {
if c.tempReadOnlyTransactionOptions != nil {
defer func() { c.tempReadOnlyTransactionOptions = nil }()
return *c.tempReadOnlyTransactionOptions
opts := *c.tempReadOnlyTransactionOptions
opts.BeginTransactionOption = c.convertDefaultBeginTransactionOption(opts.BeginTransactionOption)
return opts
}
return ReadOnlyTransactionOptions{TimestampBound: c.readOnlyStaleness, BeginTransactionOption: spanner.InlinedBeginTransaction}
return ReadOnlyTransactionOptions{TimestampBound: c.readOnlyStaleness, BeginTransactionOption: c.convertDefaultBeginTransactionOption(c.beginTransactionOption)}
}

func (c *conn) withTempBatchReadOnlyTransactionOptions(options *BatchReadOnlyTransactionOptions) {
Expand Down Expand Up @@ -1083,6 +1090,16 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
return c.tx, nil
}

func (c *conn) convertDefaultBeginTransactionOption(opt spanner.BeginTransactionOption) spanner.BeginTransactionOption {
if opt == spanner.DefaultBeginTransaction {
if c.beginTransactionOption == spanner.DefaultBeginTransaction {
return spanner.InlinedBeginTransaction
}
return c.beginTransactionOption
}
return opt
}

func (c *conn) inTransaction() bool {
return c.tx != nil
}
Expand Down
47 changes: 47 additions & 0 deletions conn_with_mockserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,53 @@ func TestBeginTx(t *testing.T) {
}
}

func TestExplicitBeginTx(t *testing.T) {
t.Parallel()

db, server, teardown := setupTestDBConnectionWithConnectorConfig(t, ConnectorConfig{
Project: "p",
Instance: "i",
Database: "d",

BeginTransactionOption: spanner.ExplicitBeginTransaction,
})
defer teardown()
ctx := context.Background()

for _, readOnly := range []bool{true, false} {
tx, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: readOnly})
if err != nil {
t.Fatal(err)
}
res, err := tx.QueryContext(ctx, testutil.SelectFooFromBar)
if err != nil {
t.Fatal(err)
}
for res.Next() {
}
if err := res.Err(); err != nil {
t.Fatal(err)
}
if err := tx.Rollback(); err != nil {
t.Fatal(err)
}

requests := drainRequestsFromServer(server.TestSpanner)
beginRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{}))
if g, w := len(beginRequests), 1; g != w {
t.Fatalf("begin requests count mismatch\n Got: %v\nWant: %v", g, w)
}
executeRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{}))
if g, w := len(executeRequests), 1; g != w {
t.Fatalf("execute requests count mismatch\n Got: %v\nWant: %v", g, w)
}
request := executeRequests[0].(*spannerpb.ExecuteSqlRequest)
if request.GetTransaction() == nil || request.GetTransaction().GetId() == nil {
t.Fatal("missing transaction id on ExecuteSqlRequest")
}
}
}

func TestBeginTxWithIsolationLevel(t *testing.T) {
t.Parallel()

Expand Down
22 changes: 22 additions & 0 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,11 @@ type ConnectorConfig struct {
// IsolationLevel is the default isolation level for read/write transactions.
IsolationLevel sql.IsolationLevel

// BeginTransactionOption determines the default for how to begin transactions.
// The Spanner database/sql driver uses spanner.InlinedBeginTransaction by default
// for both read-only and read/write transactions.
BeginTransactionOption spanner.BeginTransactionOption

// DecodeToNativeArrays determines whether arrays that have a Go native
// type should be decoded to those types rather than the corresponding
// spanner.NullTypeName type.
Expand Down Expand Up @@ -551,6 +556,11 @@ func createConnector(d *Driver, connectorConfig ConnectorConfig) (*connector, er
connectorConfig.IsolationLevel = val
}
}
if strval, ok := connectorConfig.Params[strings.ToLower("BeginTransactionOption")]; ok {
if val, err := parseBeginTransactionOption(strval); err == nil {
connectorConfig.BeginTransactionOption = val
}
}
if strval, ok := connectorConfig.Params[strings.ToLower("StatementCacheSize")]; ok {
if val, err := strconv.Atoi(strval); err == nil {
connectorConfig.StatementCacheSize = val
Expand Down Expand Up @@ -1284,6 +1294,18 @@ func checkIsValidType(v driver.Value) bool {
return true
}

func parseBeginTransactionOption(val string) (spanner.BeginTransactionOption, error) {
switch strings.ToLower(val) {
case strings.ToLower("DefaultBeginTransaction"):
return spanner.DefaultBeginTransaction, nil
case strings.ToLower("InlinedBeginTransaction"):
return spanner.InlinedBeginTransaction, nil
case strings.ToLower("ExplicitBeginTransaction"):
return spanner.ExplicitBeginTransaction, nil
}
return spanner.DefaultBeginTransaction, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "invalid or unsupported BeginTransactionOption: %v", val))
}

func parseIsolationLevel(val string) (sql.IsolationLevel, error) {
switch strings.Replace(strings.ToLower(strings.TrimSpace(val)), " ", "_", 1) {
case "default":
Expand Down
62 changes: 60 additions & 2 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,23 @@ func TestExtractDnsParts(t *testing.T) {
Params: map[string]string{
"isolationlevel": "repeatable_read",
},
IsolationLevel: sql.LevelRepeatableRead,
},
wantSpannerConfig: spanner.ClientConfig{
SessionPoolConfig: spanner.DefaultSessionPoolConfig,
UserAgent: userAgent,
},
},
{
input: "projects/p/instances/i/databases/d?beginTransactionOption=ExplicitBeginTransaction",
wantConnectorConfig: ConnectorConfig{
Project: "p",
Instance: "i",
Database: "d",
Params: map[string]string{
"begintransactionoption": "ExplicitBeginTransaction",
},
BeginTransactionOption: spanner.ExplicitBeginTransaction,
},
wantSpannerConfig: spanner.ClientConfig{
SessionPoolConfig: spanner.DefaultSessionPoolConfig,
Expand All @@ -186,6 +203,7 @@ func TestExtractDnsParts(t *testing.T) {
Params: map[string]string{
"statementcachesize": "100",
},
StatementCacheSize: 100,
},
wantSpannerConfig: spanner.ClientConfig{
SessionPoolConfig: spanner.DefaultSessionPoolConfig,
Expand Down Expand Up @@ -252,8 +270,7 @@ func TestExtractDnsParts(t *testing.T) {
if tc.wantErr {
t.Error("did not encounter expected error")
}
tc.wantConnectorConfig.name = tc.input
if diff := cmp.Diff(config, tc.wantConnectorConfig, cmp.AllowUnexported(ConnectorConfig{})); diff != "" {
if diff := cmp.Diff(config.Params, tc.wantConnectorConfig.Params); diff != "" {
t.Errorf("connector config mismatch for %q\n%v", tc.input, diff)
}
conn, err := newOrCachedConnector(&Driver{connectors: make(map[string]*connector)}, tc.input)
Expand All @@ -263,6 +280,47 @@ func TestExtractDnsParts(t *testing.T) {
if diff := cmp.Diff(conn.spannerClientConfig, tc.wantSpannerConfig, cmpopts.IgnoreUnexported(spanner.ClientConfig{}, spanner.SessionPoolConfig{}, spanner.InactiveTransactionRemovalOptions{}, spannerpb.ExecuteSqlRequest_QueryOptions{})); diff != "" {
t.Errorf("connector Spanner client config mismatch for %q\n%v", tc.input, diff)
}
actualConfig := conn.connectorConfig
actualConfig.name = ""
if diff := cmp.Diff(actualConfig, tc.wantConnectorConfig, cmp.AllowUnexported(ConnectorConfig{})); diff != "" {
t.Errorf("actual connector config mismatch for %q\n%v", tc.input, diff)
}
}
})
}
}

func TestParseBeginTransactionOption(t *testing.T) {
tests := []struct {
input string
want spanner.BeginTransactionOption
wantErr bool
}{
{
input: "DefaultBeginTransaction",
want: spanner.DefaultBeginTransaction,
},
{
input: "InlinedBeginTransaction",
want: spanner.InlinedBeginTransaction,
},
{
input: "ExplicitBeginTransaction",
want: spanner.ExplicitBeginTransaction,
},
{
input: "invalid",
wantErr: true,
},
}
for i, test := range tests {
t.Run(test.input, func(t *testing.T) {
val, err := parseBeginTransactionOption(test.input)
if (err != nil) != test.wantErr {
t.Errorf("%d: parseBeginTransactionOption(%q) error = %v, wantErr %v", i, err, test.wantErr, err)
}
if g, w := val, test.want; g != w {
t.Errorf("%d: parseBeginTransactionOption(%q) = %v, want %v", i, g, w, g)
}
})
}
Expand Down
40 changes: 23 additions & 17 deletions driver_with_mockserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4418,17 +4418,13 @@ func TestRunTransaction(t *testing.T) {
if req.Transaction == nil {
t.Fatalf("missing transaction for ExecuteSqlRequest")
}
if req.Transaction.GetId() == nil {
t.Fatalf("missing id selector for ExecuteSqlRequest")
if req.Transaction.GetBegin() == nil {
t.Fatalf("missing begin selector for ExecuteSqlRequest")
}
commitRequests := requestsOfType(requests, reflect.TypeOf(&sppb.CommitRequest{}))
if g, w := len(commitRequests), 1; g != w {
t.Fatalf("commit requests count mismatch\nGot: %v\nWant: %v", g, w)
}
commitReq := commitRequests[0].(*sppb.CommitRequest)
if c, e := commitReq.GetTransactionId(), req.Transaction.GetId(); !cmp.Equal(c, e) {
t.Fatalf("transaction id mismatch\nCommit: %c\nExecute: %v", c, e)
}
}

func TestRunTransactionCommitAborted(t *testing.T) {
Expand Down Expand Up @@ -4494,12 +4490,19 @@ func TestRunTransactionCommitAborted(t *testing.T) {
if req.Transaction == nil {
t.Fatalf("missing transaction for ExecuteSqlRequest")
}
if req.Transaction.GetId() == nil {
t.Fatalf("missing id selector for ExecuteSqlRequest")
}
commitReq := commitRequests[i].(*sppb.CommitRequest)
if c, e := commitReq.GetTransactionId(), req.Transaction.GetId(); !cmp.Equal(c, e) {
t.Fatalf("transaction id mismatch\nCommit: %c\nExecute: %v", c, e)
if i == 0 {
if req.Transaction.GetBegin() == nil {
t.Fatalf("missing begin selector for ExecuteSqlRequest")
}
} else {
// The retried transaction uses an explicit BeginTransaction RPC.
if req.Transaction.GetId() == nil {
t.Fatalf("missing id selector for ExecuteSqlRequest")
}
commitReq := commitRequests[i].(*sppb.CommitRequest)
if c, e := commitReq.GetTransactionId(), req.Transaction.GetId(); !cmp.Equal(c, e) {
t.Fatalf("transaction id mismatch\nCommit: %c\nExecute: %v", c, e)
}
}
}
}
Expand Down Expand Up @@ -4633,9 +4636,11 @@ func TestRunTransactionQueryError(t *testing.T) {
if g, w := len(commitRequests), 0; g != w {
t.Fatalf("commit requests count mismatch\nGot: %v\nWant: %v", g, w)
}
// There should be a RollbackRequest, as the transaction failed.
// There is no RollbackRequest, as the transaction was never started.
// The ExecuteSqlRequest included a BeginTransaction option, but because that
// request failed, the transaction was not started.
rollbackRequests := requestsOfType(requests, reflect.TypeOf(&sppb.RollbackRequest{}))
if g, w := len(rollbackRequests), 1; g != w {
if g, w := len(rollbackRequests), 0; g != w {
t.Fatalf("rollback requests count mismatch\nGot: %v\nWant: %v", g, w)
}
}
Expand Down Expand Up @@ -4806,8 +4811,9 @@ func TestBeginReadWriteTransaction(t *testing.T) {
tx, err := BeginReadWriteTransaction(ctx, db, ReadWriteTransactionOptions{
DisableInternalRetries: true,
TransactionOptions: spanner.TransactionOptions{
TransactionTag: tag,
CommitPriority: sppb.RequestOptions_PRIORITY_LOW,
TransactionTag: tag,
CommitPriority: sppb.RequestOptions_PRIORITY_LOW,
BeginTransactionOption: spanner.ExplicitBeginTransaction,
},
})
if err != nil {
Expand Down Expand Up @@ -4873,7 +4879,7 @@ func TestBeginReadWriteTransaction(t *testing.T) {
t.Fatalf("missing transaction for ExecuteSqlRequest")
}
if req.Transaction.GetId() == nil {
t.Fatalf("missing id selector for ExecuteSqlRequest")
t.Fatalf("missing begin selector for ExecuteSqlRequest")
}
if g, w := req.RequestOptions.TransactionTag, tag; g != w {
t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w)
Expand Down
Loading