Skip to content

Commit

Permalink
feat(websocket): change websocket lib to nhooyr.io/websocket
Browse files Browse the repository at this point in the history
  • Loading branch information
Hellysonrp committed Jan 4, 2021
1 parent f9ddeb1 commit 55a5494
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 50 deletions.
41 changes: 32 additions & 9 deletions Gopkg.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Gopkg.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ required = [
version = "1.1.0"

[[constraint]]
name = "github.com/gorilla/websocket"
version = "1.2.0"
name = "nhooyr.io/websocket"
version = "1.8.6"

[[constraint]]
branch = "master"
Expand Down
40 changes: 18 additions & 22 deletions go/grpcweb/websocket_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ import (
"time"

"github.com/desertbit/timer"
"github.com/gorilla/websocket"
"golang.org/x/net/http2"
"nhooyr.io/websocket"
)

type webSocketResponseWriter struct {
Expand All @@ -24,40 +24,34 @@ type webSocketResponseWriter struct {
flushedHeaders http.Header
timeOutInterval time.Duration
timer *timer.Timer
context context.Context
}

func newWebSocketResponseWriter(wsConn *websocket.Conn) *webSocketResponseWriter {
func newWebSocketResponseWriter(ctx context.Context, wsConn *websocket.Conn) *webSocketResponseWriter {
return &webSocketResponseWriter{
writtenHeaders: false,
headers: make(http.Header),
flushedHeaders: make(http.Header),
wsConn: wsConn,
context: ctx,
}
}

func (w *webSocketResponseWriter) enablePing(timeOutInterval time.Duration) {
w.timeOutInterval = timeOutInterval
w.timer = timer.NewTimer(w.timeOutInterval)
dispose := make(chan bool)
w.wsConn.SetCloseHandler(func(code int, text string) error {
close(dispose)
return nil
})
go w.ping(dispose)
go w.ping()
}

func (w *webSocketResponseWriter) ping(dispose chan bool) {
if dispose == nil {
return
}
func (w *webSocketResponseWriter) ping() {
defer w.timer.Stop()
for {
select {
case <-dispose:
case <-w.context.Done():
return
case <-w.timer.C:
w.timer.Reset(w.timeOutInterval)
w.wsConn.WriteMessage(websocket.PingMessage, []byte{})
w.wsConn.Ping(w.context)
}
}
}
Expand All @@ -73,16 +67,16 @@ func (w *webSocketResponseWriter) Write(b []byte) (int, error) {
if w.timeOutInterval > time.Second && w.timer != nil {
w.timer.Reset(w.timeOutInterval)
}
return len(b), w.wsConn.WriteMessage(websocket.BinaryMessage, b)
return len(b), w.wsConn.Write(w.context, websocket.MessageBinary, b)
}

func (w *webSocketResponseWriter) writeHeaderFrame(headers http.Header) {
headerBuffer := new(bytes.Buffer)
headers.Write(headerBuffer)
headerGrpcDataHeader := []byte{1 << 7, 0, 0, 0, 0} // MSB=1 indicates this is a header data frame.
binary.BigEndian.PutUint32(headerGrpcDataHeader[1:5], uint32(headerBuffer.Len()))
w.wsConn.WriteMessage(websocket.BinaryMessage, headerGrpcDataHeader)
w.wsConn.WriteMessage(websocket.BinaryMessage, headerBuffer.Bytes())
w.wsConn.Write(w.context, websocket.MessageBinary, headerGrpcDataHeader)
w.wsConn.Write(w.context, websocket.MessageBinary, headerBuffer.Bytes())
}

func (w *webSocketResponseWriter) copyFlushedHeaders() {
Expand Down Expand Up @@ -127,12 +121,13 @@ type webSocketWrappedReader struct {
respWriter *webSocketResponseWriter
remainingBuffer []byte
remainingError error
context context.Context
cancel context.CancelFunc
}

func (w *webSocketWrappedReader) Close() error {
w.respWriter.FlushTrailers()
return w.wsConn.Close()
return w.wsConn.Close(websocket.StatusNormalClosure, "request body closed")
}

// First byte of a binary WebSocket frame is used for control flow:
Expand Down Expand Up @@ -167,15 +162,15 @@ func (w *webSocketWrappedReader) Read(p []byte) (int, error) {
}

// Read a whole frame from the WebSocket connection
messageType, framePayload, err := w.wsConn.ReadMessage()
if err == io.EOF || messageType == -1 {
messageType, framePayload, err := w.wsConn.Read(w.context)
if err == io.EOF || messageType == 0 {
// The client has closed the connection. Indicate to the response writer that it should close
w.cancel()
return 0, io.EOF
}

// Only Binary frames are valid
if messageType != websocket.BinaryMessage {
if messageType != websocket.MessageBinary {
return 0, errors.New("websocket frame was not a binary frame")
}

Expand Down Expand Up @@ -211,12 +206,13 @@ func (w *webSocketWrappedReader) Read(p []byte) (int, error) {
return len(p), nil
}

func newWebsocketWrappedReader(wsConn *websocket.Conn, respWriter *webSocketResponseWriter, cancel context.CancelFunc) *webSocketWrappedReader {
func newWebsocketWrappedReader(ctx context.Context, wsConn *websocket.Conn, respWriter *webSocketResponseWriter, cancel context.CancelFunc) *webSocketWrappedReader {
return &webSocketWrappedReader{
wsConn: wsConn,
respWriter: respWriter,
remainingBuffer: nil,
remainingError: nil,
context: ctx,
cancel: cancel,
}
}
Expand Down
31 changes: 14 additions & 17 deletions go/grpcweb/wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ import (
"strings"
"time"

"github.com/gorilla/websocket"
"github.com/rs/cors"
"google.golang.org/grpc"
"google.golang.org/grpc/grpclog"
"nhooyr.io/websocket"
)

var (
Expand Down Expand Up @@ -147,18 +147,15 @@ func (w *WrappedGrpcServer) HandleGrpcWebRequest(resp http.ResponseWriter, req *
intResp.finishRequest(req)
}

var websocketUpgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool { return true },
Subprotocols: []string{"grpc-websockets"},
}

// HandleGrpcWebsocketRequest takes a HTTP request that is assumed to be a gRPC-Websocket request and wraps it with a
// compatibility layer to transform it to a standard gRPC request for the wrapped gRPC server and transforms the
// response to comply with the gRPC-Web protocol.
func (w *WrappedGrpcServer) HandleGrpcWebsocketRequest(resp http.ResponseWriter, req *http.Request) {
wsConn, err := websocketUpgrader.Upgrade(resp, req, nil)

wsConn, err := websocket.Accept(resp, req, &websocket.AcceptOptions{
InsecureSkipVerify: true, // managed by ServeHTTP
Subprotocols: []string{"grpc-websockets"},
})
if err != nil {
grpclog.Errorf("Unable to upgrade websocket request: %v", err)
return
Expand All @@ -170,13 +167,16 @@ func (w *WrappedGrpcServer) HandleGrpcWebsocketRequest(resp http.ResponseWriter,
}
}

messageType, readBytes, err := wsConn.ReadMessage()
ctx, cancelFunc := context.WithCancel(req.Context())
defer cancelFunc()

messageType, readBytes, err := wsConn.Read(ctx)
if err != nil {
grpclog.Errorf("Unable to read first websocket message: %v", err)
grpclog.Errorf("Unable to read first websocket message: %v %v %v", messageType, readBytes, err)
return
}

if messageType != websocket.BinaryMessage {
if messageType != websocket.MessageBinary {
grpclog.Errorf("First websocket message is non-binary")
return
}
Expand All @@ -187,14 +187,11 @@ func (w *WrappedGrpcServer) HandleGrpcWebsocketRequest(resp http.ResponseWriter,
return
}

ctx, cancelFunc := context.WithCancel(req.Context())
defer cancelFunc()

respWriter := newWebSocketResponseWriter(wsConn)
respWriter := newWebSocketResponseWriter(ctx, wsConn)
if w.opts.websocketPingInterval >= time.Second {
respWriter.enablePing(w.opts.websocketPingInterval)
}
wrappedReader := newWebsocketWrappedReader(wsConn, respWriter, cancelFunc)
wrappedReader := newWebsocketWrappedReader(ctx, wsConn, respWriter, cancelFunc)

for name, values := range wsHeaders {
headers[name] = values
Expand Down

0 comments on commit 55a5494

Please sign in to comment.