diff --git a/server/accounts.go b/server/accounts.go index 8f62104b38a..5d37d5f89cf 100644 --- a/server/accounts.go +++ b/server/accounts.go @@ -21,6 +21,7 @@ import ( "hash/fnv" "hash/maphash" "io" + "io/fs" "math" "math/rand" "net/http" @@ -3993,17 +3994,19 @@ func (dr *DirAccResolver) Start(s *Server) error { dr.DirJWTStore.changed = func(pubKey string) { if v, ok := s.accounts.Load(pubKey); ok { if theJwt, err := dr.LoadAcc(pubKey); err != nil { - s.Errorf("update got error on load: %v", err) + s.Errorf("DirResolver - Update got error on load: %v", err) } else { acc := v.(*Account) if err = s.updateAccountWithClaimJWT(acc, theJwt); err != nil { - s.Errorf("update resulted in error %v", err) + s.Errorf("DirResolver - Update for account %q resulted in error %v", pubKey, err) } else { if _, jsa, err := acc.checkForJetStream(); err != nil { - s.Warnf("error checking for JetStream enabled error %v", err) + if !IsNatsErr(err, JSNotEnabledForAccountErr) { + s.Warnf("DirResolver - Error checking for JetStream support for account %q: %v", pubKey, err) + } } else if jsa == nil { if err = s.configJetStream(acc); err != nil { - s.Errorf("updated resulted in error when configuring JetStream %v", err) + s.Errorf("DirResolver - Error configuring JetStream for account %q: %v", pubKey, err) } } } @@ -4024,7 +4027,7 @@ func (dr *DirAccResolver) Start(s *Server) error { } else if len(tk) == accUpdateTokensOld { pubKey = tk[accUpdateAccIdxOld] } else { - s.Debugf("jwt update skipped due to bad subject %q", subj) + s.Debugf("DirResolver - jwt update skipped due to bad subject %q", subj) return } if claim, err := jwt.DecodeAccountClaims(string(msg)); err != nil { @@ -4074,8 +4077,15 @@ func (dr *DirAccResolver) Start(s *Server) error { if len(tk) != accLookupReqTokens { return } - if theJWT, err := dr.DirJWTStore.LoadAcc(tk[accReqAccIndex]); err != nil { - s.Errorf("Merging resulted in error: %v", err) + accName := tk[accReqAccIndex] + if theJWT, err := dr.DirJWTStore.LoadAcc(accName); err != nil { + if errors.Is(err, fs.ErrNotExist) { + s.Debugf("DirResolver - Could not find account %q", accName) + // Reply with empty response to signal absence of JWT to others. + s.sendInternalMsgLocked(reply, _EMPTY_, nil, nil) + } else { + s.Errorf("DirResolver - Error looking up account %q: %v", accName, err) + } } else { s.sendInternalMsgLocked(reply, _EMPTY_, nil, []byte(theJWT)) } @@ -4083,7 +4093,7 @@ func (dr *DirAccResolver) Start(s *Server) error { return fmt.Errorf("error setting up lookup request handling: %v", err) } // respond to pack requests with one or more pack messages - // an empty message signifies the end of the response responder + // an empty message signifies the end of the response responder. if _, err := s.sysSubscribeQ(accPackReqSubj, "responder", func(_ *subscription, _ *client, _ *Account, _, reply string, theirHash []byte) { if reply == _EMPTY_ { return @@ -4091,14 +4101,14 @@ func (dr *DirAccResolver) Start(s *Server) error { ourHash := dr.DirJWTStore.Hash() if bytes.Equal(theirHash, ourHash[:]) { s.sendInternalMsgLocked(reply, _EMPTY_, nil, []byte{}) - s.Debugf("pack request matches hash %x", ourHash[:]) + s.Debugf("DirResolver - Pack request matches hash %x", ourHash[:]) } else if err := dr.DirJWTStore.PackWalk(1, func(partialPackMsg string) { s.sendInternalMsgLocked(reply, _EMPTY_, nil, []byte(partialPackMsg)) }); err != nil { // let them timeout - s.Errorf("pack request error: %v", err) + s.Errorf("DirResolver - Pack request error: %v", err) } else { - s.Debugf("pack request hash %x - finished responding with hash %x", theirHash, ourHash) + s.Debugf("DirResolver - Pack request hash %x - finished responding with hash %x", theirHash, ourHash) s.sendInternalMsgLocked(reply, _EMPTY_, nil, []byte{}) } }); err != nil { @@ -4119,12 +4129,12 @@ func (dr *DirAccResolver) Start(s *Server) error { if _, err := s.sysSubscribe(packRespIb, func(_ *subscription, _ *client, _ *Account, _, _ string, msg []byte) { hash := dr.DirJWTStore.Hash() if len(msg) == 0 { // end of response stream - s.Debugf("Merging Finished and resulting in: %x", dr.DirJWTStore.Hash()) + s.Debugf("DirResolver - Merging finished and resulting in: %x", dr.DirJWTStore.Hash()) return } else if err := dr.DirJWTStore.Merge(string(msg)); err != nil { - s.Errorf("Merging resulted in error: %v", err) + s.Errorf("DirResolver - Merging resulted in error: %v", err) } else { - s.Debugf("Merging succeeded and changed %x to %x", hash, dr.DirJWTStore.Hash()) + s.Debugf("DirResolver - Merging succeeded and changed %x to %x", hash, dr.DirJWTStore.Hash()) } }); err != nil { return fmt.Errorf("error setting up pack response handling: %v", err) @@ -4142,7 +4152,7 @@ func (dr *DirAccResolver) Start(s *Server) error { case <-ticker.C: } ourHash := dr.DirJWTStore.Hash() - s.Debugf("Checking store state: %x", ourHash) + s.Debugf("DirResolver - Checking store state: %x", ourHash) s.sendInternalMsgLocked(accPackReqSubj, packRespIb, nil, ourHash[:]) } }) @@ -4227,20 +4237,35 @@ func (s *Server) fetch(res AccountResolver, name string, timeout time.Duration) s.mu.Unlock() return _EMPTY_, fmt.Errorf("eventing shut down") } + // Resolver will wait for detected active servers to reply + // before serving an error in case there weren't any found. + expectedServers := len(s.sys.servers) replySubj := s.newRespInbox() replies := s.sys.replies + // Store our handler. replies[replySubj] = func(sub *subscription, _ *client, _ *Account, subject, _ string, msg []byte) { - clone := make([]byte, len(msg)) - copy(clone, msg) + var clone []byte + isEmpty := len(msg) == 0 + if !isEmpty { + clone = make([]byte, len(msg)) + copy(clone, msg) + } s.mu.Lock() + defer s.mu.Unlock() + expectedServers-- + // Skip empty responses until getting all the available servers. + if isEmpty && expectedServers > 0 { + return + } + // Use the first valid response if there is still interest or + // one of the empty responses to signal that it was not found. if _, ok := replies[replySubj]; ok { select { - case respC <- clone: // only use first response and only if there is still interest + case respC <- clone: default: } } - s.mu.Unlock() } s.sendInternalMsg(accountLookupRequest, replySubj, nil, []byte{}) quit := s.quitCh @@ -4253,7 +4278,9 @@ func (s *Server) fetch(res AccountResolver, name string, timeout time.Duration) case <-time.After(timeout): err = errors.New("fetching jwt timed out") case m := <-respC: - if err = res.Store(name, string(m)); err == nil { + if len(m) == 0 { + err = errors.New("account jwt not found") + } else if err = res.Store(name, string(m)); err == nil { theJWT = string(m) } } @@ -4291,9 +4318,9 @@ func (dr *CacheDirAccResolver) Start(s *Server) error { dr.DirJWTStore.changed = func(pubKey string) { if v, ok := s.accounts.Load(pubKey); !ok { } else if theJwt, err := dr.LoadAcc(pubKey); err != nil { - s.Errorf("update got error on load: %v", err) + s.Errorf("DirResolver - Update got error on load: %v", err) } else if err := s.updateAccountWithClaimJWT(v.(*Account), theJwt); err != nil { - s.Errorf("update resulted in error %v", err) + s.Errorf("DirResolver - Update resulted in error %v", err) } } dr.DirJWTStore.deleted = func(pubKey string) { @@ -4309,7 +4336,7 @@ func (dr *CacheDirAccResolver) Start(s *Server) error { } else if len(tk) == accUpdateTokensOld { pubKey = tk[accUpdateAccIdxOld] } else { - s.Debugf("jwt update cache skipped due to bad subject %q", subj) + s.Debugf("DirResolver - jwt update cache skipped due to bad subject %q", subj) return } if claim, err := jwt.DecodeAccountClaims(string(msg)); err != nil { diff --git a/server/dirstore.go b/server/dirstore.go index cabd4a1997a..b39ab9ae082 100644 --- a/server/dirstore.go +++ b/server/dirstore.go @@ -288,6 +288,10 @@ func (store *DirJWTStore) PackWalk(maxJWTs int, cb func(partialPackMsg string)) if err != nil { return err } + if len(jwtBytes) == 0 { + // Skip if no contents in the JWT. + return nil + } if exp != nil { claim, err := jwt.DecodeGeneric(string(jwtBytes)) if err == nil && claim.Expires > 0 && claim.Expires < time.Now().Unix() { @@ -406,6 +410,9 @@ func (store *DirJWTStore) load(publicKey string) (string, error) { // write that keeps hash of all jwt in sync // Assumes the lock is held. Does return true or an error never both. func (store *DirJWTStore) write(path string, publicKey string, theJWT string) (bool, error) { + if len(theJWT) == 0 { + return false, fmt.Errorf("invalid JWT") + } var newHash *[sha256.Size]byte if store.expiration != nil { h := sha256.Sum256([]byte(theJWT)) diff --git a/server/jwt_test.go b/server/jwt_test.go index 1085892434c..58884b97c51 100644 --- a/server/jwt_test.go +++ b/server/jwt_test.go @@ -3692,7 +3692,7 @@ func TestJWTAccountNATSResolverCrossClusterFetch(t *testing.T) { listen: 127.0.0.1:-1 no_advertise: true } - `, ojwt, syspub, dirAA))) + `, ojwt, syspub, dirAA))) sAA, _ := RunServerWithConfig(confAA) defer sAA.Shutdown() // Create Server B (using no_advertise to prevent fail over) @@ -3718,7 +3718,7 @@ func TestJWTAccountNATSResolverCrossClusterFetch(t *testing.T) { nats-route://127.0.0.1:%d ] } - `, ojwt, syspub, dirAB, sAA.opts.Cluster.Port))) + `, ojwt, syspub, dirAB, sAA.opts.Cluster.Port))) sAB, _ := RunServerWithConfig(confAB) defer sAB.Shutdown() // Create Server C (using no_advertise to prevent fail over) @@ -3744,10 +3744,10 @@ func TestJWTAccountNATSResolverCrossClusterFetch(t *testing.T) { listen: 127.0.0.1:-1 no_advertise: true } - `, ojwt, syspub, dirBA, sAA.opts.Gateway.Port))) + `, ojwt, syspub, dirBA, sAA.opts.Gateway.Port))) sBA, _ := RunServerWithConfig(confBA) defer sBA.Shutdown() - // Create Sever BA (using no_advertise to prevent fail over) + // Create Server BA (using no_advertise to prevent fail over) confBB := createConfFile(t, []byte(fmt.Sprintf(` listen: 127.0.0.1:-1 server_name: srv-B-B @@ -3773,7 +3773,7 @@ func TestJWTAccountNATSResolverCrossClusterFetch(t *testing.T) { {name: "clust-A", url: "nats://127.0.0.1:%d"}, ] } - `, ojwt, syspub, dirBB, sBA.opts.Cluster.Port, sAA.opts.Cluster.Port))) + `, ojwt, syspub, dirBB, sBA.opts.Cluster.Port, sAA.opts.Cluster.Port))) sBB, _ := RunServerWithConfig(confBB) defer sBB.Shutdown() // Assert topology @@ -6592,3 +6592,190 @@ func TestServerOperatorModeNoAuthRequired(t *testing.T) { require_True(t, nc.AuthRequired()) } + +func TestJWTAccountNATSResolverWrongCreds(t *testing.T) { + require_NoLocalOrRemoteConnections := func(account string, srvs ...*Server) { + t.Helper() + for _, srv := range srvs { + if acc, ok := srv.accounts.Load(account); ok { + checkAccClientsCount(t, acc.(*Account), 0) + } + } + } + connect := func(url string, credsfile string, acc string, srvs ...*Server) { + t.Helper() + nc := natsConnect(t, url, nats.UserCredentials(credsfile), nats.Timeout(5*time.Second)) + nc.Close() + require_NoLocalOrRemoteConnections(acc, srvs...) + } + createAccountAndUser := func(limit bool, done chan struct{}, pubKey, jwt1, jwt2, creds *string) { + t.Helper() + kp, _ := nkeys.CreateAccount() + *pubKey, _ = kp.PublicKey() + claim := jwt.NewAccountClaims(*pubKey) + var err error + *jwt1, err = claim.Encode(oKp) + require_NoError(t, err) + *jwt2, err = claim.Encode(oKp) + require_NoError(t, err) + ukp, _ := nkeys.CreateUser() + seed, _ := ukp.Seed() + upub, _ := ukp.PublicKey() + uclaim := newJWTTestUserClaims() + uclaim.Subject = upub + ujwt, err := uclaim.Encode(kp) + require_NoError(t, err) + *creds = genCredsFile(t, ujwt, seed) + done <- struct{}{} + } + // Create Accounts and corresponding user creds. + doneChan := make(chan struct{}, 4) + defer close(doneChan) + var syspub, sysjwt, dummy1, sysCreds string + createAccountAndUser(false, doneChan, &syspub, &sysjwt, &dummy1, &sysCreds) + + var apub, ajwt1, ajwt2, aCreds string + createAccountAndUser(true, doneChan, &apub, &ajwt1, &ajwt2, &aCreds) + + var bpub, bjwt1, bjwt2, bCreds string + createAccountAndUser(true, doneChan, &bpub, &bjwt1, &bjwt2, &bCreds) + + // The one that is going to be missing. + var cpub, cjwt1, cjwt2, cCreds string + createAccountAndUser(true, doneChan, &cpub, &cjwt1, &cjwt2, &cCreds) + for i := 0; i < cap(doneChan); i++ { + <-doneChan + } + // Create one directory for each server + dirA := t.TempDir() + dirB := t.TempDir() + dirC := t.TempDir() + + // Store accounts on servers A and B, then let C sync on its own. + writeJWT(t, dirA, apub, ajwt1) + writeJWT(t, dirB, bpub, bjwt1) + + ///////////////////////////////////////// + // // + // Server A: has creds from client A // + // // + ///////////////////////////////////////// + confA := createConfFile(t, []byte(fmt.Sprintf(` + listen: 127.0.0.1:-1 + server_name: srv-A + operator: %s + system_account: %s + debug: true + resolver: { + type: full + dir: '%s' + allow_delete: true + timeout: "1.5s" + interval: "200ms" + } + resolver_preload: { + %s: %s + } + cluster { + name: clust + listen: 127.0.0.1:-1 + no_advertise: true + } + `, ojwt, syspub, dirA, apub, ajwt1))) + sA, _ := RunServerWithConfig(confA) + defer sA.Shutdown() + require_JWTPresent(t, dirA, apub) + + ///////////////////////////////////////// + // // + // Server B: has creds from client B // + // // + ///////////////////////////////////////// + confB := createConfFile(t, []byte(fmt.Sprintf(` + listen: 127.0.0.1:-1 + server_name: srv-B + operator: %s + system_account: %s + resolver: { + type: full + dir: '%s' + allow_delete: true + timeout: "1.5s" + interval: "200ms" + } + cluster { + name: clust + listen: 127.0.0.1:-1 + no_advertise: true + routes [ + nats-route://127.0.0.1:%d + ] + } + `, ojwt, syspub, dirB, sA.opts.Cluster.Port))) + sB, _ := RunServerWithConfig(confB) + defer sB.Shutdown() + + ///////////////////////////////////////// + // // + // Server C: has no creds // + // // + ///////////////////////////////////////// + fmtC := ` + listen: 127.0.0.1:-1 + server_name: srv-C + operator: %s + system_account: %s + resolver: { + type: full + dir: '%s' + allow_delete: true + timeout: "1.5s" + interval: "200ms" + } + cluster { + name: clust + listen: 127.0.0.1:-1 + no_advertise: true + routes [ + nats-route://127.0.0.1:%d + ] + } + ` + confClongTTL := createConfFile(t, []byte(fmt.Sprintf(fmtC, ojwt, syspub, dirC, sA.opts.Cluster.Port))) + sC, _ := RunServerWithConfig(confClongTTL) // use long ttl to assure it is not kicking + defer sC.Shutdown() + + // startup cluster + checkClusterFormed(t, sA, sB, sC) + time.Sleep(1 * time.Second) // wait for the protocol to converge + // // Check all accounts + require_JWTPresent(t, dirA, apub) // was already present on startup + require_JWTPresent(t, dirB, apub) // was copied from server A + require_JWTPresent(t, dirA, bpub) // was copied from server B + require_JWTPresent(t, dirB, bpub) // was already present on startup + + // There should be no state about the missing account. + require_JWTAbsent(t, dirA, cpub) + require_JWTAbsent(t, dirB, cpub) + require_JWTAbsent(t, dirC, cpub) + + // system account client can connect to every server + connect(sA.ClientURL(), sysCreds, "") + connect(sB.ClientURL(), sysCreds, "") + connect(sC.ClientURL(), sysCreds, "") + + // A and B clients can connect to any server. + connect(sA.ClientURL(), aCreds, "") + connect(sB.ClientURL(), aCreds, "") + connect(sC.ClientURL(), aCreds, "") + connect(sA.ClientURL(), bCreds, "") + connect(sB.ClientURL(), bCreds, "") + connect(sC.ClientURL(), bCreds, "") + + // Check that trying to connect with bad credentials should not hang until the fetch timeout + // and instead return a faster response when an account is not found. + _, err := nats.Connect(sC.ClientURL(), nats.UserCredentials(cCreds), nats.Timeout(500*time.Second)) + if err != nil && !errors.Is(err, nats.ErrAuthorization) { + t.Fatalf("Expected auth error: %v", err) + } +}