diff --git a/go/grpcweb/websocket_wrapper.go b/go/grpcweb/websocket_wrapper.go index 575ced42..a1513b0c 100644 --- a/go/grpcweb/websocket_wrapper.go +++ b/go/grpcweb/websocket_wrapper.go @@ -3,6 +3,7 @@ package grpcweb import ( "bufio" "bytes" + "context" "encoding/binary" "errors" "io" @@ -110,6 +111,7 @@ type webSocketWrappedReader struct { respWriter *webSocketResponseWriter remainingBuffer []byte remainingError error + cancel context.CancelFunc } func (w *webSocketWrappedReader) Close() error { @@ -152,7 +154,7 @@ func (w *webSocketWrappedReader) Read(p []byte) (int, error) { messageType, framePayload, err := w.wsConn.ReadMessage() if err == io.EOF || messageType == -1 { // The client has closed the connection. Indicate to the response writer that it should close - w.respWriter.closeNotifyChan <- true + w.cancel() return 0, io.EOF } @@ -193,12 +195,13 @@ func (w *webSocketWrappedReader) Read(p []byte) (int, error) { return len(p), nil } -func newWebsocketWrappedReader(wsConn *websocket.Conn, respWriter *webSocketResponseWriter) *webSocketWrappedReader { +func newWebsocketWrappedReader(wsConn *websocket.Conn, respWriter *webSocketResponseWriter, cancel context.CancelFunc) *webSocketWrappedReader { return &webSocketWrappedReader{ wsConn: wsConn, respWriter: respWriter, remainingBuffer: nil, remainingError: nil, + cancel: cancel, } } diff --git a/go/grpcweb/wrapper.go b/go/grpcweb/wrapper.go index ab0e1b53..f92b6a27 100644 --- a/go/grpcweb/wrapper.go +++ b/go/grpcweb/wrapper.go @@ -7,6 +7,7 @@ import ( "net/http" "net/url" + "context" "strings" "time" @@ -148,14 +149,17 @@ func (w *WrappedGrpcServer) handleWebSocket(wsConn *websocket.Conn, req *http.Re return } + ctx, cancelFunc := context.WithCancel(req.Context()) + defer cancelFunc() + respWriter := newWebSocketResponseWriter(wsConn) - wrappedReader := newWebsocketWrappedReader(wsConn, respWriter) + wrappedReader := newWebsocketWrappedReader(wsConn, respWriter, cancelFunc) req.Body = wrappedReader req.Method = http.MethodPost req.Header = headers - w.server.ServeHTTP(respWriter, hackIntoNormalGrpcRequest(req)) + w.server.ServeHTTP(respWriter, hackIntoNormalGrpcRequest(req.WithContext(ctx))) } // IsGrpcWebRequest determines if a request is a gRPC-Web request by checking that the "content-type" is