Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track closed connections and reason for closing #692

Merged
merged 8 commits into from
Jun 27, 2018
69 changes: 54 additions & 15 deletions server/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"crypto/tls"
"encoding/json"
"fmt"
"io"
"math/rand"
"net"
"sync"
Expand Down Expand Up @@ -99,6 +100,33 @@ func (cf *clientFlag) setIfNotSet(c clientFlag) bool {
return false
}

// Reason client was closed. This will be passed into
// calls to clearConnection, but will only be stored
// in ConnInfo for monitoring.
type ClosedState int

const (
ClientClosed = ClosedState(iota + 1)
AuthenticationTimeout
AuthenticationViolation
TLSHandshakeError
SlowConsumerPendingBytes
SlowConsumerWriteDeadline
WriteError
ReadError
ParseError
StaleConnection
ProtocolViolation
BadClientProtocolVersion
WrongPort
MaxConnectionsExceeded
MaxPayloadExceeded
MaxControlLineExceeded
DuplicateRoute
RouteRemoved
ServerShutdown
)

type client struct {
// Here first because of use of atomics, and memory alignment.
stats
Expand Down Expand Up @@ -357,7 +385,11 @@ func (c *client) readLoop() {
for {
n, err := nc.Read(b)
if err != nil {
c.closeConnection()
if err == io.EOF {
c.closeConnection(ClientClosed)
} else {
c.closeConnection(ReadError)
}
return
}

Expand All @@ -375,7 +407,7 @@ func (c *client) readLoop() {
// handled inline
if err != ErrMaxPayload && err != ErrAuthorization {
c.Errorf("%s", err.Error())
c.closeConnection()
c.closeConnection(ProtocolViolation)
}
return
}
Expand Down Expand Up @@ -530,11 +562,12 @@ func (c *client) flushOutbound() bool {
if n == 0 {
c.out.pb -= attempted
}
c.clearConnection()
if ne, ok := err.(net.Error); ok && ne.Timeout() {
atomic.AddInt64(&srv.slowConsumers, 1)
c.clearConnection(SlowConsumerWriteDeadline)
c.Noticef("Slow Consumer Detected: WriteDeadline of %v Exceeded", c.out.wdl)
} else {
c.clearConnection(WriteError)
c.Debugf("Error flushing: %v", err)
}
return true
Expand Down Expand Up @@ -627,7 +660,7 @@ func (c *client) processErr(errStr string) {
case ROUTER:
c.Errorf("Route Error %s", errStr)
}
c.closeConnection()
c.closeConnection(ParseError)
}

func (c *client) processConnect(arg []byte) error {
Expand Down Expand Up @@ -689,13 +722,13 @@ func (c *client) processConnect(arg []byte) error {
// Check client protocol request if it exists.
if typ == CLIENT && (proto < ClientProtoZero || proto > ClientProtoInfo) {
c.sendErr(ErrBadClientProtocol.Error())
c.closeConnection()
c.closeConnection(BadClientProtocolVersion)
return ErrBadClientProtocol
} else if typ == ROUTER && lang != "" {
// Way to detect clients that incorrectly connect to the route listen
// port. Client provide Lang in the CONNECT protocol while ROUTEs don't.
c.sendErr(ErrClientConnectedToRoutePort.Error())
c.closeConnection()
c.closeConnection(WrongPort)
return ErrClientConnectedToRoutePort
}

Expand All @@ -715,7 +748,7 @@ func (c *client) processConnect(arg []byte) error {
func (c *client) authTimeout() {
c.sendErr(ErrAuthTimeout.Error())
c.Debugf("Authorization Timeout")
c.closeConnection()
c.closeConnection(AuthenticationTimeout)
}

func (c *client) authViolation() {
Expand All @@ -727,19 +760,19 @@ func (c *client) authViolation() {
c.Errorf(ErrAuthorization.Error())
}
c.sendErr("Authorization Violation")
c.closeConnection()
c.closeConnection(AuthenticationViolation)
}

func (c *client) maxConnExceeded() {
c.Errorf(ErrTooManyConnections.Error())
c.sendErr(ErrTooManyConnections.Error())
c.closeConnection()
c.closeConnection(MaxConnectionsExceeded)
}

func (c *client) maxPayloadViolation(sz int, max int64) {
c.Errorf("%s: %d vs %d", ErrMaxPayload.Error(), sz, max)
c.sendErr("Maximum Payload Violation")
c.closeConnection()
c.closeConnection(MaxPayloadExceeded)
}

// queueOutbound queues data for client/route connections.
Expand All @@ -752,7 +785,7 @@ func (c *client) queueOutbound(data []byte) {
// Check for slow consumer via pending bytes limit.
// ok to return here, client is going away.
if c.out.pb > c.out.mp {
c.clearConnection()
c.clearConnection(SlowConsumerPendingBytes)
atomic.AddInt64(&c.srv.slowConsumers, 1)
c.Noticef("Slow Consumer Detected: MaxPending of %d Exceeded", c.out.mp)
return
Expand Down Expand Up @@ -1517,7 +1550,7 @@ func (c *client) processPingTimer() {
if c.ping.out+1 > c.srv.getOpts().MaxPingsOut {
c.Debugf("Stale Client Connection - Closing")
c.sendProto([]byte(fmt.Sprintf("-ERR '%s'\r\n", "Stale Connection")), true)
c.clearConnection()
c.clearConnection(StaleConnection)
return
}

Expand Down Expand Up @@ -1575,11 +1608,12 @@ func (c *client) isAuthTimerSet() bool {
}

// Lock should be held
func (c *client) clearConnection() {
func (c *client) clearConnection(reason ClosedState) {
if c.flags.isSet(clearConnection) {
return
}
c.flags.set(clearConnection)

nc := c.nc
if nc == nil || c.srv == nil {
return
Expand All @@ -1599,6 +1633,11 @@ func (c *client) clearConnection() {
nc.Close()
// Do this always to also kick out any IO writes.
nc.SetWriteDeadline(time.Time{})

// Save off the connection if its a client.
if c.typ == CLIENT && c.srv != nil {
go c.srv.saveClosedClient(c, nc, reason)
}
}

func (c *client) typeString() string {
Expand All @@ -1611,7 +1650,7 @@ func (c *client) typeString() string {
return "Unknown Type"
}

func (c *client) closeConnection() {
func (c *client) closeConnection(reason ClosedState) {
c.mu.Lock()
if c.nc == nil {
c.mu.Unlock()
Expand All @@ -1622,7 +1661,7 @@ func (c *client) closeConnection() {

c.clearAuthTimer()
c.clearPingTimer()
c.clearConnection()
c.clearConnection(reason)
c.nc = nil

// Snapshot for use.
Expand Down
129 changes: 126 additions & 3 deletions server/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -549,15 +549,15 @@ func TestClientRemoveSubsOnDisconnect(t *testing.T) {
if s.sl.Count() != 2 {
t.Fatalf("Should have 2 subscriptions, got %d\n", s.sl.Count())
}
c.closeConnection()
c.closeConnection(ClientClosed)
if s.sl.Count() != 0 {
t.Fatalf("Should have no subscriptions after close, got %d\n", s.sl.Count())
}
}

func TestClientDoesNotAddSubscriptionsWhenConnectionClosed(t *testing.T) {
s, c, _ := setupClient()
c.closeConnection()
c.closeConnection(ClientClosed)
subs := []byte("SUB foo 1\r\nSUB bar 2\r\n")

ch := make(chan bool)
Expand Down Expand Up @@ -767,7 +767,7 @@ func TestTLSCloseClientConnection(t *testing.T) {
}
}()
// Close the client
cli.closeConnection()
cli.closeConnection(ClientClosed)
ch <- true
}

Expand Down Expand Up @@ -1078,3 +1078,126 @@ func TestAvoidSlowConsumerBigMessages(t *testing.T) {
t.Fatalf("Failed to receive all large messages: %d of %d\n", r, expected)
}
}

func closedConnsEqual(s *Server, num int, wait time.Duration) bool {
end := time.Now().Add(wait)
for time.Now().Before(end) {
if s.numClosedConns() == num {
break
}
time.Sleep(5 * time.Millisecond)
}
n := s.numClosedConns()
return n == num
}

func totalClosedConnsEqual(s *Server, num uint64, wait time.Duration) bool {
end := time.Now().Add(wait)
for time.Now().Before(end) {
if s.totalClosedConns() == num {
break
}
time.Sleep(5 * time.Millisecond)
}
n := s.totalClosedConns()
return n == num
}

func TestClosedConnsAccounting(t *testing.T) {
opts := DefaultOptions()
opts.MaxClosedClients = 10

s := RunServer(opts)
defer s.Shutdown()

wait := 20 * time.Millisecond

nc, err := nats.Connect(fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port))
if err != nil {
t.Fatalf("Error on connect: %v", err)
}
nc.Close()

if !closedConnsEqual(s, 1, wait) {
t.Fatalf("Closed conns expected to be 1, got %d\n", s.numClosedConns())
}

conns := s.closedClients()
if lc := len(conns); lc != 1 {
t.Fatalf("len(conns) expected to be %d, got %d\n", 1, lc)
}
if conns[0].Cid != 1 {
t.Fatalf("Expected CID to be 1, got %d\n", conns[0].Cid)
}

// Now create 21 more
for i := 0; i < 21; i++ {
nc, err = nats.Connect(fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port))
if err != nil {
t.Fatalf("Error on connect: %v", err)
}
nc.Close()
}

if !closedConnsEqual(s, opts.MaxClosedClients, wait) {
t.Fatalf("Closed conns expected to be %d, got %d\n",
opts.MaxClosedClients,
s.numClosedConns())
}

if !totalClosedConnsEqual(s, 22, wait) {
t.Fatalf("Closed conns expected to be 22, got %d\n",
s.numClosedConns())
}

conns = s.closedClients()
if lc := len(conns); lc != opts.MaxClosedClients {
t.Fatalf("len(conns) expected to be %d, got %d\n",
opts.MaxClosedClients, lc)
}

// Set it to the start after overflow.
cid := uint64(22 - opts.MaxClosedClients)
for _, ci := range conns {
cid++
if ci.Cid != cid {
t.Fatalf("Expected cid of %d, got %d\n", cid, ci.Cid)
}
}
}

func TestClosedConnsSubsAccounting(t *testing.T) {
opts := DefaultOptions()
s := RunServer(opts)
defer s.Shutdown()

url := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port)

nc, err := nats.Connect(url)
if err != nil {
t.Fatalf("Error on subscribe: %v", err)
}

// Now create some subscriptions
numSubs := 10
for i := 0; i < numSubs; i++ {
subj := fmt.Sprintf("foo.%d", i)
nc.Subscribe(subj, func(m *nats.Msg) {})
}
nc.Flush()
nc.Close()

if !closedConnsEqual(s, 1, 20*time.Millisecond) {
t.Fatalf("Closed conns expected to be 1, got %d\n",
s.numClosedConns())
}
conns := s.closedClients()
if lc := len(conns); lc != 1 {
t.Fatalf("len(conns) expected to be 1, got %d\n", lc)
}
ci := conns[0]

if len(ci.subs) != numSubs {
t.Fatalf("Expected number of Subs to be %d, got %d\n", numSubs, len(ci.subs))
}
}
3 changes: 3 additions & 0 deletions server/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,7 @@ const (

// DEFAULT_REMOTE_QSUBS_SWEEPER
DEFAULT_REMOTE_QSUBS_SWEEPER = 30 * time.Second

// DEFAULT_MAX_CLOSED_CLIENTS
DEFAULT_MAX_CLOSED_CLIENTS = 10000
)
Loading