Skip to content

Commit

Permalink
Merge pull request #505 from nats-io/fix_last_err
Browse files Browse the repository at this point in the history
Trim "nats: " before invoking processErr in doReconnect
  • Loading branch information
kozlovic committed Aug 6, 2019
2 parents 1f35261 + abce817 commit 8f27558
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 37 deletions.
23 changes: 7 additions & 16 deletions nats.go
Original file line number Diff line number Diff line change
Expand Up @@ -1852,9 +1852,10 @@ func (nc *Conn) doReconnect(err error) {
// If we have a lastErr recorded for this server
// do the normal processing here. We might get closed.
if nc.current.lastErr != nil {
err := nc.err
// Remove possible "nats: " prefix
errStr := strings.TrimPrefix(nc.err.Error(), "nats: ")
nc.mu.Unlock()
nc.processErr(err.Error())
nc.processErr(errStr)
nc.mu.Lock()
if nc.isClosed() {
break
Expand All @@ -1868,6 +1869,10 @@ func (nc *Conn) doReconnect(err error) {
continue
}

// Clear possible lastErr under the connection lock after
// a successful processConnectInit().
nc.current.lastErr = nil

// Clear out server stats for the server we connected to..
cur.didConnect = true
cur.reconnects = 0
Expand Down Expand Up @@ -1904,18 +1909,12 @@ func (nc *Conn) doReconnect(err error) {
nc.ach.push(func() { nc.Opts.ReconnectedCB(nc) })
}

lastErr := nc.current.lastErr

// Release lock here, we will return below.
nc.mu.Unlock()

// Make sure to flush everything
nc.Flush()

if lastErr != nil && !nc.IsClosed() {
nc.clearCurrentLastErr()
}

return
}

Expand Down Expand Up @@ -2297,14 +2296,6 @@ func (nc *Conn) processAuthError(err error) {
nc.mu.Unlock()
}

// clearCurrentLastErr will clear the last error when we know we have
// successfully connected after a flush.
func (nc *Conn) clearCurrentLastErr() {
nc.mu.Lock()
nc.current.lastErr = nil
nc.mu.Unlock()
}

// flusher is a separate Go routine that will process flush requests for the write
// bufio. This allows coalescing of writes to the underlying socket.
func (nc *Conn) flusher() {
Expand Down
108 changes: 87 additions & 21 deletions nats_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1412,35 +1412,74 @@ func createNewUserKeys() (string, []byte) {
}

func TestExpiredUserCredentials(t *testing.T) {
if server.VERSION[0] == '1' {
t.Skip()
// The goal of this test was to check how a client with an expiring JWT
// behaves. It should receive an async -ERR indicating that the auth
// has expired, which will trigger reconnects. There, the lib should
// received -ERR for auth violation in response to the CONNECT (instead
// of the PONG). The library should close the connection after receiving
// twice the same auth error.
// If we use an actual JWT that expires, the way the JWT library expires
// a JWT cause the server to send the async -ERR first but then accepts
// the CONNECT (since JWT lib does not say that it has expired), but
// when the server sets up the expire callback, that callback fires right
// away and so client receives async -ERR again.
// So for a deterministic test, we won't use an actual NATS Server.
// Instead, we will use a mock that simply returns appropriate -ERR and
// ensure the client behaves as expected.
l, e := net.Listen("tcp", "127.0.0.1:0")
if e != nil {
t.Fatal("Could not listen on an ephemeral port")
}
ts := runTrustServer()
defer ts.Shutdown()
tl := l.(*net.TCPListener)
defer tl.Close()

// Create user credentials that will expire in a short timeframe.
pub, priv := createNewUserKeys()
nuc := jwt.NewUserClaims(pub)
nuc.Expires = time.Now().Add(time.Second).Unix()
akp, _ := nkeys.FromSeed(aSeed)
ujwt, err := nuc.Encode(akp)
if err != nil {
t.Fatalf("Error encoding user jwt: %v", err)
}
creds, err := jwt.FormatUserConfig(ujwt, priv)
if err != nil {
t.Fatalf("Error encoding credentials: %v", err)
}
chainedFile := createTmpFile(t, creds)
defer os.Remove(chainedFile)
addr := tl.Addr().(*net.TCPAddr)

wg := sync.WaitGroup{}
wg.Add(1)

go func() {
defer wg.Done()
connect := 0
for {
conn, err := l.Accept()
if err != nil {
return
}
defer conn.Close()

info := "INFO {\"server_id\":\"foobar\",\"nonce\":\"anonce\"}\r\n"
conn.Write([]byte(info))

// Read connect and ping commands sent from the client
br := bufio.NewReaderSize(conn, 10*1024)
br.ReadLine()
br.ReadLine()

if connect++; connect == 1 {
conn.Write([]byte(fmt.Sprintf("%s%s", _PONG_OP_, _CRLF_)))
time.Sleep(300 * time.Millisecond)
conn.Write([]byte(fmt.Sprintf("-ERR '%s'\r\n", AUTHENTICATION_EXPIRED_ERR)))
} else {
conn.Write([]byte(fmt.Sprintf("-ERR '%s'\r\n", AUTHORIZATION_ERR)))
}
conn.Close()
}
}()

ch := make(chan bool)
errCh := make(chan error, 10)

url := fmt.Sprintf("nats://127.0.0.1:%d", TEST_PORT)
url := fmt.Sprintf("nats://127.0.0.1:%d", addr.Port)
nc, err := Connect(url,
UserCredentials(chainedFile),
ReconnectWait(25*time.Millisecond),
MaxReconnects(-1),
ErrorHandler(func(_ *Conn, _ *Subscription, e error) {
select {
case errCh <- e:
default:
}
}),
ClosedHandler(func(nc *Conn) {
ch <- true
}),
Expand All @@ -1457,6 +1496,33 @@ func TestExpiredUserCredentials(t *testing.T) {
if stats := nc.Stats(); stats.Reconnects > 2 {
t.Fatalf("Expected at most 2 reconnects, got %d", stats.Reconnects)
}
// We expect 3 errors, an AUTHENTICATION_EXPIRED_ERR, then 2 AUTHORIZATION_ERR
// before the connection is closed.
for i := 0; i < 3; i++ {
select {
case e := <-errCh:
if i == 0 && e != ErrAuthExpired {
t.Fatalf("Expected error %q, got %q", ErrAuthExpired, e)
} else if i > 0 && e != ErrAuthorization {
t.Fatalf("Expected error %q, got %q", ErrAuthorization, e)
}
default:
if i == 0 {
t.Fatalf("Missing %q error", ErrAuthExpired)
} else {
t.Fatalf("Missing %q error", ErrAuthorization)
}
}
}
// We should not have any more error
select {
case e := <-errCh:
t.Fatalf("Extra error: %v", e)
default:
}
// Close the listener and wait for go routine to end.
l.Close()
wg.Wait()
}

func TestExpiredUserCredentialsRenewal(t *testing.T) {
Expand Down

0 comments on commit 8f27558

Please sign in to comment.