From 3608bcfd1c2fa1299af1d98c100557bf148cefdc Mon Sep 17 00:00:00 2001 From: Ivan Kozlovic Date: Wed, 14 Feb 2024 10:13:22 -0700 Subject: [PATCH] Made updates to message tracing code to support JWT updates Related to #5014 Signed-off-by: Ivan Kozlovic --- go.mod | 2 +- go.sum | 10 +- server/accounts.go | 38 +++- server/client.go | 2 +- server/msgtrace_test.go | 491 ++++++++++++++++++++++++++++++++-------- 5 files changed, 428 insertions(+), 115 deletions(-) diff --git a/go.mod b/go.mod index 996548dcacb..0b07a34f690 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.20 require ( github.com/klauspost/compress v1.17.6 github.com/minio/highwayhash v1.0.2 - github.com/nats-io/jwt/v2 v2.5.3 + github.com/nats-io/jwt/v2 v2.5.4-0.20240214164243-40f01dce329c github.com/nats-io/nats.go v1.32.0 github.com/nats-io/nkeys v0.4.7 github.com/nats-io/nuid v1.0.1 diff --git a/go.sum b/go.sum index 010a129d499..24248295128 100644 --- a/go.sum +++ b/go.sum @@ -1,12 +1,10 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/klauspost/compress v1.17.5 h1:d4vBd+7CHydUqpFBgUEKkSdtSugf9YFmSkvUYPquI5E= -github.com/klauspost/compress v1.17.5/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/klauspost/compress v1.17.6 h1:60eq2E/jlfwQXtvZEeBUYADs+BwKBWURIY+Gj2eRGjI= github.com/klauspost/compress v1.17.6/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA/g= github.com/minio/highwayhash v1.0.2/go.mod h1:BQskDq+xkJ12lmlUUi7U0M5Swg3EWR+dLTk+kldvVxY= -github.com/nats-io/jwt/v2 v2.5.3 h1:/9SWvzc6hTfamcgXJ3uYRpgj+QuY2aLNqRiqrKcrpEo= -github.com/nats-io/jwt/v2 v2.5.3/go.mod h1:iysuPemFcc7p4IoYots3IuELSI4EDe9Y0bQMe+I3Bf4= +github.com/nats-io/jwt/v2 v2.5.4-0.20240214164243-40f01dce329c h1:MxTE8IDAMTeHunOmrm4DojsMiZfOO+5ovBD+Ca3Guy0= +github.com/nats-io/jwt/v2 v2.5.4-0.20240214164243-40f01dce329c/go.mod h1:9d5GwImcMyYc5qCEt6N3ebkyviwwVBssCnHz9yRqPCM= github.com/nats-io/nats.go v1.32.0 h1:Bx9BZS+aXYlxW08k8Gd3yR2s73pV5XSoAQUyp1Kwvp0= github.com/nats-io/nats.go v1.32.0/go.mod h1:Ubdu4Nh9exXdSz0RVWRFBbRfrbSxOYd26oF0wkWclB8= github.com/nats-io/nkeys v0.4.7 h1:RwNJbbIdYCoClSDNY7QVKZlyb/wfT6ugvFCiKy6vDvI= @@ -18,13 +16,9 @@ github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4 github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= go.uber.org/automaxprocs v1.5.3 h1:kWazyxZUrS3Gs4qUpbwo5kEIMGe/DAvi5Z4tl2NW4j8= go.uber.org/automaxprocs v1.5.3/go.mod h1:eRbA25aqJrxAbsLO0xy5jVwPt7FQnRgjW+efnwa1WM0= -golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc= -golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/sys v0.0.0-20190130150945-aca44879d564/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= -golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= diff --git a/server/accounts.go b/server/accounts.go index ddcb87f74ee..19b0e247cb1 100644 --- a/server/accounts.go +++ b/server/accounts.go @@ -157,6 +157,7 @@ type serviceImport struct { share bool tracking bool didDeliver bool + atrc bool // allow trace (got from service export) trackingHdr http.Header // header from request } @@ -1908,11 +1909,13 @@ func (a *Account) addServiceImport(dest *Account, from, to string, claim *jwt.Im return nil, ErrMissingAccount } + var atrc bool dest.mu.RLock() se := dest.getServiceExport(to) if se != nil { rt = se.respType lat = se.latency + atrc = se.atrc } dest.mu.RUnlock() @@ -1956,7 +1959,7 @@ func (a *Account) addServiceImport(dest *Account, from, to string, claim *jwt.Im if claim != nil { share = claim.Share } - si := &serviceImport{dest, claim, se, nil, from, to, tr, 0, rt, lat, nil, nil, usePub, false, false, share, false, false, nil} + si := &serviceImport{dest, claim, se, nil, from, to, tr, 0, rt, lat, nil, nil, usePub, false, false, share, false, false, atrc, nil} a.imports.services[from] = si a.mu.Unlock() @@ -2421,7 +2424,7 @@ func (a *Account) addRespServiceImport(dest *Account, to string, osi *serviceImp // dest is the requestor's account. a is the service responder with the export. // Marked as internal here, that is how we distinguish. - si := &serviceImport{dest, nil, osi.se, nil, nrr, to, nil, 0, rt, nil, nil, nil, false, true, false, osi.share, false, false, nil} + si := &serviceImport{dest, nil, osi.se, nil, nrr, to, nil, 0, rt, nil, nil, nil, false, true, false, osi.share, false, false, false, nil} if a.exports.responses == nil { a.exports.responses = make(map[string]*serviceImport) @@ -2531,10 +2534,9 @@ func (a *Account) addMappedStreamImportWithClaim(account *Account, from, to stri a.mu.Unlock() return ErrStreamImportDuplicate } - // TODO(ik): When AllowTrace is added to JWT, uncomment those lines: - // if imClaim != nil { - // allowTrace = imClaim.AllowTrace - // } + if imClaim != nil { + allowTrace = imClaim.AllowTrace + } a.imports.streams = append(a.imports.streams, &streamImport{account, from, to, tr, nil, imClaim, usePub, false, allowTrace}) a.mu.Unlock() return nil @@ -2869,7 +2871,9 @@ func (a *Account) checkStreamImportsEqual(b *Account) bool { bm[bim.acc.Name+bim.from+bim.to] = bim } for _, aim := range a.imports.streams { - if _, ok := bm[aim.acc.Name+aim.from+aim.to]; !ok { + if bim, ok := bm[aim.acc.Name+aim.from+aim.to]; !ok { + return false + } else if aim.atrc != bim.atrc { return false } } @@ -2955,6 +2959,9 @@ func isServiceExportEqual(a, b *serviceExport) bool { return false } } + if a.atrc != b.atrc { + return false + } return true } @@ -3232,6 +3239,9 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim a.nameTag = ac.Name a.tags = ac.Tags + // Update TraceDest + a.traceDest = string(ac.TraceDest) + // Check for external authorization. if ac.HasExternalAuthorization() { a.extAuth = &jwt.ExternalAuthorization{} @@ -3360,6 +3370,9 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim s.Debugf("Error adding service export response threshold for [%s]: %v", a.traceLabel(), err) } } + if err := a.SetServiceExportAllowTrace(sub, e.AllowTrace); err != nil { + s.Debugf("Error adding allow_trace for %q: %v", sub, err) + } } var revocationChanged *bool @@ -3500,10 +3513,15 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim if si != nil && si.acc.Name == a.Name { // Check for if we are still authorized for an import. si.invalid = !a.checkServiceImportAuthorized(acc, si.to, si.claim) - if si.latency != nil && !si.response { - // Make sure we should still be tracking latency. + // Make sure we should still be tracking latency and if we + // are allowed to trace. + if !si.response { if se := a.getServiceExport(si.to); se != nil { - si.latency = se.latency + if si.latency != nil { + si.latency = se.latency + } + // Update allow trace. + si.atrc = se.atrc } } } diff --git a/server/client.go b/server/client.go index 1938096e668..09bacceaadc 100644 --- a/server/client.go +++ b/server/client.go @@ -4198,7 +4198,7 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt } } siAcc := si.acc - allowTrace := si.se != nil && si.se.atrc + allowTrace := si.atrc acc.mu.RUnlock() // We have a special case where JetStream pulls in all service imports through one export. diff --git a/server/msgtrace_test.go b/server/msgtrace_test.go index b6ced07e0f0..a06d8da560d 100644 --- a/server/msgtrace_test.go +++ b/server/msgtrace_test.go @@ -26,7 +26,9 @@ import ( "time" "github.com/klauspost/compress/s2" + "github.com/nats-io/jwt/v2" "github.com/nats-io/nats.go" + "github.com/nats-io/nkeys" ) func init() { @@ -1697,71 +1699,74 @@ func TestMsgTraceWithGatewayToOldServer(t *testing.T) { } func TestMsgTraceServiceImport(t *testing.T) { - for _, mainTest := range []struct { - name string - allow bool - }{ - {"allowed", true}, - {"not allowed", false}, - } { - t.Run(mainTest.name, func(t *testing.T) { - conf := createConfFile(t, []byte(fmt.Sprintf(` - listen: 127.0.0.1:-1 - accounts { - A { - users: [{user: a, password: pwd}] - exports: [ { service: ">", allow_trace: %v} ] - mappings = { - bar: bozo - } - } - B { - users: [{user: b, password: pwd}] - imports: [ { service: {account: "A", subject:">"} } ] - exports: [ { service: ">", allow_trace: %v} ] - } - C { - users: [{user: c, password: pwd}] - exports: [ { service: ">", allow_trace: %v } ] + tmpl := ` + listen: 127.0.0.1:-1 + accounts { + A { + users: [{user: a, password: pwd}] + exports: [ { service: ">", allow_trace: %v} ] + mappings = { + bar: bozo } - D { - users: [{user: d, password: pwd}] - imports: [ - { service: {account: "B", subject:"bar"}, to: baz } - { service: {account: "C", subject:">"} } - ] - mappings = { - bat: baz - } + } + B { + users: [{user: b, password: pwd}] + imports: [ { service: {account: "A", subject:">"} } ] + exports: [ { service: ">", allow_trace: %v} ] + } + C { + users: [{user: c, password: pwd}] + exports: [ { service: ">", allow_trace: %v } ] + } + D { + users: [{user: d, password: pwd}] + imports: [ + { service: {account: "B", subject:"bar"}, to: baz } + { service: {account: "C", subject:">"} } + ] + mappings = { + bat: baz } } - `, mainTest.allow, mainTest.allow, mainTest.allow))) - s, _ := RunServerWithConfig(conf) - defer s.Shutdown() + } + ` + conf := createConfFile(t, []byte(fmt.Sprintf(tmpl, false, false, false))) + s, _ := RunServerWithConfig(conf) + defer s.Shutdown() - nc := natsConnect(t, s.ClientURL(), nats.UserInfo("d", "pwd"), nats.Name("Requestor")) - defer nc.Close() + nc := natsConnect(t, s.ClientURL(), nats.UserInfo("d", "pwd"), nats.Name("Requestor")) + defer nc.Close() - traceSub := natsSubSync(t, nc, "my.trace.subj") - sub := natsSubSync(t, nc, "my.service.response.inbox") + traceSub := natsSubSync(t, nc, "my.trace.subj") + sub := natsSubSync(t, nc, "my.service.response.inbox") - nc2 := natsConnect(t, s.ClientURL(), nats.UserInfo("a", "pwd"), nats.Name("ServiceA")) - defer nc2.Close() - recv := int32(0) - natsQueueSub(t, nc2, "*", "my_queue", func(m *nats.Msg) { - atomic.AddInt32(&recv, 1) - m.Respond(m.Data) - }) - natsFlush(t, nc2) + nc2 := natsConnect(t, s.ClientURL(), nats.UserInfo("a", "pwd"), nats.Name("ServiceA")) + defer nc2.Close() + recv := int32(0) + natsQueueSub(t, nc2, "*", "my_queue", func(m *nats.Msg) { + atomic.AddInt32(&recv, 1) + m.Respond(m.Data) + }) + natsFlush(t, nc2) - nc3 := natsConnect(t, s.ClientURL(), nats.UserInfo("c", "pwd"), nats.Name("ServiceC")) - defer nc3.Close() - natsSub(t, nc3, "baz", func(m *nats.Msg) { - atomic.AddInt32(&recv, 1) - m.Respond(m.Data) - }) - natsFlush(t, nc3) + nc3 := natsConnect(t, s.ClientURL(), nats.UserInfo("c", "pwd"), nats.Name("ServiceC")) + defer nc3.Close() + natsSub(t, nc3, "baz", func(m *nats.Msg) { + atomic.AddInt32(&recv, 1) + m.Respond(m.Data) + }) + natsFlush(t, nc3) + for mainIter, mainTest := range []struct { + name string + allow bool + }{ + {"not allowed", false}, + {"allowed", true}, + {"not allowed again", false}, + } { + atomic.StoreInt32(&recv, 0) + t.Run(mainTest.name, func(t *testing.T) { for _, test := range []struct { name string deliverMsg bool @@ -1890,6 +1895,12 @@ func TestMsgTraceServiceImport(t *testing.T) { } }) } + switch mainIter { + case 0: + reloadUpdateConfig(t, s, conf, fmt.Sprintf(tmpl, true, true, true)) + case 1: + reloadUpdateConfig(t, s, conf, fmt.Sprintf(tmpl, false, false, false)) + } }) } } @@ -2596,50 +2607,52 @@ func TestMsgTraceServiceImportWithLeafNodeLeaf(t *testing.T) { } func TestMsgTraceStreamExport(t *testing.T) { - for _, mainTest := range []struct { + tmpl := ` + listen: 127.0.0.1:-1 + accounts { + A { + users: [{user: a, password: pwd}] + exports: [ + { stream: "info.*.*.>"} + ] + } + B { + users: [{user: b, password: pwd}] + imports: [ { stream: {account: "A", subject:"info.*.*.>"}, to: "B.info.$2.$1.>", allow_trace: %v } ] + } + C { + users: [{user: c, password: pwd}] + imports: [ { stream: {account: "A", subject:"info.*.*.>"}, to: "C.info.$1.$2.>", allow_trace: %v } ] + } + } + ` + conf := createConfFile(t, []byte(fmt.Sprintf(tmpl, false, false))) + s, _ := RunServerWithConfig(conf) + defer s.Shutdown() + + nc := natsConnect(t, s.ClientURL(), nats.UserInfo("a", "pwd"), nats.Name("Tracer")) + defer nc.Close() + traceSub := natsSubSync(t, nc, "my.trace.subj") + + nc2 := natsConnect(t, s.ClientURL(), nats.UserInfo("b", "pwd"), nats.Name("sub1")) + defer nc2.Close() + sub1 := natsSubSync(t, nc2, "B.info.*.*.>") + natsFlush(t, nc2) + + nc3 := natsConnect(t, s.ClientURL(), nats.UserInfo("c", "pwd"), nats.Name("sub2")) + defer nc3.Close() + sub2 := natsQueueSubSync(t, nc3, "C.info.>", "my_queue") + natsFlush(t, nc3) + + for mainIter, mainTest := range []struct { name string allow bool }{ - {"allowed", true}, {"not allowed", false}, + {"allowed", true}, + {"not allowed again", false}, } { t.Run(mainTest.name, func(t *testing.T) { - conf := createConfFile(t, []byte(fmt.Sprintf(` - listen: 127.0.0.1:-1 - accounts { - A { - users: [{user: a, password: pwd}] - exports: [ - { stream: "info.*.*.>"} - ] - } - B { - users: [{user: b, password: pwd}] - imports: [ { stream: {account: "A", subject:"info.*.*.>"}, to: "B.info.$2.$1.>", allow_trace: %v } ] - } - C { - users: [{user: c, password: pwd}] - imports: [ { stream: {account: "A", subject:"info.*.*.>"}, to: "C.info.$1.$2.>", allow_trace: %v } ] - } - } - `, mainTest.allow, mainTest.allow))) - s, _ := RunServerWithConfig(conf) - defer s.Shutdown() - - nc := natsConnect(t, s.ClientURL(), nats.UserInfo("a", "pwd"), nats.Name("Tracer")) - defer nc.Close() - traceSub := natsSubSync(t, nc, "my.trace.subj") - - nc2 := natsConnect(t, s.ClientURL(), nats.UserInfo("b", "pwd"), nats.Name("sub1")) - defer nc2.Close() - sub1 := natsSubSync(t, nc2, "B.info.*.*.>") - natsFlush(t, nc2) - - nc3 := natsConnect(t, s.ClientURL(), nats.UserInfo("c", "pwd"), nats.Name("sub2")) - defer nc3.Close() - sub2 := natsQueueSubSync(t, nc3, "C.info.>", "my_queue") - natsFlush(t, nc3) - for _, test := range []struct { name string deliverMsg bool @@ -2720,6 +2733,12 @@ func TestMsgTraceStreamExport(t *testing.T) { } }) } + switch mainIter { + case 0: + reloadUpdateConfig(t, s, conf, fmt.Sprintf(tmpl, true, true)) + case 1: + reloadUpdateConfig(t, s, conf, fmt.Sprintf(tmpl, false, false)) + } }) } } @@ -4597,3 +4616,285 @@ func TestMsgTraceTriggeredByExternalHeader(t *testing.T) { }) } } + +func TestMsgTraceAccountTraceDestJWTUpdate(t *testing.T) { + // create system account + sysKp, _ := nkeys.CreateAccount() + sysPub, _ := sysKp.PublicKey() + sysCreds := newUser(t, sysKp) + // create account A + akp, _ := nkeys.CreateAccount() + aPub, _ := akp.PublicKey() + claim := jwt.NewAccountClaims(aPub) + aJwt, err := claim.Encode(oKp) + require_NoError(t, err) + + dir := t.TempDir() + conf := createConfFile(t, []byte(fmt.Sprintf(` + listen: -1 + operator: %s + resolver: { + type: full + dir: '%s' + } + system_account: %s + `, ojwt, dir, sysPub))) + s, _ := RunServerWithConfig(conf) + defer s.Shutdown() + updateJwt(t, s.ClientURL(), sysCreds, aJwt, 1) + + nc := natsConnect(t, s.ClientURL(), createUserCreds(t, nil, akp)) + defer nc.Close() + + sub := natsSubSync(t, nc, "acc.trace.dest") + natsFlush(t, nc) + + for i, test := range []struct { + name string + traceTriggered bool + }{ + {"no acc dest", false}, + {"adding trace dest", true}, + {"removing trace dest", false}, + } { + t.Run(test.name, func(t *testing.T) { + msg := nats.NewMsg("foo") + msg.Header.Set(traceParentHdr, "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01") + msg.Data = []byte("hello") + err = nc.PublishMsg(msg) + require_NoError(t, err) + + if test.traceTriggered { + tm := natsNexMsg(t, sub, time.Second) + var e MsgTraceEvent + err = json.Unmarshal(tm.Data, &e) + require_NoError(t, err) + // Simple check + require_Equal[string](t, e.Server.Name, s.Name()) + } + // No (more) trace message expected. + tm, err := sub.NextMsg(250 * time.Millisecond) + if err != nats.ErrTimeout { + t.Fatalf("Expected no trace message, got %s", tm.Data) + } + if i < 2 { + if i == 0 { + claim.TraceDest = "acc.trace.dest" + } else { + claim.TraceDest = _EMPTY_ + } + aJwt, err = claim.Encode(oKp) + require_NoError(t, err) + updateJwt(t, s.ClientURL(), sysCreds, aJwt, 1) + } + }) + } +} + +func TestMsgTraceServiceJWTUpdate(t *testing.T) { + // create system account + sysKp, _ := nkeys.CreateAccount() + sysPub, _ := sysKp.PublicKey() + sysCreds := newUser(t, sysKp) + // create account A + akp, _ := nkeys.CreateAccount() + aPub, _ := akp.PublicKey() + aClaim := jwt.NewAccountClaims(aPub) + serviceExport := &jwt.Export{Subject: "req", Type: jwt.Service} + aClaim.Exports.Add(serviceExport) + aJwt, err := aClaim.Encode(oKp) + require_NoError(t, err) + // create account B + bkp, _ := nkeys.CreateAccount() + bPub, _ := bkp.PublicKey() + bClaim := jwt.NewAccountClaims(bPub) + serviceImport := &jwt.Import{Account: aPub, Subject: "req", Type: jwt.Service} + bClaim.Imports.Add(serviceImport) + bJwt, err := bClaim.Encode(oKp) + require_NoError(t, err) + + dir := t.TempDir() + conf := createConfFile(t, []byte(fmt.Sprintf(` + listen: -1 + operator: %s + resolver: { + type: full + dir: '%s' + } + system_account: %s + `, ojwt, dir, sysPub))) + s, _ := RunServerWithConfig(conf) + defer s.Shutdown() + updateJwt(t, s.ClientURL(), sysCreds, aJwt, 1) + updateJwt(t, s.ClientURL(), sysCreds, bJwt, 1) + + ncA := natsConnect(t, s.ClientURL(), createUserCreds(t, nil, akp), nats.Name("Service")) + defer ncA.Close() + + natsSub(t, ncA, "req", func(m *nats.Msg) { + m.Respond([]byte("resp")) + }) + natsFlush(t, ncA) + + ncB := natsConnect(t, s.ClientURL(), createUserCreds(t, nil, bkp)) + defer ncB.Close() + + sub := natsSubSync(t, ncB, "trace.dest") + natsFlush(t, ncB) + + for i, test := range []struct { + name string + allowTrace bool + }{ + {"trace not allowed", false}, + {"trace allowed", true}, + {"trace not allowed again", false}, + } { + t.Run(test.name, func(t *testing.T) { + msg := nats.NewMsg("req") + msg.Header.Set(MsgTraceDest, sub.Subject) + msg.Data = []byte("request") + reply, err := ncB.RequestMsg(msg, time.Second) + require_NoError(t, err) + require_Equal[string](t, string(reply.Data), "resp") + + tm := natsNexMsg(t, sub, time.Second) + var e MsgTraceEvent + err = json.Unmarshal(tm.Data, &e) + require_NoError(t, err) + require_Equal[string](t, e.Server.Name, s.Name()) + require_Equal[string](t, e.Ingress().Account, bPub) + sis := e.ServiceImports() + require_Equal[int](t, len(sis), 1) + si := sis[0] + require_Equal[string](t, si.Account, aPub) + egresses := e.Egresses() + if !test.allowTrace { + require_Equal[int](t, len(egresses), 0) + } else { + require_Equal[int](t, len(egresses), 1) + eg := egresses[0] + require_Equal[string](t, eg.Name, "Service") + require_Equal[string](t, eg.Account, aPub) + require_Equal[string](t, eg.Subscription, "req") + } + // No (more) trace message expected. + tm, err = sub.NextMsg(250 * time.Millisecond) + if err != nats.ErrTimeout { + t.Fatalf("Expected no trace message, got %s", tm.Data) + } + if i < 2 { + // Set AllowTrace to true at the first iteration, then + // false at the second. + aClaim.Exports[0].AllowTrace = (i == 0) + aJwt, err = aClaim.Encode(oKp) + require_NoError(t, err) + updateJwt(t, s.ClientURL(), sysCreds, aJwt, 1) + } + }) + } +} + +func TestMsgTraceStreamJWTUpdate(t *testing.T) { + // create system account + sysKp, _ := nkeys.CreateAccount() + sysPub, _ := sysKp.PublicKey() + sysCreds := newUser(t, sysKp) + // create account A + akp, _ := nkeys.CreateAccount() + aPub, _ := akp.PublicKey() + aClaim := jwt.NewAccountClaims(aPub) + streamExport := &jwt.Export{Subject: "info", Type: jwt.Stream} + aClaim.Exports.Add(streamExport) + aJwt, err := aClaim.Encode(oKp) + require_NoError(t, err) + // create account B + bkp, _ := nkeys.CreateAccount() + bPub, _ := bkp.PublicKey() + bClaim := jwt.NewAccountClaims(bPub) + streamImport := &jwt.Import{Account: aPub, Subject: "info", To: "b", Type: jwt.Stream} + bClaim.Imports.Add(streamImport) + bJwt, err := bClaim.Encode(oKp) + require_NoError(t, err) + + dir := t.TempDir() + conf := createConfFile(t, []byte(fmt.Sprintf(` + listen: -1 + operator: %s + resolver: { + type: full + dir: '%s' + } + system_account: %s + `, ojwt, dir, sysPub))) + s, _ := RunServerWithConfig(conf) + defer s.Shutdown() + updateJwt(t, s.ClientURL(), sysCreds, aJwt, 1) + updateJwt(t, s.ClientURL(), sysCreds, bJwt, 1) + + ncA := natsConnect(t, s.ClientURL(), createUserCreds(t, nil, akp)) + defer ncA.Close() + + traceSub := natsSubSync(t, ncA, "trace.dest") + natsFlush(t, ncA) + + ncB := natsConnect(t, s.ClientURL(), createUserCreds(t, nil, bkp), nats.Name("BInfo")) + defer ncB.Close() + + appSub := natsSubSync(t, ncB, "b.info") + natsFlush(t, ncB) + + for i, test := range []struct { + name string + allowTrace bool + }{ + {"trace not allowed", false}, + {"trace allowed", true}, + {"trace not allowed again", false}, + } { + t.Run(test.name, func(t *testing.T) { + msg := nats.NewMsg("info") + msg.Header.Set(MsgTraceDest, traceSub.Subject) + msg.Data = []byte("some info") + err = ncA.PublishMsg(msg) + require_NoError(t, err) + + appMsg := natsNexMsg(t, appSub, time.Second) + require_Equal[string](t, string(appMsg.Data), "some info") + + tm := natsNexMsg(t, traceSub, time.Second) + var e MsgTraceEvent + err = json.Unmarshal(tm.Data, &e) + require_NoError(t, err) + require_Equal[string](t, e.Server.Name, s.Name()) + ses := e.StreamExports() + require_Equal[int](t, len(ses), 1) + se := ses[0] + require_Equal[string](t, se.Account, bPub) + require_Equal[string](t, se.To, "b.info") + egresses := e.Egresses() + if !test.allowTrace { + require_Equal[int](t, len(egresses), 0) + } else { + require_Equal[int](t, len(egresses), 1) + eg := egresses[0] + require_Equal[string](t, eg.Name, "BInfo") + require_Equal[string](t, eg.Account, bPub) + require_Equal[string](t, eg.Subscription, "info") + } + // No (more) trace message expected. + tm, err = traceSub.NextMsg(250 * time.Millisecond) + if err != nats.ErrTimeout { + t.Fatalf("Expected no trace message, got %s", tm.Data) + } + if i < 2 { + // Set AllowTrace to true at the first iteration, then + // false at the second. + bClaim.Imports[0].AllowTrace = (i == 0) + bJwt, err = bClaim.Encode(oKp) + require_NoError(t, err) + updateJwt(t, s.ClientURL(), sysCreds, bJwt, 1) + } + }) + } +}