diff --git a/Makefile b/Makefile index ad1ba257..f9f31c49 100644 --- a/Makefile +++ b/Makefile @@ -2,12 +2,6 @@ all: fmt lint test .SILENT: -.PHONY: * - -.ONESHELL: -SHELL = bash -.SHELLFLAGS = -ceuo pipefail - include ci/fmt.mk include ci/lint.mk include ci/test.mk diff --git a/README.md b/README.md index 631a14c9..5dddf84a 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,6 @@ # websocket -[![release](https://img.shields.io/github/v/release/nhooyr/websocket?color=6b9ded&sort=semver)](https://github.com/nhooyr/websocket/releases) [![godoc](https://godoc.org/nhooyr.io/websocket?status.svg)](https://godoc.org/nhooyr.io/websocket) -[![coverage](https://img.shields.io/coveralls/github/nhooyr/websocket?color=65d6a4)](https://coveralls.io/github/nhooyr/websocket) -[![ci](https://github.com/nhooyr/websocket/workflows/ci/badge.svg)](https://github.com/nhooyr/websocket/actions) websocket is a minimal and idiomatic WebSocket library for Go. @@ -17,7 +14,8 @@ go get nhooyr.io/websocket - Minimal and idiomatic API - First class [context.Context](https://blog.golang.org/context) support -- Thorough tests, fully passes the WebSocket [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) +- Fully passes the WebSocket [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) +- Thorough unit tests with [90% coverage](https://coveralls.io/github/nhooyr/websocket) - [Minimal dependencies](https://godoc.org/nhooyr.io/websocket?imports) - JSON and protobuf helpers in the [wsjson](https://godoc.org/nhooyr.io/websocket/wsjson) and [wspb](https://godoc.org/nhooyr.io/websocket/wspb) subpackages - Zero alloc reads and writes @@ -111,8 +109,7 @@ Advantages of nhooyr.io/websocket: - Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/). - Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support - Gorilla only supports no context takeover mode - - Uses [klauspost/compress](https://github.com/klauspost/compress) for optimized compression - - See [gorilla/websocket#203](https://github.com/gorilla/websocket/issues/203) + - We use [klauspost/compress](https://github.com/klauspost/compress) for much lower memory usage ([gorilla/websocket#203](https://github.com/gorilla/websocket/issues/203)) - [CloseRead](https://godoc.org/nhooyr.io/websocket#Conn.CloseRead) helper ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492)) - Actively maintained ([gorilla/websocket#370](https://github.com/gorilla/websocket/issues/370)) diff --git a/accept.go b/accept.go index 75d6d643..5a162de0 100644 --- a/accept.go +++ b/accept.go @@ -6,14 +6,15 @@ import ( "bytes" "crypto/sha1" "encoding/base64" + "errors" + "fmt" "io" "net/http" "net/textproto" "net/url" + "strconv" "strings" - "golang.org/x/xerrors" - "nhooyr.io/websocket/internal/errd" ) @@ -85,7 +86,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con hj, ok := w.(http.Hijacker) if !ok { - err = xerrors.New("http.ResponseWriter does not implement http.Hijacker") + err = errors.New("http.ResponseWriter does not implement http.Hijacker") http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) return nil, err } @@ -110,7 +111,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con netConn, brw, err := hj.Hijack() if err != nil { - err = xerrors.Errorf("failed to hijack connection: %w", err) + err = fmt.Errorf("failed to hijack connection: %w", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return nil, err } @@ -133,32 +134,32 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) { if !r.ProtoAtLeast(1, 1) { - return http.StatusUpgradeRequired, xerrors.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto) + return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto) } if !headerContainsToken(r.Header, "Connection", "Upgrade") { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") - return http.StatusUpgradeRequired, xerrors.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) + return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) } if !headerContainsToken(r.Header, "Upgrade", "websocket") { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "websocket") - return http.StatusUpgradeRequired, xerrors.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) + return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) } if r.Method != "GET" { - return http.StatusMethodNotAllowed, xerrors.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method) + return http.StatusMethodNotAllowed, fmt.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method) } if r.Header.Get("Sec-WebSocket-Version") != "13" { w.Header().Set("Sec-WebSocket-Version", "13") - return http.StatusBadRequest, xerrors.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) + return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) } if r.Header.Get("Sec-WebSocket-Key") == "" { - return http.StatusBadRequest, xerrors.New("WebSocket protocol violation: missing Sec-WebSocket-Key") + return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key") } return 0, nil @@ -169,10 +170,10 @@ func authenticateOrigin(r *http.Request) error { if origin != "" { u, err := url.Parse(origin) if err != nil { - return xerrors.Errorf("failed to parse Origin header %q: %w", origin, err) + return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) } if !strings.EqualFold(u.Host, r.Host) { - return xerrors.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) + return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) } } return nil @@ -208,6 +209,7 @@ func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionM func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { copts := mode.opts() + copts.serverMaxWindowBits = 8 for _, p := range ext.params { switch p { @@ -219,11 +221,31 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi continue } - if strings.HasPrefix(p, "client_max_window_bits") || strings.HasPrefix(p, "server_max_window_bits") { + if strings.HasPrefix(p, "client_max_window_bits") { + continue + + // bits, ok := parseExtensionParameter(p, 15) + // if !ok || bits < 8 || bits > 16 { + // err := fmt.Errorf("invalid client_max_window_bits: %q", p) + // http.Error(w, err.Error(), http.StatusBadRequest) + // return nil, err + // } + // copts.clientMaxWindowBits = bits + // continue + } + + if false && strings.HasPrefix(p, "server_max_window_bits") { + // We always send back 8 but make sure to validate. + bits, ok := parseExtensionParameter(p, 0) + if !ok || bits < 8 || bits > 16 { + err := fmt.Errorf("invalid server_max_window_bits: %q", p) + http.Error(w, err.Error(), http.StatusBadRequest) + return nil, err + } continue } - err := xerrors.Errorf("unsupported permessage-deflate parameter: %q", p) + err := fmt.Errorf("unsupported permessage-deflate parameter: %q", p) http.Error(w, err.Error(), http.StatusBadRequest) return nil, err } @@ -233,6 +255,21 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi return copts, nil } +// parseExtensionParameter parses the value in the extension parameter p. +// It falls back to defaultVal if there is no value. +// If defaultVal == 0, then ok == false if there is no value. +func parseExtensionParameter(p string, defaultVal int) (int, bool) { + ps := strings.Split(p, "=") + if len(ps) == 1 { + if defaultVal > 0 { + return defaultVal, true + } + return 0, false + } + i, e := strconv.Atoi(strings.Trim(ps[1], `"`)) + return i, e == nil +} + func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { copts := mode.opts() // The peer must explicitly request it. @@ -253,7 +290,7 @@ func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode Com // // Either way, we're only implementing this for webkit which never sends the max_window_bits // parameter so we don't need to worry about it. - err := xerrors.Errorf("unsupported x-webkit-deflate-frame parameter: %q", p) + err := fmt.Errorf("unsupported x-webkit-deflate-frame parameter: %q", p) http.Error(w, err.Error(), http.StatusBadRequest) return nil, err } diff --git a/accept_js.go b/accept_js.go index 5db12d7b..724b35b5 100644 --- a/accept_js.go +++ b/accept_js.go @@ -1,9 +1,8 @@ package websocket import ( + "errors" "net/http" - - "golang.org/x/xerrors" ) // AcceptOptions represents Accept's options. @@ -16,5 +15,5 @@ type AcceptOptions struct { // Accept is stubbed out for Wasm. func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { - return nil, xerrors.New("unimplemented") + return nil, errors.New("unimplemented") } diff --git a/accept_test.go b/accept_test.go index 53338e17..555f0dc0 100644 --- a/accept_test.go +++ b/accept_test.go @@ -4,14 +4,13 @@ package websocket import ( "bufio" + "errors" "net" "net/http" "net/http/httptest" "strings" "testing" - "golang.org/x/xerrors" - "nhooyr.io/websocket/internal/test/assert" ) @@ -80,7 +79,7 @@ func TestAccept(t *testing.T) { w := mockHijacker{ ResponseWriter: httptest.NewRecorder(), hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) { - return nil, nil, xerrors.New("haha") + return nil, nil, errors.New("haha") }, } @@ -328,6 +327,7 @@ func Test_acceptCompression(t *testing.T) { expCopts: &compressionOptions{ clientNoContextTakeover: true, serverNoContextTakeover: true, + serverMaxWindowBits: 8, }, }, { diff --git a/autobahn_test.go b/autobahn_test.go index fb24a06b..50473534 100644 --- a/autobahn_test.go +++ b/autobahn_test.go @@ -15,8 +15,6 @@ import ( "testing" "time" - "golang.org/x/xerrors" - "nhooyr.io/websocket" "nhooyr.io/websocket/internal/errd" "nhooyr.io/websocket/internal/test/assert" @@ -108,7 +106,7 @@ func wstestClientServer(ctx context.Context) (url string, closeFn func(), err er "exclude-cases": excludedAutobahnCases, }) if err != nil { - return "", nil, xerrors.Errorf("failed to write spec: %w", err) + return "", nil, fmt.Errorf("failed to write spec: %w", err) } ctx, cancel := context.WithTimeout(context.Background(), time.Minute*15) @@ -126,7 +124,7 @@ func wstestClientServer(ctx context.Context) (url string, closeFn func(), err er wstest := exec.CommandContext(ctx, "wstest", args...) err = wstest.Start() if err != nil { - return "", nil, xerrors.Errorf("failed to start wstest: %w", err) + return "", nil, fmt.Errorf("failed to start wstest: %w", err) } return url, func() { @@ -209,7 +207,7 @@ func unusedListenAddr() (_ string, err error) { func tempJSONFile(v interface{}) (string, error) { f, err := ioutil.TempFile("", "temp.json") if err != nil { - return "", xerrors.Errorf("temp file: %w", err) + return "", fmt.Errorf("temp file: %w", err) } defer f.Close() @@ -217,12 +215,12 @@ func tempJSONFile(v interface{}) (string, error) { e.SetIndent("", "\t") err = e.Encode(v) if err != nil { - return "", xerrors.Errorf("json encode: %w", err) + return "", fmt.Errorf("json encode: %w", err) } err = f.Close() if err != nil { - return "", xerrors.Errorf("close temp file: %w", err) + return "", fmt.Errorf("close temp file: %w", err) } return f.Name(), nil diff --git a/ci/ensure_fmt.sh b/ci/ensure_fmt.sh new file mode 100755 index 00000000..6fe9cb18 --- /dev/null +++ b/ci/ensure_fmt.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +set -euo pipefail + +main() { + local files + mapfile -t files < <(git ls-files --other --modified --exclude-standard) + if [[ ${files[*]} == "" ]]; then + return + fi + + echo "Files need generation or are formatted incorrectly:" + for f in "${files[@]}"; do + echo " $f" + done + + echo + echo "Please run the following locally:" + echo " make fmt" + exit 1 +} + +main "$@" diff --git a/ci/fmt.mk b/ci/fmt.mk index f3969721..f313562c 100644 --- a/ci/fmt.mk +++ b/ci/fmt.mk @@ -1,12 +1,6 @@ -fmt: modtidy gofmt goimports prettier +fmt: modtidy gofmt goimports prettier shfmt ifdef CI - if [[ $$(git ls-files --other --modified --exclude-standard) != "" ]]; then - echo "Files need generation or are formatted incorrectly:" - git -c color.ui=always status | grep --color=no '\e\[31m' - echo "Please run the following locally:" - echo " make fmt" - exit 1 - fi + ./ci/ensure_fmt.sh endif modtidy: gen @@ -23,3 +17,6 @@ prettier: gen: stringer -type=opcode,MessageType,StatusCode -output=stringer.go + +shfmt: + shfmt -i 2 -w -s -sr $$(git ls-files "*.sh") diff --git a/ci/image/Dockerfile b/ci/image/Dockerfile index 88c96502..ed408eda 100644 --- a/ci/image/Dockerfile +++ b/ci/image/Dockerfile @@ -1,10 +1,12 @@ FROM golang:1 RUN apt-get update -RUN apt-get install -y chromium npm +RUN apt-get install -y chromium npm shellcheck + +ARG SHFMT_URL=https://github.com/mvdan/sh/releases/download/v3.0.1/shfmt_v3.0.1_linux_amd64 +RUN curl -L $SHFMT_URL > /usr/local/bin/shfmt && chmod +x /usr/local/bin/shfmt ENV GOFLAGS="-mod=readonly" -ENV PAGER=cat ENV CI=true ENV MAKEFLAGS="--jobs=16 --output-sync=target" diff --git a/ci/lint.mk b/ci/lint.mk index 031f0de3..4335e7b1 100644 --- a/ci/lint.mk +++ b/ci/lint.mk @@ -1,4 +1,4 @@ -lint: govet golint +lint: govet golint govet-wasm golint-wasm shellcheck govet: go vet ./... @@ -11,3 +11,6 @@ golint: golint-wasm: GOOS=js GOARCH=wasm golint -set_exit_status ./... + +shellcheck: + shellcheck $$(git ls-files "*.sh") diff --git a/ci/test.mk b/ci/test.mk index 3d1f0ed1..c62a25b6 100644 --- a/ci/test.mk +++ b/ci/test.mk @@ -7,7 +7,6 @@ ci/out/coverage.html: gotest go tool cover -html=ci/out/coverage.prof -o=ci/out/coverage.html coveralls: gotest - # https://github.com/coverallsapp/github-action/blob/master/src/run.ts echo "--- coveralls" goveralls -coverprofile=ci/out/coverage.prof diff --git a/close.go b/close.go index 20073233..7cbc19e9 100644 --- a/close.go +++ b/close.go @@ -1,9 +1,8 @@ package websocket import ( + "errors" "fmt" - - "golang.org/x/xerrors" ) // StatusCode represents a WebSocket status code. @@ -53,7 +52,7 @@ const ( // CloseError is returned when the connection is closed with a status and reason. // -// Use Go 1.13's xerrors.As to check for this error. +// Use Go 1.13's errors.As to check for this error. // Also see the CloseStatus helper. type CloseError struct { Code StatusCode @@ -64,13 +63,13 @@ func (ce CloseError) Error() string { return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason) } -// CloseStatus is a convenience wrapper around Go 1.13's xerrors.As to grab +// CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab // the status code from a CloseError. // // -1 will be returned if the passed error is nil or not a CloseError. func CloseStatus(err error) StatusCode { var ce CloseError - if xerrors.As(err, &ce) { + if errors.As(err, &ce) { return ce.Code } return -1 diff --git a/close_notjs.go b/close_notjs.go index 3367ea01..c25b088f 100644 --- a/close_notjs.go +++ b/close_notjs.go @@ -5,11 +5,11 @@ package websocket import ( "context" "encoding/binary" + "errors" + "fmt" "log" "time" - "golang.org/x/xerrors" - "nhooyr.io/websocket/internal/errd" ) @@ -46,7 +46,7 @@ func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { return nil } -var errAlreadyWroteClose = xerrors.New("already wrote close") +var errAlreadyWroteClose = errors.New("already wrote close") func (c *Conn) writeClose(code StatusCode, reason string) error { c.closeMu.Lock() @@ -62,7 +62,7 @@ func (c *Conn) writeClose(code StatusCode, reason string) error { Reason: reason, } - c.setCloseErr(xerrors.Errorf("sent close frame: %w", ce)) + c.setCloseErr(fmt.Errorf("sent close frame: %w", ce)) var p []byte var err error @@ -119,7 +119,7 @@ func parseClosePayload(p []byte) (CloseError, error) { } if len(p) < 2 { - return CloseError{}, xerrors.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) + return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) } ce := CloseError{ @@ -128,7 +128,7 @@ func parseClosePayload(p []byte) (CloseError, error) { } if !validWireCloseCode(ce.Code) { - return CloseError{}, xerrors.Errorf("invalid status code %v", ce.Code) + return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code) } return ce, nil @@ -155,7 +155,7 @@ func validWireCloseCode(code StatusCode) bool { func (ce CloseError) bytes() ([]byte, error) { p, err := ce.bytesErr() if err != nil { - err = xerrors.Errorf("failed to marshal close frame: %w", err) + err = fmt.Errorf("failed to marshal close frame: %w", err) ce = CloseError{ Code: StatusInternalError, } @@ -168,11 +168,11 @@ const maxCloseReason = maxControlPayload - 2 func (ce CloseError) bytesErr() ([]byte, error) { if len(ce.Reason) > maxCloseReason { - return nil, xerrors.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason)) + return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason)) } if !validWireCloseCode(ce.Code) { - return nil, xerrors.Errorf("status code %v cannot be set", ce.Code) + return nil, fmt.Errorf("status code %v cannot be set", ce.Code) } buf := make([]byte, 2+len(ce.Reason)) @@ -189,7 +189,7 @@ func (c *Conn) setCloseErr(err error) { func (c *Conn) setCloseErrLocked(err error) { if c.closeErr == nil { - c.closeErr = xerrors.Errorf("WebSocket closed: %w", err) + c.closeErr = fmt.Errorf("WebSocket closed: %w", err) } } diff --git a/compress_notjs.go b/compress_notjs.go index a6911056..ef82eb4d 100644 --- a/compress_notjs.go +++ b/compress_notjs.go @@ -3,6 +3,7 @@ package websocket import ( + "fmt" "io" "net/http" "sync" @@ -19,7 +20,10 @@ func (m CompressionMode) opts() *compressionOptions { type compressionOptions struct { clientNoContextTakeover bool + clientMaxWindowBits int + serverNoContextTakeover bool + serverMaxWindowBits int } func (copts *compressionOptions) setHeader(h http.Header) { @@ -30,6 +34,12 @@ func (copts *compressionOptions) setHeader(h http.Header) { if copts.serverNoContextTakeover { s += "; server_no_context_takeover" } + if false && copts.serverMaxWindowBits > 0 { + s += fmt.Sprintf("; server_max_window_bits=%v", copts.serverMaxWindowBits) + } + if false && copts.clientMaxWindowBits > 0 { + s += fmt.Sprintf("; client_max_window_bits=%v", copts.clientMaxWindowBits) + } h.Set("Sec-WebSocket-Extensions", s) } @@ -152,9 +162,8 @@ func (sw *slidingWindow) close() { } swPoolMu.Lock() - defer swPoolMu.Unlock() - swPool[cap(sw.buf)].Put(sw.buf) + swPoolMu.Unlock() sw.buf = nil } diff --git a/conn_notjs.go b/conn_notjs.go index 8598ded3..7ee60fbc 100644 --- a/conn_notjs.go +++ b/conn_notjs.go @@ -5,13 +5,13 @@ package websocket import ( "bufio" "context" + "errors" + "fmt" "io" "runtime" "strconv" "sync" "sync/atomic" - - "golang.org/x/xerrors" ) // Conn represents a WebSocket connection. @@ -108,7 +108,7 @@ func newConn(cfg connConfig) *Conn { } runtime.SetFinalizer(c, func(c *Conn) { - c.close(xerrors.New("connection garbage collected")) + c.close(errors.New("connection garbage collected")) }) go c.timeoutLoop() @@ -165,10 +165,10 @@ func (c *Conn) timeoutLoop() { case readCtx = <-c.readTimeout: case <-readCtx.Done(): - c.setCloseErr(xerrors.Errorf("read timed out: %w", readCtx.Err())) - go c.writeError(StatusPolicyViolation, xerrors.New("timed out")) + c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err())) + go c.writeError(StatusPolicyViolation, errors.New("timed out")) case <-writeCtx.Done(): - c.close(xerrors.Errorf("write timed out: %w", writeCtx.Err())) + c.close(fmt.Errorf("write timed out: %w", writeCtx.Err())) return } } @@ -190,7 +190,7 @@ func (c *Conn) Ping(ctx context.Context) error { err := c.ping(ctx, strconv.Itoa(int(p))) if err != nil { - return xerrors.Errorf("failed to ping: %w", err) + return fmt.Errorf("failed to ping: %w", err) } return nil } @@ -217,7 +217,7 @@ func (c *Conn) ping(ctx context.Context, p string) error { case <-c.closed: return c.closeErr case <-ctx.Done(): - err := xerrors.Errorf("failed to wait for pong: %w", ctx.Err()) + err := fmt.Errorf("failed to wait for pong: %w", ctx.Err()) c.close(err) return err case <-pong: @@ -242,7 +242,7 @@ func (m *mu) Lock(ctx context.Context) error { case <-m.c.closed: return m.c.closeErr case <-ctx.Done(): - err := xerrors.Errorf("failed to acquire lock: %w", ctx.Err()) + err := fmt.Errorf("failed to acquire lock: %w", ctx.Err()) m.c.close(err) return err case m.ch <- struct{}{}: diff --git a/conn_test.go b/conn_test.go index 7755048c..64e6736f 100644 --- a/conn_test.go +++ b/conn_test.go @@ -19,7 +19,6 @@ import ( "github.com/golang/protobuf/ptypes" "github.com/golang/protobuf/ptypes/duration" - "golang.org/x/xerrors" "nhooyr.io/websocket" "nhooyr.io/websocket/internal/test/assert" @@ -115,13 +114,21 @@ func TestConn(t *testing.T) { for i := 0; i < count; i++ { go func() { - errs <- c1.Write(tt.ctx, websocket.MessageBinary, msg) + select { + case errs <- c1.Write(tt.ctx, websocket.MessageBinary, msg): + case <-tt.ctx.Done(): + return + } }() } for i := 0; i < count; i++ { - err := <-errs - assert.Success(t, err) + select { + case err := <-errs: + assert.Success(t, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } } err := c1.Close(websocket.StatusNormalClosure, "") @@ -172,8 +179,12 @@ func TestConn(t *testing.T) { _, err = n1.Read(nil) assert.Equal(t, "read error", err, io.EOF) - err = <-errs - assert.Success(t, err) + select { + case err := <-errs: + assert.Success(t, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } assert.Equal(t, "read msg", []byte("hello"), b) }) @@ -196,8 +207,12 @@ func TestConn(t *testing.T) { _, err := ioutil.ReadAll(n1) assert.Contains(t, err, `unexpected frame type read (expected MessageBinary): MessageText`) - err = <-errs - assert.Success(t, err) + select { + case err := <-errs: + assert.Success(t, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } }) t.Run("wsjson", func(t *testing.T) { @@ -219,8 +234,12 @@ func TestConn(t *testing.T) { assert.Success(t, err) assert.Equal(t, "read msg", exp, act) - err = <-werr - assert.Success(t, err) + select { + case err := <-werr: + assert.Success(t, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } err = c1.Close(websocket.StatusNormalClosure, "") assert.Success(t, err) @@ -289,10 +308,10 @@ func TestWasm(t *testing.T) { func assertCloseStatus(exp websocket.StatusCode, err error) error { if websocket.CloseStatus(err) == -1 { - return xerrors.Errorf("expected websocket.CloseError: %T %v", err, err) + return fmt.Errorf("expected websocket.CloseError: %T %v", err, err) } if websocket.CloseStatus(err) != exp { - return xerrors.Errorf("expected close status %v but got ", exp, err) + return fmt.Errorf("expected close status %v but got %v", exp, err) } return nil } @@ -412,14 +431,22 @@ func BenchmarkConn(b *testing.B) { go func() { for range writes { - werrs <- c1.Write(bb.ctx, websocket.MessageText, msg) + select { + case werrs <- c1.Write(bb.ctx, websocket.MessageText, msg): + case <-bb.ctx.Done(): + return + } } }() b.SetBytes(int64(len(msg))) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - writes <- struct{}{} + select { + case writes <- struct{}{}: + case <-bb.ctx.Done(): + b.Fatal(bb.ctx.Err()) + } typ, r, err := c1.Reader(bb.ctx) if err != nil { @@ -446,7 +473,11 @@ func BenchmarkConn(b *testing.B) { assert.Equal(b, "msg", msg, readBuf) } - err = <-werrs + select { + case err = <-werrs: + case <-bb.ctx.Done(): + b.Fatal(bb.ctx.Err()) + } if err != nil { b.Fatal(err) } diff --git a/dial.go b/dial.go index 09546ac6..8ff39597 100644 --- a/dial.go +++ b/dial.go @@ -8,14 +8,15 @@ import ( "context" "crypto/rand" "encoding/base64" + "errors" + "fmt" "io" "io/ioutil" "net/http" "net/url" "strings" "sync" - - "golang.org/x/xerrors" + "time" "nhooyr.io/websocket/internal/errd" ) @@ -78,7 +79,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( secWebSocketKey, err := secWebSocketKey(rand) if err != nil { - return nil, nil, xerrors.Errorf("failed to generate Sec-WebSocket-Key: %w", err) + return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err) } resp, err := handshakeRequest(ctx, urls, opts, secWebSocketKey) @@ -91,6 +92,12 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( if err != nil { // We read a bit of the body for easier debugging. r := io.LimitReader(respBody, 1024) + + timer := time.AfterFunc(time.Second*3, func() { + respBody.Close() + }) + defer timer.Stop() + b, _ := ioutil.ReadAll(r) respBody.Close() resp.Body = ioutil.NopCloser(bytes.NewReader(b)) @@ -104,7 +111,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( rwc, ok := respBody.(io.ReadWriteCloser) if !ok { - return nil, resp, xerrors.Errorf("response body is not a io.ReadWriteCloser: %T", respBody) + return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", respBody) } return newConn(connConfig{ @@ -120,12 +127,12 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWebSocketKey string) (*http.Response, error) { if opts.HTTPClient.Timeout > 0 { - return nil, xerrors.New("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67") + return nil, errors.New("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67") } u, err := url.Parse(urls) if err != nil { - return nil, xerrors.Errorf("failed to parse url: %w", err) + return nil, fmt.Errorf("failed to parse url: %w", err) } switch u.Scheme { @@ -134,7 +141,7 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWe case "wss": u.Scheme = "https" default: - return nil, xerrors.Errorf("unexpected url scheme: %q", u.Scheme) + return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme) } req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil) @@ -148,12 +155,13 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWe } if opts.CompressionMode != CompressionDisabled { copts := opts.CompressionMode.opts() + copts.clientMaxWindowBits = 8 copts.setHeader(req.Header) } resp, err := opts.HTTPClient.Do(req) if err != nil { - return nil, xerrors.Errorf("failed to send handshake request: %w", err) + return nil, fmt.Errorf("failed to send handshake request: %w", err) } return resp, nil } @@ -165,26 +173,26 @@ func secWebSocketKey(rr io.Reader) (string, error) { b := make([]byte, 16) _, err := io.ReadFull(rr, b) if err != nil { - return "", xerrors.Errorf("failed to read random data from rand.Reader: %w", err) + return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err) } return base64.StdEncoding.EncodeToString(b), nil } func verifyServerResponse(opts *DialOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) { if resp.StatusCode != http.StatusSwitchingProtocols { - return nil, xerrors.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) + return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) } if !headerContainsToken(resp.Header, "Connection", "Upgrade") { - return nil, xerrors.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) + return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) } if !headerContainsToken(resp.Header, "Upgrade", "WebSocket") { - return nil, xerrors.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) + return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) } if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) { - return nil, xerrors.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", + return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", resp.Header.Get("Sec-WebSocket-Accept"), secWebSocketKey, ) @@ -210,7 +218,7 @@ func verifySubprotocol(subprotos []string, resp *http.Response) error { } } - return xerrors.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) + return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) } func verifyServerExtensions(h http.Header) (*compressionOptions, error) { @@ -221,19 +229,40 @@ func verifyServerExtensions(h http.Header) (*compressionOptions, error) { ext := exts[0] if ext.name != "permessage-deflate" || len(exts) > 1 { - return nil, xerrors.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:]) + return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:]) } copts := &compressionOptions{} + copts.clientMaxWindowBits = 8 for _, p := range ext.params { switch p { case "client_no_context_takeover": copts.clientNoContextTakeover = true + continue case "server_no_context_takeover": copts.serverNoContextTakeover = true - default: - return nil, xerrors.Errorf("unsupported permessage-deflate parameter: %q", p) + continue + } + + if false && strings.HasPrefix(p, "server_max_window_bits") { + bits, ok := parseExtensionParameter(p, 0) + if !ok || bits < 8 || bits > 16 { + return nil, fmt.Errorf("invalid server_max_window_bits: %q", p) + } + copts.serverMaxWindowBits = bits + continue + } + + if false && strings.HasPrefix(p, "client_max_window_bits") { + bits, ok := parseExtensionParameter(p, 0) + if !ok || bits < 8 || bits > 16 { + return nil, fmt.Errorf("invalid client_max_window_bits: %q", p) + } + copts.clientMaxWindowBits = 8 + continue } + + return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p) } return copts, nil diff --git a/example_echo_test.go b/example_echo_test.go index 1daec8a5..cd195d2e 100644 --- a/example_echo_test.go +++ b/example_echo_test.go @@ -4,6 +4,7 @@ package websocket_test import ( "context" + "errors" "fmt" "io" "log" @@ -12,7 +13,6 @@ import ( "time" "golang.org/x/time/rate" - "golang.org/x/xerrors" "nhooyr.io/websocket" "nhooyr.io/websocket/wsjson" @@ -78,7 +78,7 @@ func echoServer(w http.ResponseWriter, r *http.Request) error { if c.Subprotocol() != "echo" { c.Close(websocket.StatusPolicyViolation, "client must speak the echo subprotocol") - return xerrors.New("client does not speak echo sub protocol") + return errors.New("client does not speak echo sub protocol") } l := rate.NewLimiter(rate.Every(time.Millisecond*100), 10) @@ -88,7 +88,7 @@ func echoServer(w http.ResponseWriter, r *http.Request) error { return nil } if err != nil { - return xerrors.Errorf("failed to echo with %v: %w", r.RemoteAddr, err) + return fmt.Errorf("failed to echo with %v: %w", r.RemoteAddr, err) } } } @@ -117,7 +117,7 @@ func echo(ctx context.Context, c *websocket.Conn, l *rate.Limiter) error { _, err = io.Copy(w, r) if err != nil { - return xerrors.Errorf("failed to io.Copy: %w", err) + return fmt.Errorf("failed to io.Copy: %w", err) } err = w.Close() diff --git a/frame.go b/frame.go index 4acaecf4..2a036f94 100644 --- a/frame.go +++ b/frame.go @@ -3,12 +3,11 @@ package websocket import ( "bufio" "encoding/binary" + "fmt" "io" "math" "math/bits" - "golang.org/x/xerrors" - "nhooyr.io/websocket/internal/errd" ) @@ -87,7 +86,7 @@ func readFrameHeader(r *bufio.Reader, readBuf []byte) (h header, err error) { } if h.payloadLength < 0 { - return header{}, xerrors.Errorf("received negative payload length: %v", h.payloadLength) + return header{}, fmt.Errorf("received negative payload length: %v", h.payloadLength) } if h.masked { diff --git a/go.mod b/go.mod index a10c7b1e..801d6be6 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module nhooyr.io/websocket -go 1.12 +go 1.13 require ( github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee // indirect @@ -11,5 +11,4 @@ require ( github.com/gorilla/websocket v1.4.1 github.com/klauspost/compress v1.10.0 golang.org/x/time v0.0.0-20191024005414-555d28b269f0 - golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 ) diff --git a/internal/errd/wrap.go b/internal/errd/wrap.go index ed0b7754..6e779131 100644 --- a/internal/errd/wrap.go +++ b/internal/errd/wrap.go @@ -2,41 +2,13 @@ package errd import ( "fmt" - - "golang.org/x/xerrors" ) -type wrapError struct { - msg string - err error - frame xerrors.Frame -} - -func (e *wrapError) Error() string { - return fmt.Sprint(e) -} - -func (e *wrapError) Format(s fmt.State, v rune) { xerrors.FormatError(e, s, v) } - -func (e *wrapError) FormatError(p xerrors.Printer) (next error) { - p.Print(e.msg) - e.frame.Format(p) - return e.err -} - -func (e *wrapError) Unwrap() error { - return e.err -} - -// Wrap wraps err with xerrors.Errorf if err is non nil. +// Wrap wraps err with fmt.Errorf if err is non nil. // Intended for use with defer and a named error return. // Inspired by https://github.com/golang/go/issues/32676. func Wrap(err *error, f string, v ...interface{}) { if *err != nil { - *err = &wrapError{ - msg: fmt.Sprintf(f, v...), - err: *err, - frame: xerrors.Caller(1), - } + *err = fmt.Errorf(f+": %w", append(v, *err)...) } } diff --git a/internal/test/assert/assert.go b/internal/test/assert/assert.go index 602b887e..6eaf7fc3 100644 --- a/internal/test/assert/assert.go +++ b/internal/test/assert/assert.go @@ -2,17 +2,27 @@ package assert import ( "fmt" + "reflect" "strings" "testing" - "nhooyr.io/websocket/internal/test/cmp" + "github.com/golang/protobuf/proto" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" ) +// Diff returns a human readable diff between v1 and v2 +func Diff(v1, v2 interface{}) string { + return cmp.Diff(v1, v2, cmpopts.EquateErrors(), cmp.Exporter(func(r reflect.Type) bool { + return true + }), cmp.Comparer(proto.Equal)) +} + // Equal asserts exp == act. func Equal(t testing.TB, name string, exp, act interface{}) { t.Helper() - if diff := cmp.Diff(exp, act); diff != "" { + if diff := Diff(exp, act); diff != "" { t.Fatalf("unexpected %v: %v", name, diff) } } diff --git a/internal/test/cmp/cmp.go b/internal/test/cmp/cmp.go deleted file mode 100644 index eadcb5d9..00000000 --- a/internal/test/cmp/cmp.go +++ /dev/null @@ -1,16 +0,0 @@ -package cmp - -import ( - "reflect" - - "github.com/golang/protobuf/proto" - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" -) - -// Diff returns a human readable diff between v1 and v2 -func Diff(v1, v2 interface{}) string { - return cmp.Diff(v1, v2, cmpopts.EquateErrors(), cmp.Exporter(func(r reflect.Type) bool { - return true - }), cmp.Comparer(proto.Equal)) -} diff --git a/internal/test/wstest/echo.go b/internal/test/wstest/echo.go index 714767fc..8f4e47c8 100644 --- a/internal/test/wstest/echo.go +++ b/internal/test/wstest/echo.go @@ -3,13 +3,12 @@ package wstest import ( "bytes" "context" + "fmt" "io" "time" - "golang.org/x/xerrors" - "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/test/cmp" + "nhooyr.io/websocket/internal/test/assert" "nhooyr.io/websocket/internal/test/xrand" "nhooyr.io/websocket/internal/xsync" ) @@ -73,11 +72,11 @@ func Echo(ctx context.Context, c *websocket.Conn, max int) error { } if expType != actType { - return xerrors.Errorf("unexpected message typ (%v): %v", expType, actType) + return fmt.Errorf("unexpected message typ (%v): %v", expType, actType) } if !bytes.Equal(msg, act) { - return xerrors.Errorf("unexpected msg read: %v", cmp.Diff(msg, act)) + return fmt.Errorf("unexpected msg read: %v", assert.Diff(msg, act)) } return nil diff --git a/internal/test/wstest/pipe.go b/internal/test/wstest/pipe.go index 81705a8a..0a2899ee 100644 --- a/internal/test/wstest/pipe.go +++ b/internal/test/wstest/pipe.go @@ -5,12 +5,11 @@ package wstest import ( "bufio" "context" + "fmt" "net" "net/http" "net/http/httptest" - "golang.org/x/xerrors" - "nhooyr.io/websocket" "nhooyr.io/websocket/internal/errd" "nhooyr.io/websocket/internal/test/xrand" @@ -39,11 +38,11 @@ func Pipe(dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) clientConn, _, err := websocket.Dial(context.Background(), "ws://example.com", dialOpts) if err != nil { - return nil, nil, xerrors.Errorf("failed to dial with fake transport: %w", err) + return nil, nil, fmt.Errorf("failed to dial with fake transport: %w", err) } if serverConn == nil { - return nil, nil, xerrors.Errorf("failed to get server conn from fake transport: %w", acceptErr) + return nil, nil, fmt.Errorf("failed to get server conn from fake transport: %w", acceptErr) } if xrand.Bool() { diff --git a/internal/xsync/go.go b/internal/xsync/go.go index 712739aa..7a61f27f 100644 --- a/internal/xsync/go.go +++ b/internal/xsync/go.go @@ -1,7 +1,7 @@ package xsync import ( - "golang.org/x/xerrors" + "fmt" ) // Go allows running a function in another goroutine @@ -13,7 +13,7 @@ func Go(fn func() error) <-chan error { r := recover() if r != nil { select { - case errs <- xerrors.Errorf("panic in go fn: %v", r): + case errs <- fmt.Errorf("panic in go fn: %v", r): default: } } diff --git a/netconn.go b/netconn.go index a2d8f4f3..64aadf0b 100644 --- a/netconn.go +++ b/netconn.go @@ -2,13 +2,12 @@ package websocket import ( "context" + "fmt" "io" "math" "net" "sync" "time" - - "golang.org/x/xerrors" ) // NetConn converts a *websocket.Conn into a net.Conn. @@ -108,7 +107,7 @@ func (c *netConn) Read(p []byte) (int, error) { return 0, err } if typ != c.msgType { - err := xerrors.Errorf("unexpected frame type read (expected %v): %v", c.msgType, typ) + err := fmt.Errorf("unexpected frame type read (expected %v): %v", c.msgType, typ) c.c.Close(StatusUnsupportedData, err.Error()) return 0, err } diff --git a/read.go b/read.go index bbad30d1..a1efecab 100644 --- a/read.go +++ b/read.go @@ -5,13 +5,13 @@ package websocket import ( "bufio" "context" + "errors" + "fmt" "io" "io/ioutil" "strings" "time" - "golang.org/x/xerrors" - "nhooyr.io/websocket/internal/errd" "nhooyr.io/websocket/internal/xsync" ) @@ -144,13 +144,13 @@ func (c *Conn) readLoop(ctx context.Context) (header, error) { } if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 { - err := xerrors.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3) + err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3) c.writeError(StatusProtocolError, err) return header{}, err } if !c.client && !h.masked { - return header{}, xerrors.New("received unmasked frame from client") + return header{}, errors.New("received unmasked frame from client") } switch h.opcode { @@ -161,12 +161,12 @@ func (c *Conn) readLoop(ctx context.Context) (header, error) { if h.opcode == opClose && CloseStatus(err) != -1 { return header{}, err } - return header{}, xerrors.Errorf("failed to handle control frame %v: %w", h.opcode, err) + return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err) } case opContinuation, opText, opBinary: return h, nil default: - err := xerrors.Errorf("received unknown opcode %v", h.opcode) + err := fmt.Errorf("received unknown opcode %v", h.opcode) c.writeError(StatusProtocolError, err) return header{}, err } @@ -217,7 +217,7 @@ func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { case <-ctx.Done(): return n, ctx.Err() default: - err = xerrors.Errorf("failed to read frame payload: %w", err) + err = fmt.Errorf("failed to read frame payload: %w", err) c.close(err) return n, err } @@ -234,13 +234,13 @@ func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { func (c *Conn) handleControl(ctx context.Context, h header) (err error) { if h.payloadLength < 0 || h.payloadLength > maxControlPayload { - err := xerrors.Errorf("received control frame payload with invalid length: %d", h.payloadLength) + err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength) c.writeError(StatusProtocolError, err) return err } if !h.fin { - err := xerrors.New("received fragmented control frame") + err := errors.New("received fragmented control frame") c.writeError(StatusProtocolError, err) return err } @@ -277,12 +277,12 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { ce, err := parseClosePayload(b) if err != nil { - err = xerrors.Errorf("received invalid close payload: %w", err) + err = fmt.Errorf("received invalid close payload: %w", err) c.writeError(StatusProtocolError, err) return err } - err = xerrors.Errorf("received close frame: %w", ce) + err = fmt.Errorf("received close frame: %w", ce) c.setCloseErr(err) c.writeClose(ce.Code, ce.Reason) c.close(err) @@ -299,7 +299,7 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro defer c.readMu.Unlock() if !c.msgReader.fin { - return 0, nil, xerrors.New("previous message not read to completion") + return 0, nil, errors.New("previous message not read to completion") } h, err := c.readLoop(ctx) @@ -308,7 +308,7 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro } if h.opcode == opContinuation { - err := xerrors.New("received continuation frame without text or binary frame") + err := errors.New("received continuation frame without text or binary frame") c.writeError(StatusProtocolError, err) return 0, nil, err } @@ -357,10 +357,10 @@ func (mr *msgReader) setFrame(h header) { func (mr *msgReader) Read(p []byte) (n int, err error) { defer func() { - if xerrors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate { + if errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate { err = io.EOF } - if xerrors.Is(err, io.EOF) { + if errors.Is(err, io.EOF) { err = io.EOF mr.putFlateReader() return @@ -397,7 +397,7 @@ func (mr *msgReader) read(p []byte) (int, error) { return 0, err } if h.opcode != opContinuation { - err := xerrors.New("received new data message without finishing the previous message") + err := errors.New("received new data message without finishing the previous message") mr.c.writeError(StatusProtocolError, err) return 0, err } @@ -448,7 +448,7 @@ func (lr *limitReader) reset(r io.Reader) { func (lr *limitReader) Read(p []byte) (int, error) { if lr.n <= 0 { - err := xerrors.Errorf("read limited at %v bytes", lr.limit.Load()) + err := fmt.Errorf("read limited at %v bytes", lr.limit.Load()) lr.c.writeError(StatusMessageTooBig, err) return 0, err } diff --git a/write.go b/write.go index b560b44c..81b9141a 100644 --- a/write.go +++ b/write.go @@ -7,12 +7,13 @@ import ( "context" "crypto/rand" "encoding/binary" + "errors" + "fmt" "io" "sync" "time" "github.com/klauspost/compress/flate" - "golang.org/x/xerrors" "nhooyr.io/websocket/internal/errd" ) @@ -27,7 +28,7 @@ import ( func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { w, err := c.writer(ctx, typ) if err != nil { - return nil, xerrors.Errorf("failed to get writer: %w", err) + return nil, fmt.Errorf("failed to get writer: %w", err) } return w, nil } @@ -41,7 +42,7 @@ func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, err func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { _, err := c.write(ctx, typ, p) if err != nil { - return xerrors.Errorf("failed to write msg: %w", err) + return fmt.Errorf("failed to write msg: %w", err) } return nil } @@ -53,14 +54,14 @@ type msgWriter struct { func (mw *msgWriter) Write(p []byte) (int, error) { if mw.closed { - return 0, xerrors.New("cannot use closed writer") + return 0, errors.New("cannot use closed writer") } return mw.mw.Write(p) } func (mw *msgWriter) Close() error { if mw.closed { - return xerrors.New("cannot use closed writer") + return errors.New("cannot use closed writer") } mw.closed = true return mw.mw.Close() @@ -182,7 +183,7 @@ func (mw *msgWriterState) Write(p []byte) (_ int, err error) { func (mw *msgWriterState) write(p []byte) (int, error) { n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p) if err != nil { - return n, xerrors.Errorf("failed to write data frame: %w", err) + return n, fmt.Errorf("failed to write data frame: %w", err) } mw.opcode = opContinuation return n, nil @@ -197,7 +198,7 @@ func (mw *msgWriterState) Close() (err error) { _, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil) if err != nil { - return xerrors.Errorf("failed to write fin frame: %w", err) + return fmt.Errorf("failed to write fin frame: %w", err) } if mw.flate && !mw.flateContextTakeover() { @@ -218,7 +219,7 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error _, err := c.writeFrame(ctx, true, false, opcode, p) if err != nil { - return xerrors.Errorf("failed to write control frame %v: %w", opcode, err) + return fmt.Errorf("failed to write control frame %v: %w", opcode, err) } return nil } @@ -245,7 +246,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco c.writeHeader.masked = true _, err = io.ReadFull(rand.Reader, c.writeHeaderBuf[:4]) if err != nil { - return 0, xerrors.Errorf("failed to generate masking key: %w", err) + return 0, fmt.Errorf("failed to generate masking key: %w", err) } c.writeHeader.maskKey = binary.LittleEndian.Uint32(c.writeHeaderBuf[:]) } @@ -268,7 +269,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco if c.writeHeader.fin { err = c.bw.Flush() if err != nil { - return n, xerrors.Errorf("failed to flush: %w", err) + return n, fmt.Errorf("failed to flush: %w", err) } } diff --git a/ws_js.go b/ws_js.go index ecf3d78c..2b560ce8 100644 --- a/ws_js.go +++ b/ws_js.go @@ -3,6 +3,8 @@ package websocket // import "nhooyr.io/websocket" import ( "bytes" "context" + "errors" + "fmt" "io" "net/http" "reflect" @@ -10,8 +12,6 @@ import ( "sync" "syscall/js" - "golang.org/x/xerrors" - "nhooyr.io/websocket/internal/bpool" "nhooyr.io/websocket/internal/wsjs" "nhooyr.io/websocket/internal/xsync" @@ -45,7 +45,7 @@ func (c *Conn) close(err error, wasClean bool) { runtime.SetFinalizer(c, nil) if !wasClean { - err = xerrors.Errorf("unclean connection close: %w", err) + err = fmt.Errorf("unclean connection close: %w", err) } c.setCloseErr(err) c.closeWasClean = wasClean @@ -87,7 +87,7 @@ func (c *Conn) init() { }) runtime.SetFinalizer(c, func(c *Conn) { - c.setCloseErr(xerrors.New("connection garbage collected")) + c.setCloseErr(errors.New("connection garbage collected")) c.closeWithInternal() }) } @@ -100,15 +100,15 @@ func (c *Conn) closeWithInternal() { // The maximum time spent waiting is bounded by the context. func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { if c.isReadClosed.Load() == 1 { - return 0, nil, xerrors.New("WebSocket connection read closed") + return 0, nil, errors.New("WebSocket connection read closed") } typ, p, err := c.read(ctx) if err != nil { - return 0, nil, xerrors.Errorf("failed to read: %w", err) + return 0, nil, fmt.Errorf("failed to read: %w", err) } if int64(len(p)) > c.msgReadLimit.Load() { - err := xerrors.Errorf("read limited at %v bytes", c.msgReadLimit.Load()) + err := fmt.Errorf("read limited at %v bytes", c.msgReadLimit.Load()) c.Close(StatusMessageTooBig, err.Error()) return 0, nil, err } @@ -166,7 +166,7 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { // to match the Go API. It can only error if the message type // is unexpected or the passed bytes contain invalid UTF-8 for // MessageText. - err := xerrors.Errorf("failed to write: %w", err) + err := fmt.Errorf("failed to write: %w", err) c.setCloseErr(err) c.closeWithInternal() return err @@ -184,7 +184,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error { case MessageText: return c.ws.SendText(string(p)) default: - return xerrors.Errorf("unexpected message type: %v", typ) + return fmt.Errorf("unexpected message type: %v", typ) } } @@ -195,7 +195,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error { func (c *Conn) Close(code StatusCode, reason string) error { err := c.exportedClose(code, reason) if err != nil { - return xerrors.Errorf("failed to close WebSocket: %w", err) + return fmt.Errorf("failed to close WebSocket: %w", err) } return nil } @@ -204,13 +204,13 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error { c.closingMu.Lock() defer c.closingMu.Unlock() - ce := xerrors.Errorf("sent close: %w", CloseError{ + ce := fmt.Errorf("sent close: %w", CloseError{ Code: code, Reason: reason, }) if c.isClosed() { - return xerrors.Errorf("tried to close with %q but connection already closed: %w", ce, c.closeErr) + return fmt.Errorf("tried to close with %q but connection already closed: %w", ce, c.closeErr) } c.setCloseErr(ce) @@ -245,7 +245,7 @@ type DialOptions struct { func Dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) { c, resp, err := dial(ctx, url, opts) if err != nil { - return nil, nil, xerrors.Errorf("failed to WebSocket dial %q: %w", url, err) + return nil, nil, fmt.Errorf("failed to WebSocket dial %q: %w", url, err) } return c, resp, nil } @@ -318,25 +318,25 @@ type writer struct { func (w writer) Write(p []byte) (int, error) { if w.closed { - return 0, xerrors.New("cannot write to closed writer") + return 0, errors.New("cannot write to closed writer") } n, err := w.b.Write(p) if err != nil { - return n, xerrors.Errorf("failed to write message: %w", err) + return n, fmt.Errorf("failed to write message: %w", err) } return n, nil } func (w writer) Close() error { if w.closed { - return xerrors.New("cannot close closed writer") + return errors.New("cannot close closed writer") } w.closed = true defer bpool.Put(w.b) err := w.c.Write(w.ctx, w.typ, w.b.Bytes()) if err != nil { - return xerrors.Errorf("failed to close writer: %w", err) + return fmt.Errorf("failed to close writer: %w", err) } return nil } @@ -361,7 +361,7 @@ func (c *Conn) SetReadLimit(n int64) { func (c *Conn) setCloseErr(err error) { c.closeErrOnce.Do(func() { - c.closeErr = xerrors.Errorf("WebSocket closed: %w", err) + c.closeErr = fmt.Errorf("WebSocket closed: %w", err) }) } diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go index e6f06a2f..99996a69 100644 --- a/wsjson/wsjson.go +++ b/wsjson/wsjson.go @@ -4,8 +4,7 @@ package wsjson // import "nhooyr.io/websocket/wsjson" import ( "context" "encoding/json" - - "golang.org/x/xerrors" + "fmt" "nhooyr.io/websocket" "nhooyr.io/websocket/internal/bpool" @@ -28,7 +27,7 @@ func read(ctx context.Context, c *websocket.Conn, v interface{}) (err error) { if typ != websocket.MessageText { c.Close(websocket.StatusUnsupportedData, "expected text message") - return xerrors.Errorf("expected text message for JSON but got: %v", typ) + return fmt.Errorf("expected text message for JSON but got: %v", typ) } b := bpool.Get() @@ -42,7 +41,7 @@ func read(ctx context.Context, c *websocket.Conn, v interface{}) (err error) { err = json.Unmarshal(b.Bytes(), v) if err != nil { c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal JSON") - return xerrors.Errorf("failed to unmarshal JSON: %w", err) + return fmt.Errorf("failed to unmarshal JSON: %w", err) } return nil @@ -66,7 +65,7 @@ func write(ctx context.Context, c *websocket.Conn, v interface{}) (err error) { // a copy of the byte slice but Encoder does as it directly writes to w. err = json.NewEncoder(w).Encode(v) if err != nil { - return xerrors.Errorf("failed to marshal JSON: %w", err) + return fmt.Errorf("failed to marshal JSON: %w", err) } return w.Close() diff --git a/wspb/wspb.go b/wspb/wspb.go index 06ac3368..e43042d5 100644 --- a/wspb/wspb.go +++ b/wspb/wspb.go @@ -4,9 +4,9 @@ package wspb // import "nhooyr.io/websocket/wspb" import ( "bytes" "context" + "fmt" "github.com/golang/protobuf/proto" - "golang.org/x/xerrors" "nhooyr.io/websocket" "nhooyr.io/websocket/internal/bpool" @@ -29,7 +29,7 @@ func read(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) { if typ != websocket.MessageBinary { c.Close(websocket.StatusUnsupportedData, "expected binary message") - return xerrors.Errorf("expected binary message for protobuf but got: %v", typ) + return fmt.Errorf("expected binary message for protobuf but got: %v", typ) } b := bpool.Get() @@ -43,7 +43,7 @@ func read(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) { err = proto.Unmarshal(b.Bytes(), v) if err != nil { c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal protobuf") - return xerrors.Errorf("failed to unmarshal protobuf: %w", err) + return fmt.Errorf("failed to unmarshal protobuf: %w", err) } return nil @@ -66,7 +66,7 @@ func write(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) err = pb.Marshal(v) if err != nil { - return xerrors.Errorf("failed to marshal protobuf: %w", err) + return fmt.Errorf("failed to marshal protobuf: %w", err) } return c.Write(ctx, websocket.MessageBinary, pb.Bytes())