From d8dd91ff9be75bfef0e49d0f55900b29975041c6 Mon Sep 17 00:00:00 2001 From: KevBurnsJr Date: Sun, 26 Sep 2021 15:08:35 -0700 Subject: [PATCH] Thread safety --- skipfilter.go | 72 +++++++++++++++++++++++++++++++++------------------ 1 file changed, 47 insertions(+), 25 deletions(-) diff --git a/skipfilter.go b/skipfilter.go index 1b40a9a..754d52f 100644 --- a/skipfilter.go +++ b/skipfilter.go @@ -2,11 +2,13 @@ package skipfilter import ( "fmt" + "runtime" "sync" + "sync/atomic" "github.com/MauriceGit/skiplist" "github.com/RoaringBitmap/roaring/roaring64" - "github.com/hashicorp/golang-lru/simplelru" + "github.com/hashicorp/golang-lru" ) type SkipFilter struct { @@ -14,7 +16,7 @@ type SkipFilter struct { idx map[interface{}]uint64 set *roaring64.Bitmap list skiplist.SkipList - cache *simplelru.LRU + cache *lru.Cache test func(interface{}, interface{}) bool mutex sync.RWMutex } @@ -27,7 +29,7 @@ func New(test func(value interface{}, filter interface{}) bool, size int) *SkipF if size <= 0 { size = 1e5 } - cache, _ := simplelru.NewLRU(size, nil) + cache, _ := lru.New(size) return &SkipFilter{ idx: make(map[interface{}]uint64), set: roaring64.New(), @@ -39,6 +41,8 @@ func New(test func(value interface{}, filter interface{}) bool, size int) *SkipF // Add adds a value to the set func (sf *SkipFilter) Add(value interface{}) { + sf.mutex.Lock() + defer sf.mutex.Unlock() el := &entry{sf.i, value} sf.list.Insert(el) sf.set.Add(sf.i) @@ -48,6 +52,8 @@ func (sf *SkipFilter) Add(value interface{}) { // Remove removes a value from the set func (sf *SkipFilter) Remove(value interface{}) { + sf.mutex.Lock() + defer sf.mutex.Unlock() if id, ok := sf.idx[value]; ok { sf.list.Delete(&entry{id: id}) sf.set.Remove(id) @@ -57,22 +63,31 @@ func (sf *SkipFilter) Remove(value interface{}) { // Len returns the number of values in the set func (sf *SkipFilter) Len() int { + sf.mutex.RLock() + defer sf.mutex.RUnlock() return sf.list.GetNodeCount() } // MatchAny returns a slice of values in the set matching any of the provided filters -func (sf *SkipFilter) MatchAny(filters ...interface{}) []interface{} { - var set = roaring64.New() - var fs = make([]*filter, len(filters)) - for i, k := range filters { - fs[i] = sf.getFilter(k) - set.Or(fs[i].set) +func (sf *SkipFilter) MatchAny(filterKeys ...interface{}) []interface{} { + sf.mutex.RLock() + defer sf.mutex.RUnlock() + var sets = make([]*roaring64.Bitmap, len(filterKeys)) + var filters = make([]*filter, len(filterKeys)) + for i, k := range filterKeys { + filters[i] = sf.getFilter(k) + sets[i] = filters[i].set } + var set = roaring64.ParOr(runtime.NumCPU(), sets...) values, notfound := sf.getValues(set) - for _, id := range notfound { + if len(notfound) > 0 { // Clean up references to removed values - for _, f := range fs { - f.set.Remove(id) + for _, f := range filters { + f.mutex.Lock() + for _, id := range notfound { + f.set.Remove(id) + } + f.mutex.Unlock() } } return values @@ -82,6 +97,8 @@ func (sf *SkipFilter) MatchAny(filters ...interface{}) []interface{} { // Return true in callback to continue iterating, false to stop. // Returned uint64 is index of `next` element (send as `start` to continue iterating) func (sf *SkipFilter) Walk(start uint64, callback func(val interface{}) bool) uint64 { + sf.mutex.RLock() + defer sf.mutex.RUnlock() var i uint64 var id = start var prev uint64 @@ -111,24 +128,28 @@ func (sf *SkipFilter) getFilter(k interface{}) *filter { if ok { f = val.(*filter) } else { - f = &filter{0, roaring64.New()} + f = &filter{i: 0, set: roaring64.New()} sf.cache.Add(k, f) } var id uint64 var prev uint64 var first = true - for el, ok := sf.list.FindGreaterOrEqual(&entry{id: f.i}); ok && el != nil; el = sf.list.Next(el) { - if id = el.GetValue().(*entry).id; !first && id <= prev { - // skiplist loops back to first element so we have to detect loop and break manually - break - } - if sf.test(el.GetValue().(*entry).val, k) { - f.set.Add(id) + if atomic.LoadUint64(&f.i) < sf.i { + f.mutex.Lock() + defer f.mutex.Unlock() + for el, ok := sf.list.FindGreaterOrEqual(&entry{id: f.i}); ok && el != nil; el = sf.list.Next(el) { + if id = el.GetValue().(*entry).id; !first && id <= prev { + // skiplist loops back to first element so we have to detect loop and break manually + break + } + if sf.test(el.GetValue().(*entry).val, k) { + f.set.Add(id) + } + prev = id + first = false } - prev = id - first = false + f.i = sf.i } - f.i = sf.i return f } @@ -166,6 +187,7 @@ func (e *entry) String() string { } type filter struct { - i uint64 - set *roaring64.Bitmap + i uint64 + mutex sync.RWMutex + set *roaring64.Bitmap }