Skip to content

Commit

Permalink
Description: (#31)
Browse files Browse the repository at this point in the history
1. The exported receivers on the websocket.(*Conn) will not panic
2. The websocket.(*Conn) becomes nil, in some instances, leading to unpredictable panics in the receivers i.e websocket.(*Conn).SetWriteDeadline(), websocket.(*Conn).beginMessage()
3. The websocket.(*Conn).(net.Conn) becomes nil, in some instances, leading to unpredicatble panics in i.e websocket.(*Conn).(net.Conn).SetWriteDeadline()
4. The panics are handled by nil pointer check in case of websocket.(*Conn) & nil interface check in case of websocket.(*Conn).(net.Conn) before accessing the fields.
5. The panics are handled in the exported receivers at the moment and based on the need, we can handle the panics similarly in the unexported receivers as well.
6. 2 new errors (websocket.ErrNilConn & websocket.ErrNilNetConn) are defined in errors.go and returned for all the modified receivers, in which an error could be returned.
7. Additional return of an error in all the exported receivers would break existing applications.
8. Prevent the applciation crash from this package, return errors and let the application disconnect the websocket connection gracefully.
  • Loading branch information
gokpm committed Nov 2, 2022
1 parent c56b561 commit b8ea6bd
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 3 deletions.
105 changes: 102 additions & 3 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,28 +336,46 @@ func (c *Conn) setReadRemaining(n int64) error {

// Subprotocol returns the negotiated protocol for the connection.
func (c *Conn) Subprotocol() string {
if c == nil {
return ""
}
return c.subprotocol
}

// Close closes the underlying network connection without sending or waiting
// for a close message.
func (c *Conn) Close() error {
if c == nil {
return ErrNilConn
}
if c.conn == nil {
return ErrNilNetConn
}
return c.conn.Close()
}

// LocalAddr returns the local network address.
func (c *Conn) LocalAddr() net.Addr {
if c == nil || c.conn == nil {
return nil
}
return c.conn.LocalAddr()
}

// RemoteAddr returns the remote network address.
func (c *Conn) RemoteAddr() net.Addr {
if c == nil || c.conn == nil {
return nil
}
return c.conn.RemoteAddr()
}

// Write methods

func (c *Conn) writeFatal(err error) error {
if c == nil {
return ErrNilConn
}
err = hideTempErr(err)
c.writeErrMu.Lock()
if c.writeErr == nil {
Expand All @@ -368,6 +386,9 @@ func (c *Conn) writeFatal(err error) error {
}

func (c *Conn) read(n int) ([]byte, error) {
if c == nil {
return nil, ErrNilConn
}
p, err := c.br.Peek(n)
if err == io.EOF {
err = errUnexpectedEOF
Expand All @@ -377,6 +398,9 @@ func (c *Conn) read(n int) ([]byte, error) {
}

func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error {
if c == nil {
return ErrNilConn
}
<-c.mu
defer func() { c.mu <- struct{}{} }()

Expand All @@ -386,7 +410,9 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error
if err != nil {
return err
}

if c.conn == nil {
return ErrNilNetConn
}
c.conn.SetWriteDeadline(deadline)
if len(buf1) == 0 {
_, err = c.conn.Write(buf0)
Expand All @@ -411,6 +437,9 @@ func (c *Conn) writeBufs(bufs ...[]byte) error {
// WriteControl writes a control message with the given deadline. The allowed
// message types are CloseMessage, PingMessage and PongMessage.
func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error {
if c == nil {
return ErrNilConn
}
if !isControl(messageType) {
return errBadWriteOpCode
}
Expand Down Expand Up @@ -459,7 +488,9 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
if err != nil {
return err
}

if c.conn == nil {
return ErrNilNetConn
}
c.conn.SetWriteDeadline(deadline)
_, err = c.conn.Write(buf)
if err != nil {
Expand All @@ -473,6 +504,9 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er

// beginMessage prepares a connection and message writer for a new message.
func (c *Conn) beginMessage(mw *messageWriter, messageType int) error {
if c == nil {
return ErrNilConn
}
// Close previous writer if not already closed by the application. It's
// probably better to return an error in this situation, but we cannot
// change this without breaking existing applications.
Expand Down Expand Up @@ -516,6 +550,9 @@ func (c *Conn) beginMessage(mw *messageWriter, messageType int) error {
// All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and
// PongMessage) are supported.
func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
if c == nil {
return nil, ErrNilConn
}
var mw messageWriter
if err := c.beginMessage(&mw, messageType); err != nil {
return nil, err
Expand Down Expand Up @@ -555,6 +592,9 @@ func (w *messageWriter) endMessage(err error) error {
// final argument indicates that this is the last frame in the message.
func (w *messageWriter) flushFrame(final bool, extra []byte) error {
c := w.c
if c == nil {
return ErrNilConn
}
length := w.pos - maxFrameHeaderSize + len(extra)

// Check for invalid control frames.
Expand Down Expand Up @@ -733,6 +773,9 @@ func (w *messageWriter) Close() error {

// WritePreparedMessage writes prepared message into connection.
func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error {
if c == nil {
return ErrNilConn
}
frameType, frameData, err := pm.frame(prepareKey{
isServer: c.isServer,
compress: c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType),
Expand All @@ -756,7 +799,9 @@ func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error {
// WriteMessage is a helper method for getting a writer using NextWriter,
// writing the message and closing the writer.
func (c *Conn) WriteMessage(messageType int, data []byte) error {

if c == nil {
return ErrNilConn
}
if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) {
// Fast path with no allocations and single frame.

Expand Down Expand Up @@ -785,6 +830,9 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error {
// all future writes will return an error. A zero value for t means writes will
// not time out.
func (c *Conn) SetWriteDeadline(t time.Time) error {
if c == nil {
return ErrNilConn
}
c.writeDeadline = t
return nil
}
Expand Down Expand Up @@ -977,6 +1025,9 @@ func (c *Conn) advanceFrame() (int, error) {
}

func (c *Conn) handleProtocolError(message string) error {
if c == nil {
return ErrNilConn
}
data := FormatCloseMessage(CloseProtocolError, message)
if len(data) > maxControlFramePayloadSize {
data = data[:maxControlFramePayloadSize]
Expand All @@ -996,6 +1047,9 @@ func (c *Conn) handleProtocolError(message string) error {
// permanent. Once this method returns a non-nil error, all subsequent calls to
// this method return the same error.
func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
if c == nil {
return 0, nil, ErrNilConn
}
// Close previous reader, only relevant for decompression.
if c.reader != nil {
c.reader.Close()
Expand Down Expand Up @@ -1037,6 +1091,9 @@ type messageReader struct{ c *Conn }

func (r *messageReader) Read(b []byte) (int, error) {
c := r.c
if c == nil {
return 0, ErrNilConn
}
if c.messageReader != r {
return 0, io.EOF
}
Expand Down Expand Up @@ -1089,6 +1146,9 @@ func (r *messageReader) Close() error {
// ReadMessage is a helper method for getting a reader using NextReader and
// reading from that reader to a buffer.
func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
if c == nil {
return 0, nil, ErrNilConn
}
var r io.Reader
messageType, r, err = c.NextReader()
if err != nil {
Expand All @@ -1103,18 +1163,30 @@ func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
// all future reads will return an error. A zero value for t means reads will
// not time out.
func (c *Conn) SetReadDeadline(t time.Time) error {
if c == nil {
return ErrNilConn
}
if c.conn == nil {
return ErrNilNetConn
}
return c.conn.SetReadDeadline(t)
}

// SetReadLimit sets the maximum size in bytes for a message read from the peer. If a
// message exceeds the limit, the connection sends a close message to the peer
// and returns ErrReadLimit to the application.
func (c *Conn) SetReadLimit(limit int64) {
if c == nil {
return
}
c.readLimit = limit
}

// CloseHandler returns the current close handler
func (c *Conn) CloseHandler() func(code int, text string) error {
if c == nil {
return nil
}
return c.handleClose
}

Expand All @@ -1133,6 +1205,9 @@ func (c *Conn) CloseHandler() func(code int, text string) error {
// application must perform some action before sending a close message back to
// the peer.
func (c *Conn) SetCloseHandler(h func(code int, text string) error) {
if c == nil {
return
}
if h == nil {
h = func(code int, text string) error {
message := FormatCloseMessage(code, "")
Expand All @@ -1145,6 +1220,9 @@ func (c *Conn) SetCloseHandler(h func(code int, text string) error) {

// PingHandler returns the current ping handler
func (c *Conn) PingHandler() func(appData string) error {
if c == nil {
return nil
}
return c.handlePing
}

Expand All @@ -1156,6 +1234,9 @@ func (c *Conn) PingHandler() func(appData string) error {
// reader Read methods. The application must read the connection to process
// ping messages as described in the section on Control Messages above.
func (c *Conn) SetPingHandler(h func(appData string) error) {
if c == nil {
return
}
if h == nil {
h = func(message string) error {
err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait))
Expand All @@ -1172,6 +1253,9 @@ func (c *Conn) SetPingHandler(h func(appData string) error) {

// PongHandler returns the current pong handler
func (c *Conn) PongHandler() func(appData string) error {
if c == nil {
return nil
}
return c.handlePong
}

Expand All @@ -1183,6 +1267,9 @@ func (c *Conn) PongHandler() func(appData string) error {
// reader Read methods. The application must read the connection to process
// pong messages as described in the section on Control Messages above.
func (c *Conn) SetPongHandler(h func(appData string) error) {
if c == nil {
return
}
if h == nil {
h = func(string) error { return nil }
}
Expand All @@ -1193,20 +1280,29 @@ func (c *Conn) SetPongHandler(h func(appData string) error) {
// Note that writing to or reading from this connection directly will corrupt the
// WebSocket connection.
func (c *Conn) NetConn() net.Conn {
if c == nil {
return nil
}
return c.conn
}

// UnderlyingConn returns the internal net.Conn. This can be used to further
// modifications to connection specific flags.
// Deprecated: Use the NetConn method.
func (c *Conn) UnderlyingConn() net.Conn {
if c == nil {
return nil
}
return c.conn
}

// EnableWriteCompression enables and disables write compression of
// subsequent text and binary messages. This function is a noop if
// compression was not negotiated with the peer.
func (c *Conn) EnableWriteCompression(enable bool) {
if c == nil {
return
}
c.enableWriteCompression = enable
}

Expand All @@ -1215,6 +1311,9 @@ func (c *Conn) EnableWriteCompression(enable bool) {
// with the peer. See the compress/flate package for a description of
// compression levels.
func (c *Conn) SetCompressionLevel(level int) error {
if c == nil {
return ErrNilConn
}
if !isValidCompressionLevel(level) {
return errors.New("websocket: invalid compression level")
}
Expand Down
8 changes: 8 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package websocket

import "errors"

var (
ErrNilConn = errors.New("nil *Conn")
ErrNilNetConn = errors.New("nil net.Conn")
)

0 comments on commit b8ea6bd

Please sign in to comment.