Skip to content

Commit

Permalink
Consistent calls to getInvalidVecs(..)* -> getVecIDsToExclude(..)
Browse files Browse the repository at this point in the history
+ Increase ref counts within locking (read or write) to avoid any
  possibility of raciness. This includes invoking cacheEntry.load().
+ Also refactors getInvalidVecs to getVecIDsToExclude.
  • Loading branch information
abhinavdangeti committed Apr 17, 2024
1 parent 5e3882f commit f756ea7
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions faiss_vector_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,25 +54,28 @@ func (vc *vectorIndexCache) Clear() {

func (vc *vectorIndexCache) loadOrCreate(fieldID uint16, mem []byte, except *roaring.Bitmap) (
index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, vecIDsToExclude []int64, err error) {
entry := vc.fetch(fieldID)
if entry != nil {
index, vecDocIDMap = entry.load()
vecIDsToExclude = getInvalidVecs(vecDocIDMap, except)
} else {
var found bool
index, vecDocIDMap, vecIDsToExclude, found = vc.fetchFromCache(fieldID, except)
if !found {
index, vecDocIDMap, vecIDsToExclude, err = vc.createAndCache(fieldID, mem, except)
}
return index, vecDocIDMap, vecIDsToExclude, err
}

func (vc *vectorIndexCache) fetch(fieldID uint16) *cacheEntry {
func (vc *vectorIndexCache) fetchFromCache(fieldID uint16, except *roaring.Bitmap) (
index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, vecIDsToExclude []int64, found bool) {
vc.m.RLock()
defer vc.m.RUnlock()

entry, ok := vc.cache[fieldID]
if !ok {
return nil
return nil, nil, nil, false
}
return entry

index, vecDocIDMap = entry.load()
vecIDsToExclude = getVecIDsToExclude(vecDocIDMap, except)

return index, vecDocIDMap, vecIDsToExclude, true
}

func (vc *vectorIndexCache) createAndCache(fieldID uint16, mem []byte, except *roaring.Bitmap) (
Expand All @@ -86,7 +89,7 @@ func (vc *vectorIndexCache) createAndCache(fieldID uint16, mem []byte, except *r
entry, ok := vc.cache[fieldID]
if ok {
index, vecDocIDMap = entry.load()
vecIDsToExclude = getInvalidVecs(vecDocIDMap, except)
vecIDsToExclude = getVecIDsToExclude(vecDocIDMap, except)
return index, vecDocIDMap, vecIDsToExclude, nil
}

Expand Down Expand Up @@ -281,7 +284,9 @@ func (ce *cacheEntry) close() {
}()
}

func getInvalidVecs(vecDocIDMap map[int64]uint32, except *roaring.Bitmap) (vecIDsToExclude []int64) {
// -----------------------------------------------------------------------------

func getVecIDsToExclude(vecDocIDMap map[int64]uint32, except *roaring.Bitmap) (vecIDsToExclude []int64) {
if except != nil && !except.IsEmpty() {
for vecID, docID := range vecDocIDMap {
if except.Contains(docID) {
Expand Down

0 comments on commit f756ea7

Please sign in to comment.