Skip to content

Commit

Permalink
MB-61029: Caching Vec To DocID Map
Browse files Browse the repository at this point in the history
 - Generalized some of the cache functions
 - Cache will include vec to docid mapping as well as
one structure to help vec excluded calculation as well as
the vec excluded structure
 - Added back the cache mutexes because cache reads will also
write back some of the structures depending on the except bitmap
  • Loading branch information
Likith101 committed Apr 10, 2024
1 parent 97f15cb commit df4f403
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 52 deletions.
139 changes: 113 additions & 26 deletions faiss_vector_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
package zap

import (
"encoding/binary"
"sync"
"sync/atomic"
"time"

"github.com/RoaringBitmap/roaring"
faiss "github.com/blevesearch/go-faiss"
)

Expand All @@ -39,8 +41,13 @@ type ewma struct {

type cacheEntry struct {
tracker *ewma
m sync.RWMutex
refs int64
index *faiss.IndexImpl

index *faiss.IndexImpl
vecDocIDMap map[int64]uint32
docIDExcluded map[uint32]struct{}
vecExcluded []int64
}

func newVectorIndexCache() *vecIndexCache {
Expand All @@ -57,21 +64,65 @@ func (vc *vecIndexCache) Clear() {
vc.m.Unlock()
}

func (vc *vecIndexCache) loadVectorIndex(fieldID uint16,
indexBytes []byte) (vecIndex *faiss.IndexImpl, err error) {
cachedIndex, present := vc.isIndexCached(fieldID)
func (vc *vecIndexCache) loadFromCache(fieldID uint16, mem []byte,
except *roaring.Bitmap) (index *faiss.IndexImpl,
vecDocIDMap map[int64]uint32, vecExcluded []int64, err error) {
entry, present := vc.checkEntry(fieldID)

if present {
vecIndex = cachedIndex
vc.incHit(fieldID)
index, vecDocIDMap, vecExcluded = entry.load(except)
} else {
vecIndex, err = vc.createAndCacheVectorIndex(fieldID, indexBytes)
index, vecDocIDMap, vecExcluded, err =
vc.createAndCacheEntry(fieldID, mem, except)
}

vc.addRef(fieldID)
return vecIndex, err
return index, vecDocIDMap, vecExcluded, err
}

func (vc *vecIndexCache) createAndCacheVectorIndex(fieldID uint16,
indexBytes []byte) (*faiss.IndexImpl, error) {
func (vc *cacheEntry) load(except *roaring.Bitmap) (*faiss.IndexImpl,
map[int64]uint32, []int64) {

vc.m.RLock()
vecIndex := vc.index
vecDocIDMap := vc.vecDocIDMap
docIDExcluded := vc.docIDExcluded
vecExcluded := vc.vecExcluded
vc.m.RUnlock()

if except != nil {
newExcluded := false
vc.m.Lock()
it := except.Iterator()
for it.HasNext() {
docID := it.Next()
if _, exists := docIDExcluded[docID]; !exists {
docIDExcluded[docID] = struct{}{}
newExcluded = true
}
}

if newExcluded {
for vecID, docID := range vecDocIDMap {
if _, exists := docIDExcluded[docID]; exists {
delete(vecDocIDMap, vecID)
vecExcluded = append(vecExcluded, vecID)
}
}
vc.vecDocIDMap = vecDocIDMap
vc.docIDExcluded = docIDExcluded
vc.vecExcluded = vecExcluded
}
vc.m.Unlock()
}

vc.incHit()
return vecIndex, vecDocIDMap, vecExcluded
}

func (vc *vecIndexCache) createAndCacheEntry(fieldID uint16,
mem []byte, except *roaring.Bitmap) (index *faiss.IndexImpl,
vecDocIDMap map[int64]uint32, vecExcluded []int64, err error) {
vc.m.Lock()
defer vc.m.Unlock()

Expand All @@ -80,18 +131,46 @@ func (vc *vecIndexCache) createAndCacheVectorIndex(fieldID uint16,
// cached.
entry, present := vc.cache[fieldID]
if present {
rv := entry.index
entry.incHit()
return rv, nil
index, vecDocIDMap, vecExcluded = entry.load(except)
return index, vecDocIDMap, vecExcluded, nil
}

pos := 0
numVecs, n := binary.Uvarint(mem[pos : pos+binary.MaxVarintLen64])
pos += n

docIDExcluded := make(map[uint32]struct{})

for i := 0; i < int(numVecs); i++ {
vecID, n := binary.Varint(mem[pos : pos+binary.MaxVarintLen64])
pos += n
docID, n := binary.Uvarint(mem[pos : pos+binary.MaxVarintLen64])
pos += n

docIDUint32 := uint32(docID)
if except != nil && except.Contains(docIDUint32) {
// populate the list of vector IDs to be ignored on search
vecExcluded = append(vecExcluded, vecID)
docIDExcluded[docIDUint32] = struct{}{}
// also, skip adding entry to vecDocIDMap
continue
}
vecDocIDMap[vecID] = docIDUint32
}

// if the cache doesn't have vector index, just construct it out of the
// index bytes and update the cache under lock.
vecIndex, err := faiss.ReadIndexFromBuffer(indexBytes, faiss.IOFlagReadOnly)
vc.updateLOCKED(fieldID, vecIndex)
return vecIndex, err
indexSize, n := binary.Uvarint(mem[pos : pos+binary.MaxVarintLen64])
pos += n

index, err = faiss.ReadIndexFromBuffer(mem[pos:pos+int(indexSize)], faiss.IOFlagReadOnly)
if err != nil {
return nil, nil, nil, err
}

vc.updateLOCKED(fieldID, index, vecDocIDMap, docIDExcluded, vecExcluded)
return index, vecDocIDMap, vecExcluded, nil
}
func (vc *vecIndexCache) updateLOCKED(fieldIDPlus1 uint16, index *faiss.IndexImpl) {
func (vc *vecIndexCache) updateLOCKED(fieldIDPlus1 uint16, index *faiss.IndexImpl,
vecDocIDMap map[int64]uint32, docIDExcluded map[uint32]struct{}, vecExcluded []int64) {
// the first time we've hit the cache, try to spawn a monitoring routine
// which will reconcile the moving averages for all the fields being hit
if len(vc.cache) == 0 {
Expand All @@ -105,20 +184,19 @@ func (vc *vecIndexCache) updateLOCKED(fieldIDPlus1 uint16, index *faiss.IndexImp
// this makes the average to be kept above the threshold value for a
// longer time and thereby the index to be resident in the cache
// for longer time.
vc.cache[fieldIDPlus1] = initCacheEntry(index, 0.4)
vc.cache[fieldIDPlus1] = initCacheEntry(index, vecDocIDMap, docIDExcluded, vecExcluded, 0.4)
}
}

func (vc *vecIndexCache) isIndexCached(fieldID uint16) (*faiss.IndexImpl, bool) {
func (vc *vecIndexCache) checkEntry(fieldID uint16) (*cacheEntry, bool) {
vc.m.RLock()
entry, present := vc.cache[fieldID]
vc.m.RUnlock()
if entry == nil {
return nil, false
}

rv := entry.index
return rv, present && (rv != nil)
return entry, present && (entry.index != nil)
}

func (vc *vecIndexCache) incHit(fieldIDPlus1 uint16) {
Expand Down Expand Up @@ -196,10 +274,14 @@ func (e *ewma) add(val uint64) {
}
}

func initCacheEntry(index *faiss.IndexImpl, alpha float64) *cacheEntry {
func initCacheEntry(vecIndex *faiss.IndexImpl, vecDocIDMap map[int64]uint32,
docIDExcluded map[uint32]struct{}, vecExcluded []int64, alpha float64) *cacheEntry {
vc := &cacheEntry{
index: index,
tracker: &ewma{},
index: vecIndex,
vecDocIDMap: vecDocIDMap,
docIDExcluded: docIDExcluded,
vecExcluded: vecExcluded,
tracker: &ewma{},
}
vc.tracker.alpha = alpha

Expand All @@ -223,6 +305,11 @@ func (vc *cacheEntry) decRef() {
}

func (vc *cacheEntry) closeIndex() {
vc.m.Lock()
vc.index.Close()
vc.index = nil
vc.docIDExcluded = nil
vc.vecDocIDMap = nil
vc.vecExcluded = nil
vc.m.Unlock()
}
28 changes: 2 additions & 26 deletions faiss_vector_posting.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ func (sb *SegmentBase) InterpretVectorIndex(field string, except *roaring.Bitmap
segment.VectorIndex, error) {
// Params needed for the closures
var vecIndex *faiss.IndexImpl
vecDocIDMap := make(map[int64]uint32)
var vecDocIDMap map[int64]uint32
var vectorIDsToExclude []int64
var fieldIDPlus1 uint16

Expand Down Expand Up @@ -373,31 +373,7 @@ func (sb *SegmentBase) InterpretVectorIndex(field string, except *roaring.Bitmap
pos += n
}

// read the number vectors indexed for this field and load the vector to docID mapping.
// todo: cache the vecID to docIDs mapping for a fieldID
numVecs, n := binary.Uvarint(sb.mem[pos : pos+binary.MaxVarintLen64])
pos += n
for i := 0; i < int(numVecs); i++ {
vecID, n := binary.Varint(sb.mem[pos : pos+binary.MaxVarintLen64])
pos += n
docID, n := binary.Uvarint(sb.mem[pos : pos+binary.MaxVarintLen64])
pos += n

docIDUint32 := uint32(docID)
if except != nil && except.Contains(docIDUint32) {
// populate the list of vector IDs to be ignored on search
vectorIDsToExclude = append(vectorIDsToExclude, vecID)
// also, skip adding entry to vecDocIDMap
continue
}
vecDocIDMap[vecID] = docIDUint32
}

indexSize, n := binary.Uvarint(sb.mem[pos : pos+binary.MaxVarintLen64])
pos += n

vecIndex, err = sb.vectorCache.loadVectorIndex(fieldIDPlus1, sb.mem[pos:pos+int(indexSize)])
pos += int(indexSize)
vecIndex, vecDocIDMap, vectorIDsToExclude, err = sb.vectorCache.loadFromCache(fieldIDPlus1, sb.mem[pos:], except)

return wrapVecIndex, err
}
Expand Down

0 comments on commit df4f403

Please sign in to comment.