Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADDED] Websocket support #1309

Merged
merged 1 commit into from
May 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
158 changes: 113 additions & 45 deletions server/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ type client struct {
route *route
gw *gateway
leaf *leaf
ws *websocket

// To keep track of gateway replies mapping
gwrm map[string]*gwReplyMap
Expand Down Expand Up @@ -270,7 +271,6 @@ type outbound struct {
mp int64 // Snapshot of max pending for client.
lft time.Duration // Last flush time for Write.
stc chan struct{} // Stall chan we create to slow down producers on overrun, e.g. fan-in.
lwb int32 // Last byte size of Write.
}

type perm struct {
Expand Down Expand Up @@ -484,16 +484,23 @@ func (c *client) initClient() {

// snapshot the string version of the connection
var conn string
if ip, ok := c.nc.(*net.TCPConn); ok {
conn = ip.RemoteAddr().String()
host, port, _ := net.SplitHostPort(conn)
iPort, _ := strconv.Atoi(port)
c.host, c.port = host, uint16(iPort)
if c.nc != nil {
if addr := c.nc.RemoteAddr(); addr != nil {
if conn = addr.String(); conn != _EMPTY_ {
host, port, _ := net.SplitHostPort(conn)
iPort, _ := strconv.Atoi(port)
c.host, c.port = host, uint16(iPort)
}
}
}

switch c.kind {
case CLIENT:
c.ncs = fmt.Sprintf("%s - cid:%d", conn, c.cid)
name := "cid"
if c.ws != nil {
name = "wid"
}
c.ncs = fmt.Sprintf("%s - %s:%d", conn, name, c.cid)
case ROUTER:
c.ncs = fmt.Sprintf("%s - rid:%d", conn, c.cid)
case GATEWAY:
Expand Down Expand Up @@ -873,6 +880,7 @@ func (c *client) readLoop() {
return
}
nc := c.nc
ws := c.ws != nil
c.in.rsz = startBufSize
// Snapshot max control line since currently can not be changed on reload and we
// were checking it on each call to parse. If this changes and we allow MaxControlLine
Expand All @@ -898,13 +906,39 @@ func (c *client) readLoop() {
// Start read buffer.
b := make([]byte, c.in.rsz)

// Websocket clients will return several slices if there are multiple
// websocket frames in the blind read. For non WS clients though, we
// will always have 1 slice per loop iteration. So we define this here
// so non WS clients will use bufs[0] = b[:n].
var _bufs [1][]byte
kozlovic marked this conversation as resolved.
Show resolved Hide resolved
bufs := _bufs[:1]

var wsr *wsReadInfo
if ws {
wsr = &wsReadInfo{}
wsr.init()
}

for {
n, err := nc.Read(b)
// If we have any data we will try to parse and exit at the end.
if n == 0 && err != nil {
c.closeConnection(closedStateForErr(err))
return
}
if ws {
bufs, err = c.wsRead(wsr, nc, b[:n])
if bufs == nil && err != nil {
if err != io.EOF {
c.Errorf("read error: %v", err)
}
c.closeConnection(closedStateForErr(err))
} else if bufs == nil {
continue
}
} else {
bufs[0] = b[:n]
}
start := time.Now()

// Clear inbound stats cache
Expand All @@ -914,20 +948,22 @@ func (c *client) readLoop() {

// Main call into parser for inbound data. This will generate callouts
// to process messages, etc.
if err := c.parse(b[:n]); err != nil {
if dur := time.Since(start); dur >= readLoopReportThreshold {
c.Warnf("Readloop processing time: %v", dur)
}
// Need to call flushClients because some of the clients have been
// assigned messages and their "fsp" incremented, and need now to be
// decremented and their writeLoop signaled.
c.flushClients(0)
// handled inline
if err != ErrMaxPayload && err != ErrAuthentication {
c.Error(err)
c.closeConnection(ProtocolViolation)
for i := 0; i < len(bufs); i++ {
if err := c.parse(bufs[i]); err != nil {
if dur := time.Since(start); dur >= readLoopReportThreshold {
c.Warnf("Readloop processing time: %v", dur)
}
// Need to call flushClients because some of the clients have been
// assigned messages and their "fsp" incremented, and need now to be
// decremented and their writeLoop signaled.
c.flushClients(0)
// handled inline
if err != ErrMaxPayload && err != ErrAuthentication {
c.Error(err)
c.closeConnection(ProtocolViolation)
kozlovic marked this conversation as resolved.
Show resolved Hide resolved
}
return
}
return
}

// Updates stats for client and server that were collected
Expand Down Expand Up @@ -1011,19 +1047,26 @@ func closedStateForErr(err error) ClosedState {

// collapsePtoNB will place primary onto nb buffer as needed in prep for WriteTo.
// This will return a copy on purpose.
func (c *client) collapsePtoNB() net.Buffers {
func (c *client) collapsePtoNB() (net.Buffers, int64) {
if c.ws != nil {
return c.wsCollapsePtoNB()
}
if c.out.p != nil {
p := c.out.p
c.out.p = nil
return append(c.out.nb, p)
return append(c.out.nb, p), c.out.pb
}
return c.out.nb
return c.out.nb, c.out.pb
}

// This will handle the fixup needed on a partial write.
// Assume pending has been already calculated correctly.
func (c *client) handlePartialWrite(pnb net.Buffers) {
nb := c.collapsePtoNB()
if c.ws != nil {
c.ws.frames = append(pnb, c.ws.frames...)
return
}
nb, _ := c.collapsePtoNB()
// The partial needs to be first, so append nb to pnb
c.out.nb = append(pnb, nb...)
}
Expand All @@ -1050,8 +1093,11 @@ func (c *client) flushOutbound() bool {
}

// Place primary on nb, assign primary to secondary, nil out nb and secondary.
nb := c.collapsePtoNB()
nb, attempted := c.collapsePtoNB()
c.out.p, c.out.nb, c.out.s = c.out.s, nil, nil
if nb == nil {
return true
}

// For selecting primary replacement.
cnb := nb
Expand All @@ -1062,7 +1108,6 @@ func (c *client) flushOutbound() bool {

// In case it goes away after releasing the lock.
nc := c.nc
attempted := c.out.pb
apm := c.out.pm

// Capture this (we change the value in some tests)
Expand All @@ -1086,7 +1131,8 @@ func (c *client) flushOutbound() bool {
// Re-acquire client lock.
c.mu.Lock()

if err != nil {
// Ignore ErrShortWrite errors, they will be handled as partials.
if err != nil && err != io.ErrShortWrite {
// Handle timeout error (slow consumer) differently
if ne, ok := err.(net.Error); ok && ne.Timeout() {
if closed := c.handleWriteTimeout(n, attempted, len(cnb)); closed {
Expand All @@ -1100,29 +1146,31 @@ func (c *client) flushOutbound() bool {
report = c.Errorf
}
report("Error flushing: %v", err)
c.markConnAsClosed(WriteError, true)
c.markConnAsClosed(WriteError)
return true
}
}

// Update flush time statistics.
c.out.lft = lft
c.out.lwb = int32(n)

// Subtract from pending bytes and messages.
c.out.pb -= int64(c.out.lwb)
c.out.pb -= n
if c.ws != nil {
c.ws.fs -= n
}
c.out.pm -= apm // FIXME(dlc) - this will not be totally accurate on partials.

// Check for partial writes
// TODO(dlc) - zero write with no error will cause lost message and the writeloop to spin.
if int64(c.out.lwb) != attempted && n > 0 {
if n != attempted && n > 0 {
c.handlePartialWrite(nb)
} else if c.out.lwb >= c.out.sz {
} else if int32(n) >= c.out.sz {
c.out.sws = 0
}

// Adjust based on what we wrote plus any pending.
pt := int64(c.out.lwb) + c.out.pb
pt := n + c.out.pb

// Adjust sz as needed downward, keeping power of 2.
// We do this at a slower rate.
Expand Down Expand Up @@ -1158,7 +1206,7 @@ func (c *client) flushOutbound() bool {

// Check if we have a stalled gate and if so and we are recovering release
// any stalled producers. Only kind==CLIENT will stall.
if c.out.stc != nil && (int64(c.out.lwb) == attempted || c.out.pb < c.out.mp/2) {
if c.out.stc != nil && (n == attempted || c.out.pb < c.out.mp/2) {
close(c.out.stc)
c.out.stc = nil
}
Expand All @@ -1173,7 +1221,7 @@ func (c *client) handleWriteTimeout(written, attempted int64, numChunks int) boo
if tlsConn, ok := c.nc.(*tls.Conn); ok {
if !tlsConn.ConnectionState().HandshakeComplete {
// Likely a TLSTimeout error instead...
c.markConnAsClosed(TLSHandshakeError, true)
c.markConnAsClosed(TLSHandshakeError)
// Would need to coordinate with tlstimeout()
// to avoid double logging, so skip logging
// here, and don't report a slow consumer error.
Expand All @@ -1184,7 +1232,7 @@ func (c *client) handleWriteTimeout(written, attempted int64, numChunks int) boo
// before the authorization timeout. If that is the case, then we handle
// as slow consumer though we do not increase the counter as that can be
// misleading.
c.markConnAsClosed(SlowConsumerWriteDeadline, true)
c.markConnAsClosed(SlowConsumerWriteDeadline)
return true
}

Expand All @@ -1195,26 +1243,40 @@ func (c *client) handleWriteTimeout(written, attempted int64, numChunks int) boo

// We always close CLIENT connections, or when nothing was written at all...
if c.kind == CLIENT || written == 0 {
c.markConnAsClosed(SlowConsumerWriteDeadline, true)
c.markConnAsClosed(SlowConsumerWriteDeadline)
return true
}
return false
}

// Marks this connection has closed with the given reason.
// Sets the closeConnection flag and skipFlushOnClose flag if asked.
// Sets the closeConnection flag and skipFlushOnClose depending on the reason.
// Depending on the kind of connection, the connection will be saved.
// If a writeLoop has been started, the final flush/close/teardown will
// be done there, otherwise flush and close of TCP connection is done here in place.
// Returns true if closed in place, flase otherwise.
// Lock is held on entry.
func (c *client) markConnAsClosed(reason ClosedState, skipFlush bool) bool {
func (c *client) markConnAsClosed(reason ClosedState) bool {
// Possibly set skipFlushOnClose flag even if connection has already been
// mark as closed. The rationale is that a connection may be closed with
// a reason that justifies a flush (say after sending an -ERR), but then
// the flushOutbound() gets a write error. If that happens, connection
// being lost, there is no reason to attempt to flush again during the
// teardown when the writeLoop exits.
var skipFlush bool
switch reason {
case ReadError, WriteError, SlowConsumerPendingBytes, SlowConsumerWriteDeadline, TLSHandshakeError:
c.flags.set(skipFlushOnClose)
skipFlush = true
}
if c.flags.isSet(closeConnection) {
return false
}
c.flags.set(closeConnection)
if skipFlush {
c.flags.set(skipFlushOnClose)
// For a websocket client, unless we are told not to flush, enqueue
// a websocket CloseMessage based on the reason.
if !skipFlush && c.ws != nil && !c.ws.closeSent {
c.wsEnqueueCloseMessage(reason)
}
// Be consistent with the creation: for routes and gateways,
// we use Noticef on create, so use that too for delete.
Expand Down Expand Up @@ -1610,7 +1672,7 @@ func (c *client) queueOutbound(data []byte) bool {
c.out.pb -= int64(len(data))
atomic.AddInt64(&c.srv.slowConsumers, 1)
c.Noticef("Slow Consumer Detected: MaxPending of %d Exceeded", c.out.mp)
c.markConnAsClosed(SlowConsumerPendingBytes, true)
c.markConnAsClosed(SlowConsumerPendingBytes)
return referenced
}

Expand Down Expand Up @@ -1755,6 +1817,10 @@ func (c *client) generateClientInfoJSON(info Info) []byte {
info.CID = c.cid
info.ClientIP = c.host
info.MaxPayload = c.mpay
if c.ws != nil {
info.ClientConnectURLs = info.WSConnectURLs
}
info.WSConnectURLs = nil
// Generate the info json
b, _ := json.Marshal(info)
pcs := [][]byte{[]byte("INFO"), b, []byte(CR_LF)}
Expand Down Expand Up @@ -3803,7 +3869,7 @@ func (c *client) closeConnection(reason ClosedState) {
// This will set the closeConnection flag and save the connection, etc..
// Will return true if no writeLoop was started and TCP connection was
// closed in place, in which case we need to do the teardown.
teardownNow := c.markConnAsClosed(reason, false)
teardownNow := c.markConnAsClosed(reason)
c.mu.Unlock()

if teardownNow {
Expand Down Expand Up @@ -3841,6 +3907,7 @@ func (c *client) teardownConn() {
var (
retryImplicit bool
connectURLs []string
wsConnectURLs []string
gwName string
gwIsOutbound bool
gwCfg *gatewayCfg
Expand Down Expand Up @@ -3870,6 +3937,7 @@ func (c *client) teardownConn() {
retryImplicit = c.route.retry
}
connectURLs = c.route.connectURLs
wsConnectURLs = c.route.wsConnURLs
}
if kind == GATEWAY {
gwName = c.gw.name
Expand All @@ -3894,11 +3962,11 @@ func (c *client) teardownConn() {

if srv != nil {
// This is a route that disconnected, but we are not in lame duck mode...
if len(connectURLs) > 0 && !srv.isLameDuckMode() {
if (len(connectURLs) > 0 || len(wsConnectURLs) > 0) && !srv.isLameDuckMode() {
// Unless disabled, possibly update the server's INFO protocol
// and send to clients that know how to handle async INFOs.
if !srv.getOpts().Cluster.NoAdvertise {
srv.removeClientConnectURLsAndSendINFOToClients(connectURLs)
srv.removeConnectURLsAndSendINFOToClients(connectURLs, wsConnectURLs)
}
}

Expand Down
10 changes: 7 additions & 3 deletions server/client_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2012-2019 The NATS Authors
// Copyright 2012-2020 The NATS Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
Expand Down Expand Up @@ -72,7 +72,7 @@ func (c *testAsyncClient) parseAndClose(proto []byte) {
func createClientAsync(ch chan *client, s *Server, cli net.Conn) {
s.grWG.Add(1)
go func() {
c := s.createClient(cli)
c := s.createClient(cli, nil)
// Must be here to suppress +OK
c.opts.Verbose = false
go c.writeLoop()
Expand Down Expand Up @@ -2163,6 +2163,10 @@ func (c *testConnWritePartial) Write(p []byte) (int, error) {
return c.buf.Write(p[:n])
}

func (c *testConnWritePartial) RemoteAddr() net.Addr {
return nil
}

func (c *testConnWritePartial) SetWriteDeadline(_ time.Time) error {
return nil
}
Expand Down Expand Up @@ -2279,7 +2283,7 @@ func TestCloseConnectionVeryEarly(t *testing.T) {
// Call again with this closed connection. Alternatively, we
// would have to call with a fake connection that implements
// net.Conn but returns an error on Write.
s.createClient(c)
s.createClient(c, nil)

// This connection should not have been added to the server.
checkClientsCount(t, s, 0)
Expand Down