Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,10 @@ public CmdLineArgs build() {
if (docVectors == null) {
throw new IllegalArgumentException("Document vectors path must be provided");
}
if (dimensions <= 0) {
throw new IllegalArgumentException("dimensions must be a positive integer");
if (dimensions <= 0 && dimensions != -1) {
throw new IllegalArgumentException(
"dimensions must be a positive integer or -1 for when dimension is available in the vector file"
);
}
return new CmdLineArgs(
docVectors,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ public static void main(String[] args) throws Exception {
knnIndexer.numSegments(result);
}
}
if (cmdLineArgs.queryVectors() != null) {
if (cmdLineArgs.queryVectors() != null && cmdLineArgs.numQueries() > 0) {
KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs);
knnSearcher.runSearch(result);
}
Expand Down
46 changes: 32 additions & 14 deletions qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class KnnIndexer {
private final Path docsPath;
private final Path indexPath;
private final VectorEncoding vectorEncoding;
private final int dim;
private int dim;
private final VectorSimilarityFunction similarityFunction;
private final Codec codec;
private final int numDocs;
Expand Down Expand Up @@ -106,10 +106,6 @@ void createIndex(KnnIndexTester.Results result) throws IOException, InterruptedE

iwc.setMaxFullFlushMergeWaitMillis(0);

FieldType fieldType = switch (vectorEncoding) {
case BYTE -> KnnByteVectorField.createFieldType(dim, similarityFunction);
case FLOAT32 -> KnnFloatVectorField.createFieldType(dim, similarityFunction);
};
iwc.setInfoStream(new PrintStreamInfoStream(System.out) {
@Override
public boolean isEnabled(String component) {
Expand Down Expand Up @@ -137,7 +133,26 @@ public boolean isEnabled(String component) {
FileChannel in = FileChannel.open(docsPath)
) {
long docsPathSizeInBytes = in.size();
if (docsPathSizeInBytes % ((long) dim * vectorEncoding.byteSize) != 0) {
int offsetByteSize = 0;
if (dim == -1) {
offsetByteSize = 4;
ByteBuffer preamble = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN);
int bytesRead = Channels.readFromFileChannel(in, 0, preamble);
if (bytesRead < 4) {
throw new IllegalArgumentException(
"docsPath \"" + docsPath + "\" does not contain a valid dims? size=" + docsPathSizeInBytes
);
}
dim = preamble.getInt(0);
if (dim <= 0) {
throw new IllegalArgumentException("docsPath \"" + docsPath + "\" has invalid dimension: " + dim);
}
}
FieldType fieldType = switch (vectorEncoding) {
case BYTE -> KnnByteVectorField.createFieldType(dim, similarityFunction);
case FLOAT32 -> KnnFloatVectorField.createFieldType(dim, similarityFunction);
};
if (docsPathSizeInBytes % (((long) dim * vectorEncoding.byteSize + offsetByteSize)) != 0) {
throw new IllegalArgumentException(
"docsPath \"" + docsPath + "\" does not contain a whole number of vectors? size=" + docsPathSizeInBytes
);
Expand All @@ -150,7 +165,7 @@ public boolean isEnabled(String component) {
vectorEncoding.byteSize
);

VectorReader inReader = VectorReader.create(in, dim, vectorEncoding);
VectorReader inReader = VectorReader.create(in, dim, vectorEncoding, offsetByteSize);
try (ExecutorService exec = Executors.newFixedThreadPool(numIndexThreads, r -> new Thread(r, "KnnIndexer-Thread"))) {
AtomicInteger numDocsIndexed = new AtomicInteger();
List<Future<?>> threads = new ArrayList<>();
Expand Down Expand Up @@ -271,36 +286,39 @@ private void _run() throws IOException {

static class VectorReader {
final float[] target;
final int offsetByteSize;
final ByteBuffer bytes;
final FileChannel input;
long position;

static VectorReader create(FileChannel input, int dim, VectorEncoding vectorEncoding) throws IOException {
static VectorReader create(FileChannel input, int dim, VectorEncoding vectorEncoding, int offsetByteSize) throws IOException {
// check if dim is set as preamble in the file:
int bufferSize = dim * vectorEncoding.byteSize;
if (input.size() % ((long) dim * vectorEncoding.byteSize) != 0) {
if (input.size() % ((long) dim * vectorEncoding.byteSize + offsetByteSize) != 0) {
throw new IllegalArgumentException(
"vectors file \"" + input + "\" does not contain a whole number of vectors? size=" + input.size()
);
}
return new VectorReader(input, dim, bufferSize);
return new VectorReader(input, dim, bufferSize, offsetByteSize);
}

VectorReader(FileChannel input, int dim, int bufferSize) throws IOException {
VectorReader(FileChannel input, int dim, int bufferSize, int offsetByteSize) throws IOException {
this.offsetByteSize = offsetByteSize;
this.bytes = ByteBuffer.wrap(new byte[bufferSize]).order(ByteOrder.LITTLE_ENDIAN);
this.input = input;
this.target = new float[dim];
reset();
}

void reset() throws IOException {
position = 0;
position = offsetByteSize;
input.position(position);
}

private void readNext() throws IOException {
int bytesRead = Channels.readFromFileChannel(this.input, position, bytes);
if (bytesRead < bytes.capacity()) {
position = 0;
position = offsetByteSize;
bytes.position(0);
// wrap around back to the start of the file if we hit the end:
logger.warn("VectorReader hit EOF when reading " + this.input + "; now wrapping around to start of file again");
Expand All @@ -312,7 +330,7 @@ private void readNext() throws IOException {
);
}
}
position += bytesRead;
position += bytesRead + offsetByteSize;
bytes.position(0);
}

Expand Down
39 changes: 29 additions & 10 deletions qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.store.MMapDirectory;
import org.elasticsearch.common.io.Channels;
import org.elasticsearch.core.PathUtils;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.search.profile.query.QueryProfiler;
Expand Down Expand Up @@ -87,7 +88,7 @@ class KnnSearcher {
private final int efSearch;
private final int nProbe;
private final KnnIndexTester.IndexType indexType;
private final int dim;
private int dim;
private final VectorSimilarityFunction similarityFunction;
private final VectorEncoding vectorEncoding;
private final float overSamplingFactor;
Expand Down Expand Up @@ -117,6 +118,7 @@ void runSearch(KnnIndexTester.Results finalResults) throws IOException {
TopDocs[] results = new TopDocs[numQueryVectors];
int[][] resultIds = new int[numQueryVectors][];
long elapsed, totalCpuTimeMS, totalVisited = 0;
int offsetByteSize = 0;
try (
FileChannel input = FileChannel.open(queryPath);
ExecutorService executorService = Executors.newFixedThreadPool(searchThreads, r -> new Thread(r, "KnnSearcher-Thread"))
Expand All @@ -128,7 +130,19 @@ void runSearch(KnnIndexTester.Results finalResults) throws IOException {
+ " bytes, assuming vector count is "
+ (queryPathSizeInBytes / ((long) dim * vectorEncoding.byteSize))
);
KnnIndexer.VectorReader targetReader = KnnIndexer.VectorReader.create(input, dim, vectorEncoding);
if (dim == -1) {
offsetByteSize = 4;
ByteBuffer preamble = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN);
int bytesRead = Channels.readFromFileChannel(input, 0, preamble);
if (bytesRead < 4) {
throw new IllegalArgumentException("queryPath \"" + queryPath + "\" does not contain a valid dims?");
}
dim = preamble.getInt(0);
if (dim <= 0) {
throw new IllegalArgumentException("queryPath \"" + queryPath + "\" has invalid dimension: " + dim);
}
}
KnnIndexer.VectorReader targetReader = KnnIndexer.VectorReader.create(input, dim, vectorEncoding, offsetByteSize);
long startNS;
try (MMapDirectory dir = new MMapDirectory(indexPath)) {
try (DirectoryReader reader = DirectoryReader.open(dir)) {
Expand Down Expand Up @@ -191,7 +205,7 @@ void runSearch(KnnIndexTester.Results finalResults) throws IOException {
}
}
logger.info("checking results");
int[][] nn = getOrCalculateExactNN();
int[][] nn = getOrCalculateExactNN(offsetByteSize);
finalResults.avgRecall = checkResults(resultIds, nn, topK);
finalResults.qps = (1000f * numQueryVectors) / elapsed;
finalResults.avgLatency = (float) elapsed / numQueryVectors;
Expand All @@ -200,7 +214,7 @@ void runSearch(KnnIndexTester.Results finalResults) throws IOException {
finalResults.avgCpuCount = (double) totalCpuTimeMS / elapsed;
}

private int[][] getOrCalculateExactNN() throws IOException {
private int[][] getOrCalculateExactNN(int vectorFileOffsetBytes) throws IOException {
// look in working directory for cached nn file
String hash = Integer.toString(
Objects.hash(
Expand Down Expand Up @@ -228,9 +242,9 @@ private int[][] getOrCalculateExactNN() throws IOException {
// checking low-precision recall
int[][] nn;
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
nn = computeExactNNByte(queryPath);
nn = computeExactNNByte(queryPath, vectorFileOffsetBytes);
} else {
nn = computeExactNN(queryPath);
nn = computeExactNN(queryPath, vectorFileOffsetBytes);
}
writeExactNN(nn, nnPath);
long elapsedMS = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startNS); // ns -> ms
Expand Down Expand Up @@ -356,12 +370,17 @@ private void writeExactNN(int[][] nn, Path nnPath) throws IOException {
}
}

private int[][] computeExactNN(Path queryPath) throws IOException {
private int[][] computeExactNN(Path queryPath, int vectorFileOffsetBytes) throws IOException {
int[][] result = new int[numQueryVectors][];
try (Directory dir = FSDirectory.open(indexPath); DirectoryReader reader = DirectoryReader.open(dir)) {
List<Callable<Void>> tasks = new ArrayList<>();
try (FileChannel qIn = FileChannel.open(queryPath)) {
KnnIndexer.VectorReader queryReader = KnnIndexer.VectorReader.create(qIn, dim, VectorEncoding.FLOAT32);
KnnIndexer.VectorReader queryReader = KnnIndexer.VectorReader.create(
qIn,
dim,
VectorEncoding.FLOAT32,
vectorFileOffsetBytes
);
for (int i = 0; i < numQueryVectors; i++) {
float[] queryVector = new float[dim];
queryReader.next(queryVector);
Expand All @@ -373,12 +392,12 @@ private int[][] computeExactNN(Path queryPath) throws IOException {
}
}

private int[][] computeExactNNByte(Path queryPath) throws IOException {
private int[][] computeExactNNByte(Path queryPath, int vectorFileOffsetBytes) throws IOException {
int[][] result = new int[numQueryVectors][];
try (Directory dir = FSDirectory.open(indexPath); DirectoryReader reader = DirectoryReader.open(dir)) {
List<Callable<Void>> tasks = new ArrayList<>();
try (FileChannel qIn = FileChannel.open(queryPath)) {
KnnIndexer.VectorReader queryReader = KnnIndexer.VectorReader.create(qIn, dim, VectorEncoding.BYTE);
KnnIndexer.VectorReader queryReader = KnnIndexer.VectorReader.create(qIn, dim, VectorEncoding.BYTE, vectorFileOffsetBytes);
for (int i = 0; i < numQueryVectors; i++) {
byte[] queryVector = new byte[dim];
queryReader.next(queryVector);
Expand Down