From d6678a871cecf1b3b4bc3f8c034f0461f98af53e Mon Sep 17 00:00:00 2001 From: Simone Magnani Date: Mon, 19 Feb 2024 11:59:30 +0100 Subject: [PATCH] introducing Map.Drain API to traverse a map while also deleting entries This commit introduces the `Map.Drain` API to traverse the map while also removing its entries. It leverages the same `MapIterator` structure, with the introduction of a new unexported method to handle the map draining. The tests make sure that the behavior is as expected, and that this API returns an error while invoked on the wrong map, such as arrays, for which `Map.Iterate` should be used instead. The `LookupAndDelete` system call support has been introduced in: 1. 5.14 for BPF_MAP_TYPE_HASH, BPF_MAP_TYPE_PERCPU_HASH, BPF_MAP_TYPE_LRU_HASH and BPF_MAP_TYPE_LRU_PERCPU_HASH. 2. 4.20 for BPF_MAP_TYPE_QUEUE, BPF_MAP_TYPE_STACK Do not expect the `Map.Drain` API to work on prior versions, according to the target map type. From the user perspective, the usage should be similar to `Map.Iterate`, as shown as follows: ```go m, err := NewMap(&MapSpec{ Type: Hash, KeySize: 4, ValueSize: 8, MaxEntries: 10, }) // populate here the map and defer close it := m.Drain() for it.Next(keyPtr, &value) { // here the entry doesn't exist anymore in the underlying map. ... } ``` Signed-off-by: Simone Magnani --- map.go | 127 ++++++++++++++++++++++++++++++++------- map_test.go | 167 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 272 insertions(+), 22 deletions(-) diff --git a/map.go b/map.go index 36fe27418..3fc51922f 100644 --- a/map.go +++ b/map.go @@ -1287,10 +1287,31 @@ func batchCount(keys, values any) (int, error) { // // It's not possible to guarantee that all keys in a map will be // returned if there are concurrent modifications to the map. +// +// Iterating a hash map from which keys are being deleted is not +// safe. You may see the same key multiple times. Iteration may +// also abort with an error, see IsIterationAborted. +// +// Iterating a queue/stack map returns an error (NextKey) as the +// Map.Drain API should be used instead. func (m *Map) Iterate() *MapIterator { return newMapIterator(m) } +// Drain traverses a map while also removing entries. +// +// It's safe to create multiple drainers at the same time, +// but their respective outputs will differ. +// +// Draining a map that does not support entry removal such as +// an array return an error (Delete/LookupAndDelete) as the +// Map.Iterate API should be used instead. +func (m *Map) Drain() *MapIterator { + it := newMapIterator(m) + it.drain = true + return it +} + // Close the Map's underlying file descriptor, which could unload the // Map from the kernel if it is not pinned or in use by a loaded Program. func (m *Map) Close() error { @@ -1549,7 +1570,7 @@ type MapIterator struct { // of []byte to avoid allocations. cursor any count, maxEntries uint32 - done bool + done, drain bool err error } @@ -1562,10 +1583,6 @@ func newMapIterator(target *Map) *MapIterator { // Next decodes the next key and value. // -// Iterating a hash map from which keys are being deleted is not -// safe. You may see the same key multiple times. Iteration may -// also abort with an error, see IsIterationAborted. -// // Returns false if there are no more entries. You must check // the result of Err afterwards. // @@ -1574,26 +1591,28 @@ func (mi *MapIterator) Next(keyOut, valueOut interface{}) bool { if mi.err != nil || mi.done { return false } + if mi.drain { + return mi.nextDrain(keyOut, valueOut) + } + return mi.nextIterate(keyOut, valueOut) +} + +func (mi *MapIterator) nextIterate(keyOut, valueOut interface{}) bool { + var key interface{} - // For array-like maps NextKey returns nil only after maxEntries - // iterations. + // For array-like maps NextKey returns nil only after maxEntries iterations. for mi.count <= mi.maxEntries { if mi.cursor == nil { // Pass nil interface to NextKey to make sure the Map's first key // is returned. If we pass an uninitialized []byte instead, it'll see a // non-nil interface and try to marshal it. mi.cursor = make([]byte, mi.target.keySize) - mi.err = mi.target.NextKey(nil, mi.cursor) + key = nil } else { - mi.err = mi.target.NextKey(mi.cursor, mi.cursor) + key = mi.cursor } - if errors.Is(mi.err, ErrKeyNotExist) { - mi.done = true - mi.err = nil - return false - } else if mi.err != nil { - mi.err = fmt.Errorf("get next key: %w", mi.err) + if !mi.fetchNextKey(key) { return false } @@ -1615,20 +1634,84 @@ func (mi *MapIterator) Next(keyOut, valueOut interface{}) bool { return false } - buf := mi.cursor.([]byte) - if ptr, ok := keyOut.(unsafe.Pointer); ok { - copy(unsafe.Slice((*byte)(ptr), len(buf)), buf) - } else { - mi.err = sysenc.Unmarshal(keyOut, buf) + return mi.copyCursorToKeyOut(keyOut) + } + + mi.err = fmt.Errorf("%w", ErrIterationAborted) + return false +} + +func (mi *MapIterator) nextDrain(keyOut, valueOut interface{}) bool { + // Handke keyless map, for which mi.cursor (key used in lookupAndDelete) should be nil + if mi.isKeylessMap() { + if keyOut != nil { + mi.err = fmt.Errorf("non-nil keyOut provided for map without a key, must be nil instead") + return false } + return mi.drainMapEntry(valueOut) + } - return mi.err == nil + // Allocate only once data for retrieving the next key in the map. + if mi.cursor == nil { + mi.cursor = make([]byte, mi.target.keySize) + } + + // Always retrieve first key in the map. This should ensure that the whole map + // is traversed, despite concurrent operations (ordering of items might differ). + for mi.err == nil && mi.fetchNextKey(nil) { + if mi.drainMapEntry(valueOut) { + return mi.copyCursorToKeyOut(keyOut) + } + } + return false +} + +func (mi *MapIterator) isKeylessMap() bool { + return mi.target.keySize == 0 +} + +func (mi *MapIterator) drainMapEntry(valueOut interface{}) bool { + mi.err = mi.target.LookupAndDelete(mi.cursor, valueOut) + if mi.err == nil { + mi.count++ + return true + } + + if errors.Is(mi.err, ErrKeyNotExist) { + mi.err = nil + } else { + mi.err = fmt.Errorf("lookup_and_delete key: %w", mi.err) } - mi.err = fmt.Errorf("%w", ErrIterationAborted) return false } +func (mi *MapIterator) fetchNextKey(key interface{}) bool { + mi.err = mi.target.NextKey(key, mi.cursor) + if mi.err == nil { + return true + } + + if errors.Is(mi.err, ErrKeyNotExist) { + mi.done = true + mi.err = nil + } else { + mi.err = fmt.Errorf("get next key: %w", mi.err) + } + + return false +} + +func (mi *MapIterator) copyCursorToKeyOut(keyOut interface{}) bool { + buf := mi.cursor.([]byte) + if ptr, ok := keyOut.(unsafe.Pointer); ok { + copy(unsafe.Slice((*byte)(ptr), len(buf)), buf) + } else { + mi.err = sysenc.Unmarshal(keyOut, buf) + } + return mi.err == nil +} + // Err returns any encountered error. // // The method must be called after Next returns nil. diff --git a/map_test.go b/map_test.go index 93dd764d7..d9621242a 100644 --- a/map_test.go +++ b/map_test.go @@ -1149,6 +1149,173 @@ func TestMapIteratorAllocations(t *testing.T) { qt.Assert(t, qt.Equals(allocs, float64(0))) } +func TestDrainEmptyMap(t *testing.T) { + for _, mapType := range []MapType{ + Hash, + Queue, + } { + t.Run(mapType.String(), func(t *testing.T) { + var ( + keySize = uint32(4) + key string + value uint64 + keyPtr interface{} = &key + ) + + if mapType == Queue { + testutils.SkipOnOldKernel(t, "4.20", "map type queue") + keySize = 0 + keyPtr = nil + } + + if mapType == Hash { + testutils.SkipOnOldKernel(t, "5.14", "map type hash") + } + + m, err := NewMap(&MapSpec{ + Type: mapType, + KeySize: keySize, + ValueSize: 8, + MaxEntries: 2, + }) + qt.Assert(t, qt.IsNil(err)) + defer m.Close() + + entries := m.Drain() + if entries.Next(keyPtr, &value) { + t.Errorf("Empty %v should not be drainable", mapType) + } + + qt.Assert(t, qt.IsNil(entries.Err())) + }) + } +} + +func TestMapDrain(t *testing.T) { + for _, mapType := range []MapType{ + Hash, + Queue, + } { + t.Run(Hash.String(), func(t *testing.T) { + var ( + key, value uint32 + values []uint32 + anyKey interface{} + keyPtr interface{} = &key + keySize uint32 = 4 + data = []uint32{0, 1} + ) + + if mapType == Queue { + testutils.SkipOnOldKernel(t, "4.20", "map type queue") + keySize = 0 + keyPtr = nil + } + + if mapType == Hash { + testutils.SkipOnOldKernel(t, "5.14", "map type hash") + } + + m, err := NewMap(&MapSpec{ + Type: mapType, + KeySize: keySize, + ValueSize: 4, + MaxEntries: 2, + }) + qt.Assert(t, qt.IsNil(err)) + defer m.Close() + + for _, v := range data { + if keySize != 0 { + anyKey = uint32(v) + } + err := m.Put(anyKey, uint32(v)) + qt.Assert(t, qt.IsNil(err)) + } + + entries := m.Drain() + for entries.Next(keyPtr, &value) { + values = append(values, value) + } + qt.Assert(t, qt.IsNil(entries.Err())) + + sort.Slice(values, func(i, j int) bool { return values[i] < values[j] }) + qt.Assert(t, qt.DeepEquals(values, data)) + }) + } +} + +func TestDrainWrongMap(t *testing.T) { + arr, err := NewMap(&MapSpec{ + Type: Array, + KeySize: 4, + ValueSize: 4, + MaxEntries: 10, + }) + qt.Assert(t, qt.IsNil(err)) + defer arr.Close() + + var key, value uint32 + entries := arr.Drain() + + qt.Assert(t, qt.IsFalse(entries.Next(&key, &value))) + qt.Assert(t, qt.IsNotNil(entries.Err())) +} + +func TestMapDrainerAllocations(t *testing.T) { + for _, mapType := range []MapType{ + Hash, + Queue, + } { + t.Run(mapType.String(), func(t *testing.T) { + var ( + key, value uint32 + anyKey interface{} + keyPtr interface{} = &key + keySize uint32 = 4 + ) + + if mapType == Queue { + testutils.SkipOnOldKernel(t, "4.20", "map type queue") + keySize = 0 + keyPtr = nil + } + + if mapType == Hash { + testutils.SkipOnOldKernel(t, "5.14", "map type hash") + } + + m, err := NewMap(&MapSpec{ + Type: mapType, + KeySize: keySize, + ValueSize: 4, + MaxEntries: 10, + }) + qt.Assert(t, qt.ErrorIs(err, nil)) + defer m.Close() + + for i := 0; i < int(m.MaxEntries()); i++ { + if keySize != 0 { + anyKey = uint32(i) + } + if err := m.Put(anyKey, uint32(i)); err != nil { + t.Fatal(err) + } + } + + iter := m.Drain() + + allocs := testing.AllocsPerRun(int(m.MaxEntries()-1), func() { + if !iter.Next(keyPtr, &value) { + t.Fatal("Next failed while draining: %w", iter.Err()) + } + }) + + qt.Assert(t, qt.Equals(allocs, float64(0))) + }) + } +} + func TestMapBatchLookupAllocations(t *testing.T) { testutils.SkipIfNotSupported(t, haveBatchAPI())