Skip to content

Commit

Permalink
Added sampling for account's trace destination (#5121)
Browse files Browse the repository at this point in the history
If an account has a trace destination and an incoming message has the
`traceparent` header with the proper sampled flag, a trace message would
be triggered. The `sampling` field allows to trace a certain percentage
of the traffic.

The field `trace_dest` or now `msg_trace` can be a simple string
representing the destination, and in this case sampling is 100% or it
can be a structure with the `dest` and `sampling` fields. Sampling
values that are negative or above 100 will trigger an error on
configuration parsing. A value of 0 is converted to 100. If the sampling
is specified without an account trace destination, it is set to 0 and a
warning is issued when parsing configuration.

There is similar support for the property set in a JWT account claim.

Relates to #5014

Signed-off-by: Ivan Kozlovic <ivan@synadia.com>
  • Loading branch information
derekcollison committed Feb 23, 2024
2 parents 6dde0fc + c8c6c62 commit eedaef4
Show file tree
Hide file tree
Showing 10 changed files with 432 additions and 26 deletions.
2 changes: 1 addition & 1 deletion go.mod
Expand Up @@ -5,7 +5,7 @@ go 1.20
require (
github.com/klauspost/compress v1.17.6
github.com/minio/highwayhash v1.0.2
github.com/nats-io/jwt/v2 v2.5.4
github.com/nats-io/jwt/v2 v2.5.5
github.com/nats-io/nats.go v1.33.1
github.com/nats-io/nkeys v0.4.7
github.com/nats-io/nuid v1.0.1
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Expand Up @@ -3,8 +3,8 @@ github.com/klauspost/compress v1.17.6 h1:60eq2E/jlfwQXtvZEeBUYADs+BwKBWURIY+Gj2e
github.com/klauspost/compress v1.17.6/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA/g=
github.com/minio/highwayhash v1.0.2/go.mod h1:BQskDq+xkJ12lmlUUi7U0M5Swg3EWR+dLTk+kldvVxY=
github.com/nats-io/jwt/v2 v2.5.4 h1:Bz+drKl2GbE30fxTOtb0NYl1BQ5RwZ+Zcqkg3mR5bbI=
github.com/nats-io/jwt/v2 v2.5.4/go.mod h1:ZdWS1nZa6WMZfFwwgpEaqBV8EPGVgOTDHN/wTbz0Y5A=
github.com/nats-io/jwt/v2 v2.5.5 h1:ROfXb50elFq5c9+1ztaUbdlrArNFl2+fQWP6B8HGEq4=
github.com/nats-io/jwt/v2 v2.5.5/go.mod h1:ZdWS1nZa6WMZfFwwgpEaqBV8EPGVgOTDHN/wTbz0Y5A=
github.com/nats-io/nats.go v1.33.1 h1:8TxLZZ/seeEfR97qV0/Bl939tpDnt2Z2fK3HkPypj70=
github.com/nats-io/nats.go v1.33.1/go.mod h1:Ubdu4Nh9exXdSz0RVWRFBbRfrbSxOYd26oF0wkWclB8=
github.com/nats-io/nkeys v0.4.7 h1:RwNJbbIdYCoClSDNY7QVKZlyb/wfT6ugvFCiKy6vDvI=
Expand Down
28 changes: 20 additions & 8 deletions server/accounts.go
Expand Up @@ -96,7 +96,12 @@ type Account struct {
nameTag string
lastLimErr int64
routePoolIdx int
traceDest string
// If the trace destination is specified and a message with a traceParentHdr
// is received, and has the least significant bit of the last token set to 1,
// then if traceDestSampling is > 0 and < 100, a random value will be selected
// and if it falls between 0 and that value, message tracing will be triggered.
traceDest string
traceDestSampling int
}

const (
Expand Down Expand Up @@ -263,11 +268,12 @@ func (a *Account) setTraceDest(dest string) {
a.mu.Unlock()
}

func (a *Account) getTraceDest() string {
func (a *Account) getTraceDestAndSampling() (string, int) {
a.mu.RLock()
dest := a.traceDest
sampling := a.traceDestSampling
a.mu.RUnlock()
return dest
return dest, sampling
}

// Used to create shallow copies of accounts for transfer
Expand All @@ -278,7 +284,7 @@ func (a *Account) getTraceDest() string {
func (a *Account) shallowCopy(na *Account) {
na.Nkey = a.Nkey
na.Issuer = a.Issuer
na.traceDest = a.traceDest
na.traceDest, na.traceDestSampling = a.traceDest, a.traceDestSampling

if a.imports.streams != nil {
na.imports.streams = make([]*streamImport, 0, len(a.imports.streams))
Expand Down Expand Up @@ -3239,12 +3245,18 @@ func (s *Server) updateAccountClaimsWithRefresh(a *Account, ac *jwt.AccountClaim
a.nameTag = ac.Name
a.tags = ac.Tags

var td string
var tds int
if ac.Trace != nil {
// Update TraceDest
a.traceDest = string(ac.Trace.Destination)
} else {
a.traceDest = _EMPTY_
// Update trace destination and sampling
td, tds = string(ac.Trace.Destination), ac.Trace.Sampling
if !IsValidPublishSubject(td) {
td, tds = _EMPTY_, 0
} else if tds <= 0 || tds > 100 {
tds = 100
}
}
a.traceDest, a.traceDestSampling = td, tds

// Check for external authorization.
if ac.HasExternalAuthorization() {
Expand Down
9 changes: 6 additions & 3 deletions server/accounts_test.go
Expand Up @@ -510,6 +510,7 @@ func TestAccountSimpleConfig(t *testing.T) {

func TestAccountParseConfig(t *testing.T) {
traceDest := "my.trace.dest"
traceSampling := 50
confFileName := createConfFile(t, []byte(fmt.Sprintf(`
accounts {
synadia {
Expand All @@ -519,14 +520,14 @@ func TestAccountParseConfig(t *testing.T) {
]
}
nats.io {
trace_dest: %q
msg_trace: {dest: %q, sampling: %d%%}
users = [
{user: derek, password: foo}
{user: ivan, password: bar}
]
}
}
`, traceDest)))
`, traceDest, traceSampling)))
opts, err := ProcessConfigFile(confFileName)
if err != nil {
t.Fatalf("Received an error processing config file: %v", err)
Expand All @@ -550,7 +551,9 @@ func TestAccountParseConfig(t *testing.T) {
if natsAcc == nil {
t.Fatalf("Error retrieving account for 'nats.io'")
}
require_Equal[string](t, natsAcc.getTraceDest(), traceDest)
td, tds := natsAcc.getTraceDestAndSampling()
require_Equal[string](t, td, traceDest)
require_Equal[int](t, tds, traceSampling)

for _, u := range opts.Users {
if u.Username == "derek" {
Expand Down
4 changes: 2 additions & 2 deletions server/client.go
Expand Up @@ -4356,8 +4356,8 @@ func (c *client) processServiceImport(si *serviceImport, acc *Account, msg []byt
// We also need to disable the message trace headers so that
// if the message is routed, it does not initialize tracing in the
// remote.
positions := mt.disableTraceHeaders(c, msg)
defer mt.enableTraceHeaders(c, msg, positions)
positions := disableTraceHeaders(c, msg)
defer enableTraceHeaders(c, msg, positions)
}
}
}
Expand Down
90 changes: 89 additions & 1 deletion server/config_check_test.go
Expand Up @@ -920,7 +920,7 @@ func TestConfigCheck(t *testing.T) {
A { trace_dest: 123 }
}
`,
err: errors.New(`interface conversion: interface {} is int64, not string`),
err: errors.New(`Expected account message trace "trace_dest" to be a string or a map/struct, got int64`),
errorLine: 3,
errorPos: 23,
},
Expand All @@ -946,6 +946,94 @@ func TestConfigCheck(t *testing.T) {
errorLine: 3,
errorPos: 23,
},
{
name: "when account message trace dest is wrong type",
config: `
accounts {
A { msg_trace: {dest: 123} }
}
`,
err: errors.New(`Field "dest" should be a string, got int64`),
errorLine: 3,
errorPos: 35,
},
{
name: "when account message trace dest is invalid",
config: `
accounts {
A { msg_trace: {dest: "invalid..dest"} }
}
`,
err: errors.New(`Trace destination "invalid..dest" is not valid`),
errorLine: 3,
errorPos: 35,
},
{
name: "when account message trace sampling is wrong type",
config: `
accounts {
A { msg_trace: {dest: "acc.dest", sampling: {wront: "type"}} }
}
`,
err: errors.New(`Trace destination sampling field "sampling" should be an integer or a percentage, got map[string]interface {}`),
errorLine: 3,
errorPos: 53,
},
{
name: "when account message trace sampling is wrong string",
config: `
accounts {
A { msg_trace: {dest: "acc.dest", sampling: abc%} }
}
`,
err: errors.New(`Invalid trace destination sampling value "abc%"`),
errorLine: 3,
errorPos: 53,
},
{
name: "when account message trace sampling is negative",
config: `
accounts {
A { msg_trace: {dest: "acc.dest", sampling: -1} }
}
`,
err: errors.New(`Ttrace destination sampling value -1 is invalid, needs to be [1..100]`),
errorLine: 3,
errorPos: 53,
},
{
name: "when account message trace sampling is zero",
config: `
accounts {
A { msg_trace: {dest: "acc.dest", sampling: 0} }
}
`,
err: errors.New(`Ttrace destination sampling value 0 is invalid, needs to be [1..100]`),
errorLine: 3,
errorPos: 53,
},
{
name: "when account message trace sampling is more than 100",
config: `
accounts {
A { msg_trace: {dest: "acc.dest", sampling: 101} }
}
`,
err: errors.New(`Ttrace destination sampling value 101 is invalid, needs to be [1..100]`),
errorLine: 3,
errorPos: 53,
},
{
name: "when account message trace has unknown field",
config: `
accounts {
A { msg_trace: {wrong: "field"} }
}
`,
err: errors.New(`Unknown field "wrong" parsing account message trace map/struct "msg_trace"`),
errorLine: 3,
errorPos: 35,
},
{
name: "when user authorization config has both token and users",
config: `
Expand Down
24 changes: 21 additions & 3 deletions server/msgtrace.go
Expand Up @@ -17,6 +17,7 @@ import (
"bytes"
"encoding/json"
"fmt"
"math/rand"
"strconv"
"strings"
"sync/atomic"
Expand Down Expand Up @@ -435,13 +436,21 @@ func (c *client) initMsgTrace() *msgTrace {
// If external, we need to have the account's trace destination set,
// otherwise, we are not enabling tracing.
if external {
var sampling int
if acc != nil {
dest = acc.getTraceDest()
dest, sampling = acc.getTraceDestAndSampling()
}
if dest == _EMPTY_ {
// No account destination, no tracing for external trace headers.
return nil
}
// Check sampling, but only from origin server.
if c.kind == CLIENT && !sample(sampling) {
// Need to desactivate the traceParentHdr so that if the message
// is routed, it does possibly trigger a trace there.
disableTraceHeaders(c, hdr)
return nil
}
}
c.pa.trace = &msgTrace{
srv: c.srv,
Expand Down Expand Up @@ -472,6 +481,15 @@ func (c *client) initMsgTrace() *msgTrace {
return c.pa.trace
}

func sample(sampling int) bool {
// Option parsing should ensure that sampling is [1..100], but consider
// any value outside of this range to be 100%.
if sampling <= 0 || sampling >= 100 {
return true
}
return rand.Int31n(100) <= int32(sampling)
}

// This function will return the header as a map (instead of http.Header because
// we want to preserve the header names' case) and a boolean that indicates if
// the headers have been lifted due to the presence of the external trace header
Expand Down Expand Up @@ -637,7 +655,7 @@ func (t *msgTrace) setHopHeader(c *client, msg []byte) []byte {
// Note that if `msg` can be either the header alone or the full message
// (header and payload). This function will use c.pa.hdr to limit the
// search to the header section alone.
func (t *msgTrace) disableTraceHeaders(c *client, msg []byte) []int {
func disableTraceHeaders(c *client, msg []byte) []int {
// Code largely copied from getHeader(), except that we don't need the value
if c.pa.hdr <= 0 {
return []int{-1, -1}
Expand Down Expand Up @@ -672,7 +690,7 @@ func (t *msgTrace) disableTraceHeaders(c *client, msg []byte) []int {

// Changes back the character at the given position `pos` in the `msg`
// byte slice to the first character of the MsgTraceSendTo header.
func (t *msgTrace) enableTraceHeaders(c *client, msg []byte, positions []int) {
func enableTraceHeaders(c *client, msg []byte, positions []int) {
firstChar := [2]byte{MsgTraceDest[0], traceParentHdr[0]}
for i, pos := range positions {
if pos == -1 {
Expand Down

0 comments on commit eedaef4

Please sign in to comment.