diff --git a/conn.go b/conn.go index 5161ef8..8b416fb 100644 --- a/conn.go +++ b/conn.go @@ -336,28 +336,46 @@ func (c *Conn) setReadRemaining(n int64) error { // Subprotocol returns the negotiated protocol for the connection. func (c *Conn) Subprotocol() string { + if c == nil { + return "" + } return c.subprotocol } // Close closes the underlying network connection without sending or waiting // for a close message. func (c *Conn) Close() error { + if c == nil { + return ErrNilConn + } + if c.conn == nil { + return ErrNilNetConn + } return c.conn.Close() } // LocalAddr returns the local network address. func (c *Conn) LocalAddr() net.Addr { + if c == nil || c.conn == nil { + return nil + } return c.conn.LocalAddr() } // RemoteAddr returns the remote network address. func (c *Conn) RemoteAddr() net.Addr { + if c == nil || c.conn == nil { + return nil + } return c.conn.RemoteAddr() } // Write methods func (c *Conn) writeFatal(err error) error { + if c == nil { + return ErrNilConn + } err = hideTempErr(err) c.writeErrMu.Lock() if c.writeErr == nil { @@ -368,6 +386,9 @@ func (c *Conn) writeFatal(err error) error { } func (c *Conn) read(n int) ([]byte, error) { + if c == nil { + return nil, ErrNilConn + } p, err := c.br.Peek(n) if err == io.EOF { err = errUnexpectedEOF @@ -377,6 +398,9 @@ func (c *Conn) read(n int) ([]byte, error) { } func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error { + if c == nil { + return ErrNilConn + } <-c.mu defer func() { c.mu <- struct{}{} }() @@ -386,7 +410,9 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error if err != nil { return err } - + if c.conn == nil { + return ErrNilNetConn + } c.conn.SetWriteDeadline(deadline) if len(buf1) == 0 { _, err = c.conn.Write(buf0) @@ -411,6 +437,9 @@ func (c *Conn) writeBufs(bufs ...[]byte) error { // WriteControl writes a control message with the given deadline. The allowed // message types are CloseMessage, PingMessage and PongMessage. func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error { + if c == nil { + return ErrNilConn + } if !isControl(messageType) { return errBadWriteOpCode } @@ -459,7 +488,9 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er if err != nil { return err } - + if c.conn == nil { + return ErrNilNetConn + } c.conn.SetWriteDeadline(deadline) _, err = c.conn.Write(buf) if err != nil { @@ -473,6 +504,9 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er // beginMessage prepares a connection and message writer for a new message. func (c *Conn) beginMessage(mw *messageWriter, messageType int) error { + if c == nil { + return ErrNilConn + } // Close previous writer if not already closed by the application. It's // probably better to return an error in this situation, but we cannot // change this without breaking existing applications. @@ -516,6 +550,9 @@ func (c *Conn) beginMessage(mw *messageWriter, messageType int) error { // All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and // PongMessage) are supported. func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { + if c == nil { + return nil, ErrNilConn + } var mw messageWriter if err := c.beginMessage(&mw, messageType); err != nil { return nil, err @@ -555,6 +592,9 @@ func (w *messageWriter) endMessage(err error) error { // final argument indicates that this is the last frame in the message. func (w *messageWriter) flushFrame(final bool, extra []byte) error { c := w.c + if c == nil { + return ErrNilConn + } length := w.pos - maxFrameHeaderSize + len(extra) // Check for invalid control frames. @@ -733,6 +773,9 @@ func (w *messageWriter) Close() error { // WritePreparedMessage writes prepared message into connection. func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error { + if c == nil { + return ErrNilConn + } frameType, frameData, err := pm.frame(prepareKey{ isServer: c.isServer, compress: c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType), @@ -756,7 +799,9 @@ func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error { // WriteMessage is a helper method for getting a writer using NextWriter, // writing the message and closing the writer. func (c *Conn) WriteMessage(messageType int, data []byte) error { - + if c == nil { + return ErrNilConn + } if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) { // Fast path with no allocations and single frame. @@ -785,6 +830,9 @@ func (c *Conn) WriteMessage(messageType int, data []byte) error { // all future writes will return an error. A zero value for t means writes will // not time out. func (c *Conn) SetWriteDeadline(t time.Time) error { + if c == nil { + return ErrNilConn + } c.writeDeadline = t return nil } @@ -977,6 +1025,9 @@ func (c *Conn) advanceFrame() (int, error) { } func (c *Conn) handleProtocolError(message string) error { + if c == nil { + return ErrNilConn + } data := FormatCloseMessage(CloseProtocolError, message) if len(data) > maxControlFramePayloadSize { data = data[:maxControlFramePayloadSize] @@ -996,6 +1047,9 @@ func (c *Conn) handleProtocolError(message string) error { // permanent. Once this method returns a non-nil error, all subsequent calls to // this method return the same error. func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { + if c == nil { + return 0, nil, ErrNilConn + } // Close previous reader, only relevant for decompression. if c.reader != nil { c.reader.Close() @@ -1037,6 +1091,9 @@ type messageReader struct{ c *Conn } func (r *messageReader) Read(b []byte) (int, error) { c := r.c + if c == nil { + return 0, ErrNilConn + } if c.messageReader != r { return 0, io.EOF } @@ -1089,6 +1146,9 @@ func (r *messageReader) Close() error { // ReadMessage is a helper method for getting a reader using NextReader and // reading from that reader to a buffer. func (c *Conn) ReadMessage() (messageType int, p []byte, err error) { + if c == nil { + return 0, nil, ErrNilConn + } var r io.Reader messageType, r, err = c.NextReader() if err != nil { @@ -1103,6 +1163,12 @@ func (c *Conn) ReadMessage() (messageType int, p []byte, err error) { // all future reads will return an error. A zero value for t means reads will // not time out. func (c *Conn) SetReadDeadline(t time.Time) error { + if c == nil { + return ErrNilConn + } + if c.conn == nil { + return ErrNilNetConn + } return c.conn.SetReadDeadline(t) } @@ -1110,11 +1176,17 @@ func (c *Conn) SetReadDeadline(t time.Time) error { // message exceeds the limit, the connection sends a close message to the peer // and returns ErrReadLimit to the application. func (c *Conn) SetReadLimit(limit int64) { + if c == nil { + return + } c.readLimit = limit } // CloseHandler returns the current close handler func (c *Conn) CloseHandler() func(code int, text string) error { + if c == nil { + return nil + } return c.handleClose } @@ -1133,6 +1205,9 @@ func (c *Conn) CloseHandler() func(code int, text string) error { // application must perform some action before sending a close message back to // the peer. func (c *Conn) SetCloseHandler(h func(code int, text string) error) { + if c == nil { + return + } if h == nil { h = func(code int, text string) error { message := FormatCloseMessage(code, "") @@ -1145,6 +1220,9 @@ func (c *Conn) SetCloseHandler(h func(code int, text string) error) { // PingHandler returns the current ping handler func (c *Conn) PingHandler() func(appData string) error { + if c == nil { + return nil + } return c.handlePing } @@ -1156,6 +1234,9 @@ func (c *Conn) PingHandler() func(appData string) error { // reader Read methods. The application must read the connection to process // ping messages as described in the section on Control Messages above. func (c *Conn) SetPingHandler(h func(appData string) error) { + if c == nil { + return + } if h == nil { h = func(message string) error { err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait)) @@ -1172,6 +1253,9 @@ func (c *Conn) SetPingHandler(h func(appData string) error) { // PongHandler returns the current pong handler func (c *Conn) PongHandler() func(appData string) error { + if c == nil { + return nil + } return c.handlePong } @@ -1183,6 +1267,9 @@ func (c *Conn) PongHandler() func(appData string) error { // reader Read methods. The application must read the connection to process // pong messages as described in the section on Control Messages above. func (c *Conn) SetPongHandler(h func(appData string) error) { + if c == nil { + return + } if h == nil { h = func(string) error { return nil } } @@ -1193,6 +1280,9 @@ func (c *Conn) SetPongHandler(h func(appData string) error) { // Note that writing to or reading from this connection directly will corrupt the // WebSocket connection. func (c *Conn) NetConn() net.Conn { + if c == nil { + return nil + } return c.conn } @@ -1200,6 +1290,9 @@ func (c *Conn) NetConn() net.Conn { // modifications to connection specific flags. // Deprecated: Use the NetConn method. func (c *Conn) UnderlyingConn() net.Conn { + if c == nil { + return nil + } return c.conn } @@ -1207,6 +1300,9 @@ func (c *Conn) UnderlyingConn() net.Conn { // subsequent text and binary messages. This function is a noop if // compression was not negotiated with the peer. func (c *Conn) EnableWriteCompression(enable bool) { + if c == nil { + return + } c.enableWriteCompression = enable } @@ -1215,6 +1311,9 @@ func (c *Conn) EnableWriteCompression(enable bool) { // with the peer. See the compress/flate package for a description of // compression levels. func (c *Conn) SetCompressionLevel(level int) error { + if c == nil { + return ErrNilConn + } if !isValidCompressionLevel(level) { return errors.New("websocket: invalid compression level") } diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..865a843 --- /dev/null +++ b/errors.go @@ -0,0 +1,8 @@ +package websocket + +import "errors" + +var ( + ErrNilConn = errors.New("nil *Conn") + ErrNilNetConn = errors.New("nil net.Conn") +)