Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,13 @@ type SpannerConn interface {
// mode and for read-only transaction.
SetReadOnlyStaleness(staleness spanner.TimestampBound) error

// IsolationLevel returns the current default isolation level that is
// used for read/write transactions on this connection.
IsolationLevel() sql.IsolationLevel
// SetIsolationLevel sets the default isolation level to use for read/write
// transactions on this connection.
SetIsolationLevel(level sql.IsolationLevel) error

// TransactionTag returns the transaction tag that will be applied to the next
// read/write transaction on this connection. The transaction tag that is set
// on the connection is cleared when a read/write transaction is started.
Expand Down Expand Up @@ -235,6 +242,10 @@ type conn struct {
autocommitDMLMode AutocommitDMLMode
// readOnlyStaleness is used for queries in autocommit mode and for read-only transactions.
readOnlyStaleness spanner.TimestampBound
// isolationLevel determines the default isolation level that is used for read/write
// transactions on this connection. This default is ignored if the BeginTx function is
// called with an isolation level other than sql.LevelDefault.
isolationLevel sql.IsolationLevel

// execOptions are applied to the next statement or transaction that is executed
// on this connection. It can also be set by passing it in as an argument to
Expand Down Expand Up @@ -309,6 +320,15 @@ func (c *conn) setReadOnlyStaleness(staleness spanner.TimestampBound) (driver.Re
return driver.ResultNoRows, nil
}

func (c *conn) IsolationLevel() sql.IsolationLevel {
return c.isolationLevel
}

func (c *conn) SetIsolationLevel(level sql.IsolationLevel) error {
c.isolationLevel = level
return nil
}

func (c *conn) MaxCommitDelay() time.Duration {
return *c.execOptions.TransactionOptions.CommitOptions.MaxCommitDelay
}
Expand Down Expand Up @@ -633,6 +653,7 @@ func (c *conn) ResetSession(_ context.Context) error {
c.autoBatchDmlUpdateCount = c.connector.connectorConfig.AutoBatchDmlUpdateCount
c.autoBatchDmlUpdateCountVerification = !c.connector.connectorConfig.DisableAutoBatchDmlUpdateCountVerification
c.retryAborts = c.connector.retryAbortsInternally
c.isolationLevel = c.connector.connectorConfig.IsolationLevel
// TODO: Reset the following fields to the connector default
c.autocommitDMLMode = Transactional
c.readOnlyStaleness = spanner.TimestampBound{}
Expand Down Expand Up @@ -888,10 +909,20 @@ func (c *conn) getTransactionOptions() ReadWriteTransactionOptions {
defer func() {
c.execOptions.TransactionOptions.TransactionTag = ""
}()
return ReadWriteTransactionOptions{
txOpts := ReadWriteTransactionOptions{
TransactionOptions: c.execOptions.TransactionOptions,
DisableInternalRetries: !c.retryAborts,
}
// Only use the default isolation level from the connection if the ExecOptions
// did not contain a more specific isolation level.
if txOpts.TransactionOptions.IsolationLevel == spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED {
// This should never really return an error, but we check just to be absolutely sure.
level, err := toProtoIsolationLevel(c.isolationLevel)
if err == nil {
txOpts.TransactionOptions.IsolationLevel = level
}
}
return txOpts
}

func (c *conn) withTempReadOnlyTransactionOptions(options *ReadOnlyTransactionOptions) {
Expand Down
124 changes: 124 additions & 0 deletions conn_with_mockserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ package spannerdriver
import (
"context"
"database/sql"
"fmt"
"reflect"
"testing"

"cloud.google.com/go/spanner"
"cloud.google.com/go/spanner/apiv1/spannerpb"
"github.com/googleapis/go-sql-spanner/testutil"
)
Expand Down Expand Up @@ -115,3 +117,125 @@ func TestBeginTxWithInvalidIsolationLevel(t *testing.T) {
}
}
}

func TestDefaultIsolationLevel(t *testing.T) {
t.Parallel()

for _, level := range []sql.IsolationLevel{
sql.LevelDefault,
sql.LevelSnapshot,
sql.LevelRepeatableRead,
sql.LevelSerializable,
} {
db, server, teardown := setupTestDBConnectionWithParams(t, fmt.Sprintf("isolationLevel=%v", level))
defer teardown()
ctx := context.Background()

conn, err := db.Conn(ctx)
if err != nil {
t.Fatal(err)
}
if err := conn.Raw(func(driverConn interface{}) error {
spannerConn, ok := driverConn.(SpannerConn)
if !ok {
return fmt.Errorf("expected spanner conn, got %T", driverConn)
}
if spannerConn.IsolationLevel() != level {
return fmt.Errorf("expected isolation level %v, got %v", level, spannerConn.IsolationLevel())
}
return nil
}); err != nil {
t.Fatal(err)
}

originalLevel := level
for _, disableRetryAborts := range []bool{true, false} {
if disableRetryAborts {
level = WithDisableRetryAborts(originalLevel)
} else {
level = originalLevel
}
// Note: No isolation level is passed in here, so it will use the default.
tx, _ := db.BeginTx(ctx, &sql.TxOptions{})
_, _ = tx.ExecContext(ctx, testutil.UpdateBarSetFoo)
_ = tx.Rollback()

requests := drainRequestsFromServer(server.TestSpanner)
beginRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{}))
if g, w := len(beginRequests), 1; g != w {
t.Fatalf("begin requests count mismatch\n Got: %v\nWant: %v", g, w)
}
request := beginRequests[0].(*spannerpb.BeginTransactionRequest)
wantIsolationLevel, _ := toProtoIsolationLevel(originalLevel)
if g, w := request.Options.GetIsolationLevel(), wantIsolationLevel; g != w {
t.Fatalf("begin isolation level mismatch\n Got: %v\nWant: %v", g, w)
}
}
}
}

func TestSetIsolationLevel(t *testing.T) {
t.Parallel()

db, _, teardown := setupTestDBConnection(t)
defer teardown()
ctx := context.Background()

// Repeat twice to ensure that the state is reset after closing the connection.
for i := 0; i < 2; i++ {
conn, err := db.Conn(ctx)
if err != nil {
t.Fatal(err)
}
var level sql.IsolationLevel
_ = conn.Raw(func(driverConn interface{}) error {
level = driverConn.(SpannerConn).IsolationLevel()
return nil
})
if g, w := level, sql.LevelDefault; g != w {
t.Fatalf("isolation level mismatch\n Got: %v\nWant: %v", g, w)
}
_ = conn.Raw(func(driverConn interface{}) error {
return driverConn.(SpannerConn).SetIsolationLevel(sql.LevelSnapshot)
})
_ = conn.Raw(func(driverConn interface{}) error {
level = driverConn.(SpannerConn).IsolationLevel()
return nil
})
if g, w := level, sql.LevelSnapshot; g != w {
t.Fatalf("isolation level mismatch\n Got: %v\nWant: %v", g, w)
}
conn.Close()
}
}

func TestIsolationLevelAutoCommit(t *testing.T) {
t.Parallel()

db, server, teardown := setupTestDBConnection(t)
defer teardown()
ctx := context.Background()

for _, level := range []sql.IsolationLevel{
sql.LevelDefault,
sql.LevelSnapshot,
sql.LevelRepeatableRead,
sql.LevelSerializable,
} {
spannerLevel, _ := toProtoIsolationLevel(level)
_, _ = db.ExecContext(ctx, testutil.UpdateBarSetFoo, ExecOptions{TransactionOptions: spanner.TransactionOptions{
IsolationLevel: spannerLevel,
}})

requests := drainRequestsFromServer(server.TestSpanner)
executeRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{}))
if g, w := len(executeRequests), 1; g != w {
t.Fatalf("execute requests count mismatch\n Got: %v\nWant: %v", g, w)
}
request := executeRequests[0].(*spannerpb.ExecuteSqlRequest)
wantIsolationLevel, _ := toProtoIsolationLevel(level)
if g, w := request.Transaction.GetBegin().GetIsolationLevel(), wantIsolationLevel; g != w {
t.Fatalf("begin isolation level mismatch\n Got: %v\nWant: %v", g, w)
}
}
}
30 changes: 30 additions & 0 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@ type ConnectorConfig struct {
AutoBatchDmlUpdateCount int64
DisableAutoBatchDmlUpdateCountVerification bool

// IsolationLevel is the default isolation level for read/write transactions.
IsolationLevel sql.IsolationLevel

// DecodeToNativeArrays determines whether arrays that have a Go native
// type should be decoded to those types rather than the corresponding
// spanner.NullTypeName type.
Expand Down Expand Up @@ -472,6 +475,11 @@ func createConnector(d *Driver, connectorConfig ConnectorConfig) (*connector, er
connectorConfig.AutoConfigEmulator = val
}
}
if strval, ok := connectorConfig.Params[strings.ToLower("IsolationLevel")]; ok {
if val, err := parseIsolationLevel(strval); err == nil {
connectorConfig.IsolationLevel = val
}
}

// Check if it is Spanner gorm that is creating the connection.
// If so, we should set a different user-agent header than the
Expand Down Expand Up @@ -1058,6 +1066,28 @@ func checkIsValidType(v driver.Value) bool {
return true
}

func parseIsolationLevel(val string) (sql.IsolationLevel, error) {
switch strings.Replace(strings.ToLower(strings.TrimSpace(val)), " ", "_", 1) {
case "default":
return sql.LevelDefault, nil
case "read_uncommitted":
return sql.LevelReadUncommitted, nil
case "read_committed":
return sql.LevelReadCommitted, nil
case "write_committed":
return sql.LevelWriteCommitted, nil
case "repeatable_read":
return sql.LevelRepeatableRead, nil
case "snapshot":
return sql.LevelSnapshot, nil
case "serializable":
return sql.LevelSerializable, nil
case "linearizable":
return sql.LevelLinearizable, nil
}
return sql.LevelDefault, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "invalid or unsupported isolation level: %v", val))
}

func toProtoIsolationLevel(level sql.IsolationLevel) (spannerpb.TransactionOptions_IsolationLevel, error) {
switch level {
case sql.LevelSerializable:
Expand Down
86 changes: 86 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,21 @@ func TestExtractDnsParts(t *testing.T) {
DisableNativeMetrics: true,
},
},
{
input: "projects/p/instances/i/databases/d?isolationLevel=repeatable_read;",
wantConnectorConfig: ConnectorConfig{
Project: "p",
Instance: "i",
Database: "d",
Params: map[string]string{
"isolationlevel": "repeatable_read",
},
},
wantSpannerConfig: spanner.ClientConfig{
SessionPoolConfig: spanner.DefaultSessionPoolConfig,
UserAgent: userAgent,
},
},
{
input: "spanner.googleapis.com/projects/p/instances/i/databases/d?minSessions=200;maxSessions=1000;numChannels=10;disableRouteToLeader=true;enableEndToEndTracing=true;disableNativeMetrics=true;rpcPriority=Medium;optimizerVersion=1;optimizerStatisticsPackage=latest;databaseRole=child",
wantConnectorConfig: ConnectorConfig{
Expand Down Expand Up @@ -291,6 +306,77 @@ func TestToProtoIsolationLevel(t *testing.T) {
}
}

func TestParseIsolationLevel(t *testing.T) {
tests := []struct {
input string
want sql.IsolationLevel
wantErr bool
}{
{
input: "default",
want: sql.LevelDefault,
},
{
input: " DEFAULT ",
want: sql.LevelDefault,
},
{
input: "read uncommitted",
want: sql.LevelReadUncommitted,
},
{
input: " read_uncommitted\n",
want: sql.LevelReadUncommitted,
},
{
input: "read committed",
want: sql.LevelReadCommitted,
},
{
input: "write committed",
want: sql.LevelWriteCommitted,
},
{
input: "repeatable read",
want: sql.LevelRepeatableRead,
},
{
input: "snapshot",
want: sql.LevelSnapshot,
},
{
input: "serializable",
want: sql.LevelSerializable,
},
{
input: "linearizable",
want: sql.LevelLinearizable,
},
{
input: "read serializable",
wantErr: true,
},
{
input: "",
wantErr: true,
},
{
input: "read-committed",
wantErr: true,
},
}
for _, tc := range tests {
level, err := parseIsolationLevel(tc.input)
if tc.wantErr && err == nil {
t.Errorf("parseIsolationLevel(%q): expected error", tc.input)
} else if !tc.wantErr && err != nil {
t.Errorf("parseIsolationLevel(%q): unexpected error: %v", tc.input, err)
} else if level != tc.want {
t.Errorf("parseIsolationLevel(%q): got %v, want %v", tc.input, level, tc.want)
}
}
}

func ExampleCreateConnector() {
connectorConfig := ConnectorConfig{
Project: "my-project",
Expand Down
Loading