From 995f67177dcee19ea5c303d54c339cab54f70467 Mon Sep 17 00:00:00 2001 From: Ivan Kozlovic Date: Tue, 6 Apr 2021 10:46:14 -0600 Subject: [PATCH] Added websocket support Signed-off-by: Ivan Kozlovic --- nats.go | 66 +++- ws.go | 697 ++++++++++++++++++++++++++++++++++++++ ws_test.go | 977 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 1732 insertions(+), 8 deletions(-) create mode 100644 ws.go create mode 100644 ws_test.go diff --git a/nats.go b/nats.go index 6050717e9..158c29ae7 100644 --- a/nats.go +++ b/nats.go @@ -426,6 +426,10 @@ type Options struct { // is established, and if a ClosedHandler is set, it will be invoked if // it fails to connect (after exhausting the MaxReconnect attempts). RetryOnFailedConnect bool + + // For websocket connections, indicates to the server that the connection + // supports compression. If the server does too, then data will be compressed. + Compression bool } const ( @@ -484,6 +488,7 @@ type Conn struct { pout int ar bool // abort reconnect rqch chan struct{} + ws bool // true if a websocket connection // New style response handler respSub string // The wildcard subject @@ -691,6 +696,8 @@ type MsgHandler func(msg *Msg) // The url can contain username/password semantics. e.g. nats://derek:pass@localhost:4222 // Comma separated arrays are also supported, e.g. urlA, urlB. // Options start with the defaults but can be overridden. +// To connect to a NATS Server's websocket port, use the `ws` or `wss` scheme, such as +// `ws://localhost:8080`. Note that websocket schemes cannot be mixed with others (nats/tls). func Connect(url string, options ...Option) (*Conn, error) { opts := GetDefaultOptions() opts.Servers = processUrlString(url) @@ -1085,6 +1092,15 @@ func RetryOnFailedConnect(retry bool) Option { } } +// Compression is an Option to indicate if this connection supports +// compression. Currently only supported for Websocket connections. +func Compression(enabled bool) Option { + return func(o *Options) error { + o.Compression = enabled + return nil + } +} + // Handler processing // SetDisconnectHandler will set the disconnect event handler. @@ -1375,6 +1391,12 @@ func (nc *Conn) setupServerPool() error { // Helper function to return scheme func (nc *Conn) connScheme() string { + if nc.ws { + if nc.Opts.Secure { + return wsSchemeTLS + } + return wsScheme + } if nc.Opts.Secure { return tlsScheme } @@ -1411,6 +1433,16 @@ func (nc *Conn) addURLToPool(sURL string, implicit, saveTLSName bool) error { sURL += defaultPortString } + isWS := isWebsocketScheme(u) + // We don't support mix and match of websocket and non websocket URLs. + // If this is the first URL, then we accept and switch the global state + // to websocket. After that, we will know how to reject mixed URLs. + if len(nc.srvPool) == 0 { + nc.ws = isWS + } else if isWS && !nc.ws || !isWS && nc.ws { + return fmt.Errorf("mixing of websocket and non websocket URLs is not allowed") + } + var tlsName string if implicit { curl := nc.current.url @@ -1506,11 +1538,14 @@ func (w *natsWriter) writeDirect(strs ...string) error { func (w *natsWriter) flush() error { // If a pending buffer is set, we don't flush. Code that needs to - // write directly to the socket, by-passing buffers during (re)connect - // use the writeDirect() API. - if w.pending != nil || len(w.bufs) == 0 { + // write directly to the socket, by-passing buffers during (re)connect, + // will use the writeDirect() API. + if w.pending != nil { return nil } + // Do not skip calling w.w.Write() here if len(w.bufs) is 0 because + // the actual writer (if websocket for instance) may have things + // to do such as sending control frames, etc.. _, err := w.w.Write(w.bufs) w.bufs = w.bufs[:0] return err @@ -1638,6 +1673,11 @@ func (nc *Conn) createConn() (err error) { return err } + // If scheme starts with "ws" then branch out to websocket code. + if isWebsocketScheme(u) { + return nc.wsInitHandshake(u) + } + // Reset reader/writer to this new TCP connection nc.bindToNewConn() return nil @@ -1926,6 +1966,12 @@ func (nc *Conn) processExpectedInfo() error { return ErrNkeysNotSupported } + // For websocket connections, we already switched to TLS if need be, + // so we are done here. + if nc.ws { + return nil + } + return nc.checkForSecure() } @@ -4417,6 +4463,12 @@ func (nc *Conn) close(status Status, doCBs bool, err error) { // all blocking calls, such as Flush() and NextMsg() func (nc *Conn) Close() { if nc != nil { + // This will be a no-op if the connection was not websocket. + // We do this here as opposed to inside close() because we want + // to do this only for the final user-driven close of the client. + // Otherwise, we would need to change close() to pass a boolean + // indicating that this is the case. + nc.wsClose() nc.close(CLOSED, !nc.Opts.NoCallbacksAfterClientClose, nil) } } @@ -4456,7 +4508,7 @@ func (nc *Conn) drainConnection() { if nc.isConnecting() || nc.isReconnecting() { nc.mu.Unlock() // Move to closed state. - nc.close(CLOSED, true, nil) + nc.Close() return } @@ -4536,12 +4588,10 @@ func (nc *Conn) drainConnection() { err := nc.FlushTimeout(5 * time.Second) if err != nil { pushErr(err) - nc.close(CLOSED, true, nil) - return } // Move to closed state. - nc.close(CLOSED, true, nil) + nc.Close() } // Drain will put a connection into a drain state. All subscriptions will @@ -4557,7 +4607,7 @@ func (nc *Conn) Drain() error { } if nc.isConnecting() || nc.isReconnecting() { nc.mu.Unlock() - nc.close(CLOSED, true, nil) + nc.Close() return ErrConnectionReconnecting } if nc.isDraining() { diff --git a/ws.go b/ws.go new file mode 100644 index 000000000..14a4d1f3b --- /dev/null +++ b/ws.go @@ -0,0 +1,697 @@ +// Copyright 2021 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nats + +import ( + "bufio" + "bytes" + "compress/flate" + "crypto/rand" + "crypto/sha1" + "encoding/base64" + "encoding/binary" + "errors" + "fmt" + "io" + "io/ioutil" + mrand "math/rand" + "net/http" + "net/url" + "strings" + "time" + "unicode/utf8" +) + +type wsOpCode int + +const ( + // From https://tools.ietf.org/html/rfc6455#section-5.2 + wsTextMessage = wsOpCode(1) + wsBinaryMessage = wsOpCode(2) + wsCloseMessage = wsOpCode(8) + wsPingMessage = wsOpCode(9) + wsPongMessage = wsOpCode(10) + + wsFinalBit = 1 << 7 + wsRsv1Bit = 1 << 6 // Used for compression, from https://tools.ietf.org/html/rfc7692#section-6 + wsRsv2Bit = 1 << 5 + wsRsv3Bit = 1 << 4 + + wsMaskBit = 1 << 7 + + wsContinuationFrame = 0 + wsMaxFrameHeaderSize = 14 // Since LeafNode may need to behave as a client + wsMaxControlPayloadSize = 125 + + // From https://tools.ietf.org/html/rfc6455#section-11.7 + wsCloseStatusNormalClosure = 1000 + wsCloseStatusNoStatusReceived = 1005 + wsCloseStatusAbnormalClosure = 1006 + wsCloseStatusInvalidPayloadData = 1007 + + wsScheme = "ws" + wsSchemeTLS = "wss" + + wsPMCExtension = "permessage-deflate" // per-message compression + wsPMCSrvNoCtx = "server_no_context_takeover" + wsPMCCliNoCtx = "client_no_context_takeover" + wsPMCReqHeaderValue = wsPMCExtension + "; " + wsPMCSrvNoCtx + "; " + wsPMCCliNoCtx +) + +// From https://tools.ietf.org/html/rfc6455#section-1.3 +var wsGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + +// As per https://tools.ietf.org/html/rfc7692#section-7.2.2 +// add 0x00, 0x00, 0xff, 0xff and then a final block so that flate reader +// does not report unexpected EOF. +var compressFinalBlock = []byte{0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff} + +type websocketReader struct { + r io.Reader + pending [][]byte + ib []byte + ff bool + fc bool + dc io.ReadCloser + nc *Conn +} + +type websocketWriter struct { + w io.Writer + compress bool + compressor *flate.Writer + ctrlFrames [][]byte // pending frames that should be sent at the next Write() + cm []byte // close message that needs to be sent when everything else has been sent + cmDone bool // a close message has been added or sent (never going back to false) + noMoreSend bool // if true, even if there is a Write() call, we should not send anything +} + +type decompressorBuffer struct { + buf []byte + rem int + off int + final bool +} + +func newDecompressorBuffer(buf []byte) *decompressorBuffer { + return &decompressorBuffer{buf: buf, rem: len(buf)} +} + +func (d *decompressorBuffer) Read(p []byte) (int, error) { + if d.buf == nil { + return 0, io.EOF + } + lim := d.rem + if len(p) < lim { + lim = len(p) + } + n := copy(p, d.buf[d.off:d.off+lim]) + d.off += n + d.rem -= n + d.checkRem() + return n, nil +} + +func (d *decompressorBuffer) checkRem() { + if d.rem != 0 { + return + } + if !d.final { + d.buf = compressFinalBlock + d.off = 0 + d.rem = len(d.buf) + d.final = true + } else { + d.buf = nil + } +} + +func (d *decompressorBuffer) ReadByte() (byte, error) { + if d.buf == nil { + return 0, io.EOF + } + b := d.buf[d.off] + d.off++ + d.rem-- + d.checkRem() + return b, nil +} + +func wsNewReader(r io.Reader) *websocketReader { + return &websocketReader{r: r, ff: true} +} + +func (r *websocketReader) Read(p []byte) (int, error) { + var err error + var buf []byte + + if l := len(r.ib); l > 0 { + buf = r.ib + r.ib = nil + } else { + if len(r.pending) > 0 { + return r.drainPending(p), nil + } + + // Get some data from the underlying reader. + n, err := r.r.Read(p) + if err != nil { + return 0, err + } + buf = p[:n] + } + + // Now parse this and decode frames. We will possibly read more to + // ensure that we get a full frame. + var ( + tmpBuf []byte + pos int + max = len(buf) + rem = 0 + ) + for pos < max { + b0 := buf[pos] + frameType := wsOpCode(b0 & 0xF) + final := b0&wsFinalBit != 0 + compressed := b0&wsRsv1Bit != 0 + pos++ + + tmpBuf, pos, err = wsGet(r.r, buf, pos, 1) + if err != nil { + return 0, err + } + b1 := tmpBuf[0] + + // Store size in case it is < 125 + rem = int(b1 & 0x7F) + + switch frameType { + case wsPingMessage, wsPongMessage, wsCloseMessage: + if rem > wsMaxControlPayloadSize { + return 0, fmt.Errorf( + fmt.Sprintf("control frame length bigger than maximum allowed of %v bytes", + wsMaxControlPayloadSize)) + } + if compressed { + return 0, errors.New("control frame should not be compressed") + } + if !final { + return 0, errors.New("control frame does not have final bit set") + } + case wsTextMessage, wsBinaryMessage: + if !r.ff { + return 0, errors.New("new message started before final frame for previous message was received") + } + r.ff = final + r.fc = compressed + case wsContinuationFrame: + // Compressed bit must be only set in the first frame + if r.ff || compressed { + return 0, errors.New("invalid continuation frame") + } + r.ff = final + default: + return 0, fmt.Errorf("unknown opcode %v", frameType) + } + + switch rem { + case 126: + tmpBuf, pos, err = wsGet(r.r, buf, pos, 2) + if err != nil { + return 0, err + } + rem = int(binary.BigEndian.Uint16(tmpBuf)) + case 127: + tmpBuf, pos, err = wsGet(r.r, buf, pos, 8) + if err != nil { + return 0, err + } + rem = int(binary.BigEndian.Uint64(tmpBuf)) + } + + // Handle control messages in place... + if wsIsControlFrame(frameType) { + pos, err = r.handleControlFrame(frameType, buf, pos, rem) + if err != nil { + return 0, err + } + rem = 0 + continue + } + + var b []byte + b, pos, err = wsGet(r.r, buf, pos, rem) + if err != nil { + return 0, err + } + rem = 0 + if r.fc { + br := newDecompressorBuffer(b) + if r.dc == nil { + r.dc = flate.NewReader(br) + } else { + r.dc.(flate.Resetter).Reset(br, nil) + } + // TODO: When Go 1.15 support is dropped, replace with io.ReadAll() + b, err = ioutil.ReadAll(r.dc) + if err != nil { + return 0, err + } + r.fc = false + } + r.pending = append(r.pending, b) + } + // At this point we should have pending slices. + return r.drainPending(p), nil +} + +func (r *websocketReader) drainPending(p []byte) int { + var n int + var max = len(p) + + for i, buf := range r.pending { + if n+len(buf) <= max { + copy(p[n:], buf) + n += len(buf) + } else { + // Is there room left? + if n < max { + // Write the partial and update this slice. + rem := max - n + copy(p[n:], buf[:rem]) + n += rem + r.pending[i] = buf[rem:] + } + // These are the remaining slices that will need to be used at + // the next Read() call. + r.pending = r.pending[i:] + return n + } + } + r.pending = r.pending[:0] + return n +} + +func wsGet(r io.Reader, buf []byte, pos, needed int) ([]byte, int, error) { + avail := len(buf) - pos + if avail >= needed { + return buf[pos : pos+needed], pos + needed, nil + } + b := make([]byte, needed) + start := copy(b, buf[pos:]) + for start != needed { + n, err := r.Read(b[start:cap(b)]) + start += n + if err != nil { + return b, start, err + } + } + return b, pos + avail, nil +} + +func (r *websocketReader) handleControlFrame(frameType wsOpCode, buf []byte, pos, rem int) (int, error) { + var payload []byte + var err error + + statusPos := pos + if rem > 0 { + payload, pos, err = wsGet(r.r, buf, pos, rem) + if err != nil { + return pos, err + } + } + switch frameType { + case wsCloseMessage: + status := wsCloseStatusNoStatusReceived + body := "" + // If there is a payload, it should contain 2 unsigned bytes + // that represent the status code and then optional payload. + if len(payload) >= 2 { + status = int(binary.BigEndian.Uint16(buf[statusPos : statusPos+2])) + body = string(buf[statusPos+2 : statusPos+len(payload)]) + if body != "" && !utf8.ValidString(body) { + // https://tools.ietf.org/html/rfc6455#section-5.5.1 + // If body is present, it must be a valid utf8 + status = wsCloseStatusInvalidPayloadData + body = "invalid utf8 body in close frame" + } + } + r.nc.wsEnqueueCloseMsg(status, body) + // Return io.EOF so that readLoop will close the connection as ClientClosed + // after processing pending buffers. + return pos, io.EOF + case wsPingMessage: + r.nc.wsEnqueueControlMsg(wsPongMessage, payload) + case wsPongMessage: + // Nothing to do.. + } + return pos, nil +} + +func (w *websocketWriter) Write(p []byte) (int, error) { + if w.noMoreSend { + return 0, nil + } + var total int + var n int + var err error + // If there are control frames, they can be sent now. Actually spec says + // that they should be sent ASAP, so we will send before any application data. + if len(w.ctrlFrames) > 0 { + n, err = w.writeCtrlFrames() + if err != nil { + return n, err + } + total += n + } + // Do the following only if there is something to send. + // We will end with checking for need to send close message. + if len(p) > 0 { + if w.compress { + buf := &bytes.Buffer{} + if w.compressor == nil { + w.compressor, _ = flate.NewWriter(buf, flate.BestSpeed) + } else { + w.compressor.Reset(buf) + } + w.compressor.Write(p) + w.compressor.Close() + b := buf.Bytes() + p = b[:len(b)-4] + } + fh, key := wsCreateFrameHeader(w.compress, wsBinaryMessage, len(p)) + wsMaskBuf(key, p) + n, err = w.w.Write(fh) + total += n + if err == nil { + n, err = w.w.Write(p) + total += n + } + } + if err == nil && w.cm != nil { + n, err = w.writeCloseMsg() + total += n + } + return total, err +} + +func (w *websocketWriter) writeCtrlFrames() (int, error) { + var ( + n int + total int + i int + err error + ) + for ; i < len(w.ctrlFrames); i++ { + buf := w.ctrlFrames[i] + n, err = w.w.Write(buf) + total += n + if err != nil { + break + } + } + if i != len(w.ctrlFrames) { + w.ctrlFrames = w.ctrlFrames[i+1:] + } else { + w.ctrlFrames = w.ctrlFrames[:0] + } + return total, err +} + +func (w *websocketWriter) writeCloseMsg() (int, error) { + n, err := w.w.Write(w.cm) + w.cm, w.noMoreSend = nil, true + return n, err +} + +func wsMaskBuf(key, buf []byte) { + for i := 0; i < len(buf); i++ { + buf[i] ^= key[i&3] + } +} + +// Create the frame header. +// Encodes the frame type and optional compression flag, and the size of the payload. +func wsCreateFrameHeader(compressed bool, frameType wsOpCode, l int) ([]byte, []byte) { + fh := make([]byte, wsMaxFrameHeaderSize) + n, key := wsFillFrameHeader(fh, compressed, frameType, l) + return fh[:n], key +} + +func wsFillFrameHeader(fh []byte, compressed bool, frameType wsOpCode, l int) (int, []byte) { + var n int + b := byte(frameType) + b |= wsFinalBit + if compressed { + b |= wsRsv1Bit + } + b1 := byte(wsMaskBit) + switch { + case l <= 125: + n = 2 + fh[0] = b + fh[1] = b1 | byte(l) + case l < 65536: + n = 4 + fh[0] = b + fh[1] = b1 | 126 + binary.BigEndian.PutUint16(fh[2:], uint16(l)) + default: + n = 10 + fh[0] = b + fh[1] = b1 | 127 + binary.BigEndian.PutUint64(fh[2:], uint64(l)) + } + var key []byte + var keyBuf [4]byte + if _, err := io.ReadFull(rand.Reader, keyBuf[:4]); err != nil { + kv := mrand.Int31() + binary.LittleEndian.PutUint32(keyBuf[:4], uint32(kv)) + } + copy(fh[n:], keyBuf[:4]) + key = fh[n : n+4] + n += 4 + return n, key +} + +func (nc *Conn) wsInitHandshake(u *url.URL) error { + compress := nc.Opts.Compression + tlsRequired := u.Scheme == wsSchemeTLS || nc.Opts.Secure || nc.Opts.TLSConfig != nil + // Do TLS here as needed. + if tlsRequired { + if err := nc.makeTLSConn(); err != nil { + return err + } + } else { + nc.bindToNewConn() + } + + var err error + + // For http request, we need the passed URL to contain either http or https scheme. + scheme := "http" + if tlsRequired { + scheme = "https" + } + ustr := fmt.Sprintf("%s://%s", scheme, u.Host) + u, err = url.Parse(ustr) + if err != nil { + return err + } + req := &http.Request{ + Method: "GET", + URL: u, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + Host: u.Host, + } + wsKey, err := wsMakeChallengeKey() + if err != nil { + return err + } + + req.Header["Upgrade"] = []string{"websocket"} + req.Header["Connection"] = []string{"Upgrade"} + req.Header["Sec-WebSocket-Key"] = []string{wsKey} + req.Header["Sec-WebSocket-Version"] = []string{"13"} + if compress { + req.Header.Add("Sec-WebSocket-Extensions", wsPMCReqHeaderValue) + } + if err := req.Write(nc.conn); err != nil { + return err + } + + var resp *http.Response + + br := bufio.NewReaderSize(nc.conn, 4096) + nc.conn.SetReadDeadline(time.Now().Add(nc.Opts.Timeout)) + resp, err = http.ReadResponse(br, req) + if err == nil && + (resp.StatusCode != 101 || + !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || + !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || + resp.Header.Get("Sec-Websocket-Accept") != wsAcceptKey(wsKey)) { + + err = fmt.Errorf("invalid websocket connection") + } + // Check compression extension... + if err == nil && compress { + // Check that not only permessage-deflate extension is present, but that + // we also have server and client no context take over. + srvCompress, noCtxTakeover := wsPMCExtensionSupport(resp.Header) + + // If server does not support compression, then simply disable it in our side. + if !srvCompress { + compress = false + } else if !noCtxTakeover { + err = fmt.Errorf("compression negotiation error") + } + } + if resp != nil { + resp.Body.Close() + } + nc.conn.SetReadDeadline(time.Time{}) + if err != nil { + return err + } + + wsr := wsNewReader(nc.br.r) + wsr.nc = nc + // We have to slurp whatever is in the bufio reader and copy to br.r + if n := br.Buffered(); n != 0 { + wsr.ib, _ = br.Peek(n) + } + nc.br.r = wsr + nc.bw.w = &websocketWriter{w: nc.bw.w, compress: compress} + nc.ws = true + return nil +} + +func (nc *Conn) wsClose() { + nc.mu.Lock() + defer nc.mu.Unlock() + if !nc.ws { + return + } + nc.wsEnqueueCloseMsgLocked(wsCloseStatusNormalClosure, _EMPTY_) +} + +func (nc *Conn) wsEnqueueCloseMsg(status int, payload string) { + // In some low-level unit tests it will happen... + if nc == nil { + return + } + nc.mu.Lock() + nc.wsEnqueueCloseMsgLocked(status, payload) + nc.mu.Unlock() +} + +func (nc *Conn) wsEnqueueCloseMsgLocked(status int, payload string) { + wr, ok := nc.bw.w.(*websocketWriter) + if !ok || wr.cmDone { + return + } + statusAndPayloadLen := 2 + len(payload) + frame := make([]byte, 2+4+statusAndPayloadLen) + n, key := wsFillFrameHeader(frame, false, wsCloseMessage, statusAndPayloadLen) + // Set the status + binary.BigEndian.PutUint16(frame[n:], uint16(status)) + // If there is a payload, copy + if len(payload) > 0 { + copy(frame[n+2:], payload) + } + // Mask status + payload + wsMaskBuf(key, frame[n:n+statusAndPayloadLen]) + wr.cm = frame + wr.cmDone = true + nc.bw.flush() +} + +func (nc *Conn) wsEnqueueControlMsg(frameType wsOpCode, payload []byte) { + // In some low-level unit tests it will happen... + if nc == nil { + return + } + fh, key := wsCreateFrameHeader(false, frameType, len(payload)) + nc.mu.Lock() + wr, ok := nc.bw.w.(*websocketWriter) + if !ok { + nc.mu.Unlock() + return + } + wr.ctrlFrames = append(wr.ctrlFrames, fh) + if len(payload) > 0 { + wsMaskBuf(key, payload) + wr.ctrlFrames = append(wr.ctrlFrames, payload) + } + nc.bw.flush() + nc.mu.Unlock() +} + +func wsPMCExtensionSupport(header http.Header) (bool, bool) { + for _, extensionList := range header["Sec-Websocket-Extensions"] { + extensions := strings.Split(extensionList, ",") + for _, extension := range extensions { + extension = strings.Trim(extension, " \t") + params := strings.Split(extension, ";") + for i, p := range params { + p = strings.Trim(p, " \t") + if strings.EqualFold(p, wsPMCExtension) { + var snc bool + var cnc bool + for j := i + 1; j < len(params); j++ { + p = params[j] + p = strings.Trim(p, " \t") + if strings.EqualFold(p, wsPMCSrvNoCtx) { + snc = true + } else if strings.EqualFold(p, wsPMCCliNoCtx) { + cnc = true + } + if snc && cnc { + return true, true + } + } + return true, false + } + } + } + } + return false, false +} + +func wsMakeChallengeKey() (string, error) { + p := make([]byte, 16) + if _, err := io.ReadFull(rand.Reader, p); err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(p), nil +} + +func wsAcceptKey(key string) string { + h := sha1.New() + h.Write([]byte(key)) + h.Write(wsGUID) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} + +// Returns true if the op code corresponds to a control frame. +func wsIsControlFrame(frameType wsOpCode) bool { + return frameType >= wsCloseMessage +} + +func isWebsocketScheme(u *url.URL) bool { + return u.Scheme == wsScheme || u.Scheme == wsSchemeTLS +} diff --git a/ws_test.go b/ws_test.go new file mode 100644 index 000000000..1caded1a6 --- /dev/null +++ b/ws_test.go @@ -0,0 +1,977 @@ +// Copyright 2021 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package nats + +import ( + "bytes" + "crypto/tls" + "encoding/binary" + "fmt" + "io" + "math/rand" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/nats-io/nats-server/v2/server" + serverTest "github.com/nats-io/nats-server/v2/test" + "github.com/nats-io/nuid" +) + +func testWSGetDefaultOptions(t *testing.T, tls bool) *server.Options { + t.Helper() + sopts := serverTest.DefaultTestOptions + sopts.Host = "127.0.0.1" + sopts.Port = -1 + sopts.Websocket.Host = "127.0.0.1" + sopts.Websocket.Port = -1 + sopts.Websocket.NoTLS = !tls + if tls { + tc := &server.TLSConfigOpts{ + CertFile: "./test/configs/certs/server.pem", + KeyFile: "./test/configs/certs/key.pem", + CaFile: "./test/configs/certs/ca.pem", + } + tlsConfig, err := server.GenTLSConfig(tc) + if err != nil { + t.Fatalf("Can't build TLCConfig: %v", err) + } + sopts.Websocket.TLSConfig = tlsConfig + } + return &sopts +} + +type fakeReader struct { + mu sync.Mutex + buf bytes.Buffer + ch chan []byte + closed bool +} + +func (f *fakeReader) Read(p []byte) (int, error) { + f.mu.Lock() + closed := f.closed + f.mu.Unlock() + if closed { + return 0, io.EOF + } + for { + if f.buf.Len() > 0 { + n, err := f.buf.Read(p) + return n, err + } + buf, ok := <-f.ch + if !ok { + return 0, io.EOF + } + f.buf.Write(buf) + } +} + +func (f *fakeReader) close() { + f.mu.Lock() + defer f.mu.Unlock() + if f.closed { + return + } + f.closed = true + close(f.ch) +} + +func TestWSReader(t *testing.T) { + mr := &fakeReader{ch: make(chan []byte, 1)} + defer mr.close() + r := wsNewReader(mr) + + p := make([]byte, 100) + checkRead := func(limit int, expected []byte, lenPending int) { + t.Helper() + n, err := r.Read(p[:limit]) + if err != nil { + t.Fatalf("Error reading: %v", err) + } + if !bytes.Equal(p[:n], expected) { + t.Fatalf("Expected %q, got %q", expected, p[:n]) + } + if len(r.pending) != lenPending { + t.Fatalf("Expected len(r.pending) to be %v, got %v", lenPending, len(r.pending)) + } + } + + // Test with a buffer that contains a single pending with all data that + // fits in the read buffer. + mr.buf.Write([]byte{130, 10}) + mr.buf.WriteString("ABCDEFGHIJ") + checkRead(100, []byte("ABCDEFGHIJ"), 0) + + // Write 2 frames in the buffer. Since we will call with a read buffer + // that can fit both, we will create 2 pending and consume them at once. + mr.buf.Write([]byte{130, 5}) + mr.buf.WriteString("ABCDE") + mr.buf.Write([]byte{130, 5}) + mr.buf.WriteString("FGHIJ") + checkRead(100, []byte("ABCDEFGHIJ"), 0) + + // We also write 2 frames, but this time we will call the first read + // with a read buffer that can accommodate only the first frame. + // So internally only a single frame is going to be read in pending. + mr.buf.Write([]byte{130, 5}) + mr.buf.WriteString("ABCDE") + mr.buf.Write([]byte{130, 5}) + mr.buf.WriteString("FGHIJ") + checkRead(6, []byte("ABCDE"), 0) + checkRead(100, []byte("FGHIJ"), 0) + + // To test partials, we need to directly set the pending buffers. + r.pending = append(r.pending, []byte("ABCDE")) + r.pending = append(r.pending, []byte("FGHIJ")) + // Now check that the first read cannot get the full first pending + // buffer and gets only a partial. + checkRead(3, []byte("ABC"), 2) + // Since the read buffer is big enough to get everything else, after + // this call we should have no pending. + checkRead(7, []byte("DEFGHIJ"), 0) + + // Similar to above but with both partials. + r.pending = append(r.pending, []byte("ABCDE")) + r.pending = append(r.pending, []byte("FGHIJ")) + checkRead(3, []byte("ABC"), 2) + // Exact amount of the partial of 1st pending + checkRead(2, []byte("DE"), 1) + checkRead(3, []byte("FGH"), 1) + // More space in read buffer than last partial + checkRead(10, []byte("IJ"), 0) + + // This test the fact that read will return only when a frame is complete. + mr.buf.Write([]byte{130, 5}) + mr.buf.WriteString("AB") + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + time.Sleep(100 * time.Millisecond) + mr.ch <- []byte{'C', 'D', 'E', 130, 2, 'F', 'G'} + wg.Done() + }() + // Read() will get "load" only the first frame, so after this call there + // should be no pending. + checkRead(100, []byte("ABCDE"), 0) + // This will load the second frame. + checkRead(100, []byte("FG"), 0) + wg.Wait() + + // Set the buffer that may be populated during the init handshake. + // Make sure that we process that one first. + r.ib = []byte{130, 4, 'A', 'B'} + mr.buf.WriteString("CD") + mr.buf.Write([]byte{130, 2}) + mr.buf.WriteString("EF") + // This will only read up to ABCD and have no pending after the call. + checkRead(100, []byte("ABCD"), 0) + // We need another Read() call to read/load the second frame. + checkRead(100, []byte("EF"), 0) + + // Close the underlying reader while reading. + mr.buf.Write([]byte{130, 4, 'A', 'B'}) + wg.Add(1) + go func() { + time.Sleep(100 * time.Millisecond) + mr.close() + wg.Done() + }() + if _, err := r.Read(p); err != io.EOF { + t.Fatalf("Expected EOF, got %v", err) + } + wg.Wait() +} + +func TestWSParseControlFrames(t *testing.T) { + mr := &fakeReader{ch: make(chan []byte, 1)} + defer mr.close() + r := wsNewReader(mr) + + p := make([]byte, 100) + + // Write a PING + mr.buf.Write([]byte{137, 0}) + n, err := r.Read(p) + if err != nil || n != 0 { + t.Fatalf("Error on read: n=%v err=%v", n, err) + } + + // Write a PONG + mr.buf.Write([]byte{138, 0}) + n, err = r.Read(p) + if err != nil || n != 0 { + t.Fatalf("Error on read: n=%v err=%v", n, err) + } + + // Write a CLOSE + mr.buf.Write([]byte{136, 6, 3, 232, 't', 'e', 's', 't'}) + n, err = r.Read(p) + if err != io.EOF || n != 0 { + t.Fatalf("Error on read: n=%v err=%v", n, err) + } +} + +func TestWSParseInvalidFrames(t *testing.T) { + + newReader := func() (*fakeReader, *websocketReader) { + mr := &fakeReader{} + r := wsNewReader(mr) + return mr, r + } + + p := make([]byte, 100) + + // Invalid utf-8 of close message + mr, r := newReader() + mr.buf.Write([]byte{136, 6, 3, 232, 't', 'e', 0xF1, 't'}) + n, err := r.Read(p) + if err != io.EOF || n != 0 { + t.Fatalf("Error on read: n=%v err=%v", n, err) + } + + // control frame length too long + mr, r = newReader() + mr.buf.Write([]byte{137, 126, 0, wsMaxControlPayloadSize + 10}) + for i := 0; i < wsMaxControlPayloadSize+10; i++ { + mr.buf.WriteByte('a') + } + n, err = r.Read(p) + if n != 0 || err == nil || !strings.Contains(err.Error(), "maximum") { + t.Fatalf("Unexpected error: n=%v err=%v", n, err) + } + + // Not a final frame + mr, r = newReader() + mr.buf.Write([]byte{byte(wsPingMessage), 0}) + n, err = r.Read(p[:2]) + if n != 0 || err == nil || !strings.Contains(err.Error(), "final") { + t.Fatalf("Unexpected error: n=%v err=%v", n, err) + } + + // Marked as compressed + mr, r = newReader() + mr.buf.Write([]byte{byte(wsPingMessage) | wsRsv1Bit, 0}) + n, err = r.Read(p[:2]) + if n != 0 || err == nil || !strings.Contains(err.Error(), "compressed") { + t.Fatalf("Unexpected error: n=%v err=%v", n, err) + } + + // Continuation frame marked as compressed + mr, r = newReader() + mr.buf.Write([]byte{2, 3}) + mr.buf.WriteString("ABC") + mr.buf.Write([]byte{0 | wsRsv1Bit, 3}) + mr.buf.WriteString("DEF") + n, err = r.Read(p) + if n != 0 || err == nil || !strings.Contains(err.Error(), "invalid continuation frame") { + t.Fatalf("Unexpected error: n=%v err=%v", n, err) + } + + // Continuation frame after a final frame + mr, r = newReader() + mr.buf.Write([]byte{130, 3}) + mr.buf.WriteString("ABC") + mr.buf.Write([]byte{0, 3}) + mr.buf.WriteString("DEF") + n, err = r.Read(p) + if n != 0 || err == nil || !strings.Contains(err.Error(), "invalid continuation frame") { + t.Fatalf("Unexpected error: n=%v err=%v", n, err) + } + + // New message received before previous ended + mr, r = newReader() + mr.buf.Write([]byte{2, 3}) + mr.buf.WriteString("ABC") + mr.buf.Write([]byte{0, 3}) + mr.buf.WriteString("DEF") + mr.buf.Write([]byte{130, 3}) + mr.buf.WriteString("GHI") + n, err = r.Read(p) + if n != 0 || err == nil || !strings.Contains(err.Error(), "started before final frame") { + t.Fatalf("Unexpected error: n=%v err=%v", n, err) + } + + // Unknown frame type + mr, r = newReader() + mr.buf.Write([]byte{99, 3}) + mr.buf.WriteString("ABC") + n, err = r.Read(p) + if n != 0 || err == nil || !strings.Contains(err.Error(), "unknown opcode") { + t.Fatalf("Unexpected error: n=%v err=%v", n, err) + } +} + +func TestWSControlFrameBetweenDataFrames(t *testing.T) { + mr := &fakeReader{ch: make(chan []byte, 1)} + defer mr.close() + r := wsNewReader(mr) + + p := make([]byte, 100) + + // Write a frame that will continue after the PONG + mr.buf.Write([]byte{2, 3}) + mr.buf.WriteString("ABC") + // Write a PONG + mr.buf.Write([]byte{138, 0}) + // Continuation of the frame + mr.buf.Write([]byte{0, 3}) + mr.buf.WriteString("DEF") + // Another PONG + mr.buf.Write([]byte{138, 0}) + // End of frame + mr.buf.Write([]byte{128, 3}) + mr.buf.WriteString("GHI") + + n, err := r.Read(p) + if err != nil { + t.Fatalf("Error on read: %v", err) + } + if string(p[:n]) != "ABCDEFGHI" { + t.Fatalf("Unexpected result: %q", p[:n]) + } +} + +func TestWSDecompressorBuffer(t *testing.T) { + br := newDecompressorBuffer([]byte("ABCDE")) + + p := make([]byte, 100) + checkRead := func(limit int, expected []byte) { + t.Helper() + n, err := br.Read(p[:limit]) + if err != nil { + t.Fatalf("Error on read: %v", err) + } + if got := p[:n]; !bytes.Equal(expected, got) { + t.Fatalf("Expected %v, got %v", expected, got) + } + } + checkEOF := func() { + t.Helper() + n, err := br.Read(p) + if err != io.EOF || n > 0 { + t.Fatalf("Unexpected result: n=%v err=%v", n, err) + } + } + checkReadByte := func(expected byte) { + t.Helper() + b, err := br.ReadByte() + if err != nil { + t.Fatalf("Error on read: %v", err) + } + if b != expected { + t.Fatalf("Expected %c, got %c", expected, b) + } + } + checkEOFWithReadByte := func() { + t.Helper() + n, err := br.ReadByte() + if err != io.EOF || n > 0 { + t.Fatalf("Unexpected result: n=%v err=%v", n, err) + } + } + + // Read with enough room + checkRead(100, []byte("ABCDE")) + checkRead(100, compressFinalBlock) + checkEOF() + checkEOFWithReadByte() + + // Read with a partial from our buffer + br = newDecompressorBuffer([]byte("FGHIJ")) + checkRead(2, []byte("FG")) + // Call with more than the end of our buffer. We will have to + // call again to start with the final block + checkRead(10, []byte("HIJ")) + checkRead(10, compressFinalBlock) + checkEOF() + checkEOFWithReadByte() + + // Read with a partial from our buffer + br = newDecompressorBuffer([]byte("KLMNO")) + checkRead(2, []byte("KL")) + // Call with exact number of bytes left for our buffer. + checkRead(3, []byte("MNO")) + checkRead(10, compressFinalBlock) + checkEOF() + checkEOFWithReadByte() + + // Now check partial of the final block + br = newDecompressorBuffer([]byte("PQRST")) + checkRead(10, []byte("PQRST")) + checkRead(2, compressFinalBlock[:2]) + checkRead(4, compressFinalBlock[2:6]) + checkRead(3, compressFinalBlock[6:9]) + checkEOF() + checkEOFWithReadByte() + + // Finally, check ReadByte. + br = newDecompressorBuffer([]byte("UVWXYZ")) + checkRead(4, []byte("UVWX")) + checkReadByte('Y') + checkReadByte('Z') + checkReadByte(compressFinalBlock[0]) + checkReadByte(compressFinalBlock[1]) + checkRead(5, compressFinalBlock[2:7]) + checkReadByte(compressFinalBlock[7]) + checkReadByte(compressFinalBlock[8]) + checkEOFWithReadByte() + checkEOF() +} + +func TestWSNoMixingScheme(t *testing.T) { + // Check opts.Connect() first + for _, test := range []struct { + url string + servers []string + }{ + {"ws://127.0.0.1:1234", []string{"nats://127.0.0.1:1235"}}, + {"ws://127.0.0.1:1234", []string{"ws://127.0.0.1:1235", "nats://127.0.0.1:1236"}}, + {"ws://127.0.0.1:1234", []string{"wss://127.0.0.1:1235", "nats://127.0.0.1:1236"}}, + {"wss://127.0.0.1:1234", []string{"nats://127.0.0.1:1235"}}, + {"wss://127.0.0.1:1234", []string{"wss://127.0.0.1:1235", "nats://127.0.0.1:1236"}}, + {"wss://127.0.0.1:1234", []string{"ws://127.0.0.1:1235", "nats://127.0.0.1:1236"}}, + } { + t.Run("Options", func(t *testing.T) { + opts := GetDefaultOptions() + opts.Url = test.url + opts.Servers = test.servers + nc, err := opts.Connect() + if err == nil || !strings.Contains(err.Error(), "mixing") { + if nc != nil { + nc.Close() + } + t.Fatalf("Expected error about mixing, got %v", err) + } + }) + } + // Check Connect() now. + for _, test := range []struct { + urls string + servers []string + }{ + {"ws://127.0.0.1:1234,nats://127.0.0.1:1235", nil}, + {"ws://127.0.0.1:1234,tcp://127.0.0.1:1235", nil}, + {"ws://127.0.0.1:1234,tls://127.0.0.1:1235", nil}, + {"nats://127.0.0.1:1234,ws://127.0.0.1:1235", nil}, + {"nats://127.0.0.1:1234,wss://127.0.0.1:1235", nil}, + {"nats://127.0.0.1:1234,tls://127.0.0.1:1235,ws://127.0.0.1:1236", nil}, + {"nats://127.0.0.1:1234,tls://127.0.0.1:1235,wss://127.0.0.1:1236", nil}, + // In Connect(), the URL is ignored when Servers() is provided. + {"", []string{"nats://127.0.0.1:1235", "ws://127.0.0.1:1236"}}, + {"", []string{"nats://127.0.0.1:1235", "wss://127.0.0.1:1236"}}, + {"", []string{"ws://127.0.0.1:1235", "nats://127.0.0.1:1236"}}, + {"", []string{"wss://127.0.0.1:1235", "nats://127.0.0.1:1236"}}, + } { + t.Run("Connect", func(t *testing.T) { + var opt Option + if len(test.servers) > 0 { + opt = func(o *Options) error { + o.Servers = test.servers + return nil + } + } + nc, err := Connect(test.urls, opt) + if err == nil || !strings.Contains(err.Error(), "mixing") { + if nc != nil { + nc.Close() + } + t.Fatalf("Expected error about mixing, got %v", err) + } + }) + } +} + +func TestWSBasic(t *testing.T) { + sopts := testWSGetDefaultOptions(t, false) + s := RunServerWithOptions(sopts) + defer s.Shutdown() + + url := fmt.Sprintf("ws://127.0.0.1:%d", sopts.Websocket.Port) + nc, err := Connect(url) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer nc.Close() + + sub, err := nc.SubscribeSync("foo") + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + + msgs := make([][]byte, 100) + for i := 0; i < len(msgs); i++ { + msg := make([]byte, rand.Intn(70000)) + for j := 0; j < len(msg); j++ { + msg[j] = 'A' + byte(rand.Intn(26)) + } + msgs[i] = msg + } + for i, msg := range msgs { + if err := nc.Publish("foo", msg); err != nil { + t.Fatalf("Error on publish: %v", err) + } + // Make sure that masking does not overwrite user data + if !bytes.Equal(msgs[i], msg) { + t.Fatalf("User content has been changed: %v, got %v", msgs[i], msg) + } + } + + for i := 0; i < len(msgs); i++ { + msg, err := sub.NextMsg(time.Second) + if err != nil { + t.Fatalf("Error getting next message: %v", err) + } + if !bytes.Equal(msgs[i], msg.Data) { + t.Fatalf("Expected message: %v, got %v", msgs[i], msg) + } + } +} + +func TestWSControlFrames(t *testing.T) { + sopts := testWSGetDefaultOptions(t, false) + s := RunServerWithOptions(sopts) + defer s.Shutdown() + + rch := make(chan bool, 10) + ncSub, err := Connect(s.ClientURL(), + ReconnectWait(50*time.Millisecond), + ReconnectHandler(func(_ *Conn) { rch <- true }), + ) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer ncSub.Close() + + sub, err := ncSub.SubscribeSync("foo") + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + if err := ncSub.Flush(); err != nil { + t.Fatalf("Error on flush: %v", err) + } + + dch := make(chan error, 10) + url := fmt.Sprintf("ws://127.0.0.1:%d", sopts.Websocket.Port) + nc, err := Connect(url, + ReconnectWait(50*time.Millisecond), + DisconnectErrHandler(func(_ *Conn, err error) { dch <- err }), + ) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer nc.Close() + + // Enqueue a PING and make sure that we don't break + nc.wsEnqueueControlMsg(wsPingMessage, []byte("this is a ping payload")) + select { + case e := <-dch: + t.Fatal(e) + case <-time.After(250 * time.Millisecond): + // OK + } + + // Shutdown the server, which should send a close message, which by + // spec the client will try to echo back. + s.Shutdown() + + select { + case <-dch: + // OK + case <-time.After(time.Second): + t.Fatal("Should have been disconnected") + } + + s = RunServerWithOptions(sopts) + defer s.Shutdown() + + // Wait to reconnect + if err := Wait(rch); err != nil { + t.Fatalf("Should have reconnected: %v", err) + } + + // Publish and close connection. + if err := nc.Publish("foo", []byte("msg")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + if err := nc.Flush(); err != nil { + t.Fatalf("Error on flush: %v", err) + } + nc.Close() + + if _, err := sub.NextMsg(time.Second); err != nil { + t.Fatalf("Did not get message: %v", err) + } +} + +func TestWSConcurrentConns(t *testing.T) { + sopts := testWSGetDefaultOptions(t, false) + s := RunServerWithOptions(sopts) + defer s.Shutdown() + + url := fmt.Sprintf("ws://127.0.0.1:%d", sopts.Websocket.Port) + + total := 50 + errCh := make(chan error, total) + wg := sync.WaitGroup{} + wg.Add(total) + for i := 0; i < total; i++ { + go func() { + defer wg.Done() + + nc, err := Connect(url) + if err != nil { + errCh <- fmt.Errorf("Error on connect: %v", err) + return + } + defer nc.Close() + + sub, err := nc.SubscribeSync(nuid.Next()) + if err != nil { + errCh <- fmt.Errorf("Error on subscribe: %v", err) + return + } + nc.Publish(sub.Subject, []byte("here")) + if _, err := sub.NextMsg(time.Second); err != nil { + errCh <- err + } + }() + } + wg.Wait() + select { + case e := <-errCh: + t.Fatal(e.Error()) + default: + } +} + +func TestWSCompression(t *testing.T) { + msgSize := rand.Intn(40000) + for _, test := range []struct { + name string + srvCompression bool + cliCompression bool + }{ + {"srv_off_cli_off", false, false}, + {"srv_off_cli_on", false, true}, + {"srv_on_cli_off", true, false}, + {"srv_on_cli_on", true, true}, + } { + t.Run(test.name, func(t *testing.T) { + sopts := testWSGetDefaultOptions(t, false) + sopts.Websocket.Compression = test.srvCompression + s := RunServerWithOptions(sopts) + defer s.Shutdown() + + url := fmt.Sprintf("ws://127.0.0.1:%d", sopts.Websocket.Port) + var opts []Option + if test.cliCompression { + opts = append(opts, Compression(true)) + } + nc, err := Connect(url, opts...) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer nc.Close() + + sub, err := nc.SubscribeSync("foo") + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + + msgs := make([][]byte, 100) + for i := 0; i < len(msgs); i++ { + msg := make([]byte, msgSize) + for j := 0; j < len(msg); j++ { + msg[j] = 'A' + } + msgs[i] = msg + } + for i, msg := range msgs { + if err := nc.Publish("foo", msg); err != nil { + t.Fatalf("Error on publish: %v", err) + } + // Make sure that compression/masking does not touch user data + if !bytes.Equal(msgs[i], msg) { + t.Fatalf("User content has been changed: %v, got %v", msgs[i], msg) + } + } + + for i := 0; i < len(msgs); i++ { + msg, err := sub.NextMsg(time.Second) + if err != nil { + t.Fatalf("Error getting next message (%d): %v", i+1, err) + } + if !bytes.Equal(msgs[i], msg.Data) { + t.Fatalf("Expected message (%d): %v, got %v", i+1, msgs[i], msg) + } + } + }) + } +} + +func TestWSWithTLS(t *testing.T) { + for _, test := range []struct { + name string + compression bool + }{ + {"without compression", false}, + {"with compression", true}, + } { + t.Run(test.name, func(t *testing.T) { + sopts := testWSGetDefaultOptions(t, true) + sopts.Websocket.Compression = test.compression + s := RunServerWithOptions(sopts) + defer s.Shutdown() + + var copts []Option + if test.compression { + copts = append(copts, Compression(true)) + } + + // Check that we fail to connect without proper TLS configuration. + nc, err := Connect(fmt.Sprintf("ws://localhost:%d", sopts.Websocket.Port), copts...) + if err == nil { + if nc != nil { + nc.Close() + } + t.Fatal("Expected error, got none") + } + + // Same but with wss protocol, which should translate to TLS, however, + // since we used self signed certificates, this should fail without + // asking to skip server cert verification. + nc, err = Connect(fmt.Sprintf("wss://localhost:%d", sopts.Websocket.Port), copts...) + if err == nil || !strings.Contains(err.Error(), "authority") { + if nc != nil { + nc.Close() + } + t.Fatalf("Expected error about unknown authority: %v", err) + } + + // Skip server verification and we should be good. + copts = append(copts, Secure(&tls.Config{InsecureSkipVerify: true})) + nc, err = Connect(fmt.Sprintf("wss://localhost:%d", sopts.Websocket.Port), copts...) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer nc.Close() + + sub, err := nc.SubscribeSync("foo") + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + if err := nc.Publish("foo", []byte("hello")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + if msg, err := sub.NextMsg(time.Second); err != nil { + t.Fatalf("Did not get message: %v", err) + } else if got := string(msg.Data); got != "hello" { + t.Fatalf("Expected %q, got %q", "hello", got) + } + }) + } +} + +func TestWSGossipAndReconnect(t *testing.T) { + o1 := testWSGetDefaultOptions(t, false) + o1.ServerName = "A" + o1.Cluster.Host = "127.0.0.1" + o1.Cluster.Name = "abc" + o1.Cluster.Port = -1 + s1 := RunServerWithOptions(o1) + defer s1.Shutdown() + + o2 := testWSGetDefaultOptions(t, false) + o2.ServerName = "B" + o2.Cluster.Host = "127.0.0.1" + o2.Cluster.Name = "abc" + o2.Cluster.Port = -1 + o2.Routes = server.RoutesFromStr(fmt.Sprintf("nats://127.0.0.1:%d", o1.Cluster.Port)) + s2 := RunServerWithOptions(o2) + defer s2.Shutdown() + + rch := make(chan bool, 10) + url := fmt.Sprintf("ws://127.0.0.1:%d", o1.Websocket.Port) + nc, err := Connect(url, + ReconnectWait(50*time.Millisecond), + ReconnectHandler(func(_ *Conn) { rch <- true }), + ) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer nc.Close() + + timeout := time.Now().Add(time.Second) + for time.Now().Before(timeout) { + if len(nc.Servers()) > 1 { + break + } + time.Sleep(15 * time.Millisecond) + } + if len(nc.Servers()) == 1 { + t.Fatal("Did not discover server 2") + } + s1.Shutdown() + + // Wait for reconnect + if err := Wait(rch); err != nil { + t.Fatalf("Did not reconnect: %v", err) + } + + // Now check that connection is still WS + nc.mu.Lock() + isWS := nc.ws + _, ok := nc.bw.w.(*websocketWriter) + nc.mu.Unlock() + + if !isWS { + t.Fatal("Connection is not marked as websocket") + } + if !ok { + t.Fatal("Connection writer is not websocket") + } +} + +func TestWSStress(t *testing.T) { + // Enable this test only when wanting to stress test the system, say after + // some changes in the library or if a bug is found. Also, don't run it + // with the `-race` flag! + t.SkipNow() + // Total producers (there will be 2 per subject) + prods := 4 + // Total messages sent + total := int64(1000000) + // Total messages received, there is 2 consumer per subject + totalRecv := 2 * total + // We will create a "golden" slice from which sent messages + // will be a subset of. Receivers will check that the content + // match the expected content. + maxPayloadSize := 100000 + mainPayload := make([]byte, maxPayloadSize) + for i := 0; i < len(mainPayload); i++ { + mainPayload[i] = 'A' + byte(rand.Intn(26)) + } + for _, test := range []struct { + name string + compress bool + }{ + {"no_compression", false}, + {"with_compression", true}, + } { + t.Run(test.name, func(t *testing.T) { + sopts := testWSGetDefaultOptions(t, false) + sopts.Websocket.Compression = test.compress + s := RunServerWithOptions(sopts) + defer s.Shutdown() + + createConn := func() *Conn { + t.Helper() + nc, err := Connect(fmt.Sprintf("ws://127.0.0.1:%d", sopts.Websocket.Port), + Compression(test.compress)) + if err != nil { + t.Fatalf("Error connecting: %v", err) + } + return nc + } + + var count int64 + consDoneCh := make(chan struct{}, 1) + errCh := make(chan error, 1) + prodDoneCh := make(chan struct{}, prods) + + pushErr := func(e error) { + select { + case errCh <- e: + default: + } + } + + cb := func(m *Msg) { + if len(m.Data) < 4 { + pushErr(fmt.Errorf("Message payload too small: %+v", m.Data)) + return + } + ps := int(binary.BigEndian.Uint32(m.Data[:4])) + if ps > maxPayloadSize { + pushErr(fmt.Errorf("Invalid message size: %v", ps)) + return + } + if !bytes.Equal(m.Data[4:4+ps], mainPayload[:ps]) { + pushErr(fmt.Errorf("invalid content")) + return + } + if atomic.AddInt64(&count, 1) == totalRecv { + consDoneCh <- struct{}{} + } + } + + subjects := []string{"foo", "bar"} + for _, subj := range subjects { + for i := 0; i < 2; i++ { + nc := createConn() + defer nc.Close() + if _, err := nc.Subscribe(subj, cb); err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + if err := nc.Flush(); err != nil { + t.Fatalf("Error on flush: %v", err) + } + } + } + + msgsPerProd := int(total / int64(prods)) + prodPerSubj := prods / len(subjects) + for _, subj := range subjects { + for i := 0; i < prodPerSubj; i++ { + go func(subj string) { + defer func() { prodDoneCh <- struct{}{} }() + + nc := createConn() + defer nc.Close() + + for i := 0; i < msgsPerProd; i++ { + // Have 80% of messages being rather small (<=1024) + maxSize := 1024 + if rand.Intn(100) > 80 { + maxSize = maxPayloadSize + } + ps := rand.Intn(maxSize) + msg := make([]byte, 4+ps) + binary.BigEndian.PutUint32(msg, uint32(ps)) + copy(msg[4:], mainPayload[:ps]) + if err := nc.Publish(subj, msg); err != nil { + pushErr(err) + return + } + } + }(subj) + } + } + + for i := 0; i < prods; i++ { + select { + case <-prodDoneCh: + case e := <-errCh: + t.Fatal(e) + } + } + // Now wait for all consumers to be done. + <-consDoneCh + }) + } +}