From 57729691cd4b6a1bc8b0656b58fe79df15b1b08a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Mon, 24 Mar 2025 10:10:00 +0100 Subject: [PATCH 1/2] feat: support isolation level REPEATABLE READ Add support for isolation level REPEATABLE READ with the BeginTx function. This allows the caller to specify the isolation level for a single transaction. A follow-up pull request will add support for setting the default isolation level that should be used by a connection. --- conn.go | 12 +++- conn_with_mockserver_test.go | 117 +++++++++++++++++++++++++++++++++++ driver.go | 21 +++++++ driver_test.go | 54 ++++++++++++++++ 4 files changed, 202 insertions(+), 2 deletions(-) create mode 100644 conn_with_mockserver_test.go diff --git a/conn.go b/conn.go index 59d1abd0..bfd08bdb 100644 --- a/conn.go +++ b/conn.go @@ -942,14 +942,22 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e disableRetryAborts := false batchReadOnly := false sil := opts.Isolation >> 8 - // TODO: Fix this, the original isolation level is not correctly restored. - opts.Isolation = opts.Isolation - sil + opts.Isolation = opts.Isolation - sil<<8 + if opts.Isolation != driver.IsolationLevel(sql.LevelDefault) { + level, err := toProtoIsolationLevel(sql.IsolationLevel(opts.Isolation)) + if err != nil { + return nil, err + } + readWriteTransactionOptions.TransactionOptions.IsolationLevel = level + } if sil > 0 { switch spannerIsolationLevel(sil) { case levelDisableRetryAborts: disableRetryAborts = true case levelBatchReadOnly: batchReadOnly = true + default: + // ignore } } if batchReadOnly && !opts.ReadOnly { diff --git a/conn_with_mockserver_test.go b/conn_with_mockserver_test.go new file mode 100644 index 00000000..6d6a1133 --- /dev/null +++ b/conn_with_mockserver_test.go @@ -0,0 +1,117 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spannerdriver + +import ( + "context" + "database/sql" + "reflect" + "testing" + + "cloud.google.com/go/spanner/apiv1/spannerpb" + "github.com/googleapis/go-sql-spanner/testutil" +) + +func TestBeginTx(t *testing.T) { + t.Parallel() + + db, server, teardown := setupTestDBConnection(t) + defer teardown() + ctx := context.Background() + + tx, _ := db.BeginTx(ctx, &sql.TxOptions{}) + _, _ = tx.ExecContext(ctx, testutil.UpdateBarSetFoo) + _ = tx.Rollback() + + requests := drainRequestsFromServer(server.TestSpanner) + beginRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 1; g != w { + t.Fatalf("begin requests count mismatch\n Got: %v\nWant: %v", g, w) + } + request := beginRequests[0].(*spannerpb.BeginTransactionRequest) + if g, w := request.Options.GetIsolationLevel(), spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED; g != w { + t.Fatalf("begin isolation level mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestBeginTxWithIsolationLevel(t *testing.T) { + t.Parallel() + + db, server, teardown := setupTestDBConnection(t) + defer teardown() + ctx := context.Background() + + for _, level := range []sql.IsolationLevel{ + sql.LevelDefault, + sql.LevelSnapshot, + sql.LevelRepeatableRead, + sql.LevelSerializable, + } { + originalLevel := level + for _, disableRetryAborts := range []bool{true, false} { + if disableRetryAborts { + level = WithDisableRetryAborts(originalLevel) + } else { + level = originalLevel + } + tx, _ := db.BeginTx(ctx, &sql.TxOptions{ + Isolation: level, + }) + _, _ = tx.ExecContext(ctx, testutil.UpdateBarSetFoo) + _ = tx.Rollback() + + requests := drainRequestsFromServer(server.TestSpanner) + beginRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 1; g != w { + t.Fatalf("begin requests count mismatch\n Got: %v\nWant: %v", g, w) + } + request := beginRequests[0].(*spannerpb.BeginTransactionRequest) + wantIsolationLevel, _ := toProtoIsolationLevel(originalLevel) + if g, w := request.Options.GetIsolationLevel(), wantIsolationLevel; g != w { + t.Fatalf("begin isolation level mismatch\n Got: %v\nWant: %v", g, w) + } + } + } +} + +func TestBeginTxWithInvalidIsolationLevel(t *testing.T) { + t.Parallel() + + db, _, teardown := setupTestDBConnection(t) + defer teardown() + ctx := context.Background() + + for _, level := range []sql.IsolationLevel{ + sql.LevelReadUncommitted, + sql.LevelReadCommitted, + sql.LevelWriteCommitted, + sql.LevelLinearizable, + } { + originalLevel := level + for _, disableRetryAborts := range []bool{true, false} { + if disableRetryAborts { + level = WithDisableRetryAborts(originalLevel) + } else { + level = originalLevel + } + _, err := db.BeginTx(ctx, &sql.TxOptions{ + Isolation: level, + }) + if err == nil { + t.Fatalf("BeginTx should have failed with invalid isolation level: %v", level) + } + } + } +} diff --git a/driver.go b/driver.go index b4452570..7b8b84bc 100644 --- a/driver.go +++ b/driver.go @@ -1058,6 +1058,27 @@ func checkIsValidType(v driver.Value) bool { return true } +func toProtoIsolationLevel(level sql.IsolationLevel) (spannerpb.TransactionOptions_IsolationLevel, error) { + switch level { + case sql.LevelSerializable: + return spannerpb.TransactionOptions_SERIALIZABLE, nil + case sql.LevelRepeatableRead: + return spannerpb.TransactionOptions_REPEATABLE_READ, nil + case sql.LevelSnapshot: + return spannerpb.TransactionOptions_REPEATABLE_READ, nil + case sql.LevelDefault: + return spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED, nil + + // Unsupported and unknown isolation levels. + case sql.LevelReadUncommitted: + case sql.LevelReadCommitted: + case sql.LevelWriteCommitted: + case sql.LevelLinearizable: + default: + } + return spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "invalid or unsupported isolation level: %v", level)) +} + type spannerIsolationLevel sql.IsolationLevel const ( diff --git a/driver_test.go b/driver_test.go index 50a6dd35..b24c068e 100644 --- a/driver_test.go +++ b/driver_test.go @@ -234,7 +234,61 @@ func TestExtractDnsParts(t *testing.T) { } }) } +} +func TestToProtoIsolationLevel(t *testing.T) { + tests := []struct { + input sql.IsolationLevel + want spannerpb.TransactionOptions_IsolationLevel + wantErr bool + }{ + { + input: sql.LevelSerializable, + want: spannerpb.TransactionOptions_SERIALIZABLE, + }, + { + input: sql.LevelRepeatableRead, + want: spannerpb.TransactionOptions_REPEATABLE_READ, + }, + { + input: sql.LevelSnapshot, + want: spannerpb.TransactionOptions_REPEATABLE_READ, + }, + { + input: sql.LevelDefault, + want: spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED, + }, + { + input: sql.LevelReadUncommitted, + wantErr: true, + }, + { + input: sql.LevelReadCommitted, + wantErr: true, + }, + { + input: sql.LevelWriteCommitted, + wantErr: true, + }, + { + input: sql.LevelLinearizable, + wantErr: true, + }, + { + input: sql.IsolationLevel(1000), + wantErr: true, + }, + } + for i, test := range tests { + g, err := toProtoIsolationLevel(test.input) + if test.wantErr && err == nil { + t.Errorf("test %d: expected error for input %v, got none", i, test.input) + } else if !test.wantErr && err != nil { + t.Errorf("test %d: unexpected error for input %v: %v", i, test.input, err) + } else if g != test.want { + t.Errorf("test %d:\n Got: %v\nWant: %v", i, g, test.want) + } + } } func ExampleCreateConnector() { From 39b49611360b643a25653bec9ba4f0f7d00ccec1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Mon, 24 Mar 2025 14:35:03 +0100 Subject: [PATCH 2/2] feat: support default isolation level for connection Support setting a default isolation level for a connection and connector. All read/write ransactions on a connection will use the default isolation level, unless an isolation level is specified in the BeginTx function call. --- conn.go | 33 +++++++++- conn_with_mockserver_test.go | 124 +++++++++++++++++++++++++++++++++++ driver.go | 30 +++++++++ driver_test.go | 86 ++++++++++++++++++++++++ 4 files changed, 272 insertions(+), 1 deletion(-) diff --git a/conn.go b/conn.go index bfd08bdb..b9bf1f12 100644 --- a/conn.go +++ b/conn.go @@ -114,6 +114,13 @@ type SpannerConn interface { // mode and for read-only transaction. SetReadOnlyStaleness(staleness spanner.TimestampBound) error + // IsolationLevel returns the current default isolation level that is + // used for read/write transactions on this connection. + IsolationLevel() sql.IsolationLevel + // SetIsolationLevel sets the default isolation level to use for read/write + // transactions on this connection. + SetIsolationLevel(level sql.IsolationLevel) error + // TransactionTag returns the transaction tag that will be applied to the next // read/write transaction on this connection. The transaction tag that is set // on the connection is cleared when a read/write transaction is started. @@ -235,6 +242,10 @@ type conn struct { autocommitDMLMode AutocommitDMLMode // readOnlyStaleness is used for queries in autocommit mode and for read-only transactions. readOnlyStaleness spanner.TimestampBound + // isolationLevel determines the default isolation level that is used for read/write + // transactions on this connection. This default is ignored if the BeginTx function is + // called with an isolation level other than sql.LevelDefault. + isolationLevel sql.IsolationLevel // execOptions are applied to the next statement or transaction that is executed // on this connection. It can also be set by passing it in as an argument to @@ -309,6 +320,15 @@ func (c *conn) setReadOnlyStaleness(staleness spanner.TimestampBound) (driver.Re return driver.ResultNoRows, nil } +func (c *conn) IsolationLevel() sql.IsolationLevel { + return c.isolationLevel +} + +func (c *conn) SetIsolationLevel(level sql.IsolationLevel) error { + c.isolationLevel = level + return nil +} + func (c *conn) MaxCommitDelay() time.Duration { return *c.execOptions.TransactionOptions.CommitOptions.MaxCommitDelay } @@ -633,6 +653,7 @@ func (c *conn) ResetSession(_ context.Context) error { c.autoBatchDmlUpdateCount = c.connector.connectorConfig.AutoBatchDmlUpdateCount c.autoBatchDmlUpdateCountVerification = !c.connector.connectorConfig.DisableAutoBatchDmlUpdateCountVerification c.retryAborts = c.connector.retryAbortsInternally + c.isolationLevel = c.connector.connectorConfig.IsolationLevel // TODO: Reset the following fields to the connector default c.autocommitDMLMode = Transactional c.readOnlyStaleness = spanner.TimestampBound{} @@ -888,10 +909,20 @@ func (c *conn) getTransactionOptions() ReadWriteTransactionOptions { defer func() { c.execOptions.TransactionOptions.TransactionTag = "" }() - return ReadWriteTransactionOptions{ + txOpts := ReadWriteTransactionOptions{ TransactionOptions: c.execOptions.TransactionOptions, DisableInternalRetries: !c.retryAborts, } + // Only use the default isolation level from the connection if the ExecOptions + // did not contain a more specific isolation level. + if txOpts.TransactionOptions.IsolationLevel == spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED { + // This should never really return an error, but we check just to be absolutely sure. + level, err := toProtoIsolationLevel(c.isolationLevel) + if err == nil { + txOpts.TransactionOptions.IsolationLevel = level + } + } + return txOpts } func (c *conn) withTempReadOnlyTransactionOptions(options *ReadOnlyTransactionOptions) { diff --git a/conn_with_mockserver_test.go b/conn_with_mockserver_test.go index 6d6a1133..05ad707c 100644 --- a/conn_with_mockserver_test.go +++ b/conn_with_mockserver_test.go @@ -17,9 +17,11 @@ package spannerdriver import ( "context" "database/sql" + "fmt" "reflect" "testing" + "cloud.google.com/go/spanner" "cloud.google.com/go/spanner/apiv1/spannerpb" "github.com/googleapis/go-sql-spanner/testutil" ) @@ -115,3 +117,125 @@ func TestBeginTxWithInvalidIsolationLevel(t *testing.T) { } } } + +func TestDefaultIsolationLevel(t *testing.T) { + t.Parallel() + + for _, level := range []sql.IsolationLevel{ + sql.LevelDefault, + sql.LevelSnapshot, + sql.LevelRepeatableRead, + sql.LevelSerializable, + } { + db, server, teardown := setupTestDBConnectionWithParams(t, fmt.Sprintf("isolationLevel=%v", level)) + defer teardown() + ctx := context.Background() + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + if err := conn.Raw(func(driverConn interface{}) error { + spannerConn, ok := driverConn.(SpannerConn) + if !ok { + return fmt.Errorf("expected spanner conn, got %T", driverConn) + } + if spannerConn.IsolationLevel() != level { + return fmt.Errorf("expected isolation level %v, got %v", level, spannerConn.IsolationLevel()) + } + return nil + }); err != nil { + t.Fatal(err) + } + + originalLevel := level + for _, disableRetryAborts := range []bool{true, false} { + if disableRetryAborts { + level = WithDisableRetryAborts(originalLevel) + } else { + level = originalLevel + } + // Note: No isolation level is passed in here, so it will use the default. + tx, _ := db.BeginTx(ctx, &sql.TxOptions{}) + _, _ = tx.ExecContext(ctx, testutil.UpdateBarSetFoo) + _ = tx.Rollback() + + requests := drainRequestsFromServer(server.TestSpanner) + beginRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 1; g != w { + t.Fatalf("begin requests count mismatch\n Got: %v\nWant: %v", g, w) + } + request := beginRequests[0].(*spannerpb.BeginTransactionRequest) + wantIsolationLevel, _ := toProtoIsolationLevel(originalLevel) + if g, w := request.Options.GetIsolationLevel(), wantIsolationLevel; g != w { + t.Fatalf("begin isolation level mismatch\n Got: %v\nWant: %v", g, w) + } + } + } +} + +func TestSetIsolationLevel(t *testing.T) { + t.Parallel() + + db, _, teardown := setupTestDBConnection(t) + defer teardown() + ctx := context.Background() + + // Repeat twice to ensure that the state is reset after closing the connection. + for i := 0; i < 2; i++ { + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + var level sql.IsolationLevel + _ = conn.Raw(func(driverConn interface{}) error { + level = driverConn.(SpannerConn).IsolationLevel() + return nil + }) + if g, w := level, sql.LevelDefault; g != w { + t.Fatalf("isolation level mismatch\n Got: %v\nWant: %v", g, w) + } + _ = conn.Raw(func(driverConn interface{}) error { + return driverConn.(SpannerConn).SetIsolationLevel(sql.LevelSnapshot) + }) + _ = conn.Raw(func(driverConn interface{}) error { + level = driverConn.(SpannerConn).IsolationLevel() + return nil + }) + if g, w := level, sql.LevelSnapshot; g != w { + t.Fatalf("isolation level mismatch\n Got: %v\nWant: %v", g, w) + } + conn.Close() + } +} + +func TestIsolationLevelAutoCommit(t *testing.T) { + t.Parallel() + + db, server, teardown := setupTestDBConnection(t) + defer teardown() + ctx := context.Background() + + for _, level := range []sql.IsolationLevel{ + sql.LevelDefault, + sql.LevelSnapshot, + sql.LevelRepeatableRead, + sql.LevelSerializable, + } { + spannerLevel, _ := toProtoIsolationLevel(level) + _, _ = db.ExecContext(ctx, testutil.UpdateBarSetFoo, ExecOptions{TransactionOptions: spanner.TransactionOptions{ + IsolationLevel: spannerLevel, + }}) + + requests := drainRequestsFromServer(server.TestSpanner) + executeRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{})) + if g, w := len(executeRequests), 1; g != w { + t.Fatalf("execute requests count mismatch\n Got: %v\nWant: %v", g, w) + } + request := executeRequests[0].(*spannerpb.ExecuteSqlRequest) + wantIsolationLevel, _ := toProtoIsolationLevel(level) + if g, w := request.Transaction.GetBegin().GetIsolationLevel(), wantIsolationLevel; g != w { + t.Fatalf("begin isolation level mismatch\n Got: %v\nWant: %v", g, w) + } + } +} diff --git a/driver.go b/driver.go index 7b8b84bc..791e2644 100644 --- a/driver.go +++ b/driver.go @@ -231,6 +231,9 @@ type ConnectorConfig struct { AutoBatchDmlUpdateCount int64 DisableAutoBatchDmlUpdateCountVerification bool + // IsolationLevel is the default isolation level for read/write transactions. + IsolationLevel sql.IsolationLevel + // DecodeToNativeArrays determines whether arrays that have a Go native // type should be decoded to those types rather than the corresponding // spanner.NullTypeName type. @@ -472,6 +475,11 @@ func createConnector(d *Driver, connectorConfig ConnectorConfig) (*connector, er connectorConfig.AutoConfigEmulator = val } } + if strval, ok := connectorConfig.Params[strings.ToLower("IsolationLevel")]; ok { + if val, err := parseIsolationLevel(strval); err == nil { + connectorConfig.IsolationLevel = val + } + } // Check if it is Spanner gorm that is creating the connection. // If so, we should set a different user-agent header than the @@ -1058,6 +1066,28 @@ func checkIsValidType(v driver.Value) bool { return true } +func parseIsolationLevel(val string) (sql.IsolationLevel, error) { + switch strings.Replace(strings.ToLower(strings.TrimSpace(val)), " ", "_", 1) { + case "default": + return sql.LevelDefault, nil + case "read_uncommitted": + return sql.LevelReadUncommitted, nil + case "read_committed": + return sql.LevelReadCommitted, nil + case "write_committed": + return sql.LevelWriteCommitted, nil + case "repeatable_read": + return sql.LevelRepeatableRead, nil + case "snapshot": + return sql.LevelSnapshot, nil + case "serializable": + return sql.LevelSerializable, nil + case "linearizable": + return sql.LevelLinearizable, nil + } + return sql.LevelDefault, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "invalid or unsupported isolation level: %v", val)) +} + func toProtoIsolationLevel(level sql.IsolationLevel) (spannerpb.TransactionOptions_IsolationLevel, error) { switch level { case sql.LevelSerializable: diff --git a/driver_test.go b/driver_test.go index b24c068e..3b6a7354 100644 --- a/driver_test.go +++ b/driver_test.go @@ -160,6 +160,21 @@ func TestExtractDnsParts(t *testing.T) { DisableNativeMetrics: true, }, }, + { + input: "projects/p/instances/i/databases/d?isolationLevel=repeatable_read;", + wantConnectorConfig: ConnectorConfig{ + Project: "p", + Instance: "i", + Database: "d", + Params: map[string]string{ + "isolationlevel": "repeatable_read", + }, + }, + wantSpannerConfig: spanner.ClientConfig{ + SessionPoolConfig: spanner.DefaultSessionPoolConfig, + UserAgent: userAgent, + }, + }, { input: "spanner.googleapis.com/projects/p/instances/i/databases/d?minSessions=200;maxSessions=1000;numChannels=10;disableRouteToLeader=true;enableEndToEndTracing=true;disableNativeMetrics=true;rpcPriority=Medium;optimizerVersion=1;optimizerStatisticsPackage=latest;databaseRole=child", wantConnectorConfig: ConnectorConfig{ @@ -291,6 +306,77 @@ func TestToProtoIsolationLevel(t *testing.T) { } } +func TestParseIsolationLevel(t *testing.T) { + tests := []struct { + input string + want sql.IsolationLevel + wantErr bool + }{ + { + input: "default", + want: sql.LevelDefault, + }, + { + input: " DEFAULT ", + want: sql.LevelDefault, + }, + { + input: "read uncommitted", + want: sql.LevelReadUncommitted, + }, + { + input: " read_uncommitted\n", + want: sql.LevelReadUncommitted, + }, + { + input: "read committed", + want: sql.LevelReadCommitted, + }, + { + input: "write committed", + want: sql.LevelWriteCommitted, + }, + { + input: "repeatable read", + want: sql.LevelRepeatableRead, + }, + { + input: "snapshot", + want: sql.LevelSnapshot, + }, + { + input: "serializable", + want: sql.LevelSerializable, + }, + { + input: "linearizable", + want: sql.LevelLinearizable, + }, + { + input: "read serializable", + wantErr: true, + }, + { + input: "", + wantErr: true, + }, + { + input: "read-committed", + wantErr: true, + }, + } + for _, tc := range tests { + level, err := parseIsolationLevel(tc.input) + if tc.wantErr && err == nil { + t.Errorf("parseIsolationLevel(%q): expected error", tc.input) + } else if !tc.wantErr && err != nil { + t.Errorf("parseIsolationLevel(%q): unexpected error: %v", tc.input, err) + } else if level != tc.want { + t.Errorf("parseIsolationLevel(%q): got %v, want %v", tc.input, level, tc.want) + } + } +} + func ExampleCreateConnector() { connectorConfig := ConnectorConfig{ Project: "my-project",