Skip to content

Commit

Permalink
Use only one channel for the two type of callbacks.
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlovic committed Jan 22, 2016
1 parent dd2b1cc commit a7f6ae4
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 51 deletions.
69 changes: 32 additions & 37 deletions nats.go
Expand Up @@ -135,12 +135,6 @@ type Options struct {
SubChanLen int
}

type errHandlerInfo struct {
cb ErrHandler
sub *Subscription
err error
}

const (
// Scratch storage for assembling protocol headers
scratchSize = 512
Expand All @@ -153,6 +147,9 @@ const (

// Default server pool size
srvPoolSize = 4

// The size of the callback channels
cbChanSize = 8
)

// A Conn represents a bare connection to a nats-server.
Expand Down Expand Up @@ -185,9 +182,8 @@ type Conn struct {
ptmr *time.Timer
pout int

connCbs chan ConnHandler
errCbs chan errHandlerInfo
cbDone chan struct{}
cbsChan chan func()
cbsDone chan struct{}
}

// A Subscription represents interest in a given subject.
Expand Down Expand Up @@ -470,9 +466,8 @@ func (o Options) Connect() (*Conn, error) {
nc := &Conn{Opts: o}

// Start the connection and error handlers "dispatching" routine.
nc.connCbs = make(chan ConnHandler, 8)
nc.errCbs = make(chan errHandlerInfo, 8)
nc.cbDone = make(chan struct{})
nc.cbsChan = make(chan func(), cbChanSize)
nc.cbsDone = make(chan struct{})
go nc.asyncCBDispatcher()

if nc.Opts.MaxPingsOut == 0 {
Expand Down Expand Up @@ -1064,7 +1059,9 @@ func (nc *Conn) doReconnect() {
dcb := nc.Opts.DisconnectedCB
if dcb != nil {
nc.mu.Unlock()
nc.connCbs <- dcb
nc.cbsChan <- func() {
dcb(nc)
}
nc.mu.Lock()
}

Expand Down Expand Up @@ -1146,7 +1143,9 @@ func (nc *Conn) doReconnect() {

// Call reconnectedCB if appropriate.
if rcb != nil {
nc.connCbs <- rcb
nc.cbsChan <- func() {
rcb(nc)
}
}

return
Expand Down Expand Up @@ -1381,7 +1380,9 @@ slowConsumer:
func (nc *Conn) processSlowConsumer(s *Subscription) {
nc.err = ErrSlowConsumer
if nc.Opts.AsyncErrorCB != nil && !s.sc {
nc.errCbs <- errHandlerInfo{cb: nc.Opts.AsyncErrorCB, sub: s, err: ErrSlowConsumer}
nc.cbsChan <- func() {
nc.Opts.AsyncErrorCB(nc, s, ErrSlowConsumer)
}
}
s.sc = true
}
Expand Down Expand Up @@ -2185,19 +2186,23 @@ func (nc *Conn) close(status Status, doCBs bool) {

// Perform appropriate callback if needed
if dcb != nil {
nc.connCbs <- dcb
nc.cbsChan <- func() {
dcb(nc)
}
}
if ccb != nil {
nc.connCbs <- ccb
nc.cbsChan <- func() {
ccb(nc)
}
}

nc.mu.Lock()

if nc.cbDone != nil {
if nc.cbsDone != nil {
// Close this channel to notify that we are done. The go routine will
// drain all remaining callbacks from their respective channels.
close(nc.cbDone)
nc.cbDone = nil
// drain all remaining callbacks from cbsChan.
close(nc.cbsDone)
nc.cbsDone = nil
}

nc.status = status
Expand Down Expand Up @@ -2272,28 +2277,18 @@ func (nc *Conn) asyncCBDispatcher() {

// Capture things under lock
nc.mu.Lock()
ch := nc.connCbs
eh := nc.errCbs
done := nc.cbDone
ch := nc.cbsChan
done := nc.cbsDone
nc.mu.Unlock()

for {
select {
case ecb := <-eh:
ecb.cb(nc, ecb.sub, ecb.err)
break
case c := <-ch:
c(nc)
break
case cb := <-ch:
cb()
case <-done:
if len(eh) > 0 {
for ecb := range eh {
ecb.cb(nc, ecb.sub, ecb.err)
}
}
if len(ch) > 0 {
for c := range ch {
c(nc)
for cb := range ch {
cb()
}
}
return
Expand Down
57 changes: 43 additions & 14 deletions test/conn_test.go
Expand Up @@ -368,9 +368,10 @@ func TestCallbacksOrder(t *testing.T) {
dtime1 := time.Time{}
dtime2 := time.Time{}
rtime := time.Time{}
atime1 := time.Time{}
atime2 := time.Time{}
ctime := time.Time{}

disconnected := make(chan bool)
reconnected := make(chan bool)
closed := make(chan bool)

Expand All @@ -382,7 +383,6 @@ func TestCallbacksOrder(t *testing.T) {
} else {
dtime2 = time.Now()
}
disconnected <- true
}

rch := func(nc *nats.Conn) {
Expand All @@ -391,6 +391,15 @@ func TestCallbacksOrder(t *testing.T) {
reconnected <- true
}

ech := func(nc *nats.Conn, sub *nats.Subscription, err error) {
time.Sleep(20 * time.Millisecond)
if sub.Subject == "foo" {
atime1 = time.Now()
} else {
atime2 = time.Now()
}
}

cch := func(nc *nats.Conn) {
ctime = time.Now()
closed <- true
Expand All @@ -400,35 +409,55 @@ func TestCallbacksOrder(t *testing.T) {
nats.DisconnectHandler(dch),
nats.ReconnectHandler(rch),
nats.ClosedHandler(cch),
nats.ErrorHandler(ech),
nats.ReconnectWait(50*time.Millisecond))
if err != nil {
t.Fatalf("Unable to connect: %v\n", err)
}

s.Shutdown()

if err := Wait(disconnected); err != nil {
t.Fatalf("Did not get the disconnected callback")
}

s = RunDefaultServer()
defer s.Shutdown()

if err := Wait(reconnected); err != nil {
t.Fatalf("Did not get the reconnected callback")
var sub1 *nats.Subscription

recv := func(m *nats.Msg) {
time.Sleep(time.Second)
m.Sub.Unsubscribe()
}

nc.Close()
sub1, err = nc.Subscribe("foo", recv)
if err != nil {
t.Fatalf("Unable to create subscription: %v\n", err)
}
sub1.SetPendingLimits(1, 100000)
for i := 0; i < 2; i++ {
nc.Publish("foo", []byte("test"))
}

var sub2 *nats.Subscription

sub2, err = nc.Subscribe("bar", recv)
if err != nil {
t.Fatalf("Unable to create subscription: %v\n", err)
}
sub2.SetPendingLimits(1, 100000)
for i := 0; i < 2; i++ {
nc.Publish("bar", []byte("test"))
}

if err := Wait(disconnected); err != nil {
t.Fatalf("Did not get the disconnected callback")
if err := Wait(reconnected); err != nil {
t.Fatal("Did not get the reconnected callback")
}

nc.Close()

if err := Wait(closed); err != nil {
t.Fatalf("Did not get the close callback")
t.Fatal("Did not get the close callback")
}

if rtime.Before(dtime1) || dtime2.Before(rtime) || ctime.Before(rtime) {
t.Fatalf("Wrong callback order")
if rtime.Before(dtime1) || dtime2.Before(rtime) || atime2.Before(atime1) || ctime.Before(atime2) {
t.Fatalf("Wrong callback order:\n%v\n%v\n%v\n%v\n%v\n%v", dtime1, rtime, atime1, atime2, dtime2, ctime)
}
}

0 comments on commit a7f6ae4

Please sign in to comment.