From 8a78c4e118d5c7ecea182397a69ad8ca86decb61 Mon Sep 17 00:00:00 2001 From: Ivan Kozlovic Date: Tue, 7 Jul 2020 17:00:50 -0600 Subject: [PATCH] [ADDED] RetryOnFailedConnect option Normally, nats.Connect() would fail if there is no server available when this call is executed. With this new option, if no connection can be made, this call will return no error and will trigger code similar to the reconnect code. Therefore, MaxReconnect and ReconnectWait options are used as if the library had been disconnected and is trying to reconnect. Note that subscription and publish calls will also behave as if the library was in reconnection mode, which means that the calls are buffered and produce no error until the reconnect buffer size is full. Obviously, since the connection is not connected, Flush or Request/Reply calls would timeout. If the ReconnectHandler is set, it will be invoked if the library connects asynchronously. Unrelated: fixed a test that had a t.skip()... Resolves #195 Signed-off-by: Ivan Kozlovic --- README.md | 15 +++ nats.go | 56 +++++++-- norace_test.go | 2 + test/conn_test.go | 295 ++++++++++++++++++++++++++++++++++++++-------- 4 files changed, 314 insertions(+), 54 deletions(-) diff --git a/README.md b/README.md index 7b4d8ee88..8b6e6f08e 100644 --- a/README.md +++ b/README.md @@ -285,6 +285,21 @@ nc.QueueSubscribe("foo", "job_workers", func(_ *Msg) { ```go +// Normally, the library will return an error when trying to connect and +// there is no server running. The RetryOnFailedConnect option will set +// the connection in reconnecting state if it failed to connect right away. +nc, err := nats.Connect(nats.DefaultURL, + nats.RetryOnFailedConnect(true), + nats.MaxReconnects(10), + nats.ReconnectWait(time.Second), + nats.ReconnectHandler(func(_ *nats.Conn) { + // Note that this will be invoked for the first asynchronous connect. + })) +if err != nil { + // Should not return an error even if it can't connect, but you still + // need to check in case there are some configuration errors. +} + // Flush connection to server, returns when all messages have been processed. nc.Flush() fmt.Println("All clear!") diff --git a/nats.go b/nats.go index 1e78bf0f4..255fcda7c 100644 --- a/nats.go +++ b/nats.go @@ -392,6 +392,15 @@ type Options struct { // gradually disconnect all its connections before shuting down. This is // often used in deployments when upgrading NATS Servers. LameDuckModeHandler ConnHandler + + // RetryOnFailedConnect sets the connection in reconnecting state right + // away if it can't connect to a server in the initial set. The + // MaxReconnect and ReconnectWait options are used for this process, + // similarly to when an established connection is disconnected. + // If a ReconnectHandler is set, it will be invoked when the connection + // is established, and if a ClosedHandler is set, it will be invoked if + // it fails to connect (after exhausting the MaxReconnect attempts). + RetryOnFailedConnect bool } const ( @@ -1000,6 +1009,16 @@ func LameDuckModeHandler(cb ConnHandler) Option { } } +// RetryOnFailedConnect sets the connection in reconnecting state right away +// if it can't connect to a server in the initial set. +// See RetryOnFailedConnect option for more details. +func RetryOnFailedConnect(retry bool) Option { + return func(o *Options) error { + o.RetryOnFailedConnect = retry + return nil + } +} + // Handler processing // SetDisconnectHandler will set the disconnect event handler. @@ -1588,11 +1607,25 @@ func (nc *Conn) connect() error { } } } - nc.initc = false + if returnedErr == nil && nc.status != CONNECTED { returnedErr = ErrNoServers } + if returnedErr == nil { + nc.initc = false + } else if nc.Opts.RetryOnFailedConnect { + nc.setup() + nc.status = RECONNECTING + nc.pending = new(bytes.Buffer) + if nc.bw == nil { + nc.bw = nc.newBuffer() + } + nc.bw.Reset(nc.pending) + go nc.doReconnect(ErrNoServers) + returnedErr = nil + } + return returnedErr } @@ -1912,10 +1945,12 @@ func (nc *Conn) doReconnect(err error) { nc.err = nil // Perform appropriate callback if needed for a disconnect. // DisconnectedErrCB has priority over deprecated DisconnectedCB - if nc.Opts.DisconnectedErrCB != nil { - nc.ach.push(func() { nc.Opts.DisconnectedErrCB(nc, err) }) - } else if nc.Opts.DisconnectedCB != nil { - nc.ach.push(func() { nc.Opts.DisconnectedCB(nc) }) + if !nc.initc { + if nc.Opts.DisconnectedErrCB != nil { + nc.ach.push(func() { nc.Opts.DisconnectedErrCB(nc, err) }) + } else if nc.Opts.DisconnectedCB != nil { + nc.ach.push(func() { nc.Opts.DisconnectedCB(nc) }) + } } // This is used to wait on go routines exit if we start them in the loop @@ -2056,6 +2091,10 @@ func (nc *Conn) doReconnect(err error) { // This is where we are truly connected. nc.status = CONNECTED + // If we are here with a retry on failed connect, indicate that the + // initial connect is now complete. + nc.initc = false + // Queue up the reconnect callback. if nc.Opts.ReconnectedCB != nil { nc.ach.push(func() { nc.Opts.ReconnectedCB(nc) }) @@ -2532,7 +2571,7 @@ func (nc *Conn) processInfo(info string) error { // did not include themselves in the async INFO protocol. // If empty, do not remove the implicit servers from the pool. if len(ncInfo.ConnectURLs) == 0 { - if ncInfo.LameDuckMode && nc.Opts.LameDuckModeHandler != nil { + if !nc.initc && ncInfo.LameDuckMode && nc.Opts.LameDuckModeHandler != nil { nc.ach.push(func() { nc.Opts.LameDuckModeHandler(nc) }) } return nil @@ -2595,7 +2634,7 @@ func (nc *Conn) processInfo(info string) error { nc.ach.push(func() { nc.Opts.DiscoveredServersCB(nc) }) } } - if ncInfo.LameDuckMode && nc.Opts.LameDuckModeHandler != nil { + if !nc.initc && ncInfo.LameDuckMode && nc.Opts.LameDuckModeHandler != nil { nc.ach.push(func() { nc.Opts.LameDuckModeHandler(nc) }) } return nil @@ -2776,7 +2815,8 @@ func (nc *Conn) publish(subj, reply string, hdr, data []byte) error { // Proactively reject payloads over the threshold set by server. msgSize := int64(len(data) + len(hdr)) - if msgSize > nc.info.MaxPayload { + // Skip this check if we are not yet connected (RetryOnFailedConnect) + if !nc.initc && msgSize > nc.info.MaxPayload { nc.mu.Unlock() return ErrMaxPayload } diff --git a/norace_test.go b/norace_test.go index 83071a2b9..247b39f53 100644 --- a/norace_test.go +++ b/norace_test.go @@ -33,6 +33,7 @@ func TestNoRaceParseStateReconnectFunctionality(t *testing.T) { opts.DisconnectedCB = func(_ *Conn) { dch <- true } + opts.NoCallbacksAfterClientClose = true nc, errc := opts.Connect() if errc != nil { @@ -94,4 +95,5 @@ func TestNoRaceParseStateReconnectFunctionality(t *testing.T) { t.Fatalf("Reconnect count incorrect: %d vs %d\n", reconnectedCount, expectedReconnectCount) } + nc.Close() } diff --git a/test/conn_test.go b/test/conn_test.go index 91db87ba9..7f8c22ccd 100644 --- a/test/conn_test.go +++ b/test/conn_test.go @@ -1470,41 +1470,48 @@ func TestDefaultOptionsDialer(t *testing.T) { } } -func TestCustomFlusherTimeout(t *testing.T) { - t.Skip("broken test") +type lowWriteBufferDialer struct{} + +func (d *lowWriteBufferDialer) Dial(network, address string) (net.Conn, error) { + c, err := net.Dial(network, address) + if err != nil { + return nil, err + } + c.(*net.TCPConn).SetWriteBuffer(100) + return c, nil +} +func TestCustomFlusherTimeout(t *testing.T) { s := RunDefaultServer() defer s.Shutdown() - opts := &nats.Options{ - Servers: []string{nats.DefaultURL}, - - // Reasonably large flusher timeout will not induce errors - // when we can flush fast - FlusherTimeout: 10 * time.Second, - } - nc1, err := opts.Connect() + // Reasonably large flusher timeout will not induce errors + // when we can flush fast + nc1, err := nats.Connect(nats.DefaultURL, nats.FlusherTimeout(10*time.Second)) if err != nil { t.Fatalf("Expected to be able to connect, got: %s", err) } - doneCh := make(chan struct{}) - payload := "" - for i := 0; i < 8192; i++ { - payload += "A" - } - payloadBytes := []byte(payload) + doneCh := make(chan struct{}, 1) + // We want to have a payload size that is big enough so that after + // few publish, the socket buffer will be full and produce the timeout. + // Since we try to produce the error in the flusher and not the publish + // call itself, use a size that is a bit less than the internal + // buffer used by the library. + payloadBytes := make([]byte, 32*1024-200) + errCh := make(chan error, 1) + wg := sync.WaitGroup{} + wg.Add(2) go func() { + defer wg.Done() for { select { case <-time.After(200 * time.Millisecond): err := nc1.Publish("hello", payloadBytes) if err != nil { - t.Errorf("Error during publish: %s", err) + errCh <- err + return } - case <-time.After(5 * time.Second): - t.Errorf("Timeout publishing messages") - return case <-doneCh: return } @@ -1512,48 +1519,134 @@ func TestCustomFlusherTimeout(t *testing.T) { }() defer nc1.Close() - opts = &nats.Options{ - Servers: []string{nats.DefaultURL}, + l, e := net.Listen("tcp", "127.0.0.1:0") + if e != nil { + t.Fatal("Could not listen on an ephemeral port") + } + tl := l.(*net.TCPListener) + defer tl.Close() - // Use short flusher timeout to trigger the error - FlusherTimeout: 1 * time.Microsecond, + addr := tl.Addr().(*net.TCPAddr) - // Upon failure to be able to exercice ping pong interval - // then we will hit this timeout and disconnect - PingInterval: 500 * time.Millisecond, - } + fsDoneCh := make(chan struct{}, 1) + fsErrCh := make(chan error, 1) + go func() { + defer wg.Done() - opts.DisconnectedErrCB = func(nc *nats.Conn, _ error) { - // Ping loops that test is done - doneCh <- struct{}{} - } + serverInfo := "INFO {\"server_id\":\"foobar\",\"host\":\"%s\",\"port\":%d,\"auth_required\":false,\"tls_required\":false,\"max_payload\":%d}\r\n" + conn, err := l.Accept() + if err != nil { + fsErrCh <- err + return + } + defer conn.Close() + // Make it small on purpose + if err := conn.(*net.TCPConn).SetReadBuffer(1024); err != nil { + fsErrCh <- err + return + } + + info := fmt.Sprintf(serverInfo, addr.IP, addr.Port, 1024*1024) + conn.Write([]byte(info)) + + // Read connect and ping commands sent from the client + line := make([]byte, 100) + _, err = conn.Read(line) + if err != nil { + fsErrCh <- fmt.Errorf("Expected CONNECT and PING from client, got: %v", err) + return + } + conn.Write([]byte("PONG\r\n")) + + // Don't consume anything at this point and wait to be notified + // that we are done. + <-fsDoneCh + fsErrCh <- nil + }() - nc2, err := opts.Connect() + nc2, err := nats.Connect( + // URL to fake server + fmt.Sprintf("nats://127.0.0.1:%d", addr.Port), + // Use custom dialer so we can set write buffer to low value + nats.SetCustomDialer(&lowWriteBufferDialer{}), + // Use short flusher timeout to trigger the error + nats.FlusherTimeout(15*time.Millisecond), + // Make sure the library does not close connection due + // to pings for this test. + nats.PingInterval(20*time.Second), + // No reconnect + nats.NoReconnect(), + // Notify when connection lost + nats.ClosedHandler(func(_ *nats.Conn) { + doneCh <- struct{}{} + })) if err != nil { t.Fatalf("Expected to be able to connect, got: %s", err) } defer nc2.Close() - // Consume messages to make the reading loop work - _, err = nc2.Subscribe(">", func(_ *nats.Msg) {}) - if err != nil { - t.Fatalf("Expected to be able to create subscription, got: %s", err) - } + var ( + pubErr error + nc2Err error + tm = time.NewTimer(5 * time.Second) + ) +forLoop: for { select { case <-time.After(100 * time.Millisecond): - // Some of the publishes will succeed and others fail with i/o timeout error - // but eventually ping interval will fail and close the connection. - err = nc2.Publish("world", payloadBytes) - if err == nats.ErrConnectionClosed { - return + // We are trying to get the flusher to report the error, but it + // is possible that the Publish() call itself flushes and we don't + // want to fail the test for that. + pubErr = nc2.Publish("world", payloadBytes) + nc2Err = nc2.LastError() + if nc2Err != nil { + break forLoop } - case <-time.After(5 * time.Second): - t.Errorf("Timeout publishing messages") - return + case <-tm.C: + // We got an error, but not from flusher. Don't fail yet. Will check + // if this is a timeout error as expected. + if pubErr != nil { + break forLoop + } + t.Fatalf("Timeout publishing messages") } } + + // Notify fake server that it can stop + close(fsDoneCh) + + // Wait for go routines to end + wg.Wait() + + // Make sure there were no error in the fake server + if err := <-fsErrCh; err != nil { + t.Fatalf("Fake server reported: %v", err) + } + + // One of those two are guaranteed to be set. + err = nc2Err + if err == nil { + err = pubErr + } + // Check that error is a timeout error as expected. + ope, ok := err.(*net.OpError) + if !ok { + t.Fatalf("expected a net.Error, got %v", err) + } + if !ope.Timeout() { + t.Fatalf("expected a timeout, got %v", err) + } + if ope.Op != "write" { + t.Fatalf("expected a write error, got %v", err) + } + + // Check that there is no error from nc1 + select { + case e := <-errCh: + t.Fatal(e) + default: + } } func TestNewServers(t *testing.T) { @@ -2287,3 +2380,113 @@ func TestTLSDontSkipVerify(t *testing.T) { } nc.Close() } + +func TestRetryOnFailedConnect(t *testing.T) { + nc, err := nats.Connect(nats.DefaultURL) + if err == nil { + nc.Close() + t.Fatal("Expected error, did not get one") + } + ch := make(chan bool, 1) + dch := make(chan bool, 1) + nc, err = nats.Connect(nats.DefaultURL, + nats.RetryOnFailedConnect(true), + nats.MaxReconnects(-1), + nats.ReconnectWait(15*time.Millisecond), + nats.DisconnectErrHandler(func(_ *nats.Conn, _ error) { + dch <- true + }), + nats.ReconnectHandler(func(_ *nats.Conn) { + ch <- true + }), + nats.NoCallbacksAfterClientClose()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + sub, err := nc.SubscribeSync("foo") + if err != nil { + t.Fatalf("Error on subscribe: %v", err) + } + if err := nc.Publish("foo", []byte("msg")); err != nil { + t.Fatalf("Error on publish: %v", err) + } + for i := 0; i < 2; i++ { + // Start server now + s := RunDefaultServer() + defer s.Shutdown() + + var action string + switch i { + case 0: + action = "connected" + case 1: + action = "reconnected" + } + + // Wait for the reconnect CB which in this context means that we connected ok + select { + case <-ch: + case <-time.After(2 * time.Second): + t.Fatalf("Should have %s", action) + } + + // Now make sure that the pub worked and sub worked. + // We should receive the message we have published. + if _, err := sub.NextMsg(time.Second); err != nil { + t.Fatalf("Iter=%v - did not receive message: %v", i, err) + } + + // Check that normal disconnect/reconnect works as expected + s.Shutdown() + + select { + case <-dch: + case <-time.After(time.Second): + t.Fatal("Should have been disconnected") + } + + if i == 0 { + if err := nc.Publish("foo", []byte("msg")); err != nil { + t.Fatalf("Iter=%v - error on publish: %v", i, err) + } + } + } + nc.Close() + + // Try again but this time we will restart a server with u/p and auth should fail. + closedCh := make(chan bool, 1) + nc, err = nats.Connect(nats.DefaultURL, + nats.RetryOnFailedConnect(true), + nats.MaxReconnects(-1), + nats.ReconnectWait(15*time.Millisecond), + nats.ReconnectHandler(func(_ *nats.Conn) { + ch <- true + }), + nats.ClosedHandler(func(_ *nats.Conn) { + closedCh <- true + })) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + defer nc.Close() + + o := test.DefaultTestOptions + o.Host = "127.0.0.1" + o.Port = 4222 + o.Username = "user" + o.Password = "password" + s := RunServerWithOptions(o) + defer s.Shutdown() + + select { + case <-closedCh: + case <-time.After(2 * time.Second): + t.Fatal("Should have stopped trying to connect due to auth failure") + } + // Make sure that we did not get the (re)connected CB + select { + case <-ch: + t.Fatal("(re)connected callback should not have been invoked") + default: + } +}