Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand stream's flow control in case of an active read. #1248

Merged
merged 13 commits into from May 23, 2017
4 changes: 2 additions & 2 deletions rpc_util.go
Expand Up @@ -260,7 +260,7 @@ type parser struct {
// that the underlying io.Reader must not return an incompatible
// error.
func (p *parser) recvMsg(maxMsgSize int) (pf payloadFormat, msg []byte, err error) {
if _, err := io.ReadFull(p.r, p.header[:]); err != nil {
if _, err := p.r.Read(p.header[:]); err != nil {
return 0, nil, err
}

Expand All @@ -276,7 +276,7 @@ func (p *parser) recvMsg(maxMsgSize int) (pf payloadFormat, msg []byte, err erro
// TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead
// of making it for each message:
msg = make([]byte, int(length))
if _, err := io.ReadFull(p.r, msg); err != nil {
if _, err := p.r.Read(msg); err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
Expand Down
12 changes: 10 additions & 2 deletions rpc_util_test.go
Expand Up @@ -47,6 +47,14 @@ import (
"google.golang.org/grpc/transport"
)

type fullReader struct {
reader io.Reader
}

func (f fullReader) Read(p []byte) (int, error) {
return io.ReadFull(f.reader, p)
}

var _ CallOption = EmptyCallOption{} // ensure EmptyCallOption implements the interface

func TestSimpleParsing(t *testing.T) {
Expand All @@ -67,7 +75,7 @@ func TestSimpleParsing(t *testing.T) {
// Check that messages with length >= 2^24 are parsed.
{append([]byte{0, 1, 0, 0, 0}, bigMsg...), nil, bigMsg, compressionNone},
} {
buf := bytes.NewReader(test.p)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think the use of fullReader and other changes to this file can be reverted, since the transport is still an io.Reader.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These test here rely on recvMsg's behavior of reading the full message, (which is true when it interacts with the transport stream). However, here the parser is given a "fake" buffer instead of the stream. Given that we changed recvMsg to go from io.ReadFull to p.r.Read, the "fake" buffer here needs to read the full message just like the transport stream does.

buf := fullReader{bytes.NewReader(test.p)}
parser := &parser{r: buf}
pt, b, err := parser.recvMsg(math.MaxInt32)
if err != test.err || !bytes.Equal(b, test.b) || pt != test.pt {
Expand All @@ -79,7 +87,7 @@ func TestSimpleParsing(t *testing.T) {
func TestMultipleParsing(t *testing.T) {
// Set a byte stream consists of 3 messages with their headers.
p := []byte{0, 0, 0, 0, 1, 'a', 0, 0, 0, 0, 2, 'b', 'c', 0, 0, 0, 0, 1, 'd'}
b := bytes.NewReader(p)
b := fullReader{bytes.NewReader(p)}
parser := &parser{r: b}

wantRecvs := []struct {
Expand Down
42 changes: 37 additions & 5 deletions test/end2end_test.go
Expand Up @@ -445,6 +445,7 @@ type test struct {
streamServerInt grpc.StreamServerInterceptor
unknownHandler grpc.StreamHandler
sc <-chan grpc.ServiceConfig
customCodec grpc.Codec
serverInitialWindowSize int32
serverInitialConnWindowSize int32
clientInitialWindowSize int32
Expand Down Expand Up @@ -545,6 +546,9 @@ func (te *test) startServer(ts testpb.TestServiceServer) {
case "clientTimeoutCreds":
sopts = append(sopts, grpc.Creds(&clientTimeoutCreds{}))
}
if te.customCodec != nil {
sopts = append(sopts, grpc.CustomCodec(te.customCodec))
}
s := grpc.NewServer(sopts...)
te.srv = s
if te.e.httpHandler {
Expand Down Expand Up @@ -625,6 +629,9 @@ func (te *test) clientConn() *grpc.ClientConn {
if te.perRPCCreds != nil {
opts = append(opts, grpc.WithPerRPCCredentials(te.perRPCCreds))
}
if te.customCodec != nil {
opts = append(opts, grpc.WithCodec(te.customCodec))
}
var err error
te.cc, err = grpc.Dial(te.srvAddr, opts...)
if err != nil {
Expand Down Expand Up @@ -2634,26 +2641,51 @@ func testServerStreamingConcurrent(t *testing.T, e env) {

}

func generatePayloadSizes() [][]int {
reqSizes := [][]int{
{27182, 8, 1828, 45904},
}

num8KPayloads := 1024
eightKPayloads := []int{}
for i := 0; i < num8KPayloads; i++ {
eightKPayloads = append(eightKPayloads, (1 << 13))
}
reqSizes = append(reqSizes, eightKPayloads)

num2MPayloads := 8
twoMPayloads := []int{}
for i := 0; i < num2MPayloads; i++ {
twoMPayloads = append(twoMPayloads, (1 << 21))
}
reqSizes = append(reqSizes, twoMPayloads)

return reqSizes
}

func TestClientStreaming(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
testClientStreaming(t, e)
for _, s := range generatePayloadSizes() {
for _, e := range listTestEnv() {
testClientStreaming(t, e, s)
}
}
}

func testClientStreaming(t *testing.T, e env) {
func testClientStreaming(t *testing.T, e env, sizes []int) {
te := newTest(t, e)
te.startServer(&testServer{security: e.security})
defer te.tearDown()
tc := testpb.NewTestServiceClient(te.clientConn())

stream, err := tc.StreamingInputCall(te.ctx)
ctx, _ := context.WithTimeout(te.ctx, time.Minute*3)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: 30s

stream, err := tc.StreamingInputCall(ctx)
if err != nil {
t.Fatalf("%v.StreamingInputCall(_) = _, %v, want <nil>", tc, err)
}

var sum int
for _, s := range reqSizes {
for _, s := range sizes {
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(s))
if err != nil {
t.Fatal(err)
Expand Down
6 changes: 6 additions & 0 deletions test/servertester.go
Expand Up @@ -287,3 +287,9 @@ func (st *serverTester) writeRSTStream(streamID uint32, code http2.ErrCode) {
st.t.Fatalf("Error writing RST_STREAM: %v", err)
}
}

func (st *serverTester) writeDataPadded(streamID uint32, endStream bool, data, padding []byte) {
if err := st.fr.WriteDataPadded(streamID, endStream, data, padding); err != nil {
st.t.Fatalf("Error writing DATA with padding: %v", err)
}
}
35 changes: 34 additions & 1 deletion transport/control.go
Expand Up @@ -58,6 +58,8 @@ const (
defaultServerKeepaliveTime = time.Duration(2 * time.Hour)
defaultServerKeepaliveTimeout = time.Duration(20 * time.Second)
defaultKeepalivePolicyMinTime = time.Duration(5 * time.Minute)
// max window limit set by HTTP2 Specs.
maxWindowSize = math.MaxInt32
)

// The following defines various control items which could flow through
Expand Down Expand Up @@ -167,14 +169,37 @@ type inFlow struct {
// The amount of data the application has consumed but grpc has not sent
// window update for them. Used to reduce window update frequency.
pendingUpdate uint32
// delta is the extra window update given by receiver when an application
// is reading data bigger in size than the inFlow limit.
delta int32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be unsigned since it can't be negative? (Also, you cast it to uint32 when reading and cast to int32 when assigning.)

}

func (f *inFlow) maybeAdjust(n uint32) uint32 {
if n > uint32(math.MaxInt32) {
n = uint32(math.MaxInt32)
}
f.mu.Lock()
defer f.mu.Unlock()
senderQuota := int32(f.limit - (f.pendingData + f.pendingUpdate))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small wording nit: rename to estimatedSenderQuota and estimatedUntransmittedData, for clarity here?

untransmittedData := int32(n - f.pendingData) // Casting into int32 since it could be negative.
if untransmittedData > senderQuota {
// Sender's window shouldn't go more than 2^31 - 1.
if f.limit+n > uint32(maxWindowSize) {
f.delta = maxWindowSize - int32(f.limit)
} else {
f.delta = int32(n)
}
return uint32(f.delta)
}
return 0
}

// onData is invoked when some data frame is received. It updates pendingData.
func (f *inFlow) onData(n uint32) error {
f.mu.Lock()
defer f.mu.Unlock()
f.pendingData += n
if f.pendingData+f.pendingUpdate > f.limit {
if f.pendingData+f.pendingUpdate > f.limit+uint32(f.delta) {
return fmt.Errorf("received %d-bytes data exceeding the limit %d bytes", f.pendingData+f.pendingUpdate, f.limit)
}
return nil
Expand All @@ -189,6 +214,14 @@ func (f *inFlow) onRead(n uint32) uint32 {
return 0
}
f.pendingData -= n
if f.delta > 0 {
f.delta -= int32(n)
Copy link
Member

@dfawley dfawley May 19, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would be easier to read without the casts and taking the negative of a negative number:

if n > f.delta {
  n -= f.delta
  f.delta = 0
} else {
  f.delta -= n
  n = 0
}

(Also would allow f.delta to be uint32 like it seems like it wants to be.)

n = 0
if f.delta < 0 {
n = uint32(-f.delta)
f.delta = 0
}
}
f.pendingUpdate += n
if f.pendingUpdate >= f.limit/4 {
wu := f.pendingUpdate
Expand Down
15 changes: 7 additions & 8 deletions transport/handler_server.go
Expand Up @@ -316,13 +316,12 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
req := ht.req

s := &Stream{
id: 0, // irrelevant
windowHandler: func(int) {}, // nothing
cancel: cancel,
buf: newRecvBuffer(),
st: ht,
method: req.URL.Path,
recvCompress: req.Header.Get("grpc-encoding"),
id: 0, // irrelevant
cancel: cancel,
buf: newRecvBuffer(),
st: ht,
method: req.URL.Path,
recvCompress: req.Header.Get("grpc-encoding"),
}
pr := &peer.Peer{
Addr: ht.RemoteAddr(),
Expand All @@ -333,7 +332,7 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
ctx = peer.NewContext(ctx, pr)
s.ctx = newContextWithStream(ctx, s)
s.dec = &recvBufferReader{ctx: s.ctx, recv: s.buf}
s.trReader = &recvBufferReader{ctx: s.ctx, recv: s.buf}

// readerDone is closed when the Body.Read-ing goroutine exits.
readerDone := make(chan struct{})
Expand Down
44 changes: 32 additions & 12 deletions transport/http2_client.go
Expand Up @@ -173,9 +173,9 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
conn, err := dial(ctx, opts.Dialer, addr.Addr)
if err != nil {
if opts.FailOnNonTempDialError {
return nil, connectionErrorf(isTemporary(err), err, "transport: %v", err)
return nil, connectionErrorf(isTemporary(err), err, "transport: Error while dialing %v", err)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: lowercase "e" in "error", and add a colon before the error being appended.

}
return nil, connectionErrorf(true, err, "transport: %v", err)
return nil, connectionErrorf(true, err, "transport: Error while dialing %v", err)
}
// Any further errors will close the underlying connection
defer func(conn net.Conn) {
Expand All @@ -194,7 +194,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
// Credentials handshake errors are typically considered permanent
// to avoid retrying on e.g. bad certificates.
temp := isTemporary(err)
return nil, connectionErrorf(temp, err, "transport: %v", err)
return nil, connectionErrorf(temp, err, "transport: authentication handshake failed %v", err)
}
isSecure = true
}
Expand Down Expand Up @@ -269,7 +269,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
n, err := t.conn.Write(clientPreface)
if err != nil {
t.Close()
return nil, connectionErrorf(true, err, "transport: %v", err)
return nil, connectionErrorf(true, err, "transport: failed to write client preface %v", err)
}
if n != len(clientPreface) {
t.Close()
Expand All @@ -285,13 +285,13 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (
}
if err != nil {
t.Close()
return nil, connectionErrorf(true, err, "transport: %v", err)
return nil, connectionErrorf(true, err, "transport: failed to write initial settings frame %v", err)
}
// Adjust the connection flow control window if needed.
if delta := uint32(icwz - defaultWindowSize); delta > 0 {
if err := t.framer.writeWindowUpdate(true, 0, delta); err != nil {
t.Close()
return nil, connectionErrorf(true, err, "transport: %v", err)
return nil, connectionErrorf(true, err, "transport: failed to write window update %v", err)
}
}
go t.controller()
Expand All @@ -316,18 +316,24 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
headerChan: make(chan struct{}),
}
t.nextID += 2
s.windowHandler = func(n int) {
t.updateWindow(s, uint32(n))
s.requestRead = func(n int) {
t.adjustWindow(s, uint32(n))
}
// The client side stream context should have exactly the same life cycle with the user provided context.
// That means, s.ctx should be read-only. And s.ctx is done iff ctx is done.
// So we use the original context here instead of creating a copy.
s.ctx = ctx
s.dec = &recvBufferReader{
ctx: s.ctx,
goAway: s.goAway,
recv: s.buf,
s.trReader = &transportReader{
reader: &recvBufferReader{
ctx: s.ctx,
goAway: s.goAway,
recv: s.buf,
},
windowHandler: func(n int) {
t.updateWindow(s, uint32(n))
},
}

return s
}

Expand Down Expand Up @@ -802,6 +808,20 @@ func (t *http2Client) getStream(f http2.Frame) (*Stream, bool) {
return s, ok
}

// adjustWindow sends out extra window update over the initial window size
// of stream if the application is requesting data larger in size than
// the window.
func (t *http2Client) adjustWindow(s *Stream, n uint32) {
s.mu.Lock()
defer s.mu.Unlock()
if s.state == streamDone {
return
}
if w := s.fc.maybeAdjust(n); n > 0 {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the if be checking w?

If not, check n first.

t.controlBuf.put(&windowUpdate{s.id, w})
}
}

// updateWindow adjusts the inbound quota for the stream and the transport.
// Window updates will deliver to the controller for sending when
// the cumulative quota exceeds the corresponding threshold.
Expand Down
32 changes: 25 additions & 7 deletions transport/http2_server.go
Expand Up @@ -274,10 +274,14 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
if len(state.mdata) > 0 {
s.ctx = metadata.NewIncomingContext(s.ctx, state.mdata)
}

s.dec = &recvBufferReader{
ctx: s.ctx,
recv: s.buf,
s.trReader = &transportReader{
reader: &recvBufferReader{
ctx: s.ctx,
recv: s.buf,
},
windowHandler: func(n int) {
t.updateWindow(s, uint32(n))
},
}
s.recvCompress = state.encoding
s.method = state.method
Expand Down Expand Up @@ -316,8 +320,8 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
t.idle = time.Time{}
}
t.mu.Unlock()
s.windowHandler = func(n int) {
t.updateWindow(s, uint32(n))
s.requestRead = func(n int) {
t.adjustWindow(s, uint32(n))
}
s.ctx = traceCtx(s.ctx, s.method)
if t.stats != nil {
Expand Down Expand Up @@ -358,7 +362,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.
return
}
if err != nil {
grpclog.Printf("transport: http2Server.HandleStreams failed to read frame: %v", err)
grpclog.Printf("transport: http2Server.HandleStreams failed to read initial settings frame: %v", err)
t.Close()
return
}
Expand Down Expand Up @@ -432,6 +436,20 @@ func (t *http2Server) getStream(f http2.Frame) (*Stream, bool) {
return s, true
}

// adjustWindow sends out extra window update over the initial window size
// of stream if the application is requesting data larger in size than
// the window.
func (t *http2Server) adjustWindow(s *Stream, n uint32) {
s.mu.Lock()
defer s.mu.Unlock()
if s.state == streamDone {
return
}
if w := s.fc.maybeAdjust(n); n > 0 {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above re: w and n.

Can this function be shared somehow? Should it be a method on the Stream instead of server/client?

t.controlBuf.put(&windowUpdate{s.id, w})
}
}

// updateWindow adjusts the inbound quota for the stream and the transport.
// Window updates will deliver to the controller for sending when
// the cumulative quota exceeds the corresponding threshold.
Expand Down