diff --git a/conn.go b/conn.go index 59d1abd0..bfd08bdb 100644 --- a/conn.go +++ b/conn.go @@ -942,14 +942,22 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e disableRetryAborts := false batchReadOnly := false sil := opts.Isolation >> 8 - // TODO: Fix this, the original isolation level is not correctly restored. - opts.Isolation = opts.Isolation - sil + opts.Isolation = opts.Isolation - sil<<8 + if opts.Isolation != driver.IsolationLevel(sql.LevelDefault) { + level, err := toProtoIsolationLevel(sql.IsolationLevel(opts.Isolation)) + if err != nil { + return nil, err + } + readWriteTransactionOptions.TransactionOptions.IsolationLevel = level + } if sil > 0 { switch spannerIsolationLevel(sil) { case levelDisableRetryAborts: disableRetryAborts = true case levelBatchReadOnly: batchReadOnly = true + default: + // ignore } } if batchReadOnly && !opts.ReadOnly { diff --git a/conn_with_mockserver_test.go b/conn_with_mockserver_test.go new file mode 100644 index 00000000..6d6a1133 --- /dev/null +++ b/conn_with_mockserver_test.go @@ -0,0 +1,117 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spannerdriver + +import ( + "context" + "database/sql" + "reflect" + "testing" + + "cloud.google.com/go/spanner/apiv1/spannerpb" + "github.com/googleapis/go-sql-spanner/testutil" +) + +func TestBeginTx(t *testing.T) { + t.Parallel() + + db, server, teardown := setupTestDBConnection(t) + defer teardown() + ctx := context.Background() + + tx, _ := db.BeginTx(ctx, &sql.TxOptions{}) + _, _ = tx.ExecContext(ctx, testutil.UpdateBarSetFoo) + _ = tx.Rollback() + + requests := drainRequestsFromServer(server.TestSpanner) + beginRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 1; g != w { + t.Fatalf("begin requests count mismatch\n Got: %v\nWant: %v", g, w) + } + request := beginRequests[0].(*spannerpb.BeginTransactionRequest) + if g, w := request.Options.GetIsolationLevel(), spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED; g != w { + t.Fatalf("begin isolation level mismatch\n Got: %v\nWant: %v", g, w) + } +} + +func TestBeginTxWithIsolationLevel(t *testing.T) { + t.Parallel() + + db, server, teardown := setupTestDBConnection(t) + defer teardown() + ctx := context.Background() + + for _, level := range []sql.IsolationLevel{ + sql.LevelDefault, + sql.LevelSnapshot, + sql.LevelRepeatableRead, + sql.LevelSerializable, + } { + originalLevel := level + for _, disableRetryAborts := range []bool{true, false} { + if disableRetryAborts { + level = WithDisableRetryAborts(originalLevel) + } else { + level = originalLevel + } + tx, _ := db.BeginTx(ctx, &sql.TxOptions{ + Isolation: level, + }) + _, _ = tx.ExecContext(ctx, testutil.UpdateBarSetFoo) + _ = tx.Rollback() + + requests := drainRequestsFromServer(server.TestSpanner) + beginRequests := requestsOfType(requests, reflect.TypeOf(&spannerpb.BeginTransactionRequest{})) + if g, w := len(beginRequests), 1; g != w { + t.Fatalf("begin requests count mismatch\n Got: %v\nWant: %v", g, w) + } + request := beginRequests[0].(*spannerpb.BeginTransactionRequest) + wantIsolationLevel, _ := toProtoIsolationLevel(originalLevel) + if g, w := request.Options.GetIsolationLevel(), wantIsolationLevel; g != w { + t.Fatalf("begin isolation level mismatch\n Got: %v\nWant: %v", g, w) + } + } + } +} + +func TestBeginTxWithInvalidIsolationLevel(t *testing.T) { + t.Parallel() + + db, _, teardown := setupTestDBConnection(t) + defer teardown() + ctx := context.Background() + + for _, level := range []sql.IsolationLevel{ + sql.LevelReadUncommitted, + sql.LevelReadCommitted, + sql.LevelWriteCommitted, + sql.LevelLinearizable, + } { + originalLevel := level + for _, disableRetryAborts := range []bool{true, false} { + if disableRetryAborts { + level = WithDisableRetryAborts(originalLevel) + } else { + level = originalLevel + } + _, err := db.BeginTx(ctx, &sql.TxOptions{ + Isolation: level, + }) + if err == nil { + t.Fatalf("BeginTx should have failed with invalid isolation level: %v", level) + } + } + } +} diff --git a/driver.go b/driver.go index b4452570..7b8b84bc 100644 --- a/driver.go +++ b/driver.go @@ -1058,6 +1058,27 @@ func checkIsValidType(v driver.Value) bool { return true } +func toProtoIsolationLevel(level sql.IsolationLevel) (spannerpb.TransactionOptions_IsolationLevel, error) { + switch level { + case sql.LevelSerializable: + return spannerpb.TransactionOptions_SERIALIZABLE, nil + case sql.LevelRepeatableRead: + return spannerpb.TransactionOptions_REPEATABLE_READ, nil + case sql.LevelSnapshot: + return spannerpb.TransactionOptions_REPEATABLE_READ, nil + case sql.LevelDefault: + return spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED, nil + + // Unsupported and unknown isolation levels. + case sql.LevelReadUncommitted: + case sql.LevelReadCommitted: + case sql.LevelWriteCommitted: + case sql.LevelLinearizable: + default: + } + return spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "invalid or unsupported isolation level: %v", level)) +} + type spannerIsolationLevel sql.IsolationLevel const ( diff --git a/driver_test.go b/driver_test.go index 50a6dd35..b24c068e 100644 --- a/driver_test.go +++ b/driver_test.go @@ -234,7 +234,61 @@ func TestExtractDnsParts(t *testing.T) { } }) } +} +func TestToProtoIsolationLevel(t *testing.T) { + tests := []struct { + input sql.IsolationLevel + want spannerpb.TransactionOptions_IsolationLevel + wantErr bool + }{ + { + input: sql.LevelSerializable, + want: spannerpb.TransactionOptions_SERIALIZABLE, + }, + { + input: sql.LevelRepeatableRead, + want: spannerpb.TransactionOptions_REPEATABLE_READ, + }, + { + input: sql.LevelSnapshot, + want: spannerpb.TransactionOptions_REPEATABLE_READ, + }, + { + input: sql.LevelDefault, + want: spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED, + }, + { + input: sql.LevelReadUncommitted, + wantErr: true, + }, + { + input: sql.LevelReadCommitted, + wantErr: true, + }, + { + input: sql.LevelWriteCommitted, + wantErr: true, + }, + { + input: sql.LevelLinearizable, + wantErr: true, + }, + { + input: sql.IsolationLevel(1000), + wantErr: true, + }, + } + for i, test := range tests { + g, err := toProtoIsolationLevel(test.input) + if test.wantErr && err == nil { + t.Errorf("test %d: expected error for input %v, got none", i, test.input) + } else if !test.wantErr && err != nil { + t.Errorf("test %d: unexpected error for input %v: %v", i, test.input, err) + } else if g != test.want { + t.Errorf("test %d:\n Got: %v\nWant: %v", i, g, test.want) + } + } } func ExampleCreateConnector() {