diff --git a/stan.go b/stan.go index c1ab9b5..45ac207 100644 --- a/stan.go +++ b/stan.go @@ -292,6 +292,9 @@ type conn struct { pubAckMap map[string]*ack pubAckChan chan (struct{}) pubAckCloseChan chan (struct{}) + pubAckTimeoutCh chan (struct{}) + pubAckHead *ack + pubAckTail *ack opts Options nc *nats.Conn ncOwned bool // NATS Streaming created the connection, so needs to close it. @@ -299,6 +302,7 @@ type conn struct { connLostCB ConnectionLostHandler closed bool ping pingInfo + wg sync.WaitGroup } // Holds all field related to the client-to-server pings @@ -316,22 +320,26 @@ type pingInfo struct { // Closure for ack contexts. type ack struct { - t *time.Timer - ah AckHandler - ch chan error + ah AckHandler + ch chan error + guid string + expire int64 + prev *ack + next *ack } // Connect will form a connection to the NATS Streaming subsystem. // Note that clientID can contain only alphanumeric and `-` or `_` characters. func Connect(stanClusterID, clientID string, options ...Option) (Conn, error) { // Process Options - c := conn{ + c := &conn{ clientID: clientID, opts: DefaultOptions, connID: []byte(nuid.Next()), pubNUID: nuid.New(), pubAckMap: make(map[string]*ack), pubAckCloseChan: make(chan struct{}), + pubAckTimeoutCh: make(chan struct{}, 1), subMap: make(map[string]*subscription), } for _, opt := range options { @@ -433,6 +441,10 @@ func Connect(stanClusterID, clientID string, options ...Option) (Conn, error) { // Capture the connection error cb c.connLostCB = c.opts.ConnectionLostCB + // Start the routine that will timeout the publish calls. + c.wg.Add(1) + go c.pubAckTimeout() + unsubPingSub := true // Do this with servers which are at least at protocolOne. if cr.Protocol >= protocolOne { @@ -470,7 +482,7 @@ func Connect(stanClusterID, clientID string, options ...Option) (Conn, error) { p.sub = nil } - return &c, nil + return c, nil } // Invoked on a failed connect. @@ -590,18 +602,9 @@ func (sc *conn) cleanupOnClose(err error) { } } - // Fail all pending pubs - for guid, pubAck := range sc.pubAckMap { - delete(sc.pubAckMap, guid) - if pubAck.t != nil { - pubAck.t.Stop() - } - if pubAck.ah != nil { - pubAck.ah(guid, err) - } else if pubAck.ch != nil { - pubAck.ch <- err - } - } + // Let the pubAckTimeout routine fail the pubAcks.. + sc.signalPubAckTimeoutCh() + // Prevent publish calls that have passed the connection close check but // not yet send to pubAckChan to be possibly blocked. close(sc.pubAckCloseChan) @@ -610,12 +613,15 @@ func (sc *conn) cleanupOnClose(err error) { // Close a connection to the stan system. func (sc *conn) Close() error { sc.Lock() - defer sc.Unlock() - if sc.closed { + sc.Unlock() // We are already closed. return nil } + defer func() { + sc.Unlock() + sc.wg.Wait() + }() // Signals we are closed. sc.closed = true @@ -676,7 +682,9 @@ func (sc *conn) processAck(m *nats.Msg) { } // Remove - a := sc.removeAck(pa.Guid) + sc.Lock() + a := sc.removeAck(pa.Guid, true) + sc.Unlock() if a != nil { // Capture error if it exists. if pa.Error != "" { @@ -732,7 +740,6 @@ func (sc *conn) publishAsync(subject string, data []byte, ah AckHandler, ch chan sc.pubAckMap[peGUID] = a // snapshot ackSubject := sc.ackSubject - ackTimeout := sc.opts.AckTimeout sc.Unlock() // Use the buffered channel to control the number of outstanding acks. @@ -755,60 +762,41 @@ func (sc *conn) publishAsync(subject string, data []byte, ah AckHandler, ch chan // Setup the timer for expiration. sc.Lock() if err != nil || sc.closed { - sc.Unlock() // If we got and error on publish or the connection has been closed, // we need to return an error only if: // - we can remove the pubAck from the map // - we can't, but this is an async pub with no provided AckHandler - removed := sc.removeAck(peGUID) != nil + removed := sc.removeAck(peGUID, true) != nil if removed || (ch == nil && ah == nil) { if err == nil { err = ErrConnectionClosed } - return "", err } - // pubAck was removed from cleanupOnClose() and error will be sent + // else pubAck was removed from cleanupOnClose() and error will be sent // to appropriate go channel (ah or ch). - return peGUID, nil + } else { + a.guid = peGUID + a.expire = time.Now().Add(sc.opts.AckTimeout).UnixNano() + sc.appendPubAckToList(a) } - a.t = time.AfterFunc(ackTimeout, func() { - pubAck := sc.removeAck(peGUID) - // processAck could get here before and handle the ack. - // If that's the case, we would get nil here and simply return. - if pubAck == nil { - return - } - if pubAck.ah != nil { - pubAck.ah(peGUID, ErrTimeout) - } else if a.ch != nil { - pubAck.ch <- ErrTimeout - } - }) sc.Unlock() - return peGUID, nil + return peGUID, err } -// removeAck removes the ack from the pubAckMap and cancels any state, e.g. timers -func (sc *conn) removeAck(guid string) *ack { - var t *time.Timer - sc.Lock() +// Removes the ack from the pubAckMap and possibly from the list. +// Lock held on entry. +func (sc *conn) removeAck(guid string, removeFromList bool) *ack { a := sc.pubAckMap[guid] if a != nil { - t = a.t delete(sc.pubAckMap, guid) + if removeFromList { + sc.removePubAckFromList(a) + } } - pac := sc.pubAckChan - sc.Unlock() - - // Cancel timer if needed. - if t != nil { - t.Stop() - } - // Remove from channel to unblock PublishAsync - if a != nil && len(pac) > 0 { - <-pac + if a != nil && len(sc.pubAckChan) > 0 { + <-sc.pubAckChan } return a } @@ -858,3 +846,142 @@ func (sc *conn) processMsg(raw *nats.Msg) { sc.nc.Publish(ackSubject, b) } } + +// Append the pub ack to the list and signal the timeout routine if this was the first. +// Lock held on entry. +func (sc *conn) appendPubAckToList(a *ack) { + if sc.pubAckTail != nil { + a.prev = sc.pubAckTail + a.prev.next = a + sc.pubAckTail = a + } else { + sc.pubAckHead, sc.pubAckTail = a, a + sc.signalPubAckTimeoutCh() + } +} + +// Signals the pubAckTimeout channel. +func (sc *conn) signalPubAckTimeoutCh() { + select { + case sc.pubAckTimeoutCh <- struct{}{}: + default: + } +} + +// Remove the pub ack from the list and signal the timeout routine if it was +// the head of the list. +// Lock held on entry. +func (sc *conn) removePubAckFromList(a *ack) { + if a.prev != nil { + a.prev.next = a.next + } + if a.next != nil { + a.next.prev = a.prev + } + if a == sc.pubAckTail { + sc.pubAckTail = a.prev + } + if a == sc.pubAckHead { + sc.pubAckHead = a.next + sc.signalPubAckTimeoutCh() + } +} + +// Long-lived go routine that deals with publish ack timeouts. +func (sc *conn) pubAckTimeout() { + defer sc.wg.Done() + + var ( + list *ack + closed bool + dur time.Duration + t = time.NewTimer(time.Hour) + errToReport = ErrTimeout + ) + for { + sc.Lock() + list = sc.pubAckHead + if sc.closed { + closed = true + errToReport = ErrConnectionClosed + } else { + now := time.Now().UnixNano() + if list != nil { + dur = time.Duration(list.expire - now) + if dur < 0 { + dur = 0 + } + } else { + // Any big value would do... + dur = time.Hour + } + } + sc.Unlock() + + if !closed { + if dur > 0 { + t.Reset(dur) + // If the head of the list is removed in processAck, we should + // be notified through pubAckTimeoutCh and will get back to + // compute the new duration. + select { + case <-sc.pubAckTimeoutCh: + continue + case <-t.C: + // Nothing to do, go back to top of loop to refresh list.. + if list == nil { + continue + } + } + } + // We have expired pub acks at this point.. + sc.Lock() + var a *ack + now := time.Now().UnixNano() + for a = list; a != nil; a = a.next { + if a.expire-now > int64(time.Millisecond) { + // This element expires in more than 1ms from now, + // so stop and end the list prior to this element. + if a != sc.pubAckHead { + a.prev.next = nil + a.prev = nil + sc.pubAckHead = a + } else { + fmt.Printf("@@IK: HERE!!!!!!!!!!!!!!!!!!!!!!!\n") + } + break + } + } + // If all elements are expired, reset the connection's list. + if a == nil { + sc.pubAckHead, sc.pubAckTail = nil, nil + } + sc.Unlock() + } + + var next *ack + for a := list; a != nil; { + // Remove the ack from the map. + sc.Lock() + removed := sc.removeAck(a.guid, false) != nil + next = a.next + sc.Unlock() + // If processAck has already processed the ack, we would not + // have been able to remove from the map, so move to the next. + if !removed { + a = next + continue + } + if a.ah != nil { + a.ah(a.guid, errToReport) + } else if a.ch != nil { + a.ch <- errToReport + } + a = next + } + + if closed { + return + } + } +} diff --git a/stan_test.go b/stan_test.go index 9dc4b40..076a8e4 100644 --- a/stan_test.go +++ b/stan_test.go @@ -164,7 +164,7 @@ func TestConnClosedOnConnectFailure(t *testing.T) { buf := make([]byte, 10000) n := runtime.Stack(buf, true) if strings.Contains(string(buf[:n]), "doReconnect") { - t.Fatal("NATS Connection suspected to not have been closed") + t.Fatalf("NATS Connection suspected to not have been closed\n%s", buf[:n]) } } @@ -250,14 +250,13 @@ func TestTimeoutPublish(t *testing.T) { opts.ID = clusterName s := runServerWithOpts(opts) defer s.Shutdown() - sc, err := Connect(clusterName, clientName, PubAckWait(50*time.Millisecond)) - + sc, err := Connect(clusterName, clientName, + ConnectWait(250*time.Millisecond), + PubAckWait(50*time.Millisecond)) if err != nil { t.Fatalf("Expected to connect correctly, got err %v\n", err) } - // Do not defer the connection close because we are going to - // shutdown the server before the client connection is closed, - // which would cause a 2 seconds delay on test exit. + defer sc.Close() ch := make(chan bool) var glock sync.Mutex @@ -725,6 +724,7 @@ func TestSubscriptionStartAtTime(t *testing.T) { // Now test Ago helper delta := time.Since(startTime) + atomic.StoreInt32(&received, 0) sub, err = sc.Subscribe("foo", mcb, StartAtTimeDelta(delta)) if err != nil { t.Fatalf("Expected no error on Subscribe, got %v\n", err) @@ -1026,7 +1026,10 @@ func TestManualAck(t *testing.T) { if err := m.Ack(); err != ErrManualAck { t.Fatalf("Expected an error trying to ack an auto-ack subscription") } - fch <- true + select { + case fch <- true: + default: + } }, DeliverAllAvailable()) if err != nil { t.Fatalf("Unexpected error on Subscribe, got %v", err) @@ -1337,12 +1340,11 @@ func TestNoDuplicatesOnSubscriberStart(t *testing.T) { if err != nil { t.Fatalf("Expected to connect correctly, got err %v\n", err) } - defer sc.Close() batch := int32(100) ch := make(chan bool) - pch := make(chan bool) + pch := make(chan bool, 1) received := int32(0) sent := int32(0) @@ -1363,7 +1365,10 @@ func TestNoDuplicatesOnSubscriberStart(t *testing.T) { sc.PublishAsync("foo", []byte("hello"), nil) } // signal that we've published a batch. - pch <- true + select { + case pch <- true: + default: + } } } @@ -1463,7 +1468,11 @@ func TestRaceAckOnClose(t *testing.T) { func TestNatsConn(t *testing.T) { s := RunServer(clusterName) defer s.Shutdown() - sc := NewDefaultConnection(t) + sc, err := Connect(clusterName, clientName, + ConnectWait(250*time.Millisecond)) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } defer sc.Close() // Make sure we can get the STAN-created Conn. @@ -1490,11 +1499,12 @@ func TestNatsConn(t *testing.T) { // Allow custom conn only if already connected opts := nats.GetDefaultOptions() - nc, err := opts.Connect() + nc, err = opts.Connect() if err != nil { t.Fatalf("Expected to connect correctly, got err %v", err) } - sc, err = Connect(clusterName, clientName, NatsConn(nc)) + sc, err = Connect(clusterName, clientName+"2", + NatsConn(nc), ConnectWait(250*time.Millisecond)) if err != nil { t.Fatalf("Expected to connect correctly, got err %v", err) } @@ -1511,7 +1521,8 @@ func TestNatsConn(t *testing.T) { t.Fatalf("Expected to connect correctly, got err %v", err) } defer nc.Close() - sc, err = Connect(clusterName, clientName, NatsConn(nc)) + sc, err = Connect(clusterName, clientName+"3", + NatsConn(nc), ConnectWait(250*time.Millisecond)) if err != nil { t.Fatalf("Expected to connect correctly, got err %v", err) } @@ -1532,14 +1543,14 @@ func TestMaxPubAcksInflight(t *testing.T) { defer nc.Close() sc, err := Connect(clusterName, clientName, + ConnectWait(250*time.Millisecond), MaxPubAcksInflight(1), PubAckWait(time.Second), NatsConn(nc)) if err != nil { t.Fatalf("Expected to connect correctly, got err %v", err) } - // Don't defer the close of connection since the server is stopped, - // the close would delay the test. + defer sc.Close() // Cause the ACK to not come by shutdown the server now s.Shutdown() @@ -2170,6 +2181,7 @@ func TestPings(t *testing.T) { if err != nil { t.Fatalf("Error on connect: %v", err) } + defer nc.Close() count := 0 ch := make(chan bool, 1) nc.Subscribe(DefaultDiscoverPrefix+"."+clusterName+".pings", func(m *nats.Msg) { @@ -2267,33 +2279,29 @@ func TestPingsCloseUnlockPubCalls(t *testing.T) { s.Shutdown() total := 100 - ch := make(chan bool, 1) - ec := int32(0) + ch := make(chan bool, total) ah := func(g string, e error) { - if c := atomic.AddInt32(&ec, 1); c == int32(total) { - ch <- true - } + // Invoke a function that requires connection's lock + sc.NatsConn() } - wg := sync.WaitGroup{} - wg.Add(total) for i := 0; i < total/2; i++ { go func() { sc.PublishAsync("foo", []byte("hello"), ah) - wg.Done() + ch <- true }() go func() { - if err := sc.Publish("foo", []byte("hello")); err != nil { - if c := atomic.AddInt32(&ec, 1); c == int32(total) { - ch <- true - } - } - wg.Done() + sc.Publish("foo", []byte("hello")) + ch <- true }() } - if err := Wait(ch); err != nil { - t.Fatal("Did not get all the expected failures") + tm := time.NewTimer(2 * time.Second) + for i := 0; i < total; i++ { + select { + case <-ch: + case <-tm.C: + t.Fatalf("%v/%v publish calls returned", i+1, total) + } } - wg.Wait() } func TestConnErrHandlerNotCalledOnNormalClose(t *testing.T) { @@ -2377,6 +2385,7 @@ func TestPubFailsOnClientReplaced(t *testing.T) { if err != nil { t.Fatalf("Error on connect: %v", err) } + defer sc2.Close() // Verify that this client can publish if err := sc2.Publish("foo", []byte("hello")); err != nil { t.Fatalf("Error on publish: %v", err) @@ -2430,6 +2439,7 @@ func TestPingsResponseError(t *testing.T) { if err != nil { t.Fatalf("Error on connect: %v", err) } + defer sc.Close() // Send a message and ensure it is ok. if err := sc.Publish("foo", []byte("hello")); err != nil { t.Fatalf("Error on publish: %v", err) @@ -2446,6 +2456,7 @@ func TestPingsResponseError(t *testing.T) { if err != nil { t.Fatalf("Error on connect: %v", err) } + defer sc2.Close() // Verify that this client can publish if err := sc2.Publish("foo", []byte("hello")); err != nil { t.Fatalf("Error on publish: %v", err) @@ -2611,3 +2622,42 @@ func TestNoMemoryLeak(t *testing.T) { } t.Fatalf("Heap in use seem high: old=%vMB - new=%vMB", oldInUse/oneMB, newInUse/oneMB) } + +func TestPublishAsyncTimeout(t *testing.T) { + ns := natsd.RunDefaultServer() + defer ns.Shutdown() + + opts := server.GetDefaultOptions() + opts.NATSServerURL = nats.DefaultURL + opts.ID = clusterName + s := runServerWithOpts(opts) + defer s.Shutdown() + + sc, err := Connect(clusterName, clientName, + ConnectWait(250*time.Millisecond), + PubAckWait(50*time.Millisecond)) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer sc.Close() + + s.Shutdown() + + total := 1000 + ch := make(chan bool, 1) + count := int32(0) + ah := func(g string, e error) { + // Invoke a function that requires connection's lock + sc.NatsConn() + if c := atomic.AddInt32(&count, 1); c == int32(total) { + ch <- true + } + } + for i := 0; i < total; i++ { + sc.PublishAsync("foo", []byte("hello"), ah) + } + if err := Wait(ch); err != nil { + c := atomic.LoadInt32(&count) + t.Fatalf("Ack handler was invoked only %v out of %v", c, total) + } +}