diff --git a/server.go b/server.go index 2fa694d555e..682fa1831ec 100644 --- a/server.go +++ b/server.go @@ -144,7 +144,8 @@ type Server struct { channelzID *channelz.Identifier czData *channelzData - serverWorkerChannel chan func() + serverWorkerChannel chan func() + serverWorkerChannelClose func() } type serverOptions struct { @@ -623,15 +624,14 @@ func (s *Server) serverWorker() { // connections to reduce the time spent overall on runtime.morestack. func (s *Server) initServerWorkers() { s.serverWorkerChannel = make(chan func()) + s.serverWorkerChannelClose = grpcsync.OnceFunc(func() { + close(s.serverWorkerChannel) + }) for i := uint32(0); i < s.opts.numServerWorkers; i++ { go s.serverWorker() } } -func (s *Server) stopServerWorkers() { - close(s.serverWorkerChannel) -} - // NewServer creates a gRPC server which has no service registered and has not // started to accept requests yet. func NewServer(opt ...ServerOption) *Server { @@ -1898,15 +1898,19 @@ func (s *Server) stop(graceful bool) { s.closeServerTransportsLocked() } - if s.opts.numServerWorkers > 0 { - s.stopServerWorkers() - } - for len(s.conns) != 0 { s.cv.Wait() } s.conns = nil + if s.opts.numServerWorkers > 0 { + // Closing the channel (only once, via grpcsync.OnceFunc) after all the + // connections have been closed above ensures that there are no + // goroutines executing the callback passed to st.HandleStreams (where + // the channel is written to). + s.serverWorkerChannelClose() + } + if s.events != nil { s.events.Finish() s.events = nil diff --git a/server_ext_test.go b/server_ext_test.go index df79755f325..c065e4ad42a 100644 --- a/server_ext_test.go +++ b/server_ext_test.go @@ -21,14 +21,20 @@ package grpc_test import ( "context" "io" + "runtime" + "sync" "testing" "time" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/status" testgrpc "google.golang.org/grpc/interop/grpc_testing" + testpb "google.golang.org/grpc/interop/grpc_testing" ) // TestServer_MaxHandlers ensures that no more than MaxConcurrentStreams server @@ -97,3 +103,85 @@ func (s) TestServer_MaxHandlers(t *testing.T) { t.Fatal("Received unexpected RPC error:", err) } } + +// Tests the case where the stream worker goroutine option is enabled, and a +// number of RPCs are initiated around the same time that Stop() is called. This +// used to result in a write to a closed channel. This test verifies that there +// is no panic. +func (s) TestStreamWorkers_RPCsAndStop(t *testing.T) { + ss := stubserver.StartTestService(t, nil, grpc.NumStreamWorkers(uint32(runtime.NumCPU()))) + // This deferred stop takes care of stopping the server when one of the + // below grpc.Dials fail, and the test exits early. + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + const numChannels = 20 + const numRPCLoops = 20 + + // Create a bunch of clientconns and ensure that they are READY by making an + // RPC on them. + ccs := make([]*grpc.ClientConn, numChannels) + for i := 0; i < numChannels; i++ { + var err error + ccs[i], err = grpc.Dial(ss.Address, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + t.Fatalf("[iteration: %d] grpc.Dial(%s) failed: %v", i, ss.Address, err) + } + defer ccs[i].Close() + client := testgrpc.NewTestServiceClient(ccs[i]) + if _, err := client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("EmptyCall() failed: %v", err) + } + } + + // Make a bunch of concurrent RPCs on the above clientconns. These will + // eventually race with Stop(), and will start to fail. + var wg sync.WaitGroup + for i := 0; i < numChannels; i++ { + client := testgrpc.NewTestServiceClient(ccs[i]) + for j := 0; j < numRPCLoops; j++ { + wg.Add(1) + go func(client testgrpc.TestServiceClient) { + defer wg.Done() + for { + _, err := client.EmptyCall(ctx, &testpb.Empty{}) + if err == nil { + continue + } + if code := status.Code(err); code == codes.Unavailable { + // Once Stop() has been called on the server, we expect + // subsequent calls to fail with Unavailable. + return + } + t.Errorf("EmptyCall() failed: %v", err) + return + } + }(client) + } + } + + // Call Stop() concurrently with the above RPC attempts. + ss.Stop() + wg.Wait() +} + +// Tests the case where the stream worker goroutine option is enabled, and both +// Stop() and GracefulStop() care called. This used to result in a close of a +// closed channel. This test verifies that there is no panic. +func (s) TestStreamWorkers_GracefulStopAndStop(t *testing.T) { + ss := stubserver.StartTestService(t, nil, grpc.NumStreamWorkers(uint32(runtime.NumCPU()))) + defer ss.Stop() + + if err := ss.StartClient(grpc.WithTransportCredentials(insecure.NewCredentials())); err != nil { + t.Fatalf("Failed to create client to stub server: %v", err) + } + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + client := testgrpc.NewTestServiceClient(ss.CC) + if _, err := client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("EmptyCall() failed: %v", err) + } + + ss.S.GracefulStop() +}