Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: prohibit more than MaxConcurrentStreams handlers from running at once #6703

Merged
merged 3 commits into from Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
39 changes: 39 additions & 0 deletions benchmark/primitives/primitives_test.go
Expand Up @@ -425,3 +425,42 @@ func BenchmarkRLockUnlock(b *testing.B) {
}
})
}

type ifNop interface {
nop()
}

type alwaysNop struct{}

func (alwaysNop) nop() {}

type concreteNop struct {
isNop atomic.Bool
i int
}

func (c *concreteNop) nop() {
if c.isNop.Load() {
return
}
c.i++
}

func BenchmarkInterfaceNop(b *testing.B) {
n := ifNop(alwaysNop{})
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
n.nop()
}
})
}

func BenchmarkConcreteNop(b *testing.B) {
n := &concreteNop{}
n.isNop.Store(true)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
n.nop()
}
})
}
11 changes: 3 additions & 8 deletions internal/transport/http2_server.go
Expand Up @@ -171,15 +171,10 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
ID: http2.SettingMaxFrameSize,
Val: http2MaxFrameLen,
}}
// TODO(zhaoq): Have a better way to signal "no limit" because 0 is
// permitted in the HTTP2 spec.
maxStreams := config.MaxStreams
if maxStreams == 0 {
maxStreams = math.MaxUint32
} else {
if config.MaxStreams != math.MaxUint32 {
isettings = append(isettings, http2.Setting{
ID: http2.SettingMaxConcurrentStreams,
Val: maxStreams,
Val: config.MaxStreams,
})
}
dynamicWindow := true
Expand Down Expand Up @@ -258,7 +253,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
framer: framer,
readerDone: make(chan struct{}),
writerDone: make(chan struct{}),
maxStreams: maxStreams,
maxStreams: config.MaxStreams,
inTapHandle: config.InTapHandle,
fc: &trInFlow{limit: uint32(icwz)},
state: reachable,
Expand Down
35 changes: 19 additions & 16 deletions internal/transport/transport_test.go
Expand Up @@ -337,6 +337,9 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
return
}
rawConn := conn
if serverConfig.MaxStreams == 0 {
serverConfig.MaxStreams = math.MaxUint32
}
transport, err := NewServerTransport(conn, serverConfig)
if err != nil {
return
Expand Down Expand Up @@ -425,8 +428,8 @@ func setUpServerOnly(t *testing.T, port int, sc *ServerConfig, ht hType) *server
return server
}

func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, *http2Client, func()) {
return setUpWithOptions(t, port, &ServerConfig{MaxStreams: maxStreams}, ht, ConnectOptions{})
func setUp(t *testing.T, port int, ht hType) (*server, *http2Client, func()) {
return setUpWithOptions(t, port, &ServerConfig{}, ht, ConnectOptions{})
}

func setUpWithOptions(t *testing.T, port int, sc *ServerConfig, ht hType, copts ConnectOptions) (*server, *http2Client, func()) {
Expand Down Expand Up @@ -521,7 +524,7 @@ func (s) TestInflightStreamClosing(t *testing.T) {

// Tests that when streamID > MaxStreamId, the current client transport drains.
func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
server, ct, cancel := setUp(t, 0, normal)
defer cancel()
defer server.stop()
callHdr := &CallHdr{
Expand Down Expand Up @@ -566,7 +569,7 @@ func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) {
}

func (s) TestClientSendAndReceive(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
server, ct, cancel := setUp(t, 0, normal)
defer cancel()
callHdr := &CallHdr{
Host: "localhost",
Expand Down Expand Up @@ -606,7 +609,7 @@ func (s) TestClientSendAndReceive(t *testing.T) {
}

func (s) TestClientErrorNotify(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
server, ct, cancel := setUp(t, 0, normal)
defer cancel()
go server.stop()
// ct.reader should detect the error and activate ct.Error().
Expand Down Expand Up @@ -640,7 +643,7 @@ func performOneRPC(ct ClientTransport) {
}

func (s) TestClientMix(t *testing.T) {
s, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
s, ct, cancel := setUp(t, 0, normal)
defer cancel()
time.AfterFunc(time.Second, s.stop)
go func(ct ClientTransport) {
Expand All @@ -654,7 +657,7 @@ func (s) TestClientMix(t *testing.T) {
}

func (s) TestLargeMessage(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
server, ct, cancel := setUp(t, 0, normal)
defer cancel()
callHdr := &CallHdr{
Host: "localhost",
Expand Down Expand Up @@ -789,7 +792,7 @@ func (s) TestLargeMessageWithDelayRead(t *testing.T) {
// proceed until they complete naturally, while not allowing creation of new
// streams during this window.
func (s) TestGracefulClose(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, pingpong)
server, ct, cancel := setUp(t, 0, pingpong)
defer cancel()
defer func() {
// Stop the server's listener to make the server's goroutines terminate
Expand Down Expand Up @@ -855,7 +858,7 @@ func (s) TestGracefulClose(t *testing.T) {
}

func (s) TestLargeMessageSuspension(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, suspended)
server, ct, cancel := setUp(t, 0, suspended)
defer cancel()
callHdr := &CallHdr{
Host: "localhost",
Expand Down Expand Up @@ -963,7 +966,7 @@ func (s) TestMaxStreams(t *testing.T) {
}

func (s) TestServerContextCanceledOnClosedConnection(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, suspended)
server, ct, cancel := setUp(t, 0, suspended)
defer cancel()
callHdr := &CallHdr{
Host: "localhost",
Expand Down Expand Up @@ -1435,7 +1438,7 @@ func (s) TestClientWithMisbehavedServer(t *testing.T) {
var encodingTestStatus = status.New(codes.Internal, "\n")

func (s) TestEncodingRequiredStatus(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, encodingRequiredStatus)
server, ct, cancel := setUp(t, 0, encodingRequiredStatus)
defer cancel()
callHdr := &CallHdr{
Host: "localhost",
Expand Down Expand Up @@ -1463,7 +1466,7 @@ func (s) TestEncodingRequiredStatus(t *testing.T) {
}

func (s) TestInvalidHeaderField(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField)
server, ct, cancel := setUp(t, 0, invalidHeaderField)
defer cancel()
callHdr := &CallHdr{
Host: "localhost",
Expand All @@ -1485,7 +1488,7 @@ func (s) TestInvalidHeaderField(t *testing.T) {
}

func (s) TestHeaderChanClosedAfterReceivingAnInvalidHeader(t *testing.T) {
server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField)
server, ct, cancel := setUp(t, 0, invalidHeaderField)
defer cancel()
defer server.stop()
defer ct.Close(fmt.Errorf("closed manually by test"))
Expand Down Expand Up @@ -2153,7 +2156,7 @@ func (s) TestPingPong1MB(t *testing.T) {

// This is a stress-test of flow control logic.
func runPingPongTest(t *testing.T, msgSize int) {
server, client, cancel := setUp(t, 0, 0, pingpong)
server, client, cancel := setUp(t, 0, pingpong)
defer cancel()
defer server.stop()
defer client.Close(fmt.Errorf("closed manually by test"))
Expand Down Expand Up @@ -2235,7 +2238,7 @@ func (s) TestHeaderTblSize(t *testing.T) {
}
}()

server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
server, ct, cancel := setUp(t, 0, normal)
defer cancel()
defer ct.Close(fmt.Errorf("closed manually by test"))
defer server.stop()
Expand Down Expand Up @@ -2594,7 +2597,7 @@ func TestConnectionError_Unwrap(t *testing.T) {

func (s) TestPeerSetInServerContext(t *testing.T) {
// create client and server transports.
server, client, cancel := setUp(t, 0, math.MaxUint32, normal)
server, client, cancel := setUp(t, 0, normal)
defer cancel()
defer server.stop()
defer client.Close(fmt.Errorf("closed manually by test"))
Expand Down
71 changes: 50 additions & 21 deletions server.go
Expand Up @@ -115,12 +115,6 @@
mdata any
}

type serverWorkerData struct {
st transport.ServerTransport
wg *sync.WaitGroup
stream *transport.Stream
}

// Server is a gRPC server to serve RPC requests.
type Server struct {
opts serverOptions
Expand All @@ -145,7 +139,7 @@
channelzID *channelz.Identifier
czData *channelzData

serverWorkerChannel chan *serverWorkerData
serverWorkerChannel chan func()
}

type serverOptions struct {
Expand Down Expand Up @@ -179,6 +173,7 @@
}

var defaultServerOptions = serverOptions{
maxConcurrentStreams: math.MaxUint32,
easwars marked this conversation as resolved.
Show resolved Hide resolved
maxReceiveMessageSize: defaultServerMaxReceiveMessageSize,
maxSendMessageSize: defaultServerMaxSendMessageSize,
connectionTimeout: 120 * time.Second,
Expand Down Expand Up @@ -404,6 +399,9 @@
// MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
// of concurrent streams to each ServerTransport.
func MaxConcurrentStreams(n uint32) ServerOption {
if n == 0 {
easwars marked this conversation as resolved.
Show resolved Hide resolved
n = math.MaxUint32
}

Check warning on line 404 in server.go

View check run for this annotation

Codecov / codecov/patch

server.go#L403-L404

Added lines #L403 - L404 were not covered by tests
return newFuncServerOption(func(o *serverOptions) {
o.maxConcurrentStreams = n
})
Expand Down Expand Up @@ -605,24 +603,19 @@
// [1] https://github.com/golang/go/issues/18138
func (s *Server) serverWorker() {
for completed := 0; completed < serverWorkerResetThreshold; completed++ {
data, ok := <-s.serverWorkerChannel
f, ok := <-s.serverWorkerChannel

Check warning on line 606 in server.go

View check run for this annotation

Codecov / codecov/patch

server.go#L606

Added line #L606 was not covered by tests
if !ok {
return
}
s.handleSingleStream(data)
f()

Check warning on line 610 in server.go

View check run for this annotation

Codecov / codecov/patch

server.go#L610

Added line #L610 was not covered by tests
}
go s.serverWorker()
}

func (s *Server) handleSingleStream(data *serverWorkerData) {
defer data.wg.Done()
s.handleStream(data.st, data.stream)
}

// initServerWorkers creates worker goroutines and a channel to process incoming
// connections to reduce the time spent overall on runtime.morestack.
func (s *Server) initServerWorkers() {
s.serverWorkerChannel = make(chan *serverWorkerData)
s.serverWorkerChannel = make(chan func())

Check warning on line 618 in server.go

View check run for this annotation

Codecov / codecov/patch

server.go#L618

Added line #L618 was not covered by tests
for i := uint32(0); i < s.opts.numServerWorkers; i++ {
go s.serverWorker()
}
Expand Down Expand Up @@ -982,21 +975,26 @@
defer st.Close(errors.New("finished serving streams for the server transport"))
var wg sync.WaitGroup

streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams)
st.HandleStreams(func(stream *transport.Stream) {
wg.Add(1)

streamQuota.acquire()
f := func() {
defer streamQuota.release()
defer wg.Done()
s.handleStream(st, stream)
}

if s.opts.numServerWorkers > 0 {
data := &serverWorkerData{st: st, wg: &wg, stream: stream}
select {
case s.serverWorkerChannel <- data:
case s.serverWorkerChannel <- f:

Check warning on line 991 in server.go

View check run for this annotation

Codecov / codecov/patch

server.go#L991

Added line #L991 was not covered by tests
return
default:
// If all stream workers are busy, fallback to the default code path.
}
}
go func() {
defer wg.Done()
s.handleStream(st, stream)
}()
go f()
})
wg.Wait()
}
Expand Down Expand Up @@ -2077,3 +2075,34 @@
}
return fmt.Errorf("client does not support compressor %q", name)
}

// atomicSemaphore implements a blocking, counting semaphore. acquire should be
// called synchronously; release may be called asynchronously.
type atomicSemaphore struct {
n atomic.Int64
wait chan struct{}
}

func (q *atomicSemaphore) acquire() {
if q.n.Add(-1) < 0 {
// We ran out of quota. Block until a release happens.
<-q.wait
}
}

func (q *atomicSemaphore) release() {
// N.B. the "<= 0" check below should allow for this to work with multiple
// concurrent calls to acquire, but also note that with synchronous calls to
// acquire, as our system does, n will never be less than -1. There are
// fairness issues (queuing) to consider if this was to be generalized.
if q.n.Add(1) <= 0 {
// An acquire was waiting on us. Unblock it.
q.wait <- struct{}{}
}
}

func newHandlerQuota(n uint32) *atomicSemaphore {
a := &atomicSemaphore{wait: make(chan struct{}, 1)}
a.n.Store(int64(n))
return a
}