Skip to content

Commit

Permalink
grpc: Change server stream context handling (#6598)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasweq authored Sep 1, 2023
1 parent e498bbc commit 8eb4ac4
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 117 deletions.
2 changes: 1 addition & 1 deletion internal/transport/handler_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
return err
}

func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), traceCtx func(context.Context, string) context.Context) {
func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream)) {
// With this transport type there will be exactly 1 stream: this HTTP request.

ctx := ht.req.Context()
Expand Down
5 changes: 0 additions & 5 deletions internal/transport/handler_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,6 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
}
st.ht.HandleStreams(
func(s *Stream) { go handleStream(s) },
func(ctx context.Context, method string) context.Context { return ctx },
)
wantHeader := http.Header{
"Date": nil,
Expand Down Expand Up @@ -349,7 +348,6 @@ func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string)
}
st.ht.HandleStreams(
func(s *Stream) { go handleStream(s) },
func(ctx context.Context, method string) context.Context { return ctx },
)
wantHeader := http.Header{
"Date": nil,
Expand Down Expand Up @@ -399,7 +397,6 @@ func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
}
ht.HandleStreams(
func(s *Stream) { go runStream(s) },
func(ctx context.Context, method string) context.Context { return ctx },
)
wantHeader := http.Header{
"Date": nil,
Expand Down Expand Up @@ -452,7 +449,6 @@ func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handl
st := newHandleStreamTest(t)
st.ht.HandleStreams(
func(s *Stream) { go handleStream(st, s) },
func(ctx context.Context, method string) context.Context { return ctx },
)
}

Expand Down Expand Up @@ -486,7 +482,6 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
}
hst.ht.HandleStreams(
func(s *Stream) { go handleStream(s) },
func(ctx context.Context, method string) context.Context { return ctx },
)
wantHeader := http.Header{
"Date": nil,
Expand Down
7 changes: 3 additions & 4 deletions internal/transport/http2_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,

// operateHeaders takes action on the decoded headers. Returns an error if fatal
// error encountered and transport needs to close, otherwise returns nil.
func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream), traceCtx func(context.Context, string) context.Context) error {
func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) error {
// Acquire max stream ID lock for entire duration
t.maxStreamMu.Lock()
defer t.maxStreamMu.Unlock()
Expand Down Expand Up @@ -597,7 +597,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
s.requestRead = func(n int) {
t.adjustWindow(s, uint32(n))
}
s.ctx = traceCtx(s.ctx, s.method)
for _, sh := range t.stats {
s.ctx = sh.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
inHeader := &stats.InHeader{
Expand Down Expand Up @@ -635,7 +634,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
// HandleStreams receives incoming streams using the given handler. This is
// typically run in a separate goroutine.
// traceCtx attaches trace to ctx and returns the new context.
func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.Context, string) context.Context) {
func (t *http2Server) HandleStreams(handle func(*Stream)) {
defer close(t.readerDone)
for {
t.controlBuf.throttle()
Expand Down Expand Up @@ -670,7 +669,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.
}
switch frame := frame.(type) {
case *http2.MetaHeadersFrame:
if err := t.operateHeaders(frame, handle, traceCtx); err != nil {
if err := t.operateHeaders(frame, handle); err != nil {
t.Close(err)
break
}
Expand Down
2 changes: 1 addition & 1 deletion internal/transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ type ClientTransport interface {
// Write methods for a given Stream will be called serially.
type ServerTransport interface {
// HandleStreams receives incoming streams using the given handler.
HandleStreams(func(*Stream), func(context.Context, string) context.Context)
HandleStreams(func(*Stream))

// WriteHeader sends the header metadata for the given stream.
// WriteHeader may not be called on all streams.
Expand Down
22 changes: 2 additions & 20 deletions internal/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,32 +353,20 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
s.mu.Unlock()
switch ht {
case notifyCall:
go transport.HandleStreams(h.handleStreamAndNotify,
func(ctx context.Context, _ string) context.Context {
return ctx
})
go transport.HandleStreams(h.handleStreamAndNotify)
case suspended:
go transport.HandleStreams(func(*Stream) {}, // Do nothing to handle the stream.
func(ctx context.Context, method string) context.Context {
return ctx
})
go transport.HandleStreams(func(*Stream) {})
case misbehaved:
go transport.HandleStreams(func(s *Stream) {
go h.handleStreamMisbehave(t, s)
}, func(ctx context.Context, method string) context.Context {
return ctx
})
case encodingRequiredStatus:
go transport.HandleStreams(func(s *Stream) {
go h.handleStreamEncodingRequiredStatus(s)
}, func(ctx context.Context, method string) context.Context {
return ctx
})
case invalidHeaderField:
go transport.HandleStreams(func(s *Stream) {
go h.handleStreamInvalidHeaderField(s)
}, func(ctx context.Context, method string) context.Context {
return ctx
})
case delayRead:
h.notify = make(chan struct{})
Expand All @@ -388,20 +376,14 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
s.mu.Unlock()
go transport.HandleStreams(func(s *Stream) {
go h.handleStreamDelayRead(t, s)
}, func(ctx context.Context, method string) context.Context {
return ctx
})
case pingpong:
go transport.HandleStreams(func(s *Stream) {
go h.handleStreamPingPong(t, s)
}, func(ctx context.Context, method string) context.Context {
return ctx
})
default:
go transport.HandleStreams(func(s *Stream) {
go h.handleStream(t, s)
}, func(ctx context.Context, method string) context.Context {
return ctx
})
}
}
Expand Down
Loading

0 comments on commit 8eb4ac4

Please sign in to comment.