Skip to content

Commit

Permalink
LUCENE-9004: KNN vector search using NSW graphs (apache#2022)
Browse files Browse the repository at this point in the history
  • Loading branch information
msokolov authored and msfroh committed Nov 18, 2020
1 parent f3db918 commit d68f264
Show file tree
Hide file tree
Showing 20 changed files with 2,399 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,6 @@ public float[] vectorValue(int target) throws IOException {
public BytesRef binaryValue(int targetOrd) throws IOException {
throw new UnsupportedOperationException();
}

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,17 @@

/**
* Lucene 9.0 vector format, which encodes dense numeric vector values.
* TODO: add support for approximate KNN search.
*
* @lucene.experimental
*/
public final class Lucene90VectorFormat extends VectorFormat {

static final String META_CODEC_NAME = "Lucene90VectorFormatMeta";
static final String VECTOR_DATA_CODEC_NAME = "Lucene90VectorFormatData";

static final String VECTOR_INDEX_CODEC_NAME = "Lucene90VectorFormatIndex";
static final String META_EXTENSION = "vem";
static final String VECTOR_DATA_EXTENSION = "vec";
static final String VECTOR_INDEX_EXTENSION = "vex";

static final int VERSION_START = 0;
static final int VERSION_CURRENT = VERSION_START;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,40 +22,58 @@
import java.nio.FloatBuffer;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;

import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.VectorReader;
import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.KnnGraphValues;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.Neighbor;
import org.apache.lucene.util.hnsw.Neighbors;

import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;

/**
* Reads vectors from the index segments.
* Reads vectors from the index segments along with index data structures supporting KNN search.
* @lucene.experimental
*/
public final class Lucene90VectorReader extends VectorReader {

private final FieldInfos fieldInfos;
private final Map<String, FieldEntry> fields = new HashMap<>();
private final IndexInput vectorData;
private final int maxDoc;
private final IndexInput vectorIndex;
private final long checksumSeed;

Lucene90VectorReader(SegmentReadState state) throws IOException {
this.fieldInfos = state.fieldInfos;
this.maxDoc = state.segmentInfo.maxDoc();

String metaFileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, Lucene90VectorFormat.META_EXTENSION);
int versionMeta = readMetadata(state, Lucene90VectorFormat.META_EXTENSION);
long[] checksumRef = new long[1];
vectorData = openDataInput(state, versionMeta, Lucene90VectorFormat.VECTOR_DATA_EXTENSION, Lucene90VectorFormat.VECTOR_DATA_CODEC_NAME, checksumRef);
vectorIndex = openDataInput(state, versionMeta, Lucene90VectorFormat.VECTOR_INDEX_EXTENSION, Lucene90VectorFormat.VECTOR_INDEX_CODEC_NAME, checksumRef);
checksumSeed = checksumRef[0];
}

private int readMetadata(SegmentReadState state, String fileExtension) throws IOException {
String metaFileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension);
int versionMeta = -1;
try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName, state.context)) {
Throwable priorE = null;
Expand All @@ -73,29 +91,32 @@ public final class Lucene90VectorReader extends VectorReader {
CodecUtil.checkFooter(meta, priorE);
}
}
return versionMeta;
}

private static IndexInput openDataInput(SegmentReadState state, int versionMeta, String fileExtension, String codecName, long[] checksumRef) throws IOException {
boolean success = false;

String vectorDataFileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, Lucene90VectorFormat.VECTOR_DATA_EXTENSION);
this.vectorData = state.directory.openInput(vectorDataFileName, state.context);
String fileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension);
IndexInput in = state.directory.openInput(fileName, state.context);
try {
int versionVectorData = CodecUtil.checkIndexHeader(vectorData,
Lucene90VectorFormat.VECTOR_DATA_CODEC_NAME,
int versionVectorData = CodecUtil.checkIndexHeader(in,
codecName,
Lucene90VectorFormat.VERSION_START,
Lucene90VectorFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix);
if (versionMeta != versionVectorData) {
throw new CorruptIndexException("Format versions mismatch: meta=" + versionMeta + ", vector data=" + versionVectorData, vectorData);
throw new CorruptIndexException("Format versions mismatch: meta=" + versionMeta + ", " + codecName + "=" + versionVectorData, in);
}
CodecUtil.retrieveChecksum(vectorData);

checksumRef[0] = CodecUtil.retrieveChecksum(in);
success = true;
} finally {
if (!success) {
IOUtils.closeWhileHandlingException(this.vectorData);
IOUtils.closeWhileHandlingException(in);
}
}
return in;
}

private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException {
Expand All @@ -104,23 +125,28 @@ private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOExce
if (info == null) {
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
}
int searchStrategyId = meta.readInt();
if (searchStrategyId < 0 || searchStrategyId >= VectorValues.SearchStrategy.values().length) {
throw new CorruptIndexException("Invalid search strategy id: " + searchStrategyId, meta);
}
VectorValues.SearchStrategy searchStrategy = VectorValues.SearchStrategy.values()[searchStrategyId];
long vectorDataOffset = meta.readVLong();
long vectorDataLength = meta.readVLong();
int dimension = meta.readInt();
int size = meta.readInt();
int[] ordToDoc = new int[size];
for (int i = 0; i < size; i++) {
int doc = meta.readVInt();
ordToDoc[i] = doc;
}
FieldEntry fieldEntry = new FieldEntry(dimension, searchStrategy, maxDoc, vectorDataOffset, vectorDataLength,
ordToDoc);
fields.put(info.name, fieldEntry);
fields.put(info.name, readField(meta));
}
}

private VectorValues.SearchStrategy readSearchStrategy(DataInput input) throws IOException {
int searchStrategyId = input.readInt();
if (searchStrategyId < 0 || searchStrategyId >= VectorValues.SearchStrategy.values().length) {
throw new CorruptIndexException("Invalid search strategy id: " + searchStrategyId, input);
}
return VectorValues.SearchStrategy.values()[searchStrategyId];
}

private FieldEntry readField(DataInput input) throws IOException {
VectorValues.SearchStrategy searchStrategy = readSearchStrategy(input);
switch(searchStrategy) {
case NONE:
return new FieldEntry(input, searchStrategy);
case DOT_PRODUCT_HNSW:
case EUCLIDEAN_HNSW:
return new HnswGraphFieldEntry(input, searchStrategy);
default:
throw new CorruptIndexException("Unknown vector search strategy: " + searchStrategy, input);
}
}

Expand All @@ -137,6 +163,7 @@ public long ramBytesUsed() {
@Override
public void checkIntegrity() throws IOException {
CodecUtil.checksumEntireFile(vectorData);
CodecUtil.checksumEntireFile(vectorIndex);
}

@Override
Expand Down Expand Up @@ -167,36 +194,80 @@ public VectorValues getVectorValues(String field) throws IOException {
return new OffHeapVectorValues(fieldEntry, bytesSlice);
}

public KnnGraphValues getGraphValues(String field) throws IOException {
FieldInfo info = fieldInfos.fieldInfo(field);
if (info == null) {
throw new IllegalArgumentException("No such field '" + field + "'");
}
FieldEntry entry = fields.get(field);
if (entry != null && entry.indexDataLength > 0) {
return getGraphValues(entry);
} else {
return KnnGraphValues.EMPTY;
}
}

private KnnGraphValues getGraphValues(FieldEntry entry) throws IOException {
if (entry.searchStrategy.isHnsw()) {
HnswGraphFieldEntry graphEntry = (HnswGraphFieldEntry) entry;
IndexInput bytesSlice = vectorIndex.slice("graph-data", entry.indexDataOffset, entry.indexDataLength);
return new IndexedKnnGraphReader(graphEntry, bytesSlice);
} else {
return KnnGraphValues.EMPTY;
}
}

@Override
public void close() throws IOException {
vectorData.close();
IOUtils.close(vectorData, vectorIndex);
}

private static class FieldEntry {

final int dimension;
final VectorValues.SearchStrategy searchStrategy;
final int maxDoc;

final long vectorDataOffset;
final long vectorDataLength;
final long indexDataOffset;
final long indexDataLength;
final int[] ordToDoc;

FieldEntry(int dimension, VectorValues.SearchStrategy searchStrategy, int maxDoc,
long vectorDataOffset, long vectorDataLength, int[] ordToDoc) {
this.dimension = dimension;
FieldEntry(DataInput input, VectorValues.SearchStrategy searchStrategy) throws IOException {
this.searchStrategy = searchStrategy;
this.maxDoc = maxDoc;
this.vectorDataOffset = vectorDataOffset;
this.vectorDataLength = vectorDataLength;
this.ordToDoc = ordToDoc;
vectorDataOffset = input.readVLong();
vectorDataLength = input.readVLong();
indexDataOffset = input.readVLong();
indexDataLength = input.readVLong();
dimension = input.readInt();
int size = input.readInt();
ordToDoc = new int[size];
for (int i = 0; i < size; i++) {
int doc = input.readVInt();
ordToDoc[i] = doc;
}
}

int size() {
return ordToDoc.length;
}
}

private static class HnswGraphFieldEntry extends FieldEntry {

final long[] ordOffsets;

HnswGraphFieldEntry(DataInput input, VectorValues.SearchStrategy searchStrategy) throws IOException {
super(input, searchStrategy);
ordOffsets = new long[size()];
long offset = 0;
for (int i = 0; i < ordOffsets.length; i++) {
offset += input.readVLong();
ordOffsets[i] = offset;
}
}
}

/** Read the vector values from the index input. This supports both iterated and random access. */
private final class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValuesProducer {

Expand Down Expand Up @@ -252,11 +323,6 @@ public BytesRef binaryValue() throws IOException {
return binaryValue;
}

@Override
public TopDocs search(float[] target, int k, int fanout) {
throw new UnsupportedOperationException();
}

@Override
public int docID() {
return doc;
Expand Down Expand Up @@ -288,6 +354,30 @@ public RandomAccessVectorValues randomAccess() {
return new OffHeapRandomAccess(dataIn.clone());
}

@Override
public TopDocs search(float[] vector, int topK, int fanout) throws IOException {
// use a seed that is fixed for the index so we get reproducible results for the same query
final Random random = new Random(checksumSeed);
Neighbors results = HnswGraph.search(vector, topK + fanout, topK + fanout, randomAccess(), getGraphValues(fieldEntry), random);
while (results.size() > topK) {
results.pop();
}
int i = 0;
ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), topK)];
boolean reversed = searchStrategy().reversed;
while (results.size() > 0) {
Neighbor n = results.pop();
float score;
if (reversed) {
score = (float) Math.exp(- n.score() / vector.length);
} else {
score = n.score();
}
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc[n.node()], score);
}
// always return >= the case where we can assert == is only when there are fewer than topK vectors in the index
return new TopDocs(new TotalHits(results.visitedCount(), TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), scoreDocs);
}

class OffHeapRandomAccess implements RandomAccessVectorValues {

Expand All @@ -296,12 +386,10 @@ class OffHeapRandomAccess implements RandomAccessVectorValues {
final BytesRef binaryValue;
final ByteBuffer byteBuffer;
final FloatBuffer floatBuffer;
final int byteSize;
final float[] value;

OffHeapRandomAccess(IndexInput dataIn) {
this.dataIn = dataIn;
byteSize = Float.BYTES * dimension();
byteBuffer = ByteBuffer.allocate(byteSize);
floatBuffer = byteBuffer.asFloatBuffer();
value = new float[dimension()];
Expand Down Expand Up @@ -342,7 +430,41 @@ private void readValue(int targetOrd) throws IOException {
dataIn.seek(offset);
dataIn.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
}
}
}

/** Read the nearest-neighbors graph from the index input */
private final class IndexedKnnGraphReader extends KnnGraphValues {

final HnswGraphFieldEntry entry;
final IndexInput dataIn;

int arcCount;
int arcUpTo;
int arc;

IndexedKnnGraphReader(HnswGraphFieldEntry entry, IndexInput dataIn) {
this.entry = entry;
this.dataIn = dataIn;
}

@Override
public void seek(int targetOrd) throws IOException {
// unsafe; no bounds checking
dataIn.seek(entry.ordOffsets[targetOrd]);
arcCount = dataIn.readInt();
arc = -1;
arcUpTo = 0;
}

@Override
public int nextNeighbor() throws IOException {
if (arcUpTo >= arcCount) {
return NO_MORE_DOCS;
}
++arcUpTo;
arc += dataIn.readVInt();
return arc;
}
}
}

0 comments on commit d68f264

Please sign in to comment.