diff --git a/spanner/client.go b/spanner/client.go index 1f15cb01e2d0..2a30b8eb186c 100644 --- a/spanner/client.go +++ b/spanner/client.go @@ -21,7 +21,6 @@ import ( "fmt" "os" "regexp" - "sync/atomic" "time" "cloud.google.com/go/internal/trace" @@ -64,16 +63,8 @@ func validDatabaseName(db string) error { // Client is a client for reading and writing data to a Cloud Spanner database. // A client is safe to use concurrently, except for its Close method. type Client struct { - // rr must be accessed through atomic operations. - rr uint32 - clients []*vkit.Client - - database string - // Metadata to be sent with each request. - md metadata.MD + sc *sessionClient idleSessions *sessionPool - // sessionLabels for the sessions created by this client. - sessionLabels map[string]string } // ClientConfig has configurations for the client. @@ -110,23 +101,12 @@ func contextWithOutgoingMetadata(ctx context.Context, md metadata.MD) context.Co // form projects/PROJECT_ID/instances/INSTANCE_ID/databases/DATABASE_ID. It uses // a default configuration. func NewClient(ctx context.Context, database string, opts ...option.ClientOption) (*Client, error) { - return NewClientWithConfig(ctx, database, ClientConfig{}, opts...) + return NewClientWithConfig(ctx, database, ClientConfig{SessionPoolConfig: DefaultSessionPoolConfig}, opts...) } // NewClientWithConfig creates a client to a database. A valid database name has // the form projects/PROJECT_ID/instances/INSTANCE_ID/databases/DATABASE_ID. func NewClientWithConfig(ctx context.Context, database string, config ClientConfig, opts ...option.ClientOption) (c *Client, err error) { - c = &Client{ - database: database, - md: metadata.Pairs(resourcePrefixHeader, database), - } - - // Make a copy of labels. - c.sessionLabels = make(map[string]string) - for k, v := range config.SessionLabels { - c.sessionLabels[k] = v - } - // Prepare gRPC channels. if config.NumChannels == 0 { config.NumChannels = numChannels @@ -137,7 +117,7 @@ func NewClientWithConfig(ctx context.Context, database string, config ClientConf config.MaxOpened = uint64(config.NumChannels * 100) } if config.MaxBurst == 0 { - config.MaxBurst = 10 + config.MaxBurst = DefaultSessionPoolConfig.MaxBurst } // Validate database path. @@ -174,43 +154,44 @@ func NewClientWithConfig(ctx context.Context, database string, config ClientConf // TODO(deklerk): This should be replaced with a balancer with // config.NumChannels connections, instead of config.NumChannels // clients. + var clients []*vkit.Client for i := 0; i < config.NumChannels; i++ { client, err := vkit.NewClient(ctx, allOpts...) if err != nil { return nil, errDial(i, err) } - c.clients = append(c.clients, client) + clients = append(clients, client) } - // Prepare session pool. - // TODO: support more loadbalancing options. - config.SessionPoolConfig.getRPCClient = func() (*vkit.Client, error) { - return c.rrNext(), nil + // TODO(loite): Remove as the original map cannot be changed by the user + // anyways, and the client library is also not changing it. + // Make a copy of labels. + sessionLabels := make(map[string]string) + for k, v := range config.SessionLabels { + sessionLabels[k] = v } - config.SessionPoolConfig.sessionLabels = c.sessionLabels - sp, err := newSessionPool(database, config.SessionPoolConfig, c.md) + // Create a session client. + sc := newSessionClient(clients, database, sessionLabels, metadata.Pairs(resourcePrefixHeader, database)) + // Create a session pool. + config.SessionPoolConfig.sessionLabels = sessionLabels + sp, err := newSessionPool(sc, config.SessionPoolConfig) if err != nil { - c.Close() + sc.close() return nil, err } - c.idleSessions = sp + c = &Client{ + sc: sc, + idleSessions: sp, + } return c, nil } -// rrNext returns the next available vkit Cloud Spanner RPC client in a -// round-robin manner. -func (c *Client) rrNext() *vkit.Client { - return c.clients[atomic.AddUint32(&c.rr, 1)%uint32(len(c.clients))] -} - // Close closes the client. func (c *Client) Close() { if c.idleSessions != nil { c.idleSessions.close() } - for _, gpc := range c.clients { - gpc.Close() - } + c.sc.close() } // Single provides a read-only snapshot transaction optimized for the case @@ -273,8 +254,7 @@ func (c *Client) BatchReadOnlyTransaction(ctx context.Context, tb TimestampBound }() // Create session. - sc := c.rrNext() - s, err = createSession(ctx, sc, c.database, c.sessionLabels, c.md) + s, err = c.sc.createSession(ctx) if err != nil { return nil, err } @@ -318,8 +298,7 @@ func (c *Client) BatchReadOnlyTransaction(ctx context.Context, tb TimestampBound // BatchReadOnlyTransactionFromID reconstruct a BatchReadOnlyTransaction from // BatchReadOnlyTransactionID func (c *Client) BatchReadOnlyTransactionFromID(tid BatchReadOnlyTransactionID) *BatchReadOnlyTransaction { - sc := c.rrNext() - s := &session{valid: true, client: sc, id: tid.sid, createTime: time.Now(), md: c.md} + s := c.sc.sessionWithID(tid.sid) sh := &sessionHandle{session: s} t := &BatchReadOnlyTransaction{ diff --git a/spanner/integration_test.go b/spanner/integration_test.go index 6e1384be20e6..70c19c96c23a 100644 --- a/spanner/integration_test.go +++ b/spanner/integration_test.go @@ -213,6 +213,32 @@ func initIntegrationTests() (cleanup func()) { } } +func TestIntegration_InitSessionPool(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + // Set up testing environment. + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) + defer cleanup() + sp := client.idleSessions + sp.mu.Lock() + want := sp.MinOpened + sp.mu.Unlock() + var numOpened int + for { + select { + case <-ctx.Done(): + t.Fatalf("timed out, got %d session(s), want %d", numOpened, want) + default: + sp.mu.Lock() + numOpened = sp.idleList.Len() + sp.idleWriteList.Len() + sp.mu.Unlock() + if uint64(numOpened) == want { + return + } + } + } +} + // Test SingleUse transaction. func TestIntegration_SingleUse(t *testing.T) { t.Parallel() @@ -220,7 +246,7 @@ func TestIntegration_SingleUse(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() // Set up testing environment. - client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) defer cleanup() writes := []struct { @@ -420,7 +446,7 @@ func TestIntegration_SingleUse_ReadingWithLimit(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() // Set up testing environment. - client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) defer cleanup() writes := []struct { @@ -463,7 +489,7 @@ func TestIntegration_ReadOnlyTransaction(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() // Set up testing environment. - client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) defer cleanup() writes := []struct { @@ -649,7 +675,7 @@ func TestIntegration_UpdateDuringRead(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) defer cleanup() for i, tb := range []TimestampBound{ @@ -682,7 +708,7 @@ func TestIntegration_ReadWriteTransaction(t *testing.T) { // Give a longer deadline because of transaction backoffs. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) defer cleanup() // Set up two accounts @@ -773,7 +799,7 @@ func TestIntegration_Reads(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() // Set up testing environment. - client, _, cleanup := prepareIntegrationTest(ctx, t, readDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, readDBStatements) defer cleanup() // Includes k0..k14. Strings sort lexically, eg "k1" < "k10" < "k2". @@ -841,7 +867,7 @@ func TestIntegration_EarlyTimestamp(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() // Set up testing environment. - client, _, cleanup := prepareIntegrationTest(ctx, t, readDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, readDBStatements) defer cleanup() var ms []*Mutation @@ -887,7 +913,7 @@ func TestIntegration_NestedTransaction(t *testing.T) { // You cannot use a transaction from inside a read-write transaction. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) defer cleanup() _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { @@ -919,7 +945,9 @@ func TestIntegration_DbRemovalRecovery(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, dbPath, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements) + // Create a client with MinOpened=0 to prevent the session pool maintainer + // from repeatedly trying to create sessions for the invalid database. + client, dbPath, cleanup := prepareIntegrationTest(ctx, t, SessionPoolConfig{}, singerDBStatements) defer cleanup() // Drop the testing database. @@ -970,7 +998,7 @@ func TestIntegration_BasicTypes(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) defer cleanup() t1, _ := time.Parse(time.RFC3339Nano, "2016-11-15T15:04:05.999999999Z") @@ -1118,7 +1146,7 @@ func TestIntegration_StructTypes(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) defer cleanup() tests := []struct { @@ -1205,7 +1233,7 @@ func TestIntegration_StructParametersUnsupported(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, nil) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, nil) defer cleanup() for _, test := range []struct { @@ -1250,7 +1278,7 @@ func TestIntegration_QueryExpressions(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, nil) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, nil) defer cleanup() newRow := func(vals []interface{}) *Row { @@ -1306,7 +1334,7 @@ func TestIntegration_QueryStats(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) defer cleanup() accounts := []*Mutation{ @@ -1348,12 +1376,12 @@ func TestIntegration_QueryStats(t *testing.T) { func TestIntegration_InvalidDatabase(t *testing.T) { t.Parallel() - if testProjectID == "" { - t.Skip("Integration tests skipped: GCLOUD_TESTS_GOLANG_PROJECT_ID is missing") + if databaseAdmin == nil { + t.Skip("Integration tests skipped") } ctx := context.Background() dbPath := fmt.Sprintf("projects/%v/instances/%v/databases/invalid", testProjectID, testInstanceID) - c, err := createClient(ctx, dbPath) + c, err := createClient(ctx, dbPath, SessionPoolConfig{}) // Client creation should succeed even if the database is invalid. if err != nil { t.Fatal(err) @@ -1369,7 +1397,7 @@ func TestIntegration_ReadErrors(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, readDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, readDBStatements) defer cleanup() // Read over invalid table fails @@ -1415,7 +1443,7 @@ func TestIntegration_TransactionRunner(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) defer cleanup() // Test 1: User error should abort the transaction. @@ -1556,13 +1584,13 @@ func TestIntegration_BatchQuery(t *testing.T) { ) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, dbPath, cleanup := prepareIntegrationTest(ctx, t, simpleDBStatements) + client, dbPath, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, simpleDBStatements) defer cleanup() if err = populate(ctx, client); err != nil { t.Fatal(err) } - if client2, err = createClient(ctx, dbPath); err != nil { + if client2, err = createClient(ctx, dbPath, SessionPoolConfig{}); err != nil { t.Fatal(err) } defer client2.Close() @@ -1642,13 +1670,13 @@ func TestIntegration_BatchRead(t *testing.T) { ) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, dbPath, cleanup := prepareIntegrationTest(ctx, t, simpleDBStatements) + client, dbPath, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, simpleDBStatements) defer cleanup() if err = populate(ctx, client); err != nil { t.Fatal(err) } - if client2, err = createClient(ctx, dbPath); err != nil { + if client2, err = createClient(ctx, dbPath, SessionPoolConfig{}); err != nil { t.Fatal(err) } defer client2.Close() @@ -1729,7 +1757,7 @@ func TestIntegration_BROTNormal(t *testing.T) { ) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, simpleDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, simpleDBStatements) defer cleanup() if txn, err = client.BatchReadOnlyTransaction(ctx, StrongRead()); err != nil { @@ -1758,7 +1786,7 @@ func TestIntegration_CommitTimestamp(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, ctsDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, ctsDBStatements) defer cleanup() type testTableRow struct { @@ -1828,7 +1856,7 @@ func TestIntegration_DML(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) defer cleanup() // Function that reads a single row's first name from within a transaction. @@ -1995,7 +2023,7 @@ func TestIntegration_StructParametersBind(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, nil) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, nil) defer cleanup() type tRow []interface{} @@ -2165,7 +2193,7 @@ func TestIntegration_PDML(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) defer cleanup() columns := []string{"SingerId", "FirstName", "LastName"} @@ -2219,7 +2247,7 @@ func TestBatchDML(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) defer cleanup() columns := []string{"SingerId", "FirstName", "LastName"} @@ -2272,7 +2300,7 @@ func TestBatchDML_NoStatements(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) defer cleanup() _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) (err error) { @@ -2296,7 +2324,7 @@ func TestBatchDML_TwoStatements(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) defer cleanup() columns := []string{"SingerId", "FirstName", "LastName"} @@ -2347,7 +2375,7 @@ func TestBatchDML_Error(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - client, _, cleanup := prepareIntegrationTest(ctx, t, singerDBStatements) + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, singerDBStatements) defer cleanup() columns := []string{"SingerId", "FirstName", "LastName"} @@ -2399,7 +2427,7 @@ func TestBatchDML_Error(t *testing.T) { } // Prepare initializes Cloud Spanner testing DB and clients. -func prepareIntegrationTest(ctx context.Context, t *testing.T, statements []string) (*Client, string, func()) { +func prepareIntegrationTest(ctx context.Context, t *testing.T, spc SessionPoolConfig, statements []string) (*Client, string, func()) { if databaseAdmin == nil { t.Skip("Integration tests skipped") } @@ -2419,7 +2447,7 @@ func prepareIntegrationTest(ctx context.Context, t *testing.T, statements []stri if _, err := op.Wait(ctx); err != nil { t.Fatalf("cannot create testing DB %v: %v", dbPath, err) } - client, err := createClient(ctx, dbPath) + client, err := createClient(ctx, dbPath, spc) if err != nil { t.Fatalf("cannot create data client on DB %v: %v", dbPath, err) } @@ -2563,9 +2591,9 @@ func isNaN(x interface{}) bool { } // createClient creates Cloud Spanner data client. -func createClient(ctx context.Context, dbPath string) (client *Client, err error) { +func createClient(ctx context.Context, dbPath string, spc SessionPoolConfig) (client *Client, err error) { client, err = NewClientWithConfig(ctx, dbPath, ClientConfig{ - SessionPoolConfig: SessionPoolConfig{WriteSessions: 0.2}, + SessionPoolConfig: spc, }, option.WithTokenSource(testutil.TokenSource(ctx, Scope)), option.WithEndpoint(endpoint)) if err != nil { return nil, fmt.Errorf("cannot create data client on DB %v: %v", dbPath, err) diff --git a/spanner/internal/testutil/inmem_spanner_server.go b/spanner/internal/testutil/inmem_spanner_server.go index 965e5c29d66a..8c61acce96a0 100644 --- a/spanner/internal/testutil/inmem_spanner_server.go +++ b/spanner/internal/testutil/inmem_spanner_server.go @@ -56,6 +56,7 @@ const ( const ( MethodBeginTransaction string = "BEGIN_TRANSACTION" MethodCommitTransaction string = "COMMIT_TRANSACTION" + MethodBatchCreateSession string = "BATCH_CREATE_SESSION" MethodCreateSession string = "CREATE_SESSION" MethodDeleteSession string = "DELETE_SESSION" MethodGetSession string = "GET_SESSION" @@ -206,6 +207,8 @@ type InMemSpannerServer interface { TotalSessionsCreated() uint TotalSessionsDeleted() uint + SetMaxSessionsReturnedByServerPerBatchRequest(sessionCount int32) + SetMaxSessionsReturnedByServerInTotal(sessionCount int32) ReceivedRequests() chan interface{} DumpSessions() map[string]bool @@ -249,7 +252,10 @@ type inMemSpannerServer struct { totalSessionsCreated uint totalSessionsDeleted uint - receivedRequests chan interface{} + // The maximum number of sessions that will be created per batch request. + maxSessionsReturnedByServerPerBatchRequest int32 + maxSessionsReturnedByServerInTotal int32 + receivedRequests chan interface{} // Session ping history. pings []string @@ -362,6 +368,18 @@ func (s *inMemSpannerServer) TotalSessionsDeleted() uint { return s.totalSessionsDeleted } +func (s *inMemSpannerServer) SetMaxSessionsReturnedByServerPerBatchRequest(sessionCount int32) { + s.mu.Lock() + defer s.mu.Unlock() + s.maxSessionsReturnedByServerPerBatchRequest = sessionCount +} + +func (s *inMemSpannerServer) SetMaxSessionsReturnedByServerInTotal(sessionCount int32) { + s.mu.Lock() + defer s.mu.Unlock() + s.maxSessionsReturnedByServerInTotal = sessionCount +} + func (s *inMemSpannerServer) ReceivedRequests() chan interface{} { return s.receivedRequests } @@ -393,6 +411,7 @@ func (s *inMemSpannerServer) DumpSessions() map[string]bool { func (s *inMemSpannerServer) initDefaults() { s.sessionCounter = 0 + s.maxSessionsReturnedByServerPerBatchRequest = 100 s.sessions = make(map[string]*spannerpb.Session) s.sessionLastUseTime = make(map[string]time.Time) s.transactions = make(map[string]*spannerpb.Transaction) @@ -401,9 +420,7 @@ func (s *inMemSpannerServer) initDefaults() { s.transactionCounters = make(map[string]*uint64) } -func (s *inMemSpannerServer) generateSessionName(database string) string { - s.mu.Lock() - defer s.mu.Unlock() +func (s *inMemSpannerServer) generateSessionNameLocked(database string) string { s.sessionCounter++ return fmt.Sprintf("%s/sessions/%d", database, s.sessionCounter) } @@ -524,13 +541,16 @@ func (s *inMemSpannerServer) simulateExecutionTime(method string, req interface{ } totalExecutionTime := time.Duration(int64(executionTime.MinimumExecutionTime) + randTime) <-time.After(totalExecutionTime) + s.mu.Lock() if executionTime.Errors != nil && len(executionTime.Errors) > 0 { err := executionTime.Errors[0] if !executionTime.KeepError { executionTime.Errors = executionTime.Errors[1:] } + s.mu.Unlock() return err } + s.mu.Unlock() } return nil } @@ -542,16 +562,52 @@ func (s *inMemSpannerServer) CreateSession(ctx context.Context, req *spannerpb.C if req.Database == "" { return nil, gstatus.Error(codes.InvalidArgument, "Missing database") } - sessionName := s.generateSessionName(req.Database) + s.mu.Lock() + defer s.mu.Unlock() + if s.maxSessionsReturnedByServerInTotal > int32(0) && int32(len(s.sessions)) == s.maxSessionsReturnedByServerInTotal { + return nil, gstatus.Error(codes.ResourceExhausted, "No more sessions available") + } + sessionName := s.generateSessionNameLocked(req.Database) ts := getCurrentTimestamp() session := &spannerpb.Session{Name: sessionName, CreateTime: ts, ApproximateLastUseTime: ts} - s.mu.Lock() s.totalSessionsCreated++ s.sessions[sessionName] = session - s.mu.Unlock() return session, nil } +func (s *inMemSpannerServer) BatchCreateSessions(ctx context.Context, req *spannerpb.BatchCreateSessionsRequest) (*spannerpb.BatchCreateSessionsResponse, error) { + if err := s.simulateExecutionTime(MethodBatchCreateSession, req); err != nil { + return nil, err + } + if req.Database == "" { + return nil, gstatus.Error(codes.InvalidArgument, "Missing database") + } + if req.SessionCount <= 0 { + return nil, gstatus.Error(codes.InvalidArgument, "Session count must be >= 0") + } + sessionsToCreate := req.SessionCount + s.mu.Lock() + defer s.mu.Unlock() + if s.maxSessionsReturnedByServerInTotal > int32(0) && int32(len(s.sessions)) >= s.maxSessionsReturnedByServerInTotal { + return nil, gstatus.Error(codes.ResourceExhausted, "No more sessions available") + } + if sessionsToCreate > s.maxSessionsReturnedByServerPerBatchRequest { + sessionsToCreate = s.maxSessionsReturnedByServerPerBatchRequest + } + if s.maxSessionsReturnedByServerInTotal > int32(0) && (sessionsToCreate+int32(len(s.sessions))) > s.maxSessionsReturnedByServerInTotal { + sessionsToCreate = s.maxSessionsReturnedByServerInTotal - int32(len(s.sessions)) + } + sessions := make([]*spannerpb.Session, sessionsToCreate) + for i := int32(0); i < sessionsToCreate; i++ { + sessionName := s.generateSessionNameLocked(req.Database) + ts := getCurrentTimestamp() + sessions[i] = &spannerpb.Session{Name: sessionName, CreateTime: ts, ApproximateLastUseTime: ts} + s.totalSessionsCreated++ + s.sessions[sessionName] = sessions[i] + } + return &spannerpb.BatchCreateSessionsResponse{Session: sessions}, nil +} + func (s *inMemSpannerServer) GetSession(ctx context.Context, req *spannerpb.GetSessionRequest) (*spannerpb.Session, error) { if err := s.simulateExecutionTime(MethodGetSession, req); err != nil { return nil, err diff --git a/spanner/oc_test.go b/spanner/oc_test.go index 7e957139e36f..3e3d9f8ad4ee 100644 --- a/spanner/oc_test.go +++ b/spanner/oc_test.go @@ -33,7 +33,10 @@ func TestOCStats(t *testing.T) { ms := stestutil.NewMockCloudSpanner(t, trxTs) ms.Serve() ctx := context.Background() - c, err := NewClient(ctx, "projects/P/instances/I/databases/D", + c, err := NewClientWithConfig(ctx, "projects/P/instances/I/databases/D", + ClientConfig{SessionPoolConfig: SessionPoolConfig{ + MinOpened: 0, + }}, option.WithEndpoint(ms.Addr()), option.WithGRPCDialOption(grpc.WithInsecure()), option.WithoutAuthentication()) diff --git a/spanner/pdml.go b/spanner/pdml.go index 242a48edcfe9..6f160a21aec9 100644 --- a/spanner/pdml.go +++ b/spanner/pdml.go @@ -41,9 +41,8 @@ func (c *Client) PartitionedUpdate(ctx context.Context, statement Statement) (co s *session sh *sessionHandle ) - // Create a session that will be used only for this request. - sc := c.rrNext() - s, err = createSession(ctx, sc, c.database, c.sessionLabels, c.md) + // Create session. + s, err = c.sc.createSession(ctx) if err != nil { return 0, toSpannerError(err) } diff --git a/spanner/session.go b/spanner/session.go index 20329c398f9e..3010b44294a9 100644 --- a/spanner/session.go +++ b/spanner/session.go @@ -22,6 +22,7 @@ import ( "context" "fmt" "log" + "math" "math/rand" "strings" "sync" @@ -185,7 +186,7 @@ func (s *session) ping() error { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() // s.getID is safe even when s is invalid. - _, err := s.client.GetSession(contextWithOutgoingMetadata(ctx, s.pool.md), &sppb.GetSessionRequest{Name: s.getID()}) + _, err := s.client.GetSession(contextWithOutgoingMetadata(ctx, s.md), &sppb.GetSessionRequest{Name: s.getID()}) return err } @@ -303,7 +304,7 @@ func (s *session) prepareForWrite(ctx context.Context) error { if s.isWritePrepared() { return nil } - tx, err := beginTransaction(ctx, s.getID(), s.client) + tx, err := beginTransaction(contextWithOutgoingMetadata(ctx, s.md), s.getID(), s.client) if err != nil { return err } @@ -313,10 +314,6 @@ func (s *session) prepareForWrite(ctx context.Context) error { // SessionPoolConfig stores configurations of a session pool. type SessionPoolConfig struct { - // getRPCClient is the caller supplied method for getting a gRPC client to - // Cloud Spanner, this makes session pool able to use client pooling. - getRPCClient func() (*vkit.Client, error) - // MaxOpened is the maximum number of opened sessions allowed by the session // pool. If the client tries to open a session and there are already // MaxOpened sessions, it will block until one becomes available or the @@ -332,7 +329,7 @@ type SessionPoolConfig struct { // therefore it is posssible that the number of opened sessions drops below // MinOpened. // - // Defaults to 0. + // Defaults to 100. MinOpened uint64 // MaxIdle is the maximum number of idle sessions, pool is allowed to keep. @@ -348,7 +345,7 @@ type SessionPoolConfig struct { // WriteSessions is the fraction of sessions we try to keep prepared for // write. // - // Defaults to 0. + // Defaults to 0.2. WriteSessions float64 // HealthCheckWorkers is number of workers used by health checker for this @@ -372,25 +369,59 @@ type SessionPoolConfig struct { sessionLabels map[string]string } -// errNoRPCGetter returns error for SessionPoolConfig missing getRPCClient method. -func errNoRPCGetter() error { - return spannerErrorf(codes.InvalidArgument, "require SessionPoolConfig.getRPCClient != nil, got nil") +// DefaultSessionPoolConfig is the default configuration for the session pool +// that will be used for a Spanner client, unless the user supplies a specific +// session pool config. +var DefaultSessionPoolConfig = SessionPoolConfig{ + MinOpened: 100, + MaxOpened: numChannels * 100, + MaxBurst: 10, + WriteSessions: 0.2, + HealthCheckWorkers: 10, + HealthCheckInterval: 5 * time.Minute, } // errMinOpenedGTMapOpened returns error for SessionPoolConfig.MaxOpened < SessionPoolConfig.MinOpened when SessionPoolConfig.MaxOpened is set. func errMinOpenedGTMaxOpened(maxOpened, minOpened uint64) error { return spannerErrorf(codes.InvalidArgument, - "require SessionPoolConfig.MaxOpened >= SessionPoolConfig.MinOpened, got %v and %v", maxOpened, minOpened) + "require SessionPoolConfig.MaxOpened >= SessionPoolConfig.MinOpened, got %d and %d", maxOpened, minOpened) +} + +// errWriteFractionOutOfRange returns error for +// SessionPoolConfig.WriteFraction < 0 or SessionPoolConfig.WriteFraction > 1 +func errWriteFractionOutOfRange(writeFraction float64) error { + return spannerErrorf(codes.InvalidArgument, + "require SessionPoolConfig.WriteSessions >= 0.0 && SessionPoolConfig.WriteSessions <= 1.0, got %.2f", writeFraction) +} + +// errHealthCheckWorkersNegative returns error for +// SessionPoolConfig.HealthCheckWorkers < 0 +func errHealthCheckWorkersNegative(workers int) error { + return spannerErrorf(codes.InvalidArgument, + "require SessionPoolConfig.HealthCheckWorkers >= 0, got %d", workers) +} + +// errHealthCheckIntervalNegative returns error for +// SessionPoolConfig.HealthCheckInterval < 0 +func errHealthCheckIntervalNegative(interval time.Duration) error { + return spannerErrorf(codes.InvalidArgument, + "require SessionPoolConfig.HealthCheckInterval >= 0, got %v", interval) } // validate verifies that the SessionPoolConfig is good for use. func (spc *SessionPoolConfig) validate() error { - if spc.getRPCClient == nil { - return errNoRPCGetter() - } if spc.MinOpened > spc.MaxOpened && spc.MaxOpened > 0 { return errMinOpenedGTMaxOpened(spc.MaxOpened, spc.MinOpened) } + if spc.WriteSessions < 0.0 || spc.WriteSessions > 1.0 { + return errWriteFractionOutOfRange(spc.WriteSessions) + } + if spc.HealthCheckWorkers < 0 { + return errHealthCheckWorkersNegative(spc.HealthCheckWorkers) + } + if spc.HealthCheckInterval < 0 { + return errHealthCheckIntervalNegative(spc.HealthCheckInterval) + } return nil } @@ -400,8 +431,8 @@ type sessionPool struct { mu sync.Mutex // valid marks the validity of the session pool. valid bool - // db is the database name that all sessions in the pool are associated with. - db string + // sc is used to create the sessions for the pool. + sc *sessionClient // idleList caches idle session IDs. Session IDs in this list can be // allocated for use. idleList list.List @@ -418,23 +449,20 @@ type sessionPool struct { prepareReqs uint64 // configuration of the session pool. SessionPoolConfig - // Metadata to be sent with each request - md metadata.MD // hc is the health checker hc *healthChecker } // newSessionPool creates a new session pool. -func newSessionPool(db string, config SessionPoolConfig, md metadata.MD) (*sessionPool, error) { +func newSessionPool(sc *sessionClient, config SessionPoolConfig) (*sessionPool, error) { if err := config.validate(); err != nil { return nil, err } pool := &sessionPool{ - db: db, + sc: sc, valid: true, mayGetSession: make(chan struct{}), SessionPoolConfig: config, - md: md, } if config.HealthCheckWorkers == 0 { // With 10 workers and assuming average latency of 5ms for @@ -456,10 +484,76 @@ func newSessionPool(db string, config SessionPoolConfig, md metadata.MD) (*sessi // healthChecker can effectively mantain // 100 checks_per_worker/sec * 10 workers * 300 seconds = 300K sessions. pool.hc = newHealthChecker(config.HealthCheckInterval, config.HealthCheckWorkers, config.healthCheckSampleInterval, pool) + + // First initialize the pool before we indicate that the healthchecker is + // ready. This prevents the maintainer from starting before the pool has + // been initialized, which means that we guarantee that the initial + // sessions are created using BatchCreateSessions. + if config.MinOpened > 0 { + numSessions := minUint64(config.MinOpened, math.MaxInt32) + if err := pool.initPool(int32(numSessions)); err != nil { + return nil, err + } + } close(pool.hc.ready) return pool, nil } +func (p *sessionPool) initPool(numSessions int32) error { + p.mu.Lock() + // Take budget before the actual session creation. + p.numOpened += uint64(numSessions) + recordStat(context.Background(), OpenSessionCount, int64(p.numOpened)) + p.createReqs += uint64(numSessions) + p.mu.Unlock() + // Asynchronously create the initial sessions for the pool. + return p.sc.batchCreateSessions(numSessions, p) +} + +// sessionReady is executed by the SessionClient when a session has been +// created and is ready to use. This method will add the new session to the +// pool and decrease the number of sessions that is being created. +func (p *sessionPool) sessionReady(s *session) { + p.mu.Lock() + defer p.mu.Unlock() + // Set this pool as the home pool of the session and register it with the + // health checker. + s.pool = p + p.hc.register(s) + p.createReqs-- + // Insert the session at a random position in the pool to prevent all + // sessions affiliated with a channel to be placed at sequentially in the + // pool. + if p.idleList.Len() > 0 { + pos := rand.Intn(p.idleList.Len()) + before := p.idleList.Front() + for i := 0; i < pos; i++ { + before = before.Next() + } + s.setIdleList(p.idleList.InsertBefore(s, before)) + } else { + s.setIdleList(p.idleList.PushBack(s)) + } + // Notify other waiters blocking on session creation. + close(p.mayGetSession) + p.mayGetSession = make(chan struct{}) +} + +// sessionCreationFailed is called by the SessionClient when the creation of one +// or more requested sessions finished with an error. sessionCreationFailed will +// decrease the number of sessions being created and notify any waiters that +// the session creation failed. +func (p *sessionPool) sessionCreationFailed(err error, numSessions int32) { + p.mu.Lock() + defer p.mu.Unlock() + p.createReqs -= uint64(numSessions) + p.numOpened -= uint64(numSessions) + recordStat(context.Background(), OpenSessionCount, int64(p.numOpened)) + // Notify other waiters blocking on session creation. + close(p.mayGetSession) + p.mayGetSession = make(chan struct{}) +} + // isValid checks if the session pool is still valid. func (p *sessionPool) isValid() bool { if p == nil { @@ -524,12 +618,7 @@ func (p *sessionPool) createSession(ctx context.Context) (*session, error) { p.mayGetSession = make(chan struct{}) p.mu.Unlock() } - sc, err := p.getRPCClient() - if err != nil { - doneCreate(false) - return nil, err - } - s, err := createSession(ctx, sc, p.db, p.sessionLabels, p.md) + s, err := p.sc.createSession(ctx) if err != nil { doneCreate(false) // Should return error directly because of the previous retries on @@ -545,20 +634,6 @@ func (p *sessionPool) createSession(ctx context.Context) (*session, error) { return s, nil } -func createSession(ctx context.Context, sc *vkit.Client, db string, labels map[string]string, md metadata.MD) (*session, error) { - var s *session - sid, e := sc.CreateSession(ctx, &sppb.CreateSessionRequest{ - Database: db, - Session: &sppb.Session{Labels: labels}, - }) - if e != nil { - return nil, toSpannerError(e) - } - // If no error, construct the new session. - s = &session{valid: true, client: sc, id: sid.Name, createTime: time.Now(), md: md} - return s, nil -} - func (p *sessionPool) isHealthy(s *session) bool { if s.getNextCheck().Add(2 * p.hc.getInterval()).Before(time.Now()) { // TODO: figure out if we need to schedule a new healthcheck worker here. @@ -577,7 +652,6 @@ func (p *sessionPool) isHealthy(s *session) bool { // for read operations. func (p *sessionPool) take(ctx context.Context) (*sessionHandle, error) { trace.TracePrintf(ctx, nil, "Acquiring a read-only session") - ctx = contextWithOutgoingMetadata(ctx, p.md) for { var ( s *session @@ -649,7 +723,6 @@ func (p *sessionPool) take(ctx context.Context) (*sessionHandle, error) { // returned should be used for read write transactions. func (p *sessionPool) takeWriteSession(ctx context.Context) (*sessionHandle, error) { trace.TracePrintf(ctx, nil, "Acquiring a read-write session") - ctx = contextWithOutgoingMetadata(ctx, p.md) for { var ( s *session @@ -1004,7 +1077,7 @@ func (hc *healthChecker) worker(i int) { ws := getNextForTx() if ws != nil { ctx, cancel := context.WithTimeout(context.Background(), time.Second) - err := ws.prepareForWrite(contextWithOutgoingMetadata(ctx, hc.pool.md)) + err := ws.prepareForWrite(ctx) cancel() if err != nil { // Skip handling prepare error, session can be prepared in next @@ -1044,6 +1117,9 @@ func (hc *healthChecker) maintainer() { // Wait so that pool is ready. <-hc.ready + // A maintenance window is 10 iterations. The maintainer executes a loop + // every hc.sampleInterval, which defaults to 1 minute, which means that + // the default maintenance window is 10 minutes. windowSize := uint64(10) for iteration := uint64(0); ; iteration++ { diff --git a/spanner/session_test.go b/spanner/session_test.go index 449ab5bf3ded..17b8ea2f5638 100644 --- a/spanner/session_test.go +++ b/spanner/session_test.go @@ -25,8 +25,8 @@ import ( "testing" "time" - vkit "cloud.google.com/go/spanner/apiv1" . "cloud.google.com/go/spanner/internal/testutil" + sppb "google.golang.org/genproto/googleapis/spanner/v1" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -41,22 +41,39 @@ func TestSessionPoolConfigValidation(t *testing.T) { spc SessionPoolConfig err error }{ - { - SessionPoolConfig{}, - errNoRPCGetter(), - }, { SessionPoolConfig{ - getRPCClient: func() (*vkit.Client, error) { - return client.clients[0], nil - }, MinOpened: 10, MaxOpened: 5, }, errMinOpenedGTMaxOpened(5, 10), }, + { + SessionPoolConfig{ + WriteSessions: -0.1, + }, + errWriteFractionOutOfRange(-0.1), + }, + { + SessionPoolConfig{ + WriteSessions: 2.0, + }, + errWriteFractionOutOfRange(2.0), + }, + { + SessionPoolConfig{ + HealthCheckWorkers: -1, + }, + errHealthCheckWorkersNegative(-1), + }, + { + SessionPoolConfig{ + HealthCheckInterval: -time.Second, + }, + errHealthCheckIntervalNegative(-time.Second), + }, } { - if _, err := newSessionPool("mockdb", test.spc, nil); !testEqual(err, test.err) { + if _, err := newSessionPool(client.sc, test.spc); !testEqual(err, test.err) { t.Fatalf("want %v, got %v", test.err, err) } } @@ -459,8 +476,8 @@ func TestMinOpenedSessions(t *testing.T) { defer sp.mu.Unlock() // There should be still one session left in idle list due to the min open // sessions constraint. - if sp.idleList.Len() != 1 { - t.Fatalf("got %v sessions in idle list, want 1 %d", sp.idleList.Len(), sp.numOpened) + if sp.idleList.Len() != int(sp.MinOpened) { + t.Fatalf("got %v sessions in idle list, want %d", sp.idleList.Len(), sp.MinOpened) } } @@ -1100,7 +1117,7 @@ func TestMaintainer(t *testing.T) { }) } -// Tests that maintainer creates up to MinOpened connections. +// Tests that the session pool creates up to MinOpened connections. // // Historical context: This test also checks that a low // healthCheckSampleInterval does not prevent it from opening connections. @@ -1108,11 +1125,57 @@ func TestMaintainer(t *testing.T) { // creations to time out. That should not be considered a problem, but it // could cause the test case to fail if it happens too often. // See: https://github.com/googleapis/google-cloud-go/issues/1259 -func TestMaintainer_CreatesSessions(t *testing.T) { +func TestInit_CreatesSessions(t *testing.T) { t.Parallel() spc := SessionPoolConfig{ MinOpened: 10, MaxIdle: 10, + WriteSessions: 0.0, + healthCheckSampleInterval: 20 * time.Millisecond, + } + server, client, teardown := setupMockedTestServerWithConfig(t, + ClientConfig{ + SessionPoolConfig: spc, + NumChannels: 4, + }) + defer teardown() + sp := client.idleSessions + + timeout := time.After(4 * time.Second) + var numOpened int +loop: + for { + select { + case <-timeout: + t.Fatalf("timed out, got %d session(s), want %d", numOpened, spc.MinOpened) + default: + sp.mu.Lock() + numOpened = sp.idleList.Len() + sp.idleWriteList.Len() + sp.mu.Unlock() + if numOpened == 10 { + break loop + } + } + } + _, err := shouldHaveReceived(server.TestSpanner, []interface{}{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.BatchCreateSessionsRequest{}, + &sppb.BatchCreateSessionsRequest{}, + &sppb.BatchCreateSessionsRequest{}, + }) + if err != nil { + t.Fatal(err) + } +} + +// Tests that the session pool with a MinSessions>0 also prepares WriteSessions +// sessions. +func TestInit_PreparesSessions(t *testing.T) { + t.Parallel() + spc := SessionPoolConfig{ + MinOpened: 10, + MaxIdle: 10, + WriteSessions: 0.5, healthCheckSampleInterval: 20 * time.Millisecond, } _, client, teardown := setupMockedTestServerWithConfig(t, @@ -1124,17 +1187,18 @@ func TestMaintainer_CreatesSessions(t *testing.T) { timeoutAmt := 4 * time.Second timeout := time.After(timeoutAmt) - var numOpened uint64 + var numPrepared int + want := int(spc.WriteSessions * float64(spc.MinOpened)) loop: for { select { case <-timeout: - t.Fatalf("timed out after %v, got %d session(s), want %d", timeoutAmt, numOpened, spc.MinOpened) + t.Fatalf("timed out after %v, got %d write-prepared session(s), want %d", timeoutAmt, numPrepared, want) default: sp.mu.Lock() - numOpened = sp.numOpened + numPrepared = sp.idleWriteList.Len() sp.mu.Unlock() - if numOpened == 10 { + if numPrepared == want { break loop } } diff --git a/spanner/sessionclient.go b/spanner/sessionclient.go new file mode 100644 index 000000000000..b3857a0ae6a5 --- /dev/null +++ b/spanner/sessionclient.go @@ -0,0 +1,233 @@ +/* +Copyright 2019 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spanner + +import ( + "context" + "fmt" + "sync" + "time" + + "cloud.google.com/go/internal/trace" + vkit "cloud.google.com/go/spanner/apiv1" + sppb "google.golang.org/genproto/googleapis/spanner/v1" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" +) + +// sessionConsumer is passed to the batchCreateSessions method and will receive +// the sessions that are created as they become available. A sessionConsumer +// implementation must be safe for concurrent use. +// +// The interface is implemented by sessionPool and is used for testing the +// sessionClient. +type sessionConsumer interface { + // sessionReady is called when a session has been created and is ready for + // use. + sessionReady(s *session) + + // sessionCreationFailed is called when the creation of a sub-batch of + // sessions failed. The numSessions argument specifies the number of + // sessions that could not be created as a result of this error. A + // consumer may receive multiple errors per batch. + sessionCreationFailed(err error, numSessions int32) +} + +// sessionClient creates sessions for a database, either in batches or one at a +// time. Each session will be affiliated with a gRPC channel. sessionClient +// will ensure that the sessions that are created are evenly distributed over +// all available channels. +type sessionClient struct { + mu sync.Mutex + rr int + closed bool + + gapicClients []*vkit.Client + database string + sessionLabels map[string]string + md metadata.MD + batchTimeout time.Duration +} + +// newSessionClient creates a session client to use for a database. +func newSessionClient(gapicClients []*vkit.Client, database string, sessionLabels map[string]string, md metadata.MD) *sessionClient { + return &sessionClient{ + gapicClients: gapicClients, + database: database, + sessionLabels: sessionLabels, + md: md, + batchTimeout: time.Minute, + } +} + +func (sc *sessionClient) close() error { + sc.mu.Lock() + defer sc.mu.Unlock() + sc.closed = true + var errs []error + for _, gpc := range sc.gapicClients { + if err := gpc.Close(); err != nil { + errs = append(errs, err) + } + } + switch len(errs) { + case 0: + return nil + case 1: + return errs[0] + default: + return fmt.Errorf("closing gapic clients returned multiple errors: %v", errs) + } +} + +// createSession creates one session for the database of the sessionClient. The +// session is created using one synchronous RPC. +func (sc *sessionClient) createSession(ctx context.Context) (*session, error) { + ctx = contextWithOutgoingMetadata(ctx, sc.md) + sc.mu.Lock() + if sc.closed { + return nil, spannerErrorf(codes.FailedPrecondition, "SessionClient is closed") + } + client := sc.rrNextGapicClientLocked() + sc.mu.Unlock() + sid, err := client.CreateSession(ctx, &sppb.CreateSessionRequest{ + Database: sc.database, + Session: &sppb.Session{Labels: sc.sessionLabels}, + }) + if err != nil { + return nil, toSpannerError(err) + } + return &session{valid: true, client: client, id: sid.Name, createTime: time.Now(), md: sc.md}, nil +} + +// batchCreateSessions creates a batch of sessions for the database of the +// sessionClient and returns these to the given sessionConsumer. +// +// createSessionCount is the number of sessions that should be created. The +// sessionConsumer is guaranteed to receive the requested number of sessions if +// no error occurs. If one or more errors occur, the sessionConsumer will +// receive any number of sessions + any number of errors, where each error will +// include the number of sessions that could not be created as a result of the +// error. The sum of returned sessions and errored sessions will be equal to +// the number of requested sessions. +func (sc *sessionClient) batchCreateSessions(createSessionCount int32, consumer sessionConsumer) error { + // The sessions that we create should be evenly distributed over all the + // channels (gapic clients) that are used by the client. Each gapic client + // will do a request for a fraction of the total. + sessionCountPerChannel := createSessionCount / int32(len(sc.gapicClients)) + // The remainder of the calculation will be added to the number of sessions + // that will be created for the first channel, to ensure that we create the + // exact number of requested sessions. + remainder := createSessionCount % int32(len(sc.gapicClients)) + sc.mu.Lock() + defer sc.mu.Unlock() + if sc.closed { + return spannerErrorf(codes.FailedPrecondition, "SessionClient is closed") + } + // Spread the session creation over all available gRPC channels. Spanner + // will maintain server side caches for a session on the gRPC channel that + // is used by the session. A session should therefore always use the same + // channel, and the sessions should be as evenly distributed as possible + // over the channels. + for i := 0; i < len(sc.gapicClients); i++ { + client := sc.rrNextGapicClientLocked() + // Determine the number of sessions that should be created for this + // channel. The createCount for the first channel will be increased + // with the remainder of the division of the total number of sessions + // with the number of channels. All other channels will just use the + // result of the division over all channels. + createCountForChannel := sessionCountPerChannel + if i == 0 { + // We add the remainder to the first gRPC channel we use. We could + // also spread the remainder over all channels, but this ensures + // that small batches of sessions (i.e. less than numChannels) are + // created in one RPC. + createCountForChannel += remainder + } + if createCountForChannel > 0 { + go sc.executeBatchCreateSessions(client, createCountForChannel, sc.sessionLabels, sc.md, consumer) + } + } + return nil +} + +// executeBatchCreateSessions executes the gRPC call for creating a batch of +// sessions. +func (sc *sessionClient) executeBatchCreateSessions(client *vkit.Client, createCount int32, labels map[string]string, md metadata.MD, consumer sessionConsumer) { + ctx, cancel := context.WithTimeout(context.Background(), sc.batchTimeout) + defer cancel() + ctx = contextWithOutgoingMetadata(ctx, sc.md) + + ctx = trace.StartSpan(ctx, "cloud.google.com/go/spanner.BatchCreateSessions") + defer func() { trace.EndSpan(ctx, nil) }() + trace.TracePrintf(ctx, nil, "Creating a batch of %d sessions", createCount) + remainingCreateCount := createCount + for { + sc.mu.Lock() + closed := sc.closed + sc.mu.Unlock() + if closed { + err := spannerErrorf(codes.Canceled, "Session client closed") + trace.TracePrintf(ctx, nil, "Session client closed while creating a batch of %d sessions: %v", createCount, err) + consumer.sessionCreationFailed(err, remainingCreateCount) + break + } + if ctx.Err() != nil { + trace.TracePrintf(ctx, nil, "Context error while creating a batch of %d sessions: %v", createCount, ctx.Err()) + consumer.sessionCreationFailed(ctx.Err(), remainingCreateCount) + break + } + response, err := client.BatchCreateSessions(ctx, &sppb.BatchCreateSessionsRequest{ + SessionCount: remainingCreateCount, + Database: sc.database, + SessionTemplate: &sppb.Session{Labels: labels}, + }) + if err != nil { + trace.TracePrintf(ctx, nil, "Error creating a batch of %d sessions: %v", remainingCreateCount, err) + consumer.sessionCreationFailed(err, remainingCreateCount) + break + } + actuallyCreated := int32(len(response.Session)) + trace.TracePrintf(ctx, nil, "Received a batch of %d sessions", actuallyCreated) + for _, s := range response.Session { + consumer.sessionReady(&session{valid: true, client: client, id: s.Name, createTime: time.Now(), md: md}) + } + if actuallyCreated < remainingCreateCount { + // Spanner could return less sessions than requested. In that case, we + // should do another call using the same gRPC channel. + remainingCreateCount -= actuallyCreated + } else { + trace.TracePrintf(ctx, nil, "Finished creating %d sessions", createCount) + break + } + } +} + +func (sc *sessionClient) sessionWithID(id string) *session { + sc.mu.Lock() + defer sc.mu.Unlock() + return &session{valid: true, client: sc.rrNextGapicClientLocked(), id: id, createTime: time.Now(), md: sc.md} +} + +// rrNextGapicClientLocked returns the next gRPC client to use for session creation. The +// client is set on the session, and used by all subsequent gRPC calls on the +// session. Using the same channel for all gRPC calls for a session ensures the +// optimal usage of server side caches. +func (sc *sessionClient) rrNextGapicClientLocked() *vkit.Client { + sc.rr = (sc.rr + 1) % len(sc.gapicClients) + return sc.gapicClients[sc.rr] +} diff --git a/spanner/sessionclient_test.go b/spanner/sessionclient_test.go new file mode 100644 index 000000000000..82be7fd85a23 --- /dev/null +++ b/spanner/sessionclient_test.go @@ -0,0 +1,315 @@ +/* +Copyright 2019 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package spanner + +import ( + "context" + "sync" + "testing" + "time" + + vkit "cloud.google.com/go/spanner/apiv1" + . "cloud.google.com/go/spanner/internal/testutil" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type testSessionCreateError struct { + err error + num int32 +} + +type testConsumer struct { + numExpected int32 + + mu sync.Mutex + sessions []*session + errors []*testSessionCreateError + numErr int32 + + receivedAll chan struct{} +} + +func (tc *testConsumer) sessionReady(s *session) { + tc.mu.Lock() + defer tc.mu.Unlock() + tc.sessions = append(tc.sessions, s) + tc.checkReceivedAll() +} + +func (tc *testConsumer) sessionCreationFailed(err error, num int32) { + tc.mu.Lock() + defer tc.mu.Unlock() + tc.errors = append(tc.errors, &testSessionCreateError{ + err: err, + num: num, + }) + tc.numErr += num + tc.checkReceivedAll() +} + +func (tc *testConsumer) checkReceivedAll() { + if int32(len(tc.sessions))+tc.numErr == tc.numExpected { + close(tc.receivedAll) + } +} + +func newTestConsumer(numExpected int32) *testConsumer { + return &testConsumer{ + numExpected: numExpected, + receivedAll: make(chan struct{}), + } +} + +func TestCreateAndCloseSession(t *testing.T) { + t.Parallel() + + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 0, + MaxOpened: 100, + }, + }) + defer teardown() + + s, err := client.sc.createSession(context.Background()) + if err != nil { + t.Fatalf("batch.next() return error mismatch\ngot: %v\nwant: nil", err) + } + if s == nil { + t.Fatalf("batch.next() return value mismatch\ngot: %v\nwant: any session", s) + } + if server.TestSpanner.TotalSessionsCreated() != 1 { + t.Fatalf("number of sessions created mismatch\ngot: %v\nwant: %v", server.TestSpanner.TotalSessionsCreated(), 1) + } + s.delete(context.Background()) + if server.TestSpanner.TotalSessionsDeleted() != 1 { + t.Fatalf("number of sessions deleted mismatch\ngot: %v\nwant: %v", server.TestSpanner.TotalSessionsDeleted(), 1) + } +} + +func TestBatchCreateAndCloseSession(t *testing.T) { + t.Parallel() + + numSessions := int32(100) + server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t) + defer serverTeardown() + for numChannels := 1; numChannels <= 32; numChannels *= 2 { + prevCreated := server.TestSpanner.TotalSessionsCreated() + prevDeleted := server.TestSpanner.TotalSessionsDeleted() + client, err := NewClientWithConfig(context.Background(), "projects/p/instances/i/databases/d", ClientConfig{ + NumChannels: numChannels, + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 0, + MaxOpened: 400, + }}, opts...) + if err != nil { + t.Fatal(err) + } + consumer := newTestConsumer(numSessions) + client.sc.batchCreateSessions(numSessions, consumer) + <-consumer.receivedAll + if len(consumer.sessions) != int(numSessions) { + t.Fatalf("returned number of sessions mismatch\ngot: %v\nwant: %v", len(consumer.sessions), numSessions) + } + created := server.TestSpanner.TotalSessionsCreated() - prevCreated + if created != uint(numSessions) { + t.Fatalf("number of sessions created mismatch\ngot: %v\nwant: %v", created, numSessions) + } + // Check that all channels are used evenly. + channelCounts := make(map[*vkit.Client]int32) + for _, s := range consumer.sessions { + channelCounts[s.client]++ + } + if len(channelCounts) != numChannels { + t.Fatalf("number of channels used mismatch\ngot: %v\nwant: %v", len(channelCounts), numChannels) + } + for _, c := range channelCounts { + if c < numSessions/int32(numChannels) || c > numSessions/int32(numChannels)+(numSessions%int32(numChannels)) { + t.Fatalf("channel used an unexpected number of times\ngot: %v\nwant between %v and %v", c, numSessions/int32(numChannels), numSessions/int32(numChannels)+1) + } + } + // Delete the sessions. + for _, s := range consumer.sessions { + s.delete(context.Background()) + } + deleted := server.TestSpanner.TotalSessionsDeleted() - prevDeleted + if deleted != uint(numSessions) { + t.Fatalf("number of sessions deleted mismatch\ngot: %v\nwant %v", deleted, numSessions) + } + client.Close() + } +} + +func TestBatchCreateSessionsWithExceptions(t *testing.T) { + t.Parallel() + + numSessions := int32(100) + server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t) + defer serverTeardown() + + // Run the test with everything between 1 and numChannels errors. + for numErrors := int32(1); numErrors <= numChannels; numErrors++ { + // Make sure that the error is not always the first call. + for firstErrorAt := numErrors - 1; firstErrorAt < numChannels-numErrors+1; firstErrorAt++ { + client, err := NewClientWithConfig(context.Background(), "projects/p/instances/i/databases/d", ClientConfig{ + NumChannels: numChannels, + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 0, + MaxOpened: 400, + }}, opts...) + if err != nil { + t.Fatal(err) + } + // Register the errors on the server. + errors := make([]error, numErrors+firstErrorAt) + for i := firstErrorAt; i < numErrors+firstErrorAt; i++ { + errors[i] = spannerErrorf(codes.FailedPrecondition, "session creation failed") + } + server.TestSpanner.PutExecutionTime(MethodBatchCreateSession, SimulatedExecutionTime{ + Errors: errors, + }) + consumer := newTestConsumer(numSessions) + client.sc.batchCreateSessions(numSessions, consumer) + <-consumer.receivedAll + + sessionsReturned := int32(len(consumer.sessions)) + if int32(len(consumer.errors)) != numErrors { + t.Fatalf("Error count mismatch\nGot: %d\nWant: %d", len(consumer.errors), numErrors) + } + for _, e := range consumer.errors { + if g, w := status.Code(e.err), codes.FailedPrecondition; g != w { + t.Fatalf("error code mismatch\ngot: %v\nwant: %v", g, w) + } + } + maxExpectedSessions := numSessions - numErrors*(numSessions/numChannels) + minExpectedSessions := numSessions - numErrors*(numSessions/numChannels+1) + if sessionsReturned < minExpectedSessions || sessionsReturned > maxExpectedSessions { + t.Fatalf("session count mismatch\ngot: %v\nwant between %v and %v", sessionsReturned, minExpectedSessions, maxExpectedSessions) + } + client.Close() + } + } +} + +func TestBatchCreateSessions_ServerReturnsLessThanRequestedSessions(t *testing.T) { + t.Parallel() + + numChannels := 4 + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ + NumChannels: numChannels, + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 0, + MaxOpened: 100, + }, + }) + defer teardown() + // Ensure that the server will never return more than 10 sessions per batch + // create request. + server.TestSpanner.SetMaxSessionsReturnedByServerPerBatchRequest(10) + numSessions := int32(100) + // Request a batch of sessions that is larger than will be returned by the + // server in one request. The server will return at most 10 sessions per + // request. The sessionCreator will spread these requests over the 4 + // channels that are available, i.e. do requests for 25 sessions in each + // request. The batch should still return 100 sessions. + consumer := newTestConsumer(numSessions) + client.sc.batchCreateSessions(numSessions, consumer) + <-consumer.receivedAll + if len(consumer.errors) > 0 { + t.Fatalf("Error count mismatch\nGot: %d\nWant: %d", len(consumer.errors), 0) + } + returnedSessionCount := int32(len(consumer.sessions)) + if returnedSessionCount != numSessions { + t.Fatalf("Returned sessions mismatch\nGot: %v\nWant: %v", returnedSessionCount, numSessions) + } +} + +func TestBatchCreateSessions_ServerExhausted(t *testing.T) { + t.Parallel() + + numChannels := 4 + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ + NumChannels: numChannels, + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 0, + MaxOpened: 100, + }, + }) + defer teardown() + numSessions := int32(100) + maxSessions := int32(50) + // Ensure that the server will never return more than 50 sessions in total. + server.TestSpanner.SetMaxSessionsReturnedByServerInTotal(maxSessions) + consumer := newTestConsumer(numSessions) + client.sc.batchCreateSessions(numSessions, consumer) + <-consumer.receivedAll + // Session creation should end with at least one RESOURCE_EXHAUSTED error. + if len(consumer.errors) == 0 { + t.Fatalf("Error count mismatch\nGot: %d\nWant: > %d", len(consumer.errors), 0) + } + for _, e := range consumer.errors { + if g, w := status.Code(e.err), codes.ResourceExhausted; g != w { + t.Fatalf("Error code mismath\nGot: %v\nWant: %v", g, w) + } + } + // The number of returned sessions should be equal to the max of the + // server. + returnedSessionCount := int32(len(consumer.sessions)) + if returnedSessionCount != maxSessions { + t.Fatalf("Returned sessions mismatch\nGot: %v\nWant: %v", returnedSessionCount, maxSessions) + } + if consumer.numErr != (numSessions - maxSessions) { + t.Fatalf("Num errored sessions mismatch\nGot: %v\nWant: %v", consumer.numErr, numSessions-maxSessions) + } +} + +func TestBatchCreateSessions_WithTimeout(t *testing.T) { + t.Parallel() + + numSessions := int32(100) + server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t) + defer serverTeardown() + server.TestSpanner.PutExecutionTime(MethodBatchCreateSession, SimulatedExecutionTime{ + MinimumExecutionTime: time.Second, + }) + client, err := NewClientWithConfig(context.Background(), "projects/p/instances/i/databases/d", ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 0, + MaxOpened: 400, + }}, opts...) + if err != nil { + t.Fatal(err) + } + + client.sc.batchTimeout = 10 * time.Millisecond + consumer := newTestConsumer(numSessions) + client.sc.batchCreateSessions(numSessions, consumer) + <-consumer.receivedAll + if len(consumer.sessions) > 0 { + t.Fatalf("Returned number of sessions mismatch\ngot: %v\nwant: %v", len(consumer.sessions), 0) + } + if len(consumer.errors) != numChannels { + t.Fatalf("Returned number of errors mismatch\ngot: %v\nwant: %v", len(consumer.errors), numChannels) + } + for _, e := range consumer.errors { + if g, w := status.Code(e.err), codes.DeadlineExceeded; g != w { + t.Fatalf("Error code mismatch\ngot: %v\nwant: %v", g, w) + } + } + client.Close() +}