Skip to content

Commit

Permalink
statedb: Allow non-terminated keys
Browse files Browse the repository at this point in the history
Previously statedb required the indexers to either terminate or return fixed-size
keys in order to exactly match with just prefix searching the trees. This is a
fairly large footgun.

This commit adds checks to the iterators to make sure the key matches exactly and is
not a longer key that shares the prefix. For indexers that don't terminate the key,
this may result in wasting bit of time by traversing non-matching nodes.

This was tested by "breaking" index.String by removing termination and validating
that the regression test passes.

Signed-off-by: Jussi Maki <jussi@isovalent.com>
  • Loading branch information
joamaki committed Nov 30, 2023
1 parent d0d4d46 commit aa15ba7
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 13 deletions.
6 changes: 3 additions & 3 deletions pkg/statedb/api_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (t *RemoteTable[Obj]) query(ctx context.Context, lowerBound bool, q Query[O
errChan <- err
}()

return &getIterator[Obj]{gob.NewDecoder(r)}, errChan
return &remoteGetIterator[Obj]{gob.NewDecoder(r)}, errChan
}
func (t *RemoteTable[Obj]) Get(ctx context.Context, q Query[Obj]) (Iterator[Obj], <-chan error) {
return t.query(ctx, false, q)
Expand All @@ -81,11 +81,11 @@ func (t *RemoteTable[Obj]) LowerBound(ctx context.Context, q Query[Obj]) (Iterat
return t.query(ctx, true, q)
}

type getIterator[Obj any] struct {
type remoteGetIterator[Obj any] struct {
decoder *gob.Decoder
}

func (it *getIterator[Obj]) Next() (obj Obj, revision Revision, ok bool) {
func (it *remoteGetIterator[Obj]) Next() (obj Obj, revision Revision, ok bool) {
err := it.decoder.Decode(&revision)
if err != nil {
return
Expand Down
22 changes: 18 additions & 4 deletions pkg/statedb/api_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ type queryHandler struct {

// /statedb/query
func (h *queryHandler) Handle(params GetStatedbQueryTableParams) middleware.Responder {
key, err := base64.StdEncoding.DecodeString(params.Key)
queryKey, err := base64.StdEncoding.DecodeString(params.Key)
if err != nil {
return api.Error(GetStatedbQueryTableBadRequestCode, fmt.Errorf("Invalid key: %w", err))
}
Expand All @@ -58,15 +58,29 @@ func (h *queryHandler) Handle(params GetStatedbQueryTableParams) middleware.Resp

iter := indexTxn.txn.Root().Iterator()
if params.Lowerbound {
iter.SeekLowerBound(key)
iter.SeekLowerBound(queryKey)
} else {
iter.SeekPrefixWatch(key)
iter.SeekPrefixWatch(queryKey)
}

return middleware.ResponderFunc(func(w http.ResponseWriter, _ runtime.Producer) {
w.WriteHeader(GetStatedbDumpOKCode)
enc := gob.NewEncoder(w)
for _, obj, ok := iter.Next(); ok; _, obj, ok = iter.Next() {

var match func([]byte) bool
if indexTxn.entry.unique {
match = func(k []byte) bool { return len(k) == len(queryKey) }
} else {
match = func(k []byte) bool {
_, secondary := decodeNonUniqueKey(k)
return len(secondary) == len(queryKey)
}
}

for key, obj, ok := iter.Next(); ok; _, obj, ok = iter.Next() {
if !match(key) {
continue
}
if err := enc.Encode(obj.revision); err != nil {
return
}
Expand Down
27 changes: 27 additions & 0 deletions pkg/statedb/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,7 @@ func TestWriteJSON(t *testing.T) {
}
txn.Commit()
}

func Test_callerPackage(t *testing.T) {
t.Parallel()

Expand All @@ -685,6 +686,32 @@ func Test_callerPackage(t *testing.T) {
require.Equal(t, "pkg/statedb", pkg)
}

func Test_nonUniqueKey(t *testing.T) {
// empty keys
key := encodeNonUniqueKey(nil, nil)
primary, secondary := decodeNonUniqueKey(key)
assert.Len(t, primary, 0)
assert.Len(t, secondary, 0)

// empty primary
key = encodeNonUniqueKey(nil, []byte("foo"))
primary, secondary = decodeNonUniqueKey(key)
assert.Len(t, primary, 0)
assert.Equal(t, string(secondary), "foo")

// empty secondary
key = encodeNonUniqueKey([]byte("quux"), []byte{})
primary, secondary = decodeNonUniqueKey(key)
assert.Equal(t, string(primary), "quux")
assert.Len(t, secondary, 0)

// non-empty
key = encodeNonUniqueKey([]byte("foo"), []byte("quux"))
primary, secondary = decodeNonUniqueKey(key)
assert.EqualValues(t, primary, "foo")
assert.EqualValues(t, secondary, "quux")
}

func eventuallyGraveyardIsEmpty(t testing.TB, db *DB) {
require.Eventually(t,
db.graveyardIsEmpty,
Expand Down
59 changes: 59 additions & 0 deletions pkg/statedb/iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package statedb

import (
"bytes"
"fmt"

"k8s.io/apimachinery/pkg/util/sets"
Expand Down Expand Up @@ -54,6 +55,64 @@ func (it *iterator[Obj]) Next() (obj Obj, revision uint64, ok bool) {
return
}

// uniqueIterator iterates over objects in a unique index. Since
// we find the node by prefix search, we may see a key that shares
// the search prefix but is longer. We skip those objects.
type uniqueIterator[Obj any] struct {
iter interface{ Next() ([]byte, object, bool) }
key []byte
}

func (it *uniqueIterator[Obj]) Next() (obj Obj, revision uint64, ok bool) {
var iobj object
for {
var key []byte
key, iobj, ok = it.iter.Next()
if !ok || bytes.Equal(key, it.key) {
break
}
}
if ok {
obj = iobj.data.(Obj)
revision = iobj.revision
}
return
}

// nonUniqueIterator iterates over a non-unique index. Since we seek by prefix and don't
// require that indexers terminate the keys, the iterator checks that the prefix
// has the right length.
type nonUniqueIterator[Obj any] struct {
iter interface{ Next() ([]byte, object, bool) }
key []byte
}

func (it *nonUniqueIterator[Obj]) Next() (obj Obj, revision uint64, ok bool) {
var iobj object
for {
var key []byte
key, iobj, ok = it.iter.Next()
if !ok {
return
}
_, secondary := decodeNonUniqueKey(key)

// Equal length implies equal key since we got here via
// prefix search and all child nodes share the same prefix.
if len(secondary) == len(it.key) {
break
}

// This node has a longer secondary key that shares our search
// prefix, skip it.
}
if ok {
obj = iobj.data.(Obj)
revision = iobj.revision
}
return
}

func NewDualIterator[Obj any](left, right Iterator[Obj]) *DualIterator[Obj] {
return &DualIterator[Obj]{
left: iterState[Obj]{iter: left},
Expand Down
52 changes: 49 additions & 3 deletions pkg/statedb/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,28 @@ func (t *genTable[Obj]) FirstWatch(txn ReadTxn, q Query[Obj]) (obj Obj, revision
indexTxn := txn.getTxn().mustIndexReadTxn(t.table, q.index)
iter := indexTxn.txn.Root().Iterator()
watch = iter.SeekPrefixWatch(q.key)
_, iobj, ok := iter.Next()

var iobj object
for {
var key []byte
key, iobj, ok = iter.Next()
if !ok {
break
}

// Check that we have a full match on the key
var match bool
if indexTxn.entry.unique {
match = len(key) == len(q.key)
} else {
_, secondary := decodeNonUniqueKey(key)
match = len(secondary) == len(q.key)
}
if match {
break
}
}

if ok {
obj = iobj.data.(Obj)
revision = iobj.revision
Expand All @@ -142,7 +163,28 @@ func (t *genTable[Obj]) LastWatch(txn ReadTxn, q Query[Obj]) (obj Obj, revision
indexTxn := txn.getTxn().mustIndexReadTxn(t.table, q.index)
iter := indexTxn.txn.Root().ReverseIterator()
watch = iter.SeekPrefixWatch(q.key)
_, iobj, ok := iter.Previous()

var iobj object
for {
var key []byte
key, iobj, ok = iter.Previous()
if !ok {
break
}

// Check that we have a full match on the key
var match bool
if indexTxn.entry.unique {
match = len(key) == len(q.key)
} else {
_, secondary := decodeNonUniqueKey(key)
match = len(secondary) == len(q.key)
}
if match {
break
}
}

if ok {
obj = iobj.data.(Obj)
revision = iobj.revision
Expand Down Expand Up @@ -175,7 +217,11 @@ func (t *genTable[Obj]) Get(txn ReadTxn, q Query[Obj]) (Iterator[Obj], <-chan st
indexTxn := txn.getTxn().mustIndexReadTxn(t.table, q.index)
iter := indexTxn.txn.Root().Iterator()
watchCh := iter.SeekPrefixWatch(q.key)
return &iterator[Obj]{iter}, watchCh

if indexTxn.entry.unique {
return &uniqueIterator[Obj]{iter, q.key}, watchCh
}
return &nonUniqueIterator[Obj]{iter, q.key}, watchCh
}

func (t *genTable[Obj]) Insert(txn WriteTxn, obj Obj) (oldObj Obj, hadOld bool, err error) {
Expand Down
32 changes: 29 additions & 3 deletions pkg/statedb/txn.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ func (txn *txn) Insert(meta TableMeta, guardRevision Revision, data any) (any, b
// if the new key is different delete the old entry.
indexer.fromObject(oldObj).Foreach(func(oldKey index.Key) {
if !indexer.unique {
oldKey = append(oldKey, idKey...)
oldKey = encodeNonUniqueKey(idKey, oldKey)
}
if !newKeys.Exists(oldKey) {
indexTxn.txn.Delete(oldKey)
Expand All @@ -248,7 +248,7 @@ func (txn *txn) Insert(meta TableMeta, guardRevision Revision, data any) (any, b
// Non-unique secondary indexes are formed by concatenating them
// with the primary key.
if !indexer.unique {
newKey = append(newKey, idKey...)
newKey = encodeNonUniqueKey(idKey, newKey)
}
indexTxn.txn.Insert(newKey, obj)
})
Expand Down Expand Up @@ -334,7 +334,7 @@ func (txn *txn) Delete(meta TableMeta, guardRevision Revision, data any) (any, b
for idx, indexer := range meta.secondaryIndexers() {
indexer.fromObject(obj).Foreach(func(key index.Key) {
if !indexer.unique {
key = append(key, idKey...)
key = encodeNonUniqueKey(idKey, key)
}
txn.mustIndexWriteTxn(tableName, idx).txn.Delete(key)
})
Expand All @@ -353,6 +353,32 @@ func (txn *txn) Delete(meta TableMeta, guardRevision Revision, data any) (any, b
return obj.data, true, nil
}

// encodeNonUniqueKey constructs the internal key to use with non-unique indexes.
// It concatenates the secondary key with the primary key and the length of the secondary key.
// The length is stored as unsigned 16-bit big endian.
// This allows looking up from the non-unique index with the secondary key by doing a prefix
// search. The length is used to safe-guard against indexers that don't terminate the key
// properly (e.g. if secondary key is "foo", then we don't want "foobar" to match).
func encodeNonUniqueKey(primary, secondary []byte) []byte {
key := make([]byte, 0, len(secondary)+len(primary)+2)
key = append(key, secondary...)
key = append(key, primary...)
// KeySet limits size of key to 16 bits.
return binary.BigEndian.AppendUint16(key, uint16(len(secondary)))
}

func decodeNonUniqueKey(key []byte) (primary []byte, secondary []byte) {
// Multi-index key is [<secondary...>, <primary...>, <secondary length>]
if len(key) < 2 {
return nil, nil
}
secondaryLength := int(binary.BigEndian.Uint16(key[len(key)-2:]))
if len(key) < secondaryLength {
return nil, nil
}
return key[secondaryLength : len(key)-2], key[:secondaryLength]
}

func (txn *txn) Abort() {
// If writeTxns is nil, this transaction has already been committed or aborted, and
// thus there is nothing to do. We allow this without failure to allow for defer
Expand Down

0 comments on commit aa15ba7

Please sign in to comment.