From 63204b79ae5a051bafacbeb05278a298dcd16c2b Mon Sep 17 00:00:00 2001 From: Ivan Kozlovic Date: Thu, 9 May 2024 10:25:28 -0600 Subject: [PATCH] [IMPROVED] Websocket: generating INFO to send to clients PR #4255 added code in generateClientInfoJSON to set the proper info Host/Port/TLSAvailable/TLSRequired fields to send to clients. However, this was requiring a lock but more importantly was computing the listener's host/port everytime, which is not necessary since this is immutable because we don't support the change during a config reload. Also, the TLSRequired field was set based on the server TLSConfig's InsecureSkipVerify value, which is irrelevant for a server. The mere presence of a TLSConfig (c.srv.websocket.tls being true) is enough. I have modified the TestWSReloadTLSConfig test to verify that the tls block cannot be removed and no_tls set to true, which means that tls value can't change. I also added check for the info's Host/Port/TLSAvailable/TLSRequired values. Signed-off-by: Ivan Kozlovic --- server/client.go | 23 ++----- server/websocket.go | 16 ++++- server/websocket_test.go | 134 ++++++++++++++++++++++++++++++--------- 3 files changed, 124 insertions(+), 49 deletions(-) diff --git a/server/client.go b/server/client.go index 1f57839d44..dd2fe945a7 100644 --- a/server/client.go +++ b/server/client.go @@ -2383,24 +2383,11 @@ func (c *client) generateClientInfoJSON(info Info) []byte { info.MaxPayload = c.mpay if c.isWebsocket() { info.ClientConnectURLs = info.WSConnectURLs - if c.srv != nil { // Otherwise lame duck info can panic - c.srv.websocket.mu.RLock() - info.TLSAvailable = c.srv.websocket.tls - if c.srv.websocket.tls && c.srv.websocket.server != nil { - if tc := c.srv.websocket.server.TLSConfig; tc != nil { - info.TLSRequired = !tc.InsecureSkipVerify - } - } - if c.srv.websocket.listener != nil { - laddr := c.srv.websocket.listener.Addr().String() - if h, p, err := net.SplitHostPort(laddr); err == nil { - if p, err := strconv.Atoi(p); err == nil { - info.Host = h - info.Port = p - } - } - } - c.srv.websocket.mu.RUnlock() + // Otherwise lame duck info can panic + if c.srv != nil { + ws := &c.srv.websocket + info.TLSAvailable, info.TLSRequired = ws.tls, ws.tls + info.Host, info.Port = ws.host, ws.port } } info.WSConnectURLs = nil diff --git a/server/websocket.go b/server/websocket.go index 8164917a7c..521e56042a 100644 --- a/server/websocket.go +++ b/server/websocket.go @@ -128,13 +128,18 @@ type srvWebsocket struct { server *http.Server listener net.Listener listenerErr error - tls bool allowedOrigins map[string]*allowedOrigin // host will be the key sameOrigin bool connectURLs []string connectURLsMap refCountedUrlSet authOverride bool // indicate if there is auth override in websocket config rawHeaders string // raw headers to be used in the upgrade response. + + // These are immutable and can be accessed without lock. + // This is the case when generating the client INFO. + tls bool // True if TLS is required (TLSConfig is specified). + host string // Host/IP the webserver is listening on (shortcut to opts.Websocket.Host). + port int // Port the webserver is listening on. This is after an ephemeral port may have been selected (shortcut to opts.Websocket.Port). } type allowedOrigin struct { @@ -1153,7 +1158,12 @@ func (s *Server) startWebsocketServer() { s.Warnf("Websocket not configured with TLS. DO NOT USE IN PRODUCTION!") } - s.websocket.tls = proto == "wss" + // These 3 are immutable and will be accessed without lock by the client + // when generating/sending the INFO protocols. + s.websocket.tls = proto == wsSchemePrefixTLS + s.websocket.host, s.websocket.port = o.Host, o.Port + + // This will be updated when/if the cluster changes. s.websocket.connectURLs, err = s.getConnectURLs(o.Advertise, o.Host, o.Port) if err != nil { s.Fatalf("Unable to get websocket connect URLs: %v", err) @@ -1192,8 +1202,10 @@ func (s *Server) startWebsocketServer() { ReadTimeout: o.HandshakeTimeout, ErrorLog: log.New(&captureHTTPServerLog{s, "websocket: "}, _EMPTY_, 0), } + s.websocket.mu.Lock() s.websocket.server = hs s.websocket.listener = hl + s.websocket.mu.Unlock() go func() { if err := hs.Serve(hl); err != http.ErrServerClosed { s.Fatalf("websocket listener error: %v", err) diff --git a/server/websocket_test.go b/server/websocket_test.go index 723e1a7df5..5cd26c7f8c 100644 --- a/server/websocket_test.go +++ b/server/websocket_test.go @@ -4061,59 +4061,135 @@ func TestWSJWTCookieUser(t *testing.T) { } func TestWSReloadTLSConfig(t *testing.T) { + tlsBlock := ` + tls { + cert_file: '%s' + key_file: '%s' + ca_file: '../test/configs/certs/ca.pem' + verify: %v + } + ` template := ` listen: "127.0.0.1:-1" websocket { listen: "127.0.0.1:-1" - tls { - cert_file: '%s' - key_file: '%s' - ca_file: '../test/configs/certs/ca.pem' - } + %s + no_tls: %v } ` conf := createConfFile(t, []byte(fmt.Sprintf(template, - "../test/configs/certs/server-noip.pem", - "../test/configs/certs/server-key-noip.pem"))) + fmt.Sprintf(tlsBlock, + "../test/configs/certs/server-noip.pem", + "../test/configs/certs/server-key-noip.pem", + false), false))) s, o := RunServerWithConfig(conf) defer s.Shutdown() addr := fmt.Sprintf("127.0.0.1:%d", o.Websocket.Port) - wsc, err := net.Dial("tcp", addr) - if err != nil { - t.Fatalf("Error creating ws connection: %v", err) + + check := func(tlsConfig *tls.Config, handshakeFail bool, errTxt string) { + t.Helper() + + wsc, err := net.Dial("tcp", addr) + require_NoError(t, err) + defer wsc.Close() + + wsc = tls.Client(wsc, tlsConfig) + err = wsc.(*tls.Conn).Handshake() + if handshakeFail { + require_True(t, err != nil) + require_Contains(t, err.Error(), errTxt) + return + } + require_NoError(t, err) + + req := testWSCreateValidReq() + req.URL, _ = url.Parse(wsSchemePrefixTLS + "://" + addr) + err = req.Write(wsc) + require_NoError(t, err) + + br := bufio.NewReader(wsc) + resp, err := http.ReadResponse(br, req) + if errTxt == _EMPTY_ { + require_NoError(t, err) + } else { + require_True(t, err != nil) + require_Contains(t, err.Error(), errTxt) + return + } + defer resp.Body.Close() + l := testWSReadFrame(t, br) + require_True(t, bytes.HasPrefix(l, []byte("INFO {"))) + var info Info + err = json.Unmarshal(l[5:], &info) + require_NoError(t, err) + require_True(t, info.TLSAvailable) + require_True(t, info.TLSRequired) + require_Equal[string](t, info.Host, "127.0.0.1") + require_Equal[int](t, info.Port, o.Websocket.Port) } - defer wsc.Close() tc := &TLSConfigOpts{CaFile: "../test/configs/certs/ca.pem"} tlsConfig, err := GenTLSConfig(tc) - if err != nil { - t.Fatalf("Error generating TLS config: %v", err) - } + require_NoError(t, err) tlsConfig.ServerName = "127.0.0.1" tlsConfig.RootCAs = tlsConfig.ClientCAs tlsConfig.ClientCAs = nil - wsc = tls.Client(wsc, tlsConfig.Clone()) - if err := wsc.(*tls.Conn).Handshake(); err == nil || !strings.Contains(err.Error(), "SAN") { - t.Fatalf("Unexpected error: %v", err) - } - wsc.Close() + // Handshake should fail with error regarding SANs + check(tlsConfig.Clone(), true, "SAN") + + // Replace certs with ones that allow IP. reloadUpdateConfig(t, s, conf, fmt.Sprintf(template, - "../test/configs/certs/server-cert.pem", - "../test/configs/certs/server-key.pem")) + fmt.Sprintf(tlsBlock, + "../test/configs/certs/server-cert.pem", + "../test/configs/certs/server-key.pem", + false), false)) - wsc, err = net.Dial("tcp", addr) - if err != nil { - t.Fatalf("Error creating ws connection: %v", err) - } - defer wsc.Close() + // Connection should succeed + check(tlsConfig.Clone(), false, _EMPTY_) - wsc = tls.Client(wsc, tlsConfig.Clone()) - if err := wsc.(*tls.Conn).Handshake(); err != nil { - t.Fatalf("Error on TLS handshake: %v", err) + // Udpate config to require client cert. + reloadUpdateConfig(t, s, conf, fmt.Sprintf(template, + fmt.Sprintf(tlsBlock, + "../test/configs/certs/server-cert.pem", + "../test/configs/certs/server-key.pem", + true), false)) + + // Connection should fail saying that a tls cert is required + check(tlsConfig.Clone(), false, "required") + + // Add a client cert + tc = &TLSConfigOpts{ + CertFile: "../test/configs/certs/client-cert.pem", + KeyFile: "../test/configs/certs/client-key.pem", } + tlsConfig, err = GenTLSConfig(tc) + require_NoError(t, err) + tlsConfig.InsecureSkipVerify = true + + // Connection should succeed + check(tlsConfig.Clone(), false, _EMPTY_) + + // Removing the tls{} block but with no_tls still false should fail + changeCurrentConfigContentWithNewContent(t, conf, []byte(fmt.Sprintf(template, _EMPTY_, false))) + err = s.Reload() + require_True(t, err != nil) + require_Contains(t, err.Error(), "TLS configuration") + + // We should still be able to connect a TLS client + check(tlsConfig.Clone(), false, _EMPTY_) + + // Now remove the tls{} block and set no_tls: true and that should fail + // since this is not supported. + changeCurrentConfigContentWithNewContent(t, conf, []byte(fmt.Sprintf(template, _EMPTY_, true))) + err = s.Reload() + require_True(t, err != nil) + require_Contains(t, err.Error(), "not supported") + + // We should still be able to connect a TLS client + check(tlsConfig.Clone(), false, _EMPTY_) } type captureClientConnectedLogger struct {