From fe4e66c3e166a92c108bde2d7e62998ece7cf8f3 Mon Sep 17 00:00:00 2001 From: Olav Loite Date: Sun, 18 Aug 2019 19:45:07 +0200 Subject: [PATCH] spanner: use the standard GAX retryer for stream Use a standard GAX retryer for resumableStreamDecoder and only retry on standard gRPC codes. Removes the custom error checks that were used by this decoder. Updates #1418. Change-Id: I8f339f31cf71fe3e5f9aebcb685b5444c8aa56b8 Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/44173 Reviewed-by: kokoro Reviewed-by: Jean de Klerk --- spanner/client_test.go | 264 ++++++++++++++---- .../internal/testutil/inmem_spanner_server.go | 107 +++++-- .../testutil/inmem_spanner_server_test.go | 58 ++-- spanner/read.go | 123 ++++---- spanner/read_test.go | 18 +- spanner/retry.go | 90 ------ 6 files changed, 400 insertions(+), 260 deletions(-) diff --git a/spanner/client_test.go b/spanner/client_test.go index fa36de67f9fe..9e723588c695 100644 --- a/spanner/client_test.go +++ b/spanner/client_test.go @@ -22,13 +22,14 @@ import ( "io" "strings" "testing" + "time" itestutil "cloud.google.com/go/internal/testutil" . "cloud.google.com/go/spanner/internal/testutil" "google.golang.org/api/iterator" "google.golang.org/api/option" "google.golang.org/grpc/codes" - gstatus "google.golang.org/grpc/status" + "google.golang.org/grpc/status" ) func setupMockedTestServer(t *testing.T) (server *MockedSpannerInMemTestServer, client *Client, teardown func()) { @@ -110,7 +111,7 @@ func TestClient_Single(t *testing.T) { func TestClient_Single_Unavailable(t *testing.T) { t.Parallel() - err := testSingleQuery(t, gstatus.Error(codes.Unavailable, "Temporary unavailable")) + err := testSingleQuery(t, status.Error(codes.Unavailable, "Temporary unavailable")) if err != nil { t.Fatal(err) } @@ -118,14 +119,176 @@ func TestClient_Single_Unavailable(t *testing.T) { func TestClient_Single_InvalidArgument(t *testing.T) { t.Parallel() - err := testSingleQuery(t, gstatus.Error(codes.InvalidArgument, "Invalid argument")) - if err == nil { - t.Fatalf("missing expected error") - } else if gstatus.Code(err) != codes.InvalidArgument { + err := testSingleQuery(t, status.Error(codes.InvalidArgument, "Invalid argument")) + if status.Code(err) != codes.InvalidArgument { + t.Fatalf("got unexpected exception %v, expected InvalidArgument", err) + } +} + +func TestClient_Single_RetryableErrorOnPartialResultSet(t *testing.T) { + t.Parallel() + server, client, teardown := setupMockedTestServer(t) + defer teardown() + + // Add two errors that will be returned by the mock server when the client + // is trying to fetch a partial result set. Both errors are retryable. + // The errors are not 'sticky' on the mocked server, i.e. once the error + // has been returned once, the next call for the same partial result set + // will succeed. + + // When the client is fetching the partial result set with resume token 2, + // the mock server will respond with an internal error with the message + // 'stream terminated by RST_STREAM'. The client will retry the call to get + // this partial result set. + server.TestSpanner.AddPartialResultSetError( + SelectSingerIDAlbumIDAlbumTitleFromAlbums, + PartialResultSetExecutionTime{ + ResumeToken: EncodeResumeToken(2), + Err: spannerErrorf(codes.Internal, "stream terminated by RST_STREAM"), + }, + ) + // When the client is fetching the partial result set with resume token 3, + // the mock server will respond with a 'Unavailable' error. The client will + // retry the call to get this partial result set. + server.TestSpanner.AddPartialResultSetError( + SelectSingerIDAlbumIDAlbumTitleFromAlbums, + PartialResultSetExecutionTime{ + ResumeToken: EncodeResumeToken(3), + Err: spannerErrorf(codes.Unavailable, "server is unavailable"), + }, + ) + ctx := context.Background() + if err := executeSingerQuery(ctx, client.Single()); err != nil { t.Fatal(err) } } +func TestClient_Single_NonRetryableErrorOnPartialResultSet(t *testing.T) { + t.Parallel() + server, client, teardown := setupMockedTestServer(t) + defer teardown() + + // Add two errors that will be returned by the mock server when the client + // is trying to fetch a partial result set. The first error is retryable, + // the second is not. + + // This error will automatically be retried. + server.TestSpanner.AddPartialResultSetError( + SelectSingerIDAlbumIDAlbumTitleFromAlbums, + PartialResultSetExecutionTime{ + ResumeToken: EncodeResumeToken(2), + Err: spannerErrorf(codes.Internal, "stream terminated by RST_STREAM"), + }, + ) + // 'Session not found' is not retryable and the error will be returned to + // the user. + server.TestSpanner.AddPartialResultSetError( + SelectSingerIDAlbumIDAlbumTitleFromAlbums, + PartialResultSetExecutionTime{ + ResumeToken: EncodeResumeToken(3), + Err: spannerErrorf(codes.NotFound, "Session not found"), + }, + ) + ctx := context.Background() + err := executeSingerQuery(ctx, client.Single()) + if status.Code(err) != codes.NotFound { + t.Fatalf("Error mismatch:\ngot: %v\nwant: %v", err, codes.NotFound) + } +} + +func TestClient_Single_DeadlineExceeded_NoErrors(t *testing.T) { + t.Parallel() + server, client, teardown := setupMockedTestServer(t) + defer teardown() + server.TestSpanner.PutExecutionTime(MethodExecuteStreamingSql, + SimulatedExecutionTime{ + MinimumExecutionTime: 50 * time.Millisecond, + }) + ctx := context.Background() + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(5*time.Millisecond)) + defer cancel() + err := executeSingerQuery(ctx, client.Single()) + if status.Code(err) != codes.DeadlineExceeded { + t.Fatalf("Error mismatch:\ngot: %v\nwant: %v", err, codes.DeadlineExceeded) + } +} + +func TestClient_Single_DeadlineExceeded_WithErrors(t *testing.T) { + t.Parallel() + server, client, teardown := setupMockedTestServer(t) + defer teardown() + server.TestSpanner.AddPartialResultSetError( + SelectSingerIDAlbumIDAlbumTitleFromAlbums, + PartialResultSetExecutionTime{ + ResumeToken: EncodeResumeToken(2), + Err: spannerErrorf(codes.Internal, "stream terminated by RST_STREAM"), + }, + ) + server.TestSpanner.AddPartialResultSetError( + SelectSingerIDAlbumIDAlbumTitleFromAlbums, + PartialResultSetExecutionTime{ + ResumeToken: EncodeResumeToken(3), + Err: spannerErrorf(codes.Unavailable, "server is unavailable"), + ExecutionTime: 50 * time.Millisecond, + }, + ) + ctx := context.Background() + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(25*time.Millisecond)) + defer cancel() + err := executeSingerQuery(ctx, client.Single()) + if status.Code(err) != codes.DeadlineExceeded { + t.Fatalf("got unexpected error %v, expected DeadlineExceeded", err) + } +} + +func TestClient_Single_ContextCanceled_noDeclaredServerErrors(t *testing.T) { + t.Parallel() + _, client, teardown := setupMockedTestServer(t) + defer teardown() + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + cancel() + err := executeSingerQuery(ctx, client.Single()) + if status.Code(err) != codes.Canceled { + t.Fatalf("got unexpected error %v, expected Canceled", err) + } +} + +func TestClient_Single_ContextCanceled_withDeclaredServerErrors(t *testing.T) { + t.Parallel() + server, client, teardown := setupMockedTestServer(t) + defer teardown() + server.TestSpanner.AddPartialResultSetError( + SelectSingerIDAlbumIDAlbumTitleFromAlbums, + PartialResultSetExecutionTime{ + ResumeToken: EncodeResumeToken(2), + Err: spannerErrorf(codes.Internal, "stream terminated by RST_STREAM"), + }, + ) + server.TestSpanner.AddPartialResultSetError( + SelectSingerIDAlbumIDAlbumTitleFromAlbums, + PartialResultSetExecutionTime{ + ResumeToken: EncodeResumeToken(3), + Err: spannerErrorf(codes.Unavailable, "server is unavailable"), + }, + ) + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + f := func(rowCount int64) error { + if rowCount == 2 { + cancel() + } + return nil + } + iter := client.Single().Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums)) + defer iter.Stop() + err := executeSingerQueryWithRowFunc(ctx, client.Single(), f) + if status.Code(err) != codes.Canceled { + t.Fatalf("got unexpected error %v, expected Canceled", err) + } +} + func testSingleQuery(t *testing.T, serverError error) error { ctx := context.Background() server, client, teardown := setupMockedTestServer(t) @@ -133,8 +296,17 @@ func testSingleQuery(t *testing.T, serverError error) error { if serverError != nil { server.TestSpanner.SetError(serverError) } - iter := client.Single().Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums)) + return executeSingerQuery(ctx, client.Single()) +} + +func executeSingerQuery(ctx context.Context, tx *ReadOnlyTransaction) error { + return executeSingerQueryWithRowFunc(ctx, tx, nil) +} + +func executeSingerQueryWithRowFunc(ctx context.Context, tx *ReadOnlyTransaction, f func(rowCount int64) error) error { + iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums)) defer iter.Stop() + rowCount := int64(0) for { row, err := iter.Next() if err == iterator.Done { @@ -148,14 +320,23 @@ func testSingleQuery(t *testing.T, serverError error) error { if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil { return err } + rowCount++ + if f != nil { + if err := f(rowCount); err != nil { + return err + } + } + } + if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount { + return spannerErrorf(codes.Internal, "Row count mismatch, got %v, expected %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount) } return nil } func createSimulatedExecutionTimeWithTwoUnavailableErrors(method string) map[string]SimulatedExecutionTime { errors := make([]error, 2) - errors[0] = gstatus.Error(codes.Unavailable, "Temporary unavailable") - errors[1] = gstatus.Error(codes.Unavailable, "Temporary unavailable") + errors[0] = status.Error(codes.Unavailable, "Temporary unavailable") + errors[1] = status.Error(codes.Unavailable, "Temporary unavailable") executionTimes := make(map[string]SimulatedExecutionTime) executionTimes[method] = SimulatedExecutionTime{ Errors: errors, @@ -194,8 +375,8 @@ func TestClient_ReadOnlyTransaction_UnavailableOnExecuteStreamingSql(t *testing. func TestClient_ReadOnlyTransaction_UnavailableOnCreateSessionAndBeginTransaction(t *testing.T) { t.Parallel() exec := map[string]SimulatedExecutionTime{ - MethodCreateSession: {Errors: []error{gstatus.Error(codes.Unavailable, "Temporary unavailable")}}, - MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.Unavailable, "Temporary unavailable")}}, + MethodCreateSession: {Errors: []error{status.Error(codes.Unavailable, "Temporary unavailable")}}, + MethodBeginTransaction: {Errors: []error{status.Error(codes.Unavailable, "Temporary unavailable")}}, } if err := testReadOnlyTransaction(t, exec); err != nil { t.Fatal(err) @@ -205,12 +386,12 @@ func TestClient_ReadOnlyTransaction_UnavailableOnCreateSessionAndBeginTransactio func TestClient_ReadOnlyTransaction_UnavailableOnCreateSessionAndInvalidArgumentOnBeginTransaction(t *testing.T) { t.Parallel() exec := map[string]SimulatedExecutionTime{ - MethodCreateSession: {Errors: []error{gstatus.Error(codes.Unavailable, "Temporary unavailable")}}, - MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.InvalidArgument, "Invalid argument")}}, + MethodCreateSession: {Errors: []error{status.Error(codes.Unavailable, "Temporary unavailable")}}, + MethodBeginTransaction: {Errors: []error{status.Error(codes.InvalidArgument, "Invalid argument")}}, } if err := testReadOnlyTransaction(t, exec); err == nil { t.Fatalf("Missing expected exception") - } else if gstatus.Code(err) != codes.InvalidArgument { + } else if status.Code(err) != codes.InvalidArgument { t.Fatalf("Got unexpected exception: %v", err) } } @@ -224,23 +405,7 @@ func testReadOnlyTransaction(t *testing.T, executionTimes map[string]SimulatedEx tx := client.ReadOnlyTransaction() defer tx.Close() ctx := context.Background() - iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums)) - defer iter.Stop() - for { - row, err := iter.Next() - if err == iterator.Done { - break - } - if err != nil { - return err - } - var singerID, albumID int64 - var albumTitle string - if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil { - return err - } - } - return nil + return executeSingerQuery(ctx, tx) } func TestClient_ReadWriteTransaction(t *testing.T) { @@ -253,7 +418,7 @@ func TestClient_ReadWriteTransaction(t *testing.T) { func TestClient_ReadWriteTransactionCommitAborted(t *testing.T) { t.Parallel() if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{ - MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.Aborted, "Transaction aborted")}}, + MethodCommitTransaction: {Errors: []error{status.Error(codes.Aborted, "Transaction aborted")}}, }, 2); err != nil { t.Fatal(err) } @@ -262,7 +427,7 @@ func TestClient_ReadWriteTransactionCommitAborted(t *testing.T) { func TestClient_ReadWriteTransactionExecuteStreamingSqlAborted(t *testing.T) { t.Parallel() if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{ - MethodExecuteStreamingSql: {Errors: []error{gstatus.Error(codes.Aborted, "Transaction aborted")}}, + MethodExecuteStreamingSql: {Errors: []error{status.Error(codes.Aborted, "Transaction aborted")}}, }, 2); err != nil { t.Fatal(err) } @@ -271,7 +436,7 @@ func TestClient_ReadWriteTransactionExecuteStreamingSqlAborted(t *testing.T) { func TestClient_ReadWriteTransaction_UnavailableOnBeginTransaction(t *testing.T) { t.Parallel() if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{ - MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, + MethodBeginTransaction: {Errors: []error{status.Error(codes.Unavailable, "Unavailable")}}, }, 1); err != nil { t.Fatal(err) } @@ -279,8 +444,8 @@ func TestClient_ReadWriteTransaction_UnavailableOnBeginTransaction(t *testing.T) func TestClient_ReadWriteTransaction_UnavailableOnBeginAndAbortOnCommit(t *testing.T) { if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{ - MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, - MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.Aborted, "Aborted")}}, + MethodBeginTransaction: {Errors: []error{status.Error(codes.Unavailable, "Unavailable")}}, + MethodCommitTransaction: {Errors: []error{status.Error(codes.Aborted, "Aborted")}}, }, 2); err != nil { t.Fatal(err) } @@ -289,7 +454,7 @@ func TestClient_ReadWriteTransaction_UnavailableOnBeginAndAbortOnCommit(t *testi func TestClient_ReadWriteTransaction_UnavailableOnExecuteStreamingSql(t *testing.T) { t.Parallel() if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{ - MethodExecuteStreamingSql: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, + MethodExecuteStreamingSql: {Errors: []error{status.Error(codes.Unavailable, "Unavailable")}}, }, 1); err != nil { t.Fatal(err) } @@ -298,9 +463,9 @@ func TestClient_ReadWriteTransaction_UnavailableOnExecuteStreamingSql(t *testing func TestClient_ReadWriteTransaction_UnavailableOnBeginAndExecuteStreamingSqlAndTwiceAbortOnCommit(t *testing.T) { t.Parallel() if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{ - MethodBeginTransaction: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, - MethodExecuteStreamingSql: {Errors: []error{gstatus.Error(codes.Unavailable, "Unavailable")}}, - MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.Aborted, "Aborted"), gstatus.Error(codes.Aborted, "Aborted")}}, + MethodBeginTransaction: {Errors: []error{status.Error(codes.Unavailable, "Unavailable")}}, + MethodExecuteStreamingSql: {Errors: []error{status.Error(codes.Unavailable, "Unavailable")}}, + MethodCommitTransaction: {Errors: []error{status.Error(codes.Aborted, "Aborted"), status.Error(codes.Aborted, "Aborted")}}, }, 3); err != nil { t.Fatal(err) } @@ -309,8 +474,8 @@ func TestClient_ReadWriteTransaction_UnavailableOnBeginAndExecuteStreamingSqlAnd func TestClient_ReadWriteTransaction_AbortedOnExecuteStreamingSqlAndCommit(t *testing.T) { t.Parallel() if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{ - MethodExecuteStreamingSql: {Errors: []error{gstatus.Error(codes.Aborted, "Aborted")}}, - MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.Aborted, "Aborted"), gstatus.Error(codes.Aborted, "Aborted")}}, + MethodExecuteStreamingSql: {Errors: []error{status.Error(codes.Aborted, "Aborted")}}, + MethodCommitTransaction: {Errors: []error{status.Error(codes.Aborted, "Aborted"), status.Error(codes.Aborted, "Aborted")}}, }, 4); err != nil { t.Fatal(err) } @@ -321,8 +486,8 @@ func TestClient_ReadWriteTransactionCommitAbortedAndUnavailable(t *testing.T) { if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{ MethodCommitTransaction: { Errors: []error{ - gstatus.Error(codes.Aborted, "Transaction aborted"), - gstatus.Error(codes.Unavailable, "Unavailable"), + status.Error(codes.Aborted, "Transaction aborted"), + status.Error(codes.Unavailable, "Unavailable"), }, }, }, 2); err != nil { @@ -333,9 +498,9 @@ func TestClient_ReadWriteTransactionCommitAbortedAndUnavailable(t *testing.T) { func TestClient_ReadWriteTransactionCommitAlreadyExists(t *testing.T) { t.Parallel() if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{ - MethodCommitTransaction: {Errors: []error{gstatus.Error(codes.AlreadyExists, "A row with this key already exists")}}, + MethodCommitTransaction: {Errors: []error{status.Error(codes.AlreadyExists, "A row with this key already exists")}}, }, 1); err != nil { - if gstatus.Code(err) != codes.AlreadyExists { + if status.Code(err) != codes.AlreadyExists { t.Fatalf("Got unexpected error %v, expected %v", err, codes.AlreadyExists) } } else { @@ -355,6 +520,7 @@ func testReadWriteTransaction(t *testing.T, executionTimes map[string]SimulatedE attempts++ iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums)) defer iter.Stop() + rowCount := int64(0) for { row, err := iter.Next() if err == iterator.Done { @@ -368,6 +534,10 @@ func testReadWriteTransaction(t *testing.T, executionTimes map[string]SimulatedE if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil { return err } + rowCount++ + } + if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount { + return spannerErrorf(codes.FailedPrecondition, "Row count mismatch, got %v, expected %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount) } return nil }) @@ -390,7 +560,7 @@ func TestClient_ApplyAtLeastOnce(t *testing.T) { } server.TestSpanner.PutExecutionTime(MethodCommitTransaction, SimulatedExecutionTime{ - Errors: []error{gstatus.Error(codes.Aborted, "Transaction aborted")}, + Errors: []error{status.Error(codes.Aborted, "Transaction aborted")}, }) _, err := client.Apply(context.Background(), ms, ApplyAtLeastOnce()) if err != nil { diff --git a/spanner/internal/testutil/inmem_spanner_server.go b/spanner/internal/testutil/inmem_spanner_server.go index dbd0d2fa7f31..965e5c29d66a 100644 --- a/spanner/internal/testutil/inmem_spanner_server.go +++ b/spanner/internal/testutil/inmem_spanner_server.go @@ -15,6 +15,7 @@ package testutil_test import ( + "bytes" "context" "fmt" "math/rand" @@ -29,7 +30,6 @@ import ( "google.golang.org/genproto/googleapis/rpc/status" spannerpb "google.golang.org/genproto/googleapis/spanner/v1" "google.golang.org/grpc/codes" - gstatus "google.golang.org/grpc/status" ) @@ -46,6 +46,10 @@ const ( // StatementResultUpdateCount indicates that the sql statement returns an // update count. StatementResultUpdateCount StatementResultType = 2 + // MaxRowsPerPartialResultSet is the maximum number of rows returned in + // each PartialResultSet. This number is deliberately set to a low value to + // ensure that most queries return more than one PartialResultSet. + MaxRowsPerPartialResultSet = 1 ) // The method names that can be used to register execution times and errors. @@ -59,8 +63,8 @@ const ( MethodExecuteStreamingSql string = "EXECUTE_STREAMING_SQL" ) -// StatementResult represents a mocked result on the test server. Th result can -// be either a ResultSet, an update count or an error. +// StatementResult represents a mocked result on the test server. The result is +// either of: a ResultSet, an update count or an error. type StatementResult struct { Type StatementResultType Err error @@ -68,23 +72,57 @@ type StatementResult struct { UpdateCount int64 } +// PartialResultSetExecutionTime represents execution times and errors that +// should be used when a PartialResult at the specified resume token is to +// be returned. +type PartialResultSetExecutionTime struct { + ResumeToken []byte + ExecutionTime time.Duration + Err error +} + // Converts a ResultSet to a PartialResultSet. This method is used to convert // a mocked result to a PartialResultSet when one of the streaming methods are // called. -func (s *StatementResult) toPartialResultSet() *spannerpb.PartialResultSet { - values := make([]*structpb.Value, - len(s.ResultSet.Rows)*len(s.ResultSet.Metadata.RowType.Fields)) - var idx int - for _, row := range s.ResultSet.Rows { - for colIdx := range s.ResultSet.Metadata.RowType.Fields { - values[idx] = row.Values[colIdx] - idx++ +func (s *StatementResult) toPartialResultSets(resumeToken []byte) (result []*spannerpb.PartialResultSet, err error) { + var startIndex uint64 + if len(resumeToken) > 0 { + if startIndex, err = DecodeResumeToken(resumeToken); err != nil { + return nil, err } } - return &spannerpb.PartialResultSet{ - Metadata: s.ResultSet.Metadata, - Values: values, + + totalRows := uint64(len(s.ResultSet.Rows)) + for { + rowCount := min(totalRows-startIndex, uint64(MaxRowsPerPartialResultSet)) + rows := s.ResultSet.Rows[startIndex : startIndex+rowCount] + values := make([]*structpb.Value, + len(rows)*len(s.ResultSet.Metadata.RowType.Fields)) + var idx int + for _, row := range rows { + for colIdx := range s.ResultSet.Metadata.RowType.Fields { + values[idx] = row.Values[colIdx] + idx++ + } + } + result = append(result, &spannerpb.PartialResultSet{ + Metadata: s.ResultSet.Metadata, + Values: values, + ResumeToken: EncodeResumeToken(startIndex + rowCount), + }) + startIndex += rowCount + if startIndex == totalRows { + break + } + } + return result, nil +} + +func min(x, y uint64) uint64 { + if x > y { + return y } + return x } func (s *StatementResult) updateCountToPartialResultSet(exact bool) *spannerpb.PartialResultSet { @@ -148,6 +186,10 @@ type InMemSpannerServer interface { // expect a SQL statement, including (batch) DML methods. PutStatementResult(sql string, result *StatementResult) error + // Adds a PartialResultSetExecutionTime to the server that should be returned + // for the specified SQL string. + AddPartialResultSetError(sql string, err PartialResultSetExecutionTime) + // Removes a mocked result on the server for a specific sql statement. RemoveStatementResult(sql string) @@ -201,7 +243,10 @@ type inMemSpannerServer struct { // The mocked results for this server. statementResults map[string]*StatementResult // The simulated execution times per method. - executionTimes map[string]*SimulatedExecutionTime + executionTimes map[string]*SimulatedExecutionTime + // The simulated errors for partial result sets + partialResultSetErrors map[string][]*PartialResultSetExecutionTime + totalSessionsCreated uint totalSessionsDeleted uint receivedRequests chan interface{} @@ -218,6 +263,7 @@ func NewInMemSpannerServer() InMemSpannerServer { res.initDefaults() res.statementResults = make(map[string]*StatementResult) res.executionTimes = make(map[string]*SimulatedExecutionTime) + res.partialResultSetErrors = make(map[string][]*PartialResultSetExecutionTime) res.receivedRequests = make(chan interface{}, 1000000) // Produce a closed channel, so the default action of ready is to not block. res.Freeze() @@ -275,6 +321,12 @@ func (s *inMemSpannerServer) PutExecutionTime(method string, executionTime Simul s.executionTimes[method] = &executionTime } +func (s *inMemSpannerServer) AddPartialResultSetError(sql string, partialResultSetError PartialResultSetExecutionTime) { + s.mu.Lock() + defer s.mu.Unlock() + s.partialResultSetErrors[sql] = append(s.partialResultSetErrors[sql], &partialResultSetError) +} + // Freeze stalls all requests. func (s *inMemSpannerServer) Freeze() { s.mu.Lock() @@ -628,10 +680,31 @@ func (s *inMemSpannerServer) ExecuteStreamingSql(req *spannerpb.ExecuteSqlReques case StatementResultError: return statementResult.Err case StatementResultResultSet: - part := statementResult.toPartialResultSet() - if err := stream.Send(part); err != nil { + parts, err := statementResult.toPartialResultSets(req.ResumeToken) + if err != nil { return err } + var nextPartialResultSetError *PartialResultSetExecutionTime + s.mu.Lock() + pErrors := s.partialResultSetErrors[req.Sql] + if len(pErrors) > 0 { + nextPartialResultSetError = pErrors[0] + s.partialResultSetErrors[req.Sql] = pErrors[1:] + } + s.mu.Unlock() + for _, part := range parts { + if nextPartialResultSetError != nil && bytes.Equal(part.ResumeToken, nextPartialResultSetError.ResumeToken) { + if nextPartialResultSetError.ExecutionTime > 0 { + <-time.After(nextPartialResultSetError.ExecutionTime) + } + if nextPartialResultSetError.Err != nil { + return nextPartialResultSetError.Err + } + } + if err := stream.Send(part); err != nil { + return err + } + } return nil case StatementResultUpdateCount: part := statementResult.updateCountToPartialResultSet(!isPartitionedDml) diff --git a/spanner/internal/testutil/inmem_spanner_server_test.go b/spanner/internal/testutil/inmem_spanner_server_test.go index 6e7d66076c23..08c502455e86 100644 --- a/spanner/internal/testutil/inmem_spanner_server_test.go +++ b/spanner/internal/testutil/inmem_spanner_server_test.go @@ -134,7 +134,7 @@ func TestSpannerCreateSession(t *testing.T) { t.Fatal(err) } if strings.Index(resp.Name, expectedName) != 0 { - t.Errorf("wrong name %s, should start with %s)", resp.Name, expectedName) + t.Errorf("Session name mismatch\nGot: %s\nWant: Name should start with %s)", resp.Name, expectedName) } } @@ -156,7 +156,7 @@ func TestSpannerCreateSession_Unavailable(t *testing.T) { t.Fatal(err) } if strings.Index(resp.Name, expectedName) != 0 { - t.Errorf("wrong name %s, should start with %s)", resp.Name, expectedName) + t.Errorf("Session name mismatch\nGot: %s\nWant: Name should start with %s)", resp.Name, expectedName) } } @@ -183,7 +183,7 @@ func TestSpannerGetSession(t *testing.T) { t.Fatal(err) } if getResp.Name != getRequest.Name { - t.Errorf("wrong name %s, expected %s)", getResp.Name, getRequest.Name) + t.Errorf("Session name mismatch\nGot: %s\nWant: Name should start with %s)", getResp.Name, getRequest.Name) } } @@ -220,12 +220,12 @@ func TestSpannerListSessions(t *testing.T) { t.Fatal(err) } if strings.Index(session.Name, expectedName) != 0 { - t.Errorf("wrong name %s, should start with %s)", session.Name, expectedName) + t.Errorf("Session name mismatch\nGot: %s\nWant: Name should start with %s)", session.Name, expectedName) } sessionCount++ } if sessionCount != expectedNumberOfSessions { - t.Errorf("wrong number of sessions: %d, expected %d", sessionCount, expectedNumberOfSessions) + t.Errorf("Session count mismatch\nGot: %d\nWant: %d", sessionCount, expectedNumberOfSessions) } } @@ -267,7 +267,7 @@ func TestSpannerDeleteSession(t *testing.T) { sessionCount++ } if sessionCount != expectedNumberOfSessions { - t.Errorf("wrong number of sessions: %d, expected %d", sessionCount, expectedNumberOfSessions) + t.Errorf("Session count mismatch\nGot: %d\nWant: %d", sessionCount, expectedNumberOfSessions) } // Re-list all sessions. This should now be empty. listResp = c.ListSessions(context.Background(), listRequest) @@ -319,12 +319,12 @@ func TestSpannerExecuteSql(t *testing.T) { var rowCount int64 for _, row := range response.Rows { if len(row.Values) != selectColCount { - t.Fatalf("unexpected number of columns: %d, expected %d", len(row.Values), selectColCount) + t.Fatalf("Column count mismatch\nGot: %d\nWant: %d", len(row.Values), selectColCount) } rowCount++ } if rowCount != selectRowCount { - t.Fatalf("unexpected number of rows: %d, expected %d", rowCount, selectRowCount) + t.Fatalf("Row count mismatch\nGot: %d\nWant: %d", rowCount, selectRowCount) } } @@ -364,7 +364,7 @@ func TestSpannerExecuteSqlDml(t *testing.T) { } var rowCount int64 = response.Stats.GetRowCountExact() if rowCount != updateRowCount { - t.Fatalf("unexpected number of rows updated: %d, expected %d", rowCount, updateRowCount) + t.Fatalf("Update count mismatch\nGot: %d\nWant: %d", rowCount, updateRowCount) } } @@ -407,32 +407,38 @@ func TestSpannerExecuteStreamingSql(t *testing.T) { if err != nil { t.Fatal(err) } - partial, err := response.Recv() - if err != nil { - t.Fatal(err) - } var rowIndex int64 - colCount := len(partial.Metadata.RowType.Fields) - if colCount != selectColCount { - t.Fatalf("unexpected number of columns: %d, expected %d", colCount, selectColCount) - } + var colCount int for { - for col := 0; col < colCount; col++ { - val, err := strconv.ParseInt(partial.Values[rowIndex*int64(colCount)+int64(col)].GetStringValue(), 10, 64) + for rowIndexInPartial := int64(0); rowIndexInPartial < MaxRowsPerPartialResultSet; rowIndexInPartial++ { + partial, err := response.Recv() if err != nil { t.Fatal(err) } - if val != selectValues[rowIndex] { - t.Fatalf("Unexpected value at index %d. Expected %d, got %d", rowIndex, selectValues[rowIndex], val) + if rowIndex == 0 { + colCount = len(partial.Metadata.RowType.Fields) + if colCount != selectColCount { + t.Fatalf("Column count mismatch\nGot: %d\nWant: %d", colCount, selectColCount) + } + } + for col := 0; col < colCount; col++ { + pIndex := rowIndexInPartial*int64(colCount) + int64(col) + val, err := strconv.ParseInt(partial.Values[pIndex].GetStringValue(), 10, 64) + if err != nil { + t.Fatalf("Error parsing integer at #%d: %v", pIndex, err) + } + if val != selectValues[rowIndex] { + t.Fatalf("Value mismatch at index %d\nGot: %d\nWant: %d", rowIndex, val, selectValues[rowIndex]) + } } + rowIndex++ } - rowIndex++ if rowIndex == selectRowCount { break } } if rowIndex != selectRowCount { - t.Fatalf("unexpected number of rows: %d, expected %d", rowIndex, selectRowCount) + t.Fatalf("Row count mismatch\nGot: %d\nWant: %d", rowIndex, selectRowCount) } } @@ -477,12 +483,12 @@ func TestSpannerExecuteBatchDml(t *testing.T) { for _, res := range response.ResultSets { var rowCount int64 = res.Stats.GetRowCountExact() if rowCount != updateRowCount { - t.Fatalf("unexpected number of rows updated: %d, expected %d", rowCount, updateRowCount) + t.Fatalf("Update count mismatch\nGot: %d\nWant: %d", rowCount, updateRowCount) } totalRowCount += rowCount } if totalRowCount != updateRowCount*int64(len(statements)) { - t.Fatalf("unexpected number of total rows updated: %d, expected %d", totalRowCount, updateRowCount*int64(len(statements))) + t.Fatalf("Total update count mismatch\nGot: %d\nWant: %d", totalRowCount, updateRowCount*int64(len(statements))) } } @@ -515,7 +521,7 @@ func TestBeginTransaction(t *testing.T) { } expectedName := fmt.Sprintf("%s/transactions/", session.Name) if strings.Index(string(tx.Id), expectedName) != 0 { - t.Errorf("wrong name %s, should start with %s)", string(tx.Id), expectedName) + t.Errorf("Transaction name mismatch\nGot: %s\nWant: Name should start with %s)", string(tx.Id), expectedName) } } diff --git a/spanner/read.go b/spanner/read.go index b934f8f6e1b6..7856bdc4a8f3 100644 --- a/spanner/read.go +++ b/spanner/read.go @@ -26,9 +26,9 @@ import ( "cloud.google.com/go/internal/protostruct" "cloud.google.com/go/internal/trace" - "cloud.google.com/go/spanner/internal/backoff" - proto "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/proto" proto3 "github.com/golang/protobuf/ptypes/struct" + "github.com/googleapis/gax-go/v2" "google.golang.org/api/iterator" sppb "google.golang.org/genproto/googleapis/spanner/v1" "google.golang.org/grpc/codes" @@ -320,14 +320,11 @@ type resumableStreamDecoder struct { // last revealed to caller. resumeToken []byte - // retryCount is the number of retries that have been carried out so far - retryCount int - // err is the last error resumableStreamDecoder has encountered so far. err error - // backoff to compute delays between retries. - backoff backoff.ExponentialBackoff + // backoff is used for the retry settings + backoff gax.Backoff } // newResumableStreamDecoder creates a new resumeableStreamDecoder instance. @@ -338,7 +335,7 @@ func newResumableStreamDecoder(ctx context.Context, rpc func(ct context.Context, ctx: ctx, rpc: rpc, maxBytesBetweenResumeTokens: atomic.LoadInt32(&maxBytesBetweenResumeTokens), - backoff: backoff.DefaultBackoff, + backoff: DefaultRetryBackoff, } } @@ -421,35 +418,34 @@ var ( ) func (d *resumableStreamDecoder) next() bool { + retryer := gax.OnCodes([]codes.Code{codes.Unavailable, codes.Internal}, d.backoff) for { - select { - case <-d.ctx.Done(): - // Do context check here so that even gRPC failed to do - // so, resumableStreamDecoder can still break the loop - // as expected. - d.err = errContextCanceled(d.ctx, d.err) - d.changeState(aborted) - default: - } switch d.state { case unConnected: // If no gRPC stream is available, try to initiate one. - if d.stream, d.err = d.rpc(d.ctx, d.resumeToken); d.err != nil { - if isRetryable(d.err) { - d.doBackOff() - // Be explicit about state transition, although the - // state doesn't actually change. State transition - // will be triggered only by RPC activity, regardless of - // whether there is an actual state change or not. - d.changeState(unConnected) - continue - } + d.stream, d.err = d.rpc(d.ctx, d.resumeToken) + if d.err == nil { + d.changeState(queueingRetryable) + continue + } + delay, shouldRetry := retryer.Retry(d.err) + if !shouldRetry { d.changeState(aborted) continue } - d.resetBackOff() - d.changeState(queueingRetryable) + trace.TracePrintf(d.ctx, nil, "Backing off stream read for %s", delay) + if err := gax.Sleep(d.ctx, delay); err == nil { + // Be explicit about state transition, although the + // state doesn't actually change. State transition + // will be triggered only by RPC activity, regardless of + // whether there is an actual state change or not. + d.changeState(unConnected) + } else { + d.err = err + d.changeState(aborted) + } continue + case queueingRetryable: fallthrough case queueingUnretryable: @@ -459,7 +455,7 @@ func (d *resumableStreamDecoder) next() bool { // Only the case that receiving queue is empty could cause // peekLast to return error and in such case, we should try to // receive from stream. - d.tryRecv() + d.tryRecv(retryer) continue } if d.isNewResumeToken(last.ResumeToken) { @@ -488,7 +484,7 @@ func (d *resumableStreamDecoder) next() bool { } // Needs to receive more from gRPC stream till a new resume token // is observed. - d.tryRecv() + d.tryRecv(retryer) continue case aborted: // Discard all pending items because none of them should be yield @@ -514,51 +510,38 @@ func (d *resumableStreamDecoder) next() bool { } // tryRecv attempts to receive a PartialResultSet from gRPC stream. -func (d *resumableStreamDecoder) tryRecv() { +func (d *resumableStreamDecoder) tryRecv(retryer gax.Retryer) { var res *sppb.PartialResultSet - if res, d.err = d.stream.Recv(); d.err != nil { - if d.err == io.EOF { - d.err = nil - d.changeState(finished) - return - } - if isRetryable(d.err) && d.state == queueingRetryable { - d.err = nil - // Discard all queue items (none have resume tokens). - d.q.clear() - d.stream = nil - d.changeState(unConnected) - d.doBackOff() - return + res, d.err = d.stream.Recv() + if d.err == nil { + d.q.push(res) + if d.state == queueingRetryable && !d.isNewResumeToken(res.ResumeToken) { + d.bytesBetweenResumeTokens += int32(proto.Size(res)) } - d.changeState(aborted) + d.changeState(d.state) return } - d.q.push(res) - if d.state == queueingRetryable && !d.isNewResumeToken(res.ResumeToken) { - d.bytesBetweenResumeTokens += int32(proto.Size(res)) + if d.err == io.EOF { + d.err = nil + d.changeState(finished) + return } - d.resetBackOff() - d.changeState(d.state) -} - -// resetBackOff clears the internal retry counter of resumableStreamDecoder so -// that the next exponential backoff will start at a fresh state. -func (d *resumableStreamDecoder) resetBackOff() { - d.retryCount = 0 -} - -// doBackoff does an exponential backoff sleep. -func (d *resumableStreamDecoder) doBackOff() { - delay := d.backoff.Delay(d.retryCount) - trace.TracePrintf(d.ctx, nil, "Backing off stream read for %s", delay) - ticker := time.NewTicker(delay) - defer ticker.Stop() - d.retryCount++ - select { - case <-d.ctx.Done(): - case <-ticker.C: + delay, shouldRetry := retryer.Retry(d.err) + if !shouldRetry || d.state != queueingRetryable { + d.changeState(aborted) + return + } + if err := gax.Sleep(d.ctx, delay); err != nil { + d.err = err + d.changeState(aborted) + return } + // Clear error and retry the stream. + d.err = nil + // Discard all queue items (none have resume tokens). + d.q.clear() + d.stream = nil + d.changeState(unConnected) } // get returns the most recent PartialResultSet generated by a call to next. diff --git a/spanner/read_test.go b/spanner/read_test.go index b18b0938d7d1..81aa58e8ce07 100644 --- a/spanner/read_test.go +++ b/spanner/read_test.go @@ -25,10 +25,10 @@ import ( "testing" "time" - "cloud.google.com/go/spanner/internal/backoff" . "cloud.google.com/go/spanner/internal/testutil" "github.com/golang/protobuf/proto" proto3 "github.com/golang/protobuf/ptypes/struct" + "github.com/googleapis/gax-go/v2" "google.golang.org/api/iterator" sppb "google.golang.org/genproto/googleapis/spanner/v1" "google.golang.org/grpc" @@ -1080,9 +1080,10 @@ func TestRsdBlockingStates(t *testing.T) { test.rpc, ) // Override backoff to make the test run faster. - r.backoff = backoff.ExponentialBackoff{ - Min: 1 * time.Nanosecond, - Max: 1 * time.Nanosecond, + r.backoff = gax.Backoff{ + Initial: 1 * time.Nanosecond, + Max: 1 * time.Nanosecond, + Multiplier: 1.3, } // st is the set of observed state transitions. st := []resumableStreamDecoderState{} @@ -1720,12 +1721,9 @@ func TestIteratorStopEarly(t *testing.T) { } iter.Stop() // Stop sets r.err to the FailedPrecondition error "Next called after Stop". - // Override that here so this test can observe the Canceled error from the - // stream. - iter.err = nil - iter.Next() - if ErrCode(iter.streamd.lastErr()) != codes.Canceled { - t.Errorf("after Stop: got %v, wanted Canceled", err) + _, err = iter.Next() + if g, w := ErrCode(err), codes.FailedPrecondition; g != w { + t.Errorf("after Stop: got: %v, want: %v", g, w) } } diff --git a/spanner/retry.go b/spanner/retry.go index fa10ee5f019b..6c892156609c 100644 --- a/spanner/retry.go +++ b/spanner/retry.go @@ -18,7 +18,6 @@ package spanner import ( "context" - "strings" "time" "cloud.google.com/go/internal/trace" @@ -94,95 +93,6 @@ func runWithRetryOnAborted(ctx context.Context, f func(context.Context) error) e return funcWithRetry(ctx) } -// isErrorClosing reports whether the error is generated by gRPC layer talking -// to a closed server. -func isErrorClosing(err error) bool { - if err == nil { - return false - } - if ErrCode(err) == codes.Internal && strings.Contains(ErrDesc(err), "transport is closing") { - // Handle the case when connection is closed unexpectedly. - // TODO: once gRPC is able to categorize this as retryable error, we - // should stop parsing the error message here. - return true - } - return false -} - -// isErrorRST reports whether the error is generated by gRPC client receiving a -// RST frame from server. -func isErrorRST(err error) bool { - if err == nil { - return false - } - if ErrCode(err) == codes.Internal && strings.Contains(ErrDesc(err), "stream terminated by RST_STREAM") { - // TODO: once gRPC is able to categorize this error as "go away" or "retryable", - // we should stop parsing the error message. - return true - } - return false -} - -// isErrorUnexpectedEOF returns true if error is generated by gRPC layer -// receiving io.EOF unexpectedly. -func isErrorUnexpectedEOF(err error) bool { - if err == nil { - return false - } - // Unexpected EOF is a transport layer issue that could be recovered by - // retries. The most likely scenario is a flaky RecvMsg() call due to - // network issues. - // - // For grpc version >= 1.14.0, the error code is Internal. - // (https://github.com/grpc/grpc-go/releases/tag/v1.14.0) - if ErrCode(err) == codes.Internal && strings.Contains(ErrDesc(err), "unexpected EOF") { - return true - } - // For grpc version < 1.14.0, the error code in Unknown. - if ErrCode(err) == codes.Unknown && strings.Contains(ErrDesc(err), "unexpected EOF") { - return true - } - return false -} - -// isErrorUnavailable returns true if the error is about server being -// unavailable. -func isErrorUnavailable(err error) bool { - if err == nil { - return false - } - if ErrCode(err) == codes.Unavailable { - return true - } - return false -} - -// isRetryable returns true if the Cloud Spanner error being checked is a -// retryable error. -func isRetryable(err error) bool { - if isErrorClosing(err) { - return true - } - if isErrorUnexpectedEOF(err) { - return true - } - if isErrorRST(err) { - return true - } - if isErrorUnavailable(err) { - return true - } - return false -} - -// errContextCanceled returns *spanner.Error for canceled context. -func errContextCanceled(ctx context.Context, lastErr error) error { - if ctx.Err() == context.DeadlineExceeded { - return spannerErrorf(codes.DeadlineExceeded, "%v, lastErr is <%v>", ctx.Err(), lastErr) - } - return spannerErrorf(codes.Canceled, "%v, lastErr is <%v>", ctx.Err(), lastErr) -} - // extractRetryDelay extracts retry backoff if present. func extractRetryDelay(err error) (time.Duration, bool) { trailers := errTrailers(err)