Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MB-61029: Caching Vec To DocID Map #231

Merged
merged 6 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 103 additions & 52 deletions faiss_vector_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
package zap

import (
"encoding/binary"
"sync"
"sync/atomic"
"time"

"github.com/RoaringBitmap/roaring"
faiss "github.com/blevesearch/go-faiss"
)

Expand All @@ -41,15 +43,6 @@ type ewma struct {
sample uint64
}

type cacheEntry struct {
tracker *ewma
// this is used to track the live references to the cache entry,
// such that while we do a cleanup() and we see that the avg is below a
// threshold we close/cleanup only if the live refs to the cache entry is 0.
refs int64
index *faiss.IndexImpl
}

func newVectorIndexCache() *vectorIndexCache {
return &vectorIndexCache{
cache: make(map[uint16]*cacheEntry),
Expand All @@ -63,47 +56,89 @@ func (vc *vectorIndexCache) Clear() {

// forcing a close on all indexes to avoid memory leaks.
for _, entry := range vc.cache {
entry.closeIndex()
entry.close()
}
vc.cache = nil
vc.m.Unlock()
}

func (vc *vectorIndexCache) loadVectorIndex(fieldID uint16,
indexBytes []byte) (vecIndex *faiss.IndexImpl, err error) {
cachedIndex, present := vc.isIndexCached(fieldID)
if present {
vecIndex = cachedIndex
vc.incHit(fieldID)
func (vc *vectorIndexCache) loadOrCreate(fieldID uint16, mem []byte, except *roaring.Bitmap) (
index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, vecIDsToExclude []int64, err error) {
entry := vc.fetch(fieldID)
if entry != nil {
index, vecDocIDMap = entry.load()
if except != nil && !except.IsEmpty() {
for vecID, docID := range vecDocIDMap {
Likith101 marked this conversation as resolved.
Show resolved Hide resolved
if except.Contains(docID) {
vecIDsToExclude = append(vecIDsToExclude, vecID)
}
}
}
} else {
vecIndex, err = vc.createAndCacheVectorIndex(fieldID, indexBytes)
index, vecDocIDMap, vecIDsToExclude, err = vc.createAndCache(fieldID, mem, except)
}
vc.addRef(fieldID)
return vecIndex, err
return index, vecDocIDMap, vecIDsToExclude, err
}

func (vc *vectorIndexCache) createAndCacheVectorIndex(fieldID uint16,
indexBytes []byte) (*faiss.IndexImpl, error) {
func (vc *vectorIndexCache) createAndCache(fieldID uint16, mem []byte, except *roaring.Bitmap) (
index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, vecIDsToExclude []int64, err error) {
vc.m.Lock()
defer vc.m.Unlock()

// when there are multiple threads trying to build the index, guard redundant
// index creation by doing a double check and return if already created and
// cached.
entry, present := vc.cache[fieldID]
if present {
rv := entry.index
entry.incHit()
return rv, nil
entry, exists := vc.cache[fieldID]
if exists {
entry.addRef()
index, vecDocIDMap = entry.load()
Likith101 marked this conversation as resolved.
Show resolved Hide resolved
if except != nil && !except.IsEmpty() {
for vecID, docID := range vecDocIDMap {
if except.Contains(docID) {
vecIDsToExclude = append(vecIDsToExclude, vecID)
}
}
}

return index, vecDocIDMap, vecIDsToExclude, nil
}

// if the cache doesn't have entry, construct the vector to doc id map and the
// vector index out of the mem bytes and update the cache under lock.
pos := 0
numVecs, n := binary.Uvarint(mem[pos : pos+binary.MaxVarintLen64])
pos += n

vecDocIDMap = make(map[int64]uint32)
isExceptNotEmpty := except != nil && !except.IsEmpty()
for i := 0; i < int(numVecs); i++ {
abhinavdangeti marked this conversation as resolved.
Show resolved Hide resolved
vecID, n := binary.Varint(mem[pos : pos+binary.MaxVarintLen64])
pos += n
docID, n := binary.Uvarint(mem[pos : pos+binary.MaxVarintLen64])
pos += n

docIDUint32 := uint32(docID)
if isExceptNotEmpty && except.Contains(docIDUint32) {
vecIDsToExclude = append(vecIDsToExclude, vecID)
continue
}
vecDocIDMap[vecID] = docIDUint32
}

// if the cache doesn't have vector index, just construct it out of the
// index bytes and update the cache under lock.
vecIndex, err := faiss.ReadIndexFromBuffer(indexBytes, faiss.IOFlagReadOnly)
vc.updateLOCKED(fieldID, vecIndex)
return vecIndex, err
indexSize, n := binary.Uvarint(mem[pos : pos+binary.MaxVarintLen64])
pos += n

index, err = faiss.ReadIndexFromBuffer(mem[pos:pos+int(indexSize)], faiss.IOFlagReadOnly)
if err != nil {
return nil, nil, nil, err
}

vc.upsertLOCKED(fieldID, index, vecDocIDMap)
return index, vecDocIDMap, vecIDsToExclude, nil
}
func (vc *vectorIndexCache) updateLOCKED(fieldIDPlus1 uint16, index *faiss.IndexImpl) {

func (vc *vectorIndexCache) upsertLOCKED(fieldIDPlus1 uint16,
index *faiss.IndexImpl, vecDocIDMap map[int64]uint32) {
// the first time we've hit the cache, try to spawn a monitoring routine
// which will reconcile the moving averages for all the fields being hit
if len(vc.cache) == 0 {
Expand All @@ -117,20 +152,21 @@ func (vc *vectorIndexCache) updateLOCKED(fieldIDPlus1 uint16, index *faiss.Index
// this makes the average to be kept above the threshold value for a
// longer time and thereby the index to be resident in the cache
// for longer time.
vc.cache[fieldIDPlus1] = createCacheEntry(index, 0.4)
vc.cache[fieldIDPlus1] = createCacheEntry(index, vecDocIDMap, 0.4)
}
}

func (vc *vectorIndexCache) isIndexCached(fieldID uint16) (*faiss.IndexImpl, bool) {
func (vc *vectorIndexCache) fetch(fieldID uint16) *cacheEntry {
vc.m.RLock()
entry, present := vc.cache[fieldID]
vc.m.RUnlock()
if entry == nil {
return nil, false
defer vc.m.RUnlock()

entry, exists := vc.cache[fieldID]
if !exists || entry == nil || entry.index == nil {
return nil
}

rv := entry.index
return rv, present && (rv != nil)
entry.addRef()
return entry
}

func (vc *vectorIndexCache) incHit(fieldIDPlus1 uint16) {
Expand All @@ -140,17 +176,12 @@ func (vc *vectorIndexCache) incHit(fieldIDPlus1 uint16) {
vc.m.RUnlock()
}

func (vc *vectorIndexCache) addRef(fieldIDPlus1 uint16) {
vc.m.RLock()
entry := vc.cache[fieldIDPlus1]
entry.addRef()
vc.m.RUnlock()
}

func (vc *vectorIndexCache) decRef(fieldIDPlus1 uint16) {
vc.m.RLock()
entry := vc.cache[fieldIDPlus1]
entry.decRef()
if entry != nil {
entry.decRef()
}
vc.m.RUnlock()
}

Expand All @@ -170,7 +201,7 @@ func (vc *vectorIndexCache) cleanup() bool {
// this index.
if entry.tracker.avg <= (1-entry.tracker.alpha) && refCount <= 0 {
atomic.StoreUint64(&entry.tracker.sample, 0)
entry.closeIndex()
entry.close()
delete(vc.cache, fieldIDPlus1)
continue
}
Expand Down Expand Up @@ -210,16 +241,30 @@ func (e *ewma) add(val uint64) {
}
}

func createCacheEntry(index *faiss.IndexImpl, alpha float64) *cacheEntry {
func createCacheEntry(index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, alpha float64) *cacheEntry {
return &cacheEntry{
index: index,
index: index,
vecDocIDMap: vecDocIDMap,
tracker: &ewma{
alpha: alpha,
sample: 1,
},
refs: 1,
}
}

type cacheEntry struct {
tracker *ewma

// this is used to track the live references to the cache entry,
// such that while we do a cleanup() and we see that the avg is below a
// threshold we close/cleanup only if the live refs to the cache entry is 0.
refs int64

index *faiss.IndexImpl
vecDocIDMap map[int64]uint32
}

func (ce *cacheEntry) incHit() {
atomic.AddUint64(&ce.tracker.sample, 1)
}
Expand All @@ -232,9 +277,15 @@ func (ce *cacheEntry) decRef() {
atomic.AddInt64(&ce.refs, -1)
}

func (ce *cacheEntry) closeIndex() {
func (ce *cacheEntry) load() (*faiss.IndexImpl, map[int64]uint32) {
ce.incHit()
return ce.index, ce.vecDocIDMap
}

func (ce *cacheEntry) close() {
go func() {
ce.index.Close()
ce.index = nil
ce.vecDocIDMap = nil
}()
}
30 changes: 4 additions & 26 deletions faiss_vector_posting.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ func (sb *SegmentBase) InterpretVectorIndex(field string, except *roaring.Bitmap
segment.VectorIndex, error) {
// Params needed for the closures
var vecIndex *faiss.IndexImpl
vecDocIDMap := make(map[int64]uint32)
var vecDocIDMap map[int64]uint32
var vectorIDsToExclude []int64
var fieldIDPlus1 uint16
var vecIndexSize uint64
Expand Down Expand Up @@ -371,35 +371,13 @@ func (sb *SegmentBase) InterpretVectorIndex(field string, except *roaring.Bitmap
pos += n
}

// read the number vectors indexed for this field and load the vector to docID mapping.
// todo: cache the vecID to docIDs mapping for a fieldID
numVecs, n := binary.Uvarint(sb.mem[pos : pos+binary.MaxVarintLen64])
pos += n
for i := 0; i < int(numVecs); i++ {
vecID, n := binary.Varint(sb.mem[pos : pos+binary.MaxVarintLen64])
pos += n
docID, n := binary.Uvarint(sb.mem[pos : pos+binary.MaxVarintLen64])
pos += n

docIDUint32 := uint32(docID)
if except != nil && except.Contains(docIDUint32) {
// populate the list of vector IDs to be ignored on search
vectorIDsToExclude = append(vectorIDsToExclude, vecID)
// also, skip adding entry to vecDocIDMap
continue
}
vecDocIDMap[vecID] = docIDUint32
}

indexSize, n := binary.Uvarint(sb.mem[pos : pos+binary.MaxVarintLen64])
pos += n

vecIndex, err = sb.vecIndexCache.loadVectorIndex(fieldIDPlus1, sb.mem[pos:pos+int(indexSize)])
pos += int(indexSize)
vecIndex, vecDocIDMap, vectorIDsToExclude, err =
sb.vecIndexCache.loadOrCreate(fieldIDPlus1, sb.mem[pos:], except)

if vecIndex != nil {
vecIndexSize = vecIndex.Size()
}

return wrapVecIndex, err
}

Expand Down
Loading