Skip to content

Commit

Permalink
fix(spanner): Transaction was started in a different session (#8467)
Browse files Browse the repository at this point in the history
* fix: Transaction was started in a different session

Retrying a "Session not found" error could cause a "Transaction was
started in a different session" error. This happened because:
1. The detection of a "Session not found" error would remove the session
   from the pool, and also remove the session ID from the session handle
2. The retry mechanism would check out a new session from the pool, but
   not assign it to the transaction yet
3. The retry would then proceed to retry the transaction with an
   explicit BeginTransaction RPC. This function would however pick a new
   session from the pool, because step 2 had not yet assigned the
   transaction a new session.
4. The higher level retry loop would then after executing the
   BeginTransaction RPC assign the session that was picked in step 2 to
   the transaction.
5. The transaction would then proceed to use the session from step 2
   with the transaction from step 3.

* chore: remove unused code

* chore: fix import order
  • Loading branch information
olavloite committed Aug 23, 2023
1 parent 911f31e commit 6c21558
Show file tree
Hide file tree
Showing 5 changed files with 267 additions and 14 deletions.
6 changes: 5 additions & 1 deletion spanner/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,10 @@ func (c *Client) rwTransaction(ctx context.Context, f func(context.Context, *Rea
}
}
if t.shouldExplicitBegin(attempt) {
// Make sure we set the current session handle before calling BeginTransaction.
// Note that the t.begin(ctx) call could change the session that is being used by the transaction, as the
// BeginTransaction RPC invocation will be retried on a new session if it returns SessionNotFound.
t.txReadOnly.sh = sh
if err = t.begin(ctx); err != nil {
trace.TracePrintf(ctx, nil, "Error while BeginTransaction during retrying a ReadWrite transaction: %v", ToSpannerError(err))
return ToSpannerError(err)
Expand All @@ -571,9 +575,9 @@ func (c *Client) rwTransaction(ctx context.Context, f func(context.Context, *Rea
t = &ReadWriteTransaction{
txReadyOrClosed: make(chan struct{}),
}
t.txReadOnly.sh = sh
}
attempt++
t.txReadOnly.sh = sh
t.txReadOnly.sp = c.idleSessions
t.txReadOnly.txReadEnv = t
t.txReadOnly.qo = c.qo
Expand Down
196 changes: 196 additions & 0 deletions spanner/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,202 @@ func TestClient_ReadOnlyTransaction_SessionNotFoundOnBeginTransaction_WithMaxOne
}
}

func TestClient_ReadWriteTransaction_SessionNotFoundForFirstStatement(t *testing.T) {
ctx := context.Background()
server, client, teardown := setupMockedTestServer(t)
defer teardown()
server.TestSpanner.PutExecutionTime(
MethodExecuteStreamingSql,
SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
)

expectedAttempts := 2
var attempts int
_, err := client.ReadWriteTransaction(
ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
attempts++
iter := tx.Query(ctx, NewStatement(SelectFooFromBar))
defer iter.Stop()
for {
_, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
return err
}
}
return nil
})
if err != nil {
t.Fatal(err)
}
if expectedAttempts != attempts {
t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, expectedAttempts)
}
requests := drainRequestsFromServer(server.TestSpanner)
if err := compareRequests([]interface{}{
&sppb.BatchCreateSessionsRequest{},
&sppb.ExecuteSqlRequest{},
&sppb.BeginTransactionRequest{},
&sppb.ExecuteSqlRequest{},
&sppb.CommitRequest{},
}, requests); err != nil {
t.Fatal(err)
}
}

func TestClient_ReadWriteTransaction_SessionNotFoundForFirstStatement_AndThenSessionNotFoundForBeginTransaction(t *testing.T) {
ctx := context.Background()
server, client, teardown := setupMockedTestServer(t)
defer teardown()
server.TestSpanner.PutExecutionTime(
MethodExecuteStreamingSql,
SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
)
server.TestSpanner.PutExecutionTime(
MethodBeginTransaction,
SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
)

expectedAttempts := 2
var attempts int
_, err := client.ReadWriteTransaction(
ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
attempts++
iter := tx.Query(ctx, NewStatement(SelectFooFromBar))
defer iter.Stop()
for {
_, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
return err
}
}
return nil
})
if err != nil {
t.Fatal(err)
}
if expectedAttempts != attempts {
t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, expectedAttempts)
}
requests := drainRequestsFromServer(server.TestSpanner)
if err := compareRequests([]interface{}{
&sppb.BatchCreateSessionsRequest{},
&sppb.ExecuteSqlRequest{},
&sppb.BeginTransactionRequest{},
&sppb.BeginTransactionRequest{},
&sppb.ExecuteSqlRequest{},
&sppb.CommitRequest{},
}, requests); err != nil {
t.Fatal(err)
}
}

func TestClient_ReadWriteTransaction_AbortedForFirstStatement_AndThenSessionNotFoundForBeginTransaction(t *testing.T) {
ctx := context.Background()
server, client, teardown := setupMockedTestServer(t)
defer teardown()
server.TestSpanner.PutExecutionTime(
MethodExecuteStreamingSql,
SimulatedExecutionTime{Errors: []error{status.Error(codes.Aborted, "Transaction aborted")}},
)
server.TestSpanner.PutExecutionTime(
MethodBeginTransaction,
SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
)

expectedAttempts := 2
var attempts int
_, err := client.ReadWriteTransaction(
ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
attempts++
iter := tx.Query(ctx, NewStatement(SelectFooFromBar))
defer iter.Stop()
for {
_, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
return err
}
}
return nil
})
if err != nil {
t.Fatal(err)
}
if expectedAttempts != attempts {
t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, expectedAttempts)
}
requests := drainRequestsFromServer(server.TestSpanner)
if err := compareRequests([]interface{}{
&sppb.BatchCreateSessionsRequest{},
&sppb.ExecuteSqlRequest{},
&sppb.BeginTransactionRequest{},
&sppb.BeginTransactionRequest{},
&sppb.ExecuteSqlRequest{},
&sppb.CommitRequest{},
}, requests); err != nil {
t.Fatal(err)
}
}

func TestClient_ReadWriteTransaction_SessionNotFoundForFirstStatement_DoesNotLeakSession(t *testing.T) {
ctx := context.Background()
server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
SessionPoolConfig: SessionPoolConfig{
MinOpened: 1,
MaxOpened: 1,
},
})
defer teardown()
server.TestSpanner.PutExecutionTime(
MethodExecuteStreamingSql,
SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
)

expectedAttempts := 2
var attempts int
_, err := client.ReadWriteTransaction(
ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
attempts++
iter := tx.Query(ctx, NewStatement(SelectFooFromBar))
defer iter.Stop()
for {
_, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
return err
}
}
return nil
})
if err != nil {
t.Fatal(err)
}
if expectedAttempts != attempts {
t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, expectedAttempts)
}
requests := drainRequestsFromServer(server.TestSpanner)
if err := compareRequests([]interface{}{
&sppb.BatchCreateSessionsRequest{},
&sppb.ExecuteSqlRequest{},
&sppb.BatchCreateSessionsRequest{}, // We need to create more sessions, as the one used first was destroyed.
&sppb.BeginTransactionRequest{},
&sppb.ExecuteSqlRequest{},
&sppb.CommitRequest{},
}, requests); err != nil {
t.Fatal(err)
}
}

func TestClient_ReadOnlyTransaction_QueryOptions(t *testing.T) {
for _, tt := range queryOptionsTestCases() {
t.Run(tt.name, func(t *testing.T) {
Expand Down
51 changes: 51 additions & 0 deletions spanner/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,14 @@ import (
adminpb "cloud.google.com/go/spanner/admin/database/apiv1/databasepb"
instance "cloud.google.com/go/spanner/admin/instance/apiv1"
"cloud.google.com/go/spanner/admin/instance/apiv1/instancepb"
v1 "cloud.google.com/go/spanner/apiv1"
sppb "cloud.google.com/go/spanner/apiv1/spannerpb"
"cloud.google.com/go/spanner/internal"
"go.opencensus.io/stats/view"
"go.opencensus.io/tag"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
"google.golang.org/api/option/internaloption"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/peer"
Expand Down Expand Up @@ -846,6 +848,55 @@ func TestIntegration_SingleUse_WithQueryOptions(t *testing.T) {
}
}

func TestIntegration_TransactionWasStartedInDifferentSession(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
// Set up testing environment.
client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][singerDDLStatements])
defer cleanup()

attempts := 0
_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, transaction *ReadWriteTransaction) error {
attempts++
if attempts == 1 {
deleteTestSession(ctx, t, transaction.sh.getID())
}
if _, err := readAll(transaction.Query(ctx, NewStatement("select * from singers"))); err != nil {
return err
}
return nil
})
if err != nil {
t.Fatal(err)
}
if g, w := attempts, 2; g != w {
t.Fatalf("attempts mismatch\nGot: %v\nWant: %v", g, w)
}
}

func deleteTestSession(ctx context.Context, t *testing.T, sessionName string) {
var opts []option.ClientOption
if emulatorAddr := os.Getenv("SPANNER_EMULATOR_HOST"); emulatorAddr != "" {
emulatorOpts := []option.ClientOption{
option.WithEndpoint(emulatorAddr),
option.WithGRPCDialOption(grpc.WithInsecure()),
option.WithoutAuthentication(),
internaloption.SkipDialSettingsValidation(),
}
opts = append(emulatorOpts, opts...)
}
gapic, err := v1.NewClient(ctx, opts...)
if err != nil {
t.Fatalf("could not create gapic client: %v", err)
}
defer gapic.Close()
if err := gapic.DeleteSession(ctx, &sppb.DeleteSessionRequest{Name: sessionName}); err != nil {
t.Fatal(err)
}
}

func TestIntegration_SingleUse_ReadingWithLimit(t *testing.T) {
t.Parallel()

Expand Down
18 changes: 11 additions & 7 deletions spanner/internal/testutil/inmem_spanner_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -581,13 +581,17 @@ func (s *inMemSpannerServer) beginTransaction(session *spannerpb.Session, option
return res
}

func (s *inMemSpannerServer) getTransactionByID(id []byte) (*spannerpb.Transaction, error) {
func (s *inMemSpannerServer) getTransactionByID(session *spannerpb.Session, id []byte) (*spannerpb.Transaction, error) {
s.mu.Lock()
defer s.mu.Unlock()
tx, ok := s.transactions[string(id)]
if !ok {
return nil, gstatus.Error(codes.NotFound, "Transaction not found")
}
if !strings.HasPrefix(string(id), session.Name) {
return nil, gstatus.Error(codes.InvalidArgument, "Transaction was started in a different session.")
}

aborted, ok := s.abortedTransactions[string(id)]
if ok && aborted {
return nil, newAbortedErrorWithMinimalRetryDelay()
Expand Down Expand Up @@ -813,7 +817,7 @@ func (s *inMemSpannerServer) ExecuteSql(ctx context.Context, req *spannerpb.Exec
var id []byte
s.updateSessionLastUseTime(session.Name)
if id = s.getTransactionID(session, req.Transaction); id != nil {
_, err = s.getTransactionByID(id)
_, err = s.getTransactionByID(session, id)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -860,7 +864,7 @@ func (s *inMemSpannerServer) executeStreamingSQL(req *spannerpb.ExecuteSqlReques
s.updateSessionLastUseTime(session.Name)
var id []byte
if id = s.getTransactionID(session, req.Transaction); id != nil {
_, err = s.getTransactionByID(id)
_, err = s.getTransactionByID(session, id)
if err != nil {
return err
}
Expand Down Expand Up @@ -932,7 +936,7 @@ func (s *inMemSpannerServer) ExecuteBatchDml(ctx context.Context, req *spannerpb
s.updateSessionLastUseTime(session.Name)
var id []byte
if id = s.getTransactionID(session, req.Transaction); id != nil {
_, err = s.getTransactionByID(id)
_, err = s.getTransactionByID(session, id)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1031,7 +1035,7 @@ func (s *inMemSpannerServer) Commit(ctx context.Context, req *spannerpb.CommitRe
if req.GetSingleUseTransaction() != nil {
tx = s.beginTransaction(session, req.GetSingleUseTransaction())
} else if req.GetTransactionId() != nil {
tx, err = s.getTransactionByID(req.GetTransactionId())
tx, err = s.getTransactionByID(session, req.GetTransactionId())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1064,7 +1068,7 @@ func (s *inMemSpannerServer) Rollback(ctx context.Context, req *spannerpb.Rollba
return nil, err
}
s.updateSessionLastUseTime(session.Name)
tx, err := s.getTransactionByID(req.TransactionId)
tx, err := s.getTransactionByID(session, req.TransactionId)
if err != nil {
return nil, err
}
Expand All @@ -1091,7 +1095,7 @@ func (s *inMemSpannerServer) PartitionQuery(ctx context.Context, req *spannerpb.
var tx *spannerpb.Transaction
s.updateSessionLastUseTime(session.Name)
if id = s.getTransactionID(session, req.Transaction); id != nil {
tx, err = s.getTransactionByID(id)
tx, err = s.getTransactionByID(session, id)
if err != nil {
return nil, err
}
Expand Down
10 changes: 4 additions & 6 deletions spanner/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -1380,15 +1380,13 @@ func (t *ReadWriteTransaction) begin(ctx context.Context) error {
}()
// Retry the BeginTransaction call if a 'Session not found' is returned.
for {
if sh == nil || sh.getID() == "" || sh.getClient() == nil {
tx, err = beginTransaction(contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader), sh.getID(), sh.getClient(), t.txOpts)
if isSessionNotFoundError(err) {
sh.destroy()
sh, err = t.sp.take(ctx)
if err != nil {
return err
}
}
tx, err = beginTransaction(contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader), sh.getID(), sh.getClient(), t.txOpts)
if isSessionNotFoundError(err) {
sh.destroy()
continue
} else {
err = ToSpannerError(err)
Expand All @@ -1399,7 +1397,7 @@ func (t *ReadWriteTransaction) begin(ctx context.Context) error {
t.mu.Lock()
t.tx = tx
t.sh = sh
// State transite to txActive.
// Transition state to txActive.
t.state = txActive
t.mu.Unlock()
}
Expand Down

0 comments on commit 6c21558

Please sign in to comment.