Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IMPROVED] Websocket: generating INFO to send to clients #5405

Merged
merged 1 commit into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 5 additions & 18 deletions server/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2383,24 +2383,11 @@ func (c *client) generateClientInfoJSON(info Info) []byte {
info.MaxPayload = c.mpay
if c.isWebsocket() {
info.ClientConnectURLs = info.WSConnectURLs
if c.srv != nil { // Otherwise lame duck info can panic
c.srv.websocket.mu.RLock()
info.TLSAvailable = c.srv.websocket.tls
if c.srv.websocket.tls && c.srv.websocket.server != nil {
if tc := c.srv.websocket.server.TLSConfig; tc != nil {
info.TLSRequired = !tc.InsecureSkipVerify
}
}
if c.srv.websocket.listener != nil {
laddr := c.srv.websocket.listener.Addr().String()
if h, p, err := net.SplitHostPort(laddr); err == nil {
if p, err := strconv.Atoi(p); err == nil {
info.Host = h
info.Port = p
}
}
}
c.srv.websocket.mu.RUnlock()
// Otherwise lame duck info can panic
if c.srv != nil {
ws := &c.srv.websocket
info.TLSAvailable, info.TLSRequired = ws.tls, ws.tls
info.Host, info.Port = ws.host, ws.port
}
}
info.WSConnectURLs = nil
Expand Down
16 changes: 14 additions & 2 deletions server/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,18 @@ type srvWebsocket struct {
server *http.Server
listener net.Listener
listenerErr error
tls bool
allowedOrigins map[string]*allowedOrigin // host will be the key
sameOrigin bool
connectURLs []string
connectURLsMap refCountedUrlSet
authOverride bool // indicate if there is auth override in websocket config
rawHeaders string // raw headers to be used in the upgrade response.

// These are immutable and can be accessed without lock.
// This is the case when generating the client INFO.
tls bool // True if TLS is required (TLSConfig is specified).
host string // Host/IP the webserver is listening on (shortcut to opts.Websocket.Host).
port int // Port the webserver is listening on. This is after an ephemeral port may have been selected (shortcut to opts.Websocket.Port).
}

type allowedOrigin struct {
Expand Down Expand Up @@ -1153,7 +1158,12 @@ func (s *Server) startWebsocketServer() {
s.Warnf("Websocket not configured with TLS. DO NOT USE IN PRODUCTION!")
}

s.websocket.tls = proto == "wss"
// These 3 are immutable and will be accessed without lock by the client
// when generating/sending the INFO protocols.
s.websocket.tls = proto == wsSchemePrefixTLS
s.websocket.host, s.websocket.port = o.Host, o.Port

// This will be updated when/if the cluster changes.
s.websocket.connectURLs, err = s.getConnectURLs(o.Advertise, o.Host, o.Port)
if err != nil {
s.Fatalf("Unable to get websocket connect URLs: %v", err)
Expand Down Expand Up @@ -1192,8 +1202,10 @@ func (s *Server) startWebsocketServer() {
ReadTimeout: o.HandshakeTimeout,
ErrorLog: log.New(&captureHTTPServerLog{s, "websocket: "}, _EMPTY_, 0),
}
s.websocket.mu.Lock()
s.websocket.server = hs
s.websocket.listener = hl
s.websocket.mu.Unlock()
go func() {
if err := hs.Serve(hl); err != http.ErrServerClosed {
s.Fatalf("websocket listener error: %v", err)
Expand Down
134 changes: 105 additions & 29 deletions server/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4061,59 +4061,135 @@ func TestWSJWTCookieUser(t *testing.T) {
}

func TestWSReloadTLSConfig(t *testing.T) {
tlsBlock := `
tls {
cert_file: '%s'
key_file: '%s'
ca_file: '../test/configs/certs/ca.pem'
verify: %v
}
`
template := `
listen: "127.0.0.1:-1"
websocket {
listen: "127.0.0.1:-1"
tls {
cert_file: '%s'
key_file: '%s'
ca_file: '../test/configs/certs/ca.pem'
}
%s
no_tls: %v
}
`
conf := createConfFile(t, []byte(fmt.Sprintf(template,
"../test/configs/certs/server-noip.pem",
"../test/configs/certs/server-key-noip.pem")))
fmt.Sprintf(tlsBlock,
"../test/configs/certs/server-noip.pem",
"../test/configs/certs/server-key-noip.pem",
false), false)))

s, o := RunServerWithConfig(conf)
defer s.Shutdown()

addr := fmt.Sprintf("127.0.0.1:%d", o.Websocket.Port)
wsc, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("Error creating ws connection: %v", err)

check := func(tlsConfig *tls.Config, handshakeFail bool, errTxt string) {
t.Helper()

wsc, err := net.Dial("tcp", addr)
require_NoError(t, err)
defer wsc.Close()

wsc = tls.Client(wsc, tlsConfig)
err = wsc.(*tls.Conn).Handshake()
if handshakeFail {
require_True(t, err != nil)
require_Contains(t, err.Error(), errTxt)
return
}
require_NoError(t, err)

req := testWSCreateValidReq()
req.URL, _ = url.Parse(wsSchemePrefixTLS + "://" + addr)
err = req.Write(wsc)
require_NoError(t, err)

br := bufio.NewReader(wsc)
resp, err := http.ReadResponse(br, req)
if errTxt == _EMPTY_ {
require_NoError(t, err)
} else {
require_True(t, err != nil)
require_Contains(t, err.Error(), errTxt)
return
}
defer resp.Body.Close()
l := testWSReadFrame(t, br)
require_True(t, bytes.HasPrefix(l, []byte("INFO {")))
var info Info
err = json.Unmarshal(l[5:], &info)
require_NoError(t, err)
require_True(t, info.TLSAvailable)
require_True(t, info.TLSRequired)
require_Equal[string](t, info.Host, "127.0.0.1")
require_Equal[int](t, info.Port, o.Websocket.Port)
}
defer wsc.Close()

tc := &TLSConfigOpts{CaFile: "../test/configs/certs/ca.pem"}
tlsConfig, err := GenTLSConfig(tc)
if err != nil {
t.Fatalf("Error generating TLS config: %v", err)
}
require_NoError(t, err)
tlsConfig.ServerName = "127.0.0.1"
tlsConfig.RootCAs = tlsConfig.ClientCAs
tlsConfig.ClientCAs = nil
wsc = tls.Client(wsc, tlsConfig.Clone())
if err := wsc.(*tls.Conn).Handshake(); err == nil || !strings.Contains(err.Error(), "SAN") {
t.Fatalf("Unexpected error: %v", err)
}
wsc.Close()

// Handshake should fail with error regarding SANs
check(tlsConfig.Clone(), true, "SAN")

// Replace certs with ones that allow IP.
reloadUpdateConfig(t, s, conf, fmt.Sprintf(template,
"../test/configs/certs/server-cert.pem",
"../test/configs/certs/server-key.pem"))
fmt.Sprintf(tlsBlock,
"../test/configs/certs/server-cert.pem",
"../test/configs/certs/server-key.pem",
false), false))

wsc, err = net.Dial("tcp", addr)
if err != nil {
t.Fatalf("Error creating ws connection: %v", err)
}
defer wsc.Close()
// Connection should succeed
check(tlsConfig.Clone(), false, _EMPTY_)

wsc = tls.Client(wsc, tlsConfig.Clone())
if err := wsc.(*tls.Conn).Handshake(); err != nil {
t.Fatalf("Error on TLS handshake: %v", err)
// Udpate config to require client cert.
reloadUpdateConfig(t, s, conf, fmt.Sprintf(template,
fmt.Sprintf(tlsBlock,
"../test/configs/certs/server-cert.pem",
"../test/configs/certs/server-key.pem",
true), false))

// Connection should fail saying that a tls cert is required
check(tlsConfig.Clone(), false, "required")

// Add a client cert
tc = &TLSConfigOpts{
CertFile: "../test/configs/certs/client-cert.pem",
KeyFile: "../test/configs/certs/client-key.pem",
}
tlsConfig, err = GenTLSConfig(tc)
require_NoError(t, err)
tlsConfig.InsecureSkipVerify = true

// Connection should succeed
check(tlsConfig.Clone(), false, _EMPTY_)

// Removing the tls{} block but with no_tls still false should fail
changeCurrentConfigContentWithNewContent(t, conf, []byte(fmt.Sprintf(template, _EMPTY_, false)))
err = s.Reload()
require_True(t, err != nil)
require_Contains(t, err.Error(), "TLS configuration")

// We should still be able to connect a TLS client
check(tlsConfig.Clone(), false, _EMPTY_)

// Now remove the tls{} block and set no_tls: true and that should fail
// since this is not supported.
changeCurrentConfigContentWithNewContent(t, conf, []byte(fmt.Sprintf(template, _EMPTY_, true)))
err = s.Reload()
require_True(t, err != nil)
require_Contains(t, err.Error(), "not supported")

// We should still be able to connect a TLS client
check(tlsConfig.Clone(), false, _EMPTY_)
}

type captureClientConnectedLogger struct {
Expand Down