Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update handling of account authentication expired error #695

Merged
merged 1 commit into from
Mar 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions nats.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ const (

// AUTHENTICATION_REVOKED_ERR is for when user authorization has been revoked.
AUTHENTICATION_REVOKED_ERR = "user authentication revoked"

// ACCOUNT_AUTHENTICATION_EXPIRED_ERR is for when nats server account authorization has expired.
ACCOUNT_AUTHENTICATION_EXPIRED_ERR = "account authentication expired"
)

// Errors
Expand All @@ -98,6 +101,7 @@ var (
ErrAuthorization = errors.New("nats: authorization violation")
ErrAuthExpired = errors.New("nats: authentication expired")
ErrAuthRevoked = errors.New("nats: authentication revoked")
ErrAccountAuthExpired = errors.New("nats: account authentication expired")
ErrNoServers = errors.New("nats: no servers available for connection")
ErrJsonParse = errors.New("nats: connect message, json parse error")
ErrChanArg = errors.New("nats: argument needs to be a channel type")
Expand Down Expand Up @@ -2766,6 +2770,9 @@ func checkAuthError(e string) error {
if strings.HasPrefix(e, AUTHENTICATION_REVOKED_ERR) {
return ErrAuthRevoked
}
if strings.HasPrefix(e, ACCOUNT_AUTHENTICATION_EXPIRED_ERR) {
return ErrAccountAuthExpired
}
return nil
}

Expand Down
277 changes: 95 additions & 182 deletions nats_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1436,8 +1436,8 @@ func TestUserCredentialsChainedFile(t *testing.T) {
nc.Close()
}

func TestExpiredUserCredentials(t *testing.T) {
// The goal of this test was to check how a client with an expiring JWT
func TestExpiredAuthentication(t *testing.T) {
// The goal of these tests was to check how a client with an expiring JWT
// behaves. It should receive an async -ERR indicating that the auth
// has expired, which will trigger reconnects. There, the lib should
// received -ERR for auth violation in response to the CONNECT (instead
Expand All @@ -1451,204 +1451,117 @@ func TestExpiredUserCredentials(t *testing.T) {
// So for a deterministic test, we won't use an actual NATS Server.
// Instead, we will use a mock that simply returns appropriate -ERR and
// ensure the client behaves as expected.
l, e := net.Listen("tcp", "127.0.0.1:0")
if e != nil {
t.Fatal("Could not listen on an ephemeral port")
}
tl := l.(*net.TCPListener)
defer tl.Close()

addr := tl.Addr().(*net.TCPAddr)

wg := sync.WaitGroup{}
wg.Add(1)

go func() {
defer wg.Done()
connect := 0
for {
conn, err := l.Accept()
if err != nil {
return
for _, test := range []struct {
name string
expectedProto string
expectedErr error
}{
{"expired users credentials", AUTHENTICATION_EXPIRED_ERR, ErrAuthExpired},
{"revoked users credentials", AUTHENTICATION_REVOKED_ERR, ErrAuthRevoked},
{"expired account", ACCOUNT_AUTHENTICATION_EXPIRED_ERR, ErrAccountAuthExpired},
} {
t.Run(test.name, func(t *testing.T) {
l, e := net.Listen("tcp", "127.0.0.1:0")
if e != nil {
t.Fatal("Could not listen on an ephemeral port")
}
defer conn.Close()
tl := l.(*net.TCPListener)
defer tl.Close()

info := "INFO {\"server_id\":\"foobar\",\"nonce\":\"anonce\"}\r\n"
conn.Write([]byte(info))

// Read connect and ping commands sent from the client
br := bufio.NewReaderSize(conn, 10*1024)
br.ReadLine()
br.ReadLine()

if connect++; connect == 1 {
conn.Write([]byte(fmt.Sprintf("%s%s", _PONG_OP_, _CRLF_)))
time.Sleep(300 * time.Millisecond)
conn.Write([]byte(fmt.Sprintf("-ERR '%s'\r\n", AUTHENTICATION_EXPIRED_ERR)))
} else {
conn.Write([]byte(fmt.Sprintf("-ERR '%s'\r\n", AUTHORIZATION_ERR)))
}
conn.Close()
}
}()
addr := tl.Addr().(*net.TCPAddr)

ch := make(chan bool)
errCh := make(chan error, 10)
wg := sync.WaitGroup{}
wg.Add(1)

url := fmt.Sprintf("nats://127.0.0.1:%d", addr.Port)
nc, err := Connect(url,
ReconnectWait(25*time.Millisecond),
ReconnectJitter(0, 0),
MaxReconnects(-1),
ErrorHandler(func(_ *Conn, _ *Subscription, e error) {
select {
case errCh <- e:
default:
}
}),
ClosedHandler(func(nc *Conn) {
ch <- true
}),
)
if err != nil {
t.Fatalf("Expected to connect, got %v", err)
}
defer nc.Close()
go func() {
defer wg.Done()
connect := 0
for {
conn, err := l.Accept()
if err != nil {
return
}
defer conn.Close()

// We should give up since we get the same error on both tries.
if err := WaitTime(ch, 2*time.Second); err != nil {
t.Fatal("Should have closed after multiple failed attempts.")
}
if stats := nc.Stats(); stats.Reconnects > 2 {
t.Fatalf("Expected at most 2 reconnects, got %d", stats.Reconnects)
}
// We expect 3 errors, an AUTHENTICATION_EXPIRED_ERR, then 2 AUTHORIZATION_ERR
// before the connection is closed.
for i := 0; i < 3; i++ {
select {
case e := <-errCh:
if i == 0 && e != ErrAuthExpired {
t.Fatalf("Expected error %q, got %q", ErrAuthExpired, e)
} else if i > 0 && e != ErrAuthorization {
t.Fatalf("Expected error %q, got %q", ErrAuthorization, e)
}
default:
if i == 0 {
t.Fatalf("Missing %q error", ErrAuthExpired)
} else {
t.Fatalf("Missing %q error", ErrAuthorization)
}
}
}
// We should not have any more error
select {
case e := <-errCh:
t.Fatalf("Extra error: %v", e)
default:
}
// Close the listener and wait for go routine to end.
l.Close()
wg.Wait()
}
info := "INFO {\"server_id\":\"foobar\",\"nonce\":\"anonce\"}\r\n"
conn.Write([]byte(info))

func TestRevokedUserCredentials(t *testing.T) {
// Mock that the client connects and then is revoked.
l, e := net.Listen("tcp", "127.0.0.1:0")
if e != nil {
t.Fatal("Could not listen on an ephemeral port")
}
tl := l.(*net.TCPListener)
defer tl.Close()
// Read connect and ping commands sent from the client
br := bufio.NewReaderSize(conn, 10*1024)
br.ReadLine()
br.ReadLine()

addr := tl.Addr().(*net.TCPAddr)
if connect++; connect == 1 {
conn.Write([]byte(fmt.Sprintf("%s%s", _PONG_OP_, _CRLF_)))
time.Sleep(300 * time.Millisecond)
conn.Write([]byte(fmt.Sprintf("-ERR '%s'\r\n", test.expectedProto)))
} else {
conn.Write([]byte(fmt.Sprintf("-ERR '%s'\r\n", AUTHORIZATION_ERR)))
}
conn.Close()
}
}()

wg := sync.WaitGroup{}
wg.Add(1)
ch := make(chan bool)
errCh := make(chan error, 10)

go func() {
defer wg.Done()
connect := 0
for {
conn, err := l.Accept()
url := fmt.Sprintf("nats://127.0.0.1:%d", addr.Port)
nc, err := Connect(url,
ReconnectWait(25*time.Millisecond),
ReconnectJitter(0, 0),
MaxReconnects(-1),
ErrorHandler(func(_ *Conn, _ *Subscription, e error) {
select {
case errCh <- e:
default:
}
}),
ClosedHandler(func(nc *Conn) {
ch <- true
}),
)
if err != nil {
return
t.Fatalf("Expected to connect, got %v", err)
}
defer conn.Close()

info := "INFO {\"server_id\":\"foobar\",\"nonce\":\"anonce\"}\r\n"
conn.Write([]byte(info))

// Read connect and ping commands sent from the client
br := bufio.NewReaderSize(conn, 10*1024)
br.ReadLine()
br.ReadLine()
defer nc.Close()

if connect++; connect == 1 {
conn.Write([]byte(fmt.Sprintf("%s%s", _PONG_OP_, _CRLF_)))
time.Sleep(300 * time.Millisecond)
conn.Write([]byte(fmt.Sprintf("-ERR '%s'\r\n", AUTHENTICATION_REVOKED_ERR)))
} else {
conn.Write([]byte(fmt.Sprintf("-ERR '%s'\r\n", AUTHORIZATION_ERR)))
// We should give up since we get the same error on both tries.
if err := WaitTime(ch, 2*time.Second); err != nil {
t.Fatal("Should have closed after multiple failed attempts.")
}
conn.Close()
}
}()

ch := make(chan bool)
errCh := make(chan error, 10)

url := fmt.Sprintf("nats://127.0.0.1:%d", addr.Port)
nc, err := Connect(url,
ReconnectWait(25*time.Millisecond),
ReconnectJitter(0, 0),
MaxReconnects(-1),
ErrorHandler(func(_ *Conn, _ *Subscription, e error) {
select {
case errCh <- e:
default:
if stats := nc.Stats(); stats.Reconnects > 2 {
t.Fatalf("Expected at most 2 reconnects, got %d", stats.Reconnects)
}
}),
ClosedHandler(func(nc *Conn) {
ch <- true
}),
)
if err != nil {
t.Fatalf("Expected to connect, got %v", err)
}
defer nc.Close()

// We should give up since we get the same error on both tries.
if err := WaitTime(ch, 2*time.Second); err != nil {
t.Fatal("Should have closed after multiple failed attempts.")
}
if stats := nc.Stats(); stats.Reconnects > 2 {
t.Fatalf("Expected at most 2 reconnects, got %d", stats.Reconnects)
}
for i := 0; i < 3; i++ {
select {
case e := <-errCh:
if i == 0 && e != ErrAuthRevoked {
t.Fatalf("Expected error %q, got %q", ErrAuthRevoked, e)
} else if i > 0 && e != ErrAuthorization {
t.Fatalf("Expected error %q, got %q", ErrAuthorization, e)
// We expect 3 errors, the expired auth/revoke error, then 2 AUTHORIZATION_ERR
// before the connection is closed.
for i := 0; i < 3; i++ {
select {
case e := <-errCh:
if i == 0 && e != test.expectedErr {
t.Fatalf("Expected error %q, got %q", test.expectedErr, e)
} else if i > 0 && e != ErrAuthorization {
t.Fatalf("Expected error %q, got %q", ErrAuthorization, e)
}
default:
if i == 0 {
t.Fatalf("Missing %q error", test.expectedErr)
} else {
t.Fatalf("Missing %q error", ErrAuthorization)
}
}
}
default:
if i == 0 {
t.Fatalf("Missing %q error", ErrAuthRevoked)
} else {
t.Fatalf("Missing %q error", ErrAuthorization)
// We should not have any more error
select {
case e := <-errCh:
t.Fatalf("Extra error: %v", e)
default:
}
}
}
// We should not have any more error
select {
case e := <-errCh:
t.Fatalf("Extra error: %v", e)
default:
// Close the listener and wait for go routine to end.
l.Close()
wg.Wait()
})
}
// Close the listener and wait for go routine to end.
l.Close()
wg.Wait()
}

// If we are using TLS and have multiple servers we try to match the IP
Expand Down