diff --git a/stan.go b/stan.go index c1ab9b5..a315c8f 100644 --- a/stan.go +++ b/stan.go @@ -591,15 +591,29 @@ func (sc *conn) cleanupOnClose(err error) { } // Fail all pending pubs - for guid, pubAck := range sc.pubAckMap { - delete(sc.pubAckMap, guid) - if pubAck.t != nil { - pubAck.t.Stop() + if len(sc.pubAckMap) > 0 { + // Collect only the ones that have a timer that can be stopped. + // All others will be handled either in publishAsync() or their + // timer has already fired. + acks := map[string]*ack{} + for guid, pubAck := range sc.pubAckMap { + if pubAck.t != nil && pubAck.t.Stop() { + delete(sc.pubAckMap, guid) + acks[guid] = pubAck + } } - if pubAck.ah != nil { - pubAck.ah(guid, err) - } else if pubAck.ch != nil { - pubAck.ch <- err + // If we collected any, start a go routine that will do the job. + // We can't do it in place in case user's ackHandler uses the connection. + if len(acks) > 0 { + go func() { + for guid, a := range acks { + if a.ah != nil { + a.ah(guid, ErrConnectionClosed) + } else if a.ch != nil { + a.ch <- ErrConnectionClosed + } + } + }() } } // Prevent publish calls that have passed the connection close check but diff --git a/stan_test.go b/stan_test.go index 9dc4b40..ceab802 100644 --- a/stan_test.go +++ b/stan_test.go @@ -164,7 +164,7 @@ func TestConnClosedOnConnectFailure(t *testing.T) { buf := make([]byte, 10000) n := runtime.Stack(buf, true) if strings.Contains(string(buf[:n]), "doReconnect") { - t.Fatal("NATS Connection suspected to not have been closed") + t.Fatalf("NATS Connection suspected to not have been closed\n%s", buf[:n]) } } @@ -250,14 +250,13 @@ func TestTimeoutPublish(t *testing.T) { opts.ID = clusterName s := runServerWithOpts(opts) defer s.Shutdown() - sc, err := Connect(clusterName, clientName, PubAckWait(50*time.Millisecond)) - + sc, err := Connect(clusterName, clientName, + ConnectWait(250*time.Millisecond), + PubAckWait(50*time.Millisecond)) if err != nil { t.Fatalf("Expected to connect correctly, got err %v\n", err) } - // Do not defer the connection close because we are going to - // shutdown the server before the client connection is closed, - // which would cause a 2 seconds delay on test exit. + defer sc.Close() ch := make(chan bool) var glock sync.Mutex @@ -725,6 +724,7 @@ func TestSubscriptionStartAtTime(t *testing.T) { // Now test Ago helper delta := time.Since(startTime) + atomic.StoreInt32(&received, 0) sub, err = sc.Subscribe("foo", mcb, StartAtTimeDelta(delta)) if err != nil { t.Fatalf("Expected no error on Subscribe, got %v\n", err) @@ -1026,7 +1026,10 @@ func TestManualAck(t *testing.T) { if err := m.Ack(); err != ErrManualAck { t.Fatalf("Expected an error trying to ack an auto-ack subscription") } - fch <- true + select { + case fch <- true: + default: + } }, DeliverAllAvailable()) if err != nil { t.Fatalf("Unexpected error on Subscribe, got %v", err) @@ -1337,12 +1340,11 @@ func TestNoDuplicatesOnSubscriberStart(t *testing.T) { if err != nil { t.Fatalf("Expected to connect correctly, got err %v\n", err) } - defer sc.Close() batch := int32(100) ch := make(chan bool) - pch := make(chan bool) + pch := make(chan bool, 1) received := int32(0) sent := int32(0) @@ -1363,7 +1365,10 @@ func TestNoDuplicatesOnSubscriberStart(t *testing.T) { sc.PublishAsync("foo", []byte("hello"), nil) } // signal that we've published a batch. - pch <- true + select { + case pch <- true: + default: + } } } @@ -1463,7 +1468,11 @@ func TestRaceAckOnClose(t *testing.T) { func TestNatsConn(t *testing.T) { s := RunServer(clusterName) defer s.Shutdown() - sc := NewDefaultConnection(t) + sc, err := Connect(clusterName, clientName, + ConnectWait(250*time.Millisecond)) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } defer sc.Close() // Make sure we can get the STAN-created Conn. @@ -1490,11 +1499,12 @@ func TestNatsConn(t *testing.T) { // Allow custom conn only if already connected opts := nats.GetDefaultOptions() - nc, err := opts.Connect() + nc, err = opts.Connect() if err != nil { t.Fatalf("Expected to connect correctly, got err %v", err) } - sc, err = Connect(clusterName, clientName, NatsConn(nc)) + sc, err = Connect(clusterName, clientName+"2", + NatsConn(nc), ConnectWait(250*time.Millisecond)) if err != nil { t.Fatalf("Expected to connect correctly, got err %v", err) } @@ -1511,7 +1521,8 @@ func TestNatsConn(t *testing.T) { t.Fatalf("Expected to connect correctly, got err %v", err) } defer nc.Close() - sc, err = Connect(clusterName, clientName, NatsConn(nc)) + sc, err = Connect(clusterName, clientName+"3", + NatsConn(nc), ConnectWait(250*time.Millisecond)) if err != nil { t.Fatalf("Expected to connect correctly, got err %v", err) } @@ -1532,14 +1543,14 @@ func TestMaxPubAcksInflight(t *testing.T) { defer nc.Close() sc, err := Connect(clusterName, clientName, + ConnectWait(250*time.Millisecond), MaxPubAcksInflight(1), PubAckWait(time.Second), NatsConn(nc)) if err != nil { t.Fatalf("Expected to connect correctly, got err %v", err) } - // Don't defer the close of connection since the server is stopped, - // the close would delay the test. + defer sc.Close() // Cause the ACK to not come by shutdown the server now s.Shutdown() @@ -2170,6 +2181,7 @@ func TestPings(t *testing.T) { if err != nil { t.Fatalf("Error on connect: %v", err) } + defer nc.Close() count := 0 ch := make(chan bool, 1) nc.Subscribe(DefaultDiscoverPrefix+"."+clusterName+".pings", func(m *nats.Msg) { @@ -2267,33 +2279,29 @@ func TestPingsCloseUnlockPubCalls(t *testing.T) { s.Shutdown() total := 100 - ch := make(chan bool, 1) - ec := int32(0) + ch := make(chan bool, total) ah := func(g string, e error) { - if c := atomic.AddInt32(&ec, 1); c == int32(total) { - ch <- true - } + // Invoke a function that requires connection's lock + sc.NatsConn() } - wg := sync.WaitGroup{} - wg.Add(total) for i := 0; i < total/2; i++ { go func() { sc.PublishAsync("foo", []byte("hello"), ah) - wg.Done() + ch <- true }() go func() { - if err := sc.Publish("foo", []byte("hello")); err != nil { - if c := atomic.AddInt32(&ec, 1); c == int32(total) { - ch <- true - } - } - wg.Done() + sc.Publish("foo", []byte("hello")) + ch <- true }() } - if err := Wait(ch); err != nil { - t.Fatal("Did not get all the expected failures") + tm := time.NewTimer(2 * time.Second) + for i := 0; i < total; i++ { + select { + case <-ch: + case <-tm.C: + t.Fatalf("%v/%v publish calls returned", i+1, total) + } } - wg.Wait() } func TestConnErrHandlerNotCalledOnNormalClose(t *testing.T) { @@ -2377,6 +2385,7 @@ func TestPubFailsOnClientReplaced(t *testing.T) { if err != nil { t.Fatalf("Error on connect: %v", err) } + defer sc2.Close() // Verify that this client can publish if err := sc2.Publish("foo", []byte("hello")); err != nil { t.Fatalf("Error on publish: %v", err) @@ -2430,6 +2439,7 @@ func TestPingsResponseError(t *testing.T) { if err != nil { t.Fatalf("Error on connect: %v", err) } + defer sc.Close() // Send a message and ensure it is ok. if err := sc.Publish("foo", []byte("hello")); err != nil { t.Fatalf("Error on publish: %v", err) @@ -2446,6 +2456,7 @@ func TestPingsResponseError(t *testing.T) { if err != nil { t.Fatalf("Error on connect: %v", err) } + defer sc2.Close() // Verify that this client can publish if err := sc2.Publish("foo", []byte("hello")); err != nil { t.Fatalf("Error on publish: %v", err) @@ -2611,3 +2622,46 @@ func TestNoMemoryLeak(t *testing.T) { } t.Fatalf("Heap in use seem high: old=%vMB - new=%vMB", oldInUse/oneMB, newInUse/oneMB) } + +func TestPublishAsyncTimeout(t *testing.T) { + ns := natsd.RunDefaultServer() + defer ns.Shutdown() + + opts := server.GetDefaultOptions() + opts.NATSServerURL = nats.DefaultURL + opts.ID = clusterName + s := runServerWithOpts(opts) + defer s.Shutdown() + + sc, err := Connect(clusterName, clientName, + ConnectWait(250*time.Millisecond), + PubAckWait(50*time.Millisecond)) + if err != nil { + t.Fatalf("Error on connect: %v", err) + } + defer sc.Close() + + s.Shutdown() + + total := 1000 + ch := make(chan bool, 1) + count := int32(0) + ah := func(g string, e error) { + // Invoke a function that requires connection's lock + sc.NatsConn() + if c := atomic.AddInt32(&count, 1); c == int32(total) { + ch <- true + } + } + for i := 0; i < total/2; i++ { + sc.PublishAsync("foo", []byte("hello"), ah) + } + time.Sleep(5 * time.Millisecond) + for i := 0; i < total/2; i++ { + sc.PublishAsync("foo", []byte("hello"), ah) + } + if err := Wait(ch); err != nil { + c := atomic.LoadInt32(&count) + t.Fatalf("Ack handler was invoked only %v out of %v", c, total) + } +}