Skip to content

Commit

Permalink
[ADDED] Websocket support
Browse files Browse the repository at this point in the history
Websocket support can be enabled with a new websocket
configuration block:

```
websocket {
    # Specify a host and port to listen for websocket connections
    # listen: "host:port"

    # It can also be configured with individual parameters,
    # namely host and port.
    # host: "hostname"
    # port: 4443

    # TLS configuration is required
    tls {
      cert_file: "/path/to/cert.pem"
      key_file: "/path/to/key.pem"
    }

    # If same_origin is true, then the Origin header of the
    # client request must match the request's Host.
    # same_origin: true

    # This list specifies the only accepted values for
    # the client's request Origin header. The scheme,
    # host and port must match. By convention, the
    # absence of port for an http:// scheme will be 80,
    # and for https:// will be 443.
    # allowed_origins [
    #    "http://www.example.com"
    #    "https://www.other-example.com"
    # ]

    # This is the total time allowed for the server to
    # read the client request and write the response back
    # to the client. This include the time needed for the
    # TLS handshake.
    # handshake_timeout: "2s"
}
```

Use of "https://" is enforced.

Signed-off-by: Ivan Kozlovic <ivan@synadia.com>
  • Loading branch information
kozlovic committed Apr 14, 2020
1 parent 9fe4146 commit 07af9de
Show file tree
Hide file tree
Showing 12 changed files with 4,002 additions and 113 deletions.
141 changes: 102 additions & 39 deletions server/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ type client struct {
route *route
gw *gateway
leaf *leaf
ws *websocket

// To keep track of gateway replies mapping
gwrm map[string]*gwReplyMap
Expand Down Expand Up @@ -255,7 +256,6 @@ type outbound struct {
mp int64 // Snapshot of max pending for client.
lft time.Duration // Last flush time for Write.
stc chan struct{} // Stall chan we create to slow down producers on overrun, e.g. fan-in.
lwb int32 // Last byte size of Write.
}

type perm struct {
Expand Down Expand Up @@ -471,16 +471,23 @@ func (c *client) initClient() {

// snapshot the string version of the connection
var conn string
if ip, ok := c.nc.(*net.TCPConn); ok {
conn = ip.RemoteAddr().String()
host, port, _ := net.SplitHostPort(conn)
iPort, _ := strconv.Atoi(port)
c.host, c.port = host, uint16(iPort)
if c.nc != nil {
if addr := c.nc.RemoteAddr(); addr != nil {
if conn = addr.String(); conn != _EMPTY_ {
host, port, _ := net.SplitHostPort(conn)
iPort, _ := strconv.Atoi(port)
c.host, c.port = host, uint16(iPort)
}
}
}

switch c.kind {
case CLIENT:
c.ncs = fmt.Sprintf("%s - cid:%d", conn, c.cid)
name := "cid"
if c.ws != nil {
name = "wid"
}
c.ncs = fmt.Sprintf("%s - %s:%d", conn, name, c.cid)
case ROUTER:
c.ncs = fmt.Sprintf("%s - rid:%d", conn, c.cid)
case GATEWAY:
Expand Down Expand Up @@ -855,6 +862,7 @@ func (c *client) readLoop() {
return
}
nc := c.nc
ws := c.ws != nil
c.in.rsz = startBufSize
// Snapshot max control line since currently can not be changed on reload and we
// were checking it on each call to parse. If this changes and we allow MaxControlLine
Expand All @@ -880,13 +888,36 @@ func (c *client) readLoop() {
// Start read buffer.
b := make([]byte, c.in.rsz)

// Websocket clients will return several slices if there are multiple
// websocket frames in the blind read. For non WS clients though, we
// will always have 1 slice per loop iteration. So we define this here
// so non WS clients will use bufs[0] = b[:n].
var _bufs [1][]byte
bufs := _bufs[:1]

var wsr *wsReadInfo
if ws {
wsr = &wsReadInfo{}
wsr.init()
}

for {
n, err := nc.Read(b)
// If we have any data we will try to parse and exit at the end.
if n == 0 && err != nil {
c.closeConnection(closedStateForErr(err))
return
}
if ws {
bufs, err = c.wsRead(wsr, nc, b[:n])
if bufs == nil && err != nil {
c.closeConnection(closedStateForErr(err))
} else if bufs == nil {
continue
}
} else {
bufs[0] = b[:n]
}
start := time.Now()

// Clear inbound stats cache
Expand All @@ -896,20 +927,22 @@ func (c *client) readLoop() {

// Main call into parser for inbound data. This will generate callouts
// to process messages, etc.
if err := c.parse(b[:n]); err != nil {
if dur := time.Since(start); dur >= readLoopReportThreshold {
c.Warnf("Readloop processing time: %v", dur)
}
// Need to call flushClients because some of the clients have been
// assigned messages and their "fsp" incremented, and need now to be
// decremented and their writeLoop signaled.
c.flushClients(0)
// handled inline
if err != ErrMaxPayload && err != ErrAuthentication {
c.Error(err)
c.closeConnection(ProtocolViolation)
for i := 0; i < len(bufs); i++ {
if err := c.parse(bufs[i]); err != nil {
if dur := time.Since(start); dur >= readLoopReportThreshold {
c.Warnf("Readloop processing time: %v", dur)
}
// Need to call flushClients because some of the clients have been
// assigned messages and their "fsp" incremented, and need now to be
// decremented and their writeLoop signaled.
c.flushClients(0)
// handled inline
if err != ErrMaxPayload && err != ErrAuthentication {
c.Error(err)
c.closeConnection(ProtocolViolation)
}
return
}
return
}

// Updates stats for client and server that were collected
Expand Down Expand Up @@ -1005,6 +1038,10 @@ func (c *client) collapsePtoNB() net.Buffers {
// This will handle the fixup needed on a partial write.
// Assume pending has been already calculated correctly.
func (c *client) handlePartialWrite(pnb net.Buffers) {
if c.ws != nil {
c.ws.frames = pnb
return
}
nb := c.collapsePtoNB()
// The partial needs to be first, so append nb to pnb
c.out.nb = append(pnb, nb...)
Expand Down Expand Up @@ -1035,6 +1072,11 @@ func (c *client) flushOutbound() bool {
nb := c.collapsePtoNB()
c.out.p, c.out.nb, c.out.s = c.out.s, nil, nil

attempted := c.out.pb
if c.ws != nil {
nb, attempted = c.wsFrameOutbound(nb)
}

// For selecting primary replacement.
cnb := nb
var lfs int
Expand All @@ -1044,7 +1086,6 @@ func (c *client) flushOutbound() bool {

// In case it goes away after releasing the lock.
nc := c.nc
attempted := c.out.pb
apm := c.out.pm

// Capture this (we change the value in some tests)
Expand Down Expand Up @@ -1080,29 +1121,31 @@ func (c *client) flushOutbound() bool {
report = c.Errorf
}
report("Error flushing: %v", err)
c.markConnAsClosed(WriteError, true)
c.markConnAsClosed(WriteError)
return true
}
}

// Update flush time statistics.
c.out.lft = lft
c.out.lwb = int32(n)

// Subtract from pending bytes and messages.
c.out.pb -= int64(c.out.lwb)
c.out.pb -= n
if c.ws != nil {
c.ws.fs -= n
}
c.out.pm -= apm // FIXME(dlc) - this will not be totally accurate on partials.

// Check for partial writes
// TODO(dlc) - zero write with no error will cause lost message and the writeloop to spin.
if int64(c.out.lwb) != attempted && n > 0 {
if n != attempted && n > 0 {
c.handlePartialWrite(nb)
} else if c.out.lwb >= c.out.sz {
} else if int32(n) >= c.out.sz {
c.out.sws = 0
}

// Adjust based on what we wrote plus any pending.
pt := int64(c.out.lwb) + c.out.pb
pt := n + c.out.pb

// Adjust sz as needed downward, keeping power of 2.
// We do this at a slower rate.
Expand Down Expand Up @@ -1138,7 +1181,7 @@ func (c *client) flushOutbound() bool {

// Check if we have a stalled gate and if so and we are recovering release
// any stalled producers. Only kind==CLIENT will stall.
if c.out.stc != nil && (int64(c.out.lwb) == attempted || c.out.pb < c.out.mp/2) {
if c.out.stc != nil && (n == attempted || c.out.pb < c.out.mp/2) {
close(c.out.stc)
c.out.stc = nil
}
Expand All @@ -1153,7 +1196,7 @@ func (c *client) handleWriteTimeout(written, attempted int64, numChunks int) boo
if tlsConn, ok := c.nc.(*tls.Conn); ok {
if !tlsConn.ConnectionState().HandshakeComplete {
// Likely a TLSTimeout error instead...
c.markConnAsClosed(TLSHandshakeError, true)
c.markConnAsClosed(TLSHandshakeError)
// Would need to coordinate with tlstimeout()
// to avoid double logging, so skip logging
// here, and don't report a slow consumer error.
Expand All @@ -1164,7 +1207,7 @@ func (c *client) handleWriteTimeout(written, attempted int64, numChunks int) boo
// before the authorization timeout. If that is the case, then we handle
// as slow consumer though we do not increase the counter as that can be
// misleading.
c.markConnAsClosed(SlowConsumerWriteDeadline, true)
c.markConnAsClosed(SlowConsumerWriteDeadline)
return true
}

Expand All @@ -1175,26 +1218,40 @@ func (c *client) handleWriteTimeout(written, attempted int64, numChunks int) boo

// We always close CLIENT connections, or when nothing was written at all...
if c.kind == CLIENT || written == 0 {
c.markConnAsClosed(SlowConsumerWriteDeadline, true)
c.markConnAsClosed(SlowConsumerWriteDeadline)
return true
}
return false
}

// Marks this connection has closed with the given reason.
// Sets the closeConnection flag and skipFlushOnClose flag if asked.
// Sets the closeConnection flag and skipFlushOnClose depending on the reason.
// Depending on the kind of connection, the connection will be saved.
// If a writeLoop has been started, the final flush/close/teardown will
// be done there, otherwise flush and close of TCP connection is done here in place.
// Returns true if closed in place, flase otherwise.
// Lock is held on entry.
func (c *client) markConnAsClosed(reason ClosedState, skipFlush bool) bool {
func (c *client) markConnAsClosed(reason ClosedState) bool {
// Possibly set skipFlushOnClose flag even if connection has already been
// mark as closed. The rationale is that a connection may be closed with
// a reason that justifies a flush (say after sending an -ERR), but then
// the flushOutbound() gets a write error. If that happens, connection
// being lost, there is no reason to attempt to flush again during the
// teardown when the writeLoop exits.
var skipFlush bool
switch reason {
case ReadError, WriteError, SlowConsumerPendingBytes, SlowConsumerWriteDeadline, TLSHandshakeError:
c.flags.set(skipFlushOnClose)
skipFlush = true
}
if c.flags.isSet(closeConnection) {
return false
}
c.flags.set(closeConnection)
if skipFlush {
c.flags.set(skipFlushOnClose)
// For a websocket client, unless we are told not to flush, enqueue
// a websocket CloseMessage based on the reason.
if !skipFlush && c.ws != nil && !c.ws.closeSent {
c.wsEnqueueCloseMessage(reason)
}
// Save off the connection if its a client or leafnode.
if c.kind == CLIENT || c.kind == LEAF {
Expand Down Expand Up @@ -1575,7 +1632,7 @@ func (c *client) queueOutbound(data []byte) bool {
c.out.pb -= int64(len(data))
atomic.AddInt64(&c.srv.slowConsumers, 1)
c.Noticef("Slow Consumer Detected: MaxPending of %d Exceeded", c.out.mp)
c.markConnAsClosed(SlowConsumerPendingBytes, true)
c.markConnAsClosed(SlowConsumerPendingBytes)
return referenced
}

Expand Down Expand Up @@ -1720,6 +1777,10 @@ func (c *client) generateClientInfoJSON(info Info) []byte {
info.CID = c.cid
info.ClientIP = c.host
info.MaxPayload = c.mpay
if c.ws != nil {
info.ClientConnectURLs = info.WSConnectURLs
}
info.WSConnectURLs = nil
// Generate the info json
b, _ := json.Marshal(info)
pcs := [][]byte{[]byte("INFO"), b, []byte(CR_LF)}
Expand Down Expand Up @@ -3502,7 +3563,7 @@ func (c *client) closeConnection(reason ClosedState) {
// This will set the closeConnection flag and save the connection, etc..
// Will return true if no writeLoop was started and TCP connection was
// closed in place, in which case we need to do the teardown.
teardownNow := c.markConnAsClosed(reason, false)
teardownNow := c.markConnAsClosed(reason)
c.mu.Unlock()

if teardownNow {
Expand Down Expand Up @@ -3537,6 +3598,7 @@ func (c *client) teardownConn() {
var (
retryImplicit bool
connectURLs []string
wsConnectURLs []string
gwName string
gwIsOutbound bool
gwCfg *gatewayCfg
Expand Down Expand Up @@ -3566,6 +3628,7 @@ func (c *client) teardownConn() {
retryImplicit = c.route.retry
}
connectURLs = c.route.connectURLs
wsConnectURLs = c.route.wsConnURLs
}
if kind == GATEWAY {
gwName = c.gw.name
Expand All @@ -3584,11 +3647,11 @@ func (c *client) teardownConn() {

if srv != nil {
// This is a route that disconnected, but we are not in lame duck mode...
if len(connectURLs) > 0 && !srv.isLameDuckMode() {
if (len(connectURLs) > 0 || len(wsConnectURLs) > 0) && !srv.isLameDuckMode() {
// Unless disabled, possibly update the server's INFO protocol
// and send to clients that know how to handle async INFOs.
if !srv.getOpts().Cluster.NoAdvertise {
srv.removeClientConnectURLsAndSendINFOToClients(connectURLs)
srv.removeConnectURLsAndSendINFOToClients(connectURLs, wsConnectURLs)
}
}

Expand Down
8 changes: 6 additions & 2 deletions server/client_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2012-2019 The NATS Authors
// Copyright 2012-2020 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
Expand Down Expand Up @@ -71,7 +71,7 @@ func (c *testAsyncClient) parseAndClose(proto []byte) {
func createClientAsync(ch chan *client, s *Server, cli net.Conn) {
s.grWG.Add(1)
go func() {
c := s.createClient(cli)
c := s.createClient(cli, nil)
// Must be here to suppress +OK
c.opts.Verbose = false
go c.writeLoop()
Expand Down Expand Up @@ -1948,6 +1948,10 @@ func (c *testConnWritePartial) Write(p []byte) (int, error) {
return c.buf.Write(p[:n])
}

func (c *testConnWritePartial) RemoteAddr() net.Addr {
return nil
}

func (c *testConnWritePartial) SetWriteDeadline(_ time.Time) error {
return nil
}
Expand Down
4 changes: 2 additions & 2 deletions server/gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func waitCh(t *testing.T, ch chan bool, errTxt string) {
}
}

func natsConnect(t *testing.T, url string, options ...nats.Option) *nats.Conn {
func natsConnect(t testing.TB, url string, options ...nats.Option) *nats.Conn {
t.Helper()
nc, err := nats.Connect(url, options...)
if err != nil {
Expand Down Expand Up @@ -213,7 +213,7 @@ func natsFlush(t *testing.T, nc *nats.Conn) {
}
}

func natsPub(t *testing.T, nc *nats.Conn, subj string, payload []byte) {
func natsPub(t testing.TB, nc *nats.Conn, subj string, payload []byte) {
t.Helper()
if err := nc.Publish(subj, payload); err != nil {
t.Fatalf("Error on publish: %v", err)
Expand Down
7 changes: 5 additions & 2 deletions server/monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,7 @@ type Varz struct {
TLSVerify bool `json:"tls_verify,omitempty"`
IP string `json:"ip,omitempty"`
ClientConnectURLs []string `json:"connect_urls,omitempty"`
WSConnectURLs []string `json:"ws_connect_urls,omitempty"`
MaxConn int `json:"max_connections"`
MaxSubs int `json:"max_subscriptions,omitempty"`
PingInterval time.Duration `json:"ping_interval"`
Expand Down Expand Up @@ -1247,8 +1248,10 @@ func (s *Server) updateVarzRuntimeFields(v *Varz, forceUpdate bool, pcpu float64
v.Mem = rss
v.CPU = pcpu
if l := len(s.info.ClientConnectURLs); l > 0 {
v.ClientConnectURLs = make([]string, l)
copy(v.ClientConnectURLs, s.info.ClientConnectURLs)
v.ClientConnectURLs = append([]string(nil), s.info.ClientConnectURLs...)
}
if l := len(s.info.WSConnectURLs); l > 0 {
v.WSConnectURLs = append([]string(nil), s.info.WSConnectURLs...)
}
v.Connections = len(s.clients)
v.TotalConnections = s.totalClients
Expand Down

0 comments on commit 07af9de

Please sign in to comment.