From 8e7331e95ddd39823121d7500b4667723c10e084 Mon Sep 17 00:00:00 2001
From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com>
Date: Tue, 22 Apr 2025 16:45:13 -0400
Subject: [PATCH 1/3] New bulk scorer for binary quantized vectors via
optimized scalar quantization
---
.../benchmark/vector/OSQScorerBenchmark.java | 204 ++++++++
.../DefaultESVectorizationProvider.java | 9 +
.../vectorization/ES91OSQVectorsScorer.java | 166 +++++++
.../ESVectorizationProvider.java | 6 +
.../ESVectorizationProvider.java | 5 +
.../MemorySegmentES91OSQVectorsScorer.java | 450 ++++++++++++++++++
.../PanamaESVectorizationProvider.java | 17 +
.../ES91OSQVectorScorerTests.java | 103 ++++
8 files changed, 960 insertions(+)
create mode 100644 benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java
create mode 100644 libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorsScorer.java
create mode 100644 libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java
create mode 100644 libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java
diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java
new file mode 100644
index 0000000000000..486919f10bcf5
--- /dev/null
+++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java
@@ -0,0 +1,204 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+package org.elasticsearch.benchmark.vector;
+
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.store.IOContext;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.store.IndexOutput;
+import org.apache.lucene.store.MMapDirectory;
+import org.apache.lucene.util.VectorUtil;
+import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
+import org.elasticsearch.common.logging.LogConfigurator;
+import org.elasticsearch.simdvec.internal.vectorization.ES91OSQVectorsScorer;
+import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Fork;
+import org.openjdk.jmh.annotations.Measurement;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.OutputTimeUnit;
+import org.openjdk.jmh.annotations.Param;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.Warmup;
+import org.openjdk.jmh.infra.Blackhole;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.util.Random;
+import java.util.concurrent.TimeUnit;
+
+@BenchmarkMode(Mode.Throughput)
+@OutputTimeUnit(TimeUnit.MILLISECONDS)
+@State(Scope.Benchmark)
+// first iteration is complete garbage, so make sure we really warmup
+@Warmup(iterations = 4, time = 1)
+// real iterations. not useful to spend tons of time here, better to fork more
+@Measurement(iterations = 5, time = 1)
+// engage some noise reduction
+@Fork(value = 1)
+public class OSQScorerBenchmark {
+
+ static {
+ LogConfigurator.configureESLogging(); // native access requires logging to be initialized
+ }
+
+ @Param({ "1024" })
+ int dims;
+
+ int length;
+
+ int numVectors = ES91OSQVectorsScorer.BULK_SIZE * 10;
+ int numQueries = 10;
+
+ byte[][] binaryVectors;
+ byte[][] binaryQueries;
+ OptimizedScalarQuantizer.QuantizationResult result;
+ float centroidDp;
+
+ byte[] scratch;
+ ES91OSQVectorsScorer scorer;
+
+ IndexInput in;
+
+ float[] scratchScores;
+ float[] corrections;
+
+ @Setup
+ public void setup() throws IOException {
+ Random random = new Random(123);
+
+ this.length = OptimizedScalarQuantizer.discretize(dims, 64) / 8;
+
+ binaryVectors = new byte[numVectors][length];
+ for (byte[] binaryVector : binaryVectors) {
+ random.nextBytes(binaryVector);
+ }
+
+ Directory dir = new MMapDirectory(Files.createTempDirectory("vectorData"));
+ IndexOutput out = dir.createOutput("vectors", IOContext.DEFAULT);
+ byte[] correctionBytes = new byte[14 * ES91OSQVectorsScorer.BULK_SIZE];
+ for (int i = 0; i < numVectors; i += ES91OSQVectorsScorer.BULK_SIZE) {
+ for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
+ out.writeBytes(binaryVectors[i + j], 0, binaryVectors[i + j].length);
+ }
+ random.nextBytes(correctionBytes);
+ out.writeBytes(correctionBytes, 0, correctionBytes.length);
+ }
+ out.close();
+ in = dir.openInput("vectors", IOContext.DEFAULT);
+
+ binaryQueries = new byte[numVectors][4 * length];
+ for (byte[] binaryVector : binaryVectors) {
+ random.nextBytes(binaryVector);
+ }
+ result = new OptimizedScalarQuantizer.QuantizationResult(
+ random.nextFloat(),
+ random.nextFloat(),
+ random.nextFloat(),
+ Short.toUnsignedInt((short) random.nextInt())
+ );
+ centroidDp = random.nextFloat();
+
+ scratch = new byte[length];
+ scorer = ESVectorizationProvider.getInstance().newES91OSQVectorsScorer(in, dims);
+ scratchScores = new float[16];
+ corrections = new float[3];
+ }
+
+ @Benchmark
+ @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
+ public void scoreFromArray(Blackhole bh) throws IOException {
+ for (int j = 0; j < numQueries; j++) {
+ in.seek(0);
+ for (int i = 0; i < numVectors; i++) {
+ in.readBytes(scratch, 0, length);
+ float qDist = VectorUtil.int4BitDotProduct(binaryQueries[j], scratch);
+ in.readFloats(corrections, 0, corrections.length);
+ int addition = Short.toUnsignedInt(in.readShort());
+ float score = scorer.score(
+ result,
+ VectorSimilarityFunction.EUCLIDEAN,
+ centroidDp,
+ corrections[0],
+ corrections[1],
+ addition,
+ corrections[2],
+ qDist
+ );
+ bh.consume(score);
+ }
+ }
+ }
+
+ @Benchmark
+ @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
+ public void scoreFromMemorySegmentOnlyVector(Blackhole bh) throws IOException {
+ for (int j = 0; j < numQueries; j++) {
+ in.seek(0);
+ for (int i = 0; i < numVectors; i++) {
+ float qDist = scorer.quantizeScore(binaryQueries[j]);
+ in.readFloats(corrections, 0, corrections.length);
+ int addition = Short.toUnsignedInt(in.readShort());
+ float score = scorer.score(
+ result,
+ VectorSimilarityFunction.EUCLIDEAN,
+ centroidDp,
+ corrections[0],
+ corrections[1],
+ addition,
+ corrections[2],
+ qDist
+ );
+ bh.consume(score);
+ }
+ }
+ }
+
+ @Benchmark
+ @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
+ public void scoreFromMemorySegmentOnlyVectorBulk(Blackhole bh) throws IOException {
+ for (int j = 0; j < numQueries; j++) {
+ in.seek(0);
+ for (int i = 0; i < numVectors; i += 16) {
+ scorer.quantizeScoreBulk(binaryQueries[j], ES91OSQVectorsScorer.BULK_SIZE, scratchScores);
+ for (int k = 0; k < ES91OSQVectorsScorer.BULK_SIZE; k++) {
+ in.readFloats(corrections, 0, corrections.length);
+ int addition = Short.toUnsignedInt(in.readShort());
+ float score = scorer.score(
+ result,
+ VectorSimilarityFunction.EUCLIDEAN,
+ centroidDp,
+ corrections[0],
+ corrections[1],
+ addition,
+ corrections[2],
+ scratchScores[k]
+ );
+ bh.consume(score);
+ }
+ }
+ }
+ }
+
+ @Benchmark
+ @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
+ public void scoreFromMemorySegmentAllBulk(Blackhole bh) throws IOException {
+ for (int j = 0; j < numQueries; j++) {
+ in.seek(0);
+ for (int i = 0; i < numVectors; i += 16) {
+ scorer.scoreBulk(binaryQueries[j], result, VectorSimilarityFunction.EUCLIDEAN, centroidDp, scratchScores);
+ bh.consume(scratchScores);
+ }
+ }
+ }
+}
diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java
index 6c0f7ed146b86..e8ff6f83f2172 100644
--- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java
+++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java
@@ -9,6 +9,10 @@
package org.elasticsearch.simdvec.internal.vectorization;
+import org.apache.lucene.store.IndexInput;
+
+import java.io.IOException;
+
final class DefaultESVectorizationProvider extends ESVectorizationProvider {
private final ESVectorUtilSupport vectorUtilSupport;
@@ -20,4 +24,9 @@ final class DefaultESVectorizationProvider extends ESVectorizationProvider {
public ESVectorUtilSupport getVectorUtilSupport() {
return vectorUtilSupport;
}
+
+ @Override
+ public ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException {
+ return new ES91OSQVectorsScorer(input, dimension);
+ }
}
diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorsScorer.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorsScorer.java
new file mode 100644
index 0000000000000..66a8e5b89855d
--- /dev/null
+++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorsScorer.java
@@ -0,0 +1,166 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+package org.elasticsearch.simdvec.internal.vectorization;
+
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.util.BitUtil;
+import org.apache.lucene.util.VectorUtil;
+import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
+
+import java.io.IOException;
+
+import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
+import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
+
+/** Scorer for quantized vectors stored as an {@link IndexInput}. */
+public class ES91OSQVectorsScorer {
+
+ public static final int BULK_SIZE = 16;
+
+ protected static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1);
+
+ /** The wrapper {@link IndexInput}. */
+ protected final IndexInput in;
+
+ protected final int length;
+ protected final int dimensions;
+
+ protected final float[] lowerIntervals = new float[BULK_SIZE];
+ protected final float[] upperIntervals = new float[BULK_SIZE];
+ protected final int[] targetComponentSums = new int[BULK_SIZE];
+ protected final float[] additionalCorrections = new float[BULK_SIZE];
+
+ /** Sole constructor, called by sub-classes. */
+ public ES91OSQVectorsScorer(IndexInput in, int dimensions) {
+ this.in = in;
+ this.dimensions = dimensions;
+ this.length = OptimizedScalarQuantizer.discretize(dimensions, 64) / 8;
+ }
+
+ /**
+ * compute the quantize distance between the provided quantized query and the quantized vector
+ * that is read from the wrapped {@link IndexInput}.
+ */
+ public long quantizeScore(byte[] q) throws IOException {
+ assert q.length == length * 4;
+ final int size = length;
+ long subRet0 = 0;
+ long subRet1 = 0;
+ long subRet2 = 0;
+ long subRet3 = 0;
+ int r = 0;
+ for (final int upperBound = size & -Long.BYTES; r < upperBound; r += Long.BYTES) {
+ final long value = in.readLong();
+ subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r) & value);
+ subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r + size) & value);
+ subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r + 2 * size) & value);
+ subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, r + 3 * size) & value);
+ }
+ for (final int upperBound = size & -Integer.BYTES; r < upperBound; r += Integer.BYTES) {
+ final int value = in.readInt();
+ subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r) & value);
+ subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r + size) & value);
+ subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r + 2 * size) & value);
+ subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, r + 3 * size) & value);
+ }
+ for (; r < size; r++) {
+ final byte value = in.readByte();
+ subRet0 += Integer.bitCount((q[r] & value) & 0xFF);
+ subRet1 += Integer.bitCount((q[r + size] & value) & 0xFF);
+ subRet2 += Integer.bitCount((q[r + 2 * size] & value) & 0xFF);
+ subRet3 += Integer.bitCount((q[r + 3 * size] & value) & 0xFF);
+ }
+ return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
+ }
+
+ /**
+ * compute the quantize distance between the provided quantized query and the quantized vectors
+ * that are read from the wrapped {@link IndexInput}. The number of quantized vectors to read is
+ * determined by {code count} and the results are stored in the provided {@code scores} array.
+ */
+ public void quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOException {
+ for (int i = 0; i < count; i++) {
+ scores[i] = quantizeScore(q);
+ }
+ }
+
+ /**
+ * Computes the score by applying the necessary corrections to the provided quantized distance.
+ */
+ public float score(
+ OptimizedScalarQuantizer.QuantizationResult queryCorrections,
+ VectorSimilarityFunction similarityFunction,
+ float centroidDp,
+ float lowerInterval,
+ float upperInterval,
+ int targetComponentSum,
+ float additionalCorrection,
+ float qcDist
+ ) {
+ float ax = lowerInterval;
+ // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
+ float lx = upperInterval - ax;
+ float ay = queryCorrections.lowerInterval();
+ float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
+ float y1 = queryCorrections.quantizedComponentSum();
+ float score = ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
+ // For euclidean, we need to invert the score and apply the additional correction, which is
+ // assumed to be the squared l2norm of the centroid centered vectors.
+ if (similarityFunction == EUCLIDEAN) {
+ score = queryCorrections.additionalCorrection() + additionalCorrection - 2 * score;
+ return Math.max(1 / (1f + score), 0);
+ } else {
+ // For cosine and max inner product, we need to apply the additional correction, which is
+ // assumed to be the non-centered dot-product between the vector and the centroid
+ score += queryCorrections.additionalCorrection() + additionalCorrection - centroidDp;
+ if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
+ return VectorUtil.scaleMaxInnerProductScore(score);
+ }
+ return Math.max((1f + score) / 2f, 0);
+ }
+ }
+
+ /**
+ * compute the distance between the provided quantized query and the quantized vectors that are
+ * read from the wrapped {@link IndexInput}.
+ *
+ *
The number of vectors to score is defined by {@link #BULK_SIZE}. The expected format of the
+ * input is as follows: First the quantized vectors are read from the input,then all the lower
+ * intervals as floats, then all the upper intervals as floats, then all the target component sums
+ * as shorts, and finally all the additional corrections as floats.
+ *
+ *
The results are stored in the provided scores array.
+ */
+ public void scoreBulk(
+ byte[] q,
+ OptimizedScalarQuantizer.QuantizationResult queryCorrections,
+ VectorSimilarityFunction similarityFunction,
+ float centroidDp,
+ float[] scores
+ ) throws IOException {
+ quantizeScoreBulk(q, BULK_SIZE, scores);
+ in.readFloats(lowerIntervals, 0, BULK_SIZE);
+ in.readFloats(upperIntervals, 0, BULK_SIZE);
+ for (int i = 0; i < BULK_SIZE; i++) {
+ targetComponentSums[i] = Short.toUnsignedInt(in.readShort());
+ }
+ in.readFloats(additionalCorrections, 0, BULK_SIZE);
+ for (int i = 0; i < BULK_SIZE; i++) {
+ scores[i] = score(
+ queryCorrections,
+ similarityFunction,
+ centroidDp,
+ lowerIntervals[i],
+ upperIntervals[i],
+ targetComponentSums[i],
+ additionalCorrections[i],
+ scores[i]
+ );
+ }
+ }
+}
diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java
index e541c10e145bf..7f4e62f156a36 100644
--- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java
+++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java
@@ -9,6 +9,9 @@
package org.elasticsearch.simdvec.internal.vectorization;
+import org.apache.lucene.store.IndexInput;
+
+import java.io.IOException;
import java.util.Objects;
public abstract class ESVectorizationProvider {
@@ -24,6 +27,9 @@ public static ESVectorizationProvider getInstance() {
public abstract ESVectorUtilSupport getVectorUtilSupport();
+ /** Create a new {@link ES91OSQVectorsScorer} for the given {@link IndexInput}. */
+ public abstract ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException;
+
// visible for tests
static ESVectorizationProvider lookup(boolean testMode) {
return new DefaultESVectorizationProvider();
diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java
index 758116300aa0f..af5df094659e5 100644
--- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java
+++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java
@@ -9,10 +9,12 @@
package org.elasticsearch.simdvec.internal.vectorization;
+import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Constants;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
+import java.io.IOException;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
@@ -32,6 +34,9 @@ public static ESVectorizationProvider getInstance() {
public abstract ESVectorUtilSupport getVectorUtilSupport();
+ /** Create a new {@link ES91OSQVectorsScorer} for the given {@link IndexInput}. */
+ public abstract ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException;
+
// visible for tests
static ESVectorizationProvider lookup(boolean testMode) {
final int runtimeVersion = Runtime.version().feature();
diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java
new file mode 100644
index 0000000000000..3b9c8ae5c5212
--- /dev/null
+++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java
@@ -0,0 +1,450 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+package org.elasticsearch.simdvec.internal.vectorization;
+
+import jdk.incubator.vector.ByteVector;
+import jdk.incubator.vector.FloatVector;
+import jdk.incubator.vector.IntVector;
+import jdk.incubator.vector.LongVector;
+import jdk.incubator.vector.ShortVector;
+import jdk.incubator.vector.VectorOperators;
+import jdk.incubator.vector.VectorSpecies;
+
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.util.VectorUtil;
+import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
+
+import java.io.IOException;
+import java.lang.foreign.MemorySegment;
+import java.nio.ByteOrder;
+
+import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
+import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
+
+/** Panamized scorer for quantized vectors stored as an {@link IndexInput}. */
+public final class MemorySegmentES91OSQVectorsScorer extends ES91OSQVectorsScorer {
+
+ private static final VectorSpecies INT_SPECIES_128 = IntVector.SPECIES_128;
+
+ private static final VectorSpecies LONG_SPECIES_128 = LongVector.SPECIES_128;
+ private static final VectorSpecies LONG_SPECIES_256 = LongVector.SPECIES_256;
+
+ private static final VectorSpecies BYTE_SPECIES_128 = ByteVector.SPECIES_128;
+ private static final VectorSpecies BYTE_SPECIES_256 = ByteVector.SPECIES_256;
+
+ private static final VectorSpecies SHORT_SPECIES_128 = ShortVector.SPECIES_128;
+ private static final VectorSpecies SHORT_SPECIES_256 = ShortVector.SPECIES_256;
+
+ private static final VectorSpecies FLOAT_SPECIES_128 = FloatVector.SPECIES_128;
+ private static final VectorSpecies FLOAT_SPECIES_256 = FloatVector.SPECIES_256;
+
+ private final MemorySegment memorySegment;
+
+ public MemorySegmentES91OSQVectorsScorer(IndexInput in, int dimensions, MemorySegment memorySegment) {
+ super(in, dimensions);
+ this.memorySegment = memorySegment;
+ }
+
+ @Override
+ public long quantizeScore(byte[] q) throws IOException {
+ assert q.length == length * 4;
+ // 128 / 8 == 16
+ if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) {
+ if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) {
+ return quantizeScore256(q);
+ } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) {
+ return quantizeScore128(q);
+ }
+ }
+ return super.quantizeScore(q);
+ }
+
+ private long quantizeScore256(byte[] q) throws IOException {
+ long subRet0 = 0;
+ long subRet1 = 0;
+ long subRet2 = 0;
+ long subRet3 = 0;
+ int i = 0;
+ long offset = in.getFilePointer();
+ if (length >= ByteVector.SPECIES_256.vectorByteSize() * 2) {
+ int limit = ByteVector.SPECIES_256.loopBound(length);
+ var sum0 = LongVector.zero(LONG_SPECIES_256);
+ var sum1 = LongVector.zero(LONG_SPECIES_256);
+ var sum2 = LongVector.zero(LONG_SPECIES_256);
+ var sum3 = LongVector.zero(LONG_SPECIES_256);
+ for (; i < limit; i += ByteVector.SPECIES_256.length(), offset += LONG_SPECIES_256.vectorByteSize()) {
+ var vq0 = ByteVector.fromArray(BYTE_SPECIES_256, q, i).reinterpretAsLongs();
+ var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length).reinterpretAsLongs();
+ var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 2).reinterpretAsLongs();
+ var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 3).reinterpretAsLongs();
+ var vd = LongVector.fromMemorySegment(LONG_SPECIES_256, memorySegment, offset, ByteOrder.LITTLE_ENDIAN);
+ sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ }
+ subRet0 += sum0.reduceLanes(VectorOperators.ADD);
+ subRet1 += sum1.reduceLanes(VectorOperators.ADD);
+ subRet2 += sum2.reduceLanes(VectorOperators.ADD);
+ subRet3 += sum3.reduceLanes(VectorOperators.ADD);
+ }
+
+ if (length - i >= ByteVector.SPECIES_128.vectorByteSize()) {
+ var sum0 = LongVector.zero(LONG_SPECIES_128);
+ var sum1 = LongVector.zero(LONG_SPECIES_128);
+ var sum2 = LongVector.zero(LONG_SPECIES_128);
+ var sum3 = LongVector.zero(LONG_SPECIES_128);
+ int limit = ByteVector.SPECIES_128.loopBound(length);
+ for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += LONG_SPECIES_128.vectorByteSize()) {
+ var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsLongs();
+ var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsLongs();
+ var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsLongs();
+ var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsLongs();
+ var vd = LongVector.fromMemorySegment(LONG_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN);
+ sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ }
+ subRet0 += sum0.reduceLanes(VectorOperators.ADD);
+ subRet1 += sum1.reduceLanes(VectorOperators.ADD);
+ subRet2 += sum2.reduceLanes(VectorOperators.ADD);
+ subRet3 += sum3.reduceLanes(VectorOperators.ADD);
+ }
+ // tail as bytes
+ in.seek(offset);
+ for (; i < length; i++) {
+ int dValue = in.readByte() & 0xFF;
+ subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);
+ subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF);
+ subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF);
+ subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF);
+ }
+ return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
+ }
+
+ private long quantizeScore128(byte[] q) throws IOException {
+ long subRet0 = 0;
+ long subRet1 = 0;
+ long subRet2 = 0;
+ long subRet3 = 0;
+ int i = 0;
+ long offset = in.getFilePointer();
+
+ var sum0 = IntVector.zero(INT_SPECIES_128);
+ var sum1 = IntVector.zero(INT_SPECIES_128);
+ var sum2 = IntVector.zero(INT_SPECIES_128);
+ var sum3 = IntVector.zero(INT_SPECIES_128);
+ int limit = ByteVector.SPECIES_128.loopBound(length);
+ for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += INT_SPECIES_128.vectorByteSize()) {
+ var vd = IntVector.fromMemorySegment(INT_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN);
+ var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsInts();
+ var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsInts();
+ var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsInts();
+ var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsInts();
+ sum0 = sum0.add(vd.and(vq0).lanewise(VectorOperators.BIT_COUNT));
+ sum1 = sum1.add(vd.and(vq1).lanewise(VectorOperators.BIT_COUNT));
+ sum2 = sum2.add(vd.and(vq2).lanewise(VectorOperators.BIT_COUNT));
+ sum3 = sum3.add(vd.and(vq3).lanewise(VectorOperators.BIT_COUNT));
+ }
+ subRet0 += sum0.reduceLanes(VectorOperators.ADD);
+ subRet1 += sum1.reduceLanes(VectorOperators.ADD);
+ subRet2 += sum2.reduceLanes(VectorOperators.ADD);
+ subRet3 += sum3.reduceLanes(VectorOperators.ADD);
+ // tail as bytes
+ in.seek(offset);
+ for (; i < length; i++) {
+ int dValue = in.readByte() & 0xFF;
+ subRet0 += Integer.bitCount((dValue & q[i]) & 0xFF);
+ subRet1 += Integer.bitCount((dValue & q[i + length]) & 0xFF);
+ subRet2 += Integer.bitCount((dValue & q[i + 2 * length]) & 0xFF);
+ subRet3 += Integer.bitCount((dValue & q[i + 3 * length]) & 0xFF);
+ }
+ return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
+ }
+
+ @Override
+ public void quantizeScoreBulk(byte[] q, int count, float[] scores) throws IOException {
+ assert q.length == length * 4;
+ // 128 / 8 == 16
+ if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) {
+ if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) {
+ quantizeScore256Bulk(q, count, scores);
+ return;
+ } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) {
+ quantizeScore128Bulk(q, count, scores);
+ return;
+ }
+ }
+ super.quantizeScoreBulk(q, count, scores);
+ }
+
+ private void quantizeScore128Bulk(byte[] q, int count, float[] scores) throws IOException {
+ for (int iter = 0; iter < count; iter++) {
+ long subRet0 = 0;
+ long subRet1 = 0;
+ long subRet2 = 0;
+ long subRet3 = 0;
+ int i = 0;
+ long offset = in.getFilePointer();
+
+ var sum0 = IntVector.zero(INT_SPECIES_128);
+ var sum1 = IntVector.zero(INT_SPECIES_128);
+ var sum2 = IntVector.zero(INT_SPECIES_128);
+ var sum3 = IntVector.zero(INT_SPECIES_128);
+ int limit = ByteVector.SPECIES_128.loopBound(length);
+ for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += INT_SPECIES_128.vectorByteSize()) {
+ var vd = IntVector.fromMemorySegment(INT_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN);
+ var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsInts();
+ var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsInts();
+ var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsInts();
+ var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsInts();
+ sum0 = sum0.add(vd.and(vq0).lanewise(VectorOperators.BIT_COUNT));
+ sum1 = sum1.add(vd.and(vq1).lanewise(VectorOperators.BIT_COUNT));
+ sum2 = sum2.add(vd.and(vq2).lanewise(VectorOperators.BIT_COUNT));
+ sum3 = sum3.add(vd.and(vq3).lanewise(VectorOperators.BIT_COUNT));
+ }
+ subRet0 += sum0.reduceLanes(VectorOperators.ADD);
+ subRet1 += sum1.reduceLanes(VectorOperators.ADD);
+ subRet2 += sum2.reduceLanes(VectorOperators.ADD);
+ subRet3 += sum3.reduceLanes(VectorOperators.ADD);
+ // tail as bytes
+ in.seek(offset);
+ for (; i < length; i++) {
+ int dValue = in.readByte() & 0xFF;
+ subRet0 += Integer.bitCount((dValue & q[i]) & 0xFF);
+ subRet1 += Integer.bitCount((dValue & q[i + length]) & 0xFF);
+ subRet2 += Integer.bitCount((dValue & q[i + 2 * length]) & 0xFF);
+ subRet3 += Integer.bitCount((dValue & q[i + 3 * length]) & 0xFF);
+ }
+ scores[iter] = subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
+ }
+ }
+
+ private void quantizeScore256Bulk(byte[] q, int count, float[] scores) throws IOException {
+ for (int iter = 0; iter < count; iter++) {
+ long subRet0 = 0;
+ long subRet1 = 0;
+ long subRet2 = 0;
+ long subRet3 = 0;
+ int i = 0;
+ long offset = in.getFilePointer();
+ if (length >= ByteVector.SPECIES_256.vectorByteSize() * 2) {
+ int limit = ByteVector.SPECIES_256.loopBound(length);
+ var sum0 = LongVector.zero(LONG_SPECIES_256);
+ var sum1 = LongVector.zero(LONG_SPECIES_256);
+ var sum2 = LongVector.zero(LONG_SPECIES_256);
+ var sum3 = LongVector.zero(LONG_SPECIES_256);
+ for (; i < limit; i += ByteVector.SPECIES_256.length(), offset += LONG_SPECIES_256.vectorByteSize()) {
+ var vq0 = ByteVector.fromArray(BYTE_SPECIES_256, q, i).reinterpretAsLongs();
+ var vq1 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length).reinterpretAsLongs();
+ var vq2 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 2).reinterpretAsLongs();
+ var vq3 = ByteVector.fromArray(BYTE_SPECIES_256, q, i + length * 3).reinterpretAsLongs();
+ var vd = LongVector.fromMemorySegment(LONG_SPECIES_256, memorySegment, offset, ByteOrder.LITTLE_ENDIAN);
+ sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ }
+ subRet0 += sum0.reduceLanes(VectorOperators.ADD);
+ subRet1 += sum1.reduceLanes(VectorOperators.ADD);
+ subRet2 += sum2.reduceLanes(VectorOperators.ADD);
+ subRet3 += sum3.reduceLanes(VectorOperators.ADD);
+ }
+
+ if (length - i >= ByteVector.SPECIES_128.vectorByteSize()) {
+ var sum0 = LongVector.zero(LONG_SPECIES_128);
+ var sum1 = LongVector.zero(LONG_SPECIES_128);
+ var sum2 = LongVector.zero(LONG_SPECIES_128);
+ var sum3 = LongVector.zero(LONG_SPECIES_128);
+ int limit = ByteVector.SPECIES_128.loopBound(length);
+ for (; i < limit; i += ByteVector.SPECIES_128.length(), offset += LONG_SPECIES_128.vectorByteSize()) {
+ var vq0 = ByteVector.fromArray(BYTE_SPECIES_128, q, i).reinterpretAsLongs();
+ var vq1 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length).reinterpretAsLongs();
+ var vq2 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 2).reinterpretAsLongs();
+ var vq3 = ByteVector.fromArray(BYTE_SPECIES_128, q, i + length * 3).reinterpretAsLongs();
+ var vd = LongVector.fromMemorySegment(LONG_SPECIES_128, memorySegment, offset, ByteOrder.LITTLE_ENDIAN);
+ sum0 = sum0.add(vq0.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ sum1 = sum1.add(vq1.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ sum2 = sum2.add(vq2.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ sum3 = sum3.add(vq3.and(vd).lanewise(VectorOperators.BIT_COUNT));
+ }
+ subRet0 += sum0.reduceLanes(VectorOperators.ADD);
+ subRet1 += sum1.reduceLanes(VectorOperators.ADD);
+ subRet2 += sum2.reduceLanes(VectorOperators.ADD);
+ subRet3 += sum3.reduceLanes(VectorOperators.ADD);
+ }
+ // tail as bytes
+ in.seek(offset);
+ for (; i < length; i++) {
+ int dValue = in.readByte() & 0xFF;
+ subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);
+ subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF);
+ subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF);
+ subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF);
+ }
+ scores[iter] = subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
+ }
+ }
+
+ @Override
+ public void scoreBulk(
+ byte[] q,
+ OptimizedScalarQuantizer.QuantizationResult queryCorrections,
+ VectorSimilarityFunction similarityFunction,
+ float centroidDp,
+ float[] scores
+ ) throws IOException {
+ assert q.length == length * 4;
+ // 128 / 8 == 16
+ if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) {
+ if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) {
+ score256Bulk(q, queryCorrections, similarityFunction, centroidDp, scores);
+ return;
+ } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) {
+ score128Bulk(q, queryCorrections, similarityFunction, centroidDp, scores);
+ return;
+ }
+ }
+ super.scoreBulk(q, queryCorrections, similarityFunction, centroidDp, scores);
+ }
+
+ private void score128Bulk(
+ byte[] q,
+ OptimizedScalarQuantizer.QuantizationResult queryCorrections,
+ VectorSimilarityFunction similarityFunction,
+ float centroidDp,
+ float[] scores
+ ) throws IOException {
+ quantizeScore128Bulk(q, BULK_SIZE, scores);
+ int limit = FLOAT_SPECIES_128.loopBound(BULK_SIZE);
+ int i = 0;
+ long offset = in.getFilePointer();
+ float ay = queryCorrections.lowerInterval();
+ float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
+ float y1 = queryCorrections.quantizedComponentSum();
+ for (; i < limit; i += FLOAT_SPECIES_128.length()) {
+ var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_128, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
+ var lx = FloatVector.fromMemorySegment(
+ FLOAT_SPECIES_128,
+ memorySegment,
+ offset + 4 * BULK_SIZE + i * Float.BYTES,
+ ByteOrder.LITTLE_ENDIAN
+ ).sub(ax);
+ var targetComponentSums = ShortVector.fromMemorySegment(
+ SHORT_SPECIES_128,
+ memorySegment,
+ offset + 8 * BULK_SIZE + i * Short.BYTES,
+ ByteOrder.LITTLE_ENDIAN
+ ).convert(VectorOperators.S2I, 0).reinterpretAsInts().and(0xffff).convert(VectorOperators.I2F, 0);
+ var additionalCorrections = FloatVector.fromMemorySegment(
+ FLOAT_SPECIES_128,
+ memorySegment,
+ offset + 10 * BULK_SIZE + i * Float.BYTES,
+ ByteOrder.LITTLE_ENDIAN
+ );
+ var qcDist = FloatVector.fromArray(FLOAT_SPECIES_128, scores, i);
+ // ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly *
+ // qcDist;
+ var res1 = ax.mul(ay).mul(dimensions);
+ var res2 = lx.mul(ay).mul(targetComponentSums);
+ var res3 = ax.mul(ly).mul(y1);
+ var res4 = lx.mul(ly).mul(qcDist);
+ var res = res1.add(res2).add(res3).add(res4);
+ // For euclidean, we need to invert the score and apply the additional correction, which is
+ // assumed to be the squared l2norm of the centroid centered vectors.
+ if (similarityFunction == EUCLIDEAN) {
+ res = res.mul(-2).add(additionalCorrections).add(queryCorrections.additionalCorrection()).add(1f);
+ res = FloatVector.broadcast(FLOAT_SPECIES_128, 1).div(res).max(0);
+ res.intoArray(scores, i);
+ } else {
+ // For cosine and max inner product, we need to apply the additional correction, which is
+ // assumed to be the non-centered dot-product between the vector and the centroid
+ res = res.add(queryCorrections.additionalCorrection()).add(additionalCorrections).sub(centroidDp);
+ if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
+ res.intoArray(scores, i);
+ // not sure how to do it better
+ for (int j = 0; j < FLOAT_SPECIES_128.length(); j++) {
+ scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]);
+ }
+ } else {
+ res = res.add(1f).mul(0.5f).max(0);
+ res.intoArray(scores, i);
+ }
+ }
+ }
+ in.seek(offset + 14L * BULK_SIZE);
+ }
+
+ private void score256Bulk(
+ byte[] q,
+ OptimizedScalarQuantizer.QuantizationResult queryCorrections,
+ VectorSimilarityFunction similarityFunction,
+ float centroidDp,
+ float[] scores
+ ) throws IOException {
+ quantizeScore256Bulk(q, BULK_SIZE, scores);
+ int limit = FLOAT_SPECIES_256.loopBound(BULK_SIZE);
+ int i = 0;
+ long offset = in.getFilePointer();
+ float ay = queryCorrections.lowerInterval();
+ float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
+ float y1 = queryCorrections.quantizedComponentSum();
+ for (; i < limit; i += FLOAT_SPECIES_256.length()) {
+ var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_256, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
+ var lx = FloatVector.fromMemorySegment(
+ FLOAT_SPECIES_256,
+ memorySegment,
+ offset + 4 * BULK_SIZE + i * Float.BYTES,
+ ByteOrder.LITTLE_ENDIAN
+ ).sub(ax);
+ var targetComponentSums = ShortVector.fromMemorySegment(
+ SHORT_SPECIES_256,
+ memorySegment,
+ offset + 8 * BULK_SIZE + i * Short.BYTES,
+ ByteOrder.LITTLE_ENDIAN
+ ).convert(VectorOperators.S2I, 0).reinterpretAsInts().and(0xffff).convert(VectorOperators.I2F, 0);
+ var additionalCorrections = FloatVector.fromMemorySegment(
+ FLOAT_SPECIES_256,
+ memorySegment,
+ offset + 10 * BULK_SIZE + i * Float.BYTES,
+ ByteOrder.LITTLE_ENDIAN
+ );
+ var qcDist = FloatVector.fromArray(FLOAT_SPECIES_256, scores, i);
+ // ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly *
+ // qcDist;
+ var res1 = ax.mul(ay).mul(dimensions);
+ var res2 = lx.mul(ay).mul(targetComponentSums);
+ var res3 = ax.mul(ly).mul(y1);
+ var res4 = lx.mul(ly).mul(qcDist);
+ var res = res1.add(res2).add(res3).add(res4);
+ // For euclidean, we need to invert the score and apply the additional correction, which is
+ // assumed to be the squared l2norm of the centroid centered vectors.
+ if (similarityFunction == EUCLIDEAN) {
+ res = res.mul(-2).add(additionalCorrections).add(queryCorrections.additionalCorrection()).add(1f);
+ res = FloatVector.broadcast(FLOAT_SPECIES_256, 1).div(res).max(0);
+ res.intoArray(scores, i);
+ } else {
+ // For cosine and max inner product, we need to apply the additional correction, which is
+ // assumed to be the non-centered dot-product between the vector and the centroid
+ res = res.add(queryCorrections.additionalCorrection()).add(additionalCorrections).sub(centroidDp);
+ if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
+ res.intoArray(scores, i);
+ // not sure how to do it better
+ for (int j = 0; j < FLOAT_SPECIES_256.length(); j++) {
+ scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]);
+ }
+ } else {
+ res = res.add(1f).mul(0.5f).max(0);
+ res.intoArray(scores, i);
+ }
+ }
+ }
+ in.seek(offset + 14L * BULK_SIZE);
+ }
+}
diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java
index 62d25d79487ed..c409be4fb37d8 100644
--- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java
+++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java
@@ -9,6 +9,12 @@
package org.elasticsearch.simdvec.internal.vectorization;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.store.MemorySegmentAccessInput;
+
+import java.io.IOException;
+import java.lang.foreign.MemorySegment;
+
final class PanamaESVectorizationProvider extends ESVectorizationProvider {
private final ESVectorUtilSupport vectorUtilSupport;
@@ -21,4 +27,15 @@ final class PanamaESVectorizationProvider extends ESVectorizationProvider {
public ESVectorUtilSupport getVectorUtilSupport() {
return vectorUtilSupport;
}
+
+ @Override
+ public ES91OSQVectorsScorer newES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException {
+ if (PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS && input instanceof MemorySegmentAccessInput msai) {
+ MemorySegment ms = msai.segmentSliceOrNull(0, input.length());
+ if (ms != null) {
+ return new MemorySegmentES91OSQVectorsScorer(input, dimension, ms);
+ }
+ }
+ return new ES91OSQVectorsScorer(input, dimension);
+ }
}
diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java
new file mode 100644
index 0000000000000..81172872d47d8
--- /dev/null
+++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java
@@ -0,0 +1,103 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.simdvec.internal.vectorization;
+
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.store.IOContext;
+import org.apache.lucene.store.IndexInput;
+import org.apache.lucene.store.IndexOutput;
+import org.apache.lucene.store.MMapDirectory;
+import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
+
+public class ES91OSQVectorScorerTests extends BaseVectorizationTests {
+ public void testQuantizeScore() throws Exception {
+ final int dimensions = random().nextInt(1, 2000);
+ final int length = OptimizedScalarQuantizer.discretize(dimensions, 64) / 8;
+ final int numVectors = random().nextInt(1, 100);
+ final byte[] vector = new byte[length];
+ try (Directory dir = new MMapDirectory(createTempDir())) {
+ try (IndexOutput out = dir.createOutput("tests.bin", IOContext.DEFAULT)) {
+ for (int i = 0; i < numVectors; i++) {
+ random().nextBytes(vector);
+ out.writeBytes(vector, 0, length);
+ }
+ }
+ final byte[] query = new byte[4 * length];
+ random().nextBytes(query);
+ try (IndexInput in = dir.openInput("tests.bin", IOContext.DEFAULT)) {
+ // Work on a slice that has just the right number of bytes to make the test fail with an
+ // index-out-of-bounds in case the implementation reads more than the allowed number of
+ // padding bytes.
+ final IndexInput slice = in.slice("test", 0, (long) length * numVectors);
+ final ES91OSQVectorsScorer defaultScorer = defaultProvider().newES91OSQVectorsScorer(slice, dimensions);
+ final ES91OSQVectorsScorer panamaScorer = maybePanamaProvider().newES91OSQVectorsScorer(in, dimensions);
+ for (int i = 0; i < numVectors; i++) {
+ assertEquals(defaultScorer.quantizeScore(query), panamaScorer.quantizeScore(query));
+ assertEquals(in.getFilePointer(), slice.getFilePointer());
+ }
+ assertEquals((long) length * numVectors, slice.getFilePointer());
+ }
+ }
+ }
+
+ public void testScore() throws Exception {
+ final int dimensions = random().nextInt(1, 2000);
+ final int length = OptimizedScalarQuantizer.discretize(dimensions, 64) / 8;
+ final int numVectors = ES91OSQVectorsScorer.BULK_SIZE * random().nextInt(1, 10);
+ final byte[] vector = new byte[length + 14];
+ int padding = random().nextInt(100);
+ try (Directory dir = new MMapDirectory(createTempDir())) {
+ try (IndexOutput out = dir.createOutput("testScore.bin", IOContext.DEFAULT)) {
+ for (int i = 0; i < padding; i++) {
+ out.writeByte((byte) random().nextInt());
+ }
+ for (int i = 0; i < numVectors; i++) {
+ random().nextBytes(vector);
+ out.writeBytes(vector, 0, length + 14);
+ }
+ }
+ final byte[] query = new byte[4 * length];
+ random().nextBytes(query);
+ OptimizedScalarQuantizer.QuantizationResult result = new OptimizedScalarQuantizer.QuantizationResult(
+ random().nextFloat(),
+ random().nextFloat(),
+ random().nextFloat(),
+ Short.toUnsignedInt((short) random().nextInt())
+ );
+ final float centroidDp = random().nextFloat();
+ final float[] scores1 = new float[ES91OSQVectorsScorer.BULK_SIZE];
+ final float[] scores2 = new float[ES91OSQVectorsScorer.BULK_SIZE];
+ for (VectorSimilarityFunction similarityFunction : VectorSimilarityFunction.values()) {
+ try (IndexInput in = dir.openInput("testScore.bin", IOContext.DEFAULT)) {
+ in.seek(padding);
+ assertEquals(in.length(), padding + (long) numVectors * (length + 14));
+ // Work on a slice that has just the right number of bytes to make the test fail with an
+ // index-out-of-bounds in case the implementation reads more than the allowed number of
+ // padding bytes.
+ for (int i = 0; i < numVectors; i += ES91OSQVectorsScorer.BULK_SIZE) {
+ final IndexInput slice = in.slice(
+ "test",
+ in.getFilePointer(),
+ (long) (length + 14) * ES91OSQVectorsScorer.BULK_SIZE
+ );
+ final ES91OSQVectorsScorer defaultScorer = defaultProvider().newES91OSQVectorsScorer(slice, dimensions);
+ final ES91OSQVectorsScorer panamaScorer = maybePanamaProvider().newES91OSQVectorsScorer(in, dimensions);
+ defaultScorer.scoreBulk(query, result, similarityFunction, centroidDp, scores1);
+ panamaScorer.scoreBulk(query, result, similarityFunction, centroidDp, scores2);
+ assertArrayEquals(scores1, scores2, 1e-2f);
+ assertEquals(((long) (ES91OSQVectorsScorer.BULK_SIZE) * (length + 14)), slice.getFilePointer());
+ assertEquals(padding + ((long) (i + ES91OSQVectorsScorer.BULK_SIZE) * (length + 14)), in.getFilePointer());
+ }
+ }
+ }
+ }
+ }
+}
From f134ad3c7a8759e3b8d622a1bfbfa77f64314a13 Mon Sep 17 00:00:00 2001
From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com>
Date: Wed, 23 Apr 2025 10:42:27 -0400
Subject: [PATCH 2/3] fixing headers
---
.../internal/vectorization/ES91OSQVectorsScorer.java | 8 +++++---
.../vectorization/MemorySegmentES91OSQVectorsScorer.java | 8 +++++---
2 files changed, 10 insertions(+), 6 deletions(-)
diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorsScorer.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorsScorer.java
index 66a8e5b89855d..839e5f29a1148 100644
--- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorsScorer.java
+++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorsScorer.java
@@ -1,8 +1,10 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.simdvec.internal.vectorization;
diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java
index 3b9c8ae5c5212..0bf3b8da22d2c 100644
--- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java
+++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java
@@ -1,8 +1,10 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.simdvec.internal.vectorization;
From 9353495d4781f9d6c0fb00ef6940c0eab8aa18ab Mon Sep 17 00:00:00 2001
From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com>
Date: Wed, 23 Apr 2025 16:07:20 -0400
Subject: [PATCH 3/3] fixing tests
---
.../ES91OSQVectorScorerTests.java | 43 ++++++++++++++-----
1 file changed, 33 insertions(+), 10 deletions(-)
diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java
index 81172872d47d8..53b14ae4910c0 100644
--- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java
+++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java
@@ -17,7 +17,10 @@
import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
+import static org.hamcrest.Matchers.lessThan;
+
public class ES91OSQVectorScorerTests extends BaseVectorizationTests {
+
public void testQuantizeScore() throws Exception {
final int dimensions = random().nextInt(1, 2000);
final int length = OptimizedScalarQuantizer.discretize(dimensions, 64) / 8;
@@ -49,28 +52,38 @@ public void testQuantizeScore() throws Exception {
}
public void testScore() throws Exception {
- final int dimensions = random().nextInt(1, 2000);
+ final int maxDims = 512;
+ final int dimensions = random().nextInt(1, maxDims);
final int length = OptimizedScalarQuantizer.discretize(dimensions, 64) / 8;
final int numVectors = ES91OSQVectorsScorer.BULK_SIZE * random().nextInt(1, 10);
- final byte[] vector = new byte[length + 14];
+ final byte[] vector = new byte[length];
int padding = random().nextInt(100);
+ byte[] paddingBytes = new byte[padding];
try (Directory dir = new MMapDirectory(createTempDir())) {
try (IndexOutput out = dir.createOutput("testScore.bin", IOContext.DEFAULT)) {
- for (int i = 0; i < padding; i++) {
- out.writeByte((byte) random().nextInt());
- }
+ random().nextBytes(paddingBytes);
+ out.writeBytes(paddingBytes, 0, padding);
for (int i = 0; i < numVectors; i++) {
random().nextBytes(vector);
- out.writeBytes(vector, 0, length + 14);
+ out.writeBytes(vector, 0, length);
+ float lower = random().nextFloat();
+ float upper = random().nextFloat() + lower / 2;
+ float additionalCorrection = random().nextFloat();
+ int targetComponentSum = randomIntBetween(0, dimensions / 2);
+ out.writeInt(Float.floatToIntBits(lower));
+ out.writeInt(Float.floatToIntBits(upper));
+ out.writeShort((short) targetComponentSum);
+ out.writeInt(Float.floatToIntBits(additionalCorrection));
}
}
final byte[] query = new byte[4 * length];
random().nextBytes(query);
+ float lower = random().nextFloat();
OptimizedScalarQuantizer.QuantizationResult result = new OptimizedScalarQuantizer.QuantizationResult(
+ lower,
+ random().nextFloat() + lower / 2,
random().nextFloat(),
- random().nextFloat(),
- random().nextFloat(),
- Short.toUnsignedInt((short) random().nextInt())
+ randomIntBetween(0, dimensions * 2)
);
final float centroidDp = random().nextFloat();
final float[] scores1 = new float[ES91OSQVectorsScorer.BULK_SIZE];
@@ -92,7 +105,17 @@ public void testScore() throws Exception {
final ES91OSQVectorsScorer panamaScorer = maybePanamaProvider().newES91OSQVectorsScorer(in, dimensions);
defaultScorer.scoreBulk(query, result, similarityFunction, centroidDp, scores1);
panamaScorer.scoreBulk(query, result, similarityFunction, centroidDp, scores2);
- assertArrayEquals(scores1, scores2, 1e-2f);
+ for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) {
+ if (scores1[j] > (maxDims * Short.MAX_VALUE)) {
+ int diff = (int) (scores1[j] - scores2[j]);
+ assertThat("defaultScores: " + scores1[j] + " bulkScores: " + scores2[j], Math.abs(diff), lessThan(65));
+ } else if (scores1[j] > (maxDims * Byte.MAX_VALUE)) {
+ int diff = (int) (scores1[j] - scores2[j]);
+ assertThat("defaultScores: " + scores1[j] + " bulkScores: " + scores2[j], Math.abs(diff), lessThan(9));
+ } else {
+ assertEquals(scores1[j], scores2[j], 1e-2f);
+ }
+ }
assertEquals(((long) (ES91OSQVectorsScorer.BULK_SIZE) * (length + 14)), slice.getFilePointer());
assertEquals(padding + ((long) (i + ES91OSQVectorsScorer.BULK_SIZE) * (length + 14)), in.getFilePointer());
}