Skip to content

Commit

Permalink
[FIXED] Dispatching of asynchronous callbacks
Browse files Browse the repository at this point in the history
This is a follow up on #365. Working on another issue and running
tests, I still had a race on the original ach or a panic on send
to close channel.

Move from a channel to a linked list with dedicated lock.

Signed-off-by: Ivan Kozlovic <ivan@synadia.com>
  • Loading branch information
kozlovic committed Jun 12, 2018
1 parent d4841f2 commit f9649c3
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 92 deletions.
4 changes: 2 additions & 2 deletions enc.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,9 @@ func (c *EncodedConn) subscribe(subject, queue string, cb Handler) (*Subscriptio
}
if err := c.Enc.Decode(m.Subject, m.Data, oPtr.Interface()); err != nil {
if c.Conn.Opts.AsyncErrorCB != nil {
c.Conn.ach <- func() {
c.Conn.ach.push(func() {
c.Conn.Opts.AsyncErrorCB(c.Conn, m.Sub, errors.New("nats: Got an error trying to unmarshal: "+err.Error()))
}
})
}
return
}
Expand Down
157 changes: 93 additions & 64 deletions nats.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,17 @@ type ConnHandler func(*Conn)
type ErrHandler func(*Conn, *Subscription, error)

// asyncCB is used to preserve order for async callbacks.
type asyncCB func()
type asyncCB struct {
f func()
next *asyncCB
}

type asyncCallbacksHandler struct {
mu sync.Mutex
cond *sync.Cond
head *asyncCB
tail *asyncCB
}

// Option is a function on the options for a connection.
type Option func(*Options) error
Expand Down Expand Up @@ -269,9 +279,6 @@ const (
// Default server pool size
srvPoolSize = 4

// Channel size for the async callback handler.
asyncCBChanSize = 32

// NUID size
nuidSize = 22
)
Expand Down Expand Up @@ -300,7 +307,7 @@ type Conn struct {
ssid int64
subsMu sync.RWMutex
subs map[int64]*Subscription
ach chan asyncCB
ach *asyncCallbacksHandler
pongs []chan struct{}
scratch [scratchSize]byte
status Status
Expand Down Expand Up @@ -753,15 +760,15 @@ func (o Options) Connect() (*Conn, error) {
return nil, err
}

// Create the async callback channel.
nc.ach = make(chan asyncCB, asyncCBChanSize)
// Create the async callback handler.
nc.ach = newAsyncCallbacksHandler()

if err := nc.connect(); err != nil {
return nil, err
}

// Spin up the async cb dispatcher on success
go nc.asyncDispatch()
go nc.ach.asyncCBDispatcher()

return nc, nil
}
Expand Down Expand Up @@ -1393,13 +1400,10 @@ func (nc *Conn) doReconnect() {

// Clear any errors.
nc.err = nil
disconnectedCB := nc.Opts.DisconnectedCB
nc.mu.Unlock()
// Perform appropriate callback if needed for a disconnect.
if disconnectedCB != nil {
nc.ach <- func() { disconnectedCB(nc) }
if nc.Opts.DisconnectedCB != nil {
nc.ach.push(func() { nc.Opts.DisconnectedCB(nc) })
}
nc.mu.Lock()

for len(nc.srvPool) > 0 {
cur, err := nc.selectNextServer()
Expand Down Expand Up @@ -1480,14 +1484,12 @@ func (nc *Conn) doReconnect() {
// This is where we are truly connected.
nc.status = CONNECTED

reconnectedCB := nc.Opts.ReconnectedCB
// Release lock here, we will return below.
nc.mu.Unlock()

// Queue up the reconnect callback.
if reconnectedCB != nil {
nc.ach <- func() { reconnectedCB(nc) }
if nc.Opts.ReconnectedCB != nil {
nc.ach.push(func() { nc.Opts.ReconnectedCB(nc) })
}
// Release lock here, we will return below.
nc.mu.Unlock()

// Make sure to flush everything
nc.Flush()
Expand Down Expand Up @@ -1542,35 +1544,70 @@ func (nc *Conn) processOpErr(err error) {
nc.Close()
}

// Marker to close the channel to kick out the Go routine.
func (nc *Conn) closeAsyncFunc() asyncCB {
return func() {
nc.mu.Lock()
if nc.ach != nil {
close(nc.ach)
nc.ach = nil
}
nc.mu.Unlock()
}
// Returns an initialized asyncCallbacksHandler object
func newAsyncCallbacksHandler() *asyncCallbacksHandler {
ac := &asyncCallbacksHandler{}
ac.cond = sync.NewCond(&ac.mu)
return ac
}

// asyncDispatch is responsible for calling any async callbacks
func (nc *Conn) asyncDispatch() {
// snapshot since they can change from underneath of us.
nc.mu.Lock()
ach := nc.ach
nc.mu.Unlock()

// Loop on the channel and process async callbacks.
// dispatch is responsible for calling any async callbacks
func (ac *asyncCallbacksHandler) asyncCBDispatcher() {
for {
if f, ok := <-ach; !ok {
ac.mu.Lock()
// Protect for spurious wakeups. We should get out of the
// wait only if there is an element to pop from the list.
for ac.head == nil {
ac.cond.Wait()
}
cur := ac.head
ac.head = cur.next
if cur == ac.tail {
ac.tail = nil
}
ac.mu.Unlock()

// This signals that the dispatcher has been closed and all
// previous callbacks have been dispatched.
if cur.f == nil {
return
} else {
f()
}
// Invoke callback outside of handler's lock
cur.f()
}
}

// Add the given function to the tail of the list and
// signals the dispatcher.
func (ac *asyncCallbacksHandler) push(f func()) {
ac.pushOrClose(f, false)
}

// Signals that we are closing...
func (ac *asyncCallbacksHandler) close() {
ac.pushOrClose(nil, true)
}

// Add the given function to the tail of the list and
// signals the dispatcher.
func (ac *asyncCallbacksHandler) pushOrClose(f func(), close bool) {
ac.mu.Lock()
defer ac.mu.Unlock()
// Make sure that library is not calling push with nil function,
// since this is used to notify the dispatcher that it should stop.
if !close && f == nil {
panic("pushing a nil callback")
}
cb := &asyncCB{f: f}
if ac.tail != nil {
ac.tail.next = cb
} else {
ac.head = cb
}
ac.tail = cb
ac.cond.Signal()
}

// readLoop() will sit on the socket reading and processing the
// protocol from the server. It will dispatch appropriately based
// on the op type.
Expand Down Expand Up @@ -1785,11 +1822,10 @@ slowConsumer:
// is already experiencing client-side slow consumer situation.
nc.mu.Lock()
nc.err = ErrSlowConsumer
asyncErrorCB := nc.Opts.AsyncErrorCB
nc.mu.Unlock()
if asyncErrorCB != nil {
nc.ach <- func() { asyncErrorCB(nc, sub, ErrSlowConsumer) }
if nc.Opts.AsyncErrorCB != nil {
nc.ach.push(func() { nc.Opts.AsyncErrorCB(nc, sub, ErrSlowConsumer) })
}
nc.mu.Unlock()
}
}

Expand All @@ -1800,23 +1836,21 @@ func (nc *Conn) processPermissionsViolation(err string) {
// create error here so we can pass it as a closure to the async cb dispatcher.
e := errors.New("nats: " + err)
nc.err = e
asyncErrorCB := nc.Opts.AsyncErrorCB
nc.mu.Unlock()
if asyncErrorCB != nil {
nc.ach <- func() { asyncErrorCB(nc, nil, e) }
if nc.Opts.AsyncErrorCB != nil {
nc.ach.push(func() { nc.Opts.AsyncErrorCB(nc, nil, e) })
}
nc.mu.Unlock()
}

// processAuthorizationViolation is called when the server signals a user
// authorization violation.
func (nc *Conn) processAuthorizationViolation(err string) {
nc.mu.Lock()
nc.err = ErrAuthorization
asyncErrorCB := nc.Opts.AsyncErrorCB
nc.mu.Unlock()
if asyncErrorCB != nil {
nc.ach <- func() { asyncErrorCB(nc, nil, ErrAuthorization) }
if nc.Opts.AsyncErrorCB != nil {
nc.ach.push(func() { nc.Opts.AsyncErrorCB(nc, nil, ErrAuthorization) })
}
nc.mu.Unlock()
}

// flusher is a separate Go routine that will process flush requests for the write
Expand Down Expand Up @@ -1961,7 +1995,7 @@ func (nc *Conn) processInfo(info string) error {
nc.addURLToPool(fmt.Sprintf("nats://%s", curl), true)
}
if hasNew && !nc.initc && nc.Opts.DiscoveredServersCB != nil {
nc.ach <- func() { nc.Opts.DiscoveredServersCB(nc) }
nc.ach.push(func() { nc.Opts.DiscoveredServersCB(nc) })
}
return nil
}
Expand Down Expand Up @@ -2991,22 +3025,17 @@ func (nc *Conn) close(status Status, doCBs bool) {

nc.status = status

disconnectedCB := nc.Opts.DisconnectedCB
closedCB := nc.Opts.ClosedCB
conn := nc.conn
asyncFunc := nc.closeAsyncFunc()
nc.mu.Unlock()

// Perform appropriate callback if needed for a disconnect.
if doCBs {
if disconnectedCB != nil && conn != nil {
nc.ach <- func() { disconnectedCB(nc) }
if nc.Opts.DisconnectedCB != nil && nc.conn != nil {
nc.ach.push(func() { nc.Opts.DisconnectedCB(nc) })
}
if closedCB != nil {
nc.ach <- func() { closedCB(nc) }
if nc.Opts.ClosedCB != nil {
nc.ach.push(func() { nc.Opts.ClosedCB(nc) })
}
nc.ach <- asyncFunc
nc.ach.close()
}
nc.mu.Unlock()
}

// Close will close the connection to the server. This call will release
Expand Down
20 changes: 7 additions & 13 deletions nats_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1104,41 +1104,35 @@ func TestConnServers(t *testing.T) {
}

func TestProcessErrAuthorizationError(t *testing.T) {
ach := make(chan asyncCB, 1)
called := make(chan error, 1)
c := &Conn{
ach: ach,
ach: newAsyncCallbacksHandler(),
Opts: Options{
AsyncErrorCB: func(nc *Conn, sub *Subscription, err error) {
called <- err
},
},
}
go c.ach.asyncCBDispatcher()
defer c.ach.close()
c.processErr("Authorization Violation")
select {
case cb := <-ach:
cb()
default:
t.Fatal("Expected callback on channel")
}

select {
case err := <-called:
if err != ErrAuthorization {
t.Fatalf("Expected ErrAuthorization, got: %v", err)
}
default:
case <-time.After(2 * time.Second):
t.Fatal("Expected error on channel")
}
}

func TestConnAsyncCBDeadlock(t *testing.T) {
s := RunServerOnPort(DefaultPort)
s := RunServerOnPort(TEST_PORT)
defer s.Shutdown()

ch := make(chan bool)
o := GetDefaultOptions()
o.Url = DefaultURL
o.Url = fmt.Sprintf("nats://127.0.0.1:%d", TEST_PORT)
o.ClosedCB = func(_ *Conn) {
ch <- true
}
Expand All @@ -1152,7 +1146,7 @@ func TestConnAsyncCBDeadlock(t *testing.T) {
}

wg := &sync.WaitGroup{}
for i := 0; i < cap(nc.ach)*10; i++ {
for i := 0; i < 300; i++ {
wg.Add(1)
go func() {
// overwhelm asyncCB with errors
Expand Down
4 changes: 2 additions & 2 deletions netchan.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func chPublish(c *EncodedConn, chVal reflect.Value, subject string) {
if c.Conn.isClosed() {
go c.Conn.Opts.AsyncErrorCB(c.Conn, nil, e)
} else {
c.Conn.ach <- func() { c.Conn.Opts.AsyncErrorCB(c.Conn, nil, e) }
c.Conn.ach.push(func() { c.Conn.Opts.AsyncErrorCB(c.Conn, nil, e) })
}
}
return
Expand Down Expand Up @@ -88,7 +88,7 @@ func (c *EncodedConn) bindRecvChan(subject, queue string, channel interface{}) (
if err := c.Enc.Decode(m.Subject, m.Data, oPtr.Interface()); err != nil {
c.Conn.err = errors.New("nats: Got an error trying to unmarshal: " + err.Error())
if c.Conn.Opts.AsyncErrorCB != nil {
c.Conn.ach <- func() { c.Conn.Opts.AsyncErrorCB(c.Conn, m.Sub, c.Conn.err) }
c.Conn.ach.push(func() { c.Conn.Opts.AsyncErrorCB(c.Conn, m.Sub, c.Conn.err) })
}
return
}
Expand Down
Loading

0 comments on commit f9649c3

Please sign in to comment.