From 98df5fc5e696bc1306b055dc736c4067942a2349 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ram=C3=B3n=20Berrutti?= Date: Thu, 23 Nov 2023 22:00:42 +0000 Subject: [PATCH] callout: try to renew jwt when expire --- server/auth_callout.go | 2 +- server/auth_callout_test.go | 80 ++++++++++++++++++++++++------------- server/client.go | 71 ++++++++++++++++++++++++++++++-- 3 files changed, 122 insertions(+), 31 deletions(-) diff --git a/server/auth_callout.go b/server/auth_callout.go index 359914c633..79a3947247 100644 --- a/server/auth_callout.go +++ b/server/auth_callout.go @@ -289,7 +289,7 @@ func (s *Server) processClientOrLeafCallout(c *client, opts *Options) (authorize } // Check if we need to set an auth timer if the user jwt expires. - c.setExpiration(arc.Claims(), expiration) + c.setRenewal(arc.Claims(), expiration) respCh <- _EMPTY_ } diff --git a/server/auth_callout_test.go b/server/auth_callout_test.go index ba77d8237d..5ee8ac8561 100644 --- a/server/auth_callout_test.go +++ b/server/auth_callout_test.go @@ -207,18 +207,23 @@ func TestAuthCalloutBasics(t *testing.T) { } ` callouts := uint32(0) + waitForCallout := make(chan struct{}) handler := func(m *nats.Msg) { - atomic.AddUint32(&callouts, 1) + calls := atomic.AddUint32(&callouts, 1) user, si, ci, opts, _ := decodeAuthRequest(t, m.Data) require_True(t, si.Name == "A") require_True(t, ci.Host == "127.0.0.1") // Allow dlc user. - if opts.Username == "dlc" && opts.Password == "zzz" { + if calls <= 3 && opts.Username == "dlc" && opts.Password == "zzz" { var j jwt.UserPermissionLimits j.Pub.Allow.Add("$SYS.>") + if calls == 3 { + j.Pub.Allow.Add("ramon.>") + } j.Payload = 1024 - ujwt := createAuthUser(t, user, _EMPTY_, globalAccountName, "", nil, 10*time.Minute, &j) + ujwt := createAuthUser(t, user, _EMPTY_, globalAccountName, "", nil, 10*time.Second, &j) m.Respond(serviceResponse(t, user, si.ID, ujwt, "", 0)) + waitForCallout <- struct{}{} } else { // Nil response signals no authentication. m.Respond(nil) @@ -233,33 +238,54 @@ func TestAuthCalloutBasics(t *testing.T) { // This one will use callout since not defined in server config. nc := at.Connect(nats.UserInfo("dlc", "zzz")) - resp, err := nc.Request(userDirectInfoSubj, nil, time.Second) - require_NoError(t, err) - response := ServerAPIResponse{Data: &UserInfo{}} - err = json.Unmarshal(resp.Data, &response) - require_NoError(t, err) - - userInfo := response.Data.(*UserInfo) + compareUserInfo := func(perm ...string) { + time.Sleep(100 * time.Millisecond) + resp, err := nc.Request(userDirectInfoSubj, nil, time.Second) + require_NoError(t, err) + response := ServerAPIResponse{Data: &UserInfo{}} + err = json.Unmarshal(resp.Data, &response) + require_NoError(t, err) - dlc := &UserInfo{ - UserID: "dlc", - Account: globalAccountName, - Permissions: &Permissions{ - Publish: &SubjectPermission{ - Allow: []string{"$SYS.>"}, - Deny: []string{AuthCalloutSubject}, // Will be auto-added since in auth account. + userInfo := response.Data.(*UserInfo) + + dlc := &UserInfo{ + UserID: "dlc", + Account: globalAccountName, + Permissions: &Permissions{ + Publish: &SubjectPermission{ + Allow: append([]string{"$SYS.>"}, perm...), + Deny: []string{AuthCalloutSubject}, // Will be auto-added since in auth account. + }, + Subscribe: &SubjectPermission{}, }, - Subscribe: &SubjectPermission{}, - }, - } - expires := userInfo.Expires - userInfo.Expires = 0 - if !reflect.DeepEqual(dlc, userInfo) { - t.Fatalf("User info for %q did not match", "dlc") - } - if expires > 10*time.Minute || expires < (10*time.Minute-5*time.Second) { - t.Fatalf("Expected expires of ~%v, got %v", 10*time.Minute, expires) + } + expires := userInfo.Expires + userInfo.Expires = 0 + if !reflect.DeepEqual(dlc, userInfo) { + dlcJson, _ := json.MarshalIndent(dlc, "", " ") + userInfoJson, _ := json.MarshalIndent(userInfo, "", " ") + t.Fatalf("User info for %q did not match %s %s", "dlc", dlcJson, userInfoJson) + } + if expires > 10*time.Second || expires < (10*time.Second-5*time.Second) { + t.Fatalf("Expected expires of ~%v, got %v", 10*time.Second, expires) + } } + + <-waitForCallout + compareUserInfo() + + // Wait for a second valid callout with a new permission. + <-waitForCallout + compareUserInfo("ramon.>") + + disconnected := make(chan struct{}) + nc.SetErrorHandler(func(_ *nats.Conn, _ *nats.Subscription, err error) { + if err != nats.ErrAuthExpired { + t.Fatalf("Expected %v, got %v", nats.ErrAuthExpired, err) + } + close(disconnected) + }) + <-disconnected } func TestAuthCalloutMultiAccounts(t *testing.T) { diff --git a/server/client.go b/server/client.go index 72346e6958..0188ef8bbd 100644 --- a/server/client.go +++ b/server/client.go @@ -254,6 +254,7 @@ type client struct { darray []string pcd map[*client]struct{} atmr *time.Timer + rtmr *time.Timer // renew timer. expires time.Time ping pinfo msgb [msgScratchSize]byte @@ -1144,9 +1145,17 @@ func (c *client) mergeDenyPermissionsLocked(what denyType, denyPubs []string) { // Check to see if we have an expiration for the user JWT via base claims. // FIXME(dlc) - Clear on connect with new JWT. func (c *client) setExpiration(claims *jwt.ClaimsData, validFor time.Duration) { + c.setTimer(claims, validFor, c.setExpirationTimer) +} + +func (c *client) setRenewal(claims *jwt.ClaimsData, validFor time.Duration) { + c.setTimer(claims, validFor, c.setRenewalTimer) +} + +func (c *client) setTimer(claims *jwt.ClaimsData, validFor time.Duration, f func(time.Duration)) { if claims.Expires == 0 { if validFor != 0 { - c.setExpirationTimer(validFor) + f(validFor) } return } @@ -1156,9 +1165,9 @@ func (c *client) setExpiration(claims *jwt.ClaimsData, validFor time.Duration) { expiresAt = time.Duration(claims.Expires-tn) * time.Second } if validFor != 0 && validFor < expiresAt { - c.setExpirationTimer(validFor) + f(validFor) } else { - c.setExpirationTimer(expiresAt) + f(expiresAt) } } @@ -4872,6 +4881,10 @@ func (c *client) clearTlsToTimer() { // Lock should be held func (c *client) setAuthTimer(d time.Duration) { + if c.atmr != nil { + c.atmr.Stop() + } + c.atmr = time.AfterFunc(d, c.authTimeout) } @@ -4885,6 +4898,16 @@ func (c *client) clearAuthTimer() bool { return stopped } +// Lock should be held +func (c *client) clearRenewTimer() bool { + if c.rtmr == nil { + return true + } + stopped := c.rtmr.Stop() + c.rtmr = nil + return stopped +} + // We may reuse atmr for expiring user jwts, // so check connectReceived. // Lock assume held on entry. @@ -4902,6 +4925,10 @@ func (c *client) setExpirationTimer(d time.Duration) { // This will set the atmr for the JWT expiration time. client lock should be held before call func (c *client) setExpirationTimerUnlocked(d time.Duration) { + // Stop any previous timer + if c.atmr != nil { + c.atmr.Stop() + } c.atmr = time.AfterFunc(d, c.authExpired) // This is an JWT expiration. if c.flags.isSet(connectReceived) { @@ -4909,6 +4936,43 @@ func (c *client) setExpirationTimerUnlocked(d time.Duration) { } } +func (c *client) setRenewalTimer(d time.Duration) { + c.mu.Lock() + c.setRenewalTimerUnlocked(d) + c.mu.Unlock() +} + +func (c *client) setRenewalTimerUnlocked(d time.Duration) { + // Stop any previous timer + if c.rtmr != nil { + c.rtmr.Stop() + } + c.rtmr = time.AfterFunc(d, c.renewCallout) + // This is an JWT expiration. + if c.flags.isSet(connectReceived) { + c.expires = time.Now().Add(d).Truncate(time.Second) + } +} + +func (c *client) renewCallout() { + c.mu.Lock() + srv := c.srv + c.mu.Unlock() + if srv == nil { + return + } + + authorized, _ := srv.processClientOrLeafCallout(c, srv.getOpts()) + // If we are authorized, we will set the renewal timer again. + // Deny the Callout subject. + if authorized { + c.mergeDenyPermissionsLocked(pub, []string{AuthCalloutSubject}) + } else { + // If we are not authorized, we will close the connection in the expiration handler. + c.authExpired() + } +} + // Return when this client expires via a claim, or 0 if not set. func (c *client) claimExpiration() time.Duration { c.mu.Lock() @@ -5075,6 +5139,7 @@ func (c *client) closeConnection(reason ClosedState) { c.rref++ c.flags.set(closeConnection) c.clearAuthTimer() + c.clearRenewTimer() c.clearPingTimer() c.clearTlsToTimer() c.markConnAsClosed(reason)