Skip to content

Commit

Permalink
Merge 30607b5 into 7c468cd
Browse files Browse the repository at this point in the history
  • Loading branch information
piotrpio committed May 20, 2023
2 parents 7c468cd + 30607b5 commit 4f66949
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 37 deletions.
137 changes: 100 additions & 37 deletions nats.go
Expand Up @@ -524,31 +524,32 @@ type Conn struct {
mu sync.RWMutex
// Opts holds the configuration of the Conn.
// Modifying the configuration of a running Conn is a race.
Opts Options
wg sync.WaitGroup
srvPool []*srv
current *srv
urls map[string]struct{} // Keep track of all known URLs (used by processInfo)
conn net.Conn
bw *natsWriter
br *natsReader
fch chan struct{}
info serverInfo
ssid int64
subsMu sync.RWMutex
subs map[int64]*Subscription
ach *asyncCallbacksHandler
pongs []chan struct{}
scratch [scratchSize]byte
status Status
initc bool // true if the connection is performing the initial connect
err error
ps *parseState
ptmr *time.Timer
pout int
ar bool // abort reconnect
rqch chan struct{}
ws bool // true if a websocket connection
Opts Options
wg sync.WaitGroup
srvPool []*srv
current *srv
urls map[string]struct{} // Keep track of all known URLs (used by processInfo)
conn net.Conn
bw *natsWriter
br *natsReader
fch chan struct{}
info serverInfo
ssid int64
subsMu sync.RWMutex
subs map[int64]*Subscription
ach *asyncCallbacksHandler
pongs []chan struct{}
scratch [scratchSize]byte
status Status
statListeners map[Status][]chan Status
initc bool // true if the connection is performing the initial connect
err error
ps *parseState
ptmr *time.Timer
pout int
ar bool // abort reconnect
rqch chan struct{}
ws bool // true if a websocket connection

// New style response handler
respSub string // The wildcard subject
Expand Down Expand Up @@ -2219,7 +2220,7 @@ func (nc *Conn) processConnectInit() error {
defer nc.conn.SetDeadline(time.Time{})

// Set our status to connecting.
nc.status = CONNECTING
nc.changeConnStatus(CONNECTING)

// Process the INFO protocol received from the server
err := nc.processExpectedInfo()
Expand Down Expand Up @@ -2311,7 +2312,7 @@ func (nc *Conn) connect() (bool, error) {
nc.initc = false
} else if nc.Opts.RetryOnFailedConnect {
nc.setup()
nc.status = RECONNECTING
nc.changeConnStatus(RECONNECTING)
nc.bw.switchToPending()
go nc.doReconnect(ErrNoServers)
err = nil
Expand Down Expand Up @@ -2551,7 +2552,7 @@ func (nc *Conn) sendConnect() error {
}

// This is where we are truly connected.
nc.status = CONNECTED
nc.changeConnStatus(CONNECTED)

return nil
}
Expand Down Expand Up @@ -2726,7 +2727,7 @@ func (nc *Conn) doReconnect(err error) {
if nc.ar {
break
}
nc.status = RECONNECTING
nc.changeConnStatus(RECONNECTING)
continue
}

Expand All @@ -2744,7 +2745,7 @@ func (nc *Conn) doReconnect(err error) {
// Now send off and clear pending buffer
nc.err = nc.flushReconnectPendingItems()
if nc.err != nil {
nc.status = RECONNECTING
nc.changeConnStatus(RECONNECTING)
// Stop the ping timer (if set)
nc.stopPingTimer()
// Since processConnectInit() returned without error, the
Expand Down Expand Up @@ -2797,7 +2798,7 @@ func (nc *Conn) processOpErr(err error) {

if nc.Opts.AllowReconnect && nc.status == CONNECTED {
// Set our new status
nc.status = RECONNECTING
nc.changeConnStatus(RECONNECTING)
// Stop ping timer if set
nc.stopPingTimer()
if nc.conn != nil {
Expand All @@ -2816,7 +2817,7 @@ func (nc *Conn) processOpErr(err error) {
return
}

nc.status = DISCONNECTED
nc.changeConnStatus(DISCONNECTED)
nc.err = err
nc.mu.Unlock()
nc.close(CLOSED, true, nil)
Expand Down Expand Up @@ -5002,11 +5003,11 @@ func (nc *Conn) clearPendingRequestCalls() {
func (nc *Conn) close(status Status, doCBs bool, err error) {
nc.mu.Lock()
if nc.isClosed() {
nc.status = status
nc.changeConnStatus(CLOSED)
nc.mu.Unlock()
return
}
nc.status = CLOSED
nc.changeConnStatus(CLOSED)

// Kick the Go routines so they fall out.
nc.kickFlusher()
Expand Down Expand Up @@ -5065,7 +5066,7 @@ func (nc *Conn) close(status Status, doCBs bool, err error) {
nc.subs = nil
nc.subsMu.Unlock()

nc.status = status
nc.changeConnStatus(status)

// Perform appropriate callback if needed for a disconnect.
if doCBs {
Expand Down Expand Up @@ -5210,7 +5211,7 @@ func (nc *Conn) drainConnection() {

// Flip State
nc.mu.Lock()
nc.status = DRAINING_PUBS
nc.changeConnStatus(DRAINING_PUBS)
nc.mu.Unlock()

// Do publish drain via Flush() call.
Expand Down Expand Up @@ -5245,7 +5246,7 @@ func (nc *Conn) Drain() error {
nc.mu.Unlock()
return nil
}
nc.status = DRAINING_SUBS
nc.changeConnStatus(DRAINING_SUBS)
go nc.drainConnection()
nc.mu.Unlock()

Expand Down Expand Up @@ -5455,6 +5456,68 @@ func (nc *Conn) GetClientID() (uint64, error) {
return nc.info.CID, nil
}

// StatusChanged returns a channel on which given list of connection status changes will be reported.
// If no statuses are provided, defaults will be used: CONNECTED, RECONNECTING, DISCONNECTED, CLOSED.
func (nc *Conn) StatusChanged(statuses ...Status) chan Status {
if len(statuses) == 0 {
statuses = []Status{CONNECTED, RECONNECTING, DISCONNECTED, CLOSED}
}
ch := make(chan Status)
for _, s := range statuses {
nc.registerStatusChangeListener(s, ch)
}
return ch
}

// registerStatusChangeListener registers a channel waiting for a specific status change event.
// Status change events are non-blocking - if no receiver is waiting for the status change,
// it will not be sent on the channel. Closed channels are ignored.
func (nc *Conn) registerStatusChangeListener(status Status, ch chan Status) {
nc.mu.Lock()
defer nc.mu.Unlock()
if nc.statListeners == nil {
nc.statListeners = make(map[Status][]chan Status)
}
if _, ok := nc.statListeners[status]; !ok {
nc.statListeners[status] = make([]chan Status, 0)
}
nc.statListeners[status] = append(nc.statListeners[status], ch)
}

// sendStatusEvent sends connection status event to all channels.
// If channel is closed, or there is no listener, sendStatusEvent
// will not block. Lock should be held entering.
func (nc *Conn) sendStatusEvent(s Status) {
Loop:
for i := 0; i < len(nc.statListeners[s]); i++ {
// make sure channel is not closed
select {
case <-nc.statListeners[s][i]:
// if chan is closed, remove it
nc.statListeners[s][i] = nc.statListeners[s][len(nc.statListeners[s])-1]
nc.statListeners[s] = nc.statListeners[s][:len(nc.statListeners[s])-1]
i--
continue Loop
default:
}
// only send event if someone's listening
select {
case nc.statListeners[s][i] <- s:
default:
}
}
}

// changeConnStatus changes connections status and sends events
// to all listeners. Lock should be held entering.
func (nc *Conn) changeConnStatus(status Status) {
if nc == nil {
return
}
nc.sendStatusEvent(status)
nc.status = status
}

// NkeyOptionFromSeed will load an nkey pair from a seed file.
// It will return the NKey Option and will handle
// signing of nonce challenges from the server. It will take
Expand Down
90 changes: 90 additions & 0 deletions test/conn_test.go
Expand Up @@ -2761,3 +2761,93 @@ func TestRetryOnFailedConnectWithTLSError(t *testing.T) {
t.Fatal("Should have connected")
}
}

func TestConnStatusChangedEvents(t *testing.T) {
waitForStatus := func(t *testing.T, ch chan nats.Status, expected nats.Status) {
select {
case s := <-ch:
if s != expected {
t.Fatalf("Expected status: %s; got: %s", expected, s)
}
case <-time.After(5 * time.Second):
t.Fatalf("Timeout waiting for status %q", expected)
}
}
t.Run("default events", func(t *testing.T) {
s := RunDefaultServer()
nc, err := nats.Connect(s.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
statusCh := nc.StatusChanged()
defer close(statusCh)
newStatus := make(chan nats.Status, 10)
// non-blocking channel, so we need to be constantly listening
go func() {
for {
s, ok := <-statusCh
if !ok {
return
}
newStatus <- s
}
}()
time.Sleep(50 * time.Millisecond)

s.Shutdown()
waitForStatus(t, newStatus, nats.RECONNECTING)

s = RunDefaultServer()
defer s.Shutdown()

waitForStatus(t, newStatus, nats.CONNECTED)

nc.Close()
waitForStatus(t, newStatus, nats.CLOSED)
})

t.Run("custom event only", func(t *testing.T) {
s := RunDefaultServer()
nc, err := nats.Connect(s.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
statusCh := nc.StatusChanged(nats.CLOSED)
defer close(statusCh)
newStatus := make(chan nats.Status, 10)
// non-blocking channel, so we need to be constantly listening
go func() {
for {
s, ok := <-statusCh
if !ok {
return
}
fmt.Println(s)
newStatus <- s
}
}()
time.Sleep(50 * time.Millisecond)
s.Shutdown()
s = RunDefaultServer()
defer s.Shutdown()
nc.Close()
waitForStatus(t, newStatus, nats.CLOSED)
})
t.Run("do not block on channel if it's not used", func(t *testing.T) {
s := RunDefaultServer()
nc, err := nats.Connect(s.ClientURL())
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}
defer nc.Close()
// do not use the returned channel, client should never block
_ = nc.StatusChanged()
s.Shutdown()
s = RunDefaultServer()
defer s.Shutdown()

if err := nc.Publish("foo", []byte("msg")); err != nil {
t.Fatalf("Unexpected error: %v", err)
}
})
}

0 comments on commit 4f66949

Please sign in to comment.