Skip to content

Commit

Permalink
[FIXED] Validation in jetstream and KV (#1613)
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Piotrowski <piotr@synadia.com>
  • Loading branch information
piotrpio committed Apr 22, 2024
1 parent 9d4b227 commit 7bdb629
Show file tree
Hide file tree
Showing 11 changed files with 226 additions and 15 deletions.
10 changes: 8 additions & 2 deletions jetstream/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,9 @@ func upsertConsumer(ctx context.Context, js *jetStream, stream string, cfg Consu

var ccSubj string
if cfg.FilterSubject != "" && len(cfg.FilterSubjects) == 0 {
if err := validateSubject(cfg.FilterSubject); err != nil {
return nil, err
}
ccSubj = apiSubj(js.apiPrefix, fmt.Sprintf(apiConsumerCreateWithFilterSubjectT, stream, consumerName, cfg.FilterSubject))
} else {
ccSubj = apiSubj(js.apiPrefix, fmt.Sprintf(apiConsumerCreateT, stream, consumerName))
Expand Down Expand Up @@ -318,8 +321,11 @@ func deleteConsumer(ctx context.Context, js *jetStream, stream, consumer string)
}

func validateConsumerName(dur string) error {
if strings.Contains(dur, ".") {
return fmt.Errorf("%w: %q", ErrInvalidConsumerName, dur)
if dur == "" {
return fmt.Errorf("%w: '%s'", ErrInvalidConsumerName, "name is required")
}
if strings.ContainsAny(dur, ">*. /\\") {
return fmt.Errorf("%w: '%s'", ErrInvalidConsumerName, dur)
}
return nil
}
4 changes: 2 additions & 2 deletions jetstream/jetstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,7 @@ func validateStreamName(stream string) error {
if stream == "" {
return ErrStreamNameRequired
}
if strings.Contains(stream, ".") {
if strings.ContainsAny(stream, ">*. /\\") {
return fmt.Errorf("%w: '%s'", ErrInvalidStreamName, stream)
}
return nil
Expand All @@ -783,7 +783,7 @@ func validateSubject(subject string) error {
if subject == "" {
return fmt.Errorf("%w: %s", ErrInvalidSubject, "subject cannot be empty")
}
if !subjectRegexp.MatchString(subject) {
if subject[0] == '.' || subject[len(subject)-1] == '.' || !subjectRegexp.MatchString(subject) {
return fmt.Errorf("%w: %s", ErrInvalidSubject, subject)
}
return nil
Expand Down
93 changes: 93 additions & 0 deletions jetstream/jetstream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -456,3 +456,96 @@ func TestPullConsumer_checkPending(t *testing.T) {
})
}
}

func TestKV_keyValid(t *testing.T) {
tests := []struct {
key string
ok bool
}{
{key: "foo123", ok: true},
{key: "foo.bar", ok: true},
{key: "Foo.123=bar_baz-abc", ok: true},
{key: "foo.*.bar", ok: false},
{key: "foo.>", ok: false},
{key: ">", ok: false},
{key: "*", ok: false},
{key: "foo!", ok: false},
{key: "foo bar", ok: false},
{key: "", ok: false},
{key: " ", ok: false},
{key: ".", ok: false},
{key: ".foo", ok: false},
{key: "foo.", ok: false},
}

for _, test := range tests {
t.Run(test.key, func(t *testing.T) {
res := keyValid(test.key)
if res != test.ok {
t.Fatalf("Invalid result; want: %v; got: %v", test.ok, res)
}
})
}
}

func TestKV_searchKeyValid(t *testing.T) {
tests := []struct {
key string
ok bool
}{
{key: "foo123", ok: true},
{key: "foo.bar", ok: true},
{key: "Foo.123=bar_baz-abc", ok: true},
{key: "foo.*.bar", ok: true},
{key: "foo.>", ok: true},
{key: ">", ok: true},
{key: "*", ok: true},
{key: "foo!", ok: false},
{key: "foo bar", ok: false},
{key: "", ok: false},
{key: " ", ok: false},
{key: ".", ok: false},
{key: ".foo", ok: false},
{key: "foo.", ok: false},
}

for _, test := range tests {
t.Run(test.key, func(t *testing.T) {
res := searchKeyValid(test.key)
if res != test.ok {
t.Fatalf("Invalid result; want: %v; got: %v", test.ok, res)
}
})
}
}

func TestKV_bucketValid(t *testing.T) {
tests := []struct {
key string
ok bool
}{
{key: "foo123", ok: true},
{key: "Foo123-bar_baz", ok: true},
{key: "foo.bar", ok: false},
{key: "foo.*.bar", ok: false},
{key: "foo.>", ok: false},
{key: ">", ok: false},
{key: "*", ok: false},
{key: "foo!", ok: false},
{key: "foo bar", ok: false},
{key: "", ok: false},
{key: " ", ok: false},
{key: ".", ok: false},
{key: ".foo", ok: false},
{key: "foo.", ok: false},
}

for _, test := range tests {
t.Run(test.key, func(t *testing.T) {
res := bucketValid(test.key)
if res != test.ok {
t.Fatalf("Invalid result; want: %v; got: %v", test.ok, res)
}
})
}
}
28 changes: 23 additions & 5 deletions jetstream/kv.go
Original file line number Diff line number Diff line change
Expand Up @@ -448,12 +448,13 @@ const (

// Regex for valid keys and buckets.
var (
validBucketRe = regexp.MustCompile(`\A[a-zA-Z0-9_-]+\z`)
validKeyRe = regexp.MustCompile(`\A[-/_=\.a-zA-Z0-9]+\z`)
validBucketRe = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
validKeyRe = regexp.MustCompile(`^[-/_=\.a-zA-Z0-9]+$`)
validSearchKeyRe = regexp.MustCompile(`^[-/_=\.a-zA-Z0-9*]*[>]?$`)
)

func (js *jetStream) KeyValue(ctx context.Context, bucket string) (KeyValue, error) {
if !validBucketRe.MatchString(bucket) {
if !bucketValid(bucket) {
return nil, ErrInvalidBucketName
}
streamName := fmt.Sprintf(kvBucketNameTmpl, bucket)
Expand Down Expand Up @@ -558,7 +559,7 @@ func (js *jetStream) CreateOrUpdateKeyValue(ctx context.Context, cfg KeyValueCon
}

func (js *jetStream) prepareKeyValueConfig(ctx context.Context, cfg KeyValueConfig) (StreamConfig, error) {
if !validBucketRe.MatchString(cfg.Bucket) {
if !bucketValid(cfg.Bucket) {
return StreamConfig{}, ErrInvalidBucketName
}
if _, err := js.AccountInfo(ctx); err != nil {
Expand Down Expand Up @@ -656,7 +657,7 @@ func (js *jetStream) prepareKeyValueConfig(ctx context.Context, cfg KeyValueConf

// DeleteKeyValue will delete this KeyValue store (JetStream stream).
func (js *jetStream) DeleteKeyValue(ctx context.Context, bucket string) error {
if !validBucketRe.MatchString(bucket) {
if !bucketValid(bucket) {
return ErrInvalidBucketName
}
stream := fmt.Sprintf(kvBucketNameTmpl, bucket)
Expand Down Expand Up @@ -793,13 +794,27 @@ func (js *jetStream) legacyJetStream() (nats.JetStreamContext, error) {
return js.conn.JetStream(opts...)
}

func bucketValid(bucket string) bool {
if len(bucket) == 0 {
return false
}
return validBucketRe.MatchString(bucket)
}

func keyValid(key string) bool {
if len(key) == 0 || key[0] == '.' || key[len(key)-1] == '.' {
return false
}
return validKeyRe.MatchString(key)
}

func searchKeyValid(key string) bool {
if len(key) == 0 || key[0] == '.' || key[len(key)-1] == '.' {
return false
}
return validSearchKeyRe.MatchString(key)
}

func (kv *kvs) get(ctx context.Context, key string, revision uint64) (KeyValueEntry, error) {
if !keyValid(key) {
return nil, ErrInvalidKey
Expand Down Expand Up @@ -1056,6 +1071,9 @@ func (w *watcher) Stop() error {
// Watch for any updates to keys that match the keys argument which could include wildcards.
// Watch will send a nil entry when it has received all initial values.
func (kv *kvs) Watch(ctx context.Context, keys string, opts ...WatchOpt) (KeyWatcher, error) {
if !searchKeyValid(keys) {
return nil, fmt.Errorf("%w: %s", ErrInvalidKey, "keys cannot be empty and must be a valid NATS subject")
}
var o watchOpts
for _, opt := range opts {
if opt != nil {
Expand Down
28 changes: 28 additions & 0 deletions jetstream/test/kv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,34 @@ func TestKeyValueWatch(t *testing.T) {
expectUpdate("age", "22", 3)
expectUpdate("name2", "ik", 4)
})

t.Run("invalid watchers", func(t *testing.T) {
s := RunBasicJetStreamServer()
defer shutdownJSServerAndRemoveStorage(t, s)

nc, js := jsClient(t, s)
defer nc.Close()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

kv, err := js.CreateKeyValue(ctx, jetstream.KeyValueConfig{Bucket: "WATCH"})
expectOk(t, err)

// empty keys
_, err = kv.Watch(ctx, "")
expectErr(t, err, jetstream.ErrInvalidKey)

// invalid key
_, err = kv.Watch(ctx, "a.>.b")
expectErr(t, err, jetstream.ErrInvalidKey)

_, err = kv.Watch(ctx, "foo.")
expectErr(t, err, jetstream.ErrInvalidKey)

// conflicting options
_, err = kv.Watch(ctx, "foo", jetstream.IncludeHistory(), jetstream.UpdatesOnly())
expectErr(t, err, jetstream.ErrInvalidOption)
})
}

func TestKeyValueWatchContext(t *testing.T) {
Expand Down
10 changes: 10 additions & 0 deletions jetstream/test/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,16 @@ func TestCreateConsumer(t *testing.T) {
consumerConfig: jetstream.ConsumerConfig{FilterSubjects: []string{"FOO.A", ""}},
withError: jetstream.ErrEmptyFilter,
},
{
name: "with invalid filter subject, leading dot",
consumerConfig: jetstream.ConsumerConfig{FilterSubject: ".foo"},
withError: jetstream.ErrInvalidSubject,
},
{
name: "with invalid filter subject, trailing dot",
consumerConfig: jetstream.ConsumerConfig{FilterSubject: "foo."},
withError: jetstream.ErrInvalidSubject,
},
{
name: "consumer already exists, error",
consumerConfig: jetstream.ConsumerConfig{Durable: "dur", Description: "test consumer"},
Expand Down
3 changes: 3 additions & 0 deletions jserrors.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ var (
// ErrInvalidConsumerName is returned when the provided consumer name is invalid (contains '.' or ' ').
ErrInvalidConsumerName JetStreamError = &jsError{message: "invalid consumer name"}

// ErrInvalidFilterSubject is returned when the provided filter subject is invalid.
ErrInvalidFilterSubject JetStreamError = &jsError{message: "invalid filter subject"}

// ErrNoMatchingStream is returned when stream lookup by subject is unsuccessful.
ErrNoMatchingStream JetStreamError = &jsError{message: "no stream matches subject"}

Expand Down
4 changes: 4 additions & 0 deletions jsm.go
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,10 @@ func (js *js) upsertConsumer(stream, consumerName string, cfg *ConsumerConfig, o
// if filter subject is empty or ">", use the endpoint without filter subject
ccSubj = fmt.Sprintf(apiConsumerCreateT, stream, consumerName)
} else {
// safeguard against passing invalid filter subject in request subject
if cfg.FilterSubject[0] == '.' || cfg.FilterSubject[len(cfg.FilterSubject)-1] == '.' {
return nil, fmt.Errorf("%w: %q", ErrInvalidFilterSubject, cfg.FilterSubject)
}
// if filter subject is not empty, use the endpoint with filter subject
ccSubj = fmt.Sprintf(apiConsumerCreateWithFilterSubjectT, stream, consumerName, cfg.FilterSubject)
}
Expand Down
28 changes: 23 additions & 5 deletions kv.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,16 +344,17 @@ const (

// Regex for valid keys and buckets.
var (
validBucketRe = regexp.MustCompile(`\A[a-zA-Z0-9_-]+\z`)
validKeyRe = regexp.MustCompile(`\A[-/_=\.a-zA-Z0-9]+\z`)
validBucketRe = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
validKeyRe = regexp.MustCompile(`^[-/_=\.a-zA-Z0-9]+$`)
validSearchKeyRe = regexp.MustCompile(`^[-/_=\.a-zA-Z0-9*]*[>]?$`)
)

// KeyValue will lookup and bind to an existing KeyValue store.
func (js *js) KeyValue(bucket string) (KeyValue, error) {
if !js.nc.serverMinVersion(2, 6, 2) {
return nil, errors.New("nats: key-value requires at least server version 2.6.2")
}
if !validBucketRe.MatchString(bucket) {
if !bucketValid(bucket) {
return nil, ErrInvalidBucketName
}
stream := fmt.Sprintf(kvBucketNameTmpl, bucket)
Expand Down Expand Up @@ -381,7 +382,7 @@ func (js *js) CreateKeyValue(cfg *KeyValueConfig) (KeyValue, error) {
if cfg == nil {
return nil, ErrKeyValueConfigRequired
}
if !validBucketRe.MatchString(cfg.Bucket) {
if !bucketValid(cfg.Bucket) {
return nil, ErrInvalidBucketName
}
if _, err := js.AccountInfo(); err != nil {
Expand Down Expand Up @@ -507,7 +508,7 @@ func (js *js) CreateKeyValue(cfg *KeyValueConfig) (KeyValue, error) {

// DeleteKeyValue will delete this KeyValue store (JetStream stream).
func (js *js) DeleteKeyValue(bucket string) error {
if !validBucketRe.MatchString(bucket) {
if !bucketValid(bucket) {
return ErrInvalidBucketName
}
stream := fmt.Sprintf(kvBucketNameTmpl, bucket)
Expand Down Expand Up @@ -547,13 +548,27 @@ func (e *kve) Created() time.Time { return e.created }
func (e *kve) Delta() uint64 { return e.delta }
func (e *kve) Operation() KeyValueOp { return e.op }

func bucketValid(bucket string) bool {
if len(bucket) == 0 {
return false
}
return validBucketRe.MatchString(bucket)
}

func keyValid(key string) bool {
if len(key) == 0 || key[0] == '.' || key[len(key)-1] == '.' {
return false
}
return validKeyRe.MatchString(key)
}

func searchKeyValid(key string) bool {
if len(key) == 0 || key[0] == '.' || key[len(key)-1] == '.' {
return false
}
return validSearchKeyRe.MatchString(key)
}

// Get returns the latest value for the key.
func (kv *kvs) Get(key string) (KeyValueEntry, error) {
e, err := kv.get(key, kvLatestRevision)
Expand Down Expand Up @@ -951,6 +966,9 @@ func (kv *kvs) WatchAll(opts ...WatchOpt) (KeyWatcher, error) {
// Watch will fire the callback when a key that matches the keys pattern is updated.
// keys needs to be a valid NATS subject.
func (kv *kvs) Watch(keys string, opts ...WatchOpt) (KeyWatcher, error) {
if !searchKeyValid(keys) {
return nil, fmt.Errorf("%w: %s", ErrInvalidKey, "keys cannot be empty and must be a valid NATS subject")
}
var o watchOpts
for _, opt := range opts {
if opt != nil {
Expand Down
9 changes: 9 additions & 0 deletions test/js_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2604,6 +2604,15 @@ func TestJetStreamManagement(t *testing.T) {
}
})

t.Run("with invalid filter subject", func(t *testing.T) {
if _, err = js.AddConsumer("foo", &nats.ConsumerConfig{Name: "tc", FilterSubject: ".foo"}); !errors.Is(err, nats.ErrInvalidFilterSubject) {
t.Fatalf("Expected: %v; got: %v", nats.ErrInvalidFilterSubject, err)
}
if _, err = js.AddConsumer("foo", &nats.ConsumerConfig{Name: "tc", FilterSubject: "foo."}); !errors.Is(err, nats.ErrInvalidFilterSubject) {
t.Fatalf("Expected: %v; got: %v", nats.ErrInvalidFilterSubject, err)
}
})

t.Run("with invalid consumer name", func(t *testing.T) {
if _, err = js.AddConsumer("foo", &nats.ConsumerConfig{Durable: "test.durable"}); err != nats.ErrInvalidConsumerName {
t.Fatalf("Expected: %v; got: %v", nats.ErrInvalidConsumerName, err)
Expand Down

0 comments on commit 7bdb629

Please sign in to comment.