diff --git a/spanner/errors.go b/spanner/errors.go index af03c321ae4c..547a6af0ca59 100644 --- a/spanner/errors.go +++ b/spanner/errors.go @@ -41,6 +41,9 @@ type Error struct { Desc string // trailers are the trailers returned in the response, if any. trailers metadata.MD + // additionalInformation optionally contains any additional information + // about the error. + additionalInformation string } // Error implements error.Error. @@ -49,7 +52,10 @@ func (e *Error) Error() string { return fmt.Sprintf("spanner: OK") } code := ErrCode(e) - return fmt.Sprintf("spanner: code = %q, desc = %q", code, e.Desc) + if e.additionalInformation == "" { + return fmt.Sprintf("spanner: code = %q, desc = %q", code, e.Desc) + } + return fmt.Sprintf("spanner: code = %q, desc = %q, additional information = %s", code, e.Desc, e.additionalInformation) } // Unwrap returns the wrapped error (if any). @@ -115,11 +121,11 @@ func toSpannerErrorWithMetadata(err error, trailers metadata.MD) error { } switch { case err == context.DeadlineExceeded || err == context.Canceled: - return &Error{status.FromContextError(err).Code(), status.FromContextError(err).Err(), err.Error(), trailers} + return &Error{status.FromContextError(err).Code(), status.FromContextError(err).Err(), err.Error(), trailers, ""} case status.Code(err) == codes.Unknown: - return &Error{codes.Unknown, err, err.Error(), trailers} + return &Error{codes.Unknown, err, err.Error(), trailers, ""} default: - return &Error{status.Convert(err).Code(), err, status.Convert(err).Message(), trailers} + return &Error{status.Convert(err).Code(), err, status.Convert(err).Message(), trailers, ""} } } diff --git a/spanner/session.go b/spanner/session.go index 573a00801f91..6545c8afc5aa 100644 --- a/spanner/session.go +++ b/spanner/session.go @@ -24,6 +24,7 @@ import ( "log" "math" "math/rand" + "runtime/debug" "strings" "sync" "time" @@ -44,19 +45,39 @@ type sessionHandle struct { // session is a pointer to a session object. Transactions never need to // access it directly. session *session + // checkoutTime is the time the session was checked out of the pool. + checkoutTime time.Time + // trackedSessionHandle is the linked list node which links the session to + // the list of tracked session handles. trackedSessionHandle is only set if + // TrackSessionHandles has been enabled in the session pool configuration. + trackedSessionHandle *list.Element + // stack is the call stack of the goroutine that checked out the session + // from the pool. This can be used to track down session leak problems. + stack []byte } // recycle gives the inner session object back to its home session pool. It is // safe to call recycle multiple times but only the first one would take effect. func (sh *sessionHandle) recycle() { sh.mu.Lock() - defer sh.mu.Unlock() if sh.session == nil { // sessionHandle has already been recycled. + sh.mu.Unlock() return } + p := sh.session.pool + tracked := sh.trackedSessionHandle sh.session.recycle() sh.session = nil + sh.trackedSessionHandle = nil + sh.checkoutTime = time.Time{} + sh.stack = nil + sh.mu.Unlock() + if tracked != nil { + p.mu.Lock() + p.trackedSessionHandles.Remove(tracked) + p.mu.Unlock() + } } // getID gets the Cloud Spanner session ID from the internal session object. @@ -109,8 +130,18 @@ func (sh *sessionHandle) getTransactionID() transactionID { func (sh *sessionHandle) destroy() { sh.mu.Lock() s := sh.session + p := s.pool + tracked := sh.trackedSessionHandle sh.session = nil + sh.trackedSessionHandle = nil + sh.checkoutTime = time.Time{} + sh.stack = nil sh.mu.Unlock() + if tracked != nil { + p.mu.Lock() + p.trackedSessionHandles.Remove(tracked) + p.mu.Unlock() + } if s == nil { // sessionHandle has already been destroyed.. return @@ -376,6 +407,13 @@ type SessionPoolConfig struct { // Defaults to 5m. HealthCheckInterval time.Duration + // TrackSessionHandles determines whether the session pool will keep track + // of the stacktrace of the goroutines that take sessions from the pool. + // This setting can be used to track down session leak problems. + // + // Defaults to false. + TrackSessionHandles bool + // healthCheckSampleInterval is how often the health checker samples live // session (for use in maintaining session pool size). // @@ -450,6 +488,10 @@ type sessionPool struct { valid bool // sc is used to create the sessions for the pool. sc *sessionClient + // trackedSessionHandles contains all sessions handles that have been + // checked out of the pool. The list is only filled if TrackSessionHandles + // has been enabled. + trackedSessionHandles list.List // idleList caches idle session IDs. Session IDs in this list can be // allocated for use. idleList list.List @@ -621,6 +663,68 @@ var errInvalidSessionPool = spannerErrorf(codes.InvalidArgument, "invalid sessio // sessionPool.take(). var errGetSessionTimeout = spannerErrorf(codes.Canceled, "timeout / context canceled during getting session") +// newSessionHandle creates a new session handle for the given session for this +// session pool. The session handle will also hold a copy of the current call +// stack if the session pool has been configured to track the call stacks of +// sessions being checked out of the pool. +func (p *sessionPool) newSessionHandle(s *session) (sh *sessionHandle) { + sh = &sessionHandle{session: s, checkoutTime: time.Now()} + if p.TrackSessionHandles { + p.mu.Lock() + sh.trackedSessionHandle = p.trackedSessionHandles.PushBack(sh) + p.mu.Unlock() + sh.stack = debug.Stack() + } + return sh +} + +// errGetSessionTimeout returns error for context timeout during +// sessionPool.take(). +func (p *sessionPool) errGetSessionTimeout() error { + if p.TrackSessionHandles { + return p.errGetSessionTimeoutWithTrackedSessionHandles() + } + return p.errGetBasicSessionTimeout() +} + +// errGetBasicSessionTimeout returns error for context timout during +// sessionPool.take() without any tracked sessionHandles. +func (p *sessionPool) errGetBasicSessionTimeout() error { + return spannerErrorf(codes.Canceled, "timeout / context canceled during getting session.\n"+ + "Enable SessionPoolConfig.TrackSessionHandles if you suspect a session leak to get more information about the checked out sessions.") +} + +// errGetSessionTimeoutWithTrackedSessionHandles returns error for context +// timout during sessionPool.take() including a stacktrace of each checked out +// session handle. +func (p *sessionPool) errGetSessionTimeoutWithTrackedSessionHandles() error { + err := spannerErrorf(codes.Canceled, "timeout / context canceled during getting session.") + err.(*Error).additionalInformation = p.getTrackedSessionHandleStacksLocked() + return err +} + +// getTrackedSessionHandleStacksLocked returns a string containing the +// stacktrace of all currently checked out sessions of the pool. This method +// requires the caller to have locked p.mu. +func (p *sessionPool) getTrackedSessionHandleStacksLocked() string { + p.mu.Lock() + defer p.mu.Unlock() + stackTraces := "" + i := 1 + element := p.trackedSessionHandles.Front() + for element != nil { + sh := element.Value.(*sessionHandle) + sh.mu.Lock() + if sh.stack != nil { + stackTraces = fmt.Sprintf("%s\n\nSession %d checked out of pool at %s by goroutine:\n%s", stackTraces, i, sh.checkoutTime.Format(time.RFC3339), sh.stack) + } + sh.mu.Unlock() + element = element.Next() + i++ + } + return stackTraces +} + // shouldPrepareWriteLocked returns true if we should prepare more sessions for write. func (p *sessionPool) shouldPrepareWriteLocked() bool { return !p.disableBackgroundPrepareSessions && float64(p.numOpened)*p.WriteSessions > float64(p.idleWriteList.Len()+int(p.prepareReqs)) @@ -710,7 +814,7 @@ func (p *sessionPool) take(ctx context.Context) (*sessionHandle, error) { if !p.isHealthy(s) { continue } - return &sessionHandle{session: s}, nil + return p.newSessionHandle(s), nil } // Idle list is empty, block if session pool has reached max session @@ -722,7 +826,7 @@ func (p *sessionPool) take(ctx context.Context) (*sessionHandle, error) { select { case <-ctx.Done(): trace.TracePrintf(ctx, nil, "Context done waiting for session") - return nil, errGetSessionTimeout + return nil, p.errGetSessionTimeout() case <-mayGetSession: } continue @@ -743,7 +847,7 @@ func (p *sessionPool) take(ctx context.Context) (*sessionHandle, error) { } trace.TracePrintf(ctx, map[string]interface{}{"sessionID": s.getID()}, "Created session") - return &sessionHandle{session: s}, nil + return p.newSessionHandle(s), nil } } @@ -795,7 +899,7 @@ func (p *sessionPool) takeWriteSession(ctx context.Context) (*sessionHandle, err select { case <-ctx.Done(): trace.TracePrintf(ctx, nil, "Context done waiting for session") - return nil, errGetSessionTimeout + return nil, p.errGetSessionTimeout() case <-mayGetSession: } continue @@ -825,7 +929,7 @@ func (p *sessionPool) takeWriteSession(ctx context.Context) (*sessionHandle, err return nil, toSpannerError(err) } } - return &sessionHandle{session: s}, nil + return p.newSessionHandle(s), nil } } diff --git a/spanner/session_test.go b/spanner/session_test.go index 0016709289d8..b1377840e040 100644 --- a/spanner/session_test.go +++ b/spanner/session_test.go @@ -25,10 +25,12 @@ import ( "log" "math/rand" "os" + "strings" "testing" "time" . "cloud.google.com/go/spanner/internal/testutil" + "google.golang.org/api/iterator" sppb "google.golang.org/genproto/googleapis/spanner/v1" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -464,6 +466,86 @@ func TestTakeFromIdleWriteListChecked(t *testing.T) { } } +// TestSessionLeak tests leaking a session and getting the stack of the +// goroutine that leaked it. +func TestSessionLeak(t *testing.T) { + t.Parallel() + ctx := context.Background() + + _, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + TrackSessionHandles: true, + MinOpened: 0, + MaxOpened: 1, + }, + }) + defer teardown() + + // Execute a query without calling rowIterator.Stop. This will cause the + // session not to be returned to the pool. + single := client.Single() + iter := single.Query(ctx, NewStatement(SelectFooFromBar)) + for { + _, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + t.Fatalf("Got unexpected error while iterating results: %v\n", err) + } + } + // The session should not have been returned to the pool. + if g, w := client.idleSessions.idleList.Len(), 0; g != w { + t.Fatalf("Idle sessions count mismatch\nGot: %d\nWant: %d\n", g, w) + } + // The checked out session should contain a stack trace. + if single.sh.stack == nil { + t.Fatalf("Missing stacktrace from session handle") + } + stack := fmt.Sprintf("%s", single.sh.stack) + testMethod := "TestSessionLeak" + if !strings.Contains(stack, testMethod) { + t.Fatalf("Stacktrace does not contain '%s'\nGot: %s", testMethod, stack) + } + // Return the session to the pool. + iter.Stop() + // The stack should now have been removed from the session handle. + if single.sh.stack != nil { + t.Fatalf("Got unexpected stacktrace in session handle: %s", single.sh.stack) + } + + // Do another query and hold on to the session. + single = client.Single() + iter = single.Query(ctx, NewStatement(SelectFooFromBar)) + for { + _, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + t.Fatalf("Got unexpected error while iterating results: %v\n", err) + } + } + // Try to do another query. This will fail as MaxOpened=1. + ctxWithTimeout, cancel := context.WithTimeout(ctx, time.Millisecond*10) + defer cancel() + single2 := client.Single() + iter2 := single2.Query(ctxWithTimeout, NewStatement(SelectFooFromBar)) + _, gotErr := iter2.Next() + wantErr := client.idleSessions.errGetSessionTimeoutWithTrackedSessionHandles() + // The error should contain the stacktraces of all the checked out + // sessions. + if !testEqual(gotErr, wantErr) { + t.Fatalf("Error mismatch on iterating result set.\nGot: %v\nWant: %v\n", gotErr, wantErr) + } + if !strings.Contains(gotErr.Error(), testMethod) { + t.Fatalf("Error does not contain '%s'\nGot: %s", testMethod, gotErr.Error()) + } + // Close iterators to check sessions back into the pool before closing. + iter2.Stop() + iter.Stop() +} + // TestMaxOpenedSessions tests max open sessions constraint. func TestMaxOpenedSessions(t *testing.T) { t.Parallel() @@ -486,7 +568,7 @@ func TestMaxOpenedSessions(t *testing.T) { ctx2, cancel := context.WithTimeout(ctx, 10*time.Millisecond) defer cancel() _, gotErr := sp.take(ctx2) - if wantErr := errGetSessionTimeout; gotErr != wantErr { + if wantErr := sp.errGetBasicSessionTimeout(); !testEqual(gotErr, wantErr) { t.Fatalf("the second session retrival returns error %v, want %v", gotErr, wantErr) } doneWaiting := make(chan struct{}) @@ -619,7 +701,7 @@ func TestMaxBurst(t *testing.T) { _, gotErr := sp.take(ctx2) // Since MaxBurst == 1, the second session request should block. - if wantErr := errGetSessionTimeout; gotErr != wantErr { + if wantErr := sp.errGetBasicSessionTimeout(); !testEqual(gotErr, wantErr) { t.Fatalf("session retrival returns error %v, want %v", gotErr, wantErr) }