From 6c21558f75628908a70de79c62aff2851e756e7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Wed, 23 Aug 2023 12:37:17 +0200 Subject: [PATCH] fix(spanner): Transaction was started in a different session (#8467) * fix: Transaction was started in a different session Retrying a "Session not found" error could cause a "Transaction was started in a different session" error. This happened because: 1. The detection of a "Session not found" error would remove the session from the pool, and also remove the session ID from the session handle 2. The retry mechanism would check out a new session from the pool, but not assign it to the transaction yet 3. The retry would then proceed to retry the transaction with an explicit BeginTransaction RPC. This function would however pick a new session from the pool, because step 2 had not yet assigned the transaction a new session. 4. The higher level retry loop would then after executing the BeginTransaction RPC assign the session that was picked in step 2 to the transaction. 5. The transaction would then proceed to use the session from step 2 with the transaction from step 3. * chore: remove unused code * chore: fix import order --- spanner/client.go | 6 +- spanner/client_test.go | 196 ++++++++++++++++++ spanner/integration_test.go | 51 +++++ .../internal/testutil/inmem_spanner_server.go | 18 +- spanner/transaction.go | 10 +- 5 files changed, 267 insertions(+), 14 deletions(-) diff --git a/spanner/client.go b/spanner/client.go index a95b02431ad4..bfe00c2dcec0 100644 --- a/spanner/client.go +++ b/spanner/client.go @@ -563,6 +563,10 @@ func (c *Client) rwTransaction(ctx context.Context, f func(context.Context, *Rea } } if t.shouldExplicitBegin(attempt) { + // Make sure we set the current session handle before calling BeginTransaction. + // Note that the t.begin(ctx) call could change the session that is being used by the transaction, as the + // BeginTransaction RPC invocation will be retried on a new session if it returns SessionNotFound. + t.txReadOnly.sh = sh if err = t.begin(ctx); err != nil { trace.TracePrintf(ctx, nil, "Error while BeginTransaction during retrying a ReadWrite transaction: %v", ToSpannerError(err)) return ToSpannerError(err) @@ -571,9 +575,9 @@ func (c *Client) rwTransaction(ctx context.Context, f func(context.Context, *Rea t = &ReadWriteTransaction{ txReadyOrClosed: make(chan struct{}), } + t.txReadOnly.sh = sh } attempt++ - t.txReadOnly.sh = sh t.txReadOnly.sp = c.idleSessions t.txReadOnly.txReadEnv = t t.txReadOnly.qo = c.qo diff --git a/spanner/client_test.go b/spanner/client_test.go index 6f963102805e..1141a6ebb5bc 100644 --- a/spanner/client_test.go +++ b/spanner/client_test.go @@ -727,6 +727,202 @@ func TestClient_ReadOnlyTransaction_SessionNotFoundOnBeginTransaction_WithMaxOne } } +func TestClient_ReadWriteTransaction_SessionNotFoundForFirstStatement(t *testing.T) { + ctx := context.Background() + server, client, teardown := setupMockedTestServer(t) + defer teardown() + server.TestSpanner.PutExecutionTime( + MethodExecuteStreamingSql, + SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}}, + ) + + expectedAttempts := 2 + var attempts int + _, err := client.ReadWriteTransaction( + ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + attempts++ + iter := tx.Query(ctx, NewStatement(SelectFooFromBar)) + defer iter.Stop() + for { + _, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return err + } + } + return nil + }) + if err != nil { + t.Fatal(err) + } + if expectedAttempts != attempts { + t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, expectedAttempts) + } + requests := drainRequestsFromServer(server.TestSpanner) + if err := compareRequests([]interface{}{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.BeginTransactionRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.CommitRequest{}, + }, requests); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadWriteTransaction_SessionNotFoundForFirstStatement_AndThenSessionNotFoundForBeginTransaction(t *testing.T) { + ctx := context.Background() + server, client, teardown := setupMockedTestServer(t) + defer teardown() + server.TestSpanner.PutExecutionTime( + MethodExecuteStreamingSql, + SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}}, + ) + server.TestSpanner.PutExecutionTime( + MethodBeginTransaction, + SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}}, + ) + + expectedAttempts := 2 + var attempts int + _, err := client.ReadWriteTransaction( + ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + attempts++ + iter := tx.Query(ctx, NewStatement(SelectFooFromBar)) + defer iter.Stop() + for { + _, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return err + } + } + return nil + }) + if err != nil { + t.Fatal(err) + } + if expectedAttempts != attempts { + t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, expectedAttempts) + } + requests := drainRequestsFromServer(server.TestSpanner) + if err := compareRequests([]interface{}{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.BeginTransactionRequest{}, + &sppb.BeginTransactionRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.CommitRequest{}, + }, requests); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadWriteTransaction_AbortedForFirstStatement_AndThenSessionNotFoundForBeginTransaction(t *testing.T) { + ctx := context.Background() + server, client, teardown := setupMockedTestServer(t) + defer teardown() + server.TestSpanner.PutExecutionTime( + MethodExecuteStreamingSql, + SimulatedExecutionTime{Errors: []error{status.Error(codes.Aborted, "Transaction aborted")}}, + ) + server.TestSpanner.PutExecutionTime( + MethodBeginTransaction, + SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}}, + ) + + expectedAttempts := 2 + var attempts int + _, err := client.ReadWriteTransaction( + ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + attempts++ + iter := tx.Query(ctx, NewStatement(SelectFooFromBar)) + defer iter.Stop() + for { + _, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return err + } + } + return nil + }) + if err != nil { + t.Fatal(err) + } + if expectedAttempts != attempts { + t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, expectedAttempts) + } + requests := drainRequestsFromServer(server.TestSpanner) + if err := compareRequests([]interface{}{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.BeginTransactionRequest{}, + &sppb.BeginTransactionRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.CommitRequest{}, + }, requests); err != nil { + t.Fatal(err) + } +} + +func TestClient_ReadWriteTransaction_SessionNotFoundForFirstStatement_DoesNotLeakSession(t *testing.T) { + ctx := context.Background() + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 1, + MaxOpened: 1, + }, + }) + defer teardown() + server.TestSpanner.PutExecutionTime( + MethodExecuteStreamingSql, + SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}}, + ) + + expectedAttempts := 2 + var attempts int + _, err := client.ReadWriteTransaction( + ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + attempts++ + iter := tx.Query(ctx, NewStatement(SelectFooFromBar)) + defer iter.Stop() + for { + _, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return err + } + } + return nil + }) + if err != nil { + t.Fatal(err) + } + if expectedAttempts != attempts { + t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, expectedAttempts) + } + requests := drainRequestsFromServer(server.TestSpanner) + if err := compareRequests([]interface{}{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.BatchCreateSessionsRequest{}, // We need to create more sessions, as the one used first was destroyed. + &sppb.BeginTransactionRequest{}, + &sppb.ExecuteSqlRequest{}, + &sppb.CommitRequest{}, + }, requests); err != nil { + t.Fatal(err) + } +} + func TestClient_ReadOnlyTransaction_QueryOptions(t *testing.T) { for _, tt := range queryOptionsTestCases() { t.Run(tt.name, func(t *testing.T) { diff --git a/spanner/integration_test.go b/spanner/integration_test.go index 7aa7d622df66..5d016ede945d 100644 --- a/spanner/integration_test.go +++ b/spanner/integration_test.go @@ -42,12 +42,14 @@ import ( adminpb "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" instance "cloud.google.com/go/spanner/admin/instance/apiv1" "cloud.google.com/go/spanner/admin/instance/apiv1/instancepb" + v1 "cloud.google.com/go/spanner/apiv1" sppb "cloud.google.com/go/spanner/apiv1/spannerpb" "cloud.google.com/go/spanner/internal" "go.opencensus.io/stats/view" "go.opencensus.io/tag" "google.golang.org/api/iterator" "google.golang.org/api/option" + "google.golang.org/api/option/internaloption" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/peer" @@ -846,6 +848,55 @@ func TestIntegration_SingleUse_WithQueryOptions(t *testing.T) { } } +func TestIntegration_TransactionWasStartedInDifferentSession(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + // Set up testing environment. + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][singerDDLStatements]) + defer cleanup() + + attempts := 0 + _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, transaction *ReadWriteTransaction) error { + attempts++ + if attempts == 1 { + deleteTestSession(ctx, t, transaction.sh.getID()) + } + if _, err := readAll(transaction.Query(ctx, NewStatement("select * from singers"))); err != nil { + return err + } + return nil + }) + if err != nil { + t.Fatal(err) + } + if g, w := attempts, 2; g != w { + t.Fatalf("attempts mismatch\nGot: %v\nWant: %v", g, w) + } +} + +func deleteTestSession(ctx context.Context, t *testing.T, sessionName string) { + var opts []option.ClientOption + if emulatorAddr := os.Getenv("SPANNER_EMULATOR_HOST"); emulatorAddr != "" { + emulatorOpts := []option.ClientOption{ + option.WithEndpoint(emulatorAddr), + option.WithGRPCDialOption(grpc.WithInsecure()), + option.WithoutAuthentication(), + internaloption.SkipDialSettingsValidation(), + } + opts = append(emulatorOpts, opts...) + } + gapic, err := v1.NewClient(ctx, opts...) + if err != nil { + t.Fatalf("could not create gapic client: %v", err) + } + defer gapic.Close() + if err := gapic.DeleteSession(ctx, &sppb.DeleteSessionRequest{Name: sessionName}); err != nil { + t.Fatal(err) + } +} + func TestIntegration_SingleUse_ReadingWithLimit(t *testing.T) { t.Parallel() diff --git a/spanner/internal/testutil/inmem_spanner_server.go b/spanner/internal/testutil/inmem_spanner_server.go index 922ae6ad1328..b1adf02f2182 100644 --- a/spanner/internal/testutil/inmem_spanner_server.go +++ b/spanner/internal/testutil/inmem_spanner_server.go @@ -581,13 +581,17 @@ func (s *inMemSpannerServer) beginTransaction(session *spannerpb.Session, option return res } -func (s *inMemSpannerServer) getTransactionByID(id []byte) (*spannerpb.Transaction, error) { +func (s *inMemSpannerServer) getTransactionByID(session *spannerpb.Session, id []byte) (*spannerpb.Transaction, error) { s.mu.Lock() defer s.mu.Unlock() tx, ok := s.transactions[string(id)] if !ok { return nil, gstatus.Error(codes.NotFound, "Transaction not found") } + if !strings.HasPrefix(string(id), session.Name) { + return nil, gstatus.Error(codes.InvalidArgument, "Transaction was started in a different session.") + } + aborted, ok := s.abortedTransactions[string(id)] if ok && aborted { return nil, newAbortedErrorWithMinimalRetryDelay() @@ -813,7 +817,7 @@ func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.Exec var id []byte s.updateSessionLastUseTime(session.Name) if id = s.getTransactionID(session, req.Transaction); id != nil { - _, err = s.getTransactionByID(id) + _, err = s.getTransactionByID(session, id) if err != nil { return nil, err } @@ -860,7 +864,7 @@ func (s *inMemSpannerServer) executeStreamingSQL(req *spannerpb.ExecuteSqlReques s.updateSessionLastUseTime(session.Name) var id []byte if id = s.getTransactionID(session, req.Transaction); id != nil { - _, err = s.getTransactionByID(id) + _, err = s.getTransactionByID(session, id) if err != nil { return err } @@ -932,7 +936,7 @@ func (s *inMemSpannerServer) ExecuteBatchDml(ctx context.Context, req *spannerpb s.updateSessionLastUseTime(session.Name) var id []byte if id = s.getTransactionID(session, req.Transaction); id != nil { - _, err = s.getTransactionByID(id) + _, err = s.getTransactionByID(session, id) if err != nil { return nil, err } @@ -1031,7 +1035,7 @@ func (s *inMemSpannerServer) Commit(ctx context.Context, req *spannerpb.CommitRe if req.GetSingleUseTransaction() != nil { tx = s.beginTransaction(session, req.GetSingleUseTransaction()) } else if req.GetTransactionId() != nil { - tx, err = s.getTransactionByID(req.GetTransactionId()) + tx, err = s.getTransactionByID(session, req.GetTransactionId()) if err != nil { return nil, err } @@ -1064,7 +1068,7 @@ func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.Rollba return nil, err } s.updateSessionLastUseTime(session.Name) - tx, err := s.getTransactionByID(req.TransactionId) + tx, err := s.getTransactionByID(session, req.TransactionId) if err != nil { return nil, err } @@ -1091,7 +1095,7 @@ func (s *inMemSpannerServer) PartitionQuery(ctx context.Context, req *spannerpb. var tx *spannerpb.Transaction s.updateSessionLastUseTime(session.Name) if id = s.getTransactionID(session, req.Transaction); id != nil { - tx, err = s.getTransactionByID(id) + tx, err = s.getTransactionByID(session, id) if err != nil { return nil, err } diff --git a/spanner/transaction.go b/spanner/transaction.go index 85de18327f89..81d7e036521c 100644 --- a/spanner/transaction.go +++ b/spanner/transaction.go @@ -1380,15 +1380,13 @@ func (t *ReadWriteTransaction) begin(ctx context.Context) error { }() // Retry the BeginTransaction call if a 'Session not found' is returned. for { - if sh == nil || sh.getID() == "" || sh.getClient() == nil { + tx, err = beginTransaction(contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader), sh.getID(), sh.getClient(), t.txOpts) + if isSessionNotFoundError(err) { + sh.destroy() sh, err = t.sp.take(ctx) if err != nil { return err } - } - tx, err = beginTransaction(contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader), sh.getID(), sh.getClient(), t.txOpts) - if isSessionNotFoundError(err) { - sh.destroy() continue } else { err = ToSpannerError(err) @@ -1399,7 +1397,7 @@ func (t *ReadWriteTransaction) begin(ctx context.Context) error { t.mu.Lock() t.tx = tx t.sh = sh - // State transite to txActive. + // Transition state to txActive. t.state = txActive t.mu.Unlock() }