diff --git a/build.go b/build.go index a545b07..2f3bcf9 100644 --- a/build.go +++ b/build.go @@ -175,6 +175,7 @@ func InitSegmentBase(mem []byte, memCRC uint32, chunkMode uint32, docValueOffset: 0, // docValueOffsets identified automatically by the section dictLocs: dictLocs, fieldFSTs: make(map[uint16]*vellum.FST), + vecIndexCache: newVectorIndexCache(), } sb.updateSize() diff --git a/faiss_vector_cache.go b/faiss_vector_cache.go new file mode 100644 index 0000000..9096d37 --- /dev/null +++ b/faiss_vector_cache.go @@ -0,0 +1,240 @@ +// Copyright (c) 2024 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build vectors +// +build vectors + +package zap + +import ( + "sync" + "sync/atomic" + "time" + + faiss "github.com/blevesearch/go-faiss" +) + +type vectorIndexCache struct { + closeCh chan struct{} + m sync.RWMutex + cache map[uint16]*cacheEntry +} + +type ewma struct { + alpha float64 + avg float64 + // every hit to the cache entry is recorded as part of a sample + // which will be used to calculate the average in the next cycle of average + // computation (which is average traffic for the field till now). this is + // used to track the per second hits to the cache entries. + sample uint64 +} + +type cacheEntry struct { + tracker *ewma + // this is used to track the live references to the cache entry, + // such that while we do a cleanup() and we see that the avg is below a + // threshold we close/cleanup only if the live refs to the cache entry is 0. + refs int64 + index *faiss.IndexImpl +} + +func newVectorIndexCache() *vectorIndexCache { + return &vectorIndexCache{ + cache: make(map[uint16]*cacheEntry), + closeCh: make(chan struct{}), + } +} + +func (vc *vectorIndexCache) Clear() { + vc.m.Lock() + close(vc.closeCh) + + // forcing a close on all indexes to avoid memory leaks. + for _, entry := range vc.cache { + entry.closeIndex() + } + vc.cache = nil + vc.m.Unlock() +} + +func (vc *vectorIndexCache) loadVectorIndex(fieldID uint16, + indexBytes []byte) (vecIndex *faiss.IndexImpl, err error) { + cachedIndex, present := vc.isIndexCached(fieldID) + if present { + vecIndex = cachedIndex + vc.incHit(fieldID) + } else { + vecIndex, err = vc.createAndCacheVectorIndex(fieldID, indexBytes) + } + vc.addRef(fieldID) + return vecIndex, err +} + +func (vc *vectorIndexCache) createAndCacheVectorIndex(fieldID uint16, + indexBytes []byte) (*faiss.IndexImpl, error) { + vc.m.Lock() + defer vc.m.Unlock() + + // when there are multiple threads trying to build the index, guard redundant + // index creation by doing a double check and return if already created and + // cached. + entry, present := vc.cache[fieldID] + if present { + rv := entry.index + entry.incHit() + return rv, nil + } + + // if the cache doesn't have vector index, just construct it out of the + // index bytes and update the cache under lock. + vecIndex, err := faiss.ReadIndexFromBuffer(indexBytes, faiss.IOFlagReadOnly) + vc.updateLOCKED(fieldID, vecIndex) + return vecIndex, err +} +func (vc *vectorIndexCache) updateLOCKED(fieldIDPlus1 uint16, index *faiss.IndexImpl) { + // the first time we've hit the cache, try to spawn a monitoring routine + // which will reconcile the moving averages for all the fields being hit + if len(vc.cache) == 0 { + go vc.monitor() + } + + _, ok := vc.cache[fieldIDPlus1] + if !ok { + // initializing the alpha with 0.4 essentially means that we are favoring + // the history a little bit more relative to the current sample value. + // this makes the average to be kept above the threshold value for a + // longer time and thereby the index to be resident in the cache + // for longer time. + vc.cache[fieldIDPlus1] = createCacheEntry(index, 0.4) + } +} + +func (vc *vectorIndexCache) isIndexCached(fieldID uint16) (*faiss.IndexImpl, bool) { + vc.m.RLock() + entry, present := vc.cache[fieldID] + vc.m.RUnlock() + if entry == nil { + return nil, false + } + + rv := entry.index + return rv, present && (rv != nil) +} + +func (vc *vectorIndexCache) incHit(fieldIDPlus1 uint16) { + vc.m.RLock() + entry := vc.cache[fieldIDPlus1] + entry.incHit() + vc.m.RUnlock() +} + +func (vc *vectorIndexCache) addRef(fieldIDPlus1 uint16) { + vc.m.RLock() + entry := vc.cache[fieldIDPlus1] + entry.addRef() + vc.m.RUnlock() +} + +func (vc *vectorIndexCache) decRef(fieldIDPlus1 uint16) { + vc.m.RLock() + entry := vc.cache[fieldIDPlus1] + entry.decRef() + vc.m.RUnlock() +} + +func (vc *vectorIndexCache) cleanup() bool { + vc.m.Lock() + cache := vc.cache + + // for every field reconcile the average with the current sample values + for fieldIDPlus1, entry := range cache { + sample := atomic.LoadUint64(&entry.tracker.sample) + entry.tracker.add(sample) + + refCount := atomic.LoadInt64(&entry.refs) + // the comparison threshold as of now is (1 - a). mathematically it + // means that there is only 1 query per second on average as per history. + // and in the current second, there were no queries performed against + // this index. + if entry.tracker.avg <= (1-entry.tracker.alpha) && refCount <= 0 { + atomic.StoreUint64(&entry.tracker.sample, 0) + entry.closeIndex() + delete(vc.cache, fieldIDPlus1) + continue + } + atomic.StoreUint64(&entry.tracker.sample, 0) + } + + rv := len(vc.cache) == 0 + vc.m.Unlock() + return rv +} + +var monitorFreq = 1 * time.Second + +func (vc *vectorIndexCache) monitor() { + ticker := time.NewTicker(monitorFreq) + for { + select { + case <-vc.closeCh: + return + case <-ticker.C: + exit := vc.cleanup() + if exit { + // no entries to be monitored, exit + return + } + } + } +} + +func (e *ewma) add(val uint64) { + if e.avg == 0.0 { + e.avg = float64(val) + } else { + // the exponentially weighted moving average + // X(t) = a.v + (1 - a).X(t-1) + e.avg = e.alpha*float64(val) + (1-e.alpha)*e.avg + } +} + +func createCacheEntry(index *faiss.IndexImpl, alpha float64) *cacheEntry { + return &cacheEntry{ + index: index, + tracker: &ewma{ + alpha: alpha, + sample: 1, + }, + } +} + +func (ce *cacheEntry) incHit() { + atomic.AddUint64(&ce.tracker.sample, 1) +} + +func (ce *cacheEntry) addRef() { + atomic.AddInt64(&ce.refs, 1) +} + +func (ce *cacheEntry) decRef() { + atomic.AddInt64(&ce.refs, -1) +} + +func (ce *cacheEntry) closeIndex() { + go func() { + ce.index.Close() + ce.index = nil + }() +} diff --git a/faiss_vector_cache_nosup.go b/faiss_vector_cache_nosup.go new file mode 100644 index 0000000..ff152f9 --- /dev/null +++ b/faiss_vector_cache_nosup.go @@ -0,0 +1,27 @@ +// Copyright (c) 2024 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !vectors +// +build !vectors + +package zap + +type vectorIndexCache struct { +} + +func newVectorIndexCache() *vectorIndexCache { + return nil +} + +func (v *vectorIndexCache) Clear() {} diff --git a/faiss_vector_posting.go b/faiss_vector_posting.go index adfbed2..7bc4b3e 100644 --- a/faiss_vector_posting.go +++ b/faiss_vector_posting.go @@ -295,6 +295,8 @@ func (sb *SegmentBase) InterpretVectorIndex(field string, except *roaring.Bitmap var vecIndex *faiss.IndexImpl vecDocIDMap := make(map[int64]uint32) var vectorIDsToExclude []int64 + var fieldIDPlus1 uint16 + var vecIndexSize uint64 var ( wrapVecIndex = &vectorIndexWrapper{ @@ -335,22 +337,19 @@ func (sb *SegmentBase) InterpretVectorIndex(field string, except *roaring.Bitmap return rv, nil }, close: func() { - if vecIndex != nil { - vecIndex.Close() - } + // skipping the closing because the index is cached and it's being + // deferred to a later point of time. + sb.vecIndexCache.decRef(fieldIDPlus1) }, size: func() uint64 { - if vecIndex != nil { - return vecIndex.Size() - } - return 0 + return vecIndexSize }, } err error ) - fieldIDPlus1 := sb.fieldsMap[field] + fieldIDPlus1 = sb.fieldsMap[field] if fieldIDPlus1 <= 0 { return wrapVecIndex, nil } @@ -392,13 +391,15 @@ func (sb *SegmentBase) InterpretVectorIndex(field string, except *roaring.Bitmap vecDocIDMap[vecID] = docIDUint32 } - // todo: not a good idea to cache the vector index perhaps, since it could be quite huge. indexSize, n := binary.Uvarint(sb.mem[pos : pos+binary.MaxVarintLen64]) pos += n - indexBytes := sb.mem[pos : pos+int(indexSize)] + + vecIndex, err = sb.vecIndexCache.loadVectorIndex(fieldIDPlus1, sb.mem[pos:pos+int(indexSize)]) pos += int(indexSize) - vecIndex, err = faiss.ReadIndexFromBuffer(indexBytes, faiss.IOFlagReadOnly) + if vecIndex != nil { + vecIndexSize = vecIndex.Size() + } return wrapVecIndex, err } diff --git a/segment.go b/segment.go index 062abf2..0c040a3 100644 --- a/segment.go +++ b/segment.go @@ -55,6 +55,7 @@ func (*ZapPlugin) Open(path string) (segment.Segment, error) { SegmentBase: SegmentBase{ fieldsMap: make(map[string]uint16), fieldFSTs: make(map[uint16]*vellum.FST), + vecIndexCache: newVectorIndexCache(), fieldDvReaders: make([]map[uint16]*docValueReader, len(segmentSections)), }, f: f, @@ -81,7 +82,6 @@ func (*ZapPlugin) Open(path string) (segment.Segment, error) { _ = rv.Close() return nil, err } - return rv, nil } @@ -110,6 +110,9 @@ type SegmentBase struct { m sync.Mutex fieldFSTs map[uint16]*vellum.FST + + // this cache comes into play when vectors are supported in builds. + vecIndexCache *vectorIndexCache } func (sb *SegmentBase) Size() int { @@ -146,7 +149,7 @@ func (sb *SegmentBase) updateSize() { func (sb *SegmentBase) AddRef() {} func (sb *SegmentBase) DecRef() (err error) { return nil } -func (sb *SegmentBase) Close() (err error) { return nil } +func (sb *SegmentBase) Close() (err error) { sb.vecIndexCache.Clear(); return nil } // Segment implements a persisted segment.Segment interface, by // embedding an mmap()'ed SegmentBase. @@ -629,6 +632,9 @@ func (s *Segment) Close() (err error) { } func (s *Segment) closeActual() (err error) { + // clear contents from the vector index cache before un-mmapping + s.vecIndexCache.Clear() + if s.mm != nil { err = s.mm.Unmap() } @@ -640,6 +646,7 @@ func (s *Segment) closeActual() (err error) { err = err2 } } + return }