Skip to content

Commit

Permalink
MB-61029: Caching Vec To DocID Map (#231)
Browse files Browse the repository at this point in the history
 - Generalised some of the cache function names to be inclusive of the map
 - Added the map to the cache which will behave the same as the index
 - Except bitmap logic is not part of the cache and the vecs excluded is
calculated outside of the map

---------

Co-authored-by: Abhinav Dangeti <abhinav@couchbase.com>
  • Loading branch information
Likith101 and abhinavdangeti committed Apr 18, 2024
1 parent b2384fc commit eeb2336
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 103 deletions.
212 changes: 135 additions & 77 deletions faiss_vector_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,92 +18,117 @@
package zap

import (
"encoding/binary"
"sync"
"sync/atomic"
"time"

"github.com/RoaringBitmap/roaring"
faiss "github.com/blevesearch/go-faiss"
)

type vectorIndexCache struct {
closeCh chan struct{}
m sync.RWMutex
cache map[uint16]*cacheEntry
}

type ewma struct {
alpha float64
avg float64
// every hit to the cache entry is recorded as part of a sample
// which will be used to calculate the average in the next cycle of average
// computation (which is average traffic for the field till now). this is
// used to track the per second hits to the cache entries.
sample uint64
}

type cacheEntry struct {
tracker *ewma
// this is used to track the live references to the cache entry,
// such that while we do a cleanup() and we see that the avg is below a
// threshold we close/cleanup only if the live refs to the cache entry is 0.
refs int64
index *faiss.IndexImpl
}

func newVectorIndexCache() *vectorIndexCache {
return &vectorIndexCache{
cache: make(map[uint16]*cacheEntry),
closeCh: make(chan struct{}),
}
}

type vectorIndexCache struct {
closeCh chan struct{}
m sync.RWMutex
cache map[uint16]*cacheEntry
}

func (vc *vectorIndexCache) Clear() {
vc.m.Lock()
close(vc.closeCh)

// forcing a close on all indexes to avoid memory leaks.
for _, entry := range vc.cache {
entry.closeIndex()
entry.close()
}
vc.cache = nil
vc.m.Unlock()
}

func (vc *vectorIndexCache) loadVectorIndex(fieldID uint16,
indexBytes []byte) (vecIndex *faiss.IndexImpl, err error) {
cachedIndex, present := vc.isIndexCached(fieldID)
if present {
vecIndex = cachedIndex
vc.incHit(fieldID)
} else {
vecIndex, err = vc.createAndCacheVectorIndex(fieldID, indexBytes)
func (vc *vectorIndexCache) loadOrCreate(fieldID uint16, mem []byte, except *roaring.Bitmap) (
index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, vecIDsToExclude []int64, err error) {
var found bool
index, vecDocIDMap, vecIDsToExclude, found = vc.loadFromCache(fieldID, except)
if !found {
index, vecDocIDMap, vecIDsToExclude, err = vc.createAndCache(fieldID, mem, except)
}
return index, vecDocIDMap, vecIDsToExclude, err
}

func (vc *vectorIndexCache) loadFromCache(fieldID uint16, except *roaring.Bitmap) (
index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, vecIDsToExclude []int64, found bool) {
vc.m.RLock()
defer vc.m.RUnlock()

entry, ok := vc.cache[fieldID]
if !ok {
return nil, nil, nil, false
}
vc.addRef(fieldID)
return vecIndex, err

index, vecDocIDMap = entry.load()
vecIDsToExclude = getVecIDsToExclude(vecDocIDMap, except)

return index, vecDocIDMap, vecIDsToExclude, true
}

func (vc *vectorIndexCache) createAndCacheVectorIndex(fieldID uint16,
indexBytes []byte) (*faiss.IndexImpl, error) {
func (vc *vectorIndexCache) createAndCache(fieldID uint16, mem []byte, except *roaring.Bitmap) (
index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, vecIDsToExclude []int64, err error) {
vc.m.Lock()
defer vc.m.Unlock()

// when there are multiple threads trying to build the index, guard redundant
// index creation by doing a double check and return if already created and
// cached.
entry, present := vc.cache[fieldID]
if present {
rv := entry.index
entry.incHit()
return rv, nil
entry, ok := vc.cache[fieldID]
if ok {
index, vecDocIDMap = entry.load()
vecIDsToExclude = getVecIDsToExclude(vecDocIDMap, except)
return index, vecDocIDMap, vecIDsToExclude, nil
}

// if the cache doesn't have vector index, just construct it out of the
// index bytes and update the cache under lock.
vecIndex, err := faiss.ReadIndexFromBuffer(indexBytes, faiss.IOFlagReadOnly)
vc.updateLOCKED(fieldID, vecIndex)
return vecIndex, err
// if the cache doesn't have entry, construct the vector to doc id map and the
// vector index out of the mem bytes and update the cache under lock.
pos := 0
numVecs, n := binary.Uvarint(mem[pos : pos+binary.MaxVarintLen64])
pos += n

vecDocIDMap = make(map[int64]uint32)
isExceptNotEmpty := except != nil && !except.IsEmpty()
for i := 0; i < int(numVecs); i++ {
vecID, n := binary.Varint(mem[pos : pos+binary.MaxVarintLen64])
pos += n
docID, n := binary.Uvarint(mem[pos : pos+binary.MaxVarintLen64])
pos += n

docIDUint32 := uint32(docID)
if isExceptNotEmpty && except.Contains(docIDUint32) {
vecIDsToExclude = append(vecIDsToExclude, vecID)
continue
}
vecDocIDMap[vecID] = docIDUint32
}

indexSize, n := binary.Uvarint(mem[pos : pos+binary.MaxVarintLen64])
pos += n

index, err = faiss.ReadIndexFromBuffer(mem[pos:pos+int(indexSize)], faiss.IOFlagReadOnly)
if err != nil {
return nil, nil, nil, err
}

vc.insertLOCKED(fieldID, index, vecDocIDMap)
return index, vecDocIDMap, vecIDsToExclude, nil
}
func (vc *vectorIndexCache) updateLOCKED(fieldIDPlus1 uint16, index *faiss.IndexImpl) {

func (vc *vectorIndexCache) insertLOCKED(fieldIDPlus1 uint16,
index *faiss.IndexImpl, vecDocIDMap map[int64]uint32) {
// the first time we've hit the cache, try to spawn a monitoring routine
// which will reconcile the moving averages for all the fields being hit
if len(vc.cache) == 0 {
Expand All @@ -117,40 +142,25 @@ func (vc *vectorIndexCache) updateLOCKED(fieldIDPlus1 uint16, index *faiss.Index
// this makes the average to be kept above the threshold value for a
// longer time and thereby the index to be resident in the cache
// for longer time.
vc.cache[fieldIDPlus1] = createCacheEntry(index, 0.4)
vc.cache[fieldIDPlus1] = createCacheEntry(index, vecDocIDMap, 0.4)
}
}

func (vc *vectorIndexCache) isIndexCached(fieldID uint16) (*faiss.IndexImpl, bool) {
vc.m.RLock()
entry, present := vc.cache[fieldID]
vc.m.RUnlock()
if entry == nil {
return nil, false
}

rv := entry.index
return rv, present && (rv != nil)
}

func (vc *vectorIndexCache) incHit(fieldIDPlus1 uint16) {
vc.m.RLock()
entry := vc.cache[fieldIDPlus1]
entry.incHit()
vc.m.RUnlock()
}

func (vc *vectorIndexCache) addRef(fieldIDPlus1 uint16) {
vc.m.RLock()
entry := vc.cache[fieldIDPlus1]
entry.addRef()
entry, ok := vc.cache[fieldIDPlus1]
if ok {
entry.incHit()
}
vc.m.RUnlock()
}

func (vc *vectorIndexCache) decRef(fieldIDPlus1 uint16) {
vc.m.RLock()
entry := vc.cache[fieldIDPlus1]
entry.decRef()
entry, ok := vc.cache[fieldIDPlus1]
if ok {
entry.decRef()
}
vc.m.RUnlock()
}

Expand All @@ -170,8 +180,8 @@ func (vc *vectorIndexCache) cleanup() bool {
// this index.
if entry.tracker.avg <= (1-entry.tracker.alpha) && refCount <= 0 {
atomic.StoreUint64(&entry.tracker.sample, 0)
entry.closeIndex()
delete(vc.cache, fieldIDPlus1)
entry.close()
continue
}
atomic.StoreUint64(&entry.tracker.sample, 0)
Expand Down Expand Up @@ -200,6 +210,18 @@ func (vc *vectorIndexCache) monitor() {
}
}

// -----------------------------------------------------------------------------

type ewma struct {
alpha float64
avg float64
// every hit to the cache entry is recorded as part of a sample
// which will be used to calculate the average in the next cycle of average
// computation (which is average traffic for the field till now). this is
// used to track the per second hits to the cache entries.
sample uint64
}

func (e *ewma) add(val uint64) {
if e.avg == 0.0 {
e.avg = float64(val)
Expand All @@ -210,16 +232,32 @@ func (e *ewma) add(val uint64) {
}
}

func createCacheEntry(index *faiss.IndexImpl, alpha float64) *cacheEntry {
// -----------------------------------------------------------------------------

func createCacheEntry(index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, alpha float64) *cacheEntry {
return &cacheEntry{
index: index,
index: index,
vecDocIDMap: vecDocIDMap,
tracker: &ewma{
alpha: alpha,
sample: 1,
},
refs: 1,
}
}

type cacheEntry struct {
tracker *ewma

// this is used to track the live references to the cache entry,
// such that while we do a cleanup() and we see that the avg is below a
// threshold we close/cleanup only if the live refs to the cache entry is 0.
refs int64

index *faiss.IndexImpl
vecDocIDMap map[int64]uint32
}

func (ce *cacheEntry) incHit() {
atomic.AddUint64(&ce.tracker.sample, 1)
}
Expand All @@ -232,9 +270,29 @@ func (ce *cacheEntry) decRef() {
atomic.AddInt64(&ce.refs, -1)
}

func (ce *cacheEntry) closeIndex() {
func (ce *cacheEntry) load() (*faiss.IndexImpl, map[int64]uint32) {
ce.incHit()
ce.addRef()
return ce.index, ce.vecDocIDMap
}

func (ce *cacheEntry) close() {
go func() {
ce.index.Close()
ce.index = nil
ce.vecDocIDMap = nil
}()
}

// -----------------------------------------------------------------------------

func getVecIDsToExclude(vecDocIDMap map[int64]uint32, except *roaring.Bitmap) (vecIDsToExclude []int64) {
if except != nil && !except.IsEmpty() {
for vecID, docID := range vecDocIDMap {
if except.Contains(docID) {
vecIDsToExclude = append(vecIDsToExclude, vecID)
}
}
}
return vecIDsToExclude
}
30 changes: 4 additions & 26 deletions faiss_vector_posting.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ func (sb *SegmentBase) InterpretVectorIndex(field string, except *roaring.Bitmap
segment.VectorIndex, error) {
// Params needed for the closures
var vecIndex *faiss.IndexImpl
vecDocIDMap := make(map[int64]uint32)
var vecDocIDMap map[int64]uint32
var vectorIDsToExclude []int64
var fieldIDPlus1 uint16
var vecIndexSize uint64
Expand Down Expand Up @@ -371,35 +371,13 @@ func (sb *SegmentBase) InterpretVectorIndex(field string, except *roaring.Bitmap
pos += n
}

// read the number vectors indexed for this field and load the vector to docID mapping.
// todo: cache the vecID to docIDs mapping for a fieldID
numVecs, n := binary.Uvarint(sb.mem[pos : pos+binary.MaxVarintLen64])
pos += n
for i := 0; i < int(numVecs); i++ {
vecID, n := binary.Varint(sb.mem[pos : pos+binary.MaxVarintLen64])
pos += n
docID, n := binary.Uvarint(sb.mem[pos : pos+binary.MaxVarintLen64])
pos += n

docIDUint32 := uint32(docID)
if except != nil && except.Contains(docIDUint32) {
// populate the list of vector IDs to be ignored on search
vectorIDsToExclude = append(vectorIDsToExclude, vecID)
// also, skip adding entry to vecDocIDMap
continue
}
vecDocIDMap[vecID] = docIDUint32
}

indexSize, n := binary.Uvarint(sb.mem[pos : pos+binary.MaxVarintLen64])
pos += n

vecIndex, err = sb.vecIndexCache.loadVectorIndex(fieldIDPlus1, sb.mem[pos:pos+int(indexSize)])
pos += int(indexSize)
vecIndex, vecDocIDMap, vectorIDsToExclude, err =
sb.vecIndexCache.loadOrCreate(fieldIDPlus1, sb.mem[pos:], except)

if vecIndex != nil {
vecIndexSize = vecIndex.Size()
}

return wrapVecIndex, err
}

Expand Down

0 comments on commit eeb2336

Please sign in to comment.