diff --git a/server/client.go b/server/client.go index 66115a1be6..2bd9dae9ec 100644 --- a/server/client.go +++ b/server/client.go @@ -16,12 +16,14 @@ package server import ( "bytes" "crypto/tls" + "crypto/x509" "encoding/json" "fmt" "io" "math/rand" "net" "net/http" + "net/url" "regexp" "runtime" "strconv" @@ -597,7 +599,11 @@ func (c *client) initClient() { case GATEWAY: c.ncs.Store(fmt.Sprintf("%s - gid:%d", conn, c.cid)) case LEAF: - c.ncs.Store(fmt.Sprintf("%s - lid:%d", conn, c.cid)) + var ws string + if c.isWebsocket() { + ws = "_ws" + } + c.ncs.Store(fmt.Sprintf("%s - lid%s:%d", conn, ws, c.cid)) case SYSTEM: c.ncs.Store("SYSTEM") case JETSTREAM: @@ -1018,6 +1024,10 @@ func (c *client) readLoop(pre []byte) { // Last per-account-cache check for closed subscriptions lpacc := time.Now() acc := c.acc + var masking bool + if ws { + masking = c.ws.maskread + } c.mu.Unlock() defer func() { @@ -1041,21 +1051,26 @@ func (c *client) readLoop(pre []byte) { var wsr *wsReadInfo if ws { - wsr = &wsReadInfo{} + wsr = &wsReadInfo{mask: masking} wsr.init() } - // If we have a pre buffer parse that first. - if len(pre) > 0 { - c.parse(pre) - } - for { - n, err := nc.Read(b) - // If we have any data we will try to parse and exit at the end. - if n == 0 && err != nil { - c.closeConnection(closedStateForErr(err)) - return + var n int + var err error + + // If we have a pre buffer parse that first. + if len(pre) > 0 { + b = pre + n = len(pre) + pre = nil + } else { + n, err = nc.Read(b) + // If we have any data we will try to parse and exit at the end. + if n == 0 && err != nil { + c.closeConnection(closedStateForErr(err)) + return + } } if ws { bufs, err = c.wsRead(wsr, nc, b[:n]) @@ -1154,8 +1169,15 @@ func (c *client) readLoop(pre []byte) { } // re-snapshot the account since it can change during reload, etc. acc = c.acc + // Refresh nc because in some cases, we have upgraded c.nc to TLS. + nc = c.nc c.mu.Unlock() + // Connection was closed + if nc == nil { + return + } + if dur := time.Since(start); dur >= readLoopReportThreshold { c.Warnf("Readloop processing time: %v", dur) } @@ -1415,12 +1437,12 @@ func (c *client) markConnAsClosed(reason ClosedState) { if !skipFlush && c.isWebsocket() && !c.ws.closeSent { c.wsEnqueueCloseMessage(reason) } - // Be consistent with the creation: for routes and gateways, + // Be consistent with the creation: for routes, gateways and leaf, // we use Noticef on create, so use that too for delete. if c.srv != nil { - if c.kind == ROUTER || c.kind == GATEWAY { + if c.kind == ROUTER || c.kind == GATEWAY || c.kind == LEAF { c.Noticef("%s connection closed: %s", c.typeString(), reason) - } else { // Client, System, Jetstream, Account and Leafnode connections. + } else { // Client, System, Jetstream, and Account connections. c.Debugf("%s connection closed: %s", c.typeString(), reason) } } @@ -1501,7 +1523,7 @@ func (c *client) processInfo(arg []byte) error { case GATEWAY: c.processGatewayInfo(&info) case LEAF: - return c.processLeafnodeInfo(&info) + c.processLeafnodeInfo(&info) } return nil } @@ -4672,6 +4694,102 @@ func (c *client) getClientInfo(detailed bool) *ClientInfo { return &ci } +func (c *client) doTLSServerHandshake(typ string, tlsConfig *tls.Config, timeout float64) error { + _, err := c.doTLSHandshake(typ, false, nil, tlsConfig, _EMPTY_, timeout) + return err +} + +func (c *client) doTLSClientHandshake(typ string, url *url.URL, tlsConfig *tls.Config, tlsName string, timeout float64) (bool, error) { + return c.doTLSHandshake(typ, true, url, tlsConfig, tlsName, timeout) +} + +// Performs eithe server or client side (if solicit is true) TLS Handshake. +// On error, the TLS handshake error has been logged and the connection +// has been closed. +// +// Lock is held on entry. +func (c *client) doTLSHandshake(typ string, solicit bool, url *url.URL, tlsConfig *tls.Config, tlsName string, timeout float64) (bool, error) { + var host string + var resetTLSName bool + var err error + + // Capture kind for some debug/error statements. + kind := c.kind + + // If we solicited, we will act like the client, otherwise the server. + if solicit { + c.Debugf("Starting TLS %s client handshake", typ) + if tlsConfig.ServerName == _EMPTY_ { + // If the given url is a hostname, use this hostname for the + // ServerName. If it is an IP, use the cfg's tlsName. If none + // is available, resort to current IP. + host = url.Hostname() + if tlsName != _EMPTY_ && net.ParseIP(host) != nil { + host = tlsName + } + tlsConfig.ServerName = host + } + c.nc = tls.Client(c.nc, tlsConfig) + } else { + if kind == CLIENT { + c.Debugf("Starting TLS client connection handshake") + } else { + c.Debugf("Starting TLS %s server handshake", typ) + } + c.nc = tls.Server(c.nc, tlsConfig) + } + + conn := c.nc.(*tls.Conn) + + // Setup the timeout + ttl := secondsToDuration(timeout) + time.AfterFunc(ttl, func() { tlsTimeout(c, conn) }) + conn.SetReadDeadline(time.Now().Add(ttl)) + + c.mu.Unlock() + if err = conn.Handshake(); err != nil { + if solicit { + // Based on type of error, possibly clear the saved tlsName + // See: https://github.com/nats-io/nats-server/issues/1256 + if _, ok := err.(x509.HostnameError); ok { + if host == tlsName { + resetTLSName = true + } + } + } + if kind == CLIENT { + c.Errorf("TLS handshake error: %v", err) + } else { + c.Errorf("TLS %s handshake error: %v", typ, err) + } + c.closeConnection(TLSHandshakeError) + + // Grab the lock before returning since the caller was holding the lock on entry + c.mu.Lock() + // Returning any error is fine. Since the connection is closed ErrConnectionClosed + // is appropriate. + return resetTLSName, ErrConnectionClosed + } + + // Reset the read deadline + conn.SetReadDeadline(time.Time{}) + + // Re-Grab lock + c.mu.Lock() + + // To be consistent with client, set this flag to indicate that handshake is done + c.flags.set(handshakeComplete) + + // The connection still may have been closed on success handshake due + // to a race with tls timeout. If that the case, return error indicating + // that the connection is closed. + if err == nil && c.isClosed() { + err = ErrConnectionClosed + } + + return false, err +} + // getRAwAuthUser returns the raw auth user for the client. // Lock should be held. func (c *client) getRawAuthUser() string { diff --git a/server/gateway.go b/server/gateway.go index c6f337f735..d661aaa357 100644 --- a/server/gateway.go +++ b/server/gateway.go @@ -17,7 +17,6 @@ import ( "bytes" "crypto/sha256" "crypto/tls" - "crypto/x509" "encoding/json" "fmt" "math/rand" @@ -733,11 +732,6 @@ func (s *Server) createGateway(cfg *gatewayCfg, url *url.URL, conn net.Conn) { // Are we creating the gateway based on the configuration solicit := cfg != nil var tlsRequired bool - if solicit { - tlsRequired = cfg.TLSConfig != nil - } else { - tlsRequired = opts.Gateway.TLSConfig != nil - } s.gateway.RLock() infoJSON := s.gateway.infoJSON @@ -749,86 +743,51 @@ func (s *Server) createGateway(cfg *gatewayCfg, url *url.URL, conn net.Conn) { c.gw = &gateway{} if solicit { // This is an outbound gateway connection + cfg.RLock() + tlsRequired = cfg.TLSConfig != nil + cfgName := cfg.Name + cfg.RUnlock() c.gw.outbound = true - c.gw.name = cfg.Name + c.gw.name = cfgName c.gw.cfg = cfg cfg.bumpConnAttempts() // Since we are delaying the connect until after receiving // the remote's INFO protocol, save the URL we need to connect to. c.gw.connectURL = url - c.Noticef("Creating outbound gateway connection to %q", cfg.Name) + c.Noticef("Creating outbound gateway connection to %q", cfgName) } else { c.flags.set(expectConnect) // Inbound gateway connection c.Noticef("Processing inbound gateway connection") + // Check if TLS is required for inbound GW connections. + tlsRequired = opts.Gateway.TLSConfig != nil } // Check for TLS if tlsRequired { - var host string + var tlsConfig *tls.Config + var tlsName string var timeout float64 - // If we solicited, we will act like the client, otherwise the server. + if solicit { - c.Debugf("Starting TLS gateway client handshake") cfg.RLock() - tlsName := cfg.tlsName - tlsConfig := cfg.TLSConfig.Clone() + tlsName = cfg.tlsName + tlsConfig = cfg.TLSConfig.Clone() timeout = cfg.TLSTimeout cfg.RUnlock() - if tlsConfig.ServerName == "" { - // If the given url is a hostname, use this hostname for the - // ServerName. If it is an IP, use the cfg's tlsName. If none - // is available, resort to current IP. - host = url.Hostname() - if tlsName != "" && net.ParseIP(host) != nil { - host = tlsName - } - tlsConfig.ServerName = host - } - c.nc = tls.Client(c.nc, tlsConfig) } else { - c.Debugf("Starting TLS gateway server handshake") - c.nc = tls.Server(c.nc, opts.Gateway.TLSConfig) + tlsConfig = opts.Gateway.TLSConfig timeout = opts.Gateway.TLSTimeout } - conn := c.nc.(*tls.Conn) - - // Setup the timeout - ttl := secondsToDuration(timeout) - time.AfterFunc(ttl, func() { tlsTimeout(c, conn) }) - conn.SetReadDeadline(time.Now().Add(ttl)) - - c.mu.Unlock() - if err := conn.Handshake(); err != nil { - if solicit { - // Based on type of error, possibly clear the saved tlsName - // See: https://github.com/nats-io/nats-server/issues/1256 - if _, ok := err.(x509.HostnameError); ok { - cfg.Lock() - if host == cfg.tlsName { - cfg.tlsName = "" - } - cfg.Unlock() - } + // Perform (either server or client side) TLS handshake. + if resetTLSName, err := c.doTLSHandshake("gateway", solicit, url, tlsConfig, tlsName, timeout); err != nil { + if resetTLSName { + cfg.Lock() + cfg.tlsName = _EMPTY_ + cfg.Unlock() } - c.Errorf("TLS gateway handshake error: %v", err) - c.sendErr("Secure Connection - TLS Required") - c.closeConnection(TLSHandshakeError) - return - } - // Reset the read deadline - conn.SetReadDeadline(time.Time{}) - - // Re-Grab lock - c.mu.Lock() - - // To be consistent with client, set this flag to indicate that handshake is done - c.flags.set(handshakeComplete) - - // Verify that the connection did not go away while we released the lock. - if c.isClosed() { c.mu.Unlock() return } diff --git a/server/leafnode.go b/server/leafnode.go index 87a54b3afe..d1e1196027 100644 --- a/server/leafnode.go +++ b/server/leafnode.go @@ -17,13 +17,12 @@ import ( "bufio" "bytes" "crypto/tls" - "crypto/x509" "encoding/base64" "encoding/json" "fmt" - "io" "io/ioutil" "net" + "net/http" "net/url" "reflect" "regexp" @@ -51,6 +50,10 @@ const leafNodeReconnectAfterPermViolation = 30 * time.Second // Prefix for loop detection subject const leafNodeLoopDetectionSubjectPrefix = "$LDS." +// Path added to URL to indicate to WS server that the connection is a +// LEAF connection as opposed to a CLIENT. +const leafNodeWSPath = "/leafnode" + type leaf struct { // We have any auth stuff here for solicited connections. remote *leafNodeCfg @@ -180,6 +183,23 @@ func validateLeafNode(o *Options) error { } } + // If a remote has a websocket scheme, all need to have it. + for _, rcfg := range o.LeafNode.Remotes { + if len(rcfg.URLs) >= 2 { + firstIsWS, ok := isWSURL(rcfg.URLs[0]), true + for i := 1; i < len(rcfg.URLs); i++ { + u := rcfg.URLs[i] + if isWS := isWSURL(u); isWS && !firstIsWS || !isWS && firstIsWS { + ok = false + break + } + } + if !ok { + return fmt.Errorf("remote leaf node configuration cannot have a mix of websocket and non-websocket urls: %q", rcfg.URLs) + } + } + } + if o.LeafNode.Port == 0 { return nil } @@ -394,13 +414,7 @@ func (s *Server) connectToRemoteLeafNode(remote *leafNodeCfg, firstConnect bool) // We have a connection here to a remote server. // Go ahead and create our leaf node and return. - s.createLeafNode(conn, remote) - - // We will put this in the normal log if first connect, does not force -DV mode to know - // that the connect worked. - if firstConnect { - s.Noticef("Connected leafnode to %q", rURL.Host) - } + s.createLeafNode(conn, rURL, remote, nil) return } } @@ -508,7 +522,7 @@ func (s *Server) startLeafNodeAcceptLoop() { if warn { s.Warnf(leafnodeTLSInsecureWarning) } - go s.acceptConnections(l, "Leafnode", func(conn net.Conn) { s.createLeafNode(conn, nil) }, nil) + go s.acceptConnections(l, "Leafnode", func(conn net.Conn) { s.createLeafNode(conn, nil, nil, nil) }, nil) s.mu.Unlock() } @@ -621,6 +635,7 @@ func (s *Server) removeLeafNodeURL(urlStr string) bool { // Server lock is held on entry func (s *Server) generateLeafNodeInfoJSON() { s.leafNodeInfo.LeafNodeURLs = s.leafURLsMap.getAsStringSlice() + s.leafNodeInfo.WSConnectURLs = s.websocket.connectURLsMap.getAsStringSlice() b, _ := json.Marshal(s.leafNodeInfo) pcs := [][]byte{[]byte("INFO"), b, []byte(CR_LF)} s.leafNodeInfoJSON = bytes.Join(pcs, []byte(" ")) @@ -637,7 +652,7 @@ func (s *Server) sendAsyncLeafNodeInfo() { } // Called when an inbound leafnode connection is accepted or we create one for a solicited leafnode. -func (s *Server) createLeafNode(conn net.Conn, remote *leafNodeCfg) *client { +func (s *Server) createLeafNode(conn net.Conn, rURL *url.URL, remote *leafNodeCfg, ws *websocket) *client { // Snapshot server options. opts := s.getOpts() @@ -653,15 +668,27 @@ func (s *Server) createLeafNode(conn net.Conn, remote *leafNodeCfg) *client { // Do not update the smap here, we need to do it in initLeafNodeSmapAndSendSubs c.leaf = &leaf{} + // For accepted LN connections, ws will be != nil if it was accepted + // through the Websocket port. + c.ws = ws + // For remote, check if the scheme starts with "ws", if so, we will initiate + // a remote Leaf Node connection as a websocket connection. + if remote != nil && rURL != nil && isWSURL(rURL) { + remote.RLock() + c.ws = &websocket{compress: remote.Websocket.Compression, maskwrite: !remote.Websocket.NoMasking} + remote.RUnlock() + } + // Determines if we are soliciting the connection or not. var solicited bool - var sendSysConnectEvent bool var acc *Account c.mu.Lock() c.initClient() + c.Noticef("Leafnode connection created") if remote != nil { solicited = true + remote.Lock() // Users can bind to any local account, if its empty // we will assume the $G account. if remote.LocalAccount == "" { @@ -669,19 +696,19 @@ func (s *Server) createLeafNode(conn net.Conn, remote *leafNodeCfg) *client { } c.leaf.remote = remote c.setPermissions(remote.perms) - if c.leaf.remote.Hub { - sendSysConnectEvent = true - } else { + if !c.leaf.remote.Hub { c.leaf.isSpoke = true } + lacc := remote.LocalAccount + remote.Unlock() c.mu.Unlock() // TODO: Decide what should be the optimal behavior here. // For now, if lookup fails, we will constantly try // to recreate this LN connection. var err error - acc, err = s.LookupAccount(remote.LocalAccount) + acc, err = s.LookupAccount(lacc) if err != nil { - c.Errorf("No local account %q for leafnode: %v", remote.LocalAccount, err) + c.Errorf("No local account %q for leafnode: %v", lacc, err) c.closeConnection(MissingAccount) return nil } @@ -689,160 +716,51 @@ func (s *Server) createLeafNode(conn net.Conn, remote *leafNodeCfg) *client { c.acc = acc } else { c.flags.set(expectConnect) + if ws != nil { + c.Debugf("Leafnode compression=%v", c.ws.compress) + } } c.mu.Unlock() var nonce [nonceLen]byte + var info *Info - // Grab server variables - s.mu.Lock() - info := s.copyLeafNodeInfo() if !solicited { + // Grab server variables + s.mu.Lock() + info = s.copyLeafNodeInfo() s.generateNonce(nonce[:]) + s.mu.Unlock() } - clusterName := s.info.Cluster - headers := s.supportsHeaders() - s.mu.Unlock() // Grab lock c.mu.Lock() - // If connection has been closed, this function will unlock and call - // closeConnection() to ensure proper clean-up. - isClosedUnlock := func() bool { - if c.isClosed() { - c.mu.Unlock() - c.closeConnection(WriteError) - return true - } - return false - } - - // I don't think that the connection can be closed this early (since it isn't - // registered anywhere and doesn't have read/write loops running), but let's - // check in case code is changed in the future and there is such possibility. - if isClosedUnlock() { - return nil - } - + var preBuf []byte if solicited { - // We need to wait here for the info, but not for too long. - c.nc.SetReadDeadline(time.Now().Add(DEFAULT_LEAFNODE_INFO_WAIT)) - br := bufio.NewReaderSize(c.nc, MAX_CONTROL_LINE_SIZE) - info, err := br.ReadString('\n') - if err != nil { - c.mu.Unlock() - if err == io.EOF { - c.closeConnection(ClientClosed) - } else { - c.closeConnection(ReadError) - } - return nil - } - c.nc.SetReadDeadline(time.Time{}) - - c.mu.Unlock() - // Handle only connection to wrong port here, others will be handled below. - if err := c.parse([]byte(info)); err == ErrConnectedToWrongPort { - c.Errorf(err.Error()) - c.closeConnection(WrongPort) - return nil - } - c.mu.Lock() - - if !c.flags.isSet(infoReceived) { - c.mu.Unlock() - c.Errorf("Did not get the remote leafnode's INFO, timed-out") - c.closeConnection(ReadError) - return nil - } - - // Not sure that can happen, but in case the connection was marked - // as closed during the call to parse... - if isClosedUnlock() { - return nil - } - - // Do TLS here as needed. - remote.RLock() - remoteTLSConfig := remote.TLSConfig - tlsRequired := remote.TLS || remoteTLSConfig != nil - remote.RUnlock() - if tlsRequired { - c.Debugf("Starting TLS leafnode client handshake") - // Specify the ServerName we are expecting. - var tlsConfig *tls.Config - if remoteTLSConfig != nil { - tlsConfig = remoteTLSConfig.Clone() - } else { - tlsConfig = &tls.Config{MinVersion: tls.VersionTLS12} - } - - var host string - // If ServerName was given to us from the option, use that, always. - if tlsConfig.ServerName == "" { - url := remote.getCurrentURL() - host = url.Hostname() - // We need to check if this host is an IP. If so, we probably - // had this advertised to us and should use the configured host - // name for the TLS server name. - if remote.tlsName != "" && net.ParseIP(host) != nil { - host = remote.tlsName + // For websocket connection, we need to send an HTTP request, + // and get the response before starting the readLoop to get + // the INFO, etc.. + if c.isWebsocket() { + var err error + var closeReason ClosedState + + preBuf, closeReason, err = c.leafNodeSolicitWSConnection(opts, rURL, remote) + if err != nil { + c.Errorf("Error soliciting websocket connection: %v", err) + c.mu.Unlock() + if closeReason != 0 { + c.closeConnection(closeReason) } - tlsConfig.ServerName = host - } - - c.nc = tls.Client(c.nc, tlsConfig) - - conn := c.nc.(*tls.Conn) - - // Setup the timeout - var wait time.Duration - if remote.TLSTimeout == 0 { - wait = TLS_TIMEOUT - } else { - wait = secondsToDuration(remote.TLSTimeout) - } - time.AfterFunc(wait, func() { tlsTimeout(c, conn) }) - conn.SetReadDeadline(time.Now().Add(wait)) - - // Force handshake - c.mu.Unlock() - if err = conn.Handshake(); err != nil { - // If we overrode and used the saved tlsName but that failed - // we will clear that here. This is for the case that another server - // does not have the same tlsName, maybe only IPs. - // https://github.com/nats-io/nats-server/issues/1256 - if _, ok := err.(x509.HostnameError); ok { - remote.Lock() - if host == remote.tlsName { - remote.tlsName = "" - } - remote.Unlock() - } - c.Errorf("TLS handshake error: %v", err) - c.closeConnection(TLSHandshakeError) - return nil - } - // Reset the read deadline - conn.SetReadDeadline(time.Time{}) - - // Re-Grab lock - c.mu.Lock() - - // Timeout may have closed the connection while the lock was released. - if isClosedUnlock() { return nil } + } else { + // We need to wait for the info, but not for too long. + c.nc.SetReadDeadline(time.Now().Add(DEFAULT_LEAFNODE_INFO_WAIT)) } - if err := c.sendLeafConnect(clusterName, tlsRequired, headers); err != nil { - c.mu.Unlock() - c.closeConnection(ProtocolViolation) - return nil - } - c.Debugf("Remote leafnode connect msg sent") - + // We will process the INFO from the readloop and finish by + // sending the CONNECT and finish registration later. } else { // Send our info to the other side. // Remember the nonce we sent here for signatures, etc. @@ -858,41 +776,18 @@ func (s *Server) createLeafNode(conn net.Conn, remote *leafNodeCfg) *client { // this before it can initiate the TLS handshake.. c.sendProtoNow(bytes.Join(pcs, []byte(" "))) - // The above call could have marked the connection as closed (due to - // TCP error), so if that is the case, bail out here. - if isClosedUnlock() { + // The above call could have marked the connection as closed (due to TCP error). + if c.isClosed() { + c.mu.Unlock() + c.closeConnection(WriteError) return nil } // Check to see if we need to spin up TLS. - if info.TLSRequired { - c.Debugf("Starting TLS leafnode server handshake") - c.nc = tls.Server(c.nc, opts.LeafNode.TLSConfig) - conn := c.nc.(*tls.Conn) - - // Setup the timeout - ttl := secondsToDuration(opts.LeafNode.TLSTimeout) - time.AfterFunc(ttl, func() { tlsTimeout(c, conn) }) - conn.SetReadDeadline(time.Now().Add(ttl)) - - // Force handshake - c.mu.Unlock() - if err := conn.Handshake(); err != nil { - c.Errorf("TLS handshake error: %v", err) - c.closeConnection(TLSHandshakeError) - return nil - } - // Reset the read deadline - conn.SetReadDeadline(time.Time{}) - - // Re-Grab lock - c.mu.Lock() - - // Indicate that handshake is complete (used in monitoring) - c.flags.set(handshakeComplete) - - // Timeout may have closed the connection while the lock was released. - if isClosedUnlock() { + if !c.isWebsocket() && info.TLSRequired { + // Perform server-side TLS handshake. + if err := c.doTLSServerHandshake("leafnode", opts.LeafNode.TLSConfig, opts.LeafNode.TLSTimeout); err != nil { + c.mu.Unlock() return nil } } @@ -900,6 +795,9 @@ func (s *Server) createLeafNode(conn net.Conn, remote *leafNodeCfg) *client { // Leaf nodes will always require a CONNECT to let us know // when we are properly bound to an account. c.setAuthTimer(secondsToDuration(opts.LeafNode.AuthTimeout)) + + // Set the Ping timer + s.setFirstPingTimer(c) } // Keep track in case server is shutdown before we can successfully register. @@ -911,62 +809,33 @@ func (s *Server) createLeafNode(conn net.Conn, remote *leafNodeCfg) *client { } // Spin up the read loop. - s.startGoRoutine(func() { c.readLoop(nil) }) - - // Spin up the write loop. - s.startGoRoutine(func() { c.writeLoop() }) + s.startGoRoutine(func() { c.readLoop(preBuf) }) - // Set the Ping timer - s.setFirstPingTimer(c) + // We will sping the write loop for solicited connections only + // when processing the INFO and after switching to TLS if needed. + if !solicited { + s.startGoRoutine(func() { c.writeLoop() }) + } c.mu.Unlock() - c.Debugf("Leafnode connection created") - - // Update server's accounting here if we solicited. - // Also send our local subs. - if solicited { - // Make sure we register with the account here. - c.registerWithAccount(acc) - s.addLeafNodeConnection(c, _EMPTY_, false) - s.initLeafNodeSmapAndSendSubs(c) - if sendSysConnectEvent { - s.sendLeafNodeConnect(acc) - } - - // The above functions are not atomically under the client - // lock doing those operations. It is possible - since we - // have started the read/write loops - that the connection - // is closed before or in between. This would leave the - // closed LN connection possible registered with the account - // and/or the server's leafs map. So check if connection - // is closed, and if so, manually cleanup. - c.mu.Lock() - closed := c.isClosed() - c.mu.Unlock() - if closed { - s.removeLeafNodeConnection(c) - if prev := acc.removeClient(c); prev == 1 { - s.decActiveAccounts() - } - } - } - return c } -func (c *client) processLeafnodeInfo(info *Info) error { +func (c *client) processLeafnodeInfo(info *Info) { c.mu.Lock() - defer c.mu.Unlock() - if c.leaf == nil || c.isClosed() { - return nil + c.mu.Unlock() + return } + var firstINFO bool + // Mark that the INFO protocol has been received. // Note: For now, only the initial INFO has a nonce. We // will probably do auto key rotation at some point. if c.flags.setIfNotSet(infoReceived) { + firstINFO = true // Prevent connecting to non leafnode port. Need to do this only for // the first INFO, not for async INFO updates... // @@ -987,7 +856,10 @@ func (c *client) processLeafnodeInfo(info *Info) error { // from the remote server an INFO with CID and LeafNodeURLs. Anything // else should be considered an attempt to connect to a wrong port. if c.leaf.remote != nil && (info.CID == 0 || info.LeafNodeURLs == nil) { - return ErrConnectedToWrongPort + c.mu.Unlock() + c.Errorf(ErrConnectedToWrongPort.Error()) + c.closeConnection(WrongPort) + return } // Capture a nonce here. c.nonce = []byte(info.Nonce) @@ -1009,7 +881,7 @@ func (c *client) processLeafnodeInfo(info *Info) error { } // For both initial INFO and async INFO protocols, Possibly // update our list of remote leafnode URLs we can connect to. - if c.leaf.remote != nil && len(info.LeafNodeURLs) > 0 { + if c.leaf.remote != nil && (len(info.LeafNodeURLs) > 0 || len(info.WSConnectURLs) > 0) { // Consider the incoming array as the most up-to-date // representation of the remote cluster's list of URLs. c.updateLeafNodeURLs(info) @@ -1033,7 +905,22 @@ func (c *client) processLeafnodeInfo(info *Info) error { c.setPermissions(perms) } - return nil + var finishConnect bool + var s *Server + + // If this is a remote connection and this is the first INFO protocol, + // then we need to finish the connect process by sending CONNECT, etc.. + if firstINFO && c.leaf.remote != nil { + // Clear deadline that was set in createLeafNode while waiting for the INFO. + c.nc.SetDeadline(time.Time{}) + finishConnect = true + s = c.srv + } + c.mu.Unlock() + + if finishConnect && s != nil { + s.leafNodeFinishConnectProcess(c) + } } // When getting a leaf node INFO protocol, use the provided @@ -1043,10 +930,23 @@ func (c *client) updateLeafNodeURLs(info *Info) { cfg.Lock() defer cfg.Unlock() - cfg.urls = make([]*url.URL, 0, 1+len(info.LeafNodeURLs)) + // We have ensured that if a remote has a WS scheme, then all are. + // So check if first is WS, then add WS URLs, otherwise, add non WS ones. + if len(cfg.URLs) > 0 && isWSURL(cfg.URLs[0]) { + // We use wsSchemePrefix. It does not matter if TLS or not since + // the distinction is done when creating the LN connection based + // on presence of TLS config, etc.. + c.doUpdateLNURLs(cfg, wsSchemePrefix, info.WSConnectURLs) + return + } + c.doUpdateLNURLs(cfg, "nats-leaf", info.LeafNodeURLs) +} + +func (c *client) doUpdateLNURLs(cfg *leafNodeCfg, scheme string, URLs []string) { + cfg.urls = make([]*url.URL, 0, 1+len(URLs)) // Add the ones we receive in the protocol - for _, surl := range info.LeafNodeURLs { - url, err := url.Parse("nats-leaf://" + surl) + for _, surl := range URLs { + url, err := url.Parse(fmt.Sprintf("%s://%s", scheme, surl)) if err != nil { c.Errorf("Error parsing url %q: %v", surl, err) continue @@ -2039,3 +1939,259 @@ func (c *client) setLeafConnectDelayIfSoliciting(delay time.Duration) (string, t c.mu.Unlock() return accName, delay } + +// For the given remote Leafnode configuration, this function returns +// if TLS is required, and if so, will return a clone of the TLS Config +// (since some fields will be changed during handshake), the TLS server +// name that is remembered, and the TLS timeout. +func (c *client) leafNodeGetTLSConfigForSolicit(remote *leafNodeCfg, needsLock bool) (bool, *tls.Config, string, float64) { + var ( + tlsConfig *tls.Config + tlsName string + tlsTimeout float64 + ) + + if needsLock { + remote.RLock() + } + tlsRequired := remote.TLS || remote.TLSConfig != nil + if tlsRequired { + if remote.TLSConfig != nil { + tlsConfig = remote.TLSConfig.Clone() + } else { + tlsConfig = &tls.Config{MinVersion: tls.VersionTLS12} + } + tlsName = remote.tlsName + tlsTimeout = remote.TLSTimeout + if tlsTimeout == 0 { + tlsTimeout = float64(TLS_TIMEOUT / time.Second) + } + } + if needsLock { + remote.RUnlock() + } + + return tlsRequired, tlsConfig, tlsName, tlsTimeout +} + +// Initiates the LeafNode Websocket connection by: +// - doing the TLS handshake if needed +// - sending the HTTP request +// - waiting for the HTTP response +// +// Since some bufio reader is used to consume the HTTP response, this function +// returns the slice of buffered bytes (if any) so that the readLoop that will +// be started after that consume those first before reading from the socket. +// The boolean +// +// Lock held on entry. +func (c *client) leafNodeSolicitWSConnection(opts *Options, rURL *url.URL, remote *leafNodeCfg) ([]byte, ClosedState, error) { + remote.RLock() + compress := remote.Websocket.Compression + // By default the server will mask outbound frames, but it can be disabled with this option. + noMasking := remote.Websocket.NoMasking + tlsRequired, tlsConfig, tlsName, tlsTimeout := c.leafNodeGetTLSConfigForSolicit(remote, false) + remote.RUnlock() + // Do TLS here as needed. + if tlsRequired { + // Perform the client-side TLS handshake. + if resetTLSName, err := c.doTLSClientHandshake("leafnode", rURL, tlsConfig, tlsName, tlsTimeout); err != nil { + // Check if we need to reset the remote's TLS name. + if resetTLSName { + remote.Lock() + remote.tlsName = _EMPTY_ + remote.Unlock() + } + // 0 will indicate that the connection was already closed + return nil, 0, err + } + } + + var req *http.Request + var wsKey string + + // For http request, we need the passed URL to contain either http or https scheme. + scheme := "http" + if tlsRequired { + scheme = "https" + } + // We will use the `/leafnode` path to tell the accepting WS server that it should + // create a LEAF connection, not a CLIENT. + // In case we use the user's URL path in the future, make sure we append the user's + // path to our `/leafnode` path. + path := leafNodeWSPath + if curPath := rURL.EscapedPath(); curPath != _EMPTY_ { + if curPath[0] == '/' { + curPath = curPath[1:] + } + path += curPath + } + ustr := fmt.Sprintf("%s://%s%s", scheme, rURL.Host, path) + u, _ := url.Parse(ustr) + req = &http.Request{ + Method: "GET", + URL: u, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + Host: u.Host, + } + wsKey, err := wsMakeChallengeKey() + if err != nil { + return nil, WriteError, err + } + + req.Header["Upgrade"] = []string{"websocket"} + req.Header["Connection"] = []string{"Upgrade"} + req.Header["Sec-WebSocket-Key"] = []string{wsKey} + req.Header["Sec-WebSocket-Version"] = []string{"13"} + if compress { + req.Header.Add("Sec-WebSocket-Extensions", wsPMCExtension+wsNoCtxTakeOver) + } + if noMasking { + req.Header.Add("Sec-WebSocket-Extensions", wsNoMaskingExtension) + } + if err := req.Write(c.nc); err != nil { + return nil, WriteError, err + } + + var resp *http.Response + + br := bufio.NewReaderSize(c.nc, MAX_CONTROL_LINE_SIZE) + c.nc.SetReadDeadline(time.Now().Add(DEFAULT_LEAFNODE_INFO_WAIT)) + resp, err = http.ReadResponse(br, req) + if err == nil && + (resp.StatusCode != 101 || + !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || + !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || + resp.Header.Get("Sec-Websocket-Accept") != wsAcceptKey(wsKey)) { + + err = fmt.Errorf("invalid websocket connection") + } + if err == nil && (c.ws.compress || noMasking) { + // Check extensions... + + srvCompress, srvNoMasking := wsClientWantedExtensions(resp.Header) + + // We said to the otherside that we support compression. Now check that + // the other side said that it supports compression too. + if c.ws.compress && !srvCompress { + // No extension, or does not contain the indication that per-message + // compression is supported, so disable on our side. + c.ws.compress = false + } + + // Same for no masking... + if noMasking && !srvNoMasking { + // Need to mask our writes as any client would do. + c.ws.maskwrite = true + } + } + if resp != nil { + resp.Body.Close() + } + if err != nil { + return nil, ReadError, err + } + c.Debugf("Leafnode compression=%v masking=%v", c.ws.compress, c.ws.maskwrite) + + var preBuf []byte + // We have to slurp whatever is in the bufio reader and pass that to the readloop. + if n := br.Buffered(); n != 0 { + preBuf, _ = br.Peek(n) + } + return preBuf, 0, nil +} + +// This is invoked for remote LEAF remote connections after processing the INFO +// protocol. This will do the TLS handshake (if needed be), send the CONNECT protocol +// and register the leaf node. +func (s *Server) leafNodeFinishConnectProcess(c *client) { + clusterName := s.ClusterName() + + c.mu.Lock() + if c.isClosed() { + c.mu.Unlock() + return + } + remote := c.leaf.remote + + // Check if we will need to send the system connect event. + remote.RLock() + sendSysConnectEvent := remote.Hub + remote.RUnlock() + + var tlsRequired bool + + // In case of websocket, the TLS handshake has been already done. + // So check only for non websocket connections. + if !c.isWebsocket() { + var tlsConfig *tls.Config + var tlsName string + var tlsTimeout float64 + + // Check if TLS is required and gather TLS config variables. + tlsRequired, tlsConfig, tlsName, tlsTimeout = c.leafNodeGetTLSConfigForSolicit(remote, true) + + // If TLS required, peform handshake. + if tlsRequired { + // Get the URL that was used to connect to the remote server. + rURL := remote.getCurrentURL() + + // Perform the client-side TLS handshake. + if resetTLSName, err := c.doTLSClientHandshake("leafnode", rURL, tlsConfig, tlsName, tlsTimeout); err != nil { + // Check if we need to reset the remote's TLS name. + if resetTLSName { + remote.Lock() + remote.tlsName = _EMPTY_ + remote.Unlock() + } + c.mu.Unlock() + return + } + } + } + if err := c.sendLeafConnect(clusterName, tlsRequired, c.headers); err != nil { + c.mu.Unlock() + c.closeConnection(WriteError) + return + } + + // Spin up the write loop. + s.startGoRoutine(func() { c.writeLoop() }) + + c.Debugf("Remote leafnode connect msg sent") + + // Capture account before releasing lock + acc := c.acc + c.mu.Unlock() + + // Make sure we register with the account here. + c.registerWithAccount(acc) + s.addLeafNodeConnection(c, _EMPTY_, false) + s.initLeafNodeSmapAndSendSubs(c) + if sendSysConnectEvent { + s.sendLeafNodeConnect(acc) + } + + // The above functions are not atomically under the client + // lock doing those operations. It is possible - since we + // have started the read/write loops - that the connection + // is closed before or in between. This would leave the + // closed LN connection possible registered with the account + // and/or the server's leafs map. So check if connection + // is closed, and if so, manually cleanup. + c.mu.Lock() + closed := c.isClosed() + if !closed { + s.setFirstPingTimer(c) + } + c.mu.Unlock() + if closed { + s.removeLeafNodeConnection(c) + if prev := acc.removeClient(c); prev == 1 { + s.decActiveAccounts() + } + } +} diff --git a/server/leafnode_test.go b/server/leafnode_test.go index 9c389d70f0..f279f94e89 100644 --- a/server/leafnode_test.go +++ b/server/leafnode_test.go @@ -15,10 +15,12 @@ package server import ( "bufio" + "bytes" "context" "crypto/tls" "fmt" "io/ioutil" + "math/rand" "net" "net/url" "os" @@ -28,6 +30,7 @@ import ( "testing" "time" + jwt "github.com/nats-io/jwt/v2" "github.com/nats-io/nats.go" ) @@ -953,7 +956,7 @@ func TestLeafCloseTLSConnection(t *testing.T) { ch <- true } -func TestLeafCloseTLSSaveName(t *testing.T) { +func TestLeafNodeTLSSaveName(t *testing.T) { opts := DefaultOptions() opts.LeafNode.Host = "127.0.0.1" opts.LeafNode.Port = -1 @@ -2359,3 +2362,465 @@ func TestLeafNodeTLSConfigReloadForRemote(t *testing.T) { return nil }) } + +func testDefaultLeafNodeWSOptions() *Options { + o := DefaultOptions() + o.Websocket.Host = "127.0.0.1" + o.Websocket.Port = -1 + o.Websocket.NoTLS = true + o.LeafNode.Host = "127.0.0.1" + o.LeafNode.Port = -1 + return o +} + +func testDefaultRemoteLeafNodeWSOptions(t *testing.T, o *Options, tls bool) *Options { + // Use some path in the URL.. we don't use that, but internally + // the server will prefix the path with /leafnode so that the + // WS webserver knows that it needs to create a LEAF connection. + u, _ := url.Parse(fmt.Sprintf("ws://127.0.0.1:%d/some/path", o.Websocket.Port)) + lo := DefaultOptions() + lo.Cluster.Name = "LN" + remote := &RemoteLeafOpts{URLs: []*url.URL{u}} + if tls { + tc := &TLSConfigOpts{ + CertFile: "../test/configs/certs/server-cert.pem", + KeyFile: "../test/configs/certs/server-key.pem", + CaFile: "../test/configs/certs/ca.pem", + } + tlsConf, err := GenTLSConfig(tc) + if err != nil { + t.Fatalf("Error generating TLS config: %v", err) + } + // GenTLSConfig sets the CA in ClientCAs, but since here we act + // as a client, set RootCAs... + tlsConf.RootCAs = tlsConf.ClientCAs + remote.TLSConfig = tlsConf + } + lo.LeafNode.Remotes = []*RemoteLeafOpts{remote} + return lo +} + +func TestLeafNodeWSMixURLs(t *testing.T) { + for _, test := range []struct { + name string + urls []string + }{ + {"mix 1", []string{"nats://127.0.0.1:1234", "ws://127.0.0.1:5678", "wss://127.0.0.1:9012"}}, + {"mix 2", []string{"ws://127.0.0.1:1234", "nats://127.0.0.1:5678", "wss://127.0.0.1:9012"}}, + {"mix 3", []string{"wss://127.0.0.1:1234", "ws://127.0.0.1:5678", "nats://127.0.0.1:9012"}}, + {"mix 4", []string{"ws://127.0.0.1:1234", "nats://127.0.0.1:9012"}}, + {"mix 5", []string{"nats://127.0.0.1:1234", "ws://127.0.0.1:9012"}}, + {"mix 6", []string{"wss://127.0.0.1:1234", "nats://127.0.0.1:9012"}}, + {"mix 7", []string{"nats://127.0.0.1:1234", "wss://127.0.0.1:9012"}}, + } { + t.Run(test.name, func(t *testing.T) { + o := DefaultOptions() + remote := &RemoteLeafOpts{} + urls := make([]*url.URL, 0, 3) + for _, ustr := range test.urls { + u, err := url.Parse(ustr) + if err != nil { + t.Fatalf("Error parsing url: %v", err) + } + urls = append(urls, u) + } + remote.URLs = urls + o.LeafNode.Remotes = []*RemoteLeafOpts{remote} + s, err := NewServer(o) + if err == nil || !strings.Contains(err.Error(), "mix") { + if s != nil { + s.Shutdown() + } + t.Fatalf("Unexpected error: %v", err) + } + }) + } +} + +type testConnTrackSize struct { + sync.Mutex + net.Conn + sz int +} + +func (c *testConnTrackSize) Write(p []byte) (int, error) { + c.Lock() + defer c.Unlock() + n, err := c.Conn.Write(p) + c.sz += n + return n, err +} + +func TestLeafNodeWSBasic(t *testing.T) { + for _, test := range []struct { + name string + masking bool + tls bool + acceptCompression bool + remoteCompression bool + }{ + {"masking plain no compression", true, false, false, false}, + {"masking plain compression", true, false, true, true}, + {"masking plain compression disagree", true, false, false, true}, + {"masking plain compression disagree 2", true, false, true, false}, + {"masking tls no compression", true, true, false, false}, + {"masking tls compression", true, true, true, true}, + {"masking tls compression disagree", true, true, false, true}, + {"masking tls compression disagree 2", true, true, true, false}, + {"no masking plain no compression", false, false, false, false}, + {"no masking plain compression", false, false, true, true}, + {"no masking plain compression disagree", false, false, false, true}, + {"no masking plain compression disagree 2", false, false, true, false}, + {"no masking tls no compression", false, true, false, false}, + {"no masking tls compression", false, true, true, true}, + {"no masking tls compression disagree", false, true, false, true}, + {"no masking tls compression disagree 2", false, true, true, false}, + } { + t.Run(test.name, func(t *testing.T) { + o := testDefaultLeafNodeWSOptions() + o.Websocket.NoTLS = !test.tls + if test.tls { + tc := &TLSConfigOpts{ + CertFile: "../test/configs/certs/server-cert.pem", + KeyFile: "../test/configs/certs/server-key.pem", + CaFile: "../test/configs/certs/ca.pem", + } + tlsConf, err := GenTLSConfig(tc) + if err != nil { + t.Fatalf("Error generating TLS config: %v", err) + } + o.Websocket.TLSConfig = tlsConf + } + o.Websocket.Compression = test.acceptCompression + s := RunServer(o) + defer s.Shutdown() + + lo := testDefaultRemoteLeafNodeWSOptions(t, o, test.tls) + lo.LeafNode.Remotes[0].Websocket.Compression = test.remoteCompression + lo.LeafNode.Remotes[0].Websocket.NoMasking = !test.masking + ln := RunServer(lo) + defer ln.Shutdown() + + checkLeafNodeConnected(t, s) + checkLeafNodeConnected(t, ln) + + var trackSizeConn *testConnTrackSize + if !test.tls { + var cln *client + ln.mu.Lock() + for _, l := range ln.leafs { + cln = l + break + } + ln.mu.Unlock() + cln.mu.Lock() + trackSizeConn = &testConnTrackSize{Conn: cln.nc} + cln.nc = trackSizeConn + cln.mu.Unlock() + } + + nc1 := natsConnect(t, s.ClientURL()) + defer nc1.Close() + sub1 := natsSubSync(t, nc1, "foo") + natsFlush(t, nc1) + + checkSubInterest(t, ln, globalAccountName, "foo", time.Second) + + nc2 := natsConnect(t, ln.ClientURL()) + msg1Payload := make([]byte, 512) + for i := 0; i < len(msg1Payload); i++ { + msg1Payload[i] = 'A' + } + natsPub(t, nc2, "foo", msg1Payload) + + msg := natsNexMsg(t, sub1, time.Second) + if !bytes.Equal(msg.Data, msg1Payload) { + t.Fatalf("Invalid message: %q", msg.Data) + } + + sub2 := natsSubSync(t, nc2, "bar") + natsFlush(t, nc2) + + checkSubInterest(t, s, globalAccountName, "bar", time.Second) + + msg2Payload := make([]byte, 512) + for i := 0; i < len(msg2Payload); i++ { + msg2Payload[i] = 'B' + } + natsPub(t, nc1, "bar", msg2Payload) + + msg = natsNexMsg(t, sub2, time.Second) + if !bytes.Equal(msg.Data, msg2Payload) { + t.Fatalf("Invalid message: %q", msg.Data) + } + + if !test.tls { + trackSizeConn.Lock() + size := trackSizeConn.sz + trackSizeConn.Unlock() + + if test.acceptCompression && test.remoteCompression { + if size >= 100 { + t.Fatalf("Seems that there was no compression: size=%v", size) + } + } else if size < 500 { + t.Fatalf("Seems compression was on while it should not: size=%v", size) + } + } + }) + } +} + +func TestLeafNodeWSRemoteCompressAndMaskingOptions(t *testing.T) { + for _, test := range []struct { + name string + compress bool + compStr string + noMasking bool + noMaskStr string + }{ + {"compression masking", true, "true", false, "false"}, + {"compression no masking", true, "true", true, "true"}, + {"no compression masking", false, "false", false, "false"}, + {"no compression no masking", false, "false", true, "true"}, + } { + t.Run(test.name, func(t *testing.T) { + conf := createConfFile(t, []byte(fmt.Sprintf(` + port: -1 + leafnodes { + remotes [ + {url: "ws://127.0.0.1:1234", ws_compression: %s, ws_no_masking: %s} + ] + } + `, test.compStr, test.noMaskStr))) + defer os.Remove(conf) + o, err := ProcessConfigFile(conf) + if err != nil { + t.Fatalf("Error loading conf: %v", err) + } + if nr := len(o.LeafNode.Remotes); nr != 1 { + t.Fatalf("Expected 1 remote, got %v", nr) + } + r := o.LeafNode.Remotes[0] + if cur := r.Websocket.Compression; cur != test.compress { + t.Fatalf("Expected compress to be %v, got %v", test.compress, cur) + } + if cur := r.Websocket.NoMasking; cur != test.noMasking { + t.Fatalf("Expected ws_masking to be %v, got %v", test.noMasking, cur) + } + }) + } +} + +func TestLeafNodeWSNoMaskingRejected(t *testing.T) { + wsTestRejectNoMasking = true + defer func() { wsTestRejectNoMasking = false }() + + o := testDefaultLeafNodeWSOptions() + s := RunServer(o) + defer s.Shutdown() + + lo := testDefaultRemoteLeafNodeWSOptions(t, o, false) + ln := RunServer(lo) + defer ln.Shutdown() + + checkLeafNodeConnected(t, s) + checkLeafNodeConnected(t, ln) + + var cln *client + ln.mu.Lock() + for _, l := range ln.leafs { + cln = l + break + } + ln.mu.Unlock() + + cln.mu.Lock() + maskWrite := cln.ws.maskwrite + cln.mu.Unlock() + + if !maskWrite { + t.Fatal("Leafnode remote connection should mask writes, it does not") + } +} + +func TestLeafNodeWSFailedConnection(t *testing.T) { + o := testDefaultLeafNodeWSOptions() + s := RunServer(o) + defer s.Shutdown() + + lo := testDefaultRemoteLeafNodeWSOptions(t, o, true) + lo.LeafNode.ReconnectInterval = 100 * time.Millisecond + ln := RunServer(lo) + defer ln.Shutdown() + + el := &captureErrorLogger{errCh: make(chan string, 100)} + ln.SetLogger(el, false, false) + + select { + case err := <-el.errCh: + if !strings.Contains(err, "handshake error") { + t.Fatalf("Unexpected error: %v", err) + } + case <-time.After(time.Second): + t.Fatal("No error reported!") + } + ln.Shutdown() + s.Shutdown() + + lst, err := natsListen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Error starting listener: %v", err) + } + defer lst.Close() + + wg := sync.WaitGroup{} + wg.Add(2) + + go func() { + defer wg.Done() + + for i := 0; i < 10; i++ { + c, err := lst.Accept() + if err != nil { + return + } + time.Sleep(time.Duration(rand.Intn(100)) * time.Millisecond) + if rand.Intn(2) == 1 { + c.Write([]byte("something\r\n")) + } + c.Close() + } + }() + + time.Sleep(100 * time.Millisecond) + + port := lst.Addr().(*net.TCPAddr).Port + u, _ := url.Parse(fmt.Sprintf("ws://127.0.0.1:%d", port)) + lo = DefaultOptions() + lo.LeafNode.Remotes = []*RemoteLeafOpts{{URLs: []*url.URL{u}}} + lo.LeafNode.ReconnectInterval = 10 * time.Millisecond + ln, _ = NewServer(lo) + el = &captureErrorLogger{errCh: make(chan string, 100)} + ln.SetLogger(el, false, false) + + go func() { + ln.Start() + wg.Done() + }() + + timeout := time.NewTimer(time.Second) + for i := 0; i < 10; i++ { + select { + case err := <-el.errCh: + if !strings.Contains(err, "Error soliciting") { + t.Fatalf("Unexpected error: %v", err) + } + case <-timeout.C: + t.Fatal("No error reported!") + } + } + ln.Shutdown() + lst.Close() + wg.Wait() +} + +func TestLeafNodeWSAuth(t *testing.T) { + conf := createConfFile(t, []byte(fmt.Sprintf(` + port: -1 + authorization { + users [ + {user: "user", pass: "puser", connection_types: ["%s"]} + {user: "leaf", pass: "pleaf", connection_types: ["%s"]} + ] + } + websocket { + port: -1 + no_tls: true + } + leafnodes { + port: -1 + } + `, jwt.ConnectionTypeStandard, jwt.ConnectionTypeLeafnode))) + defer os.Remove(conf) + o, err := ProcessConfigFile(conf) + if err != nil { + t.Fatalf("Error processing config file: %v", err) + } + o.NoLog, o.NoSigs = true, true + s := RunServer(o) + defer s.Shutdown() + + lo := testDefaultRemoteLeafNodeWSOptions(t, o, false) + ln := RunServer(lo) + defer ln.Shutdown() + + checkLeafNodeConnected(t, s) + checkLeafNodeConnected(t, ln) + + nc1 := natsConnect(t, fmt.Sprintf("nats://user:puser@127.0.0.1:%d", o.Port)) + defer nc1.Close() + + sub := natsSubSync(t, nc1, "foo") + natsFlush(t, nc1) + + checkSubInterest(t, ln, globalAccountName, "foo", time.Second) + + nc2 := natsConnect(t, ln.ClientURL()) + defer nc2.Close() + + natsPub(t, nc2, "foo", []byte("msg1")) + msg := natsNexMsg(t, sub, time.Second) + + if md := string(msg.Data); md != "msg1" { + t.Fatalf("Invalid message: %q", md) + } +} + +func TestLeafNodeWSGossip(t *testing.T) { + o1 := testDefaultLeafNodeWSOptions() + s1 := RunServer(o1) + defer s1.Shutdown() + + // Now connect from a server that knows only about s1 + lo := testDefaultRemoteLeafNodeWSOptions(t, o1, false) + lo.LeafNode.ReconnectInterval = 15 * time.Millisecond + ln := RunServer(lo) + defer ln.Shutdown() + + checkLeafNodeConnected(t, s1) + checkLeafNodeConnected(t, ln) + + // Now add a routed server to s1 + o2 := testDefaultLeafNodeWSOptions() + o2.Routes = RoutesFromStr(fmt.Sprintf("nats://127.0.0.1:%d", o1.Cluster.Port)) + s2 := RunServer(o2) + defer s2.Shutdown() + + // Wait for cluster to form + checkClusterFormed(t, s1, s2) + + // Now shutdown s1 and check that ln is able to reconnect to s2. + s1.Shutdown() + + checkLeafNodeConnected(t, s2) + checkLeafNodeConnected(t, ln) + + // Make sure that the reconnection was as a WS connection, not simply to + // the regular LN port. + var s2lc *client + s2.mu.Lock() + for _, l := range s2.leafs { + s2lc = l + break + } + s2.mu.Unlock() + + s2lc.mu.Lock() + isWS := s2lc.isWebsocket() + s2lc.mu.Unlock() + + if !isWS { + t.Fatal("Leafnode connection is not websocket!") + } +} diff --git a/server/mqtt.go b/server/mqtt.go index c2daaa12d1..cebe4a3118 100644 --- a/server/mqtt.go +++ b/server/mqtt.go @@ -349,10 +349,15 @@ func (s *Server) createMQTTClient(conn net.Conn) *client { c.mu.Lock() - isClosed := c.isClosed() + // In case connection has already been closed + if c.isClosed() { + c.mu.Unlock() + c.closeConnection(WriteError) + return nil + } var pre []byte - if !isClosed && tlsRequired && opts.AllowNonTLS { + if tlsRequired && opts.AllowNonTLS { pre = make([]byte, 4) c.nc.SetReadDeadline(time.Now().Add(secondsToDuration(opts.MQTT.TLSTimeout))) n, _ := io.ReadFull(c.nc, pre[:]) @@ -365,39 +370,17 @@ func (s *Server) createMQTTClient(conn net.Conn) *client { } } - if !isClosed && tlsRequired { - c.Debugf("Starting TLS client connection handshake") + if tlsRequired { if len(pre) > 0 { c.nc = &tlsMixConn{c.nc, bytes.NewBuffer(pre)} pre = nil } - c.nc = tls.Server(c.nc, opts.MQTT.TLSConfig) - conn := c.nc.(*tls.Conn) - - ttl := secondsToDuration(opts.MQTT.TLSTimeout) - time.AfterFunc(ttl, func() { tlsTimeout(c, conn) }) - conn.SetReadDeadline(time.Now().Add(ttl)) - - c.mu.Unlock() - if err := conn.Handshake(); err != nil { - c.Errorf("TLS handshake error: %v", err) - c.closeConnection(TLSHandshakeError) + // Perform server-side TLS handshake. + if err := c.doTLSServerHandshake("mqtt", opts.MQTT.TLSConfig, opts.MQTT.TLSTimeout); err != nil { + c.mu.Unlock() return nil } - conn.SetReadDeadline(time.Time{}) - - c.mu.Lock() - - c.flags.set(handshakeComplete) - - isClosed = c.isClosed() - } - - if isClosed { - c.mu.Unlock() - c.closeConnection(WriteError) - return nil } if authRequired { diff --git a/server/opts.go b/server/opts.go index 17f6a74f22..e14c44e7e9 100644 --- a/server/opts.go +++ b/server/opts.go @@ -147,6 +147,15 @@ type RemoteLeafOpts struct { Hub bool `json:"hub,omitempty"` DenyImports []string `json:"-"` DenyExports []string `json:"-"` + + // When an URL has the "ws" (or "wss") scheme, then the server will initiate the + // connection as a websocket connection. By default, the websocket frames will be + // masked (as if this server was a websocket client to the remote server). The + // NoMasking option will change this behavior and will send umasked frames. + Websocket struct { + Compression bool `json:"-"` + NoMasking bool `json:"-"` + } } // Options block for nats-server. @@ -1798,6 +1807,10 @@ func parseRemoteLeafNodes(v interface{}, errors *[]error, warnings *[]error) ([] continue } remote.DenyExports = subjects + case "ws_compress", "ws_compression", "websocket_compress", "websocket_compression": + remote.Websocket.Compression = v.(bool) + case "ws_no_masking", "websocket_no_masking": + remote.Websocket.NoMasking = v.(bool) default: if !tk.IsUsedVariable() { err := &unknownConfigFieldErr{ @@ -3541,7 +3554,7 @@ func parseWebsocket(v interface{}, o *Options, errors *[]error, warnings *[]erro *errors = append(*errors, err) } o.Websocket.HandshakeTimeout = ht - case "compression": + case "compress", "compression": o.Websocket.Compression = mv.(bool) case "authorization", "authentication": auth := parseSimpleAuth(tk, errors, warnings) diff --git a/server/route.go b/server/route.go index 385b95149d..cf36b862ec 100644 --- a/server/route.go +++ b/server/route.go @@ -674,6 +674,14 @@ func (c *client) processRouteInfo(info *Info) { if !s.getOpts().Cluster.NoAdvertise { s.addConnectURLsAndSendINFOToClients(info.ClientConnectURLs, info.WSConnectURLs) } + // Add the remote's leafnodeURL to our list of URLs and send the update + // to all LN connections. (Note that when coming from a route, LeafNodeURLs + // is an array of size 1 max). + s.mu.Lock() + if len(info.LeafNodeURLs) == 1 && s.addLeafNodeURL(info.LeafNodeURLs[0]) { + s.sendAsyncLeafNodeInfo() + } + s.mu.Unlock() } else { c.Debugf("Detected duplicate remote route %q", info.ID) c.closeConnection(DuplicateRoute) @@ -1314,46 +1322,13 @@ func (s *Server) createRoute(conn net.Conn, rURL *url.URL) *client { // Check for TLS if tlsRequired { - // Copy off the config to add in ServerName if we need to. - tlsConfig := opts.Cluster.TLSConfig.Clone() - - // If we solicited, we will act like the client, otherwise the server. + tlsConfig := opts.Cluster.TLSConfig if didSolicit { - c.Debugf("Starting TLS route client handshake") - // Specify the ServerName we are expecting. - host, _, _ := net.SplitHostPort(rURL.Host) - tlsConfig.ServerName = host - c.nc = tls.Client(c.nc, tlsConfig) - } else { - c.Debugf("Starting TLS route server handshake") - c.nc = tls.Server(c.nc, tlsConfig) - } - - conn := c.nc.(*tls.Conn) - - // Setup the timeout - ttl := secondsToDuration(opts.Cluster.TLSTimeout) - time.AfterFunc(ttl, func() { tlsTimeout(c, conn) }) - conn.SetReadDeadline(time.Now().Add(ttl)) - - c.mu.Unlock() - if err := conn.Handshake(); err != nil { - c.Errorf("TLS route handshake error: %v", err) - c.sendErr("Secure Connection - TLS Required") - c.closeConnection(TLSHandshakeError) - return nil + // Copy off the config to add in ServerName if we need to. + tlsConfig = tlsConfig.Clone() } - // Reset the read deadline - conn.SetReadDeadline(time.Time{}) - - // Re-Grab lock - c.mu.Lock() - - // To be consistent with client, set this flag to indicate that handshake is done - c.flags.set(handshakeComplete) - - // Verify that the connection did not go away while we released the lock. - if c.isClosed() { + // Perform (server or client side) TLS handshake. + if _, err := c.doTLSHandshake("route", didSolicit, rURL, tlsConfig, _EMPTY_, opts.Cluster.TLSTimeout); err != nil { c.mu.Unlock() return nil } @@ -1463,13 +1438,6 @@ func (s *Server) addRoute(c *client, info *Info) (bool, bool) { if info.GatewayURL != "" && s.addGatewayURL(info.GatewayURL) { s.sendAsyncGatewayInfo() } - - // Add the remote's leafnodeURL to our list of URLs and send the update - // to all LN connections. (Note that when coming from a route, LeafNodeURLs - // is an array of size 1 max). - if len(info.LeafNodeURLs) == 1 && s.addLeafNodeURL(info.LeafNodeURLs[0]) { - s.sendAsyncLeafNodeInfo() - } } s.mu.Unlock() diff --git a/server/server.go b/server/server.go index 6096267496..4607ec765e 100644 --- a/server/server.go +++ b/server/server.go @@ -1553,6 +1553,13 @@ func (s *Server) Start() { s.startGateways() } + // Start websocket server if needed. Do this before starting the routes, and + // leaf node because we want to resolve the gateway host:port so that this + // information can be sent to other routes. + if opts.Websocket.Port != 0 { + s.startWebsocketServer() + } + // Start up listen if we want to accept leaf node connections. if opts.LeafNode.Port != 0 { // Will resolve or assign the advertise address for the leafnode listener. @@ -1578,13 +1585,6 @@ func (s *Server) Start() { // port to be opened and potential ephemeral port selected. clientListenReady := make(chan struct{}) - // Start websocket server if needed. Do this before starting the routes, - // because we want to resolve the gateway host:port so that this information - // can be sent to other routes. - if opts.Websocket.Port != 0 { - s.startWebsocketServer() - } - // MQTT if opts.MQTT.Port != 0 { s.startMQTT() @@ -2262,40 +2262,17 @@ func (s *Server) createClient(conn net.Conn) *client { // Check for TLS if !isClosed && tlsRequired { - c.Debugf("Starting TLS client connection handshake") // If we have a prebuffer create a multi-reader. if len(pre) > 0 { c.nc = &tlsMixConn{c.nc, bytes.NewBuffer(pre)} // Clear pre so it is not parsed. pre = nil } - - c.nc = tls.Server(c.nc, opts.TLSConfig) - conn := c.nc.(*tls.Conn) - - // Setup the timeout - ttl := secondsToDuration(opts.TLSTimeout) - time.AfterFunc(ttl, func() { tlsTimeout(c, conn) }) - conn.SetReadDeadline(time.Now().Add(ttl)) - - // Force handshake - c.mu.Unlock() - if err := conn.Handshake(); err != nil { - c.Errorf("TLS handshake error: %v", err) - c.closeConnection(TLSHandshakeError) + // Performs server-side TLS handshake. + if err := c.doTLSServerHandshake(_EMPTY_, opts.TLSConfig, opts.TLSTimeout); err != nil { + c.mu.Unlock() return nil } - // Reset the read deadline - conn.SetReadDeadline(time.Time{}) - - // Re-Grab lock - c.mu.Lock() - - // Indicate that handshake is complete (used in monitoring) - c.flags.set(handshakeComplete) - - // The connection may have been closed - isClosed = c.isClosed() } // If connection is marked as closed, bail out. @@ -2922,9 +2899,9 @@ func (s *Server) PortsInfo(maxWait time.Duration) *Ports { } if wsListener != nil { - protocol := "ws" + protocol := wsSchemePrefix if wss { - protocol = "wss" + protocol = wsSchemePrefixTLS } ports.WebSocket = formatURL(protocol, wsListener) } diff --git a/server/websocket.go b/server/websocket.go index 901b4017ee..abf7bc6db5 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -16,6 +16,7 @@ package server import ( "bytes" "compress/flate" + "crypto/rand" "crypto/sha1" "crypto/tls" "encoding/base64" @@ -25,6 +26,7 @@ import ( "io" "io/ioutil" "log" + mrand "math/rand" "net" "net/http" "net/url" @@ -53,7 +55,7 @@ const ( wsMaskBit = 1 << 7 wsContinuationFrame = 0 - wsMaxFrameHeaderSize = 10 // For a server-to-client frame + wsMaxFrameHeaderSize = 14 // Since LeafNode may need to behave as a client wsMaxControlPayloadSize = 125 wsFrameSizeForBrowsers = 4096 // From experiment, webrowsers behave better with limited frame size @@ -75,6 +77,13 @@ const ( wsFinalFrame = true wsCompressedFrame = true wsUncompressedFrame = false + + wsSchemePrefix = "ws" + wsSchemePrefixTLS = "wss" + + wsNoMaskingExtension = "no-masking" + wsPMCExtension = "permessage-deflate" // per-message compression + wsNoCtxTakeOver = "; server_no_context_takeover; client_no_context_takeover; " ) var decompressorPool sync.Pool @@ -82,6 +91,9 @@ var decompressorPool sync.Pool // From https://tools.ietf.org/html/rfc6455#section-1.3 var wsGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") +// Test can enable this so that server does not support "no-masking" requests. +var wsTestRejectNoMasking = false + type websocket struct { frames net.Buffers fs int64 @@ -89,6 +101,8 @@ type websocket struct { compress bool closeSent bool browser bool + maskread bool + maskwrite bool compressor *flate.Writer cookieJwt string } @@ -113,6 +127,7 @@ type allowedOrigin struct { type wsUpgradeResult struct { conn net.Conn ws *websocket + kind int } type wsReadInfo struct { @@ -120,6 +135,7 @@ type wsReadInfo struct { fs bool ff bool fc bool + mask bool // Incoming leafnode connections may not have masking. mkpos byte mkey [4]byte buf []byte @@ -188,7 +204,8 @@ func (c *client) wsRead(r *wsReadInfo, ior io.Reader, buf []byte) ([][]byte, err b1 := tmpBuf[0] // Clients MUST set the mask bit. If not set, reject. - if b1&wsMaskBit == 0 { + // However, LEAF by default will not have masking, unless they are forced to, by configuration. + if r.mask && b1&wsMaskBit == 0 { return bufs, c.wsHandleProtocolError("mask bit missing") } @@ -236,13 +253,15 @@ func (c *client) wsRead(r *wsReadInfo, ior io.Reader, buf []byte) ([][]byte, err r.rem = int(binary.BigEndian.Uint64(tmpBuf)) } - // Read masking key - tmpBuf, pos, err = wsGet(ior, buf, pos, 4) - if err != nil { - return bufs, err + if r.mask { + // Read masking key + tmpBuf, pos, err = wsGet(ior, buf, pos, 4) + if err != nil { + return bufs, err + } + copy(r.mkey[:], tmpBuf) + r.mkpos = 0 } - copy(r.mkey[:], tmpBuf) - r.mkpos = 0 // Handle control messages in place... if wsIsControlFrame(frameType) { @@ -272,7 +291,9 @@ func (c *client) wsRead(r *wsReadInfo, ior io.Reader, buf []byte) ([][]byte, err b = r.buf } if !r.fc || r.rem == 0 { - r.unmask(b) + if r.mask { + r.unmask(b) + } if r.fc { // As per https://tools.ietf.org/html/rfc7692#section-7.2.2 // add 0x00, 0x00, 0xff, 0xff and then a final block so that flate reader @@ -314,7 +335,9 @@ func (c *client) wsHandleControlFrame(r *wsReadInfo, frameType wsOpCode, nc io.R if err != nil { return pos, err } - r.unmask(payload) + if r.mask { + r.unmask(payload) + } r.rem = 0 } switch frameType { @@ -382,13 +405,13 @@ func wsIsControlFrame(frameType wsOpCode) bool { // Create the frame header. // Encodes the frame type and optional compression flag, and the size of the payload. -func wsCreateFrameHeader(compressed bool, frameType wsOpCode, l int) []byte { +func wsCreateFrameHeader(useMasking, compressed bool, frameType wsOpCode, l int) ([]byte, []byte) { fh := make([]byte, wsMaxFrameHeaderSize) - n := wsFillFrameHeader(fh, wsFirstFrame, wsFinalFrame, compressed, frameType, l) - return fh[:n] + n, key := wsFillFrameHeader(fh, useMasking, wsFirstFrame, wsFinalFrame, compressed, frameType, l) + return fh[:n], key } -func wsFillFrameHeader(fh []byte, first, final, compressed bool, frameType wsOpCode, l int) int { +func wsFillFrameHeader(fh []byte, useMasking, first, final, compressed bool, frameType wsOpCode, l int) (int, []byte) { var n int var b byte if first { @@ -400,23 +423,38 @@ func wsFillFrameHeader(fh []byte, first, final, compressed bool, frameType wsOpC if compressed { b |= wsRsv1Bit } + b1 := byte(0) + if useMasking { + b1 |= wsMaskBit + } switch { case l <= 125: n = 2 fh[0] = b - fh[1] = byte(l) + fh[1] = b1 | byte(l) case l < 65536: n = 4 fh[0] = b - fh[1] = 126 + fh[1] = b1 | 126 binary.BigEndian.PutUint16(fh[2:], uint16(l)) default: n = 10 fh[0] = b - fh[1] = 127 + fh[1] = b1 | 127 binary.BigEndian.PutUint64(fh[2:], uint64(l)) } - return n + var key []byte + if useMasking { + var keyBuf [4]byte + if _, err := io.ReadFull(rand.Reader, keyBuf[:4]); err != nil { + kv := mrand.Int31() + binary.LittleEndian.PutUint32(keyBuf[:4], uint32(kv)) + } + copy(fh[n:], keyBuf[:4]) + key = fh[n : n+4] + n += 4 + } + return n, key } // Invokes wsEnqueueControlMessageLocked under client lock. @@ -428,6 +466,25 @@ func (c *client) wsEnqueueControlMessage(controlMsg wsOpCode, payload []byte) { c.mu.Unlock() } +// Mask the buffer with the given key +func wsMaskBuf(key, buf []byte) { + for i := 0; i < len(buf); i++ { + buf[i] ^= key[i&3] + } +} + +// Mask the buffers, as if they were contiguous, with the given key +func wsMaskBufs(key []byte, bufs [][]byte) { + pos := 0 + for i := 0; i < len(bufs); i++ { + buf := bufs[i] + for j := 0; j < len(buf); j++ { + buf[j] ^= key[pos&3] + pos++ + } + } +} + // Enqueues a websocket control message. // If the control message is a wsCloseMessage, then marks this client // has having sent the close message (since only one should be sent). @@ -437,12 +494,20 @@ func (c *client) wsEnqueueControlMessage(controlMsg wsOpCode, payload []byte) { func (c *client) wsEnqueueControlMessageLocked(controlMsg wsOpCode, payload []byte) { // Control messages are never compressed and their size will be // less than wsMaxControlPayloadSize, which means the frame header - // will be only 2 bytes. - cm := make([]byte, 2+len(payload)) - wsFillFrameHeader(cm, wsFirstFrame, wsFinalFrame, wsUncompressedFrame, controlMsg, len(payload)) + // will be only 2 or 6 bytes. + useMasking := c.ws.maskwrite + sz := 2 + if useMasking { + sz += 4 + } + cm := make([]byte, sz+len(payload)) + n, key := wsFillFrameHeader(cm, useMasking, wsFirstFrame, wsFinalFrame, wsUncompressedFrame, controlMsg, len(payload)) // Note that payload is optional. if len(payload) > 0 { - copy(cm[2:], payload) + copy(cm[n:], payload) + if useMasking { + wsMaskBuf(key, cm[n:]) + } } c.out.pb += int64(len(cm)) if controlMsg == wsCloseMessage { @@ -522,6 +587,14 @@ func wsCreateCloseMessage(status int, body string) []byte { // will be used to create a *client object. // Invoked from the HTTP server listening on websocket port. func (s *Server) wsUpgrade(w http.ResponseWriter, r *http.Request) (*wsUpgradeResult, error) { + kind := CLIENT + if r.URL != nil { + ep := r.URL.EscapedPath() + if strings.HasPrefix(ep, leafNodeWSPath) { + kind = LEAF + } + } + opts := s.getOpts() // From https://tools.ietf.org/html/rfc6455#section-4.2.1 @@ -558,11 +631,12 @@ func (s *Server) wsUpgrade(w http.ResponseWriter, r *http.Request) (*wsUpgradeRe // Point 8. // We don't have protocols, so ignore. // Point 9. - // Extensions, only support for compression at the moment - compress := opts.Websocket.Compression - if compress { - compress = wsClientSupportsCompression(r.Header) - } + // Extensions, only support for compression and no-masking at the moment + wantsCompress, wantsNoMasking := wsClientWantedExtensions(r.Header) + // We will use compression only if both agree + compress := opts.Websocket.Compression && wantsCompress + // We will do masking if asked (unless we reject for tests) + noMasking := wantsNoMasking && !wsTestRejectNoMasking h := w.(http.Hijacker) conn, brw, err := h.Hijack() @@ -584,8 +658,16 @@ func (s *Server) wsUpgrade(w http.ResponseWriter, r *http.Request) (*wsUpgradeRe p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) p = append(p, wsAcceptKey(key)...) p = append(p, _CRLF_...) - if compress { - p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) + if compress || noMasking { + p = append(p, "Sec-WebSocket-Extensions: "...) + if compress { + p = append(p, wsPMCExtension...) + p = append(p, wsNoCtxTakeOver...) + } + if noMasking { + p = append(p, wsNoMaskingExtension...) + } + p = append(p, CR_LF...) } p = append(p, _CRLF_...) @@ -597,17 +679,21 @@ func (s *Server) wsUpgrade(w http.ResponseWriter, r *http.Request) (*wsUpgradeRe if opts.Websocket.HandshakeTimeout > 0 { conn.SetDeadline(time.Time{}) } - ws := &websocket{compress: compress} - // Indicate if this is likely coming from a browser. - if ua := r.Header.Get("User-Agent"); ua != "" && strings.HasPrefix(ua, "Mozilla/") { - ws.browser = true - } - if opts.Websocket.JWTCookie != "" { - if c, err := r.Cookie(opts.Websocket.JWTCookie); err == nil && c != nil { - ws.cookieJwt = c.Value + // Server always expect "clients" to send masked payload, unless the option + // "no-masking" has been enabled. + ws := &websocket{compress: compress, maskread: !noMasking} + if kind == CLIENT { + // Indicate if this is likely coming from a browser. + if ua := r.Header.Get("User-Agent"); ua != "" && strings.HasPrefix(ua, "Mozilla/") { + ws.browser = true + } + if opts.Websocket.JWTCookie != "" { + if c, err := r.Cookie(opts.Websocket.JWTCookie); err == nil && c != nil { + ws.cookieJwt = c.Value + } } } - return &wsUpgradeResult{conn: conn, ws: ws}, nil + return &wsUpgradeResult{conn: conn, ws: ws, kind: kind}, nil } // Returns true if the header named `name` contains a token with value `value`. @@ -624,22 +710,28 @@ func wsHeaderContains(header http.Header, name string, value string) bool { return false } -// Return true if the client has "permessage-deflate" in its extensions. -func wsClientSupportsCompression(header http.Header) bool { +// Return if known extensions are wanted by the client. +func wsClientWantedExtensions(header http.Header) (bool, bool) { + var compress bool + var noMasking bool + for _, extensionList := range header["Sec-Websocket-Extensions"] { extensions := strings.Split(extensionList, ",") for _, extension := range extensions { extension = strings.Trim(extension, " \t") params := strings.Split(extension, ";") for _, p := range params { - p = strings.Trim(p, " \t") - if strings.EqualFold(p, "permessage-deflate") { - return true + p = strings.ToLower(strings.Trim(p, " \t")) + switch p { + case wsPMCExtension: + compress = true + case wsNoMaskingExtension: + noMasking = true } } } } - return false + return compress, noMasking } // Send an HTTP error with the given `status`` to the given http response writer `w`. @@ -728,6 +820,14 @@ func wsAcceptKey(key string) string { return base64.StdEncoding.EncodeToString(h.Sum(nil)) } +func wsMakeChallengeKey() (string, error) { + p := make([]byte, 16) + if _, err := io.ReadFull(rand.Reader, p); err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(p), nil +} + // Validate the websocket related options. func validateWebsocketOptions(o *Options) error { wo := &o.Websocket @@ -841,11 +941,11 @@ func (s *Server) startWebsocketServer() { // user has configured NoTLS because otherwise the server would have failed // to start due to options validation. if o.TLSConfig != nil { - proto = "wss" + proto = wsSchemePrefixTLS config := o.TLSConfig.Clone() hl, err = tls.Listen("tcp", hp, config) } else { - proto = "ws" + proto = wsSchemePrefix hl, err = net.Listen("tcp", hp) } if err != nil { @@ -854,7 +954,7 @@ func (s *Server) startWebsocketServer() { return } s.Noticef("Listening for websocket clients on %s://%s:%d", proto, o.Host, port) - if proto == "ws" { + if proto == wsSchemePrefix { s.Warnf("Websocket not configured with TLS. DO NOT USE IN PRODUCTION!") } @@ -862,7 +962,6 @@ func (s *Server) startWebsocketServer() { if port == 0 { s.opts.Websocket.Port = hl.Addr().(*net.TCPAddr).Port } - s.Noticef("Listening for websocket clients on %s://%s:%d", proto, o.Host, port) s.websocket.connectURLs, err = s.getConnectURLs(o.Advertise, o.Host, o.Port) if err != nil { s.Fatalf("Unable to get websocket connect URLs: %v", err) @@ -870,6 +969,7 @@ func (s *Server) startWebsocketServer() { s.mu.Unlock() return } + hasLeaf := sopts.LeafNode.Port != 0 mux := http.NewServeMux() mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { res, err := s.wsUpgrade(w, r) @@ -877,7 +977,20 @@ func (s *Server) startWebsocketServer() { s.Errorf(err.Error()) return } - s.createWSClient(res.conn, res.ws) + switch res.kind { + case CLIENT: + s.createWSClient(res.conn, res.ws) + case LEAF: + if !hasLeaf { + s.Errorf("Not configured to accept leaf node connections") + // Silently close for now. If we want to send an error back, we would + // need to create the leafnode client anyway, so that is is handling websocket + // frames, then send the error to the remote. + res.conn.Close() + return + } + s.createLeafNode(res.conn, nil, nil, res.ws) + } }) hs := &http.Server{ Addr: hp, @@ -1031,6 +1144,7 @@ func (c *client) wsCollapsePtoNB() (net.Buffers, int64) { } else if len(c.out.nb) > 0 { nb = c.out.nb } + mask := c.ws.maskwrite // Start with possible already framed buffers (that we could have // got from partials or control messages such as ws pings or pongs). bufs := c.ws.frames @@ -1062,13 +1176,19 @@ func (c *client) wsCollapsePtoNB() (net.Buffers, int64) { final = true } fh := make([]byte, wsMaxFrameHeaderSize) - n := wsFillFrameHeader(fh, first, final, wsCompressedFrame, wsBinaryMessage, lp) + n, key := wsFillFrameHeader(fh, mask, first, final, wsCompressedFrame, wsBinaryMessage, lp) + if mask { + wsMaskBuf(key, p[:lp]) + } bufs = append(bufs, fh[:n], p[:lp]) csz += n + lp p = p[lp:] } } else { - h := wsCreateFrameHeader(true, wsBinaryMessage, len(p)) + h, key := wsCreateFrameHeader(mask, true, wsBinaryMessage, len(p)) + if mask { + wsMaskBuf(key, p) + } bufs = append(bufs, h, p) csz = len(h) + len(p) } @@ -1085,10 +1205,13 @@ func (c *client) wsCollapsePtoNB() (net.Buffers, int64) { return len(bufs) - 1 } endFrame := func(idx, size int) { - n := wsFillFrameHeader(bufs[idx], wsFirstFrame, wsFinalFrame, wsUncompressedFrame, wsBinaryMessage, size) + n, key := wsFillFrameHeader(bufs[idx], mask, wsFirstFrame, wsFinalFrame, wsUncompressedFrame, wsBinaryMessage, size) c.out.pb += int64(n) c.ws.fs += int64(n + size) bufs[idx] = bufs[idx][:n] + if mask { + wsMaskBufs(key, bufs[idx+1:]) + } } fhIdx := startFrame() @@ -1119,10 +1242,14 @@ func (c *client) wsCollapsePtoNB() (net.Buffers, int64) { for _, b := range nb { total += len(b) } - wsfh := wsCreateFrameHeader(false, wsBinaryMessage, total) + wsfh, key := wsCreateFrameHeader(mask, false, wsBinaryMessage, total) c.out.pb += int64(len(wsfh)) bufs = append(bufs, wsfh) + idx := len(bufs) bufs = append(bufs, nb...) + if mask { + wsMaskBufs(key, bufs[idx:]) + } c.ws.fs += int64(len(wsfh) + total) } } @@ -1134,3 +1261,7 @@ func (c *client) wsCollapsePtoNB() (net.Buffers, int64) { c.ws.frames = nil return bufs, c.ws.fs } + +func isWSURL(u *url.URL) bool { + return strings.HasPrefix(strings.ToLower(u.Scheme), wsSchemePrefix) +} diff --git a/server/websocket_test.go b/server/websocket_test.go index 101d1ba1fe..7d0a24762e 100644 --- a/server/websocket_test.go +++ b/server/websocket_test.go @@ -161,7 +161,7 @@ func TestWSUnmask(t *testing.T) { return buf } - ri := &wsReadInfo{} + ri := &wsReadInfo{mask: true} ri.init() copy(ri.mkey[:], key) @@ -239,7 +239,7 @@ func TestWSCreateFrameHeader(t *testing.T) { {"compressed 100000", wsTextMessage, true, 100000}, } { t.Run(test.name, func(t *testing.T) { - res := wsCreateFrameHeader(test.compressed, test.frameType, test.len) + res, _ := wsCreateFrameHeader(false, test.compressed, test.frameType, test.len) // The server is always sending the message has a single frame, // so the "final" bit should be set. expected := byte(test.frameType) | wsFinalBit @@ -329,7 +329,7 @@ func testWSCreateClientMsg(frameType wsOpCode, frameNum int, final, compressed b } func testWSSetupForRead() (*client, *wsReadInfo, *testReader) { - ri := &wsReadInfo{} + ri := &wsReadInfo{mask: true} ri.init() tr := &testReader{} opts := DefaultOptions() @@ -2399,147 +2399,217 @@ func TestWSAdvertise(t *testing.T) { } func TestWSFrameOutbound(t *testing.T) { - c, _, _ := testWSSetupForRead() - - var bufs net.Buffers - bufs = append(bufs, []byte("this ")) - bufs = append(bufs, []byte("is ")) - bufs = append(bufs, []byte("a ")) - bufs = append(bufs, []byte("set ")) - bufs = append(bufs, []byte("of ")) - bufs = append(bufs, []byte("buffers")) - en := 2 - for _, b := range bufs { - en += len(b) - } - c.mu.Lock() - c.out.nb = bufs - res, n := c.collapsePtoNB() - c.mu.Unlock() - if n != int64(en) { - t.Fatalf("Expected size to be %v, got %v", en, n) - } - if eb := 1 + len(bufs); eb != len(res) { - t.Fatalf("Expected %v buffers, got %v", eb, len(res)) - } - var ob []byte - for i := 1; i < len(res); i++ { - ob = append(ob, res[i]...) - } - if !bytes.Equal(ob, []byte("this is a set of buffers")) { - t.Fatalf("Unexpected outbound: %q", ob) - } - - bufs = nil - c.out.pb = 0 - c.ws.fs = 0 - c.ws.frames = nil - c.ws.browser = true - bufs = append(bufs, []byte("some smaller ")) - bufs = append(bufs, []byte("buffers")) - bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers+10)) - bufs = append(bufs, []byte("then some more")) - en = 2 + len(bufs[0]) + len(bufs[1]) - en += 4 + len(bufs[2]) - 10 - en += 2 + len(bufs[3]) + 10 - c.mu.Lock() - c.out.nb = bufs - res, n = c.collapsePtoNB() - c.mu.Unlock() - if n != int64(en) { - t.Fatalf("Expected size to be %v, got %v", en, n) - } - if len(res) != 8 { - t.Fatalf("Unexpected number of outbound buffers: %v", len(res)) - } - if len(res[4]) != wsFrameSizeForBrowsers { - t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4])) - } - if len(res[6]) != 10 { - t.Fatalf("Frame 6 should have the partial of 10 bytes, got %v", len(res[6])) - } - - bufs = nil - c.out.pb = 0 - c.ws.fs = 0 - c.ws.frames = nil - c.ws.browser = true - bufs = append(bufs, []byte("some smaller ")) - bufs = append(bufs, []byte("buffers")) - // Have one of the exact max size - bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers)) - bufs = append(bufs, []byte("then some more")) - en = 2 + len(bufs[0]) + len(bufs[1]) - en += 4 + len(bufs[2]) - en += 2 + len(bufs[3]) - c.mu.Lock() - c.out.nb = bufs - res, n = c.collapsePtoNB() - c.mu.Unlock() - if n != int64(en) { - t.Fatalf("Expected size to be %v, got %v", en, n) - } - if len(res) != 7 { - t.Fatalf("Unexpected number of outbound buffers: %v", len(res)) - } - if len(res[4]) != wsFrameSizeForBrowsers { - t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4])) - } - if string(res[6]) != string(bufs[3]) { - t.Fatalf("Frame 6 should be %q, got %q", bufs[3], res[6]) - } - - bufs = nil - c.out.pb = 0 - c.ws.fs = 0 - c.ws.frames = nil - c.ws.browser = true - bufs = append(bufs, []byte("some smaller ")) - bufs = append(bufs, []byte("buffers")) - // Have one of the exact max size, and last in the list - bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers)) - en = 2 + len(bufs[0]) + len(bufs[1]) - en += 4 + len(bufs[2]) - c.mu.Lock() - c.out.nb = bufs - res, n = c.collapsePtoNB() - c.mu.Unlock() - if n != int64(en) { - t.Fatalf("Expected size to be %v, got %v", en, n) - } - if len(res) != 5 { - t.Fatalf("Unexpected number of outbound buffers: %v", len(res)) - } - if len(res[4]) != wsFrameSizeForBrowsers { - t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4])) - } - - bufs = nil - c.out.pb = 0 - c.ws.fs = 0 - c.ws.frames = nil - c.ws.browser = true - bufs = append(bufs, []byte("some smaller buffer")) - bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers-5)) - bufs = append(bufs, []byte("then some more")) - en = 2 + len(bufs[0]) - en += 4 + len(bufs[1]) - en += 2 + len(bufs[2]) - c.mu.Lock() - c.out.nb = bufs - res, n = c.collapsePtoNB() - c.mu.Unlock() - if n != int64(en) { - t.Fatalf("Expected size to be %v, got %v", en, n) - } - if len(res) != 6 { - t.Fatalf("Unexpected number of outbound buffers: %v", len(res)) - } - if len(res[3]) != wsFrameSizeForBrowsers-5 { - t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4])) - } - if string(res[5]) != string(bufs[2]) { - t.Fatalf("Frame 6 should be %q, got %q", bufs[2], res[5]) + for _, test := range []struct { + name string + maskingWrite bool + }{ + {"no write masking", false}, + {"write masking", true}, + } { + t.Run(test.name, func(t *testing.T) { + c, _, _ := testWSSetupForRead() + c.ws.maskwrite = test.maskingWrite + + getKey := func(buf []byte) []byte { + return buf[len(buf)-4:] + } + + var bufs net.Buffers + bufs = append(bufs, []byte("this ")) + bufs = append(bufs, []byte("is ")) + bufs = append(bufs, []byte("a ")) + bufs = append(bufs, []byte("set ")) + bufs = append(bufs, []byte("of ")) + bufs = append(bufs, []byte("buffers")) + en := 2 + for _, b := range bufs { + en += len(b) + } + if test.maskingWrite { + en += 4 + } + c.mu.Lock() + c.out.nb = bufs + res, n := c.collapsePtoNB() + c.mu.Unlock() + if n != int64(en) { + t.Fatalf("Expected size to be %v, got %v", en, n) + } + if eb := 1 + len(bufs); eb != len(res) { + t.Fatalf("Expected %v buffers, got %v", eb, len(res)) + } + var ob []byte + for i := 1; i < len(res); i++ { + ob = append(ob, res[i]...) + } + if test.maskingWrite { + wsMaskBuf(getKey(res[0]), ob) + } + if !bytes.Equal(ob, []byte("this is a set of buffers")) { + t.Fatalf("Unexpected outbound: %q", ob) + } + + bufs = nil + c.out.pb = 0 + c.ws.fs = 0 + c.ws.frames = nil + c.ws.browser = true + bufs = append(bufs, []byte("some smaller ")) + bufs = append(bufs, []byte("buffers")) + bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers+10)) + bufs = append(bufs, []byte("then some more")) + en = 2 + len(bufs[0]) + len(bufs[1]) + en += 4 + len(bufs[2]) - 10 + en += 2 + len(bufs[3]) + 10 + c.mu.Lock() + c.out.nb = bufs + res, n = c.collapsePtoNB() + c.mu.Unlock() + if test.maskingWrite { + en += 3 * 4 + } + if n != int64(en) { + t.Fatalf("Expected size to be %v, got %v", en, n) + } + if len(res) != 8 { + t.Fatalf("Unexpected number of outbound buffers: %v", len(res)) + } + if len(res[4]) != wsFrameSizeForBrowsers { + t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4])) + } + if len(res[6]) != 10 { + t.Fatalf("Frame 6 should have the partial of 10 bytes, got %v", len(res[6])) + } + if test.maskingWrite { + b := &bytes.Buffer{} + key := getKey(res[0]) + b.Write(res[1]) + b.Write(res[2]) + ud := b.Bytes() + wsMaskBuf(key, ud) + if string(ud) != "some smaller buffers" { + t.Fatalf("Unexpected result: %q", ud) + } + + b.Reset() + key = getKey(res[3]) + b.Write(res[4]) + ud = b.Bytes() + wsMaskBuf(key, ud) + for i := 0; i < len(ud); i++ { + if ud[i] != 0 { + t.Fatalf("Unexpected result: %v", ud) + } + } + + b.Reset() + key = getKey(res[5]) + b.Write(res[6]) + b.Write(res[7]) + ud = b.Bytes() + wsMaskBuf(key, ud) + for i := 0; i < len(ud[:10]); i++ { + if ud[i] != 0 { + t.Fatalf("Unexpected result: %v", ud[:10]) + } + } + if string(ud[10:]) != "then some more" { + t.Fatalf("Unexpected result: %q", ud[10:]) + } + } + + bufs = nil + c.out.pb = 0 + c.ws.fs = 0 + c.ws.frames = nil + c.ws.browser = true + bufs = append(bufs, []byte("some smaller ")) + bufs = append(bufs, []byte("buffers")) + // Have one of the exact max size + bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers)) + bufs = append(bufs, []byte("then some more")) + en = 2 + len(bufs[0]) + len(bufs[1]) + en += 4 + len(bufs[2]) + en += 2 + len(bufs[3]) + c.mu.Lock() + c.out.nb = bufs + res, n = c.collapsePtoNB() + c.mu.Unlock() + if test.maskingWrite { + en += 3 * 4 + } + if n != int64(en) { + t.Fatalf("Expected size to be %v, got %v", en, n) + } + if len(res) != 7 { + t.Fatalf("Unexpected number of outbound buffers: %v", len(res)) + } + if len(res[4]) != wsFrameSizeForBrowsers { + t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4])) + } + if string(res[6]) != string(bufs[3]) { + t.Fatalf("Frame 6 should be %q, got %q", bufs[3], res[6]) + } + + bufs = nil + c.out.pb = 0 + c.ws.fs = 0 + c.ws.frames = nil + c.ws.browser = true + bufs = append(bufs, []byte("some smaller ")) + bufs = append(bufs, []byte("buffers")) + // Have one of the exact max size, and last in the list + bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers)) + en = 2 + len(bufs[0]) + len(bufs[1]) + en += 4 + len(bufs[2]) + c.mu.Lock() + c.out.nb = bufs + res, n = c.collapsePtoNB() + c.mu.Unlock() + if test.maskingWrite { + en += 2 * 4 + } + if n != int64(en) { + t.Fatalf("Expected size to be %v, got %v", en, n) + } + if len(res) != 5 { + t.Fatalf("Unexpected number of outbound buffers: %v", len(res)) + } + if len(res[4]) != wsFrameSizeForBrowsers { + t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4])) + } + + bufs = nil + c.out.pb = 0 + c.ws.fs = 0 + c.ws.frames = nil + c.ws.browser = true + bufs = append(bufs, []byte("some smaller buffer")) + bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers-5)) + bufs = append(bufs, []byte("then some more")) + en = 2 + len(bufs[0]) + en += 4 + len(bufs[1]) + en += 2 + len(bufs[2]) + c.mu.Lock() + c.out.nb = bufs + res, n = c.collapsePtoNB() + c.mu.Unlock() + if test.maskingWrite { + en += 3 * 4 + } + if n != int64(en) { + t.Fatalf("Expected size to be %v, got %v", en, n) + } + if len(res) != 6 { + t.Fatalf("Unexpected number of outbound buffers: %v", len(res)) + } + if len(res[3]) != wsFrameSizeForBrowsers-5 { + t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4])) + } + if string(res[5]) != string(bufs[2]) { + t.Fatalf("Frame 6 should be %q, got %q", bufs[2], res[5]) + } + }) } } @@ -2781,45 +2851,62 @@ func TestWSCompressionWithPartialWrite(t *testing.T) { } func TestWSCompressionFrameSizeLimit(t *testing.T) { - opts := testWSOptions() - opts.MaxPending = MAX_PENDING_SIZE - s := &Server{opts: opts} - c := &client{srv: s, ws: &websocket{compress: true, browser: true}} - c.initClient() + for _, test := range []struct { + name string + maskWrite bool + }{ + {"no write masking", false}, + {"write masking", true}, + } { + t.Run(test.name, func(t *testing.T) { + opts := testWSOptions() + opts.MaxPending = MAX_PENDING_SIZE + s := &Server{opts: opts} + c := &client{srv: s, ws: &websocket{compress: true, browser: true, maskwrite: test.maskWrite}} + c.initClient() - // uncompressedPayload := []byte("abcdefghijklmnopqrstuvwxyz") - uncompressedPayload := make([]byte, 2*wsFrameSizeForBrowsers) - for i := 0; i < len(uncompressedPayload); i++ { - uncompressedPayload[i] = byte(rand.Intn(256)) - } + uncompressedPayload := make([]byte, 2*wsFrameSizeForBrowsers) + for i := 0; i < len(uncompressedPayload); i++ { + uncompressedPayload[i] = byte(rand.Intn(256)) + } - c.mu.Lock() - c.out.nb = append(net.Buffers(nil), uncompressedPayload) - nb, _ := c.collapsePtoNB() - c.mu.Unlock() + c.mu.Lock() + c.out.nb = append(net.Buffers(nil), uncompressedPayload) + nb, _ := c.collapsePtoNB() + c.mu.Unlock() - bb := &bytes.Buffer{} - for i, b := range nb { - // frame header buffer are always very small. The payload should not be more - // than 10 bytes since that is what we passed as the limit. - if len(b) > wsFrameSizeForBrowsers { - t.Fatalf("Frame size too big: %v (%q)", len(b), b) - } - // Check frame headers for the proper formatting. - if i%2 == 1 { - bb.Write(b) - } - } - buf := bb.Bytes() - buf = append(buf, 0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff) - dbr := bytes.NewBuffer(buf) - d := flate.NewReader(dbr) - uncompressed, err := ioutil.ReadAll(d) - if err != nil { - t.Fatalf("Error reading frame: %v", err) - } - if !bytes.Equal(uncompressed, uncompressedPayload) { - t.Fatalf("Unexpected uncomressed data: %q", uncompressed) + bb := &bytes.Buffer{} + var key []byte + for i, b := range nb { + // frame header buffer are always very small. The payload should not be more + // than 10 bytes since that is what we passed as the limit. + if len(b) > wsFrameSizeForBrowsers { + t.Fatalf("Frame size too big: %v (%q)", len(b), b) + } + if test.maskWrite { + if i%2 == 0 { + key = b[len(b)-4:] + } else { + wsMaskBuf(key, b) + } + } + // Check frame headers for the proper formatting. + if i%2 == 1 { + bb.Write(b) + } + } + buf := bb.Bytes() + buf = append(buf, 0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff) + dbr := bytes.NewBuffer(buf) + d := flate.NewReader(dbr) + uncompressed, err := ioutil.ReadAll(d) + if err != nil { + t.Fatalf("Error reading frame: %v", err) + } + if !bytes.Equal(uncompressed, uncompressedPayload) { + t.Fatalf("Unexpected uncomressed data: %q", uncompressed) + } + }) } }